@@ -23,6 +23,7 @@ mod score;
mod split_commit;
mod split_dataset;
mod synthesize;
+mod truncate_expected_patch;
mod word_diff;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use collections::HashSet;
@@ -54,6 +55,7 @@ use crate::score::run_scoring;
use crate::split_commit::SplitCommitArgs;
use crate::split_dataset::SplitArgs;
use crate::synthesize::{SynthesizeConfig, run_synthesize};
+use crate::truncate_expected_patch::TruncatePatchArgs;
#[derive(Parser, Debug)]
#[command(name = "ep")]
@@ -193,6 +195,8 @@ enum Command {
Clean,
/// Generate an evaluation example by splitting a chronologically-ordered commit
SplitCommit(SplitCommitArgs),
+ /// Truncate expected patch by the given criteria
+ TruncatePatch(TruncatePatchArgs),
/// Split a JSONL dataset into multiple files (stratified by repository_url if present)
Split(SplitArgs),
/// Filter a JSONL dataset by programming language (based on cursor_path extension)
@@ -233,6 +237,7 @@ impl Display for Command {
}
Command::Clean => write!(f, "clean"),
Command::SplitCommit(_) => write!(f, "split-commit"),
+ Command::TruncatePatch(_) => write!(f, "truncate-patch"),
Command::Split(_) => write!(f, "split"),
Command::FilterLanguages(_) => write!(f, "filter-languages"),
Command::ImportBatch(args) => {
@@ -745,6 +750,15 @@ fn main() {
}
return;
}
+ Command::TruncatePatch(truncate_args) => {
+ if let Err(error) =
+ truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
+ {
+ eprintln!("{error:#}");
+ std::process::exit(1);
+ }
+ return;
+ }
Command::Split(split_args) => {
if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
eprintln!("{error:#}");
@@ -937,6 +951,7 @@ fn main() {
| Command::Synthesize(_)
| Command::SplitCommit(_)
| Command::Split(_)
+ | Command::TruncatePatch(_)
| Command::FilterLanguages(_)
| Command::ImportBatch(_) => {
unreachable!()
@@ -0,0 +1,75 @@
+use crate::example::{Example, read_example_files};
+use crate::reorder_patch::{Hunk, Patch, PatchLine};
+use clap::Args;
+use std::path::PathBuf;
+
+#[derive(Args, Debug, Clone)]
+pub struct TruncatePatchArgs {
+ /// Number of logical groups ahead to leave
+ #[arg(long)]
+ pub num_groups: usize,
+
+ /// Leave only edits in the file under the cursor
+ #[arg(long, default_value_t = false)]
+ pub current_file_only: bool,
+}
+
+pub fn run_truncate_expected_patch(
+ args: &TruncatePatchArgs,
+ inputs: &[PathBuf],
+) -> anyhow::Result<()> {
+ let stdin_path = PathBuf::from("-");
+ let inputs = if inputs.is_empty() {
+ std::slice::from_ref(&stdin_path)
+ } else {
+ inputs
+ };
+
+ let mut examples = read_example_files(inputs);
+ for example in &mut examples {
+ run_one_input(example, args)?;
+
+ println!("{}", serde_json::to_string(&example)?);
+ }
+ Ok(())
+}
+
+fn run_one_input(example: &mut Example, args: &TruncatePatchArgs) -> anyhow::Result<()> {
+ let mut patch = Patch::parse_unified_diff(&example.spec.expected_patches[0]);
+ let mut groups_left = args.num_groups;
+
+ patch.hunks.retain(|hunk| {
+ if groups_left == 0 {
+ return false;
+ }
+ if starts_new_group(hunk) {
+ groups_left -= 1;
+ }
+
+ if args.current_file_only {
+ return hunk.filename == example.spec.cursor_path.display().to_string();
+ }
+
+ true
+ });
+
+ // Remove all group headers
+ patch.header = String::new();
+ patch.hunks.iter_mut().for_each(|hunk| {
+ hunk.lines.retain(|line| match line {
+ PatchLine::Garbage(line) => !line.starts_with("//"),
+ _ => true,
+ });
+ });
+
+ example.spec.expected_patches[0] = patch.to_string();
+
+ Ok(())
+}
+
+fn starts_new_group(hunk: &Hunk) -> bool {
+ hunk.lines.iter().any(|line| match line {
+ PatchLine::Garbage(content) => content.starts_with("///"),
+ _ => false,
+ })
+}