Fix some issues with edit prediction CLI (#46197)

Max Brunsfeld created

* Added `--repo` and `--name` flags for running only examples with a
specific name, or repo (substring matching)
* Fixed a race condition that caused hangs when running multiple
examples at the same repo and sha
* Fixed a bug where scoring was completely wrong because I had passed
the arguments to `apply_diff_to_string` in the wrong order
* The current evals now run quickly and without errors.

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/udiff.rs                |  40 +---
crates/edit_prediction_cli/src/main.rs             | 144 +++++++++------
crates/edit_prediction_cli/src/retrieve_context.rs |  20 +
crates/edit_prediction_cli/src/score.rs            |   4 
crates/edit_prediction_cli/src/synthesize.rs       |   4 
5 files changed, 117 insertions(+), 95 deletions(-)

Detailed changes

crates/edit_prediction/src/udiff.rs 🔗

@@ -12,7 +12,7 @@ use collections::HashMap;
 use gpui::{AsyncApp, Entity};
 use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff};
 use postage::stream::Stream as _;
-use project::{Project, ProjectPath};
+use project::Project;
 use util::{paths::PathStyle, rel_path::RelPath};
 use worktree::Worktree;
 
@@ -58,36 +58,20 @@ pub async fn apply_diff(
             } => {
                 let buffer = match current_file {
                     None => {
-                        if is_new_file {
-                            let rel_path = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)
-                                .unwrap()
-                                .into_arc();
-
+                        let buffer = if is_new_file {
                             project
+                                .update(cx, |project, cx| project.create_buffer(true, cx))?
+                                .await?
+                        } else {
+                            let project_path = project
                                 .update(cx, |project, cx| {
-                                    project.create_entry(
-                                        ProjectPath {
-                                            worktree_id: worktree.read(cx).id(),
-                                            path: rel_path,
-                                        },
-                                        false,
-                                        cx,
-                                    )
+                                    project.find_project_path(path.as_ref(), cx)
                                 })?
-                                .await?;
-                        }
-
-                        let project_path = project
-                            .update(cx, |project, cx| {
-                                project.find_project_path(path.as_ref(), cx)
-                            })?
-                            .with_context(|| {
-                                format!("Failed to find project path in diff: {}", path)
-                            })?;
-
-                        let buffer = project
-                            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
-                            .await?;
+                                .with_context(|| format!("no such path: {}", path))?;
+                            project
+                                .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+                                .await?
+                        };
                         included_files.insert(path.to_string(), buffer.clone());
                         current_file = Some(buffer);
                         current_file.as_ref().unwrap()

crates/edit_prediction_cli/src/main.rs 🔗

@@ -44,6 +44,12 @@ struct EpArgs {
     max_parallelism: usize,
     #[clap(long, global = true)]
     limit: Option<usize>,
+    /// Filter examples by name
+    #[clap(long, global = true)]
+    name: Option<String>,
+    /// Filter examples by repository
+    #[clap(long, global = true)]
+    repo: Option<String>,
     #[command(subcommand)]
     command: Option<Command>,
     #[clap(global = true, help = INPUTS_HELP)]
@@ -249,6 +255,7 @@ async fn load_examples(
     }
 
     let mut examples = read_example_files(&file_inputs);
+
     let total_steps = examples.len() + captured_after_timestamps.len();
     Progress::global().set_total_steps(total_steps);
 
@@ -275,6 +282,13 @@ async fn load_examples(
 
     crate::example::sort_examples_by_repo_and_rev(&mut examples);
 
+    if let Some(name_filter) = &args.name {
+        examples.retain(|example| example.spec.name.contains(name_filter));
+    }
+    if let Some(repo_filter) = &args.repo {
+        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
+    }
+
     if let Some(limit) = args.limit {
         if examples.len() > limit {
             examples.truncate(limit);
@@ -413,62 +427,16 @@ fn main() {
                             }
                             .await;
 
-                            if let Err(e) = result {
-                                Progress::global().increment_failed();
-                                let failed_example_path =
-                                    FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
-                                app_state
-                                    .fs
-                                    .write(
-                                        &failed_example_path,
-                                        &serde_json::to_vec_pretty(&example).unwrap(),
-                                    )
-                                    .await
-                                    .unwrap();
-                                let err_path = FAILED_EXAMPLES_DIR
-                                    .join(format!("{}_err.txt", example.spec.name));
-                                app_state
-                                    .fs
-                                    .write(&err_path, format!("{e:?}").as_bytes())
-                                    .await
-                                    .unwrap();
-
-                                let file_path = example
-                                    .repo_name()
-                                    .unwrap()
-                                    .worktree_path()
-                                    .join(&example.spec.cursor_path);
-
-                                let msg = format!(
-                                    indoc::indoc! {"
-                                        While processing \"{}\":
-
-                                        {:?}
-
-                                        Written to: \x1b[36m{}\x1b[0m
-
-                                        Cursor File: \x1b[36m{}\x1b[0m
-
-                                        Explore this example data with:
-                                            fx \x1b[36m{}\x1b[0m
-
-                                        Re-run this example with:
-                                            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
-                                    "},
-                                    example.spec.name,
-                                    e,
-                                    err_path.display(),
-                                    file_path.display(),
-                                    failed_example_path.display(),
-                                    command,
-                                    failed_example_path.display(),
-                                );
-                                if args.failfast || failfast_on_single_example {
-                                    Progress::global().finalize();
-                                    panic!("{}", msg);
-                                } else {
-                                    log::error!("{}", msg);
-                                }
+                            if let Err(error) = result {
+                                handle_error(
+                                    error,
+                                    &args,
+                                    &command,
+                                    &app_state,
+                                    failfast_on_single_example,
+                                    example,
+                                )
+                                .await;
                             }
                         }
                     });
@@ -499,3 +467,67 @@ fn main() {
         .detach();
     });
 }
+
+async fn handle_error(
+    error: anyhow::Error,
+    args: &EpArgs,
+    command: &Command,
+    app_state: &Arc<headless::EpAppState>,
+    failfast_on_single_example: bool,
+    example: &Example,
+) {
+    Progress::global().increment_failed();
+    let example_name = example.spec.filename();
+    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
+    app_state
+        .fs
+        .write(
+            &failed_example_path,
+            &serde_json::to_vec_pretty(&example).unwrap(),
+        )
+        .await
+        .unwrap();
+    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
+    app_state
+        .fs
+        .write(&err_path, format!("{error:?}").as_bytes())
+        .await
+        .unwrap();
+
+    let file_path = example
+        .repo_name()
+        .unwrap()
+        .worktree_path()
+        .join(&example.spec.cursor_path);
+
+    let msg = format!(
+        indoc::indoc! {"
+            While processing \"{}\":
+
+            {:?}
+
+            Written to: \x1b[36m{}\x1b[0m
+
+            Cursor File: \x1b[36m{}\x1b[0m
+
+            Explore this example data with:
+            fx \x1b[36m{}\x1b[0m
+
+            Re-run this example with:
+            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
+        "},
+        example.spec.name,
+        error,
+        err_path.display(),
+        file_path.display(),
+        failed_example_path.display(),
+        command,
+        failed_example_path.display(),
+    );
+    if args.failfast || failfast_on_single_example {
+        Progress::global().finalize();
+        panic!("{}", msg);
+    } else {
+        log::error!("{}", msg);
+    }
+}

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -125,12 +125,6 @@ async fn wait_for_language_servers_to_start(
 
     drop(added_subscription);
 
-    if !language_server_ids.is_empty() {
-        project
-            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
-            .detach();
-    }
-
     let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
     let subscriptions = [
         cx.subscribe(&lsp_store, {
@@ -153,6 +147,7 @@ async fn wait_for_language_servers_to_start(
         }),
         cx.subscribe(project, {
             let step_progress = step_progress.clone();
+            let lsp_store = lsp_store.clone();
             move |_, event, cx| match event {
                 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
                     let lsp_store = lsp_store.read(cx);
@@ -172,7 +167,18 @@ async fn wait_for_language_servers_to_start(
         .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
         .await?;
 
-    let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
+    let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
+        language_server_ids
+            .iter()
+            .copied()
+            .filter(|id| {
+                lsp_store
+                    .language_server_statuses
+                    .get(id)
+                    .is_some_and(|status| status.has_pending_diagnostic_updates)
+            })
+            .collect::<HashSet<_>>()
+    })?;
     while !pending_language_server_ids.is_empty() {
         futures::select! {
             language_server_id = rx.next() => {

crates/edit_prediction_cli/src/score.rs 🔗

@@ -34,14 +34,14 @@ pub async fn run_scoring(
         .expected_patches
         .iter()
         .map(|patch| {
-            apply_diff_to_string(original_text, patch)
+            apply_diff_to_string(patch, original_text)
                 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
         })
         .collect::<Result<Vec<_>, _>>()?;
 
     let mut scores = vec![];
     for prediction in &example.predictions {
-        let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
+        let actual_text = match apply_diff_to_string(&prediction.actual_patch, original_text) {
             Ok(text) => text,
             Err(_) => {
                 scores.push(ExampleScore { delta_chr_f: 0.0 });

crates/edit_prediction_cli/src/synthesize.rs 🔗

@@ -687,7 +687,7 @@ async fn build_example(
 
     // Validate expected patch applies to intermediate state
     let expected_patch_with_header = ensure_diff_header(expected_patch, &cursor_file);
-    apply_diff_to_string(&intermediate_state, &expected_patch_with_header)
+    apply_diff_to_string(&expected_patch_with_header, &intermediate_state)
         .map_err(|e| format!("Expected patch failed to apply: {}", e))?;
 
     // Find where the expected patch edits would apply in the intermediate state
@@ -764,7 +764,7 @@ fn apply_edit_history_to_content(
         return Ok(content.to_string());
     }
 
-    apply_diff_to_string(content, &file_diff)
+    apply_diff_to_string(&file_diff, content)
         .map_err(|e| format!("Failed to apply edit history: {}", e))
 }