diff --git a/Cargo.lock b/Cargo.lock index 611165099c6e615174f437f0768f00850b43f14f..ccf1edcdbf85a11e5c874813a229ada281071fbc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5308,6 +5308,7 @@ dependencies = [ "smol", "sqlez", "sqlez_macros", + "tempfile", "terminal_view", "util", "wasmtime", diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 36f264c70ed579865b3af6f25ac1d6690c89603d..874f344865e2517f3622f44ae0e08462cf40df5a 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -71,3 +71,4 @@ indoc.workspace = true gpui = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true +tempfile.workspace = true diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index c12c2ae2aadb1f308da4db9a5682379392287223..e64246a2af7fb278a6ff5f0d6bfa1db6943d64fe 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -14,6 +14,7 @@ mod reorder_patch; mod retrieve_context; mod score; mod split_commit; +mod split_dataset; mod synthesize; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; use collections::HashSet; @@ -40,6 +41,7 @@ use crate::progress::Progress; use crate::retrieve_context::run_context_retrieval; use crate::score::run_scoring; use crate::split_commit::SplitCommitArgs; +use crate::split_dataset::SplitArgs; use crate::synthesize::{SynthesizeConfig, run_synthesize}; #[derive(Parser, Debug)] @@ -124,6 +126,8 @@ enum Command { Clean, /// Generate an evaluation example by splitting a chronologically-ordered commit SplitCommit(SplitCommitArgs), + /// Split a JSONL dataset into multiple files (stratified by repository_url if present) + Split(SplitArgs), } impl Display for Command { @@ -178,6 +182,7 @@ impl Display for Command { } Command::Clean => write!(f, "clean"), Command::SplitCommit(_) => write!(f, "split-commit"), + Command::Split(_) => write!(f, "split"), } } } @@ -416,6 +421,13 @@ fn main() { } return; } + Command::Split(split_args) => { + if let Err(error) = split_dataset::run_split(split_args, &args.inputs) { + eprintln!("{error:#}"); + std::process::exit(1); + } + return; + } _ => {} } @@ -509,7 +521,8 @@ fn main() { } Command::Clean | Command::Synthesize(_) - | Command::SplitCommit(_) => { + | Command::SplitCommit(_) + | Command::Split(_) => { unreachable!() } } @@ -603,17 +616,12 @@ async fn handle_error( indoc::indoc! {" While processing \"{}\": - {:?} - - Written to: \x1b[36m{}\x1b[0m - - Cursor File: \x1b[36m{}\x1b[0m - - Explore this example data with: - fx \x1b[36m{}\x1b[0m + \x1b[31m{:?}\x1b[0m - Re-run this example with: - cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m + Example: \x1b[36m{}\x1b[0m + Error file: \x1b[36m{}\x1b[0m + Cursor file: \x1b[36m{}\x1b[0m + Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m "}, example.spec.name, error, diff --git a/crates/edit_prediction_cli/src/split_dataset.rs b/crates/edit_prediction_cli/src/split_dataset.rs new file mode 100644 index 0000000000000000000000000000000000000000..d4ff39be87d341ec1204bb6f668006050d1afac3 --- /dev/null +++ b/crates/edit_prediction_cli/src/split_dataset.rs @@ -0,0 +1,516 @@ +//! `ep split` implementation. +//! +//! This command splits a JSONL dataset into multiple files based on size specifications, +//! with stratification by repository URL (if the field is present). +//! +//! # Usage +//! +//! ```text +//! ep split [input.jsonl] = = ... +//! ``` +//! +//! If `input.jsonl` is not provided or is `-`, reads from stdin. +//! +//! # Size specifications +//! +//! - `80%` - percentage of total (repositories if stratified, examples otherwise) +//! - `100` - absolute count of repositories (if stratified) or examples +//! - `rest` - all remaining items (only one split can use this) +//! +//! # Stratification +//! +//! When examples have a `repository_url` field, the split is stratified by repository. +//! This ensures each output file contains examples from non-overlapping repositories. +//! Size specifications apply to the number of repositories, not individual examples. +//! +//! Examples without `repository_url` are distributed proportionally across all outputs. + +use anyhow::{Context as _, Result, bail}; +use clap::Args; +use rand::SeedableRng; +use rand::seq::SliceRandom; +use serde_json::Value; +use std::collections::HashMap; +use std::fs::File; +use std::io::{self, BufRead, BufReader, BufWriter, Write}; +use std::path::{Path, PathBuf}; + +/// `ep split` CLI args. +#[derive(Debug, Args, Clone)] +#[command( + about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)", + after_help = r#"SIZE SPECIFICATIONS: + % Percentage of total (e.g., 80%) + Absolute number (e.g., 100) + rest All remaining items (only one output can use this) + + When stratifying by repository_url, sizes apply to repositories, not examples. + +EXAMPLES: + # Split 80% train, 20% validation + ep split input.jsonl train.jsonl=80% valid.jsonl=rest + + # Split into train/valid/test + ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest + + # Use absolute counts (100 repos to train, rest to valid) + ep split input.jsonl train.jsonl=100 valid.jsonl=rest + + # Read from stdin + cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest + + # Reproducible split with seed + ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest + +STRATIFICATION: + When examples have a "repository_url" field, the split ensures each output + file contains examples from non-overlapping repositories. This prevents + data leakage between train/test splits. +"# +)] +pub struct SplitArgs { + /// Random seed for reproducibility + #[arg(long)] + pub seed: Option, +} + +#[derive(Debug, Clone)] +pub enum SplitSize { + Percentage(f64), + Absolute(usize), + Rest, +} + +#[derive(Debug, Clone)] +pub struct SplitSpec { + pub path: PathBuf, + pub size: SplitSize, +} + +fn parse_split_spec(spec: &str) -> Result { + let (path, size_str) = spec + .rsplit_once('=') + .with_context(|| format!("invalid split spec '{}': expected =", spec))?; + + let size = if size_str == "rest" { + SplitSize::Rest + } else if size_str.ends_with('%') { + let pct_str = size_str.trim_end_matches('%'); + let pct: f64 = pct_str + .parse() + .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?; + if !(0.0..=100.0).contains(&pct) { + bail!("percentage must be between 0 and 100, got {}", pct); + } + SplitSize::Percentage(pct / 100.0) + } else { + let count: usize = size_str + .parse() + .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?; + SplitSize::Absolute(count) + }; + + Ok(SplitSpec { + path: PathBuf::from(path), + size, + }) +} + +fn read_lines_from_input(input: Option<&Path>) -> Result> { + let reader: Box = match input { + Some(path) => { + let file = + File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?; + Box::new(BufReader::new(file)) + } + None => Box::new(BufReader::new(io::stdin())), + }; + + let lines: Vec = reader + .lines() + .collect::>>() + .context("failed to read input lines")?; + + Ok(lines) +} + +fn get_repository_url(line: &str) -> Option { + let value: Value = serde_json::from_str(line).ok()?; + value + .get("repository_url") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) +} + +fn group_lines_by_repo(lines: Vec) -> (HashMap>, Vec) { + let mut by_repo: HashMap> = HashMap::new(); + let mut without_repo: Vec = Vec::new(); + + for line in lines { + if let Some(repo_url) = get_repository_url(&line) { + by_repo.entry(repo_url).or_default().push(line); + } else { + without_repo.push(line); + } + } + + (by_repo, without_repo) +} + +fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result> { + let mut counts = vec![0usize; specs.len()]; + let mut remaining = total; + let mut rest_index: Option = None; + + for (i, spec) in specs.iter().enumerate() { + match &spec.size { + SplitSize::Percentage(pct) => { + let count = (total as f64 * pct).round() as usize; + counts[i] = count.min(remaining); + remaining = remaining.saturating_sub(counts[i]); + } + SplitSize::Absolute(count) => { + counts[i] = (*count).min(remaining); + remaining = remaining.saturating_sub(counts[i]); + } + SplitSize::Rest => { + if rest_index.is_some() { + bail!("only one split can use 'rest'"); + } + rest_index = Some(i); + } + } + } + + if let Some(idx) = rest_index { + counts[idx] = remaining; + } + + Ok(counts) +} + +fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> { + if let Some(parent) = path.parent() { + if !parent.as_os_str().is_empty() { + std::fs::create_dir_all(parent) + .with_context(|| format!("failed to create directory '{}'", parent.display()))?; + } + } + + let file = + File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?; + let mut writer = BufWriter::new(file); + + for line in lines { + writeln!(writer, "{}", line) + .with_context(|| format!("failed to write to '{}'", path.display()))?; + } + + writer + .flush() + .with_context(|| format!("failed to flush '{}'", path.display()))?; + + Ok(()) +} + +pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> { + if inputs.is_empty() { + bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest"); + } + + let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) = + if inputs.first().is_some_and(|p| { + let s = p.to_string_lossy(); + !s.contains('=') + }) { + let first = inputs.first().map(|p| p.as_path()); + let first = if first == Some(Path::new("-")) { + None + } else { + first + }; + (first, &inputs[1..]) + } else { + (None, inputs) + }; + + if split_specs_raw.is_empty() { + bail!("no split specifications provided"); + } + + let specs: Vec = split_specs_raw + .iter() + .map(|p| parse_split_spec(&p.to_string_lossy())) + .collect::>>()?; + + let lines = read_lines_from_input(input_path)?; + let total_lines = lines.len(); + + if total_lines == 0 { + for spec in &specs { + write_lines_to_file(&spec.path, &[])?; + } + return Ok(()); + } + + let (by_repo, without_repo) = group_lines_by_repo(lines); + let has_repos = !by_repo.is_empty(); + + if has_repos { + eprintln!( + "Stratifying by repository_url ({} unique repositories, {} examples)", + by_repo.len(), + total_lines - without_repo.len() + ); + if !without_repo.is_empty() { + eprintln!( + " + {} examples without repository_url (distributed proportionally)", + without_repo.len() + ); + } + } + + let mut rng = match args.seed { + Some(seed) => rand::rngs::StdRng::seed_from_u64(seed), + None => rand::rngs::StdRng::from_os_rng(), + }; + + let mut split_outputs: Vec> = vec![Vec::new(); specs.len()]; + + if has_repos { + let mut repos: Vec = by_repo.keys().cloned().collect(); + repos.shuffle(&mut rng); + + let repo_counts = compute_split_counts(&specs, repos.len())?; + + let mut repo_iter = repos.into_iter(); + for (split_idx, &count) in repo_counts.iter().enumerate() { + for _ in 0..count { + if let Some(repo) = repo_iter.next() { + if let Some(repo_lines) = by_repo.get(&repo) { + split_outputs[split_idx].extend(repo_lines.iter().cloned()); + } + } + } + } + + if !without_repo.is_empty() { + let no_repo_counts = compute_split_counts(&specs, without_repo.len())?; + let mut no_repo_shuffled = without_repo; + no_repo_shuffled.shuffle(&mut rng); + + let mut line_iter = no_repo_shuffled.into_iter(); + for (split_idx, &count) in no_repo_counts.iter().enumerate() { + for _ in 0..count { + if let Some(line) = line_iter.next() { + split_outputs[split_idx].push(line); + } + } + } + } + } else { + let line_counts = compute_split_counts(&specs, total_lines)?; + let mut shuffled_lines = without_repo; + shuffled_lines.shuffle(&mut rng); + + let mut line_iter = shuffled_lines.into_iter(); + for (split_idx, &count) in line_counts.iter().enumerate() { + for _ in 0..count { + if let Some(line) = line_iter.next() { + split_outputs[split_idx].push(line); + } + } + } + } + + for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) { + write_lines_to_file(&spec.path, output_lines)?; + eprintln!("{}: {} examples", spec.path.display(), output_lines.len()); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::NamedTempFile; + + fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile { + let mut file = NamedTempFile::new().unwrap(); + for line in lines { + writeln!(file, "{}", line).unwrap(); + } + file.flush().unwrap(); + file + } + + #[test] + fn test_parse_split_spec_percentage() { + let spec = parse_split_spec("train.jsonl=80%").unwrap(); + assert_eq!(spec.path, PathBuf::from("train.jsonl")); + match spec.size { + SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001), + _ => panic!("expected percentage"), + } + } + + #[test] + fn test_parse_split_spec_absolute() { + let spec = parse_split_spec("test.jsonl=100").unwrap(); + assert_eq!(spec.path, PathBuf::from("test.jsonl")); + match spec.size { + SplitSize::Absolute(n) => assert_eq!(n, 100), + _ => panic!("expected absolute"), + } + } + + #[test] + fn test_parse_split_spec_rest() { + let spec = parse_split_spec("valid.jsonl=rest").unwrap(); + assert_eq!(spec.path, PathBuf::from("valid.jsonl")); + assert!(matches!(spec.size, SplitSize::Rest)); + } + + #[test] + fn test_get_repository_url() { + let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#; + assert_eq!( + get_repository_url(line), + Some("https://github.com/example/repo".to_string()) + ); + + let line_no_repo = r#"{"data": 123}"#; + assert_eq!(get_repository_url(line_no_repo), None); + } + + #[test] + fn test_compute_split_counts_percentage() { + let specs = vec![ + SplitSpec { + path: PathBuf::from("a"), + size: SplitSize::Percentage(0.8), + }, + SplitSpec { + path: PathBuf::from("b"), + size: SplitSize::Percentage(0.2), + }, + ]; + let counts = compute_split_counts(&specs, 100).unwrap(); + assert_eq!(counts, vec![80, 20]); + } + + #[test] + fn test_compute_split_counts_with_rest() { + let specs = vec![ + SplitSpec { + path: PathBuf::from("a"), + size: SplitSize::Percentage(0.8), + }, + SplitSpec { + path: PathBuf::from("b"), + size: SplitSize::Rest, + }, + ]; + let counts = compute_split_counts(&specs, 100).unwrap(); + assert_eq!(counts, vec![80, 20]); + } + + #[test] + fn test_compute_split_counts_absolute() { + let specs = vec![ + SplitSpec { + path: PathBuf::from("a"), + size: SplitSize::Absolute(50), + }, + SplitSpec { + path: PathBuf::from("b"), + size: SplitSize::Rest, + }, + ]; + let counts = compute_split_counts(&specs, 100).unwrap(); + assert_eq!(counts, vec![50, 50]); + } + + #[test] + fn test_group_lines_by_repo() { + let lines = vec![ + r#"{"repository_url": "repo1", "id": 1}"#.to_string(), + r#"{"repository_url": "repo1", "id": 2}"#.to_string(), + r#"{"repository_url": "repo2", "id": 3}"#.to_string(), + r#"{"id": 4}"#.to_string(), + ]; + + let (by_repo, without_repo) = group_lines_by_repo(lines); + + assert_eq!(by_repo.len(), 2); + assert_eq!(by_repo.get("repo1").unwrap().len(), 2); + assert_eq!(by_repo.get("repo2").unwrap().len(), 1); + assert_eq!(without_repo.len(), 1); + } + + #[test] + fn test_run_split_basic() { + let input = create_temp_jsonl(&[ + r#"{"repository_url": "repo1", "id": 1}"#, + r#"{"repository_url": "repo1", "id": 2}"#, + r#"{"repository_url": "repo2", "id": 3}"#, + r#"{"repository_url": "repo2", "id": 4}"#, + r#"{"repository_url": "repo3", "id": 5}"#, + r#"{"repository_url": "repo3", "id": 6}"#, + r#"{"repository_url": "repo4", "id": 7}"#, + r#"{"repository_url": "repo4", "id": 8}"#, + ]); + + let temp_dir = tempfile::tempdir().unwrap(); + let train_path = temp_dir.path().join("train.jsonl"); + let valid_path = temp_dir.path().join("valid.jsonl"); + + let args = SplitArgs { seed: Some(42) }; + let inputs = vec![ + input.path().to_path_buf(), + PathBuf::from(format!("{}=50%", train_path.display())), + PathBuf::from(format!("{}=rest", valid_path.display())), + ]; + + run_split(&args, &inputs).unwrap(); + + let train_content = std::fs::read_to_string(&train_path).unwrap(); + let valid_content = std::fs::read_to_string(&valid_path).unwrap(); + + let train_lines: Vec<&str> = train_content.lines().collect(); + let valid_lines: Vec<&str> = valid_content.lines().collect(); + + assert_eq!(train_lines.len() + valid_lines.len(), 8); + + let train_repos: std::collections::HashSet<_> = train_lines + .iter() + .filter_map(|l| get_repository_url(l)) + .collect(); + let valid_repos: std::collections::HashSet<_> = valid_lines + .iter() + .filter_map(|l| get_repository_url(l)) + .collect(); + + assert!( + train_repos.is_disjoint(&valid_repos), + "train and valid should have non-overlapping repos" + ); + } + + #[test] + fn test_multiple_rest_fails() { + let specs = vec![ + SplitSpec { + path: PathBuf::from("a"), + size: SplitSize::Rest, + }, + SplitSpec { + path: PathBuf::from("b"), + size: SplitSize::Rest, + }, + ]; + assert!(compute_split_counts(&specs, 100).is_err()); + } +}