From b3a8816c0e26b3d1ab63fba71c9bf832b8a71936 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 11 Jun 2025 11:36:21 +0200 Subject: [PATCH] agent: Add completion cancellation when editing messages (#32533) When editing a message, cancel any in-progress completion before starting a new request to prevent overlapping model responses. Release Notes: - agent: Fixed previous completion not cancelling when editing a previous message --- crates/agent/src/active_thread.rs | 111 +++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index eff74f17869495b31ed50d18894685951db02811..24061488274c7baba21430ea119baf42fc884751 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1605,6 +1605,7 @@ impl ActiveThread { this.thread.update(cx, |thread, cx| { thread.advance_prompt_id(); + thread.cancel_last_completion(Some(window.window_handle()), cx); thread.send_to_model( model.model, CompletionIntent::UserPrompt, @@ -3706,7 +3707,7 @@ mod tests { use util::path; use workspace::CollaboratorId; - use crate::{ContextLoadResult, thread_store}; + use crate::{ContextLoadResult, thread::MessageSegment, thread_store}; use super::*; @@ -3840,6 +3841,114 @@ mod tests { }); } + #[gpui::test] + async fn test_editing_message_cancels_previous_completion(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (cx, active_thread, _, thread, model) = + setup_test_environment(cx, project.clone()).await; + + cx.update(|_, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model( + Some(ConfiguredModel { + provider: Arc::new(FakeLanguageModelProvider), + model: model.clone(), + }), + cx, + ); + }); + }); + + // Track thread events to verify cancellation + let cancellation_events = Arc::new(std::sync::Mutex::new(Vec::new())); + let new_request_events = Arc::new(std::sync::Mutex::new(Vec::new())); + + let _subscription = cx.update(|_, cx| { + let cancellation_events = cancellation_events.clone(); + let new_request_events = new_request_events.clone(); + cx.subscribe( + &thread, + move |_thread, event: &ThreadEvent, _cx| match event { + ThreadEvent::CompletionCanceled => { + cancellation_events.lock().unwrap().push(()); + } + ThreadEvent::NewRequest => { + new_request_events.lock().unwrap().push(()); + } + _ => {} + }, + ) + }); + + // Insert a user message and start streaming a response + let message = thread.update(cx, |thread, cx| { + let message_id = thread.insert_user_message( + "Hello, how are you?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + thread.advance_prompt_id(); + thread.send_to_model( + model.clone(), + CompletionIntent::UserPrompt, + cx.active_window(), + cx, + ); + thread.message(message_id).cloned().unwrap() + }); + + cx.run_until_parked(); + + // Verify that a completion is in progress + assert!(cx.read(|cx| thread.read(cx).is_generating())); + assert_eq!(new_request_events.lock().unwrap().len(), 1); + + // Edit the message while the completion is still running + active_thread.update_in(cx, |active_thread, window, cx| { + active_thread.start_editing_message( + message.id, + message.segments.as_slice(), + message.creases.as_slice(), + window, + cx, + ); + let editor = active_thread + .editing_message + .as_ref() + .unwrap() + .1 + .editor + .clone(); + editor.update(cx, |editor, cx| { + editor.set_text("What is the weather like?", window, cx); + }); + active_thread.confirm_editing_message(&Default::default(), window, cx); + }); + + cx.run_until_parked(); + + // Verify that the previous completion was cancelled + assert_eq!(cancellation_events.lock().unwrap().len(), 1); + + // Verify that a new request was started after cancellation + assert_eq!(new_request_events.lock().unwrap().len(), 2); + + // Verify that the edited message contains the new text + let edited_message = + thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap()); + match &edited_message.segments[0] { + MessageSegment::Text(text) => { + assert_eq!(text, "What is the weather like?"); + } + _ => panic!("Expected text segment"), + } + } + fn init_test_settings(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx);