split_dataset.rs

  1//! `ep split` implementation.
  2//!
  3//! This command splits a JSONL dataset into multiple files based on size specifications,
  4//! with stratification by repository URL (if the field is present).
  5//!
  6//! # Usage
  7//!
  8//! ```text
  9//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
 10//! ```
 11//!
 12//! If `input.jsonl` is not provided or is `-`, reads from stdin.
 13//!
 14//! # Size specifications
 15//!
 16//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
 17//! - `100` - absolute count of repositories (if stratified) or examples
 18//! - `rest` - all remaining items (only one split can use this)
 19//!
 20//! # Stratification
 21//!
 22//! When examples have a `repository_url` field, the split is stratified by repository.
 23//! This ensures each output file contains examples from non-overlapping repositories.
 24//! Size specifications apply to the number of repositories, not individual examples.
 25//!
 26//! Examples without `repository_url` are distributed proportionally across all outputs.
 27
 28use anyhow::{Context as _, Result, bail};
 29use clap::Args;
 30use rand::SeedableRng;
 31use rand::seq::SliceRandom;
 32use serde_json::Value;
 33use std::collections::HashMap;
 34use std::fs::File;
 35use std::io::{self, BufRead, BufReader, BufWriter, Write};
 36use std::path::{Path, PathBuf};
 37
 38/// `ep split` CLI args.
 39#[derive(Debug, Args, Clone)]
 40#[command(
 41    about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
 42    after_help = r#"SIZE SPECIFICATIONS:
 43  <percentage>%    Percentage of total (e.g., 80%)
 44  <count>          Absolute number (e.g., 100)
 45  rest             All remaining items (only one output can use this)
 46
 47  When stratifying by repository_url, sizes apply to repositories, not examples.
 48
 49EXAMPLES:
 50  # Split 80% train, 20% validation
 51  ep split input.jsonl train.jsonl=80% valid.jsonl=rest
 52
 53  # Split into train/valid/test
 54  ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
 55
 56  # Use absolute counts (100 repos to train, rest to valid)
 57  ep split input.jsonl train.jsonl=100 valid.jsonl=rest
 58
 59  # Read from stdin
 60  cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
 61
 62  # Reproducible split with seed
 63  ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
 64
 65STRATIFICATION:
 66  When examples have a "repository_url" field, the split ensures each output
 67  file contains examples from non-overlapping repositories. This prevents
 68  data leakage between train/test splits.
 69"#
 70)]
 71pub struct SplitArgs {
 72    /// Random seed for reproducibility
 73    #[arg(long)]
 74    pub seed: Option<u64>,
 75}
 76
 77#[derive(Debug, Clone)]
 78pub enum SplitSize {
 79    Percentage(f64),
 80    Absolute(usize),
 81    Rest,
 82}
 83
 84#[derive(Debug, Clone)]
 85pub struct SplitSpec {
 86    pub path: PathBuf,
 87    pub size: SplitSize,
 88}
 89
 90fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
 91    let (path, size_str) = spec
 92        .rsplit_once('=')
 93        .with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
 94
 95    let size = if size_str == "rest" {
 96        SplitSize::Rest
 97    } else if size_str.ends_with('%') {
 98        let pct_str = size_str.trim_end_matches('%');
 99        let pct: f64 = pct_str
100            .parse()
101            .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
102        if !(0.0..=100.0).contains(&pct) {
103            bail!("percentage must be between 0 and 100, got {}", pct);
104        }
105        SplitSize::Percentage(pct / 100.0)
106    } else {
107        let count: usize = size_str
108            .parse()
109            .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
110        SplitSize::Absolute(count)
111    };
112
113    Ok(SplitSpec {
114        path: PathBuf::from(path),
115        size,
116    })
117}
118
119fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
120    let reader: Box<dyn BufRead> = match input {
121        Some(path) => {
122            let file =
123                File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
124            Box::new(BufReader::new(file))
125        }
126        None => Box::new(BufReader::new(io::stdin())),
127    };
128
129    let lines: Vec<String> = reader
130        .lines()
131        .collect::<io::Result<Vec<_>>>()
132        .context("failed to read input lines")?;
133
134    Ok(lines)
135}
136
137fn get_repository_url(line: &str) -> Option<String> {
138    let value: Value = serde_json::from_str(line).ok()?;
139    value
140        .get("repository_url")
141        .and_then(|v| v.as_str())
142        .map(|s| s.to_string())
143}
144
145fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
146    let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
147    let mut without_repo: Vec<String> = Vec::new();
148
149    for line in lines {
150        if let Some(repo_url) = get_repository_url(&line) {
151            by_repo.entry(repo_url).or_default().push(line);
152        } else {
153            without_repo.push(line);
154        }
155    }
156
157    (by_repo, without_repo)
158}
159
160fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
161    let mut counts = vec![0usize; specs.len()];
162    let mut remaining = total;
163    let mut rest_index: Option<usize> = None;
164
165    for (i, spec) in specs.iter().enumerate() {
166        match &spec.size {
167            SplitSize::Percentage(pct) => {
168                let count = (total as f64 * pct).round() as usize;
169                counts[i] = count.min(remaining);
170                remaining = remaining.saturating_sub(counts[i]);
171            }
172            SplitSize::Absolute(count) => {
173                counts[i] = (*count).min(remaining);
174                remaining = remaining.saturating_sub(counts[i]);
175            }
176            SplitSize::Rest => {
177                if rest_index.is_some() {
178                    bail!("only one split can use 'rest'");
179                }
180                rest_index = Some(i);
181            }
182        }
183    }
184
185    if let Some(idx) = rest_index {
186        counts[idx] = remaining;
187    }
188
189    Ok(counts)
190}
191
192fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
193    if let Some(parent) = path.parent() {
194        if !parent.as_os_str().is_empty() {
195            std::fs::create_dir_all(parent)
196                .with_context(|| format!("failed to create directory '{}'", parent.display()))?;
197        }
198    }
199
200    let file =
201        File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
202    let mut writer = BufWriter::new(file);
203
204    for line in lines {
205        writeln!(writer, "{}", line)
206            .with_context(|| format!("failed to write to '{}'", path.display()))?;
207    }
208
209    writer
210        .flush()
211        .with_context(|| format!("failed to flush '{}'", path.display()))?;
212
213    Ok(())
214}
215
216pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
217    if inputs.is_empty() {
218        bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
219    }
220
221    let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
222        if inputs.first().is_some_and(|p| {
223            let s = p.to_string_lossy();
224            !s.contains('=')
225        }) {
226            let first = inputs.first().map(|p| p.as_path());
227            let first = if first == Some(Path::new("-")) {
228                None
229            } else {
230                first
231            };
232            (first, &inputs[1..])
233        } else {
234            (None, inputs)
235        };
236
237    if split_specs_raw.is_empty() {
238        bail!("no split specifications provided");
239    }
240
241    let specs: Vec<SplitSpec> = split_specs_raw
242        .iter()
243        .map(|p| parse_split_spec(&p.to_string_lossy()))
244        .collect::<Result<Vec<_>>>()?;
245
246    let lines = read_lines_from_input(input_path)?;
247    let total_lines = lines.len();
248
249    if total_lines == 0 {
250        for spec in &specs {
251            write_lines_to_file(&spec.path, &[])?;
252        }
253        return Ok(());
254    }
255
256    let (by_repo, without_repo) = group_lines_by_repo(lines);
257    let has_repos = !by_repo.is_empty();
258
259    if has_repos {
260        eprintln!(
261            "Stratifying by repository_url ({} unique repositories, {} examples)",
262            by_repo.len(),
263            total_lines - without_repo.len()
264        );
265        if !without_repo.is_empty() {
266            eprintln!(
267                "  + {} examples without repository_url (distributed proportionally)",
268                without_repo.len()
269            );
270        }
271    }
272
273    let mut rng = match args.seed {
274        Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
275        None => rand::rngs::StdRng::from_os_rng(),
276    };
277
278    let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
279
280    if has_repos {
281        let mut repos: Vec<String> = by_repo.keys().cloned().collect();
282        repos.shuffle(&mut rng);
283
284        let repo_counts = compute_split_counts(&specs, repos.len())?;
285
286        let mut repo_iter = repos.into_iter();
287        for (split_idx, &count) in repo_counts.iter().enumerate() {
288            for _ in 0..count {
289                if let Some(repo) = repo_iter.next() {
290                    if let Some(repo_lines) = by_repo.get(&repo) {
291                        split_outputs[split_idx].extend(repo_lines.iter().cloned());
292                    }
293                }
294            }
295        }
296
297        if !without_repo.is_empty() {
298            let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
299            let mut no_repo_shuffled = without_repo;
300            no_repo_shuffled.shuffle(&mut rng);
301
302            let mut line_iter = no_repo_shuffled.into_iter();
303            for (split_idx, &count) in no_repo_counts.iter().enumerate() {
304                for _ in 0..count {
305                    if let Some(line) = line_iter.next() {
306                        split_outputs[split_idx].push(line);
307                    }
308                }
309            }
310        }
311    } else {
312        let line_counts = compute_split_counts(&specs, total_lines)?;
313        let mut shuffled_lines = without_repo;
314        shuffled_lines.shuffle(&mut rng);
315
316        let mut line_iter = shuffled_lines.into_iter();
317        for (split_idx, &count) in line_counts.iter().enumerate() {
318            for _ in 0..count {
319                if let Some(line) = line_iter.next() {
320                    split_outputs[split_idx].push(line);
321                }
322            }
323        }
324    }
325
326    for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
327        write_lines_to_file(&spec.path, output_lines)?;
328        eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
329    }
330
331    Ok(())
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use std::io::Write;
338    use tempfile::NamedTempFile;
339
340    fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
341        let mut file = NamedTempFile::new().unwrap();
342        for line in lines {
343            writeln!(file, "{}", line).unwrap();
344        }
345        file.flush().unwrap();
346        file
347    }
348
349    #[test]
350    fn test_parse_split_spec_percentage() {
351        let spec = parse_split_spec("train.jsonl=80%").unwrap();
352        assert_eq!(spec.path, PathBuf::from("train.jsonl"));
353        match spec.size {
354            SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
355            _ => panic!("expected percentage"),
356        }
357    }
358
359    #[test]
360    fn test_parse_split_spec_absolute() {
361        let spec = parse_split_spec("test.jsonl=100").unwrap();
362        assert_eq!(spec.path, PathBuf::from("test.jsonl"));
363        match spec.size {
364            SplitSize::Absolute(n) => assert_eq!(n, 100),
365            _ => panic!("expected absolute"),
366        }
367    }
368
369    #[test]
370    fn test_parse_split_spec_rest() {
371        let spec = parse_split_spec("valid.jsonl=rest").unwrap();
372        assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
373        assert!(matches!(spec.size, SplitSize::Rest));
374    }
375
376    #[test]
377    fn test_get_repository_url() {
378        let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
379        assert_eq!(
380            get_repository_url(line),
381            Some("https://github.com/example/repo".to_string())
382        );
383
384        let line_no_repo = r#"{"data": 123}"#;
385        assert_eq!(get_repository_url(line_no_repo), None);
386    }
387
388    #[test]
389    fn test_compute_split_counts_percentage() {
390        let specs = vec![
391            SplitSpec {
392                path: PathBuf::from("a"),
393                size: SplitSize::Percentage(0.8),
394            },
395            SplitSpec {
396                path: PathBuf::from("b"),
397                size: SplitSize::Percentage(0.2),
398            },
399        ];
400        let counts = compute_split_counts(&specs, 100).unwrap();
401        assert_eq!(counts, vec![80, 20]);
402    }
403
404    #[test]
405    fn test_compute_split_counts_with_rest() {
406        let specs = vec![
407            SplitSpec {
408                path: PathBuf::from("a"),
409                size: SplitSize::Percentage(0.8),
410            },
411            SplitSpec {
412                path: PathBuf::from("b"),
413                size: SplitSize::Rest,
414            },
415        ];
416        let counts = compute_split_counts(&specs, 100).unwrap();
417        assert_eq!(counts, vec![80, 20]);
418    }
419
420    #[test]
421    fn test_compute_split_counts_absolute() {
422        let specs = vec![
423            SplitSpec {
424                path: PathBuf::from("a"),
425                size: SplitSize::Absolute(50),
426            },
427            SplitSpec {
428                path: PathBuf::from("b"),
429                size: SplitSize::Rest,
430            },
431        ];
432        let counts = compute_split_counts(&specs, 100).unwrap();
433        assert_eq!(counts, vec![50, 50]);
434    }
435
436    #[test]
437    fn test_group_lines_by_repo() {
438        let lines = vec![
439            r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
440            r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
441            r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
442            r#"{"id": 4}"#.to_string(),
443        ];
444
445        let (by_repo, without_repo) = group_lines_by_repo(lines);
446
447        assert_eq!(by_repo.len(), 2);
448        assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
449        assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
450        assert_eq!(without_repo.len(), 1);
451    }
452
453    #[test]
454    fn test_run_split_basic() {
455        let input = create_temp_jsonl(&[
456            r#"{"repository_url": "repo1", "id": 1}"#,
457            r#"{"repository_url": "repo1", "id": 2}"#,
458            r#"{"repository_url": "repo2", "id": 3}"#,
459            r#"{"repository_url": "repo2", "id": 4}"#,
460            r#"{"repository_url": "repo3", "id": 5}"#,
461            r#"{"repository_url": "repo3", "id": 6}"#,
462            r#"{"repository_url": "repo4", "id": 7}"#,
463            r#"{"repository_url": "repo4", "id": 8}"#,
464        ]);
465
466        let temp_dir = tempfile::tempdir().unwrap();
467        let train_path = temp_dir.path().join("train.jsonl");
468        let valid_path = temp_dir.path().join("valid.jsonl");
469
470        let args = SplitArgs { seed: Some(42) };
471        let inputs = vec![
472            input.path().to_path_buf(),
473            PathBuf::from(format!("{}=50%", train_path.display())),
474            PathBuf::from(format!("{}=rest", valid_path.display())),
475        ];
476
477        run_split(&args, &inputs).unwrap();
478
479        let train_content = std::fs::read_to_string(&train_path).unwrap();
480        let valid_content = std::fs::read_to_string(&valid_path).unwrap();
481
482        let train_lines: Vec<&str> = train_content.lines().collect();
483        let valid_lines: Vec<&str> = valid_content.lines().collect();
484
485        assert_eq!(train_lines.len() + valid_lines.len(), 8);
486
487        let train_repos: std::collections::HashSet<_> = train_lines
488            .iter()
489            .filter_map(|l| get_repository_url(l))
490            .collect();
491        let valid_repos: std::collections::HashSet<_> = valid_lines
492            .iter()
493            .filter_map(|l| get_repository_url(l))
494            .collect();
495
496        assert!(
497            train_repos.is_disjoint(&valid_repos),
498            "train and valid should have non-overlapping repos"
499        );
500    }
501
502    #[test]
503    fn test_multiple_rest_fails() {
504        let specs = vec![
505            SplitSpec {
506                path: PathBuf::from("a"),
507                size: SplitSize::Rest,
508            },
509            SplitSpec {
510                path: PathBuf::from("b"),
511                size: SplitSize::Rest,
512            },
513        ];
514        assert!(compute_split_counts(&specs, 100).is_err());
515    }
516}