Fix edit history clearing bug in ep (#47017)

Max Brunsfeld , Oleksiy Syvokon , Agus Zubiaga , Ben Kunkle , and Zed Zippy created

We were including changes due to Buffer.reload in the edit history.

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com>

Change summary

crates/edit_prediction/src/udiff.rs                |  16 
crates/edit_prediction_cli/src/example.rs          |   5 
crates/edit_prediction_cli/src/format_prompt.rs    |   7 
crates/edit_prediction_cli/src/headless.rs         |   4 
crates/edit_prediction_cli/src/load_project.rs     |  19 
crates/edit_prediction_cli/src/main.rs             | 201 +++++++++------
crates/edit_prediction_cli/src/predict.rs          |  12 
crates/edit_prediction_cli/src/progress.rs         |  39 ++
crates/edit_prediction_cli/src/retrieve_context.rs |  24 +
crates/edit_prediction_cli/src/score.rs            |   7 
10 files changed, 204 insertions(+), 130 deletions(-)

Detailed changes

crates/edit_prediction/src/udiff.rs 🔗

@@ -112,10 +112,12 @@ pub async fn apply_diff(
                 };
 
                 buffer.read_with(cx, |buffer, _| {
-                    edits.extend(
-                        resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice(), status)
-                            .with_context(|| format!("Diff:\n{diff_str}"))?,
-                    );
+                    edits.extend(resolve_hunk_edits_in_buffer(
+                        hunk,
+                        buffer,
+                        ranges.as_slice(),
+                        status,
+                    )?);
                     anyhow::Ok(())
                 })?;
             }
@@ -648,11 +650,7 @@ fn resolve_hunk_edits_in_buffer(
         })
         .ok_or_else(|| {
             if candidates.is_empty() {
-                anyhow!(
-                    "Failed to match context:\n\n```\n{}```\n\nBuffer contents:\n\n```\n{}```",
-                    hunk.context,
-                    buffer.text()
-                )
+                anyhow!("Failed to match context:\n\n```\n{}```\n", hunk.context,)
             } else {
                 anyhow!("Context is not unique enough:\n{}", hunk.context)
             }

crates/edit_prediction_cli/src/example.rs 🔗

@@ -11,6 +11,7 @@ use project::Project;
 use serde::{Deserialize, Serialize};
 use std::{
     borrow::Cow,
+    collections::VecDeque,
     io::Read,
     path::{Path, PathBuf},
     sync::Arc,
@@ -214,9 +215,9 @@ pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
     });
 }
 
-pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
+pub fn group_examples_by_repo(examples: Vec<Example>) -> VecDeque<Vec<Example>> {
     let mut examples_by_repo = HashMap::default();
-    for example in examples.iter_mut() {
+    for example in examples {
         examples_by_repo
             .entry(example.spec.repository_url.clone())
             .or_insert_with(Vec::new)

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     FormatPromptArgs, PredictionProvider,
     example::{Example, ExamplePrompt},
     headless::EpAppState,
-    progress::{Progress, Step},
+    progress::{ExampleProgress, Step},
     retrieve_context::run_context_retrieval,
 };
 use anyhow::{Context as _, Result};
@@ -18,11 +18,12 @@ pub async fn run_format_prompt(
     example: &mut Example,
     args: &FormatPromptArgs,
     app_state: Arc<EpAppState>,
+    example_progress: &ExampleProgress,
     cx: AsyncApp,
 ) -> Result<()> {
-    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
+    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 
-    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
+    let step_progress = example_progress.start(Step::FormatPrompt);
 
     let prompt_inputs = example
         .prompt_inputs

crates/edit_prediction_cli/src/headless.rs 🔗

@@ -37,6 +37,10 @@ impl ProjectCache {
     pub fn get(&self, repository_url: &String) -> Option<Entity<Project>> {
         self.0.lock().unwrap().get(repository_url).cloned()
     }
+
+    pub fn remove(&self, repository_url: &String) {
+        self.0.lock().unwrap().remove(repository_url);
+    }
 }
 
 pub fn init(cx: &mut App) -> EpAppState {

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     example::{Example, ExamplePromptInputs, ExampleState},
     git,
     headless::EpAppState,
-    progress::{InfoStyle, Progress, Step, StepProgress},
+    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
 };
 use anyhow::{Context as _, Result};
 use edit_prediction::{
@@ -18,13 +18,14 @@ use std::{fs, path::PathBuf, sync::Arc};
 pub async fn run_load_project(
     example: &mut Example,
     app_state: Arc<EpAppState>,
+    example_progress: &ExampleProgress,
     mut cx: AsyncApp,
 ) -> Result<()> {
     if example.state.is_some() {
         return Ok(());
     }
 
-    let progress = Progress::global().start(Step::LoadProject, &example.spec.name);
+    let progress = example_progress.start(Step::LoadProject);
 
     let project = setup_project(example, &app_state, &progress, &mut cx).await?;
 
@@ -160,15 +161,11 @@ async fn cursor_position(
 
         let mut matches = text.match_indices(&cursor_excerpt);
         let (excerpt_offset, _) = matches.next().with_context(|| {
-            format!(
-                "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
-                example.spec.name
-            )
+            format!("Cursor excerpt did not exist in buffer:\n\n{cursor_excerpt}\n",)
         })?;
         anyhow::ensure!(
             matches.next().is_none(),
-            "More than one cursor position match found for {}",
-            &example.spec.name
+            "More than one cursor position match found",
         );
         Ok(excerpt_offset)
     })?;
@@ -193,9 +190,6 @@ async fn setup_project(
     let worktree_path = setup_worktree(example, step_progress).await?;
 
     if let Some(project) = app_state.project_cache.get(&example.spec.repository_url) {
-        ep_store.update(cx, |ep_store, _| {
-            ep_store.clear_history_for_project(&project);
-        });
         let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone());
         let buffers = buffer_store.read_with(cx, |buffer_store, _| {
             buffer_store.buffers().collect::<Vec<_>>()
@@ -203,6 +197,9 @@ async fn setup_project(
         for buffer in buffers {
             buffer.update(cx, |buffer, cx| buffer.reload(cx)).await.ok();
         }
+        ep_store.update(cx, |ep_store, _| {
+            ep_store.clear_history_for_project(&project);
+        });
         return Ok(project);
     }
 

crates/edit_prediction_cli/src/main.rs 🔗

@@ -30,6 +30,7 @@ use std::fmt::Display;
 use std::fs::{File, OpenOptions};
 use std::hash::{Hash, Hasher};
 use std::io::{BufRead, BufReader, BufWriter, Write};
+use std::sync::Mutex;
 use std::{path::PathBuf, sync::Arc};
 
 use crate::distill::run_distill;
@@ -496,7 +497,7 @@ fn main() {
 
         cx.spawn(async move |cx| {
             let result = async {
-                let mut examples = load_examples(
+                let examples = load_examples(
                     app_state.client.http_client(),
                     &args,
                     output.as_ref(),
@@ -536,97 +537,133 @@ fn main() {
                         None
                     };
 
-                let mut grouped_examples = group_examples_by_repo(&mut examples);
-                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
-
-                for example_batch in example_batches {
-                    let futures = example_batch.into_iter().map(|repo_examples| async {
-                        for example in repo_examples.iter_mut() {
-                            let result = async {
-                                match &command {
-                                    Command::ParseExample => {}
-                                    Command::LoadProject => {
-                                        run_load_project(example, app_state.clone(), cx.clone())
+                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
+                let finished_examples = Mutex::new(Vec::new());
+
+                let mut tasks = Vec::new();
+                for _ in 0..args.max_parallelism {
+                    tasks.push(async {
+                        loop {
+                            let Some(mut repo_examples) =
+                                grouped_examples.lock().unwrap().pop_front()
+                            else {
+                                break;
+                            };
+                            for example in &mut repo_examples {
+                                let example_progress =
+                                    Progress::global().start_group(&example.spec.name);
+
+                                let result = async {
+                                    match &command {
+                                        Command::ParseExample => {}
+                                        Command::LoadProject => {
+                                            run_load_project(
+                                                example,
+                                                app_state.clone(),
+                                                &example_progress,
+                                                cx.clone(),
+                                            )
                                             .await?;
-                                    }
-                                    Command::Context => {
-                                        run_context_retrieval(
-                                            example,
-                                            app_state.clone(),
-                                            cx.clone(),
-                                        )
-                                        .await?;
-                                    }
-                                    Command::FormatPrompt(args) => {
-                                        run_format_prompt(
-                                            example,
-                                            args,
-                                            app_state.clone(),
-                                            cx.clone(),
-                                        )
-                                        .await?;
-                                    }
-                                    Command::Predict(args) => {
-                                        run_prediction(
-                                            example,
-                                            args,
-                                            app_state.clone(),
-                                            cx.clone(),
-                                        )
-                                        .await?;
-                                    }
-                                    Command::Distill => {
-                                        run_distill(example).await?;
-                                    }
-                                    Command::Score(args) | Command::Eval(args) => {
-                                        run_scoring(example, &args, app_state.clone(), cx.clone())
+                                        }
+                                        Command::Context => {
+                                            run_context_retrieval(
+                                                example,
+                                                app_state.clone(),
+                                                &example_progress,
+                                                cx.clone(),
+                                            )
                                             .await?;
+                                        }
+                                        Command::FormatPrompt(args) => {
+                                            run_format_prompt(
+                                                example,
+                                                args,
+                                                app_state.clone(),
+                                                &example_progress,
+                                                cx.clone(),
+                                            )
+                                            .await?;
+                                        }
+                                        Command::Predict(args) => {
+                                            run_prediction(
+                                                example,
+                                                args,
+                                                app_state.clone(),
+                                                &example_progress,
+                                                cx.clone(),
+                                            )
+                                            .await?;
+                                        }
+                                        Command::Distill => {
+                                            run_distill(example).await?;
+                                        }
+                                        Command::Score(args) | Command::Eval(args) => {
+                                            run_scoring(
+                                                example,
+                                                &args,
+                                                app_state.clone(),
+                                                &example_progress,
+                                                cx.clone(),
+                                            )
+                                            .await?;
+                                        }
+                                        Command::Clean
+                                        | Command::Synthesize(_)
+                                        | Command::SplitCommit(_)
+                                        | Command::Split(_) => {
+                                            unreachable!()
+                                        }
                                     }
-                                    Command::Clean
-                                    | Command::Synthesize(_)
-                                    | Command::SplitCommit(_)
-                                    | Command::Split(_) => {
-                                        unreachable!()
-                                    }
+                                    anyhow::Ok(())
                                 }
-                                anyhow::Ok(())
-                            }
-                            .await;
-
-                            let failed = if let Err(error) = result {
-                                handle_error(
-                                    error,
-                                    &args,
-                                    &command,
-                                    &app_state,
-                                    failfast_on_single_example,
-                                    example,
-                                )
                                 .await;
-                                true
-                            } else {
-                                false
-                            };
 
-                            let should_write = !failed || args.failed == FailedHandling::Keep;
-                            if should_write {
-                                if let Some(ref mut sender) = output_sender.clone() {
-                                    let line = serde_json::to_string(example).unwrap();
-                                    sender
-                                        .send(line)
-                                        .await
-                                        .expect("Failed to send to output writer");
-                                } else if args.output.is_none()
-                                    && !matches!(command, Command::Eval(_))
-                                {
-                                    let line = serde_json::to_string(example).unwrap();
-                                    println!("{}", line);
+                                let failed = if let Err(error) = result {
+                                    handle_error(
+                                        error,
+                                        &args,
+                                        &command,
+                                        &app_state,
+                                        failfast_on_single_example,
+                                        &example,
+                                    )
+                                    .await;
+                                    true
+                                } else {
+                                    false
+                                };
+
+                                let should_write = !failed || args.failed == FailedHandling::Keep;
+                                if should_write {
+                                    if let Some(ref mut sender) = output_sender.clone() {
+                                        let line = serde_json::to_string(&example).unwrap();
+                                        sender
+                                            .send(line)
+                                            .await
+                                            .expect("Failed to send to output writer");
+                                    } else if args.output.is_none()
+                                        && !matches!(command, Command::Eval(_))
+                                    {
+                                        let line = serde_json::to_string(&example).unwrap();
+                                        println!("{}", line);
+                                    }
                                 }
                             }
+
+                            app_state
+                                .project_cache
+                                .remove(&repo_examples.first().unwrap().spec.repository_url);
+                            for example in &mut repo_examples {
+                                example.state.take();
+                            }
+                            finished_examples
+                                .lock()
+                                .unwrap()
+                                .extend_from_slice(&repo_examples);
                         }
                     });
-                    futures::future::join_all(futures).await;
                 }
+                futures::future::join_all(tasks).await;
 
                 Progress::global().finalize();
 
@@ -638,7 +675,7 @@ fn main() {
                 }
 
                 match &command {
-                    Command::Eval(_) => score::print_report(&examples),
+                    Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
                     _ => (),
                 };
 

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -6,7 +6,7 @@ use crate::{
     headless::EpAppState,
     load_project::run_load_project,
     paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
-    progress::{InfoStyle, Progress, Step},
+    progress::{ExampleProgress, InfoStyle, Step},
     retrieve_context::run_context_retrieval,
 };
 use anyhow::Context as _;
@@ -28,6 +28,7 @@ pub async fn run_prediction(
     example: &mut Example,
     args: &PredictArgs,
     app_state: Arc<EpAppState>,
+    example_progress: &ExampleProgress,
     mut cx: AsyncApp,
 ) -> anyhow::Result<()> {
     let provider = args.provider;
@@ -41,17 +42,18 @@ pub async fn run_prediction(
         }
     }
 
-    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
+    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 
     if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
         args.provider
     {
-        let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
+        let _step_progress = example_progress.start(Step::Predict);
 
         run_format_prompt(
             example,
             &FormatPromptArgs { provider },
             app_state.clone(),
+            example_progress,
             cx,
         )
         .await?;
@@ -60,9 +62,9 @@ pub async fn run_prediction(
         return predict_anthropic(example, repetition_count, version, batched).await;
     }
 
-    run_load_project(example, app_state.clone(), cx.clone()).await?;
+    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 
-    let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
+    let step_progress = example_progress.start(Step::Predict);
 
     if matches!(
         provider,

crates/edit_prediction_cli/src/progress.rs 🔗

@@ -22,6 +22,7 @@ struct ProgressInner {
     max_example_name_len: usize,
     status_lines_displayed: usize,
     total_examples: usize,
+    completed_examples: usize,
     failed_examples: usize,
     last_line_is_logging: bool,
     ticker: Option<std::thread::JoinHandle<()>>,
@@ -101,12 +102,14 @@ impl Progress {
                     inner: Mutex::new(ProgressInner {
                         completed: Vec::new(),
                         in_progress: HashMap::new(),
-                        is_tty: std::env::var("NO_COLOR").is_err()
-                            && std::io::stderr().is_terminal(),
+                        is_tty: std::env::var("COLOR").is_ok()
+                            || (std::env::var("NO_COLOR").is_err()
+                                && std::io::stderr().is_terminal()),
                         terminal_width: get_terminal_width(),
                         max_example_name_len: 0,
                         status_lines_displayed: 0,
                         total_examples: 0,
+                        completed_examples: 0,
                         failed_examples: 0,
                         last_line_is_logging: false,
                         ticker: None,
@@ -119,6 +122,18 @@ impl Progress {
             .clone()
     }
 
+    pub fn start_group(self: &Arc<Self>, example_name: &str) -> ExampleProgress {
+        ExampleProgress {
+            progress: self.clone(),
+            example_name: example_name.to_string(),
+        }
+    }
+
+    fn increment_completed(&self) {
+        let mut inner = self.inner.lock().unwrap();
+        inner.completed_examples += 1;
+    }
+
     pub fn set_total_examples(&self, total: usize) {
         let mut inner = self.inner.lock().unwrap();
         inner.total_examples = total;
@@ -247,7 +262,6 @@ impl Progress {
             for _ in 0..inner.status_lines_displayed {
                 eprint!("\x1b[A\x1b[K");
             }
-            let _ = std::io::stderr().flush();
             inner.status_lines_displayed = 0;
         }
     }
@@ -317,7 +331,7 @@ impl Progress {
         let dim = "\x1b[2m";
 
         // Build the done/in-progress/total label
-        let done_count = inner.completed.len();
+        let done_count = inner.completed_examples;
         let in_progress_count = inner.in_progress.len();
         let failed_count = inner.failed_examples;
 
@@ -427,6 +441,23 @@ impl Progress {
     }
 }
 
+pub struct ExampleProgress {
+    progress: Arc<Progress>,
+    example_name: String,
+}
+
+impl ExampleProgress {
+    pub fn start(&self, step: Step) -> StepProgress {
+        self.progress.start(step, &self.example_name)
+    }
+}
+
+impl Drop for ExampleProgress {
+    fn drop(&mut self) {
+        self.progress.increment_completed();
+    }
+}
+
 pub struct StepProgress {
     progress: Arc<Progress>,
     step: Step,

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     example::Example,
     headless::EpAppState,
     load_project::run_load_project,
-    progress::{InfoStyle, Progress, Step, StepProgress},
+    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
 };
 use anyhow::Context as _;
 use collections::HashSet;
@@ -17,6 +17,7 @@ use std::time::Duration;
 pub async fn run_context_retrieval(
     example: &mut Example,
     app_state: Arc<EpAppState>,
+    example_progress: &ExampleProgress,
     mut cx: AsyncApp,
 ) -> anyhow::Result<()> {
     if example
@@ -27,11 +28,9 @@ pub async fn run_context_retrieval(
         return Ok(());
     }
 
-    run_load_project(example, app_state.clone(), cx.clone()).await?;
+    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 
-    let step_progress: Arc<StepProgress> = Progress::global()
-        .start(Step::Context, &example.spec.name)
-        .into();
+    let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
 
     let state = example.state.as_ref().unwrap();
     let project = state.project.clone();
@@ -96,10 +95,13 @@ async fn wait_for_language_servers_to_start(
 
     step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
 
-    let timeout = cx
-        .background_executor()
-        .timer(Duration::from_secs(60 * 5))
-        .shared();
+    let timeout_duration = if starting_language_server_ids.is_empty() {
+        Duration::from_secs(30)
+    } else {
+        Duration::from_secs(60 * 5)
+    };
+
+    let timeout = cx.background_executor().timer(timeout_duration).shared();
 
     let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
     let added_subscription = cx.subscribe(project, {
@@ -121,7 +123,7 @@ async fn wait_for_language_servers_to_start(
                 }
             },
             _ = timeout.clone().fuse() => {
-                return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
+                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
             }
         }
     }
@@ -190,7 +192,7 @@ async fn wait_for_language_servers_to_start(
                 }
             },
             _ = timeout.clone().fuse() => {
-                return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
+                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
             }
         }
     }

crates/edit_prediction_cli/src/score.rs 🔗

@@ -4,7 +4,7 @@ use crate::{
     headless::EpAppState,
     metrics,
     predict::run_prediction,
-    progress::{Progress, Step},
+    progress::{ExampleProgress, Step},
 };
 use anyhow::Context as _;
 use edit_prediction::udiff::apply_diff_to_string;
@@ -15,11 +15,12 @@ pub async fn run_scoring(
     example: &mut Example,
     args: &PredictArgs,
     app_state: Arc<EpAppState>,
+    example_progress: &ExampleProgress,
     cx: AsyncApp,
 ) -> anyhow::Result<()> {
-    run_prediction(example, args, app_state, cx).await?;
+    run_prediction(example, args, app_state, example_progress, cx).await?;
 
-    let progress = Progress::global().start(Step::Score, &example.spec.name);
+    let progress = example_progress.start(Step::Score);
 
     progress.set_substatus("applying patches");
     let original_text = &example.prompt_inputs.as_ref().unwrap().content;