agent: Add completion cancellation when editing messages (#32533)

Ben Brandt created

When editing a message, cancel any in-progress completion before
starting a new request to prevent overlapping model responses.

Release Notes:

- agent: Fixed previous completion not cancelling when editing a
previous message

Change summary

crates/agent/src/active_thread.rs | 111 ++++++++++++++++++++++++++++++++
1 file changed, 110 insertions(+), 1 deletion(-)

Detailed changes

crates/agent/src/active_thread.rs 🔗

@@ -1605,6 +1605,7 @@ impl ActiveThread {
 
                         this.thread.update(cx, |thread, cx| {
                             thread.advance_prompt_id();
+                            thread.cancel_last_completion(Some(window.window_handle()), cx);
                             thread.send_to_model(
                                 model.model,
                                 CompletionIntent::UserPrompt,
@@ -3706,7 +3707,7 @@ mod tests {
     use util::path;
     use workspace::CollaboratorId;
 
-    use crate::{ContextLoadResult, thread_store};
+    use crate::{ContextLoadResult, thread::MessageSegment, thread_store};
 
     use super::*;
 
@@ -3840,6 +3841,114 @@ mod tests {
         });
     }
 
+    #[gpui::test]
+    async fn test_editing_message_cancels_previous_completion(cx: &mut TestAppContext) {
+        init_test_settings(cx);
+
+        let project = create_test_project(cx, json!({})).await;
+
+        let (cx, active_thread, _, thread, model) =
+            setup_test_environment(cx, project.clone()).await;
+
+        cx.update(|_, cx| {
+            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+                registry.set_default_model(
+                    Some(ConfiguredModel {
+                        provider: Arc::new(FakeLanguageModelProvider),
+                        model: model.clone(),
+                    }),
+                    cx,
+                );
+            });
+        });
+
+        // Track thread events to verify cancellation
+        let cancellation_events = Arc::new(std::sync::Mutex::new(Vec::new()));
+        let new_request_events = Arc::new(std::sync::Mutex::new(Vec::new()));
+
+        let _subscription = cx.update(|_, cx| {
+            let cancellation_events = cancellation_events.clone();
+            let new_request_events = new_request_events.clone();
+            cx.subscribe(
+                &thread,
+                move |_thread, event: &ThreadEvent, _cx| match event {
+                    ThreadEvent::CompletionCanceled => {
+                        cancellation_events.lock().unwrap().push(());
+                    }
+                    ThreadEvent::NewRequest => {
+                        new_request_events.lock().unwrap().push(());
+                    }
+                    _ => {}
+                },
+            )
+        });
+
+        // Insert a user message and start streaming a response
+        let message = thread.update(cx, |thread, cx| {
+            let message_id = thread.insert_user_message(
+                "Hello, how are you?",
+                ContextLoadResult::default(),
+                None,
+                vec![],
+                cx,
+            );
+            thread.advance_prompt_id();
+            thread.send_to_model(
+                model.clone(),
+                CompletionIntent::UserPrompt,
+                cx.active_window(),
+                cx,
+            );
+            thread.message(message_id).cloned().unwrap()
+        });
+
+        cx.run_until_parked();
+
+        // Verify that a completion is in progress
+        assert!(cx.read(|cx| thread.read(cx).is_generating()));
+        assert_eq!(new_request_events.lock().unwrap().len(), 1);
+
+        // Edit the message while the completion is still running
+        active_thread.update_in(cx, |active_thread, window, cx| {
+            active_thread.start_editing_message(
+                message.id,
+                message.segments.as_slice(),
+                message.creases.as_slice(),
+                window,
+                cx,
+            );
+            let editor = active_thread
+                .editing_message
+                .as_ref()
+                .unwrap()
+                .1
+                .editor
+                .clone();
+            editor.update(cx, |editor, cx| {
+                editor.set_text("What is the weather like?", window, cx);
+            });
+            active_thread.confirm_editing_message(&Default::default(), window, cx);
+        });
+
+        cx.run_until_parked();
+
+        // Verify that the previous completion was cancelled
+        assert_eq!(cancellation_events.lock().unwrap().len(), 1);
+
+        // Verify that a new request was started after cancellation
+        assert_eq!(new_request_events.lock().unwrap().len(), 2);
+
+        // Verify that the edited message contains the new text
+        let edited_message =
+            thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
+        match &edited_message.segments[0] {
+            MessageSegment::Text(text) => {
+                assert_eq!(text, "What is the weather like?");
+            }
+            _ => panic!("Expected text segment"),
+        }
+    }
+
     fn init_test_settings(cx: &mut TestAppContext) {
         cx.update(|cx| {
             let settings_store = SettingsStore::test(cx);