diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index e704a4630f054ae340f5b6d9df6a118a4e6362df..0454a2772d758fe1bfd928027a5f760636088081 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1241,6 +1241,9 @@ impl ActiveThread { return; }; + // Cancel any ongoing streaming when user starts editing a previous message + self.cancel_last_completion(window, cx); + let editor = crate::message_editor::create_editor( self.workspace.clone(), self.context_store.downgrade(), @@ -3464,3 +3467,146 @@ fn open_editor_at_position( } }) } + +#[cfg(test)] +mod tests { + use assistant_tool::{ToolRegistry, ToolWorkingSet}; + use context_server::ContextServerSettings; + use editor::EditorSettings; + use fs::FakeFs; + use gpui::{TestAppContext, VisualTestContext}; + use language_model::{LanguageModel, fake_provider::FakeLanguageModel}; + use project::Project; + use prompt_store::PromptBuilder; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + use crate::{ContextLoadResult, thread_store}; + + use super::*; + + #[gpui::test] + async fn test_current_completion_cancelled_when_message_edited(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (cx, active_thread, thread, model) = setup_test_environment(cx, project.clone()).await; + + // Insert user message without any context (empty context vector) + let message = thread.update(cx, |thread, cx| { + let message_id = thread.insert_user_message( + "What is the best way to learn Rust?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + thread + .message(message_id) + .expect("message should exist") + .clone() + }); + + // Stream response to user message + thread.update(cx, |thread, cx| { + let request = thread.to_completion_request(model.clone(), cx); + thread.stream_completion(request, model, cx.active_window(), cx) + }); + let generating = thread.update(cx, |thread, _cx| thread.is_generating()); + assert!(generating, "There should be one pending completion"); + + // Edit the previous message + active_thread.update_in(cx, |active_thread, window, cx| { + active_thread.start_editing_message(message.id, &message.segments, window, cx); + }); + + // Check that the stream was cancelled + let generating = thread.update(cx, |thread, _cx| thread.is_generating()); + assert!(!generating, "The completion should have been cancelled"); + } + + fn init_test_settings(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AssistantSettings::register(cx); + prompt_store::init(cx); + thread_store::init(cx); + workspace::init_settings(cx); + language_model::init_settings(cx); + ThemeSettings::register(cx); + ContextServerSettings::register(cx); + EditorSettings::register(cx); + ToolRegistry::default_global(cx); + }); + } + + // Helper to create a test project with test files + async fn create_test_project( + cx: &mut TestAppContext, + files: serde_json::Value, + ) -> Entity { + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), files).await; + Project::test(fs, [path!("/test").as_ref()], cx).await + } + + async fn setup_test_environment( + cx: &mut TestAppContext, + project: Entity, + ) -> ( + &mut VisualTestContext, + Entity, + Entity, + Arc, + ) { + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_store = cx + .update(|_, cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .await + .unwrap(); + + let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); + + let model = FakeLanguageModel::default(); + let model: Arc = Arc::new(model); + + let language_registry = LanguageRegistry::new(cx.executor()); + let language_registry = Arc::new(language_registry); + + let active_thread = cx.update(|window, cx| { + cx.new(|cx| { + ActiveThread::new( + thread.clone(), + thread_store.clone(), + context_store.clone(), + language_registry.clone(), + workspace.downgrade(), + window, + cx, + ) + }) + }); + + (cx, active_thread, thread, model) + } +}