Cargo.lock 🔗
@@ -5308,6 +5308,7 @@ dependencies = [
"smol",
"sqlez",
"sqlez_macros",
+ "tempfile",
"terminal_view",
"util",
"wasmtime",
Oleksiy Syvokon created
Adds a new `ep split` command that splits JSONL datasets into multiple
output files with stratification by `repository_url` when present.
Example usage:
ep split input.jsonl train.jsonl=80% valid.jsonl=rest
Release Notes:
- N/A
Cargo.lock | 1
crates/edit_prediction_cli/Cargo.toml | 1
crates/edit_prediction_cli/src/main.rs | 30
crates/edit_prediction_cli/src/split_dataset.rs | 516 +++++++++++++++++++
4 files changed, 537 insertions(+), 11 deletions(-)
@@ -5308,6 +5308,7 @@ dependencies = [
"smol",
"sqlez",
"sqlez_macros",
+ "tempfile",
"terminal_view",
"util",
"wasmtime",
@@ -71,3 +71,4 @@ indoc.workspace = true
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
+tempfile.workspace = true
@@ -14,6 +14,7 @@ mod reorder_patch;
mod retrieve_context;
mod score;
mod split_commit;
+mod split_dataset;
mod synthesize;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use collections::HashSet;
@@ -40,6 +41,7 @@ use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
use crate::split_commit::SplitCommitArgs;
+use crate::split_dataset::SplitArgs;
use crate::synthesize::{SynthesizeConfig, run_synthesize};
#[derive(Parser, Debug)]
@@ -124,6 +126,8 @@ enum Command {
Clean,
/// Generate an evaluation example by splitting a chronologically-ordered commit
SplitCommit(SplitCommitArgs),
+ /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
+ Split(SplitArgs),
}
impl Display for Command {
@@ -178,6 +182,7 @@ impl Display for Command {
}
Command::Clean => write!(f, "clean"),
Command::SplitCommit(_) => write!(f, "split-commit"),
+ Command::Split(_) => write!(f, "split"),
}
}
}
@@ -416,6 +421,13 @@ fn main() {
}
return;
}
+ Command::Split(split_args) => {
+ if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
+ eprintln!("{error:#}");
+ std::process::exit(1);
+ }
+ return;
+ }
_ => {}
}
@@ -509,7 +521,8 @@ fn main() {
}
Command::Clean
| Command::Synthesize(_)
- | Command::SplitCommit(_) => {
+ | Command::SplitCommit(_)
+ | Command::Split(_) => {
unreachable!()
}
}
@@ -603,17 +616,12 @@ async fn handle_error(
indoc::indoc! {"
While processing \"{}\":
- {:?}
-
- Written to: \x1b[36m{}\x1b[0m
-
- Cursor File: \x1b[36m{}\x1b[0m
-
- Explore this example data with:
- fx \x1b[36m{}\x1b[0m
+ \x1b[31m{:?}\x1b[0m
- Re-run this example with:
- cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
+ Example: \x1b[36m{}\x1b[0m
+ Error file: \x1b[36m{}\x1b[0m
+ Cursor file: \x1b[36m{}\x1b[0m
+ Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
"},
example.spec.name,
error,
@@ -0,0 +1,516 @@
+//! `ep split` implementation.
+//!
+//! This command splits a JSONL dataset into multiple files based on size specifications,
+//! with stratification by repository URL (if the field is present).
+//!
+//! # Usage
+//!
+//! ```text
+//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
+//! ```
+//!
+//! If `input.jsonl` is not provided or is `-`, reads from stdin.
+//!
+//! # Size specifications
+//!
+//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
+//! - `100` - absolute count of repositories (if stratified) or examples
+//! - `rest` - all remaining items (only one split can use this)
+//!
+//! # Stratification
+//!
+//! When examples have a `repository_url` field, the split is stratified by repository.
+//! This ensures each output file contains examples from non-overlapping repositories.
+//! Size specifications apply to the number of repositories, not individual examples.
+//!
+//! Examples without `repository_url` are distributed proportionally across all outputs.
+
+use anyhow::{Context as _, Result, bail};
+use clap::Args;
+use rand::SeedableRng;
+use rand::seq::SliceRandom;
+use serde_json::Value;
+use std::collections::HashMap;
+use std::fs::File;
+use std::io::{self, BufRead, BufReader, BufWriter, Write};
+use std::path::{Path, PathBuf};
+
+/// `ep split` CLI args.
+#[derive(Debug, Args, Clone)]
+#[command(
+ about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
+ after_help = r#"SIZE SPECIFICATIONS:
+ <percentage>% Percentage of total (e.g., 80%)
+ <count> Absolute number (e.g., 100)
+ rest All remaining items (only one output can use this)
+
+ When stratifying by repository_url, sizes apply to repositories, not examples.
+
+EXAMPLES:
+ # Split 80% train, 20% validation
+ ep split input.jsonl train.jsonl=80% valid.jsonl=rest
+
+ # Split into train/valid/test
+ ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
+
+ # Use absolute counts (100 repos to train, rest to valid)
+ ep split input.jsonl train.jsonl=100 valid.jsonl=rest
+
+ # Read from stdin
+ cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
+
+ # Reproducible split with seed
+ ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
+
+STRATIFICATION:
+ When examples have a "repository_url" field, the split ensures each output
+ file contains examples from non-overlapping repositories. This prevents
+ data leakage between train/test splits.
+"#
+)]
+pub struct SplitArgs {
+ /// Random seed for reproducibility
+ #[arg(long)]
+ pub seed: Option<u64>,
+}
+
+#[derive(Debug, Clone)]
+pub enum SplitSize {
+ Percentage(f64),
+ Absolute(usize),
+ Rest,
+}
+
+#[derive(Debug, Clone)]
+pub struct SplitSpec {
+ pub path: PathBuf,
+ pub size: SplitSize,
+}
+
+fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
+ let (path, size_str) = spec
+ .rsplit_once('=')
+ .with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
+
+ let size = if size_str == "rest" {
+ SplitSize::Rest
+ } else if size_str.ends_with('%') {
+ let pct_str = size_str.trim_end_matches('%');
+ let pct: f64 = pct_str
+ .parse()
+ .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
+ if !(0.0..=100.0).contains(&pct) {
+ bail!("percentage must be between 0 and 100, got {}", pct);
+ }
+ SplitSize::Percentage(pct / 100.0)
+ } else {
+ let count: usize = size_str
+ .parse()
+ .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
+ SplitSize::Absolute(count)
+ };
+
+ Ok(SplitSpec {
+ path: PathBuf::from(path),
+ size,
+ })
+}
+
+fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
+ let reader: Box<dyn BufRead> = match input {
+ Some(path) => {
+ let file =
+ File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
+ Box::new(BufReader::new(file))
+ }
+ None => Box::new(BufReader::new(io::stdin())),
+ };
+
+ let lines: Vec<String> = reader
+ .lines()
+ .collect::<io::Result<Vec<_>>>()
+ .context("failed to read input lines")?;
+
+ Ok(lines)
+}
+
+fn get_repository_url(line: &str) -> Option<String> {
+ let value: Value = serde_json::from_str(line).ok()?;
+ value
+ .get("repository_url")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string())
+}
+
+fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
+ let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
+ let mut without_repo: Vec<String> = Vec::new();
+
+ for line in lines {
+ if let Some(repo_url) = get_repository_url(&line) {
+ by_repo.entry(repo_url).or_default().push(line);
+ } else {
+ without_repo.push(line);
+ }
+ }
+
+ (by_repo, without_repo)
+}
+
+fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
+ let mut counts = vec![0usize; specs.len()];
+ let mut remaining = total;
+ let mut rest_index: Option<usize> = None;
+
+ for (i, spec) in specs.iter().enumerate() {
+ match &spec.size {
+ SplitSize::Percentage(pct) => {
+ let count = (total as f64 * pct).round() as usize;
+ counts[i] = count.min(remaining);
+ remaining = remaining.saturating_sub(counts[i]);
+ }
+ SplitSize::Absolute(count) => {
+ counts[i] = (*count).min(remaining);
+ remaining = remaining.saturating_sub(counts[i]);
+ }
+ SplitSize::Rest => {
+ if rest_index.is_some() {
+ bail!("only one split can use 'rest'");
+ }
+ rest_index = Some(i);
+ }
+ }
+ }
+
+ if let Some(idx) = rest_index {
+ counts[idx] = remaining;
+ }
+
+ Ok(counts)
+}
+
+fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
+ if let Some(parent) = path.parent() {
+ if !parent.as_os_str().is_empty() {
+ std::fs::create_dir_all(parent)
+ .with_context(|| format!("failed to create directory '{}'", parent.display()))?;
+ }
+ }
+
+ let file =
+ File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
+ let mut writer = BufWriter::new(file);
+
+ for line in lines {
+ writeln!(writer, "{}", line)
+ .with_context(|| format!("failed to write to '{}'", path.display()))?;
+ }
+
+ writer
+ .flush()
+ .with_context(|| format!("failed to flush '{}'", path.display()))?;
+
+ Ok(())
+}
+
+pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
+ if inputs.is_empty() {
+ bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
+ }
+
+ let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
+ if inputs.first().is_some_and(|p| {
+ let s = p.to_string_lossy();
+ !s.contains('=')
+ }) {
+ let first = inputs.first().map(|p| p.as_path());
+ let first = if first == Some(Path::new("-")) {
+ None
+ } else {
+ first
+ };
+ (first, &inputs[1..])
+ } else {
+ (None, inputs)
+ };
+
+ if split_specs_raw.is_empty() {
+ bail!("no split specifications provided");
+ }
+
+ let specs: Vec<SplitSpec> = split_specs_raw
+ .iter()
+ .map(|p| parse_split_spec(&p.to_string_lossy()))
+ .collect::<Result<Vec<_>>>()?;
+
+ let lines = read_lines_from_input(input_path)?;
+ let total_lines = lines.len();
+
+ if total_lines == 0 {
+ for spec in &specs {
+ write_lines_to_file(&spec.path, &[])?;
+ }
+ return Ok(());
+ }
+
+ let (by_repo, without_repo) = group_lines_by_repo(lines);
+ let has_repos = !by_repo.is_empty();
+
+ if has_repos {
+ eprintln!(
+ "Stratifying by repository_url ({} unique repositories, {} examples)",
+ by_repo.len(),
+ total_lines - without_repo.len()
+ );
+ if !without_repo.is_empty() {
+ eprintln!(
+ " + {} examples without repository_url (distributed proportionally)",
+ without_repo.len()
+ );
+ }
+ }
+
+ let mut rng = match args.seed {
+ Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
+ None => rand::rngs::StdRng::from_os_rng(),
+ };
+
+ let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
+
+ if has_repos {
+ let mut repos: Vec<String> = by_repo.keys().cloned().collect();
+ repos.shuffle(&mut rng);
+
+ let repo_counts = compute_split_counts(&specs, repos.len())?;
+
+ let mut repo_iter = repos.into_iter();
+ for (split_idx, &count) in repo_counts.iter().enumerate() {
+ for _ in 0..count {
+ if let Some(repo) = repo_iter.next() {
+ if let Some(repo_lines) = by_repo.get(&repo) {
+ split_outputs[split_idx].extend(repo_lines.iter().cloned());
+ }
+ }
+ }
+ }
+
+ if !without_repo.is_empty() {
+ let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
+ let mut no_repo_shuffled = without_repo;
+ no_repo_shuffled.shuffle(&mut rng);
+
+ let mut line_iter = no_repo_shuffled.into_iter();
+ for (split_idx, &count) in no_repo_counts.iter().enumerate() {
+ for _ in 0..count {
+ if let Some(line) = line_iter.next() {
+ split_outputs[split_idx].push(line);
+ }
+ }
+ }
+ }
+ } else {
+ let line_counts = compute_split_counts(&specs, total_lines)?;
+ let mut shuffled_lines = without_repo;
+ shuffled_lines.shuffle(&mut rng);
+
+ let mut line_iter = shuffled_lines.into_iter();
+ for (split_idx, &count) in line_counts.iter().enumerate() {
+ for _ in 0..count {
+ if let Some(line) = line_iter.next() {
+ split_outputs[split_idx].push(line);
+ }
+ }
+ }
+ }
+
+ for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
+ write_lines_to_file(&spec.path, output_lines)?;
+ eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
+ }
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::io::Write;
+ use tempfile::NamedTempFile;
+
+ fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
+ let mut file = NamedTempFile::new().unwrap();
+ for line in lines {
+ writeln!(file, "{}", line).unwrap();
+ }
+ file.flush().unwrap();
+ file
+ }
+
+ #[test]
+ fn test_parse_split_spec_percentage() {
+ let spec = parse_split_spec("train.jsonl=80%").unwrap();
+ assert_eq!(spec.path, PathBuf::from("train.jsonl"));
+ match spec.size {
+ SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
+ _ => panic!("expected percentage"),
+ }
+ }
+
+ #[test]
+ fn test_parse_split_spec_absolute() {
+ let spec = parse_split_spec("test.jsonl=100").unwrap();
+ assert_eq!(spec.path, PathBuf::from("test.jsonl"));
+ match spec.size {
+ SplitSize::Absolute(n) => assert_eq!(n, 100),
+ _ => panic!("expected absolute"),
+ }
+ }
+
+ #[test]
+ fn test_parse_split_spec_rest() {
+ let spec = parse_split_spec("valid.jsonl=rest").unwrap();
+ assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
+ assert!(matches!(spec.size, SplitSize::Rest));
+ }
+
+ #[test]
+ fn test_get_repository_url() {
+ let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
+ assert_eq!(
+ get_repository_url(line),
+ Some("https://github.com/example/repo".to_string())
+ );
+
+ let line_no_repo = r#"{"data": 123}"#;
+ assert_eq!(get_repository_url(line_no_repo), None);
+ }
+
+ #[test]
+ fn test_compute_split_counts_percentage() {
+ let specs = vec![
+ SplitSpec {
+ path: PathBuf::from("a"),
+ size: SplitSize::Percentage(0.8),
+ },
+ SplitSpec {
+ path: PathBuf::from("b"),
+ size: SplitSize::Percentage(0.2),
+ },
+ ];
+ let counts = compute_split_counts(&specs, 100).unwrap();
+ assert_eq!(counts, vec![80, 20]);
+ }
+
+ #[test]
+ fn test_compute_split_counts_with_rest() {
+ let specs = vec![
+ SplitSpec {
+ path: PathBuf::from("a"),
+ size: SplitSize::Percentage(0.8),
+ },
+ SplitSpec {
+ path: PathBuf::from("b"),
+ size: SplitSize::Rest,
+ },
+ ];
+ let counts = compute_split_counts(&specs, 100).unwrap();
+ assert_eq!(counts, vec![80, 20]);
+ }
+
+ #[test]
+ fn test_compute_split_counts_absolute() {
+ let specs = vec![
+ SplitSpec {
+ path: PathBuf::from("a"),
+ size: SplitSize::Absolute(50),
+ },
+ SplitSpec {
+ path: PathBuf::from("b"),
+ size: SplitSize::Rest,
+ },
+ ];
+ let counts = compute_split_counts(&specs, 100).unwrap();
+ assert_eq!(counts, vec![50, 50]);
+ }
+
+ #[test]
+ fn test_group_lines_by_repo() {
+ let lines = vec![
+ r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
+ r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
+ r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
+ r#"{"id": 4}"#.to_string(),
+ ];
+
+ let (by_repo, without_repo) = group_lines_by_repo(lines);
+
+ assert_eq!(by_repo.len(), 2);
+ assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
+ assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
+ assert_eq!(without_repo.len(), 1);
+ }
+
+ #[test]
+ fn test_run_split_basic() {
+ let input = create_temp_jsonl(&[
+ r#"{"repository_url": "repo1", "id": 1}"#,
+ r#"{"repository_url": "repo1", "id": 2}"#,
+ r#"{"repository_url": "repo2", "id": 3}"#,
+ r#"{"repository_url": "repo2", "id": 4}"#,
+ r#"{"repository_url": "repo3", "id": 5}"#,
+ r#"{"repository_url": "repo3", "id": 6}"#,
+ r#"{"repository_url": "repo4", "id": 7}"#,
+ r#"{"repository_url": "repo4", "id": 8}"#,
+ ]);
+
+ let temp_dir = tempfile::tempdir().unwrap();
+ let train_path = temp_dir.path().join("train.jsonl");
+ let valid_path = temp_dir.path().join("valid.jsonl");
+
+ let args = SplitArgs { seed: Some(42) };
+ let inputs = vec![
+ input.path().to_path_buf(),
+ PathBuf::from(format!("{}=50%", train_path.display())),
+ PathBuf::from(format!("{}=rest", valid_path.display())),
+ ];
+
+ run_split(&args, &inputs).unwrap();
+
+ let train_content = std::fs::read_to_string(&train_path).unwrap();
+ let valid_content = std::fs::read_to_string(&valid_path).unwrap();
+
+ let train_lines: Vec<&str> = train_content.lines().collect();
+ let valid_lines: Vec<&str> = valid_content.lines().collect();
+
+ assert_eq!(train_lines.len() + valid_lines.len(), 8);
+
+ let train_repos: std::collections::HashSet<_> = train_lines
+ .iter()
+ .filter_map(|l| get_repository_url(l))
+ .collect();
+ let valid_repos: std::collections::HashSet<_> = valid_lines
+ .iter()
+ .filter_map(|l| get_repository_url(l))
+ .collect();
+
+ assert!(
+ train_repos.is_disjoint(&valid_repos),
+ "train and valid should have non-overlapping repos"
+ );
+ }
+
+ #[test]
+ fn test_multiple_rest_fails() {
+ let specs = vec![
+ SplitSpec {
+ path: PathBuf::from("a"),
+ size: SplitSize::Rest,
+ },
+ SplitSpec {
+ path: PathBuf::from("b"),
+ size: SplitSize::Rest,
+ },
+ ];
+ assert!(compute_split_counts(&specs, 100).is_err());
+ }
+}