ep: Stratify by cursor_path by default (#50111)

Oleksiy Syvokon created

Also, `ep split train=100` now means 100 lines, not 100 groups (repos or
cursor_paths).


Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/main.rs          |   5 
crates/edit_prediction_cli/src/score.rs         |  42 +
crates/edit_prediction_cli/src/split_dataset.rs | 318 +++++++++++-------
3 files changed, 227 insertions(+), 138 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -294,6 +294,9 @@ struct EvalArgs {
     /// Path to write summary scores as JSON
     #[clap(long)]
     summary_json: Option<PathBuf>,
+    /// Print all individual example lines (default: up to 20)
+    #[clap(long)]
+    verbose: bool,
 }
 
 #[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
@@ -1238,7 +1241,7 @@ fn main() {
                 match &command {
                     Command::Eval(args) => {
                         let examples = finished_examples.lock().unwrap();
-                        score::print_report(&examples);
+                        score::print_report(&examples, args.verbose);
                         if let Some(summary_path) = &args.summary_json {
                             score::write_summary_json(&examples, summary_path)?;
                         }

crates/edit_prediction_cli/src/score.rs 🔗

@@ -217,7 +217,8 @@ fn compute_cursor_metrics(
     }
 }
 
-pub fn print_report(examples: &[Example]) {
+pub fn print_report(examples: &[Example], verbose: bool) {
+    const MAX_EXAMPLES_DEFAULT: usize = 20;
     use crate::metrics::ClassificationMetrics;
 
     const LINE_WIDTH: usize = 101;
@@ -250,6 +251,9 @@ pub fn print_report(examples: &[Example]) {
     let mut patch_deleted_tokens: Vec<usize> = Vec::new();
     let mut predictions_with_patch: usize = 0;
 
+    let mut printed_lines: usize = 0;
+    let mut skipped_lines: usize = 0;
+
     for example in examples {
         for (score_idx, score) in example.score.iter().enumerate() {
             let exact_lines = ClassificationMetrics {
@@ -284,18 +288,23 @@ pub fn print_report(examples: &[Example]) {
                 (None, _) => "-".to_string(),
             };
 
-            println!(
-                "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
-                truncate_name(&example.spec.name, 40),
-                score.delta_chr_f,
-                score.braces_disbalance,
-                exact_lines.f1() * 100.0,
-                score.reversal_ratio * 100.0,
-                qa_reverts_str,
-                qa_conf_str,
-                cursor_str,
-                wrong_er_str
-            );
+            if verbose || printed_lines < MAX_EXAMPLES_DEFAULT {
+                println!(
+                    "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
+                    truncate_name(&example.spec.name, 40),
+                    score.delta_chr_f,
+                    score.braces_disbalance,
+                    exact_lines.f1() * 100.0,
+                    score.reversal_ratio * 100.0,
+                    qa_reverts_str,
+                    qa_conf_str,
+                    cursor_str,
+                    wrong_er_str
+                );
+                printed_lines += 1;
+            } else {
+                skipped_lines += 1;
+            }
 
             all_delta_chr_f_scores.push(score.delta_chr_f);
             all_reversal_ratios.push(score.reversal_ratio);
@@ -358,6 +367,13 @@ pub fn print_report(examples: &[Example]) {
         }
     }
 
+    if skipped_lines > 0 {
+        println!(
+            "{:<40} (use --verbose to see all {} examples)",
+            format!("... and {} more", skipped_lines),
+            printed_lines + skipped_lines
+        );
+    }
     println!("{}", separator);
 
     if !all_delta_chr_f_scores.is_empty() {

crates/edit_prediction_cli/src/split_dataset.rs 🔗

@@ -1,29 +1,34 @@
 //! `ep split` implementation.
 //!
 //! This command splits a JSONL dataset into multiple files based on size specifications,
-//! with stratification by repository URL (if the field is present).
+//! with optional stratification by a JSON field.
 //!
 //! # Usage
 //!
 //! ```text
-//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
+//! ep split [--stratify=<field>] [input.jsonl] <out1>=<size1> <out2>=<size2> ...
 //! ```
 //!
 //! If `input.jsonl` is not provided or is `-`, reads from stdin.
 //!
 //! # Size specifications
 //!
-//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
-//! - `100` - absolute count of repositories (if stratified) or examples
+//! - `80%` - percentage of total examples (lines)
+//! - `100` - approximate absolute count of examples (lines)
 //! - `rest` - all remaining items (only one split can use this)
 //!
 //! # Stratification
 //!
-//! When examples have a `repository_url` field, the split is stratified by repository.
-//! This ensures each output file contains examples from non-overlapping repositories.
-//! Size specifications apply to the number of repositories, not individual examples.
+//! The `--stratify` flag controls how examples are grouped before splitting:
 //!
-//! Examples without `repository_url` are distributed proportionally across all outputs.
+//! - `cursor-path` (default): group by the `cursor_path` JSON field
+//! - `repo`: group by the `repository_url` JSON field
+//! - `none`: no grouping, split individual examples
+//!
+//! When stratifying, the split ensures each output file contains examples from
+//! non-overlapping groups. Size specifications always apply to the number of
+//! examples (lines), with whole groups assigned greedily to meet the target.
+//! Examples missing the stratification field are treated as individual groups.
 
 use anyhow::{Context as _, Result, bail};
 use clap::Args;
@@ -38,23 +43,27 @@ use std::path::{Path, PathBuf};
 /// `ep split` CLI args.
 #[derive(Debug, Args, Clone)]
 #[command(
-    about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
+    about = "Split a JSONL dataset into multiple files with optional stratification",
     after_help = r#"SIZE SPECIFICATIONS:
   <percentage>%    Percentage of total (e.g., 80%)
   <count>          Absolute number (e.g., 100)
   rest             All remaining items (only one output can use this)
 
-  When stratifying by repository_url, sizes apply to repositories, not examples.
+  Sizes always apply to examples (lines). When stratifying, whole groups
+  are assigned greedily to approximate the target count.
 
 EXAMPLES:
-  # Split 80% train, 20% validation
+  # Split 80% train, 20% validation (default: stratify by cursor_path)
   ep split input.jsonl train.jsonl=80% valid.jsonl=rest
 
   # Split into train/valid/test
   ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
 
-  # Use absolute counts (100 repos to train, rest to valid)
-  ep split input.jsonl train.jsonl=100 valid.jsonl=rest
+  # Stratify by repository_url instead of cursor_path
+  ep split --stratify=repo input.jsonl train.jsonl=80% valid.jsonl=rest
+
+  # No stratification (split by individual examples)
+  ep split --stratify=none input.jsonl train.jsonl=80% valid.jsonl=rest
 
   # Read from stdin
   cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
@@ -62,14 +71,15 @@ EXAMPLES:
   # Reproducible split with seed
   ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
 
-  # Disable stratification (split by examples, not repositories)
-  ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest
-
 STRATIFICATION:
-  When examples have a "repository_url" field, the split ensures each output
-  file contains examples from non-overlapping repositories. This prevents
-  data leakage between train/test splits. Use --no-stratify to disable this
-  behavior and split by individual examples instead.
+  Controls how examples are grouped before splitting:
+    cursor-path  Group by "cursor_path" field (default)
+    repo         Group by "repository_url" field
+    none         No grouping, split individual examples
+
+  When stratifying, the split ensures each output file contains examples
+  from non-overlapping groups. This prevents data leakage between
+  train/test splits.
 "#
 )]
 pub struct SplitArgs {
@@ -77,9 +87,19 @@ pub struct SplitArgs {
     #[arg(long)]
     pub seed: Option<u64>,
 
-    /// Disable stratification by repository_url (split by examples instead)
-    #[arg(long)]
-    pub no_stratify: bool,
+    /// Stratification field for splitting the dataset
+    #[arg(long, default_value = "cursor-path")]
+    pub stratify: Stratify,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, strum::Display)]
+pub enum Stratify {
+    #[strum(serialize = "cursor_path")]
+    CursorPath,
+    #[strum(serialize = "repo")]
+    Repo,
+    #[strum(serialize = "none")]
+    None,
 }
 
 #[derive(Debug, Clone)]
@@ -142,29 +162,6 @@ fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
     Ok(lines)
 }
 
-fn get_repository_url(line: &str) -> Option<String> {
-    let value: Value = serde_json::from_str(line).ok()?;
-    value
-        .get("repository_url")
-        .and_then(|v| v.as_str())
-        .map(|s| s.to_string())
-}
-
-fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
-    let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
-    let mut without_repo: Vec<String> = Vec::new();
-
-    for line in lines {
-        if let Some(repo_url) = get_repository_url(&line) {
-            by_repo.entry(repo_url).or_default().push(line);
-        } else {
-            without_repo.push(line);
-        }
-    }
-
-    (by_repo, without_repo)
-}
-
 fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
     let mut counts = vec![0usize; specs.len()];
     let mut remaining = total;
@@ -261,26 +258,20 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
         return Ok(());
     }
 
-    let (by_repo, without_repo) = group_lines_by_repo(lines);
-    let has_repos = !by_repo.is_empty() && !args.no_stratify;
+    let mut grouped_lines = group_lines(&lines, args.stratify);
 
-    if args.no_stratify && !by_repo.is_empty() {
+    if args.stratify != Stratify::None {
         eprintln!(
-            "Stratification disabled (--no-stratify), splitting {} examples by line",
+            "Stratifying by {} ({} unique groups, {} examples)",
+            args.stratify,
+            grouped_lines.len(),
             total_lines
         );
-    } else if has_repos {
+    } else {
         eprintln!(
-            "Stratifying by repository_url ({} unique repositories, {} examples)",
-            by_repo.len(),
-            total_lines - without_repo.len()
+            "No stratification, splitting {} examples by line",
+            total_lines
         );
-        if !without_repo.is_empty() {
-            eprintln!(
-                "  + {} examples without repository_url (distributed proportionally)",
-                without_repo.len()
-            );
-        }
     }
 
     let mut rng = match args.seed {
@@ -288,53 +279,31 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
         None => rand::rngs::StdRng::from_os_rng(),
     };
 
-    let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
-
-    if has_repos {
-        let mut repos: Vec<String> = by_repo.keys().cloned().collect();
-        repos.shuffle(&mut rng);
+    grouped_lines.shuffle(&mut rng);
 
-        let repo_counts = compute_split_counts(&specs, repos.len())?;
+    let line_targets = compute_split_counts(&specs, total_lines)?;
+    let rest_index = specs.iter().position(|s| matches!(s.size, SplitSize::Rest));
+    let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
+    let mut group_iter = grouped_lines.into_iter();
 
-        let mut repo_iter = repos.into_iter();
-        for (split_idx, &count) in repo_counts.iter().enumerate() {
-            for _ in 0..count {
-                if let Some(repo) = repo_iter.next() {
-                    if let Some(repo_lines) = by_repo.get(&repo) {
-                        split_outputs[split_idx].extend(repo_lines.iter().cloned());
-                    }
-                }
-            }
+    for (split_idx, &target) in line_targets.iter().enumerate() {
+        if Some(split_idx) == rest_index {
+            continue;
         }
-
-        if !without_repo.is_empty() {
-            let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
-            let mut no_repo_shuffled = without_repo;
-            no_repo_shuffled.shuffle(&mut rng);
-
-            let mut line_iter = no_repo_shuffled.into_iter();
-            for (split_idx, &count) in no_repo_counts.iter().enumerate() {
-                for _ in 0..count {
-                    if let Some(line) = line_iter.next() {
-                        split_outputs[split_idx].push(line);
-                    }
-                }
+        let mut accumulated = 0;
+        while accumulated < target {
+            if let Some(group) = group_iter.next() {
+                accumulated += group.len();
+                split_outputs[split_idx].extend(group);
+            } else {
+                break;
             }
         }
-    } else {
-        let line_counts = compute_split_counts(&specs, total_lines)?;
-        let mut all_lines: Vec<String> = by_repo.into_values().flatten().collect();
-        all_lines.extend(without_repo);
-        all_lines.shuffle(&mut rng);
-
-        let mut line_iter = all_lines.into_iter();
+    }
 
-        for (split_idx, &count) in line_counts.iter().enumerate() {
-            for _ in 0..count {
-                if let Some(line) = line_iter.next() {
-                    split_outputs[split_idx].push(line);
-                }
-            }
+    if let Some(idx) = rest_index {
+        for group in group_iter {
+            split_outputs[idx].extend(group);
         }
     }
 
@@ -346,6 +315,39 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
     Ok(())
 }
 
+/// Groups lines by the specified stratification field.
+///
+/// When `stratify` is `None`, each line becomes its own group.
+/// When a line is missing the stratification field, it is also placed in its own group.
+fn group_lines(lines: &[String], stratify: Stratify) -> Vec<Vec<String>> {
+    if stratify == Stratify::None {
+        return lines.iter().map(|line| vec![line.clone()]).collect();
+    }
+
+    let field = match stratify {
+        Stratify::Repo => "repository_url",
+        Stratify::CursorPath => "cursor_path",
+        Stratify::None => unreachable!(),
+    };
+
+    let mut groups: HashMap<String, Vec<String>> = HashMap::new();
+    let mut ungrouped: Vec<Vec<String>> = Vec::new();
+
+    for line in lines {
+        let key = serde_json::from_str::<Value>(line)
+            .ok()
+            .and_then(|v| v.get(field)?.as_str().map(|s| s.to_string()));
+        match key {
+            Some(key) => groups.entry(key).or_default().push(line.clone()),
+            None => ungrouped.push(vec![line.clone()]),
+        }
+    }
+
+    let mut result: Vec<Vec<String>> = groups.into_values().collect();
+    result.extend(ungrouped);
+    result
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -389,15 +391,11 @@ mod tests {
     }
 
     #[test]
-    fn test_get_repository_url() {
-        let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
-        assert_eq!(
-            get_repository_url(line),
-            Some("https://github.com/example/repo".to_string())
-        );
-
-        let line_no_repo = r#"{"data": 123}"#;
-        assert_eq!(get_repository_url(line_no_repo), None);
+    fn test_group_lines_none() {
+        let lines = vec!["a".to_string(), "b".to_string(), "c".to_string()];
+        let groups = group_lines(&lines, Stratify::None);
+        assert_eq!(groups.len(), 3);
+        assert!(groups.iter().all(|g| g.len() == 1));
     }
 
     #[test]
@@ -457,12 +455,30 @@ mod tests {
             r#"{"id": 4}"#.to_string(),
         ];
 
-        let (by_repo, without_repo) = group_lines_by_repo(lines);
+        let groups = group_lines(&lines, Stratify::Repo);
+
+        let grouped_count: usize = groups.iter().filter(|g| g.len() > 1).count();
+        let ungrouped_count: usize = groups.iter().filter(|g| g.len() == 1).count();
+        let total_lines: usize = groups.iter().map(|g| g.len()).sum();
 
-        assert_eq!(by_repo.len(), 2);
-        assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
-        assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
-        assert_eq!(without_repo.len(), 1);
+        assert_eq!(grouped_count, 1); // repo1 has 2 lines
+        assert_eq!(ungrouped_count, 2); // repo2 (1 line) + line without repo
+        assert_eq!(total_lines, 4);
+    }
+
+    #[test]
+    fn test_group_lines_by_cursor_path() {
+        let lines = vec![
+            r#"{"cursor_path": "src/main.rs", "id": 1}"#.to_string(),
+            r#"{"cursor_path": "src/main.rs", "id": 2}"#.to_string(),
+            r#"{"cursor_path": "src/lib.rs", "id": 3}"#.to_string(),
+        ];
+
+        let groups = group_lines(&lines, Stratify::CursorPath);
+
+        let total_lines: usize = groups.iter().map(|g| g.len()).sum();
+        assert_eq!(groups.len(), 2);
+        assert_eq!(total_lines, 3);
     }
 
     #[test]
@@ -484,7 +500,7 @@ mod tests {
 
         let args = SplitArgs {
             seed: Some(42),
-            no_stratify: false,
+            stratify: Stratify::Repo,
         };
         let inputs = vec![
             input.path().to_path_buf(),
@@ -502,14 +518,18 @@ mod tests {
 
         assert_eq!(train_lines.len() + valid_lines.len(), 8);
 
-        let train_repos: std::collections::HashSet<_> = train_lines
-            .iter()
-            .filter_map(|l| get_repository_url(l))
-            .collect();
-        let valid_repos: std::collections::HashSet<_> = valid_lines
-            .iter()
-            .filter_map(|l| get_repository_url(l))
-            .collect();
+        let get_repo = |line: &str| -> Option<String> {
+            let value: Value = serde_json::from_str(line).ok()?;
+            value
+                .get("repository_url")
+                .and_then(|v| v.as_str())
+                .map(|s| s.to_string())
+        };
+
+        let train_repos: std::collections::HashSet<_> =
+            train_lines.iter().filter_map(|l| get_repo(l)).collect();
+        let valid_repos: std::collections::HashSet<_> =
+            valid_lines.iter().filter_map(|l| get_repo(l)).collect();
 
         assert!(
             train_repos.is_disjoint(&valid_repos),
@@ -531,4 +551,54 @@ mod tests {
         ];
         assert!(compute_split_counts(&specs, 100).is_err());
     }
+
+    #[test]
+    fn test_absolute_targets_lines_not_groups() {
+        // 5 repos × 3 lines each = 15 total lines.
+        // `train=6` should target ~6 lines (2 groups), NOT 6 groups (all 15 lines).
+        let input = create_temp_jsonl(&[
+            r#"{"repository_url": "r1", "id": 1}"#,
+            r#"{"repository_url": "r1", "id": 2}"#,
+            r#"{"repository_url": "r1", "id": 3}"#,
+            r#"{"repository_url": "r2", "id": 4}"#,
+            r#"{"repository_url": "r2", "id": 5}"#,
+            r#"{"repository_url": "r2", "id": 6}"#,
+            r#"{"repository_url": "r3", "id": 7}"#,
+            r#"{"repository_url": "r3", "id": 8}"#,
+            r#"{"repository_url": "r3", "id": 9}"#,
+            r#"{"repository_url": "r4", "id": 10}"#,
+            r#"{"repository_url": "r4", "id": 11}"#,
+            r#"{"repository_url": "r4", "id": 12}"#,
+            r#"{"repository_url": "r5", "id": 13}"#,
+            r#"{"repository_url": "r5", "id": 14}"#,
+            r#"{"repository_url": "r5", "id": 15}"#,
+        ]);
+
+        let temp_dir = tempfile::tempdir().unwrap();
+        let train_path = temp_dir.path().join("train.jsonl");
+        let valid_path = temp_dir.path().join("valid.jsonl");
+
+        let args = SplitArgs {
+            seed: Some(42),
+            stratify: Stratify::Repo,
+        };
+        let inputs = vec![
+            input.path().to_path_buf(),
+            PathBuf::from(format!("{}=6", train_path.display())),
+            PathBuf::from(format!("{}=rest", valid_path.display())),
+        ];
+
+        run_split(&args, &inputs).unwrap();
+
+        let train_content = std::fs::read_to_string(&train_path).unwrap();
+        let valid_content = std::fs::read_to_string(&valid_path).unwrap();
+
+        let train_lines: Vec<&str> = train_content.lines().collect();
+        let valid_lines: Vec<&str> = valid_content.lines().collect();
+
+        // With 3-line groups, train should get 2 groups (6 lines) to meet the
+        // target of 6, NOT 6 groups (which don't even exist). Valid gets the rest.
+        assert_eq!(train_lines.len(), 6);
+        assert_eq!(valid_lines.len(), 9);
+    }
 }