From d4e89f9587f66ca277fc662dbeb45324d91952bb Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 26 Feb 2026 15:36:29 -0800 Subject: [PATCH] Add edit prediction evals that test related excerpt usage (#50256) I've also fixed a race condition with the programmatic context retrieval in the CLI, which was causing no excerpts to be fetched for the Rust examples. Release Notes: - N/A --- .../evals/vscode--log-object-property.md | 56 +++++++ .../evals/zed--add-eprintln.md | 54 ++++--- .../evals/zed--change-match-arm.md | 68 +++++++++ .../src/retrieve_context.rs | 143 ++++++++---------- 4 files changed, 213 insertions(+), 108 deletions(-) create mode 100644 crates/edit_prediction_cli/evals/vscode--log-object-property.md create mode 100644 crates/edit_prediction_cli/evals/zed--change-match-arm.md diff --git a/crates/edit_prediction_cli/evals/vscode--log-object-property.md b/crates/edit_prediction_cli/evals/vscode--log-object-property.md new file mode 100644 index 0000000000000000000000000000000000000000..1c60b84f0107c54ea8bd89084dccbfdf785fb932 --- /dev/null +++ b/crates/edit_prediction_cli/evals/vscode--log-object-property.md @@ -0,0 +1,56 @@ ++++ +repository_url = "https://github.com/microsoft/vscode" +revision = "e28a92fc1fbe9de11eca2f8ad19899334bff8525" ++++ + +This prediction requires the model to see the `IDiffComputationResult` type definition. + +## Edit History + +```diff +--- a/src/vs/editor/browser/widget/diffEditorWidget.ts ++++ b/src/vs/editor/browser/widget/diffEditorWidget.ts +@@ -1117,6 +1117,7 @@ + && currentModifiedModel === this._modifiedEditor.getModel() + ) { + this._setState(editorBrowser.DiffEditorState.DiffComputed); ++ console.log("did quit:") + this._diffComputationResult = result; + this._updateDecorationsRunner.schedule(); + this._onDidUpdateDiff.fire(); +``` + +## Cursor Position + +```src/vs/editor/browser/widget/diffEditorWidget.ts + if (currentToken === this._diffComputationToken + && currentOriginalModel === this._originalEditor.getModel() + && currentModifiedModel === this._modifiedEditor.getModel() + ) { + this._setState(editorBrowser.DiffEditorState.DiffComputed); + console.log("did quit:") + // ^[CURSOR_POSITION] + this._diffComputationResult = result; + this._updateDecorationsRunner.schedule(); + this._onDidUpdateDiff.fire(); + } +``` + +## Expected Patch + +```diff +--- a/src/vs/editor/browser/widget/diffEditorWidget.ts ++++ b/src/vs/editor/browser/widget/diffEditorWidget.ts +@@ -1115,10 +1115,10 @@ + if (currentToken === this._diffComputationToken + && currentOriginalModel === this._originalEditor.getModel() + && currentModifiedModel === this._modifiedEditor.getModel() + ) { + this._setState(editorBrowser.DiffEditorState.DiffComputed); +- console.log("did quit:") ++ console.log("did quit:", result.quitEarly) + this._diffComputationResult = result; + this._updateDecorationsRunner.schedule(); + this._onDidUpdateDiff.fire(); + } +``` diff --git a/crates/edit_prediction_cli/evals/zed--add-eprintln.md b/crates/edit_prediction_cli/evals/zed--add-eprintln.md index d4252810b5f97df0991de3015c19e12138e8a27b..467bfd5151996bc98d00145bfebef62f89c5e37e 100644 --- a/crates/edit_prediction_cli/evals/zed--add-eprintln.md +++ b/crates/edit_prediction_cli/evals/zed--add-eprintln.md @@ -1,43 +1,37 @@ +++ repository_url = "git@github.com:zed-industries/zed" -revision = "780a87dd98f26816876d12e2728933b17faca78d" +revision = "b7090c9fae7390a82021b994994c0f587744d96c" +++ +This example shows the model's preference for making conservative predictions, and ability to place +the cursor within the predicted output. + ## Edit History ```diff --- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs -@@ -206,6 +206,7 @@ - self.select_next_edit(&Default::default(), window, cx); - self.confirm(&Default::default(), window, cx); - +@@ -144,7 +144,7 @@ + fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) { + epr - cx.notify(); - } - + let next_index = self + .ep_store + .read(cx) ``` ## Cursor Position ```crates/edit_prediction_ui/src/rate_prediction_modal.rs - let current_completion = self - .active_prediction - .as_ref() - .map(|completion| completion.prediction.clone()); - self.select_completion(current_completion, false, window, cx); - self.select_next_edit(&Default::default(), window, cx); - self.confirm(&Default::default(), window, cx); - + fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) { epr // ^[CURSOR_POSITION] - cx.notify(); - } - - pub fn thumbs_down_active( - &mut self, - _: &ThumbsDownActivePrediction, - window: &mut Window, + let next_index = self + .ep_store + .read(cx) + .shown_predictions() + .skip(self.selected_index) + .enumerate() + .skip(1) // Skip straight to the next item ``` ## Expected Patch @@ -45,12 +39,16 @@ revision = "780a87dd98f26816876d12e2728933b17faca78d" ```diff --- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs -@@ -201,16 +201,16 @@ - self.confirm(&Default::default(), window, cx); - +@@ -144,14 +144,14 @@ + fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) { - epr + eprintln!(""); # ^[CURSOR_POSITION] - cx.notify(); - } + let next_index = self + .ep_store + .read(cx) + .shown_predictions() + .skip(self.selected_index) + .enumerate() + .skip(1) // Skip straight to the next item ``` diff --git a/crates/edit_prediction_cli/evals/zed--change-match-arm.md b/crates/edit_prediction_cli/evals/zed--change-match-arm.md new file mode 100644 index 0000000000000000000000000000000000000000..042e2730cc352d9c90739a3fe3ea20438755896b --- /dev/null +++ b/crates/edit_prediction_cli/evals/zed--change-match-arm.md @@ -0,0 +1,68 @@ ++++ +repository_url = "git@github.com:zed-industries/zed" +revision = "be5763632dccb33470ca233c36ccd9e5e790e3b2" ++++ + +This prediction requires the model to see the `project::Event` enum. + +## Edit History + +```diff +--- a/crates/edit_prediction/src/edit_prediction.rs ++++ b/crates/edit_prediction/src/edit_prediction.rs +@@ -1035,7 +1035,7 @@ + project_state.recent_paths.push_front(path); + } + } +- project::Event::DiagnosticsUpdated { .. } => { ++ project::Event::Disk { .. } => { + if cx.has_flag::() { + self.refresh_prediction_from_diagnostics( + project, +``` + +## Cursor Position + +```crates/edit_prediction/src/edit_prediction.rs + { + project_state.recent_paths.remove(ix); + } + project_state.recent_paths.push_front(path); + } + } + project::Event::Disk { .. } => { + // ^[CURSOR_POSITION] + if cx.has_flag::() { + self.refresh_prediction_from_diagnostics( + project, +``` + +## Expected Patch + +```diff +--- a/crates/edit_prediction/src/edit_prediction.rs ++++ b/crates/edit_prediction/src/edit_prediction.rs +@@ -1032,10 +1032,10 @@ + project_state.recent_paths.push_front(path); + } + } +- project::Event::Disk { .. } => { ++ project::Event::DiskBasedDiagnosticsFinished { .. } => { + if cx.has_flag::() { + self.refresh_prediction_from_diagnostics( + project, +``` + +```diff +--- a/crates/edit_prediction/src/edit_prediction.rs ++++ b/crates/edit_prediction/src/edit_prediction.rs +@@ -1032,10 +1032,10 @@ + project_state.recent_paths.push_front(path); + } + } +- project::Event::Disk { .. } => { ++ project::Event::DiskBasedDiagnosticsStarted { .. } => { + if cx.has_flag::() { + self.refresh_prediction_from_diagnostics( + project, +``` diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index 18ee3c1b0ec1456b02bb145c98e669b777048385..a5fb00b39a67a15a7afcced897b4d109f1f3406f 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -85,46 +85,79 @@ async fn wait_for_language_servers_to_start( ) -> anyhow::Result<()> { let lsp_store = project.read_with(cx, |project, _| project.lsp_store()); - let (language_server_ids, mut starting_language_server_ids) = - buffer.update(cx, |buffer, cx| { - lsp_store.update(cx, |lsp_store, cx| { - let ids = lsp_store.language_servers_for_local_buffer(buffer, cx); - let starting_ids = ids - .iter() - .copied() - .filter(|id| !lsp_store.language_server_statuses.contains_key(&id)) - .collect::>(); - (ids, starting_ids) - }) + // Determine which servers exist for this buffer, and which are still starting. + let mut servers_pending_start = HashSet::default(); + let mut servers_pending_diagnostics = HashSet::default(); + buffer.update(cx, |buffer, cx| { + lsp_store.update(cx, |lsp_store, cx| { + let ids = lsp_store.language_servers_for_local_buffer(buffer, cx); + for &id in &ids { + match lsp_store.language_server_statuses.get(&id) { + None => { + servers_pending_start.insert(id); + servers_pending_diagnostics.insert(id); + } + Some(status) if status.has_pending_diagnostic_updates => { + servers_pending_diagnostics.insert(id); + } + Some(_) => {} + } + } }); + }); - step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len())); + step_progress.set_substatus(format!( + "waiting for {} LSPs", + servers_pending_diagnostics.len() + )); - let timeout_duration = if starting_language_server_ids.is_empty() { + let timeout_duration = if servers_pending_start.is_empty() { Duration::from_secs(30) } else { Duration::from_secs(60 * 5) }; - let timeout = cx.background_executor().timer(timeout_duration).shared(); - let (mut tx, mut rx) = mpsc::channel(language_server_ids.len()); - let added_subscription = cx.subscribe(project, { + let (mut started_tx, mut started_rx) = mpsc::channel(servers_pending_start.len().max(1)); + let (mut diag_tx, mut diag_rx) = mpsc::channel(servers_pending_diagnostics.len().max(1)); + let subscriptions = [cx.subscribe(&lsp_store, { let step_progress = step_progress.clone(); - move |_, event, _| match event { - project::Event::LanguageServerAdded(language_server_id, name, _) => { + move |lsp_store, event, cx| match event { + project::LspStoreEvent::LanguageServerAdded(id, name, _) => { step_progress.set_substatus(format!("LSP started: {}", name)); - tx.try_send(*language_server_id).ok(); + started_tx.try_send(*id).ok(); + } + project::LspStoreEvent::DiskBasedDiagnosticsFinished { language_server_id } => { + let name = lsp_store + .read(cx) + .language_server_adapter_for_id(*language_server_id) + .unwrap() + .name(); + step_progress.set_substatus(format!("LSP idle: {}", name)); + diag_tx.try_send(*language_server_id).ok(); + } + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + client::proto::LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => { + step_progress.set_substatus(message.clone()); } _ => {} } - }); + })]; - while !starting_language_server_ids.is_empty() { + // Phase 1: wait for all servers to start. + while !servers_pending_start.is_empty() { futures::select! { - language_server_id = rx.next() => { - if let Some(id) = language_server_id { - starting_language_server_ids.remove(&id); + id = started_rx.next() => { + if let Some(id) = id { + servers_pending_start.remove(&id); } }, _ = timeout.clone().fuse() => { @@ -133,67 +166,17 @@ async fn wait_for_language_servers_to_start( } } - drop(added_subscription); - - let (mut tx, mut rx) = mpsc::channel(language_server_ids.len()); - let subscriptions = [ - cx.subscribe(&lsp_store, { - let step_progress = step_progress.clone(); - move |_, event, _| { - if let project::LspStoreEvent::LanguageServerUpdate { - message: - client::proto::update_language_server::Variant::WorkProgress( - client::proto::LspWorkProgress { - message: Some(message), - .. - }, - ), - .. - } = event - { - step_progress.set_substatus(message.clone()); - } - } - }), - cx.subscribe(project, { - let step_progress = step_progress.clone(); - let lsp_store = lsp_store.clone(); - move |_, event, cx| match event { - project::Event::DiskBasedDiagnosticsFinished { language_server_id } => { - let lsp_store = lsp_store.read(cx); - let name = lsp_store - .language_server_adapter_for_id(*language_server_id) - .unwrap() - .name(); - step_progress.set_substatus(format!("LSP idle: {}", name)); - tx.try_send(*language_server_id).ok(); - } - _ => {} - } - }), - ]; - + // Save the buffer so the server sees the current content and kicks off diagnostics. project .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) .await?; - let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| { - language_server_ids - .iter() - .copied() - .filter(|id| { - lsp_store - .language_server_statuses - .get(id) - .is_some_and(|status| status.has_pending_diagnostic_updates) - }) - .collect::>() - }); - while !pending_language_server_ids.is_empty() { + // Phase 2: wait for all servers to finish their diagnostic pass. + while !servers_pending_diagnostics.is_empty() { futures::select! { - language_server_id = rx.next() => { - if let Some(id) = language_server_id { - pending_language_server_ids.remove(&id); + id = diag_rx.next() => { + if let Some(id) = id { + servers_pending_diagnostics.remove(&id); } }, _ = timeout.clone().fuse() => {