From 7aa8c742d8480b3cc8cca36bffb99bf561e37b77 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Fri, 30 Jan 2026 18:59:46 +0200 Subject: [PATCH] Add `ep split --no-stratify` option (#48036) Release Notes: - N/A --- .../edit_prediction_cli/src/split_dataset.rs | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/crates/edit_prediction_cli/src/split_dataset.rs b/crates/edit_prediction_cli/src/split_dataset.rs index d4ff39be87d341ec1204bb6f668006050d1afac3..b34d7c14c6646442359459ef8d4450dae0b9c40e 100644 --- a/crates/edit_prediction_cli/src/split_dataset.rs +++ b/crates/edit_prediction_cli/src/split_dataset.rs @@ -62,16 +62,24 @@ EXAMPLES: # Reproducible split with seed ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest + # Disable stratification (split by examples, not repositories) + ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest + STRATIFICATION: When examples have a "repository_url" field, the split ensures each output file contains examples from non-overlapping repositories. This prevents - data leakage between train/test splits. + data leakage between train/test splits. Use --no-stratify to disable this + behavior and split by individual examples instead. "# )] pub struct SplitArgs { /// Random seed for reproducibility #[arg(long)] pub seed: Option, + + /// Disable stratification by repository_url (split by examples instead) + #[arg(long)] + pub no_stratify: bool, } #[derive(Debug, Clone)] @@ -254,9 +262,14 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> { } let (by_repo, without_repo) = group_lines_by_repo(lines); - let has_repos = !by_repo.is_empty(); + let has_repos = !by_repo.is_empty() && !args.no_stratify; - if has_repos { + if args.no_stratify && !by_repo.is_empty() { + eprintln!( + "Stratification disabled (--no-stratify), splitting {} examples by line", + total_lines + ); + } else if has_repos { eprintln!( "Stratifying by repository_url ({} unique repositories, {} examples)", by_repo.len(), @@ -310,10 +323,12 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> { } } else { let line_counts = compute_split_counts(&specs, total_lines)?; - let mut shuffled_lines = without_repo; - shuffled_lines.shuffle(&mut rng); + let mut all_lines: Vec = by_repo.into_values().flatten().collect(); + all_lines.extend(without_repo); + all_lines.shuffle(&mut rng); + + let mut line_iter = all_lines.into_iter(); - let mut line_iter = shuffled_lines.into_iter(); for (split_idx, &count) in line_counts.iter().enumerate() { for _ in 0..count { if let Some(line) = line_iter.next() { @@ -467,7 +482,10 @@ mod tests { let train_path = temp_dir.path().join("train.jsonl"); let valid_path = temp_dir.path().join("valid.jsonl"); - let args = SplitArgs { seed: Some(42) }; + let args = SplitArgs { + seed: Some(42), + no_stratify: false, + }; let inputs = vec![ input.path().to_path_buf(), PathBuf::from(format!("{}=50%", train_path.display())),