Add edit prediction evals that test related excerpt usage (#50256)

Max Brunsfeld created

I've also fixed a race condition with the programmatic context retrieval
in the CLI, which was causing no excerpts to be fetched for the Rust
examples.

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/evals/vscode--log-object-property.md |  56 +
crates/edit_prediction_cli/evals/zed--add-eprintln.md           |  54 
crates/edit_prediction_cli/evals/zed--change-match-arm.md       |  68 +
crates/edit_prediction_cli/src/retrieve_context.rs              | 143 +-
4 files changed, 213 insertions(+), 108 deletions(-)

Detailed changes

crates/edit_prediction_cli/evals/vscode--log-object-property.md 🔗

@@ -0,0 +1,56 @@
++++
+repository_url = "https://github.com/microsoft/vscode"
+revision = "e28a92fc1fbe9de11eca2f8ad19899334bff8525"
++++
+
+This prediction requires the model to see the `IDiffComputationResult` type definition.
+
+## Edit History
+
+```diff
+--- a/src/vs/editor/browser/widget/diffEditorWidget.ts
++++ b/src/vs/editor/browser/widget/diffEditorWidget.ts
+@@ -1117,6 +1117,7 @@
+ 				&& currentModifiedModel === this._modifiedEditor.getModel()
+ 			) {
+ 				this._setState(editorBrowser.DiffEditorState.DiffComputed);
++				console.log("did quit:")
+ 				this._diffComputationResult = result;
+ 				this._updateDecorationsRunner.schedule();
+ 				this._onDidUpdateDiff.fire();
+```
+
+## Cursor Position
+
+```src/vs/editor/browser/widget/diffEditorWidget.ts
+			if (currentToken === this._diffComputationToken
+				&& currentOriginalModel === this._originalEditor.getModel()
+				&& currentModifiedModel === this._modifiedEditor.getModel()
+			) {
+				this._setState(editorBrowser.DiffEditorState.DiffComputed);
+				console.log("did quit:")
+				//                    ^[CURSOR_POSITION]
+				this._diffComputationResult = result;
+				this._updateDecorationsRunner.schedule();
+				this._onDidUpdateDiff.fire();
+			}
+```
+
+## Expected Patch
+
+```diff
+--- a/src/vs/editor/browser/widget/diffEditorWidget.ts
++++ b/src/vs/editor/browser/widget/diffEditorWidget.ts
+@@ -1115,10 +1115,10 @@
+ 			if (currentToken === this._diffComputationToken
+ 				&& currentOriginalModel === this._originalEditor.getModel()
+ 				&& currentModifiedModel === this._modifiedEditor.getModel()
+ 			) {
+ 				this._setState(editorBrowser.DiffEditorState.DiffComputed);
+-				console.log("did quit:")
++				console.log("did quit:", result.quitEarly)
+ 				this._diffComputationResult = result;
+ 				this._updateDecorationsRunner.schedule();
+ 				this._onDidUpdateDiff.fire();
+ 			}
+```

crates/edit_prediction_cli/evals/zed--add-eprintln.md 🔗

@@ -1,43 +1,37 @@
 +++
 repository_url = "git@github.com:zed-industries/zed"
-revision = "780a87dd98f26816876d12e2728933b17faca78d"
+revision = "b7090c9fae7390a82021b994994c0f587744d96c"
 +++
 
+This example shows the model's preference for making conservative predictions, and ability to place
+the cursor within the predicted output.
+
 ## Edit History
 
 ```diff
 --- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs
 +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs
-@@ -206,6 +206,7 @@
-         self.select_next_edit(&Default::default(), window, cx);
-         self.confirm(&Default::default(), window, cx);
-
+@@ -144,7 +144,7 @@
+     fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {
 +        epr
-         cx.notify();
-     }
-
+         let next_index = self
+             .ep_store
+             .read(cx)
 ```
 
 ## Cursor Position
 
 ```crates/edit_prediction_ui/src/rate_prediction_modal.rs
-        let current_completion = self
-            .active_prediction
-            .as_ref()
-            .map(|completion| completion.prediction.clone());
-        self.select_completion(current_completion, false, window, cx);
-        self.select_next_edit(&Default::default(), window, cx);
-        self.confirm(&Default::default(), window, cx);
-
+    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {
         epr
         // ^[CURSOR_POSITION]
-        cx.notify();
-    }
-
-    pub fn thumbs_down_active(
-        &mut self,
-        _: &ThumbsDownActivePrediction,
-        window: &mut Window,
+        let next_index = self
+            .ep_store
+            .read(cx)
+            .shown_predictions()
+            .skip(self.selected_index)
+            .enumerate()
+            .skip(1) // Skip straight to the next item
 ```
 
 ## Expected Patch
@@ -45,12 +39,16 @@ revision = "780a87dd98f26816876d12e2728933b17faca78d"
 ```diff
 --- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs
 +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs
-@@ -201,16 +201,16 @@
-         self.confirm(&Default::default(), window, cx);
-
+@@ -144,14 +144,14 @@
+     fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {
 -        epr
 +        eprintln!("");
 #                   ^[CURSOR_POSITION]
-         cx.notify();
-     }
+         let next_index = self
+             .ep_store
+             .read(cx)
+             .shown_predictions()
+             .skip(self.selected_index)
+             .enumerate()
+             .skip(1) // Skip straight to the next item
 ```

crates/edit_prediction_cli/evals/zed--change-match-arm.md 🔗

@@ -0,0 +1,68 @@
++++
+repository_url = "git@github.com:zed-industries/zed"
+revision = "be5763632dccb33470ca233c36ccd9e5e790e3b2"
++++
+
+This prediction requires the model to see the `project::Event` enum.
+
+## Edit History
+
+```diff
+--- a/crates/edit_prediction/src/edit_prediction.rs
++++ b/crates/edit_prediction/src/edit_prediction.rs
+@@ -1035,7 +1035,7 @@
+                     project_state.recent_paths.push_front(path);
+                 }
+             }
+-            project::Event::DiagnosticsUpdated { .. } => {
++            project::Event::Disk { .. } => {
+                 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
+                     self.refresh_prediction_from_diagnostics(
+                         project,
+```
+
+## Cursor Position
+
+```crates/edit_prediction/src/edit_prediction.rs
+                    {
+                        project_state.recent_paths.remove(ix);
+                    }
+                    project_state.recent_paths.push_front(path);
+                }
+            }
+            project::Event::Disk { .. } => {
+                //              ^[CURSOR_POSITION]
+                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
+                    self.refresh_prediction_from_diagnostics(
+                        project,
+```
+
+## Expected Patch
+
+```diff
+--- a/crates/edit_prediction/src/edit_prediction.rs
++++ b/crates/edit_prediction/src/edit_prediction.rs
+@@ -1032,10 +1032,10 @@
+                     project_state.recent_paths.push_front(path);
+                 }
+             }
+-            project::Event::Disk { .. } => {
++            project::Event::DiskBasedDiagnosticsFinished { .. } => {
+                 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
+                     self.refresh_prediction_from_diagnostics(
+                         project,
+```
+
+```diff
+--- a/crates/edit_prediction/src/edit_prediction.rs
++++ b/crates/edit_prediction/src/edit_prediction.rs
+@@ -1032,10 +1032,10 @@
+                     project_state.recent_paths.push_front(path);
+                 }
+             }
+-            project::Event::Disk { .. } => {
++            project::Event::DiskBasedDiagnosticsStarted { .. } => {
+                 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
+                     self.refresh_prediction_from_diagnostics(
+                         project,
+```

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -85,46 +85,79 @@ async fn wait_for_language_servers_to_start(
 ) -> anyhow::Result<()> {
     let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
 
-    let (language_server_ids, mut starting_language_server_ids) =
-        buffer.update(cx, |buffer, cx| {
-            lsp_store.update(cx, |lsp_store, cx| {
-                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
-                let starting_ids = ids
-                    .iter()
-                    .copied()
-                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
-                    .collect::<HashSet<_>>();
-                (ids, starting_ids)
-            })
+    // Determine which servers exist for this buffer, and which are still starting.
+    let mut servers_pending_start = HashSet::default();
+    let mut servers_pending_diagnostics = HashSet::default();
+    buffer.update(cx, |buffer, cx| {
+        lsp_store.update(cx, |lsp_store, cx| {
+            let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
+            for &id in &ids {
+                match lsp_store.language_server_statuses.get(&id) {
+                    None => {
+                        servers_pending_start.insert(id);
+                        servers_pending_diagnostics.insert(id);
+                    }
+                    Some(status) if status.has_pending_diagnostic_updates => {
+                        servers_pending_diagnostics.insert(id);
+                    }
+                    Some(_) => {}
+                }
+            }
         });
+    });
 
-    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
+    step_progress.set_substatus(format!(
+        "waiting for {} LSPs",
+        servers_pending_diagnostics.len()
+    ));
 
-    let timeout_duration = if starting_language_server_ids.is_empty() {
+    let timeout_duration = if servers_pending_start.is_empty() {
         Duration::from_secs(30)
     } else {
         Duration::from_secs(60 * 5)
     };
-
     let timeout = cx.background_executor().timer(timeout_duration).shared();
 
-    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
-    let added_subscription = cx.subscribe(project, {
+    let (mut started_tx, mut started_rx) = mpsc::channel(servers_pending_start.len().max(1));
+    let (mut diag_tx, mut diag_rx) = mpsc::channel(servers_pending_diagnostics.len().max(1));
+    let subscriptions = [cx.subscribe(&lsp_store, {
         let step_progress = step_progress.clone();
-        move |_, event, _| match event {
-            project::Event::LanguageServerAdded(language_server_id, name, _) => {
+        move |lsp_store, event, cx| match event {
+            project::LspStoreEvent::LanguageServerAdded(id, name, _) => {
                 step_progress.set_substatus(format!("LSP started: {}", name));
-                tx.try_send(*language_server_id).ok();
+                started_tx.try_send(*id).ok();
+            }
+            project::LspStoreEvent::DiskBasedDiagnosticsFinished { language_server_id } => {
+                let name = lsp_store
+                    .read(cx)
+                    .language_server_adapter_for_id(*language_server_id)
+                    .unwrap()
+                    .name();
+                step_progress.set_substatus(format!("LSP idle: {}", name));
+                diag_tx.try_send(*language_server_id).ok();
+            }
+            project::LspStoreEvent::LanguageServerUpdate {
+                message:
+                    client::proto::update_language_server::Variant::WorkProgress(
+                        client::proto::LspWorkProgress {
+                            message: Some(message),
+                            ..
+                        },
+                    ),
+                ..
+            } => {
+                step_progress.set_substatus(message.clone());
             }
             _ => {}
         }
-    });
+    })];
 
-    while !starting_language_server_ids.is_empty() {
+    // Phase 1: wait for all servers to start.
+    while !servers_pending_start.is_empty() {
         futures::select! {
-            language_server_id = rx.next() => {
-                if let Some(id) = language_server_id {
-                    starting_language_server_ids.remove(&id);
+            id = started_rx.next() => {
+                if let Some(id) = id {
+                    servers_pending_start.remove(&id);
                 }
             },
             _ = timeout.clone().fuse() => {
@@ -133,67 +166,17 @@ async fn wait_for_language_servers_to_start(
         }
     }
 
-    drop(added_subscription);
-
-    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
-    let subscriptions = [
-        cx.subscribe(&lsp_store, {
-            let step_progress = step_progress.clone();
-            move |_, event, _| {
-                if let project::LspStoreEvent::LanguageServerUpdate {
-                    message:
-                        client::proto::update_language_server::Variant::WorkProgress(
-                            client::proto::LspWorkProgress {
-                                message: Some(message),
-                                ..
-                            },
-                        ),
-                    ..
-                } = event
-                {
-                    step_progress.set_substatus(message.clone());
-                }
-            }
-        }),
-        cx.subscribe(project, {
-            let step_progress = step_progress.clone();
-            let lsp_store = lsp_store.clone();
-            move |_, event, cx| match event {
-                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
-                    let lsp_store = lsp_store.read(cx);
-                    let name = lsp_store
-                        .language_server_adapter_for_id(*language_server_id)
-                        .unwrap()
-                        .name();
-                    step_progress.set_substatus(format!("LSP idle: {}", name));
-                    tx.try_send(*language_server_id).ok();
-                }
-                _ => {}
-            }
-        }),
-    ];
-
+    // Save the buffer so the server sees the current content and kicks off diagnostics.
     project
         .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
         .await?;
 
-    let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
-        language_server_ids
-            .iter()
-            .copied()
-            .filter(|id| {
-                lsp_store
-                    .language_server_statuses
-                    .get(id)
-                    .is_some_and(|status| status.has_pending_diagnostic_updates)
-            })
-            .collect::<HashSet<_>>()
-    });
-    while !pending_language_server_ids.is_empty() {
+    // Phase 2: wait for all servers to finish their diagnostic pass.
+    while !servers_pending_diagnostics.is_empty() {
         futures::select! {
-            language_server_id = rx.next() => {
-                if let Some(id) = language_server_id {
-                    pending_language_server_ids.remove(&id);
+            id = diag_rx.next() => {
+                if let Some(id) = id {
+                    servers_pending_diagnostics.remove(&id);
                 }
             },
             _ = timeout.clone().fuse() => {