1//! `ep split` implementation.
2//!
3//! This command splits a JSONL dataset into multiple files based on size specifications,
4//! with stratification by repository URL (if the field is present).
5//!
6//! # Usage
7//!
8//! ```text
9//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
10//! ```
11//!
12//! If `input.jsonl` is not provided or is `-`, reads from stdin.
13//!
14//! # Size specifications
15//!
16//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
17//! - `100` - absolute count of repositories (if stratified) or examples
18//! - `rest` - all remaining items (only one split can use this)
19//!
20//! # Stratification
21//!
22//! When examples have a `repository_url` field, the split is stratified by repository.
23//! This ensures each output file contains examples from non-overlapping repositories.
24//! Size specifications apply to the number of repositories, not individual examples.
25//!
26//! Examples without `repository_url` are distributed proportionally across all outputs.
27
28use anyhow::{Context as _, Result, bail};
29use clap::Args;
30use rand::SeedableRng;
31use rand::seq::SliceRandom;
32use serde_json::Value;
33use std::collections::HashMap;
34use std::fs::File;
35use std::io::{self, BufRead, BufReader, BufWriter, Write};
36use std::path::{Path, PathBuf};
37
38/// `ep split` CLI args.
39#[derive(Debug, Args, Clone)]
40#[command(
41 about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
42 after_help = r#"SIZE SPECIFICATIONS:
43 <percentage>% Percentage of total (e.g., 80%)
44 <count> Absolute number (e.g., 100)
45 rest All remaining items (only one output can use this)
46
47 When stratifying by repository_url, sizes apply to repositories, not examples.
48
49EXAMPLES:
50 # Split 80% train, 20% validation
51 ep split input.jsonl train.jsonl=80% valid.jsonl=rest
52
53 # Split into train/valid/test
54 ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
55
56 # Use absolute counts (100 repos to train, rest to valid)
57 ep split input.jsonl train.jsonl=100 valid.jsonl=rest
58
59 # Read from stdin
60 cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
61
62 # Reproducible split with seed
63 ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
64
65STRATIFICATION:
66 When examples have a "repository_url" field, the split ensures each output
67 file contains examples from non-overlapping repositories. This prevents
68 data leakage between train/test splits.
69"#
70)]
71pub struct SplitArgs {
72 /// Random seed for reproducibility
73 #[arg(long)]
74 pub seed: Option<u64>,
75}
76
77#[derive(Debug, Clone)]
78pub enum SplitSize {
79 Percentage(f64),
80 Absolute(usize),
81 Rest,
82}
83
84#[derive(Debug, Clone)]
85pub struct SplitSpec {
86 pub path: PathBuf,
87 pub size: SplitSize,
88}
89
90fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
91 let (path, size_str) = spec
92 .rsplit_once('=')
93 .with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
94
95 let size = if size_str == "rest" {
96 SplitSize::Rest
97 } else if size_str.ends_with('%') {
98 let pct_str = size_str.trim_end_matches('%');
99 let pct: f64 = pct_str
100 .parse()
101 .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
102 if !(0.0..=100.0).contains(&pct) {
103 bail!("percentage must be between 0 and 100, got {}", pct);
104 }
105 SplitSize::Percentage(pct / 100.0)
106 } else {
107 let count: usize = size_str
108 .parse()
109 .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
110 SplitSize::Absolute(count)
111 };
112
113 Ok(SplitSpec {
114 path: PathBuf::from(path),
115 size,
116 })
117}
118
119fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
120 let reader: Box<dyn BufRead> = match input {
121 Some(path) => {
122 let file =
123 File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
124 Box::new(BufReader::new(file))
125 }
126 None => Box::new(BufReader::new(io::stdin())),
127 };
128
129 let lines: Vec<String> = reader
130 .lines()
131 .collect::<io::Result<Vec<_>>>()
132 .context("failed to read input lines")?;
133
134 Ok(lines)
135}
136
137fn get_repository_url(line: &str) -> Option<String> {
138 let value: Value = serde_json::from_str(line).ok()?;
139 value
140 .get("repository_url")
141 .and_then(|v| v.as_str())
142 .map(|s| s.to_string())
143}
144
145fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
146 let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
147 let mut without_repo: Vec<String> = Vec::new();
148
149 for line in lines {
150 if let Some(repo_url) = get_repository_url(&line) {
151 by_repo.entry(repo_url).or_default().push(line);
152 } else {
153 without_repo.push(line);
154 }
155 }
156
157 (by_repo, without_repo)
158}
159
160fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
161 let mut counts = vec![0usize; specs.len()];
162 let mut remaining = total;
163 let mut rest_index: Option<usize> = None;
164
165 for (i, spec) in specs.iter().enumerate() {
166 match &spec.size {
167 SplitSize::Percentage(pct) => {
168 let count = (total as f64 * pct).round() as usize;
169 counts[i] = count.min(remaining);
170 remaining = remaining.saturating_sub(counts[i]);
171 }
172 SplitSize::Absolute(count) => {
173 counts[i] = (*count).min(remaining);
174 remaining = remaining.saturating_sub(counts[i]);
175 }
176 SplitSize::Rest => {
177 if rest_index.is_some() {
178 bail!("only one split can use 'rest'");
179 }
180 rest_index = Some(i);
181 }
182 }
183 }
184
185 if let Some(idx) = rest_index {
186 counts[idx] = remaining;
187 }
188
189 Ok(counts)
190}
191
192fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
193 if let Some(parent) = path.parent() {
194 if !parent.as_os_str().is_empty() {
195 std::fs::create_dir_all(parent)
196 .with_context(|| format!("failed to create directory '{}'", parent.display()))?;
197 }
198 }
199
200 let file =
201 File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
202 let mut writer = BufWriter::new(file);
203
204 for line in lines {
205 writeln!(writer, "{}", line)
206 .with_context(|| format!("failed to write to '{}'", path.display()))?;
207 }
208
209 writer
210 .flush()
211 .with_context(|| format!("failed to flush '{}'", path.display()))?;
212
213 Ok(())
214}
215
216pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
217 if inputs.is_empty() {
218 bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
219 }
220
221 let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
222 if inputs.first().is_some_and(|p| {
223 let s = p.to_string_lossy();
224 !s.contains('=')
225 }) {
226 let first = inputs.first().map(|p| p.as_path());
227 let first = if first == Some(Path::new("-")) {
228 None
229 } else {
230 first
231 };
232 (first, &inputs[1..])
233 } else {
234 (None, inputs)
235 };
236
237 if split_specs_raw.is_empty() {
238 bail!("no split specifications provided");
239 }
240
241 let specs: Vec<SplitSpec> = split_specs_raw
242 .iter()
243 .map(|p| parse_split_spec(&p.to_string_lossy()))
244 .collect::<Result<Vec<_>>>()?;
245
246 let lines = read_lines_from_input(input_path)?;
247 let total_lines = lines.len();
248
249 if total_lines == 0 {
250 for spec in &specs {
251 write_lines_to_file(&spec.path, &[])?;
252 }
253 return Ok(());
254 }
255
256 let (by_repo, without_repo) = group_lines_by_repo(lines);
257 let has_repos = !by_repo.is_empty();
258
259 if has_repos {
260 eprintln!(
261 "Stratifying by repository_url ({} unique repositories, {} examples)",
262 by_repo.len(),
263 total_lines - without_repo.len()
264 );
265 if !without_repo.is_empty() {
266 eprintln!(
267 " + {} examples without repository_url (distributed proportionally)",
268 without_repo.len()
269 );
270 }
271 }
272
273 let mut rng = match args.seed {
274 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
275 None => rand::rngs::StdRng::from_os_rng(),
276 };
277
278 let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
279
280 if has_repos {
281 let mut repos: Vec<String> = by_repo.keys().cloned().collect();
282 repos.shuffle(&mut rng);
283
284 let repo_counts = compute_split_counts(&specs, repos.len())?;
285
286 let mut repo_iter = repos.into_iter();
287 for (split_idx, &count) in repo_counts.iter().enumerate() {
288 for _ in 0..count {
289 if let Some(repo) = repo_iter.next() {
290 if let Some(repo_lines) = by_repo.get(&repo) {
291 split_outputs[split_idx].extend(repo_lines.iter().cloned());
292 }
293 }
294 }
295 }
296
297 if !without_repo.is_empty() {
298 let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
299 let mut no_repo_shuffled = without_repo;
300 no_repo_shuffled.shuffle(&mut rng);
301
302 let mut line_iter = no_repo_shuffled.into_iter();
303 for (split_idx, &count) in no_repo_counts.iter().enumerate() {
304 for _ in 0..count {
305 if let Some(line) = line_iter.next() {
306 split_outputs[split_idx].push(line);
307 }
308 }
309 }
310 }
311 } else {
312 let line_counts = compute_split_counts(&specs, total_lines)?;
313 let mut shuffled_lines = without_repo;
314 shuffled_lines.shuffle(&mut rng);
315
316 let mut line_iter = shuffled_lines.into_iter();
317 for (split_idx, &count) in line_counts.iter().enumerate() {
318 for _ in 0..count {
319 if let Some(line) = line_iter.next() {
320 split_outputs[split_idx].push(line);
321 }
322 }
323 }
324 }
325
326 for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
327 write_lines_to_file(&spec.path, output_lines)?;
328 eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
329 }
330
331 Ok(())
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use std::io::Write;
338 use tempfile::NamedTempFile;
339
340 fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
341 let mut file = NamedTempFile::new().unwrap();
342 for line in lines {
343 writeln!(file, "{}", line).unwrap();
344 }
345 file.flush().unwrap();
346 file
347 }
348
349 #[test]
350 fn test_parse_split_spec_percentage() {
351 let spec = parse_split_spec("train.jsonl=80%").unwrap();
352 assert_eq!(spec.path, PathBuf::from("train.jsonl"));
353 match spec.size {
354 SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
355 _ => panic!("expected percentage"),
356 }
357 }
358
359 #[test]
360 fn test_parse_split_spec_absolute() {
361 let spec = parse_split_spec("test.jsonl=100").unwrap();
362 assert_eq!(spec.path, PathBuf::from("test.jsonl"));
363 match spec.size {
364 SplitSize::Absolute(n) => assert_eq!(n, 100),
365 _ => panic!("expected absolute"),
366 }
367 }
368
369 #[test]
370 fn test_parse_split_spec_rest() {
371 let spec = parse_split_spec("valid.jsonl=rest").unwrap();
372 assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
373 assert!(matches!(spec.size, SplitSize::Rest));
374 }
375
376 #[test]
377 fn test_get_repository_url() {
378 let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
379 assert_eq!(
380 get_repository_url(line),
381 Some("https://github.com/example/repo".to_string())
382 );
383
384 let line_no_repo = r#"{"data": 123}"#;
385 assert_eq!(get_repository_url(line_no_repo), None);
386 }
387
388 #[test]
389 fn test_compute_split_counts_percentage() {
390 let specs = vec![
391 SplitSpec {
392 path: PathBuf::from("a"),
393 size: SplitSize::Percentage(0.8),
394 },
395 SplitSpec {
396 path: PathBuf::from("b"),
397 size: SplitSize::Percentage(0.2),
398 },
399 ];
400 let counts = compute_split_counts(&specs, 100).unwrap();
401 assert_eq!(counts, vec![80, 20]);
402 }
403
404 #[test]
405 fn test_compute_split_counts_with_rest() {
406 let specs = vec![
407 SplitSpec {
408 path: PathBuf::from("a"),
409 size: SplitSize::Percentage(0.8),
410 },
411 SplitSpec {
412 path: PathBuf::from("b"),
413 size: SplitSize::Rest,
414 },
415 ];
416 let counts = compute_split_counts(&specs, 100).unwrap();
417 assert_eq!(counts, vec![80, 20]);
418 }
419
420 #[test]
421 fn test_compute_split_counts_absolute() {
422 let specs = vec![
423 SplitSpec {
424 path: PathBuf::from("a"),
425 size: SplitSize::Absolute(50),
426 },
427 SplitSpec {
428 path: PathBuf::from("b"),
429 size: SplitSize::Rest,
430 },
431 ];
432 let counts = compute_split_counts(&specs, 100).unwrap();
433 assert_eq!(counts, vec![50, 50]);
434 }
435
436 #[test]
437 fn test_group_lines_by_repo() {
438 let lines = vec![
439 r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
440 r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
441 r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
442 r#"{"id": 4}"#.to_string(),
443 ];
444
445 let (by_repo, without_repo) = group_lines_by_repo(lines);
446
447 assert_eq!(by_repo.len(), 2);
448 assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
449 assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
450 assert_eq!(without_repo.len(), 1);
451 }
452
453 #[test]
454 fn test_run_split_basic() {
455 let input = create_temp_jsonl(&[
456 r#"{"repository_url": "repo1", "id": 1}"#,
457 r#"{"repository_url": "repo1", "id": 2}"#,
458 r#"{"repository_url": "repo2", "id": 3}"#,
459 r#"{"repository_url": "repo2", "id": 4}"#,
460 r#"{"repository_url": "repo3", "id": 5}"#,
461 r#"{"repository_url": "repo3", "id": 6}"#,
462 r#"{"repository_url": "repo4", "id": 7}"#,
463 r#"{"repository_url": "repo4", "id": 8}"#,
464 ]);
465
466 let temp_dir = tempfile::tempdir().unwrap();
467 let train_path = temp_dir.path().join("train.jsonl");
468 let valid_path = temp_dir.path().join("valid.jsonl");
469
470 let args = SplitArgs { seed: Some(42) };
471 let inputs = vec![
472 input.path().to_path_buf(),
473 PathBuf::from(format!("{}=50%", train_path.display())),
474 PathBuf::from(format!("{}=rest", valid_path.display())),
475 ];
476
477 run_split(&args, &inputs).unwrap();
478
479 let train_content = std::fs::read_to_string(&train_path).unwrap();
480 let valid_content = std::fs::read_to_string(&valid_path).unwrap();
481
482 let train_lines: Vec<&str> = train_content.lines().collect();
483 let valid_lines: Vec<&str> = valid_content.lines().collect();
484
485 assert_eq!(train_lines.len() + valid_lines.len(), 8);
486
487 let train_repos: std::collections::HashSet<_> = train_lines
488 .iter()
489 .filter_map(|l| get_repository_url(l))
490 .collect();
491 let valid_repos: std::collections::HashSet<_> = valid_lines
492 .iter()
493 .filter_map(|l| get_repository_url(l))
494 .collect();
495
496 assert!(
497 train_repos.is_disjoint(&valid_repos),
498 "train and valid should have non-overlapping repos"
499 );
500 }
501
502 #[test]
503 fn test_multiple_rest_fails() {
504 let specs = vec![
505 SplitSpec {
506 path: PathBuf::from("a"),
507 size: SplitSize::Rest,
508 },
509 SplitSpec {
510 path: PathBuf::from("b"),
511 size: SplitSize::Rest,
512 },
513 ];
514 assert!(compute_split_counts(&specs, 100).is_err());
515 }
516}