diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 18dc4c2d6300f7e1a069aa4ab1ce962d12ac70f2..e22103634047fa306e0eff79f9f3146f1cda19c8 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -23,6 +23,7 @@ mod score; mod split_commit; mod split_dataset; mod synthesize; +mod truncate_expected_patch; mod word_diff; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; use collections::HashSet; @@ -54,6 +55,7 @@ use crate::score::run_scoring; use crate::split_commit::SplitCommitArgs; use crate::split_dataset::SplitArgs; use crate::synthesize::{SynthesizeConfig, run_synthesize}; +use crate::truncate_expected_patch::TruncatePatchArgs; #[derive(Parser, Debug)] #[command(name = "ep")] @@ -193,6 +195,8 @@ enum Command { Clean, /// Generate an evaluation example by splitting a chronologically-ordered commit SplitCommit(SplitCommitArgs), + /// Truncate expected patch by the given criteria + TruncatePatch(TruncatePatchArgs), /// Split a JSONL dataset into multiple files (stratified by repository_url if present) Split(SplitArgs), /// Filter a JSONL dataset by programming language (based on cursor_path extension) @@ -233,6 +237,7 @@ impl Display for Command { } Command::Clean => write!(f, "clean"), Command::SplitCommit(_) => write!(f, "split-commit"), + Command::TruncatePatch(_) => write!(f, "truncate-patch"), Command::Split(_) => write!(f, "split"), Command::FilterLanguages(_) => write!(f, "filter-languages"), Command::ImportBatch(args) => { @@ -745,6 +750,15 @@ fn main() { } return; } + Command::TruncatePatch(truncate_args) => { + if let Err(error) = + truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs) + { + eprintln!("{error:#}"); + std::process::exit(1); + } + return; + } Command::Split(split_args) => { if let Err(error) = split_dataset::run_split(split_args, &args.inputs) { eprintln!("{error:#}"); @@ -937,6 +951,7 @@ fn main() { | Command::Synthesize(_) | Command::SplitCommit(_) | Command::Split(_) + | Command::TruncatePatch(_) | Command::FilterLanguages(_) | Command::ImportBatch(_) => { unreachable!() diff --git a/crates/edit_prediction_cli/src/truncate_expected_patch.rs b/crates/edit_prediction_cli/src/truncate_expected_patch.rs new file mode 100644 index 0000000000000000000000000000000000000000..78811ee0fe711c87ffdfdc0f9ef7c016cd115d7a --- /dev/null +++ b/crates/edit_prediction_cli/src/truncate_expected_patch.rs @@ -0,0 +1,75 @@ +use crate::example::{Example, read_example_files}; +use crate::reorder_patch::{Hunk, Patch, PatchLine}; +use clap::Args; +use std::path::PathBuf; + +#[derive(Args, Debug, Clone)] +pub struct TruncatePatchArgs { + /// Number of logical groups ahead to leave + #[arg(long)] + pub num_groups: usize, + + /// Leave only edits in the file under the cursor + #[arg(long, default_value_t = false)] + pub current_file_only: bool, +} + +pub fn run_truncate_expected_patch( + args: &TruncatePatchArgs, + inputs: &[PathBuf], +) -> anyhow::Result<()> { + let stdin_path = PathBuf::from("-"); + let inputs = if inputs.is_empty() { + std::slice::from_ref(&stdin_path) + } else { + inputs + }; + + let mut examples = read_example_files(inputs); + for example in &mut examples { + run_one_input(example, args)?; + + println!("{}", serde_json::to_string(&example)?); + } + Ok(()) +} + +fn run_one_input(example: &mut Example, args: &TruncatePatchArgs) -> anyhow::Result<()> { + let mut patch = Patch::parse_unified_diff(&example.spec.expected_patches[0]); + let mut groups_left = args.num_groups; + + patch.hunks.retain(|hunk| { + if groups_left == 0 { + return false; + } + if starts_new_group(hunk) { + groups_left -= 1; + } + + if args.current_file_only { + return hunk.filename == example.spec.cursor_path.display().to_string(); + } + + true + }); + + // Remove all group headers + patch.header = String::new(); + patch.hunks.iter_mut().for_each(|hunk| { + hunk.lines.retain(|line| match line { + PatchLine::Garbage(line) => !line.starts_with("//"), + _ => true, + }); + }); + + example.spec.expected_patches[0] = patch.to_string(); + + Ok(()) +} + +fn starts_new_group(hunk: &Hunk) -> bool { + hunk.lines.iter().any(|line| match line { + PatchLine::Garbage(content) => content.starts_with("///"), + _ => false, + }) +}