split_dataset.rs

  1//! `ep split` implementation.
  2//!
  3//! This command splits a JSONL dataset into multiple files based on size specifications,
  4//! with stratification by repository URL (if the field is present).
  5//!
  6//! # Usage
  7//!
  8//! ```text
  9//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
 10//! ```
 11//!
 12//! If `input.jsonl` is not provided or is `-`, reads from stdin.
 13//!
 14//! # Size specifications
 15//!
 16//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
 17//! - `100` - absolute count of repositories (if stratified) or examples
 18//! - `rest` - all remaining items (only one split can use this)
 19//!
 20//! # Stratification
 21//!
 22//! When examples have a `repository_url` field, the split is stratified by repository.
 23//! This ensures each output file contains examples from non-overlapping repositories.
 24//! Size specifications apply to the number of repositories, not individual examples.
 25//!
 26//! Examples without `repository_url` are distributed proportionally across all outputs.
 27
 28use anyhow::{Context as _, Result, bail};
 29use clap::Args;
 30use rand::SeedableRng;
 31use rand::seq::SliceRandom;
 32use serde_json::Value;
 33use std::collections::HashMap;
 34use std::fs::File;
 35use std::io::{self, BufRead, BufReader, BufWriter, Write};
 36use std::path::{Path, PathBuf};
 37
 38/// `ep split` CLI args.
 39#[derive(Debug, Args, Clone)]
 40#[command(
 41    about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
 42    after_help = r#"SIZE SPECIFICATIONS:
 43  <percentage>%    Percentage of total (e.g., 80%)
 44  <count>          Absolute number (e.g., 100)
 45  rest             All remaining items (only one output can use this)
 46
 47  When stratifying by repository_url, sizes apply to repositories, not examples.
 48
 49EXAMPLES:
 50  # Split 80% train, 20% validation
 51  ep split input.jsonl train.jsonl=80% valid.jsonl=rest
 52
 53  # Split into train/valid/test
 54  ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
 55
 56  # Use absolute counts (100 repos to train, rest to valid)
 57  ep split input.jsonl train.jsonl=100 valid.jsonl=rest
 58
 59  # Read from stdin
 60  cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
 61
 62  # Reproducible split with seed
 63  ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
 64
 65  # Disable stratification (split by examples, not repositories)
 66  ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest
 67
 68STRATIFICATION:
 69  When examples have a "repository_url" field, the split ensures each output
 70  file contains examples from non-overlapping repositories. This prevents
 71  data leakage between train/test splits. Use --no-stratify to disable this
 72  behavior and split by individual examples instead.
 73"#
 74)]
 75pub struct SplitArgs {
 76    /// Random seed for reproducibility
 77    #[arg(long)]
 78    pub seed: Option<u64>,
 79
 80    /// Disable stratification by repository_url (split by examples instead)
 81    #[arg(long)]
 82    pub no_stratify: bool,
 83}
 84
 85#[derive(Debug, Clone)]
 86pub enum SplitSize {
 87    Percentage(f64),
 88    Absolute(usize),
 89    Rest,
 90}
 91
 92#[derive(Debug, Clone)]
 93pub struct SplitSpec {
 94    pub path: PathBuf,
 95    pub size: SplitSize,
 96}
 97
 98fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
 99    let (path, size_str) = spec
100        .rsplit_once('=')
101        .with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
102
103    let size = if size_str == "rest" {
104        SplitSize::Rest
105    } else if size_str.ends_with('%') {
106        let pct_str = size_str.trim_end_matches('%');
107        let pct: f64 = pct_str
108            .parse()
109            .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
110        if !(0.0..=100.0).contains(&pct) {
111            bail!("percentage must be between 0 and 100, got {}", pct);
112        }
113        SplitSize::Percentage(pct / 100.0)
114    } else {
115        let count: usize = size_str
116            .parse()
117            .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
118        SplitSize::Absolute(count)
119    };
120
121    Ok(SplitSpec {
122        path: PathBuf::from(path),
123        size,
124    })
125}
126
127fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
128    let reader: Box<dyn BufRead> = match input {
129        Some(path) => {
130            let file =
131                File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
132            Box::new(BufReader::new(file))
133        }
134        None => Box::new(BufReader::new(io::stdin())),
135    };
136
137    let lines: Vec<String> = reader
138        .lines()
139        .collect::<io::Result<Vec<_>>>()
140        .context("failed to read input lines")?;
141
142    Ok(lines)
143}
144
145fn get_repository_url(line: &str) -> Option<String> {
146    let value: Value = serde_json::from_str(line).ok()?;
147    value
148        .get("repository_url")
149        .and_then(|v| v.as_str())
150        .map(|s| s.to_string())
151}
152
153fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
154    let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
155    let mut without_repo: Vec<String> = Vec::new();
156
157    for line in lines {
158        if let Some(repo_url) = get_repository_url(&line) {
159            by_repo.entry(repo_url).or_default().push(line);
160        } else {
161            without_repo.push(line);
162        }
163    }
164
165    (by_repo, without_repo)
166}
167
168fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
169    let mut counts = vec![0usize; specs.len()];
170    let mut remaining = total;
171    let mut rest_index: Option<usize> = None;
172
173    for (i, spec) in specs.iter().enumerate() {
174        match &spec.size {
175            SplitSize::Percentage(pct) => {
176                let count = (total as f64 * pct).round() as usize;
177                counts[i] = count.min(remaining);
178                remaining = remaining.saturating_sub(counts[i]);
179            }
180            SplitSize::Absolute(count) => {
181                counts[i] = (*count).min(remaining);
182                remaining = remaining.saturating_sub(counts[i]);
183            }
184            SplitSize::Rest => {
185                if rest_index.is_some() {
186                    bail!("only one split can use 'rest'");
187                }
188                rest_index = Some(i);
189            }
190        }
191    }
192
193    if let Some(idx) = rest_index {
194        counts[idx] = remaining;
195    }
196
197    Ok(counts)
198}
199
200fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
201    if let Some(parent) = path.parent() {
202        if !parent.as_os_str().is_empty() {
203            std::fs::create_dir_all(parent)
204                .with_context(|| format!("failed to create directory '{}'", parent.display()))?;
205        }
206    }
207
208    let file =
209        File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
210    let mut writer = BufWriter::new(file);
211
212    for line in lines {
213        writeln!(writer, "{}", line)
214            .with_context(|| format!("failed to write to '{}'", path.display()))?;
215    }
216
217    writer
218        .flush()
219        .with_context(|| format!("failed to flush '{}'", path.display()))?;
220
221    Ok(())
222}
223
224pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
225    if inputs.is_empty() {
226        bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
227    }
228
229    let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
230        if inputs.first().is_some_and(|p| {
231            let s = p.to_string_lossy();
232            !s.contains('=')
233        }) {
234            let first = inputs.first().map(|p| p.as_path());
235            let first = if first == Some(Path::new("-")) {
236                None
237            } else {
238                first
239            };
240            (first, &inputs[1..])
241        } else {
242            (None, inputs)
243        };
244
245    if split_specs_raw.is_empty() {
246        bail!("no split specifications provided");
247    }
248
249    let specs: Vec<SplitSpec> = split_specs_raw
250        .iter()
251        .map(|p| parse_split_spec(&p.to_string_lossy()))
252        .collect::<Result<Vec<_>>>()?;
253
254    let lines = read_lines_from_input(input_path)?;
255    let total_lines = lines.len();
256
257    if total_lines == 0 {
258        for spec in &specs {
259            write_lines_to_file(&spec.path, &[])?;
260        }
261        return Ok(());
262    }
263
264    let (by_repo, without_repo) = group_lines_by_repo(lines);
265    let has_repos = !by_repo.is_empty() && !args.no_stratify;
266
267    if args.no_stratify && !by_repo.is_empty() {
268        eprintln!(
269            "Stratification disabled (--no-stratify), splitting {} examples by line",
270            total_lines
271        );
272    } else if has_repos {
273        eprintln!(
274            "Stratifying by repository_url ({} unique repositories, {} examples)",
275            by_repo.len(),
276            total_lines - without_repo.len()
277        );
278        if !without_repo.is_empty() {
279            eprintln!(
280                "  + {} examples without repository_url (distributed proportionally)",
281                without_repo.len()
282            );
283        }
284    }
285
286    let mut rng = match args.seed {
287        Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
288        None => rand::rngs::StdRng::from_os_rng(),
289    };
290
291    let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
292
293    if has_repos {
294        let mut repos: Vec<String> = by_repo.keys().cloned().collect();
295        repos.shuffle(&mut rng);
296
297        let repo_counts = compute_split_counts(&specs, repos.len())?;
298
299        let mut repo_iter = repos.into_iter();
300        for (split_idx, &count) in repo_counts.iter().enumerate() {
301            for _ in 0..count {
302                if let Some(repo) = repo_iter.next() {
303                    if let Some(repo_lines) = by_repo.get(&repo) {
304                        split_outputs[split_idx].extend(repo_lines.iter().cloned());
305                    }
306                }
307            }
308        }
309
310        if !without_repo.is_empty() {
311            let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
312            let mut no_repo_shuffled = without_repo;
313            no_repo_shuffled.shuffle(&mut rng);
314
315            let mut line_iter = no_repo_shuffled.into_iter();
316            for (split_idx, &count) in no_repo_counts.iter().enumerate() {
317                for _ in 0..count {
318                    if let Some(line) = line_iter.next() {
319                        split_outputs[split_idx].push(line);
320                    }
321                }
322            }
323        }
324    } else {
325        let line_counts = compute_split_counts(&specs, total_lines)?;
326        let mut all_lines: Vec<String> = by_repo.into_values().flatten().collect();
327        all_lines.extend(without_repo);
328        all_lines.shuffle(&mut rng);
329
330        let mut line_iter = all_lines.into_iter();
331
332        for (split_idx, &count) in line_counts.iter().enumerate() {
333            for _ in 0..count {
334                if let Some(line) = line_iter.next() {
335                    split_outputs[split_idx].push(line);
336                }
337            }
338        }
339    }
340
341    for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
342        write_lines_to_file(&spec.path, output_lines)?;
343        eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
344    }
345
346    Ok(())
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use std::io::Write;
353    use tempfile::NamedTempFile;
354
355    fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
356        let mut file = NamedTempFile::new().unwrap();
357        for line in lines {
358            writeln!(file, "{}", line).unwrap();
359        }
360        file.flush().unwrap();
361        file
362    }
363
364    #[test]
365    fn test_parse_split_spec_percentage() {
366        let spec = parse_split_spec("train.jsonl=80%").unwrap();
367        assert_eq!(spec.path, PathBuf::from("train.jsonl"));
368        match spec.size {
369            SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
370            _ => panic!("expected percentage"),
371        }
372    }
373
374    #[test]
375    fn test_parse_split_spec_absolute() {
376        let spec = parse_split_spec("test.jsonl=100").unwrap();
377        assert_eq!(spec.path, PathBuf::from("test.jsonl"));
378        match spec.size {
379            SplitSize::Absolute(n) => assert_eq!(n, 100),
380            _ => panic!("expected absolute"),
381        }
382    }
383
384    #[test]
385    fn test_parse_split_spec_rest() {
386        let spec = parse_split_spec("valid.jsonl=rest").unwrap();
387        assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
388        assert!(matches!(spec.size, SplitSize::Rest));
389    }
390
391    #[test]
392    fn test_get_repository_url() {
393        let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
394        assert_eq!(
395            get_repository_url(line),
396            Some("https://github.com/example/repo".to_string())
397        );
398
399        let line_no_repo = r#"{"data": 123}"#;
400        assert_eq!(get_repository_url(line_no_repo), None);
401    }
402
403    #[test]
404    fn test_compute_split_counts_percentage() {
405        let specs = vec![
406            SplitSpec {
407                path: PathBuf::from("a"),
408                size: SplitSize::Percentage(0.8),
409            },
410            SplitSpec {
411                path: PathBuf::from("b"),
412                size: SplitSize::Percentage(0.2),
413            },
414        ];
415        let counts = compute_split_counts(&specs, 100).unwrap();
416        assert_eq!(counts, vec![80, 20]);
417    }
418
419    #[test]
420    fn test_compute_split_counts_with_rest() {
421        let specs = vec![
422            SplitSpec {
423                path: PathBuf::from("a"),
424                size: SplitSize::Percentage(0.8),
425            },
426            SplitSpec {
427                path: PathBuf::from("b"),
428                size: SplitSize::Rest,
429            },
430        ];
431        let counts = compute_split_counts(&specs, 100).unwrap();
432        assert_eq!(counts, vec![80, 20]);
433    }
434
435    #[test]
436    fn test_compute_split_counts_absolute() {
437        let specs = vec![
438            SplitSpec {
439                path: PathBuf::from("a"),
440                size: SplitSize::Absolute(50),
441            },
442            SplitSpec {
443                path: PathBuf::from("b"),
444                size: SplitSize::Rest,
445            },
446        ];
447        let counts = compute_split_counts(&specs, 100).unwrap();
448        assert_eq!(counts, vec![50, 50]);
449    }
450
451    #[test]
452    fn test_group_lines_by_repo() {
453        let lines = vec![
454            r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
455            r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
456            r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
457            r#"{"id": 4}"#.to_string(),
458        ];
459
460        let (by_repo, without_repo) = group_lines_by_repo(lines);
461
462        assert_eq!(by_repo.len(), 2);
463        assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
464        assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
465        assert_eq!(without_repo.len(), 1);
466    }
467
468    #[test]
469    fn test_run_split_basic() {
470        let input = create_temp_jsonl(&[
471            r#"{"repository_url": "repo1", "id": 1}"#,
472            r#"{"repository_url": "repo1", "id": 2}"#,
473            r#"{"repository_url": "repo2", "id": 3}"#,
474            r#"{"repository_url": "repo2", "id": 4}"#,
475            r#"{"repository_url": "repo3", "id": 5}"#,
476            r#"{"repository_url": "repo3", "id": 6}"#,
477            r#"{"repository_url": "repo4", "id": 7}"#,
478            r#"{"repository_url": "repo4", "id": 8}"#,
479        ]);
480
481        let temp_dir = tempfile::tempdir().unwrap();
482        let train_path = temp_dir.path().join("train.jsonl");
483        let valid_path = temp_dir.path().join("valid.jsonl");
484
485        let args = SplitArgs {
486            seed: Some(42),
487            no_stratify: false,
488        };
489        let inputs = vec![
490            input.path().to_path_buf(),
491            PathBuf::from(format!("{}=50%", train_path.display())),
492            PathBuf::from(format!("{}=rest", valid_path.display())),
493        ];
494
495        run_split(&args, &inputs).unwrap();
496
497        let train_content = std::fs::read_to_string(&train_path).unwrap();
498        let valid_content = std::fs::read_to_string(&valid_path).unwrap();
499
500        let train_lines: Vec<&str> = train_content.lines().collect();
501        let valid_lines: Vec<&str> = valid_content.lines().collect();
502
503        assert_eq!(train_lines.len() + valid_lines.len(), 8);
504
505        let train_repos: std::collections::HashSet<_> = train_lines
506            .iter()
507            .filter_map(|l| get_repository_url(l))
508            .collect();
509        let valid_repos: std::collections::HashSet<_> = valid_lines
510            .iter()
511            .filter_map(|l| get_repository_url(l))
512            .collect();
513
514        assert!(
515            train_repos.is_disjoint(&valid_repos),
516            "train and valid should have non-overlapping repos"
517        );
518    }
519
520    #[test]
521    fn test_multiple_rest_fails() {
522        let specs = vec![
523            SplitSpec {
524                path: PathBuf::from("a"),
525                size: SplitSize::Rest,
526            },
527            SplitSpec {
528                path: PathBuf::from("b"),
529                size: SplitSize::Rest,
530            },
531        ];
532        assert!(compute_split_counts(&specs, 100).is_err());
533    }
534}