1//! `ep split` implementation.
2//!
3//! This command splits a JSONL dataset into multiple files based on size specifications,
4//! with optional stratification by a JSON field.
5//!
6//! # Usage
7//!
8//! ```text
9//! ep split [--stratify=<field>] [input.jsonl] <out1>=<size1> <out2>=<size2> ...
10//! ```
11//!
12//! If `input.jsonl` is not provided or is `-`, reads from stdin.
13//!
14//! # Size specifications
15//!
16//! - `80%` - percentage of total examples (lines)
17//! - `100` - approximate absolute count of examples (lines)
18//! - `rest` - all remaining items (only one split can use this)
19//!
20//! # Stratification
21//!
22//! The `--stratify` flag controls how examples are grouped before splitting:
23//!
24//! - `cursor-path` (default): group by the `cursor_path` JSON field
25//! - `repo`: group by the `repository_url` JSON field
26//! - `none`: no grouping, split individual examples
27//!
28//! When stratifying, the split ensures each output file contains examples from
29//! non-overlapping groups. Size specifications always apply to the number of
30//! examples (lines), with whole groups assigned greedily to meet the target.
31//! Examples missing the stratification field are treated as individual groups.
32
33use anyhow::{Context as _, Result, bail};
34use clap::Args;
35use rand::SeedableRng;
36use rand::seq::SliceRandom;
37use serde_json::Value;
38use std::collections::HashMap;
39use std::fs::File;
40use std::io::{self, BufRead, BufReader, BufWriter, Write};
41use std::path::{Path, PathBuf};
42
43/// `ep split` CLI args.
44#[derive(Debug, Args, Clone)]
45#[command(
46 about = "Split a JSONL dataset into multiple files with optional stratification",
47 after_help = r#"SIZE SPECIFICATIONS:
48 <percentage>% Percentage of total (e.g., 80%)
49 <count> Absolute number (e.g., 100)
50 rest All remaining items (only one output can use this)
51
52 Sizes always apply to examples (lines). When stratifying, whole groups
53 are assigned greedily to approximate the target count.
54
55EXAMPLES:
56 # Split 80% train, 20% validation (default: stratify by cursor_path)
57 ep split input.jsonl train.jsonl=80% valid.jsonl=rest
58
59 # Split into train/valid/test
60 ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
61
62 # Stratify by repository_url instead of cursor_path
63 ep split --stratify=repo input.jsonl train.jsonl=80% valid.jsonl=rest
64
65 # No stratification (split by individual examples)
66 ep split --stratify=none input.jsonl train.jsonl=80% valid.jsonl=rest
67
68 # Read from stdin
69 cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
70
71 # Reproducible split with seed
72 ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
73
74STRATIFICATION:
75 Controls how examples are grouped before splitting:
76 cursor-path Group by "cursor_path" field (default)
77 repo Group by "repository_url" field
78 none No grouping, split individual examples
79
80 When stratifying, the split ensures each output file contains examples
81 from non-overlapping groups. This prevents data leakage between
82 train/test splits.
83"#
84)]
85pub struct SplitArgs {
86 /// Random seed for reproducibility
87 #[arg(long)]
88 pub seed: Option<u64>,
89
90 /// Stratification field for splitting the dataset
91 #[arg(long, default_value = "cursor-path")]
92 pub stratify: Stratify,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, strum::Display)]
96pub enum Stratify {
97 #[strum(serialize = "cursor_path")]
98 CursorPath,
99 #[strum(serialize = "repo")]
100 Repo,
101 #[strum(serialize = "none")]
102 None,
103}
104
105#[derive(Debug, Clone)]
106pub enum SplitSize {
107 Percentage(f64),
108 Absolute(usize),
109 Rest,
110}
111
112#[derive(Debug, Clone)]
113pub struct SplitSpec {
114 pub path: PathBuf,
115 pub size: SplitSize,
116}
117
118fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
119 let (path, size_str) = spec
120 .rsplit_once('=')
121 .with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
122
123 let size = if size_str == "rest" {
124 SplitSize::Rest
125 } else if size_str.ends_with('%') {
126 let pct_str = size_str.trim_end_matches('%');
127 let pct: f64 = pct_str
128 .parse()
129 .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
130 if !(0.0..=100.0).contains(&pct) {
131 bail!("percentage must be between 0 and 100, got {}", pct);
132 }
133 SplitSize::Percentage(pct / 100.0)
134 } else {
135 let count: usize = size_str
136 .parse()
137 .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
138 SplitSize::Absolute(count)
139 };
140
141 Ok(SplitSpec {
142 path: PathBuf::from(path),
143 size,
144 })
145}
146
147fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
148 let reader: Box<dyn BufRead> = match input {
149 Some(path) => {
150 let file =
151 File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
152 Box::new(BufReader::new(file))
153 }
154 None => Box::new(BufReader::new(io::stdin())),
155 };
156
157 let lines: Vec<String> = reader
158 .lines()
159 .collect::<io::Result<Vec<_>>>()
160 .context("failed to read input lines")?;
161
162 Ok(lines)
163}
164
165fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
166 let mut counts = vec![0usize; specs.len()];
167 let mut remaining = total;
168 let mut rest_index: Option<usize> = None;
169
170 for (i, spec) in specs.iter().enumerate() {
171 match &spec.size {
172 SplitSize::Percentage(pct) => {
173 let count = (total as f64 * pct).round() as usize;
174 counts[i] = count.min(remaining);
175 remaining = remaining.saturating_sub(counts[i]);
176 }
177 SplitSize::Absolute(count) => {
178 counts[i] = (*count).min(remaining);
179 remaining = remaining.saturating_sub(counts[i]);
180 }
181 SplitSize::Rest => {
182 if rest_index.is_some() {
183 bail!("only one split can use 'rest'");
184 }
185 rest_index = Some(i);
186 }
187 }
188 }
189
190 if let Some(idx) = rest_index {
191 counts[idx] = remaining;
192 }
193
194 Ok(counts)
195}
196
197fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
198 if let Some(parent) = path.parent() {
199 if !parent.as_os_str().is_empty() {
200 std::fs::create_dir_all(parent)
201 .with_context(|| format!("failed to create directory '{}'", parent.display()))?;
202 }
203 }
204
205 let file =
206 File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
207 let mut writer = BufWriter::new(file);
208
209 for line in lines {
210 writeln!(writer, "{}", line)
211 .with_context(|| format!("failed to write to '{}'", path.display()))?;
212 }
213
214 writer
215 .flush()
216 .with_context(|| format!("failed to flush '{}'", path.display()))?;
217
218 Ok(())
219}
220
221pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
222 if inputs.is_empty() {
223 bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
224 }
225
226 let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
227 if inputs.first().is_some_and(|p| {
228 let s = p.to_string_lossy();
229 !s.contains('=')
230 }) {
231 let first = inputs.first().map(|p| p.as_path());
232 let first = if first == Some(Path::new("-")) {
233 None
234 } else {
235 first
236 };
237 (first, &inputs[1..])
238 } else {
239 (None, inputs)
240 };
241
242 if split_specs_raw.is_empty() {
243 bail!("no split specifications provided");
244 }
245
246 let specs: Vec<SplitSpec> = split_specs_raw
247 .iter()
248 .map(|p| parse_split_spec(&p.to_string_lossy()))
249 .collect::<Result<Vec<_>>>()?;
250
251 let lines = read_lines_from_input(input_path)?;
252 let total_lines = lines.len();
253
254 if total_lines == 0 {
255 for spec in &specs {
256 write_lines_to_file(&spec.path, &[])?;
257 }
258 return Ok(());
259 }
260
261 let mut grouped_lines = group_lines(&lines, args.stratify);
262
263 if args.stratify != Stratify::None {
264 eprintln!(
265 "Stratifying by {} ({} unique groups, {} examples)",
266 args.stratify,
267 grouped_lines.len(),
268 total_lines
269 );
270 } else {
271 eprintln!(
272 "No stratification, splitting {} examples by line",
273 total_lines
274 );
275 }
276
277 let mut rng = match args.seed {
278 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
279 None => rand::rngs::StdRng::from_os_rng(),
280 };
281
282 grouped_lines.shuffle(&mut rng);
283
284 let line_targets = compute_split_counts(&specs, total_lines)?;
285 let rest_index = specs.iter().position(|s| matches!(s.size, SplitSize::Rest));
286 let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
287 let mut group_iter = grouped_lines.into_iter();
288
289 for (split_idx, &target) in line_targets.iter().enumerate() {
290 if Some(split_idx) == rest_index {
291 continue;
292 }
293 let mut accumulated = 0;
294 while accumulated < target {
295 if let Some(group) = group_iter.next() {
296 accumulated += group.len();
297 split_outputs[split_idx].extend(group);
298 } else {
299 break;
300 }
301 }
302 }
303
304 if let Some(idx) = rest_index {
305 for group in group_iter {
306 split_outputs[idx].extend(group);
307 }
308 }
309
310 for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
311 write_lines_to_file(&spec.path, output_lines)?;
312 eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
313 }
314
315 Ok(())
316}
317
318/// Groups lines by the specified stratification field.
319///
320/// When `stratify` is `None`, each line becomes its own group.
321/// When a line is missing the stratification field, it is also placed in its own group.
322fn group_lines(lines: &[String], stratify: Stratify) -> Vec<Vec<String>> {
323 if stratify == Stratify::None {
324 return lines.iter().map(|line| vec![line.clone()]).collect();
325 }
326
327 let field = match stratify {
328 Stratify::Repo => "repository_url",
329 Stratify::CursorPath => "cursor_path",
330 Stratify::None => unreachable!(),
331 };
332
333 let mut groups: HashMap<String, Vec<String>> = HashMap::new();
334 let mut ungrouped: Vec<Vec<String>> = Vec::new();
335
336 for line in lines {
337 let key = serde_json::from_str::<Value>(line)
338 .ok()
339 .and_then(|v| v.get(field)?.as_str().map(|s| s.to_string()));
340 match key {
341 Some(key) => groups.entry(key).or_default().push(line.clone()),
342 None => ungrouped.push(vec![line.clone()]),
343 }
344 }
345
346 let mut result: Vec<Vec<String>> = groups.into_values().collect();
347 result.extend(ungrouped);
348 result
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use std::io::Write;
355 use tempfile::NamedTempFile;
356
357 fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
358 let mut file = NamedTempFile::new().unwrap();
359 for line in lines {
360 writeln!(file, "{}", line).unwrap();
361 }
362 file.flush().unwrap();
363 file
364 }
365
366 #[test]
367 fn test_parse_split_spec_percentage() {
368 let spec = parse_split_spec("train.jsonl=80%").unwrap();
369 assert_eq!(spec.path, PathBuf::from("train.jsonl"));
370 match spec.size {
371 SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
372 _ => panic!("expected percentage"),
373 }
374 }
375
376 #[test]
377 fn test_parse_split_spec_absolute() {
378 let spec = parse_split_spec("test.jsonl=100").unwrap();
379 assert_eq!(spec.path, PathBuf::from("test.jsonl"));
380 match spec.size {
381 SplitSize::Absolute(n) => assert_eq!(n, 100),
382 _ => panic!("expected absolute"),
383 }
384 }
385
386 #[test]
387 fn test_parse_split_spec_rest() {
388 let spec = parse_split_spec("valid.jsonl=rest").unwrap();
389 assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
390 assert!(matches!(spec.size, SplitSize::Rest));
391 }
392
393 #[test]
394 fn test_group_lines_none() {
395 let lines = vec!["a".to_string(), "b".to_string(), "c".to_string()];
396 let groups = group_lines(&lines, Stratify::None);
397 assert_eq!(groups.len(), 3);
398 assert!(groups.iter().all(|g| g.len() == 1));
399 }
400
401 #[test]
402 fn test_compute_split_counts_percentage() {
403 let specs = vec![
404 SplitSpec {
405 path: PathBuf::from("a"),
406 size: SplitSize::Percentage(0.8),
407 },
408 SplitSpec {
409 path: PathBuf::from("b"),
410 size: SplitSize::Percentage(0.2),
411 },
412 ];
413 let counts = compute_split_counts(&specs, 100).unwrap();
414 assert_eq!(counts, vec![80, 20]);
415 }
416
417 #[test]
418 fn test_compute_split_counts_with_rest() {
419 let specs = vec![
420 SplitSpec {
421 path: PathBuf::from("a"),
422 size: SplitSize::Percentage(0.8),
423 },
424 SplitSpec {
425 path: PathBuf::from("b"),
426 size: SplitSize::Rest,
427 },
428 ];
429 let counts = compute_split_counts(&specs, 100).unwrap();
430 assert_eq!(counts, vec![80, 20]);
431 }
432
433 #[test]
434 fn test_compute_split_counts_absolute() {
435 let specs = vec![
436 SplitSpec {
437 path: PathBuf::from("a"),
438 size: SplitSize::Absolute(50),
439 },
440 SplitSpec {
441 path: PathBuf::from("b"),
442 size: SplitSize::Rest,
443 },
444 ];
445 let counts = compute_split_counts(&specs, 100).unwrap();
446 assert_eq!(counts, vec![50, 50]);
447 }
448
449 #[test]
450 fn test_group_lines_by_repo() {
451 let lines = vec![
452 r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
453 r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
454 r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
455 r#"{"id": 4}"#.to_string(),
456 ];
457
458 let groups = group_lines(&lines, Stratify::Repo);
459
460 let grouped_count: usize = groups.iter().filter(|g| g.len() > 1).count();
461 let ungrouped_count: usize = groups.iter().filter(|g| g.len() == 1).count();
462 let total_lines: usize = groups.iter().map(|g| g.len()).sum();
463
464 assert_eq!(grouped_count, 1); // repo1 has 2 lines
465 assert_eq!(ungrouped_count, 2); // repo2 (1 line) + line without repo
466 assert_eq!(total_lines, 4);
467 }
468
469 #[test]
470 fn test_group_lines_by_cursor_path() {
471 let lines = vec![
472 r#"{"cursor_path": "src/main.rs", "id": 1}"#.to_string(),
473 r#"{"cursor_path": "src/main.rs", "id": 2}"#.to_string(),
474 r#"{"cursor_path": "src/lib.rs", "id": 3}"#.to_string(),
475 ];
476
477 let groups = group_lines(&lines, Stratify::CursorPath);
478
479 let total_lines: usize = groups.iter().map(|g| g.len()).sum();
480 assert_eq!(groups.len(), 2);
481 assert_eq!(total_lines, 3);
482 }
483
484 #[test]
485 fn test_run_split_basic() {
486 let input = create_temp_jsonl(&[
487 r#"{"repository_url": "repo1", "id": 1}"#,
488 r#"{"repository_url": "repo1", "id": 2}"#,
489 r#"{"repository_url": "repo2", "id": 3}"#,
490 r#"{"repository_url": "repo2", "id": 4}"#,
491 r#"{"repository_url": "repo3", "id": 5}"#,
492 r#"{"repository_url": "repo3", "id": 6}"#,
493 r#"{"repository_url": "repo4", "id": 7}"#,
494 r#"{"repository_url": "repo4", "id": 8}"#,
495 ]);
496
497 let temp_dir = tempfile::tempdir().unwrap();
498 let train_path = temp_dir.path().join("train.jsonl");
499 let valid_path = temp_dir.path().join("valid.jsonl");
500
501 let args = SplitArgs {
502 seed: Some(42),
503 stratify: Stratify::Repo,
504 };
505 let inputs = vec![
506 input.path().to_path_buf(),
507 PathBuf::from(format!("{}=50%", train_path.display())),
508 PathBuf::from(format!("{}=rest", valid_path.display())),
509 ];
510
511 run_split(&args, &inputs).unwrap();
512
513 let train_content = std::fs::read_to_string(&train_path).unwrap();
514 let valid_content = std::fs::read_to_string(&valid_path).unwrap();
515
516 let train_lines: Vec<&str> = train_content.lines().collect();
517 let valid_lines: Vec<&str> = valid_content.lines().collect();
518
519 assert_eq!(train_lines.len() + valid_lines.len(), 8);
520
521 let get_repo = |line: &str| -> Option<String> {
522 let value: Value = serde_json::from_str(line).ok()?;
523 value
524 .get("repository_url")
525 .and_then(|v| v.as_str())
526 .map(|s| s.to_string())
527 };
528
529 let train_repos: std::collections::HashSet<_> =
530 train_lines.iter().filter_map(|l| get_repo(l)).collect();
531 let valid_repos: std::collections::HashSet<_> =
532 valid_lines.iter().filter_map(|l| get_repo(l)).collect();
533
534 assert!(
535 train_repos.is_disjoint(&valid_repos),
536 "train and valid should have non-overlapping repos"
537 );
538 }
539
540 #[test]
541 fn test_multiple_rest_fails() {
542 let specs = vec![
543 SplitSpec {
544 path: PathBuf::from("a"),
545 size: SplitSize::Rest,
546 },
547 SplitSpec {
548 path: PathBuf::from("b"),
549 size: SplitSize::Rest,
550 },
551 ];
552 assert!(compute_split_counts(&specs, 100).is_err());
553 }
554
555 #[test]
556 fn test_absolute_targets_lines_not_groups() {
557 // 5 repos × 3 lines each = 15 total lines.
558 // `train=6` should target ~6 lines (2 groups), NOT 6 groups (all 15 lines).
559 let input = create_temp_jsonl(&[
560 r#"{"repository_url": "r1", "id": 1}"#,
561 r#"{"repository_url": "r1", "id": 2}"#,
562 r#"{"repository_url": "r1", "id": 3}"#,
563 r#"{"repository_url": "r2", "id": 4}"#,
564 r#"{"repository_url": "r2", "id": 5}"#,
565 r#"{"repository_url": "r2", "id": 6}"#,
566 r#"{"repository_url": "r3", "id": 7}"#,
567 r#"{"repository_url": "r3", "id": 8}"#,
568 r#"{"repository_url": "r3", "id": 9}"#,
569 r#"{"repository_url": "r4", "id": 10}"#,
570 r#"{"repository_url": "r4", "id": 11}"#,
571 r#"{"repository_url": "r4", "id": 12}"#,
572 r#"{"repository_url": "r5", "id": 13}"#,
573 r#"{"repository_url": "r5", "id": 14}"#,
574 r#"{"repository_url": "r5", "id": 15}"#,
575 ]);
576
577 let temp_dir = tempfile::tempdir().unwrap();
578 let train_path = temp_dir.path().join("train.jsonl");
579 let valid_path = temp_dir.path().join("valid.jsonl");
580
581 let args = SplitArgs {
582 seed: Some(42),
583 stratify: Stratify::Repo,
584 };
585 let inputs = vec![
586 input.path().to_path_buf(),
587 PathBuf::from(format!("{}=6", train_path.display())),
588 PathBuf::from(format!("{}=rest", valid_path.display())),
589 ];
590
591 run_split(&args, &inputs).unwrap();
592
593 let train_content = std::fs::read_to_string(&train_path).unwrap();
594 let valid_content = std::fs::read_to_string(&valid_path).unwrap();
595
596 let train_lines: Vec<&str> = train_content.lines().collect();
597 let valid_lines: Vec<&str> = valid_content.lines().collect();
598
599 // With 3-line groups, train should get 2 groups (6 lines) to meet the
600 // target of 6, NOT 6 groups (which don't even exist). Valid gets the rest.
601 assert_eq!(train_lines.len(), 6);
602 assert_eq!(valid_lines.len(), 9);
603 }
604}