Add `ep truncate-patch` command to prepare commits evalset (#48204)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/main.rs                    | 15 +
crates/edit_prediction_cli/src/truncate_expected_patch.rs | 75 +++++++++
2 files changed, 90 insertions(+)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -23,6 +23,7 @@ mod score;
 mod split_commit;
 mod split_dataset;
 mod synthesize;
+mod truncate_expected_patch;
 mod word_diff;
 use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 use collections::HashSet;
@@ -54,6 +55,7 @@ use crate::score::run_scoring;
 use crate::split_commit::SplitCommitArgs;
 use crate::split_dataset::SplitArgs;
 use crate::synthesize::{SynthesizeConfig, run_synthesize};
+use crate::truncate_expected_patch::TruncatePatchArgs;
 
 #[derive(Parser, Debug)]
 #[command(name = "ep")]
@@ -193,6 +195,8 @@ enum Command {
     Clean,
     /// Generate an evaluation example by splitting a chronologically-ordered commit
     SplitCommit(SplitCommitArgs),
+    /// Truncate expected patch by the given criteria
+    TruncatePatch(TruncatePatchArgs),
     /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
     Split(SplitArgs),
     /// Filter a JSONL dataset by programming language (based on cursor_path extension)
@@ -233,6 +237,7 @@ impl Display for Command {
             }
             Command::Clean => write!(f, "clean"),
             Command::SplitCommit(_) => write!(f, "split-commit"),
+            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
             Command::Split(_) => write!(f, "split"),
             Command::FilterLanguages(_) => write!(f, "filter-languages"),
             Command::ImportBatch(args) => {
@@ -745,6 +750,15 @@ fn main() {
             }
             return;
         }
+        Command::TruncatePatch(truncate_args) => {
+            if let Err(error) =
+                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
+            {
+                eprintln!("{error:#}");
+                std::process::exit(1);
+            }
+            return;
+        }
         Command::Split(split_args) => {
             if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
                 eprintln!("{error:#}");
@@ -937,6 +951,7 @@ fn main() {
                                         | Command::Synthesize(_)
                                         | Command::SplitCommit(_)
                                         | Command::Split(_)
+                                        | Command::TruncatePatch(_)
                                         | Command::FilterLanguages(_)
                                         | Command::ImportBatch(_) => {
                                             unreachable!()

crates/edit_prediction_cli/src/truncate_expected_patch.rs 🔗

@@ -0,0 +1,75 @@
+use crate::example::{Example, read_example_files};
+use crate::reorder_patch::{Hunk, Patch, PatchLine};
+use clap::Args;
+use std::path::PathBuf;
+
+#[derive(Args, Debug, Clone)]
+pub struct TruncatePatchArgs {
+    /// Number of logical groups ahead to leave
+    #[arg(long)]
+    pub num_groups: usize,
+
+    /// Leave only edits in the file under the cursor
+    #[arg(long, default_value_t = false)]
+    pub current_file_only: bool,
+}
+
+pub fn run_truncate_expected_patch(
+    args: &TruncatePatchArgs,
+    inputs: &[PathBuf],
+) -> anyhow::Result<()> {
+    let stdin_path = PathBuf::from("-");
+    let inputs = if inputs.is_empty() {
+        std::slice::from_ref(&stdin_path)
+    } else {
+        inputs
+    };
+
+    let mut examples = read_example_files(inputs);
+    for example in &mut examples {
+        run_one_input(example, args)?;
+
+        println!("{}", serde_json::to_string(&example)?);
+    }
+    Ok(())
+}
+
+fn run_one_input(example: &mut Example, args: &TruncatePatchArgs) -> anyhow::Result<()> {
+    let mut patch = Patch::parse_unified_diff(&example.spec.expected_patches[0]);
+    let mut groups_left = args.num_groups;
+
+    patch.hunks.retain(|hunk| {
+        if groups_left == 0 {
+            return false;
+        }
+        if starts_new_group(hunk) {
+            groups_left -= 1;
+        }
+
+        if args.current_file_only {
+            return hunk.filename == example.spec.cursor_path.display().to_string();
+        }
+
+        true
+    });
+
+    // Remove all group headers
+    patch.header = String::new();
+    patch.hunks.iter_mut().for_each(|hunk| {
+        hunk.lines.retain(|line| match line {
+            PatchLine::Garbage(line) => !line.starts_with("//"),
+            _ => true,
+        });
+    });
+
+    example.spec.expected_patches[0] = patch.to_string();
+
+    Ok(())
+}
+
+fn starts_new_group(hunk: &Hunk) -> bool {
+    hunk.lines.iter().any(|line| match line {
+        PatchLine::Garbage(content) => content.starts_with("///"),
+        _ => false,
+    })
+}