@@ -62,16 +62,24 @@ EXAMPLES:
# Reproducible split with seed
ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
+ # Disable stratification (split by examples, not repositories)
+ ep split --no-stratify input.jsonl train.jsonl=80% valid.jsonl=rest
+
STRATIFICATION:
When examples have a "repository_url" field, the split ensures each output
file contains examples from non-overlapping repositories. This prevents
- data leakage between train/test splits.
+ data leakage between train/test splits. Use --no-stratify to disable this
+ behavior and split by individual examples instead.
"#
)]
pub struct SplitArgs {
/// Random seed for reproducibility
#[arg(long)]
pub seed: Option<u64>,
+
+ /// Disable stratification by repository_url (split by examples instead)
+ #[arg(long)]
+ pub no_stratify: bool,
}
#[derive(Debug, Clone)]
@@ -254,9 +262,14 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
}
let (by_repo, without_repo) = group_lines_by_repo(lines);
- let has_repos = !by_repo.is_empty();
+ let has_repos = !by_repo.is_empty() && !args.no_stratify;
- if has_repos {
+ if args.no_stratify && !by_repo.is_empty() {
+ eprintln!(
+ "Stratification disabled (--no-stratify), splitting {} examples by line",
+ total_lines
+ );
+ } else if has_repos {
eprintln!(
"Stratifying by repository_url ({} unique repositories, {} examples)",
by_repo.len(),
@@ -310,10 +323,12 @@ pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
}
} else {
let line_counts = compute_split_counts(&specs, total_lines)?;
- let mut shuffled_lines = without_repo;
- shuffled_lines.shuffle(&mut rng);
+ let mut all_lines: Vec<String> = by_repo.into_values().flatten().collect();
+ all_lines.extend(without_repo);
+ all_lines.shuffle(&mut rng);
+
+ let mut line_iter = all_lines.into_iter();
- let mut line_iter = shuffled_lines.into_iter();
for (split_idx, &count) in line_counts.iter().enumerate() {
for _ in 0..count {
if let Some(line) = line_iter.next() {
@@ -467,7 +482,10 @@ mod tests {
let train_path = temp_dir.path().join("train.jsonl");
let valid_path = temp_dir.path().join("valid.jsonl");
- let args = SplitArgs { seed: Some(42) };
+ let args = SplitArgs {
+ seed: Some(42),
+ no_stratify: false,
+ };
let inputs = vec![
input.path().to_path_buf(),
PathBuf::from(format!("{}=50%", train_path.display())),