@@ -52,7 +52,7 @@ use ui::{
};
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
-use workspace::Workspace;
+use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
pub struct ActiveThread {
@@ -971,7 +971,22 @@ impl ActiveThread {
ThreadEvent::ShowError(error) => {
self.last_error = Some(error.clone());
}
- ThreadEvent::NewRequest | ThreadEvent::CompletionCanceled => {
+ ThreadEvent::NewRequest => {
+ cx.notify();
+ }
+ ThreadEvent::CompletionCanceled => {
+ self.thread.update(cx, |thread, cx| {
+ thread.project().update(cx, |project, cx| {
+ project.set_agent_location(None, cx);
+ })
+ });
+ self.workspace
+ .update(cx, |workspace, cx| {
+ if workspace.is_being_followed(CollaboratorId::Agent) {
+ workspace.unfollow(CollaboratorId::Agent, window, cx);
+ }
+ })
+ .ok();
cx.notify();
}
ThreadEvent::StreamedCompletion
@@ -3593,3 +3608,163 @@ fn open_editor_at_position(
}
})
}
+
+#[cfg(test)]
+mod tests {
+ use assistant_tool::{ToolRegistry, ToolWorkingSet};
+ use editor::EditorSettings;
+ use fs::FakeFs;
+ use gpui::{AppContext, TestAppContext, VisualTestContext};
+ use language_model::{LanguageModel, fake_provider::FakeLanguageModel};
+ use project::Project;
+ use prompt_store::PromptBuilder;
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+ use workspace::CollaboratorId;
+
+ use crate::{ContextLoadResult, thread_store};
+
+ use super::*;
+
+ #[gpui::test]
+ async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let project = create_test_project(
+ cx,
+ json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
+ )
+ .await;
+
+ let (cx, _active_thread, workspace, thread, model) =
+ setup_test_environment(cx, project.clone()).await;
+
+ // Insert user message without any context (empty context vector)
+ thread.update(cx, |thread, cx| {
+ thread.insert_user_message(
+ "What is the best way to learn Rust?",
+ ContextLoadResult::default(),
+ None,
+ vec![],
+ cx,
+ );
+ });
+
+ // Stream response to user message
+ thread.update(cx, |thread, cx| {
+ let request = thread.to_completion_request(model.clone(), cx);
+ thread.stream_completion(request, model, cx.active_window(), cx)
+ });
+ // Follow the agent
+ cx.update(|window, cx| {
+ workspace.update(cx, |workspace, cx| {
+ workspace.follow(CollaboratorId::Agent, window, cx);
+ })
+ });
+ assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
+
+ // Cancel the current completion
+ thread.update(cx, |thread, cx| {
+ thread.cancel_last_completion(cx.active_window(), cx)
+ });
+
+ cx.executor().run_until_parked();
+
+ // No longer following the agent
+ assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
+ }
+
+ fn init_test_settings(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ language::init(cx);
+ Project::init_settings(cx);
+ AssistantSettings::register(cx);
+ prompt_store::init(cx);
+ thread_store::init(cx);
+ workspace::init_settings(cx);
+ language_model::init_settings(cx);
+ ThemeSettings::register(cx);
+ EditorSettings::register(cx);
+ ToolRegistry::default_global(cx);
+ });
+ }
+
+ // Helper to create a test project with test files
+ async fn create_test_project(
+ cx: &mut TestAppContext,
+ files: serde_json::Value,
+ ) -> Entity<Project> {
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), files).await;
+ Project::test(fs, [path!("/test").as_ref()], cx).await
+ }
+
+ async fn setup_test_environment(
+ cx: &mut TestAppContext,
+ project: Entity<Project>,
+ ) -> (
+ &mut VisualTestContext,
+ Entity<ActiveThread>,
+ Entity<Workspace>,
+ Entity<Thread>,
+ Arc<dyn LanguageModel>,
+ ) {
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let thread_store = cx
+ .update(|_, cx| {
+ ThreadStore::load(
+ project.clone(),
+ cx.new(|_| ToolWorkingSet::default()),
+ None,
+ Arc::new(PromptBuilder::new(None).unwrap()),
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+
+ let text_thread_store = cx
+ .update(|_, cx| {
+ TextThreadStore::new(
+ project.clone(),
+ Arc::new(PromptBuilder::new(None).unwrap()),
+ Default::default(),
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+
+ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
+ let context_store =
+ cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
+
+ let model = FakeLanguageModel::default();
+ let model: Arc<dyn LanguageModel> = Arc::new(model);
+
+ let language_registry = LanguageRegistry::new(cx.executor());
+ let language_registry = Arc::new(language_registry);
+
+ let active_thread = cx.update(|window, cx| {
+ cx.new(|cx| {
+ ActiveThread::new(
+ thread.clone(),
+ thread_store.clone(),
+ text_thread_store,
+ context_store.clone(),
+ language_registry.clone(),
+ workspace.downgrade(),
+ window,
+ cx,
+ )
+ })
+ });
+
+ (cx, active_thread, workspace, thread, model)
+ }
+}