Add `ep split --no-stratify` option (#48036)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/split_dataset.rs | 32 ++++++++++++++----
1 file changed, 25 insertions(+), 7 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/split_dataset.rs 🔗

@@ -62,16 +62,24 @@ EXAMPLES:
   # Reproducible split with seed
   ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
 
+  # Disable stratification (split by examples, not repositories)
+  ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest
+
 STRATIFICATION:
   When examples have a "repository_url" field, the split ensures each output
   file contains examples from non-overlapping repositories. This prevents
-  data leakage between train/test splits.
+  data leakage between train/test splits. Use --no-stratify to disable this
+  behavior and split by individual examples instead.
 "#
 )]
 pub struct SplitArgs {
     /// Random seed for reproducibility
     #[arg(long)]
     pub seed: Option<u64>,
+
+    /// Disable stratification by repository_url (split by examples instead)
+    #[arg(long)]
+    pub no_stratify: bool,
 }
 
 #[derive(Debug, Clone)]
@@ -254,9 +262,14 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
     }
 
     let (by_repo, without_repo) = group_lines_by_repo(lines);
-    let has_repos = !by_repo.is_empty();
+    let has_repos = !by_repo.is_empty() && !args.no_stratify;
 
-    if has_repos {
+    if args.no_stratify && !by_repo.is_empty() {
+        eprintln!(
+            "Stratification disabled (--no-stratify), splitting {} examples by line",
+            total_lines
+        );
+    } else if has_repos {
         eprintln!(
             "Stratifying by repository_url ({} unique repositories, {} examples)",
             by_repo.len(),
@@ -310,10 +323,12 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
         }
     } else {
         let line_counts = compute_split_counts(&specs, total_lines)?;
-        let mut shuffled_lines = without_repo;
-        shuffled_lines.shuffle(&mut rng);
+        let mut all_lines: Vec<String> = by_repo.into_values().flatten().collect();
+        all_lines.extend(without_repo);
+        all_lines.shuffle(&mut rng);
+
+        let mut line_iter = all_lines.into_iter();
 
-        let mut line_iter = shuffled_lines.into_iter();
         for (split_idx, &count) in line_counts.iter().enumerate() {
             for _ in 0..count {
                 if let Some(line) = line_iter.next() {
@@ -467,7 +482,10 @@ mod tests {
         let train_path = temp_dir.path().join("train.jsonl");
         let valid_path = temp_dir.path().join("valid.jsonl");
 
-        let args = SplitArgs { seed: Some(42) };
+        let args = SplitArgs {
+            seed: Some(42),
+            no_stratify: false,
+        };
         let inputs = vec![
             input.path().to_path_buf(),
             PathBuf::from(format!("{}=50%", train_path.display())),