agent: Unfollow agent on completion cancellation (#31258)

Ben Brandt created

Handle unfollowing agent and clearing agent location when completion
is canceled.

Release Notes:

- N/A

Change summary

crates/agent/src/active_thread.rs | 179 ++++++++++++++++++++++++++++++++
1 file changed, 177 insertions(+), 2 deletions(-)

Detailed changes

crates/agent/src/active_thread.rs 🔗

@@ -52,7 +52,7 @@ use ui::{
 };
 use util::ResultExt as _;
 use util::markdown::MarkdownCodeBlock;
-use workspace::Workspace;
+use workspace::{CollaboratorId, Workspace};
 use zed_actions::assistant::OpenRulesLibrary;
 
 pub struct ActiveThread {
@@ -971,7 +971,22 @@ impl ActiveThread {
             ThreadEvent::ShowError(error) => {
                 self.last_error = Some(error.clone());
             }
-            ThreadEvent::NewRequest | ThreadEvent::CompletionCanceled => {
+            ThreadEvent::NewRequest => {
+                cx.notify();
+            }
+            ThreadEvent::CompletionCanceled => {
+                self.thread.update(cx, |thread, cx| {
+                    thread.project().update(cx, |project, cx| {
+                        project.set_agent_location(None, cx);
+                    })
+                });
+                self.workspace
+                    .update(cx, |workspace, cx| {
+                        if workspace.is_being_followed(CollaboratorId::Agent) {
+                            workspace.unfollow(CollaboratorId::Agent, window, cx);
+                        }
+                    })
+                    .ok();
                 cx.notify();
             }
             ThreadEvent::StreamedCompletion
@@ -3593,3 +3608,163 @@ fn open_editor_at_position(
         }
     })
 }
+
+#[cfg(test)]
+mod tests {
+    use assistant_tool::{ToolRegistry, ToolWorkingSet};
+    use editor::EditorSettings;
+    use fs::FakeFs;
+    use gpui::{AppContext, TestAppContext, VisualTestContext};
+    use language_model::{LanguageModel, fake_provider::FakeLanguageModel};
+    use project::Project;
+    use prompt_store::PromptBuilder;
+    use serde_json::json;
+    use settings::SettingsStore;
+    use util::path;
+    use workspace::CollaboratorId;
+
+    use crate::{ContextLoadResult, thread_store};
+
+    use super::*;
+
+    #[gpui::test]
+    async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) {
+        init_test_settings(cx);
+
+        let project = create_test_project(
+            cx,
+            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
+        )
+        .await;
+
+        let (cx, _active_thread, workspace, thread, model) =
+            setup_test_environment(cx, project.clone()).await;
+
+        // Insert user message without any context (empty context vector)
+        thread.update(cx, |thread, cx| {
+            thread.insert_user_message(
+                "What is the best way to learn Rust?",
+                ContextLoadResult::default(),
+                None,
+                vec![],
+                cx,
+            );
+        });
+
+        // Stream response to user message
+        thread.update(cx, |thread, cx| {
+            let request = thread.to_completion_request(model.clone(), cx);
+            thread.stream_completion(request, model, cx.active_window(), cx)
+        });
+        // Follow the agent
+        cx.update(|window, cx| {
+            workspace.update(cx, |workspace, cx| {
+                workspace.follow(CollaboratorId::Agent, window, cx);
+            })
+        });
+        assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
+
+        // Cancel the current completion
+        thread.update(cx, |thread, cx| {
+            thread.cancel_last_completion(cx.active_window(), cx)
+        });
+
+        cx.executor().run_until_parked();
+
+        // No longer following the agent
+        assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
+    }
+
+    fn init_test_settings(cx: &mut TestAppContext) {
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            language::init(cx);
+            Project::init_settings(cx);
+            AssistantSettings::register(cx);
+            prompt_store::init(cx);
+            thread_store::init(cx);
+            workspace::init_settings(cx);
+            language_model::init_settings(cx);
+            ThemeSettings::register(cx);
+            EditorSettings::register(cx);
+            ToolRegistry::default_global(cx);
+        });
+    }
+
+    // Helper to create a test project with test files
+    async fn create_test_project(
+        cx: &mut TestAppContext,
+        files: serde_json::Value,
+    ) -> Entity<Project> {
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/test"), files).await;
+        Project::test(fs, [path!("/test").as_ref()], cx).await
+    }
+
+    async fn setup_test_environment(
+        cx: &mut TestAppContext,
+        project: Entity<Project>,
+    ) -> (
+        &mut VisualTestContext,
+        Entity<ActiveThread>,
+        Entity<Workspace>,
+        Entity<Thread>,
+        Arc<dyn LanguageModel>,
+    ) {
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+        let thread_store = cx
+            .update(|_, cx| {
+                ThreadStore::load(
+                    project.clone(),
+                    cx.new(|_| ToolWorkingSet::default()),
+                    None,
+                    Arc::new(PromptBuilder::new(None).unwrap()),
+                    cx,
+                )
+            })
+            .await
+            .unwrap();
+
+        let text_thread_store = cx
+            .update(|_, cx| {
+                TextThreadStore::new(
+                    project.clone(),
+                    Arc::new(PromptBuilder::new(None).unwrap()),
+                    Default::default(),
+                    cx,
+                )
+            })
+            .await
+            .unwrap();
+
+        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
+        let context_store =
+            cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
+
+        let model = FakeLanguageModel::default();
+        let model: Arc<dyn LanguageModel> = Arc::new(model);
+
+        let language_registry = LanguageRegistry::new(cx.executor());
+        let language_registry = Arc::new(language_registry);
+
+        let active_thread = cx.update(|window, cx| {
+            cx.new(|cx| {
+                ActiveThread::new(
+                    thread.clone(),
+                    thread_store.clone(),
+                    text_thread_store,
+                    context_store.clone(),
+                    language_registry.clone(),
+                    workspace.downgrade(),
+                    window,
+                    cx,
+                )
+            })
+        });
+
+        (cx, active_thread, workspace, thread, model)
+    }
+}