diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 03f94a4dc47388c9a56169f2be0280af33dc6f1d..a6a0b2e3145cefbe7dd84a88733fe5d865b6364b 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -294,6 +294,9 @@ struct EvalArgs { /// Path to write summary scores as JSON #[clap(long)] summary_json: Option, + /// Print all individual example lines (default: up to 20) + #[clap(long)] + verbose: bool, } #[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)] @@ -1238,7 +1241,7 @@ fn main() { match &command { Command::Eval(args) => { let examples = finished_examples.lock().unwrap(); - score::print_report(&examples); + score::print_report(&examples, args.verbose); if let Some(summary_path) = &args.summary_json { score::write_summary_json(&examples, summary_path)?; } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 8436dc4a4b26206eb41bafd5b9de8645cb0abb5e..b6f745114f6dd2a091b95b724ee53869a04a8c4e 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -217,7 +217,8 @@ fn compute_cursor_metrics( } } -pub fn print_report(examples: &[Example]) { +pub fn print_report(examples: &[Example], verbose: bool) { + const MAX_EXAMPLES_DEFAULT: usize = 20; use crate::metrics::ClassificationMetrics; const LINE_WIDTH: usize = 101; @@ -250,6 +251,9 @@ pub fn print_report(examples: &[Example]) { let mut patch_deleted_tokens: Vec = Vec::new(); let mut predictions_with_patch: usize = 0; + let mut printed_lines: usize = 0; + let mut skipped_lines: usize = 0; + for example in examples { for (score_idx, score) in example.score.iter().enumerate() { let exact_lines = ClassificationMetrics { @@ -284,18 +288,23 @@ pub fn print_report(examples: &[Example]) { (None, _) => "-".to_string(), }; - println!( - "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}", - truncate_name(&example.spec.name, 40), - score.delta_chr_f, - score.braces_disbalance, - exact_lines.f1() * 100.0, - score.reversal_ratio * 100.0, - qa_reverts_str, - qa_conf_str, - cursor_str, - wrong_er_str - ); + if verbose || printed_lines < MAX_EXAMPLES_DEFAULT { + println!( + "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}", + truncate_name(&example.spec.name, 40), + score.delta_chr_f, + score.braces_disbalance, + exact_lines.f1() * 100.0, + score.reversal_ratio * 100.0, + qa_reverts_str, + qa_conf_str, + cursor_str, + wrong_er_str + ); + printed_lines += 1; + } else { + skipped_lines += 1; + } all_delta_chr_f_scores.push(score.delta_chr_f); all_reversal_ratios.push(score.reversal_ratio); @@ -358,6 +367,13 @@ pub fn print_report(examples: &[Example]) { } } + if skipped_lines > 0 { + println!( + "{:<40} (use --verbose to see all {} examples)", + format!("... and {} more", skipped_lines), + printed_lines + skipped_lines + ); + } println!("{}", separator); if !all_delta_chr_f_scores.is_empty() { diff --git a/crates/edit_prediction_cli/src/split_dataset.rs b/crates/edit_prediction_cli/src/split_dataset.rs index b34d7c14c6646442359459ef8d4450dae0b9c40e..f1e0a672695cb940f3c368f71fec3b16a64524a1 100644 --- a/crates/edit_prediction_cli/src/split_dataset.rs +++ b/crates/edit_prediction_cli/src/split_dataset.rs @@ -1,29 +1,34 @@ //! `ep split` implementation. //! //! This command splits a JSONL dataset into multiple files based on size specifications, -//! with stratification by repository URL (if the field is present). +//! with optional stratification by a JSON field. //! //! # Usage //! //! ```text -//! ep split [input.jsonl] = = ... +//! ep split [--stratify=] [input.jsonl] = = ... //! ``` //! //! If `input.jsonl` is not provided or is `-`, reads from stdin. //! //! # Size specifications //! -//! - `80%` - percentage of total (repositories if stratified, examples otherwise) -//! - `100` - absolute count of repositories (if stratified) or examples +//! - `80%` - percentage of total examples (lines) +//! - `100` - approximate absolute count of examples (lines) //! - `rest` - all remaining items (only one split can use this) //! //! # Stratification //! -//! When examples have a `repository_url` field, the split is stratified by repository. -//! This ensures each output file contains examples from non-overlapping repositories. -//! Size specifications apply to the number of repositories, not individual examples. +//! The `--stratify` flag controls how examples are grouped before splitting: //! -//! Examples without `repository_url` are distributed proportionally across all outputs. +//! - `cursor-path` (default): group by the `cursor_path` JSON field +//! - `repo`: group by the `repository_url` JSON field +//! - `none`: no grouping, split individual examples +//! +//! When stratifying, the split ensures each output file contains examples from +//! non-overlapping groups. Size specifications always apply to the number of +//! examples (lines), with whole groups assigned greedily to meet the target. +//! Examples missing the stratification field are treated as individual groups. use anyhow::{Context as _, Result, bail}; use clap::Args; @@ -38,23 +43,27 @@ use std::path::{Path, PathBuf}; /// `ep split` CLI args. #[derive(Debug, Args, Clone)] #[command( - about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)", + about = "Split a JSONL dataset into multiple files with optional stratification", after_help = r#"SIZE SPECIFICATIONS: % Percentage of total (e.g., 80%) Absolute number (e.g., 100) rest All remaining items (only one output can use this) - When stratifying by repository_url, sizes apply to repositories, not examples. + Sizes always apply to examples (lines). When stratifying, whole groups + are assigned greedily to approximate the target count. EXAMPLES: - # Split 80% train, 20% validation + # Split 80% train, 20% validation (default: stratify by cursor_path) ep split input.jsonl train.jsonl=80% valid.jsonl=rest # Split into train/valid/test ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest - # Use absolute counts (100 repos to train, rest to valid) - ep split input.jsonl train.jsonl=100 valid.jsonl=rest + # Stratify by repository_url instead of cursor_path + ep split --stratify=repo input.jsonl train.jsonl=80% valid.jsonl=rest + + # No stratification (split by individual examples) + ep split --stratify=none input.jsonl train.jsonl=80% valid.jsonl=rest # Read from stdin cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest @@ -62,14 +71,15 @@ EXAMPLES: # Reproducible split with seed ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest - # Disable stratification (split by examples, not repositories) - ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest - STRATIFICATION: - When examples have a "repository_url" field, the split ensures each output - file contains examples from non-overlapping repositories. This prevents - data leakage between train/test splits. Use --no-stratify to disable this - behavior and split by individual examples instead. + Controls how examples are grouped before splitting: + cursor-path Group by "cursor_path" field (default) + repo Group by "repository_url" field + none No grouping, split individual examples + + When stratifying, the split ensures each output file contains examples + from non-overlapping groups. This prevents data leakage between + train/test splits. "# )] pub struct SplitArgs { @@ -77,9 +87,19 @@ pub struct SplitArgs { #[arg(long)] pub seed: Option, - /// Disable stratification by repository_url (split by examples instead) - #[arg(long)] - pub no_stratify: bool, + /// Stratification field for splitting the dataset + #[arg(long, default_value = "cursor-path")] + pub stratify: Stratify, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, strum::Display)] +pub enum Stratify { + #[strum(serialize = "cursor_path")] + CursorPath, + #[strum(serialize = "repo")] + Repo, + #[strum(serialize = "none")] + None, } #[derive(Debug, Clone)] @@ -142,29 +162,6 @@ fn read_lines_from_input(input: Option<&Path>) -> Result> { Ok(lines) } -fn get_repository_url(line: &str) -> Option { - let value: Value = serde_json::from_str(line).ok()?; - value - .get("repository_url") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) -} - -fn group_lines_by_repo(lines: Vec) -> (HashMap>, Vec) { - let mut by_repo: HashMap> = HashMap::new(); - let mut without_repo: Vec = Vec::new(); - - for line in lines { - if let Some(repo_url) = get_repository_url(&line) { - by_repo.entry(repo_url).or_default().push(line); - } else { - without_repo.push(line); - } - } - - (by_repo, without_repo) -} - fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result> { let mut counts = vec![0usize; specs.len()]; let mut remaining = total; @@ -261,26 +258,20 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> { return Ok(()); } - let (by_repo, without_repo) = group_lines_by_repo(lines); - let has_repos = !by_repo.is_empty() && !args.no_stratify; + let mut grouped_lines = group_lines(&lines, args.stratify); - if args.no_stratify && !by_repo.is_empty() { + if args.stratify != Stratify::None { eprintln!( - "Stratification disabled (--no-stratify), splitting {} examples by line", + "Stratifying by {} ({} unique groups, {} examples)", + args.stratify, + grouped_lines.len(), total_lines ); - } else if has_repos { + } else { eprintln!( - "Stratifying by repository_url ({} unique repositories, {} examples)", - by_repo.len(), - total_lines - without_repo.len() + "No stratification, splitting {} examples by line", + total_lines ); - if !without_repo.is_empty() { - eprintln!( - " + {} examples without repository_url (distributed proportionally)", - without_repo.len() - ); - } } let mut rng = match args.seed { @@ -288,53 +279,31 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> { None => rand::rngs::StdRng::from_os_rng(), }; - let mut split_outputs: Vec> = vec![Vec::new(); specs.len()]; - - if has_repos { - let mut repos: Vec = by_repo.keys().cloned().collect(); - repos.shuffle(&mut rng); + grouped_lines.shuffle(&mut rng); - let repo_counts = compute_split_counts(&specs, repos.len())?; + let line_targets = compute_split_counts(&specs, total_lines)?; + let rest_index = specs.iter().position(|s| matches!(s.size, SplitSize::Rest)); + let mut split_outputs: Vec> = vec![Vec::new(); specs.len()]; + let mut group_iter = grouped_lines.into_iter(); - let mut repo_iter = repos.into_iter(); - for (split_idx, &count) in repo_counts.iter().enumerate() { - for _ in 0..count { - if let Some(repo) = repo_iter.next() { - if let Some(repo_lines) = by_repo.get(&repo) { - split_outputs[split_idx].extend(repo_lines.iter().cloned()); - } - } - } + for (split_idx, &target) in line_targets.iter().enumerate() { + if Some(split_idx) == rest_index { + continue; } - - if !without_repo.is_empty() { - let no_repo_counts = compute_split_counts(&specs, without_repo.len())?; - let mut no_repo_shuffled = without_repo; - no_repo_shuffled.shuffle(&mut rng); - - let mut line_iter = no_repo_shuffled.into_iter(); - for (split_idx, &count) in no_repo_counts.iter().enumerate() { - for _ in 0..count { - if let Some(line) = line_iter.next() { - split_outputs[split_idx].push(line); - } - } + let mut accumulated = 0; + while accumulated < target { + if let Some(group) = group_iter.next() { + accumulated += group.len(); + split_outputs[split_idx].extend(group); + } else { + break; } } - } else { - let line_counts = compute_split_counts(&specs, total_lines)?; - let mut all_lines: Vec = by_repo.into_values().flatten().collect(); - all_lines.extend(without_repo); - all_lines.shuffle(&mut rng); - - let mut line_iter = all_lines.into_iter(); + } - for (split_idx, &count) in line_counts.iter().enumerate() { - for _ in 0..count { - if let Some(line) = line_iter.next() { - split_outputs[split_idx].push(line); - } - } + if let Some(idx) = rest_index { + for group in group_iter { + split_outputs[idx].extend(group); } } @@ -346,6 +315,39 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> { Ok(()) } +/// Groups lines by the specified stratification field. +/// +/// When `stratify` is `None`, each line becomes its own group. +/// When a line is missing the stratification field, it is also placed in its own group. +fn group_lines(lines: &[String], stratify: Stratify) -> Vec> { + if stratify == Stratify::None { + return lines.iter().map(|line| vec![line.clone()]).collect(); + } + + let field = match stratify { + Stratify::Repo => "repository_url", + Stratify::CursorPath => "cursor_path", + Stratify::None => unreachable!(), + }; + + let mut groups: HashMap> = HashMap::new(); + let mut ungrouped: Vec> = Vec::new(); + + for line in lines { + let key = serde_json::from_str::(line) + .ok() + .and_then(|v| v.get(field)?.as_str().map(|s| s.to_string())); + match key { + Some(key) => groups.entry(key).or_default().push(line.clone()), + None => ungrouped.push(vec![line.clone()]), + } + } + + let mut result: Vec> = groups.into_values().collect(); + result.extend(ungrouped); + result +} + #[cfg(test)] mod tests { use super::*; @@ -389,15 +391,11 @@ mod tests { } #[test] - fn test_get_repository_url() { - let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#; - assert_eq!( - get_repository_url(line), - Some("https://github.com/example/repo".to_string()) - ); - - let line_no_repo = r#"{"data": 123}"#; - assert_eq!(get_repository_url(line_no_repo), None); + fn test_group_lines_none() { + let lines = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let groups = group_lines(&lines, Stratify::None); + assert_eq!(groups.len(), 3); + assert!(groups.iter().all(|g| g.len() == 1)); } #[test] @@ -457,12 +455,30 @@ mod tests { r#"{"id": 4}"#.to_string(), ]; - let (by_repo, without_repo) = group_lines_by_repo(lines); + let groups = group_lines(&lines, Stratify::Repo); + + let grouped_count: usize = groups.iter().filter(|g| g.len() > 1).count(); + let ungrouped_count: usize = groups.iter().filter(|g| g.len() == 1).count(); + let total_lines: usize = groups.iter().map(|g| g.len()).sum(); - assert_eq!(by_repo.len(), 2); - assert_eq!(by_repo.get("repo1").unwrap().len(), 2); - assert_eq!(by_repo.get("repo2").unwrap().len(), 1); - assert_eq!(without_repo.len(), 1); + assert_eq!(grouped_count, 1); // repo1 has 2 lines + assert_eq!(ungrouped_count, 2); // repo2 (1 line) + line without repo + assert_eq!(total_lines, 4); + } + + #[test] + fn test_group_lines_by_cursor_path() { + let lines = vec![ + r#"{"cursor_path": "src/main.rs", "id": 1}"#.to_string(), + r#"{"cursor_path": "src/main.rs", "id": 2}"#.to_string(), + r#"{"cursor_path": "src/lib.rs", "id": 3}"#.to_string(), + ]; + + let groups = group_lines(&lines, Stratify::CursorPath); + + let total_lines: usize = groups.iter().map(|g| g.len()).sum(); + assert_eq!(groups.len(), 2); + assert_eq!(total_lines, 3); } #[test] @@ -484,7 +500,7 @@ mod tests { let args = SplitArgs { seed: Some(42), - no_stratify: false, + stratify: Stratify::Repo, }; let inputs = vec![ input.path().to_path_buf(), @@ -502,14 +518,18 @@ mod tests { assert_eq!(train_lines.len() + valid_lines.len(), 8); - let train_repos: std::collections::HashSet<_> = train_lines - .iter() - .filter_map(|l| get_repository_url(l)) - .collect(); - let valid_repos: std::collections::HashSet<_> = valid_lines - .iter() - .filter_map(|l| get_repository_url(l)) - .collect(); + let get_repo = |line: &str| -> Option { + let value: Value = serde_json::from_str(line).ok()?; + value + .get("repository_url") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }; + + let train_repos: std::collections::HashSet<_> = + train_lines.iter().filter_map(|l| get_repo(l)).collect(); + let valid_repos: std::collections::HashSet<_> = + valid_lines.iter().filter_map(|l| get_repo(l)).collect(); assert!( train_repos.is_disjoint(&valid_repos), @@ -531,4 +551,54 @@ mod tests { ]; assert!(compute_split_counts(&specs, 100).is_err()); } + + #[test] + fn test_absolute_targets_lines_not_groups() { + // 5 repos × 3 lines each = 15 total lines. + // `train=6` should target ~6 lines (2 groups), NOT 6 groups (all 15 lines). + let input = create_temp_jsonl(&[ + r#"{"repository_url": "r1", "id": 1}"#, + r#"{"repository_url": "r1", "id": 2}"#, + r#"{"repository_url": "r1", "id": 3}"#, + r#"{"repository_url": "r2", "id": 4}"#, + r#"{"repository_url": "r2", "id": 5}"#, + r#"{"repository_url": "r2", "id": 6}"#, + r#"{"repository_url": "r3", "id": 7}"#, + r#"{"repository_url": "r3", "id": 8}"#, + r#"{"repository_url": "r3", "id": 9}"#, + r#"{"repository_url": "r4", "id": 10}"#, + r#"{"repository_url": "r4", "id": 11}"#, + r#"{"repository_url": "r4", "id": 12}"#, + r#"{"repository_url": "r5", "id": 13}"#, + r#"{"repository_url": "r5", "id": 14}"#, + r#"{"repository_url": "r5", "id": 15}"#, + ]); + + let temp_dir = tempfile::tempdir().unwrap(); + let train_path = temp_dir.path().join("train.jsonl"); + let valid_path = temp_dir.path().join("valid.jsonl"); + + let args = SplitArgs { + seed: Some(42), + stratify: Stratify::Repo, + }; + let inputs = vec![ + input.path().to_path_buf(), + PathBuf::from(format!("{}=6", train_path.display())), + PathBuf::from(format!("{}=rest", valid_path.display())), + ]; + + run_split(&args, &inputs).unwrap(); + + let train_content = std::fs::read_to_string(&train_path).unwrap(); + let valid_content = std::fs::read_to_string(&valid_path).unwrap(); + + let train_lines: Vec<&str> = train_content.lines().collect(); + let valid_lines: Vec<&str> = valid_content.lines().collect(); + + // With 3-line groups, train should get 2 groups (6 lines) to meet the + // target of 6, NOT 6 groups (which don't even exist). Valid gets the rest. + assert_eq!(train_lines.len(), 6); + assert_eq!(valid_lines.len(), 9); + } }