@@ -1241,6 +1241,9 @@ impl ActiveThread {
return;
};
+ // Cancel any ongoing streaming when user starts editing a previous message
+ self.cancel_last_completion(window, cx);
+
let editor = crate::message_editor::create_editor(
self.workspace.clone(),
self.context_store.downgrade(),
@@ -3464,3 +3467,146 @@ fn open_editor_at_position(
}
})
}
+
+#[cfg(test)]
+mod tests {
+ use assistant_tool::{ToolRegistry, ToolWorkingSet};
+ use context_server::ContextServerSettings;
+ use editor::EditorSettings;
+ use fs::FakeFs;
+ use gpui::{TestAppContext, VisualTestContext};
+ use language_model::{LanguageModel, fake_provider::FakeLanguageModel};
+ use project::Project;
+ use prompt_store::PromptBuilder;
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+
+ use crate::{ContextLoadResult, thread_store};
+
+ use super::*;
+
+ #[gpui::test]
+ async fn test_current_completion_cancelled_when_message_edited(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let project = create_test_project(
+ cx,
+ json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
+ )
+ .await;
+
+ let (cx, active_thread, thread, model) = setup_test_environment(cx, project.clone()).await;
+
+ // Insert user message without any context (empty context vector)
+ let message = thread.update(cx, |thread, cx| {
+ let message_id = thread.insert_user_message(
+ "What is the best way to learn Rust?",
+ ContextLoadResult::default(),
+ None,
+ vec![],
+ cx,
+ );
+ thread
+ .message(message_id)
+ .expect("message should exist")
+ .clone()
+ });
+
+ // Stream response to user message
+ thread.update(cx, |thread, cx| {
+ let request = thread.to_completion_request(model.clone(), cx);
+ thread.stream_completion(request, model, cx.active_window(), cx)
+ });
+ let generating = thread.update(cx, |thread, _cx| thread.is_generating());
+ assert!(generating, "There should be one pending completion");
+
+ // Edit the previous message
+ active_thread.update_in(cx, |active_thread, window, cx| {
+ active_thread.start_editing_message(message.id, &message.segments, window, cx);
+ });
+
+ // Check that the stream was cancelled
+ let generating = thread.update(cx, |thread, _cx| thread.is_generating());
+ assert!(!generating, "The completion should have been cancelled");
+ }
+
+ fn init_test_settings(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ language::init(cx);
+ Project::init_settings(cx);
+ AssistantSettings::register(cx);
+ prompt_store::init(cx);
+ thread_store::init(cx);
+ workspace::init_settings(cx);
+ language_model::init_settings(cx);
+ ThemeSettings::register(cx);
+ ContextServerSettings::register(cx);
+ EditorSettings::register(cx);
+ ToolRegistry::default_global(cx);
+ });
+ }
+
+ // Helper to create a test project with test files
+ async fn create_test_project(
+ cx: &mut TestAppContext,
+ files: serde_json::Value,
+ ) -> Entity<Project> {
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), files).await;
+ Project::test(fs, [path!("/test").as_ref()], cx).await
+ }
+
+ async fn setup_test_environment(
+ cx: &mut TestAppContext,
+ project: Entity<Project>,
+ ) -> (
+ &mut VisualTestContext,
+ Entity<ActiveThread>,
+ Entity<Thread>,
+ Arc<dyn LanguageModel>,
+ ) {
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let thread_store = cx
+ .update(|_, cx| {
+ ThreadStore::load(
+ project.clone(),
+ cx.new(|_| ToolWorkingSet::default()),
+ None,
+ Arc::new(PromptBuilder::new(None).unwrap()),
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+
+ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
+ let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
+
+ let model = FakeLanguageModel::default();
+ let model: Arc<dyn LanguageModel> = Arc::new(model);
+
+ let language_registry = LanguageRegistry::new(cx.executor());
+ let language_registry = Arc::new(language_registry);
+
+ let active_thread = cx.update(|window, cx| {
+ cx.new(|cx| {
+ ActiveThread::new(
+ thread.clone(),
+ thread_store.clone(),
+ context_store.clone(),
+ language_registry.clone(),
+ workspace.downgrade(),
+ window,
+ cx,
+ )
+ })
+ });
+
+ (cx, active_thread, thread, model)
+ }
+}