@@ -217,7 +217,8 @@ fn compute_cursor_metrics(
}
}
-pub fn print_report(examples: &[Example]) {
+pub fn print_report(examples: &[Example], verbose: bool) {
+ const MAX_EXAMPLES_DEFAULT: usize = 20;
use crate::metrics::ClassificationMetrics;
const LINE_WIDTH: usize = 101;
@@ -250,6 +251,9 @@ pub fn print_report(examples: &[Example]) {
let mut patch_deleted_tokens: Vec<usize> = Vec::new();
let mut predictions_with_patch: usize = 0;
+ let mut printed_lines: usize = 0;
+ let mut skipped_lines: usize = 0;
+
for example in examples {
for (score_idx, score) in example.score.iter().enumerate() {
let exact_lines = ClassificationMetrics {
@@ -284,18 +288,23 @@ pub fn print_report(examples: &[Example]) {
(None, _) => "-".to_string(),
};
- println!(
- "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
- truncate_name(&example.spec.name, 40),
- score.delta_chr_f,
- score.braces_disbalance,
- exact_lines.f1() * 100.0,
- score.reversal_ratio * 100.0,
- qa_reverts_str,
- qa_conf_str,
- cursor_str,
- wrong_er_str
- );
+ if verbose || printed_lines < MAX_EXAMPLES_DEFAULT {
+ println!(
+ "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
+ truncate_name(&example.spec.name, 40),
+ score.delta_chr_f,
+ score.braces_disbalance,
+ exact_lines.f1() * 100.0,
+ score.reversal_ratio * 100.0,
+ qa_reverts_str,
+ qa_conf_str,
+ cursor_str,
+ wrong_er_str
+ );
+ printed_lines += 1;
+ } else {
+ skipped_lines += 1;
+ }
all_delta_chr_f_scores.push(score.delta_chr_f);
all_reversal_ratios.push(score.reversal_ratio);
@@ -358,6 +367,13 @@ pub fn print_report(examples: &[Example]) {
}
}
+ if skipped_lines > 0 {
+ println!(
+ "{:<40} (use --verbose to see all {} examples)",
+ format!("... and {} more", skipped_lines),
+ printed_lines + skipped_lines
+ );
+ }
println!("{}", separator);
if !all_delta_chr_f_scores.is_empty() {
@@ -1,29 +1,34 @@
//! `ep split` implementation.
//!
//! This command splits a JSONL dataset into multiple files based on size specifications,
-//! with stratification by repository URL (if the field is present).
+//! with optional stratification by a JSON field.
//!
//! # Usage
//!
//! ```text
-//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
+//! ep split [--stratify=<field>] [input.jsonl] <out1>=<size1> <out2>=<size2> ...
//! ```
//!
//! If `input.jsonl` is not provided or is `-`, reads from stdin.
//!
//! # Size specifications
//!
-//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
-//! - `100` - absolute count of repositories (if stratified) or examples
+//! - `80%` - percentage of total examples (lines)
+//! - `100` - approximate absolute count of examples (lines)
//! - `rest` - all remaining items (only one split can use this)
//!
//! # Stratification
//!
-//! When examples have a `repository_url` field, the split is stratified by repository.
-//! This ensures each output file contains examples from non-overlapping repositories.
-//! Size specifications apply to the number of repositories, not individual examples.
+//! The `--stratify` flag controls how examples are grouped before splitting:
//!
-//! Examples without `repository_url` are distributed proportionally across all outputs.
+//! - `cursor-path` (default): group by the `cursor_path` JSON field
+//! - `repo`: group by the `repository_url` JSON field
+//! - `none`: no grouping, split individual examples
+//!
+//! When stratifying, the split ensures each output file contains examples from
+//! non-overlapping groups. Size specifications always apply to the number of
+//! examples (lines), with whole groups assigned greedily to meet the target.
+//! Examples missing the stratification field are treated as individual groups.
use anyhow::{Context as _, Result, bail};
use clap::Args;
@@ -38,23 +43,27 @@ use std::path::{Path, PathBuf};
/// `ep split` CLI args.
#[derive(Debug, Args, Clone)]
#[command(
- about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
+ about = "Split a JSONL dataset into multiple files with optional stratification",
after_help = r#"SIZE SPECIFICATIONS:
<percentage>% Percentage of total (e.g., 80%)
<count> Absolute number (e.g., 100)
rest All remaining items (only one output can use this)
- When stratifying by repository_url, sizes apply to repositories, not examples.
+ Sizes always apply to examples (lines). When stratifying, whole groups
+ are assigned greedily to approximate the target count.
EXAMPLES:
- # Split 80% train, 20% validation
+ # Split 80% train, 20% validation (default: stratify by cursor_path)
ep split input.jsonl train.jsonl=80% valid.jsonl=rest
# Split into train/valid/test
ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
- # Use absolute counts (100 repos to train, rest to valid)
- ep split input.jsonl train.jsonl=100 valid.jsonl=rest
+ # Stratify by repository_url instead of cursor_path
+ ep split --stratify=repo input.jsonl train.jsonl=80% valid.jsonl=rest
+
+ # No stratification (split by individual examples)
+ ep split --stratify=none input.jsonl train.jsonl=80% valid.jsonl=rest
# Read from stdin
cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
@@ -62,14 +71,15 @@ EXAMPLES:
# Reproducible split with seed
ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
- # Disable stratification (split by examples, not repositories)
- ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest
-
STRATIFICATION:
- When examples have a "repository_url" field, the split ensures each output
- file contains examples from non-overlapping repositories. This prevents
- data leakage between train/test splits. Use --no-stratify to disable this
- behavior and split by individual examples instead.
+ Controls how examples are grouped before splitting:
+ cursor-path Group by "cursor_path" field (default)
+ repo Group by "repository_url" field
+ none No grouping, split individual examples
+
+ When stratifying, the split ensures each output file contains examples
+ from non-overlapping groups. This prevents data leakage between
+ train/test splits.
"#
)]
pub struct SplitArgs {
@@ -77,9 +87,19 @@ pub struct SplitArgs {
#[arg(long)]
pub seed: Option<u64>,
- /// Disable stratification by repository_url (split by examples instead)
- #[arg(long)]
- pub no_stratify: bool,
+ /// Stratification field for splitting the dataset
+ #[arg(long, default_value = "cursor-path")]
+ pub stratify: Stratify,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, strum::Display)]
+pub enum Stratify {
+ #[strum(serialize = "cursor_path")]
+ CursorPath,
+ #[strum(serialize = "repo")]
+ Repo,
+ #[strum(serialize = "none")]
+ None,
}
#[derive(Debug, Clone)]
@@ -142,29 +162,6 @@ fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
Ok(lines)
}
-fn get_repository_url(line: &str) -> Option<String> {
- let value: Value = serde_json::from_str(line).ok()?;
- value
- .get("repository_url")
- .and_then(|v| v.as_str())
- .map(|s| s.to_string())
-}
-
-fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
- let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
- let mut without_repo: Vec<String> = Vec::new();
-
- for line in lines {
- if let Some(repo_url) = get_repository_url(&line) {
- by_repo.entry(repo_url).or_default().push(line);
- } else {
- without_repo.push(line);
- }
- }
-
- (by_repo, without_repo)
-}
-
fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
let mut counts = vec![0usize; specs.len()];
let mut remaining = total;
@@ -261,26 +258,20 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
return Ok(());
}
- let (by_repo, without_repo) = group_lines_by_repo(lines);
- let has_repos = !by_repo.is_empty() && !args.no_stratify;
+ let mut grouped_lines = group_lines(&lines, args.stratify);
- if args.no_stratify && !by_repo.is_empty() {
+ if args.stratify != Stratify::None {
eprintln!(
- "Stratification disabled (--no-stratify), splitting {} examples by line",
+ "Stratifying by {} ({} unique groups, {} examples)",
+ args.stratify,
+ grouped_lines.len(),
total_lines
);
- } else if has_repos {
+ } else {
eprintln!(
- "Stratifying by repository_url ({} unique repositories, {} examples)",
- by_repo.len(),
- total_lines - without_repo.len()
+ "No stratification, splitting {} examples by line",
+ total_lines
);
- if !without_repo.is_empty() {
- eprintln!(
- " + {} examples without repository_url (distributed proportionally)",
- without_repo.len()
- );
- }
}
let mut rng = match args.seed {
@@ -288,53 +279,31 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
None => rand::rngs::StdRng::from_os_rng(),
};
- let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
-
- if has_repos {
- let mut repos: Vec<String> = by_repo.keys().cloned().collect();
- repos.shuffle(&mut rng);
+ grouped_lines.shuffle(&mut rng);
- let repo_counts = compute_split_counts(&specs, repos.len())?;
+ let line_targets = compute_split_counts(&specs, total_lines)?;
+ let rest_index = specs.iter().position(|s| matches!(s.size, SplitSize::Rest));
+ let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
+ let mut group_iter = grouped_lines.into_iter();
- let mut repo_iter = repos.into_iter();
- for (split_idx, &count) in repo_counts.iter().enumerate() {
- for _ in 0..count {
- if let Some(repo) = repo_iter.next() {
- if let Some(repo_lines) = by_repo.get(&repo) {
- split_outputs[split_idx].extend(repo_lines.iter().cloned());
- }
- }
- }
+ for (split_idx, &target) in line_targets.iter().enumerate() {
+ if Some(split_idx) == rest_index {
+ continue;
}
-
- if !without_repo.is_empty() {
- let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
- let mut no_repo_shuffled = without_repo;
- no_repo_shuffled.shuffle(&mut rng);
-
- let mut line_iter = no_repo_shuffled.into_iter();
- for (split_idx, &count) in no_repo_counts.iter().enumerate() {
- for _ in 0..count {
- if let Some(line) = line_iter.next() {
- split_outputs[split_idx].push(line);
- }
- }
+ let mut accumulated = 0;
+ while accumulated < target {
+ if let Some(group) = group_iter.next() {
+ accumulated += group.len();
+ split_outputs[split_idx].extend(group);
+ } else {
+ break;
}
}
- } else {
- let line_counts = compute_split_counts(&specs, total_lines)?;
- let mut all_lines: Vec<String> = by_repo.into_values().flatten().collect();
- all_lines.extend(without_repo);
- all_lines.shuffle(&mut rng);
-
- let mut line_iter = all_lines.into_iter();
+ }
- for (split_idx, &count) in line_counts.iter().enumerate() {
- for _ in 0..count {
- if let Some(line) = line_iter.next() {
- split_outputs[split_idx].push(line);
- }
- }
+ if let Some(idx) = rest_index {
+ for group in group_iter {
+ split_outputs[idx].extend(group);
}
}
@@ -346,6 +315,39 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
Ok(())
}
+/// Groups lines by the specified stratification field.
+///
+/// When `stratify` is `None`, each line becomes its own group.
+/// When a line is missing the stratification field, it is also placed in its own group.
+fn group_lines(lines: &[String], stratify: Stratify) -> Vec<Vec<String>> {
+ if stratify == Stratify::None {
+ return lines.iter().map(|line| vec![line.clone()]).collect();
+ }
+
+ let field = match stratify {
+ Stratify::Repo => "repository_url",
+ Stratify::CursorPath => "cursor_path",
+ Stratify::None => unreachable!(),
+ };
+
+ let mut groups: HashMap<String, Vec<String>> = HashMap::new();
+ let mut ungrouped: Vec<Vec<String>> = Vec::new();
+
+ for line in lines {
+ let key = serde_json::from_str::<Value>(line)
+ .ok()
+ .and_then(|v| v.get(field)?.as_str().map(|s| s.to_string()));
+ match key {
+ Some(key) => groups.entry(key).or_default().push(line.clone()),
+ None => ungrouped.push(vec![line.clone()]),
+ }
+ }
+
+ let mut result: Vec<Vec<String>> = groups.into_values().collect();
+ result.extend(ungrouped);
+ result
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -389,15 +391,11 @@ mod tests {
}
#[test]
- fn test_get_repository_url() {
- let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
- assert_eq!(
- get_repository_url(line),
- Some("https://github.com/example/repo".to_string())
- );
-
- let line_no_repo = r#"{"data": 123}"#;
- assert_eq!(get_repository_url(line_no_repo), None);
+ fn test_group_lines_none() {
+ let lines = vec!["a".to_string(), "b".to_string(), "c".to_string()];
+ let groups = group_lines(&lines, Stratify::None);
+ assert_eq!(groups.len(), 3);
+ assert!(groups.iter().all(|g| g.len() == 1));
}
#[test]
@@ -457,12 +455,30 @@ mod tests {
r#"{"id": 4}"#.to_string(),
];
- let (by_repo, without_repo) = group_lines_by_repo(lines);
+ let groups = group_lines(&lines, Stratify::Repo);
+
+ let grouped_count: usize = groups.iter().filter(|g| g.len() > 1).count();
+ let ungrouped_count: usize = groups.iter().filter(|g| g.len() == 1).count();
+ let total_lines: usize = groups.iter().map(|g| g.len()).sum();
- assert_eq!(by_repo.len(), 2);
- assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
- assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
- assert_eq!(without_repo.len(), 1);
+ assert_eq!(grouped_count, 1); // repo1 has 2 lines
+ assert_eq!(ungrouped_count, 2); // repo2 (1 line) + line without repo
+ assert_eq!(total_lines, 4);
+ }
+
+ #[test]
+ fn test_group_lines_by_cursor_path() {
+ let lines = vec![
+ r#"{"cursor_path": "src/main.rs", "id": 1}"#.to_string(),
+ r#"{"cursor_path": "src/main.rs", "id": 2}"#.to_string(),
+ r#"{"cursor_path": "src/lib.rs", "id": 3}"#.to_string(),
+ ];
+
+ let groups = group_lines(&lines, Stratify::CursorPath);
+
+ let total_lines: usize = groups.iter().map(|g| g.len()).sum();
+ assert_eq!(groups.len(), 2);
+ assert_eq!(total_lines, 3);
}
#[test]
@@ -484,7 +500,7 @@ mod tests {
let args = SplitArgs {
seed: Some(42),
- no_stratify: false,
+ stratify: Stratify::Repo,
};
let inputs = vec![
input.path().to_path_buf(),
@@ -502,14 +518,18 @@ mod tests {
assert_eq!(train_lines.len() + valid_lines.len(), 8);
- let train_repos: std::collections::HashSet<_> = train_lines
- .iter()
- .filter_map(|l| get_repository_url(l))
- .collect();
- let valid_repos: std::collections::HashSet<_> = valid_lines
- .iter()
- .filter_map(|l| get_repository_url(l))
- .collect();
+ let get_repo = |line: &str| -> Option<String> {
+ let value: Value = serde_json::from_str(line).ok()?;
+ value
+ .get("repository_url")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string())
+ };
+
+ let train_repos: std::collections::HashSet<_> =
+ train_lines.iter().filter_map(|l| get_repo(l)).collect();
+ let valid_repos: std::collections::HashSet<_> =
+ valid_lines.iter().filter_map(|l| get_repo(l)).collect();
assert!(
train_repos.is_disjoint(&valid_repos),
@@ -531,4 +551,54 @@ mod tests {
];
assert!(compute_split_counts(&specs, 100).is_err());
}
+
+ #[test]
+ fn test_absolute_targets_lines_not_groups() {
+ // 5 repos × 3 lines each = 15 total lines.
+ // `train=6` should target ~6 lines (2 groups), NOT 6 groups (all 15 lines).
+ let input = create_temp_jsonl(&[
+ r#"{"repository_url": "r1", "id": 1}"#,
+ r#"{"repository_url": "r1", "id": 2}"#,
+ r#"{"repository_url": "r1", "id": 3}"#,
+ r#"{"repository_url": "r2", "id": 4}"#,
+ r#"{"repository_url": "r2", "id": 5}"#,
+ r#"{"repository_url": "r2", "id": 6}"#,
+ r#"{"repository_url": "r3", "id": 7}"#,
+ r#"{"repository_url": "r3", "id": 8}"#,
+ r#"{"repository_url": "r3", "id": 9}"#,
+ r#"{"repository_url": "r4", "id": 10}"#,
+ r#"{"repository_url": "r4", "id": 11}"#,
+ r#"{"repository_url": "r4", "id": 12}"#,
+ r#"{"repository_url": "r5", "id": 13}"#,
+ r#"{"repository_url": "r5", "id": 14}"#,
+ r#"{"repository_url": "r5", "id": 15}"#,
+ ]);
+
+ let temp_dir = tempfile::tempdir().unwrap();
+ let train_path = temp_dir.path().join("train.jsonl");
+ let valid_path = temp_dir.path().join("valid.jsonl");
+
+ let args = SplitArgs {
+ seed: Some(42),
+ stratify: Stratify::Repo,
+ };
+ let inputs = vec![
+ input.path().to_path_buf(),
+ PathBuf::from(format!("{}=6", train_path.display())),
+ PathBuf::from(format!("{}=rest", valid_path.display())),
+ ];
+
+ run_split(&args, &inputs).unwrap();
+
+ let train_content = std::fs::read_to_string(&train_path).unwrap();
+ let valid_content = std::fs::read_to_string(&valid_path).unwrap();
+
+ let train_lines: Vec<&str> = train_content.lines().collect();
+ let valid_lines: Vec<&str> = valid_content.lines().collect();
+
+ // With 3-line groups, train should get 2 groups (6 lines) to meet the
+ // target of 6, NOT 6 groups (which don't even exist). Valid gets the rest.
+ assert_eq!(train_lines.len(), 6);
+ assert_eq!(valid_lines.len(), 9);
+ }
}