1//! `ep split` implementation.
2//!
3//! This command splits a JSONL dataset into multiple files based on size specifications,
4//! with stratification by repository URL (if the field is present).
5//!
6//! # Usage
7//!
8//! ```text
9//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
10//! ```
11//!
12//! If `input.jsonl` is not provided or is `-`, reads from stdin.
13//!
14//! # Size specifications
15//!
16//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
17//! - `100` - absolute count of repositories (if stratified) or examples
18//! - `rest` - all remaining items (only one split can use this)
19//!
20//! # Stratification
21//!
22//! When examples have a `repository_url` field, the split is stratified by repository.
23//! This ensures each output file contains examples from non-overlapping repositories.
24//! Size specifications apply to the number of repositories, not individual examples.
25//!
26//! Examples without `repository_url` are distributed proportionally across all outputs.
27
28use anyhow::{Context as _, Result, bail};
29use clap::Args;
30use rand::SeedableRng;
31use rand::seq::SliceRandom;
32use serde_json::Value;
33use std::collections::HashMap;
34use std::fs::File;
35use std::io::{self, BufRead, BufReader, BufWriter, Write};
36use std::path::{Path, PathBuf};
37
38/// `ep split` CLI args.
39#[derive(Debug, Args, Clone)]
40#[command(
41 about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
42 after_help = r#"SIZE SPECIFICATIONS:
43 <percentage>% Percentage of total (e.g., 80%)
44 <count> Absolute number (e.g., 100)
45 rest All remaining items (only one output can use this)
46
47 When stratifying by repository_url, sizes apply to repositories, not examples.
48
49EXAMPLES:
50 # Split 80% train, 20% validation
51 ep split input.jsonl train.jsonl=80% valid.jsonl=rest
52
53 # Split into train/valid/test
54 ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
55
56 # Use absolute counts (100 repos to train, rest to valid)
57 ep split input.jsonl train.jsonl=100 valid.jsonl=rest
58
59 # Read from stdin
60 cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
61
62 # Reproducible split with seed
63 ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
64
65 # Disable stratification (split by examples, not repositories)
66 ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest
67
68STRATIFICATION:
69 When examples have a "repository_url" field, the split ensures each output
70 file contains examples from non-overlapping repositories. This prevents
71 data leakage between train/test splits. Use --no-stratify to disable this
72 behavior and split by individual examples instead.
73"#
74)]
75pub struct SplitArgs {
76 /// Random seed for reproducibility
77 #[arg(long)]
78 pub seed: Option<u64>,
79
80 /// Disable stratification by repository_url (split by examples instead)
81 #[arg(long)]
82 pub no_stratify: bool,
83}
84
85#[derive(Debug, Clone)]
86pub enum SplitSize {
87 Percentage(f64),
88 Absolute(usize),
89 Rest,
90}
91
92#[derive(Debug, Clone)]
93pub struct SplitSpec {
94 pub path: PathBuf,
95 pub size: SplitSize,
96}
97
98fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
99 let (path, size_str) = spec
100 .rsplit_once('=')
101 .with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
102
103 let size = if size_str == "rest" {
104 SplitSize::Rest
105 } else if size_str.ends_with('%') {
106 let pct_str = size_str.trim_end_matches('%');
107 let pct: f64 = pct_str
108 .parse()
109 .with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
110 if !(0.0..=100.0).contains(&pct) {
111 bail!("percentage must be between 0 and 100, got {}", pct);
112 }
113 SplitSize::Percentage(pct / 100.0)
114 } else {
115 let count: usize = size_str
116 .parse()
117 .with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
118 SplitSize::Absolute(count)
119 };
120
121 Ok(SplitSpec {
122 path: PathBuf::from(path),
123 size,
124 })
125}
126
127fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
128 let reader: Box<dyn BufRead> = match input {
129 Some(path) => {
130 let file =
131 File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
132 Box::new(BufReader::new(file))
133 }
134 None => Box::new(BufReader::new(io::stdin())),
135 };
136
137 let lines: Vec<String> = reader
138 .lines()
139 .collect::<io::Result<Vec<_>>>()
140 .context("failed to read input lines")?;
141
142 Ok(lines)
143}
144
145fn get_repository_url(line: &str) -> Option<String> {
146 let value: Value = serde_json::from_str(line).ok()?;
147 value
148 .get("repository_url")
149 .and_then(|v| v.as_str())
150 .map(|s| s.to_string())
151}
152
153fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
154 let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
155 let mut without_repo: Vec<String> = Vec::new();
156
157 for line in lines {
158 if let Some(repo_url) = get_repository_url(&line) {
159 by_repo.entry(repo_url).or_default().push(line);
160 } else {
161 without_repo.push(line);
162 }
163 }
164
165 (by_repo, without_repo)
166}
167
168fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
169 let mut counts = vec![0usize; specs.len()];
170 let mut remaining = total;
171 let mut rest_index: Option<usize> = None;
172
173 for (i, spec) in specs.iter().enumerate() {
174 match &spec.size {
175 SplitSize::Percentage(pct) => {
176 let count = (total as f64 * pct).round() as usize;
177 counts[i] = count.min(remaining);
178 remaining = remaining.saturating_sub(counts[i]);
179 }
180 SplitSize::Absolute(count) => {
181 counts[i] = (*count).min(remaining);
182 remaining = remaining.saturating_sub(counts[i]);
183 }
184 SplitSize::Rest => {
185 if rest_index.is_some() {
186 bail!("only one split can use 'rest'");
187 }
188 rest_index = Some(i);
189 }
190 }
191 }
192
193 if let Some(idx) = rest_index {
194 counts[idx] = remaining;
195 }
196
197 Ok(counts)
198}
199
200fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
201 if let Some(parent) = path.parent() {
202 if !parent.as_os_str().is_empty() {
203 std::fs::create_dir_all(parent)
204 .with_context(|| format!("failed to create directory '{}'", parent.display()))?;
205 }
206 }
207
208 let file =
209 File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
210 let mut writer = BufWriter::new(file);
211
212 for line in lines {
213 writeln!(writer, "{}", line)
214 .with_context(|| format!("failed to write to '{}'", path.display()))?;
215 }
216
217 writer
218 .flush()
219 .with_context(|| format!("failed to flush '{}'", path.display()))?;
220
221 Ok(())
222}
223
224pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
225 if inputs.is_empty() {
226 bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
227 }
228
229 let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
230 if inputs.first().is_some_and(|p| {
231 let s = p.to_string_lossy();
232 !s.contains('=')
233 }) {
234 let first = inputs.first().map(|p| p.as_path());
235 let first = if first == Some(Path::new("-")) {
236 None
237 } else {
238 first
239 };
240 (first, &inputs[1..])
241 } else {
242 (None, inputs)
243 };
244
245 if split_specs_raw.is_empty() {
246 bail!("no split specifications provided");
247 }
248
249 let specs: Vec<SplitSpec> = split_specs_raw
250 .iter()
251 .map(|p| parse_split_spec(&p.to_string_lossy()))
252 .collect::<Result<Vec<_>>>()?;
253
254 let lines = read_lines_from_input(input_path)?;
255 let total_lines = lines.len();
256
257 if total_lines == 0 {
258 for spec in &specs {
259 write_lines_to_file(&spec.path, &[])?;
260 }
261 return Ok(());
262 }
263
264 let (by_repo, without_repo) = group_lines_by_repo(lines);
265 let has_repos = !by_repo.is_empty() && !args.no_stratify;
266
267 if args.no_stratify && !by_repo.is_empty() {
268 eprintln!(
269 "Stratification disabled (--no-stratify), splitting {} examples by line",
270 total_lines
271 );
272 } else if has_repos {
273 eprintln!(
274 "Stratifying by repository_url ({} unique repositories, {} examples)",
275 by_repo.len(),
276 total_lines - without_repo.len()
277 );
278 if !without_repo.is_empty() {
279 eprintln!(
280 " + {} examples without repository_url (distributed proportionally)",
281 without_repo.len()
282 );
283 }
284 }
285
286 let mut rng = match args.seed {
287 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
288 None => rand::rngs::StdRng::from_os_rng(),
289 };
290
291 let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
292
293 if has_repos {
294 let mut repos: Vec<String> = by_repo.keys().cloned().collect();
295 repos.shuffle(&mut rng);
296
297 let repo_counts = compute_split_counts(&specs, repos.len())?;
298
299 let mut repo_iter = repos.into_iter();
300 for (split_idx, &count) in repo_counts.iter().enumerate() {
301 for _ in 0..count {
302 if let Some(repo) = repo_iter.next() {
303 if let Some(repo_lines) = by_repo.get(&repo) {
304 split_outputs[split_idx].extend(repo_lines.iter().cloned());
305 }
306 }
307 }
308 }
309
310 if !without_repo.is_empty() {
311 let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
312 let mut no_repo_shuffled = without_repo;
313 no_repo_shuffled.shuffle(&mut rng);
314
315 let mut line_iter = no_repo_shuffled.into_iter();
316 for (split_idx, &count) in no_repo_counts.iter().enumerate() {
317 for _ in 0..count {
318 if let Some(line) = line_iter.next() {
319 split_outputs[split_idx].push(line);
320 }
321 }
322 }
323 }
324 } else {
325 let line_counts = compute_split_counts(&specs, total_lines)?;
326 let mut all_lines: Vec<String> = by_repo.into_values().flatten().collect();
327 all_lines.extend(without_repo);
328 all_lines.shuffle(&mut rng);
329
330 let mut line_iter = all_lines.into_iter();
331
332 for (split_idx, &count) in line_counts.iter().enumerate() {
333 for _ in 0..count {
334 if let Some(line) = line_iter.next() {
335 split_outputs[split_idx].push(line);
336 }
337 }
338 }
339 }
340
341 for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
342 write_lines_to_file(&spec.path, output_lines)?;
343 eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
344 }
345
346 Ok(())
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use std::io::Write;
353 use tempfile::NamedTempFile;
354
355 fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
356 let mut file = NamedTempFile::new().unwrap();
357 for line in lines {
358 writeln!(file, "{}", line).unwrap();
359 }
360 file.flush().unwrap();
361 file
362 }
363
364 #[test]
365 fn test_parse_split_spec_percentage() {
366 let spec = parse_split_spec("train.jsonl=80%").unwrap();
367 assert_eq!(spec.path, PathBuf::from("train.jsonl"));
368 match spec.size {
369 SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
370 _ => panic!("expected percentage"),
371 }
372 }
373
374 #[test]
375 fn test_parse_split_spec_absolute() {
376 let spec = parse_split_spec("test.jsonl=100").unwrap();
377 assert_eq!(spec.path, PathBuf::from("test.jsonl"));
378 match spec.size {
379 SplitSize::Absolute(n) => assert_eq!(n, 100),
380 _ => panic!("expected absolute"),
381 }
382 }
383
384 #[test]
385 fn test_parse_split_spec_rest() {
386 let spec = parse_split_spec("valid.jsonl=rest").unwrap();
387 assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
388 assert!(matches!(spec.size, SplitSize::Rest));
389 }
390
391 #[test]
392 fn test_get_repository_url() {
393 let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
394 assert_eq!(
395 get_repository_url(line),
396 Some("https://github.com/example/repo".to_string())
397 );
398
399 let line_no_repo = r#"{"data": 123}"#;
400 assert_eq!(get_repository_url(line_no_repo), None);
401 }
402
403 #[test]
404 fn test_compute_split_counts_percentage() {
405 let specs = vec![
406 SplitSpec {
407 path: PathBuf::from("a"),
408 size: SplitSize::Percentage(0.8),
409 },
410 SplitSpec {
411 path: PathBuf::from("b"),
412 size: SplitSize::Percentage(0.2),
413 },
414 ];
415 let counts = compute_split_counts(&specs, 100).unwrap();
416 assert_eq!(counts, vec![80, 20]);
417 }
418
419 #[test]
420 fn test_compute_split_counts_with_rest() {
421 let specs = vec![
422 SplitSpec {
423 path: PathBuf::from("a"),
424 size: SplitSize::Percentage(0.8),
425 },
426 SplitSpec {
427 path: PathBuf::from("b"),
428 size: SplitSize::Rest,
429 },
430 ];
431 let counts = compute_split_counts(&specs, 100).unwrap();
432 assert_eq!(counts, vec![80, 20]);
433 }
434
435 #[test]
436 fn test_compute_split_counts_absolute() {
437 let specs = vec![
438 SplitSpec {
439 path: PathBuf::from("a"),
440 size: SplitSize::Absolute(50),
441 },
442 SplitSpec {
443 path: PathBuf::from("b"),
444 size: SplitSize::Rest,
445 },
446 ];
447 let counts = compute_split_counts(&specs, 100).unwrap();
448 assert_eq!(counts, vec![50, 50]);
449 }
450
451 #[test]
452 fn test_group_lines_by_repo() {
453 let lines = vec![
454 r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
455 r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
456 r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
457 r#"{"id": 4}"#.to_string(),
458 ];
459
460 let (by_repo, without_repo) = group_lines_by_repo(lines);
461
462 assert_eq!(by_repo.len(), 2);
463 assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
464 assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
465 assert_eq!(without_repo.len(), 1);
466 }
467
468 #[test]
469 fn test_run_split_basic() {
470 let input = create_temp_jsonl(&[
471 r#"{"repository_url": "repo1", "id": 1}"#,
472 r#"{"repository_url": "repo1", "id": 2}"#,
473 r#"{"repository_url": "repo2", "id": 3}"#,
474 r#"{"repository_url": "repo2", "id": 4}"#,
475 r#"{"repository_url": "repo3", "id": 5}"#,
476 r#"{"repository_url": "repo3", "id": 6}"#,
477 r#"{"repository_url": "repo4", "id": 7}"#,
478 r#"{"repository_url": "repo4", "id": 8}"#,
479 ]);
480
481 let temp_dir = tempfile::tempdir().unwrap();
482 let train_path = temp_dir.path().join("train.jsonl");
483 let valid_path = temp_dir.path().join("valid.jsonl");
484
485 let args = SplitArgs {
486 seed: Some(42),
487 no_stratify: false,
488 };
489 let inputs = vec![
490 input.path().to_path_buf(),
491 PathBuf::from(format!("{}=50%", train_path.display())),
492 PathBuf::from(format!("{}=rest", valid_path.display())),
493 ];
494
495 run_split(&args, &inputs).unwrap();
496
497 let train_content = std::fs::read_to_string(&train_path).unwrap();
498 let valid_content = std::fs::read_to_string(&valid_path).unwrap();
499
500 let train_lines: Vec<&str> = train_content.lines().collect();
501 let valid_lines: Vec<&str> = valid_content.lines().collect();
502
503 assert_eq!(train_lines.len() + valid_lines.len(), 8);
504
505 let train_repos: std::collections::HashSet<_> = train_lines
506 .iter()
507 .filter_map(|l| get_repository_url(l))
508 .collect();
509 let valid_repos: std::collections::HashSet<_> = valid_lines
510 .iter()
511 .filter_map(|l| get_repository_url(l))
512 .collect();
513
514 assert!(
515 train_repos.is_disjoint(&valid_repos),
516 "train and valid should have non-overlapping repos"
517 );
518 }
519
520 #[test]
521 fn test_multiple_rest_fails() {
522 let specs = vec![
523 SplitSpec {
524 path: PathBuf::from("a"),
525 size: SplitSize::Rest,
526 },
527 SplitSpec {
528 path: PathBuf::from("b"),
529 size: SplitSize::Rest,
530 },
531 ];
532 assert!(compute_split_counts(&specs, 100).is_err());
533 }
534}