@@ -1605,6 +1605,7 @@ impl ActiveThread {
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
+ thread.cancel_last_completion(Some(window.window_handle()), cx);
thread.send_to_model(
model.model,
CompletionIntent::UserPrompt,
@@ -3706,7 +3707,7 @@ mod tests {
use util::path;
use workspace::CollaboratorId;
- use crate::{ContextLoadResult, thread_store};
+ use crate::{ContextLoadResult, thread::MessageSegment, thread_store};
use super::*;
@@ -3840,6 +3841,114 @@ mod tests {
});
}
+ #[gpui::test]
+ async fn test_editing_message_cancels_previous_completion(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let project = create_test_project(cx, json!({})).await;
+
+ let (cx, active_thread, _, thread, model) =
+ setup_test_environment(cx, project.clone()).await;
+
+ cx.update(|_, cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.set_default_model(
+ Some(ConfiguredModel {
+ provider: Arc::new(FakeLanguageModelProvider),
+ model: model.clone(),
+ }),
+ cx,
+ );
+ });
+ });
+
+ // Track thread events to verify cancellation
+ let cancellation_events = Arc::new(std::sync::Mutex::new(Vec::new()));
+ let new_request_events = Arc::new(std::sync::Mutex::new(Vec::new()));
+
+ let _subscription = cx.update(|_, cx| {
+ let cancellation_events = cancellation_events.clone();
+ let new_request_events = new_request_events.clone();
+ cx.subscribe(
+ &thread,
+ move |_thread, event: &ThreadEvent, _cx| match event {
+ ThreadEvent::CompletionCanceled => {
+ cancellation_events.lock().unwrap().push(());
+ }
+ ThreadEvent::NewRequest => {
+ new_request_events.lock().unwrap().push(());
+ }
+ _ => {}
+ },
+ )
+ });
+
+ // Insert a user message and start streaming a response
+ let message = thread.update(cx, |thread, cx| {
+ let message_id = thread.insert_user_message(
+ "Hello, how are you?",
+ ContextLoadResult::default(),
+ None,
+ vec![],
+ cx,
+ );
+ thread.advance_prompt_id();
+ thread.send_to_model(
+ model.clone(),
+ CompletionIntent::UserPrompt,
+ cx.active_window(),
+ cx,
+ );
+ thread.message(message_id).cloned().unwrap()
+ });
+
+ cx.run_until_parked();
+
+ // Verify that a completion is in progress
+ assert!(cx.read(|cx| thread.read(cx).is_generating()));
+ assert_eq!(new_request_events.lock().unwrap().len(), 1);
+
+ // Edit the message while the completion is still running
+ active_thread.update_in(cx, |active_thread, window, cx| {
+ active_thread.start_editing_message(
+ message.id,
+ message.segments.as_slice(),
+ message.creases.as_slice(),
+ window,
+ cx,
+ );
+ let editor = active_thread
+ .editing_message
+ .as_ref()
+ .unwrap()
+ .1
+ .editor
+ .clone();
+ editor.update(cx, |editor, cx| {
+ editor.set_text("What is the weather like?", window, cx);
+ });
+ active_thread.confirm_editing_message(&Default::default(), window, cx);
+ });
+
+ cx.run_until_parked();
+
+ // Verify that the previous completion was cancelled
+ assert_eq!(cancellation_events.lock().unwrap().len(), 1);
+
+ // Verify that a new request was started after cancellation
+ assert_eq!(new_request_events.lock().unwrap().len(), 2);
+
+ // Verify that the edited message contains the new text
+ let edited_message =
+ thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
+ match &edited_message.segments[0] {
+ MessageSegment::Text(text) => {
+ assert_eq!(text, "What is the weather like?");
+ }
+ _ => panic!("Expected text segment"),
+ }
+ }
+
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);