split_dataset.rs

  1//! `ep split` implementation.
  2//!
  3//! This command splits a JSONL dataset into multiple files based on size specifications,
  4//! with optional stratification by a JSON field.
  5//!
  6//! # Usage
  7//!
  8//! ```text
  9//! ep split [--stratify=<field>] [input.jsonl] <out1>=<size1> <out2>=<size2> ...
 10//! ```
 11//!
 12//! If `input.jsonl` is not provided or is `-`, reads from stdin.
 13//!
 14//! # Size specifications
 15//!
 16//! - `80%` - percentage of total examples (lines)
 17//! - `100` - approximate absolute count of examples (lines)
 18//! - `rest` - all remaining items (only one split can use this)
 19//!
 20//! # Stratification
 21//!
 22//! The `--stratify` flag controls how examples are grouped before splitting:
 23//!
 24//! - `cursor-path` (default): group by the `cursor_path` JSON field
 25//! - `repo`: group by the `repository_url` JSON field
 26//! - `none`: no grouping, split individual examples
 27//!
 28//! When stratifying, the split ensures each output file contains examples from
 29//! non-overlapping groups. Size specifications always apply to the number of
 30//! examples (lines), with whole groups assigned greedily to meet the target.
 31//! Examples missing the stratification field are treated as individual groups.
 32
 33use anyhow::{Context as _, Result, bail};
 34use clap::Args;
 35use rand::SeedableRng;
 36use rand::seq::SliceRandom;
 37use serde_json::Value;
 38use std::collections::HashMap;
 39use std::fs::File;
 40use std::io::{self, BufRead, BufReader, BufWriter, Write};
 41use std::path::{Path, PathBuf};
 42
 43/// `ep split` CLI args.
 44#[derive(Debug, Args, Clone)]
 45#[command(
 46    about = "Split a JSONL dataset into multiple files with optional stratification",
 47    after_help = r#"SIZE SPECIFICATIONS:
 48  <percentage>%    Percentage of total (e.g., 80%)
 49  <count>          Absolute number (e.g., 100)
 50  rest             All remaining items (only one output can use this)
 51
 52  Sizes always apply to examples (lines). When stratifying, whole groups
 53  are assigned greedily to approximate the target count.
 54
 55EXAMPLES:
 56  # Split 80% train, 20% validation (default: stratify by cursor_path)
 57  ep split input.jsonl train.jsonl=80% valid.jsonl=rest
 58
 59  # Split into train/valid/test
 60  ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
 61
 62  # Stratify by repository_url instead of cursor_path
 63  ep split --stratify=repo input.jsonl train.jsonl=80% valid.jsonl=rest
 64
 65  # No stratification (split by individual examples)
 66  ep split --stratify=none input.jsonl train.jsonl=80% valid.jsonl=rest
 67
 68  # Read from stdin
 69  cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
 70
 71  # Reproducible split with seed
 72  ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
 73
 74STRATIFICATION:
 75  Controls how examples are grouped before splitting:
 76    cursor-path  Group by "cursor_path" field (default)
 77    repo         Group by "repository_url" field
 78    none         No grouping, split individual examples
 79
 80  When stratifying, the split ensures each output file contains examples
 81  from non-overlapping groups. This prevents data leakage between
 82  train/test splits.
 83"#
 84)]
 85pub struct SplitArgs {
 86    /// Random seed for reproducibility
 87    #[arg(long)]
 88    pub seed: Option<u64>,
 89
 90    /// Stratification field for splitting the dataset
 91    #[arg(long, default_value = "cursor-path")]
 92    pub stratify: Stratify,
 93}
 94
 95#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, strum::Display)]
 96pub enum Stratify {
 97    #[strum(serialize = "cursor_path")]
 98    CursorPath,
 99    #[strum(serialize = "repo")]
100    Repo,
101    #[strum(serialize = "none")]
102    None,
103}
104
105#[derive(Debug, Clone)]
106pub enum SplitSize {
107    Percentage(f64),
108    Absolute(usize),
109    Rest,
110}
111
112#[derive(Debug, Clone)]
113pub struct SplitSpec {
114    pub path: PathBuf,
115    pub size: SplitSize,
116}
117
118fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
119    let (path, size_str) = spec
120        .rsplit_once('=')
121        .with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
122
123    let size = if size_str == "rest" {
124        SplitSize::Rest
125    } else if size_str.ends_with('%') {
126        let pct_str = size_str.trim_end_matches('%');
127        let pct: f64 = pct_str
128            .parse()
129            .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
130        if !(0.0..=100.0).contains(&pct) {
131            bail!("percentage must be between 0 and 100, got {}", pct);
132        }
133        SplitSize::Percentage(pct / 100.0)
134    } else {
135        let count: usize = size_str
136            .parse()
137            .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
138        SplitSize::Absolute(count)
139    };
140
141    Ok(SplitSpec {
142        path: PathBuf::from(path),
143        size,
144    })
145}
146
147fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
148    let reader: Box<dyn BufRead> = match input {
149        Some(path) => {
150            let file =
151                File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
152            Box::new(BufReader::new(file))
153        }
154        None => Box::new(BufReader::new(io::stdin())),
155    };
156
157    let lines: Vec<String> = reader
158        .lines()
159        .collect::<io::Result<Vec<_>>>()
160        .context("failed to read input lines")?;
161
162    Ok(lines)
163}
164
165fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
166    let mut counts = vec![0usize; specs.len()];
167    let mut remaining = total;
168    let mut rest_index: Option<usize> = None;
169
170    for (i, spec) in specs.iter().enumerate() {
171        match &spec.size {
172            SplitSize::Percentage(pct) => {
173                let count = (total as f64 * pct).round() as usize;
174                counts[i] = count.min(remaining);
175                remaining = remaining.saturating_sub(counts[i]);
176            }
177            SplitSize::Absolute(count) => {
178                counts[i] = (*count).min(remaining);
179                remaining = remaining.saturating_sub(counts[i]);
180            }
181            SplitSize::Rest => {
182                if rest_index.is_some() {
183                    bail!("only one split can use 'rest'");
184                }
185                rest_index = Some(i);
186            }
187        }
188    }
189
190    if let Some(idx) = rest_index {
191        counts[idx] = remaining;
192    }
193
194    Ok(counts)
195}
196
197fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
198    if let Some(parent) = path.parent() {
199        if !parent.as_os_str().is_empty() {
200            std::fs::create_dir_all(parent)
201                .with_context(|| format!("failed to create directory '{}'", parent.display()))?;
202        }
203    }
204
205    let file =
206        File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
207    let mut writer = BufWriter::new(file);
208
209    for line in lines {
210        writeln!(writer, "{}", line)
211            .with_context(|| format!("failed to write to '{}'", path.display()))?;
212    }
213
214    writer
215        .flush()
216        .with_context(|| format!("failed to flush '{}'", path.display()))?;
217
218    Ok(())
219}
220
221pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
222    if inputs.is_empty() {
223        bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
224    }
225
226    let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
227        if inputs.first().is_some_and(|p| {
228            let s = p.to_string_lossy();
229            !s.contains('=')
230        }) {
231            let first = inputs.first().map(|p| p.as_path());
232            let first = if first == Some(Path::new("-")) {
233                None
234            } else {
235                first
236            };
237            (first, &inputs[1..])
238        } else {
239            (None, inputs)
240        };
241
242    if split_specs_raw.is_empty() {
243        bail!("no split specifications provided");
244    }
245
246    let specs: Vec<SplitSpec> = split_specs_raw
247        .iter()
248        .map(|p| parse_split_spec(&p.to_string_lossy()))
249        .collect::<Result<Vec<_>>>()?;
250
251    let lines = read_lines_from_input(input_path)?;
252    let total_lines = lines.len();
253
254    if total_lines == 0 {
255        for spec in &specs {
256            write_lines_to_file(&spec.path, &[])?;
257        }
258        return Ok(());
259    }
260
261    let mut grouped_lines = group_lines(&lines, args.stratify);
262
263    if args.stratify != Stratify::None {
264        eprintln!(
265            "Stratifying by {} ({} unique groups, {} examples)",
266            args.stratify,
267            grouped_lines.len(),
268            total_lines
269        );
270    } else {
271        eprintln!(
272            "No stratification, splitting {} examples by line",
273            total_lines
274        );
275    }
276
277    let mut rng = match args.seed {
278        Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
279        None => rand::rngs::StdRng::from_os_rng(),
280    };
281
282    grouped_lines.shuffle(&mut rng);
283
284    let line_targets = compute_split_counts(&specs, total_lines)?;
285    let rest_index = specs.iter().position(|s| matches!(s.size, SplitSize::Rest));
286    let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
287    let mut group_iter = grouped_lines.into_iter();
288
289    for (split_idx, &target) in line_targets.iter().enumerate() {
290        if Some(split_idx) == rest_index {
291            continue;
292        }
293        let mut accumulated = 0;
294        while accumulated < target {
295            if let Some(group) = group_iter.next() {
296                accumulated += group.len();
297                split_outputs[split_idx].extend(group);
298            } else {
299                break;
300            }
301        }
302    }
303
304    if let Some(idx) = rest_index {
305        for group in group_iter {
306            split_outputs[idx].extend(group);
307        }
308    }
309
310    for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
311        write_lines_to_file(&spec.path, output_lines)?;
312        eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
313    }
314
315    Ok(())
316}
317
318/// Groups lines by the specified stratification field.
319///
320/// When `stratify` is `None`, each line becomes its own group.
321/// When a line is missing the stratification field, it is also placed in its own group.
322fn group_lines(lines: &[String], stratify: Stratify) -> Vec<Vec<String>> {
323    if stratify == Stratify::None {
324        return lines.iter().map(|line| vec![line.clone()]).collect();
325    }
326
327    let field = match stratify {
328        Stratify::Repo => "repository_url",
329        Stratify::CursorPath => "cursor_path",
330        Stratify::None => unreachable!(),
331    };
332
333    let mut groups: HashMap<String, Vec<String>> = HashMap::new();
334    let mut ungrouped: Vec<Vec<String>> = Vec::new();
335
336    for line in lines {
337        let key = serde_json::from_str::<Value>(line)
338            .ok()
339            .and_then(|v| v.get(field)?.as_str().map(|s| s.to_string()));
340        match key {
341            Some(key) => groups.entry(key).or_default().push(line.clone()),
342            None => ungrouped.push(vec![line.clone()]),
343        }
344    }
345
346    let mut result: Vec<Vec<String>> = groups.into_values().collect();
347    result.extend(ungrouped);
348    result
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use std::io::Write;
355    use tempfile::NamedTempFile;
356
357    fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
358        let mut file = NamedTempFile::new().unwrap();
359        for line in lines {
360            writeln!(file, "{}", line).unwrap();
361        }
362        file.flush().unwrap();
363        file
364    }
365
366    #[test]
367    fn test_parse_split_spec_percentage() {
368        let spec = parse_split_spec("train.jsonl=80%").unwrap();
369        assert_eq!(spec.path, PathBuf::from("train.jsonl"));
370        match spec.size {
371            SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
372            _ => panic!("expected percentage"),
373        }
374    }
375
376    #[test]
377    fn test_parse_split_spec_absolute() {
378        let spec = parse_split_spec("test.jsonl=100").unwrap();
379        assert_eq!(spec.path, PathBuf::from("test.jsonl"));
380        match spec.size {
381            SplitSize::Absolute(n) => assert_eq!(n, 100),
382            _ => panic!("expected absolute"),
383        }
384    }
385
386    #[test]
387    fn test_parse_split_spec_rest() {
388        let spec = parse_split_spec("valid.jsonl=rest").unwrap();
389        assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
390        assert!(matches!(spec.size, SplitSize::Rest));
391    }
392
393    #[test]
394    fn test_group_lines_none() {
395        let lines = vec!["a".to_string(), "b".to_string(), "c".to_string()];
396        let groups = group_lines(&lines, Stratify::None);
397        assert_eq!(groups.len(), 3);
398        assert!(groups.iter().all(|g| g.len() == 1));
399    }
400
401    #[test]
402    fn test_compute_split_counts_percentage() {
403        let specs = vec![
404            SplitSpec {
405                path: PathBuf::from("a"),
406                size: SplitSize::Percentage(0.8),
407            },
408            SplitSpec {
409                path: PathBuf::from("b"),
410                size: SplitSize::Percentage(0.2),
411            },
412        ];
413        let counts = compute_split_counts(&specs, 100).unwrap();
414        assert_eq!(counts, vec![80, 20]);
415    }
416
417    #[test]
418    fn test_compute_split_counts_with_rest() {
419        let specs = vec![
420            SplitSpec {
421                path: PathBuf::from("a"),
422                size: SplitSize::Percentage(0.8),
423            },
424            SplitSpec {
425                path: PathBuf::from("b"),
426                size: SplitSize::Rest,
427            },
428        ];
429        let counts = compute_split_counts(&specs, 100).unwrap();
430        assert_eq!(counts, vec![80, 20]);
431    }
432
433    #[test]
434    fn test_compute_split_counts_absolute() {
435        let specs = vec![
436            SplitSpec {
437                path: PathBuf::from("a"),
438                size: SplitSize::Absolute(50),
439            },
440            SplitSpec {
441                path: PathBuf::from("b"),
442                size: SplitSize::Rest,
443            },
444        ];
445        let counts = compute_split_counts(&specs, 100).unwrap();
446        assert_eq!(counts, vec![50, 50]);
447    }
448
449    #[test]
450    fn test_group_lines_by_repo() {
451        let lines = vec![
452            r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
453            r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
454            r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
455            r#"{"id": 4}"#.to_string(),
456        ];
457
458        let groups = group_lines(&lines, Stratify::Repo);
459
460        let grouped_count: usize = groups.iter().filter(|g| g.len() > 1).count();
461        let ungrouped_count: usize = groups.iter().filter(|g| g.len() == 1).count();
462        let total_lines: usize = groups.iter().map(|g| g.len()).sum();
463
464        assert_eq!(grouped_count, 1); // repo1 has 2 lines
465        assert_eq!(ungrouped_count, 2); // repo2 (1 line) + line without repo
466        assert_eq!(total_lines, 4);
467    }
468
469    #[test]
470    fn test_group_lines_by_cursor_path() {
471        let lines = vec![
472            r#"{"cursor_path": "src/main.rs", "id": 1}"#.to_string(),
473            r#"{"cursor_path": "src/main.rs", "id": 2}"#.to_string(),
474            r#"{"cursor_path": "src/lib.rs", "id": 3}"#.to_string(),
475        ];
476
477        let groups = group_lines(&lines, Stratify::CursorPath);
478
479        let total_lines: usize = groups.iter().map(|g| g.len()).sum();
480        assert_eq!(groups.len(), 2);
481        assert_eq!(total_lines, 3);
482    }
483
484    #[test]
485    fn test_run_split_basic() {
486        let input = create_temp_jsonl(&[
487            r#"{"repository_url": "repo1", "id": 1}"#,
488            r#"{"repository_url": "repo1", "id": 2}"#,
489            r#"{"repository_url": "repo2", "id": 3}"#,
490            r#"{"repository_url": "repo2", "id": 4}"#,
491            r#"{"repository_url": "repo3", "id": 5}"#,
492            r#"{"repository_url": "repo3", "id": 6}"#,
493            r#"{"repository_url": "repo4", "id": 7}"#,
494            r#"{"repository_url": "repo4", "id": 8}"#,
495        ]);
496
497        let temp_dir = tempfile::tempdir().unwrap();
498        let train_path = temp_dir.path().join("train.jsonl");
499        let valid_path = temp_dir.path().join("valid.jsonl");
500
501        let args = SplitArgs {
502            seed: Some(42),
503            stratify: Stratify::Repo,
504        };
505        let inputs = vec![
506            input.path().to_path_buf(),
507            PathBuf::from(format!("{}=50%", train_path.display())),
508            PathBuf::from(format!("{}=rest", valid_path.display())),
509        ];
510
511        run_split(&args, &inputs).unwrap();
512
513        let train_content = std::fs::read_to_string(&train_path).unwrap();
514        let valid_content = std::fs::read_to_string(&valid_path).unwrap();
515
516        let train_lines: Vec<&str> = train_content.lines().collect();
517        let valid_lines: Vec<&str> = valid_content.lines().collect();
518
519        assert_eq!(train_lines.len() + valid_lines.len(), 8);
520
521        let get_repo = |line: &str| -> Option<String> {
522            let value: Value = serde_json::from_str(line).ok()?;
523            value
524                .get("repository_url")
525                .and_then(|v| v.as_str())
526                .map(|s| s.to_string())
527        };
528
529        let train_repos: std::collections::HashSet<_> =
530            train_lines.iter().filter_map(|l| get_repo(l)).collect();
531        let valid_repos: std::collections::HashSet<_> =
532            valid_lines.iter().filter_map(|l| get_repo(l)).collect();
533
534        assert!(
535            train_repos.is_disjoint(&valid_repos),
536            "train and valid should have non-overlapping repos"
537        );
538    }
539
540    #[test]
541    fn test_multiple_rest_fails() {
542        let specs = vec![
543            SplitSpec {
544                path: PathBuf::from("a"),
545                size: SplitSize::Rest,
546            },
547            SplitSpec {
548                path: PathBuf::from("b"),
549                size: SplitSize::Rest,
550            },
551        ];
552        assert!(compute_split_counts(&specs, 100).is_err());
553    }
554
555    #[test]
556    fn test_absolute_targets_lines_not_groups() {
557        // 5 repos × 3 lines each = 15 total lines.
558        // `train=6` should target ~6 lines (2 groups), NOT 6 groups (all 15 lines).
559        let input = create_temp_jsonl(&[
560            r#"{"repository_url": "r1", "id": 1}"#,
561            r#"{"repository_url": "r1", "id": 2}"#,
562            r#"{"repository_url": "r1", "id": 3}"#,
563            r#"{"repository_url": "r2", "id": 4}"#,
564            r#"{"repository_url": "r2", "id": 5}"#,
565            r#"{"repository_url": "r2", "id": 6}"#,
566            r#"{"repository_url": "r3", "id": 7}"#,
567            r#"{"repository_url": "r3", "id": 8}"#,
568            r#"{"repository_url": "r3", "id": 9}"#,
569            r#"{"repository_url": "r4", "id": 10}"#,
570            r#"{"repository_url": "r4", "id": 11}"#,
571            r#"{"repository_url": "r4", "id": 12}"#,
572            r#"{"repository_url": "r5", "id": 13}"#,
573            r#"{"repository_url": "r5", "id": 14}"#,
574            r#"{"repository_url": "r5", "id": 15}"#,
575        ]);
576
577        let temp_dir = tempfile::tempdir().unwrap();
578        let train_path = temp_dir.path().join("train.jsonl");
579        let valid_path = temp_dir.path().join("valid.jsonl");
580
581        let args = SplitArgs {
582            seed: Some(42),
583            stratify: Stratify::Repo,
584        };
585        let inputs = vec![
586            input.path().to_path_buf(),
587            PathBuf::from(format!("{}=6", train_path.display())),
588            PathBuf::from(format!("{}=rest", valid_path.display())),
589        ];
590
591        run_split(&args, &inputs).unwrap();
592
593        let train_content = std::fs::read_to_string(&train_path).unwrap();
594        let valid_content = std::fs::read_to_string(&valid_path).unwrap();
595
596        let train_lines: Vec<&str> = train_content.lines().collect();
597        let valid_lines: Vec<&str> = valid_content.lines().collect();
598
599        // With 3-line groups, train should get 2 groups (6 lines) to meet the
600        // target of 6, NOT 6 groups (which don't even exist). Valid gets the rest.
601        assert_eq!(train_lines.len(), 6);
602        assert_eq!(valid_lines.len(), 9);
603    }
604}