agent2: Always finalize diffs from the edit tool (#36918)

Ben Brandt and Antonio Scandurra created

Previously, we wouldn't finalize the diff if an error occurred during
editing or the tool call was canceled.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/agent2/src/thread.rs                |  24 +++++
crates/agent2/src/tools/edit_file_tool.rs  | 103 +++++++++++++++++++++++
crates/language_model/src/fake_provider.rs |  31 ++++++
3 files changed, 152 insertions(+), 6 deletions(-)

Detailed changes

crates/agent2/src/thread.rs 🔗

@@ -2459,6 +2459,30 @@ impl ToolCallEventStreamReceiver {
         }
     }
 
+    pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
+        let event = self.0.next().await;
+        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
+            update,
+        )))) = event
+        {
+            update.fields
+        } else {
+            panic!("Expected update fields but got: {:?}", event);
+        }
+    }
+
+    pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
+        let event = self.0.next().await;
+        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
+            update,
+        )))) = event
+        {
+            update.diff
+        } else {
+            panic!("Expected diff but got: {:?}", event);
+        }
+    }
+
     pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
         let event = self.0.next().await;
         if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -273,6 +273,13 @@ impl AgentTool for EditFileTool {
 
             let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
             event_stream.update_diff(diff.clone());
+            let _finalize_diff = util::defer({
+               let diff = diff.downgrade();
+               let mut cx = cx.clone();
+               move || {
+                   diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
+               }
+            });
 
             let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
             let old_text = cx
@@ -389,8 +396,6 @@ impl AgentTool for EditFileTool {
                 })
                 .await;
 
-            diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
-
             let input_path = input.path.display();
             if unified_diff.is_empty() {
                 anyhow::ensure!(
@@ -1545,6 +1550,100 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_diff_finalization(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = project::FakeFs::new(cx.executor());
+        fs.insert_tree("/", json!({"main.rs": ""})).await;
+
+        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
+        let languages = project.read_with(cx, |project, _cx| project.languages().clone());
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+        let model = Arc::new(FakeLanguageModel::default());
+        let thread = cx.new(|cx| {
+            Thread::new(
+                project.clone(),
+                cx.new(|_cx| ProjectContext::default()),
+                context_server_registry.clone(),
+                Templates::new(),
+                Some(model.clone()),
+                cx,
+            )
+        });
+
+        // Ensure the diff is finalized after the edit completes.
+        {
+            let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
+            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+            let edit = cx.update(|cx| {
+                tool.run(
+                    EditFileToolInput {
+                        display_description: "Edit file".into(),
+                        path: path!("/main.rs").into(),
+                        mode: EditFileMode::Edit,
+                    },
+                    stream_tx,
+                    cx,
+                )
+            });
+            stream_rx.expect_update_fields().await;
+            let diff = stream_rx.expect_diff().await;
+            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+            cx.run_until_parked();
+            model.end_last_completion_stream();
+            edit.await.unwrap();
+            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+        }
+
+        // Ensure the diff is finalized if an error occurs while editing.
+        {
+            model.forbid_requests();
+            let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
+            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+            let edit = cx.update(|cx| {
+                tool.run(
+                    EditFileToolInput {
+                        display_description: "Edit file".into(),
+                        path: path!("/main.rs").into(),
+                        mode: EditFileMode::Edit,
+                    },
+                    stream_tx,
+                    cx,
+                )
+            });
+            stream_rx.expect_update_fields().await;
+            let diff = stream_rx.expect_diff().await;
+            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+            edit.await.unwrap_err();
+            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+            model.allow_requests();
+        }
+
+        // Ensure the diff is finalized if the tool call gets dropped.
+        {
+            let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
+            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+            let edit = cx.update(|cx| {
+                tool.run(
+                    EditFileToolInput {
+                        display_description: "Edit file".into(),
+                        path: path!("/main.rs").into(),
+                        mode: EditFileMode::Edit,
+                    },
+                    stream_tx,
+                    cx,
+                )
+            });
+            stream_rx.expect_update_fields().await;
+            let diff = stream_rx.expect_diff().await;
+            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+            drop(edit);
+            cx.run_until_parked();
+            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+        }
+    }
+
     fn init_test(cx: &mut TestAppContext) {
         cx.update(|cx| {
             let settings_store = SettingsStore::test(cx);

crates/language_model/src/fake_provider.rs 🔗

@@ -4,12 +4,16 @@ use crate::{
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
     LanguageModelRequest, LanguageModelToolChoice,
 };
+use anyhow::anyhow;
 use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
 use http_client::Result;
 use parking_lot::Mutex;
 use smol::stream::StreamExt;
-use std::sync::Arc;
+use std::sync::{
+    Arc,
+    atomic::{AtomicBool, Ordering::SeqCst},
+};
 
 #[derive(Clone)]
 pub struct FakeLanguageModelProvider {
@@ -106,6 +110,7 @@ pub struct FakeLanguageModel {
             >,
         )>,
     >,
+    forbid_requests: AtomicBool,
 }
 
 impl Default for FakeLanguageModel {
@@ -114,11 +119,20 @@ impl Default for FakeLanguageModel {
             provider_id: LanguageModelProviderId::from("fake".to_string()),
             provider_name: LanguageModelProviderName::from("Fake".to_string()),
             current_completion_txs: Mutex::new(Vec::new()),
+            forbid_requests: AtomicBool::new(false),
         }
     }
 }
 
 impl FakeLanguageModel {
+    pub fn allow_requests(&self) {
+        self.forbid_requests.store(false, SeqCst);
+    }
+
+    pub fn forbid_requests(&self) {
+        self.forbid_requests.store(true, SeqCst);
+    }
+
     pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
         self.current_completion_txs
             .lock()
@@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel {
             LanguageModelCompletionError,
         >,
     > {
-        let (tx, rx) = mpsc::unbounded();
-        self.current_completion_txs.lock().push((request, tx));
-        async move { Ok(rx.boxed()) }.boxed()
+        if self.forbid_requests.load(SeqCst) {
+            async move {
+                Err(LanguageModelCompletionError::Other(anyhow!(
+                    "requests are forbidden"
+                )))
+            }
+            .boxed()
+        } else {
+            let (tx, rx) = mpsc::unbounded();
+            self.current_completion_txs.lock().push((request, tx));
+            async move { Ok(rx.boxed()) }.boxed()
+        }
     }
 
     fn as_fake(&self) -> &Self {