agent_ui: Use selected agent for new threads (#52888)

Ben Brandt , Bennet Bo Fenner , and MrSubidubi created

Persist the last used agent globally as a fallback for new
workspaces, keep per-workspace selections independent. This should mean
"new thread" should grab whatever agent you are currently looking at,
and won't leak across projects.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Release Notes:

- agent: Prefer the currently used agent per-project when creating a new
thread.

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: MrSubidubi <dev@bahn.sh>

Change summary

crates/agent_ui/src/agent_panel.rs | 386 ++++++++++++++++++++++++-------
1 file changed, 295 insertions(+), 91 deletions(-)

Detailed changes

crates/agent_ui/src/agent_panel.rs 🔗

@@ -86,6 +86,30 @@ use zed_actions::{
 
 const AGENT_PANEL_KEY: &str = "agent_panel";
 const RECENTLY_UPDATED_MENU_LIMIT: usize = 6;
+const LAST_USED_AGENT_KEY: &str = "agent_panel__last_used_external_agent";
+
+#[derive(Serialize, Deserialize)]
+struct LastUsedAgent {
+    agent: Agent,
+}
+
+/// Reads the most recently used agent across all workspaces. Used as a fallback
+/// when opening a workspace that has no per-workspace agent preference yet.
+fn read_global_last_used_agent(kvp: &KeyValueStore) -> Option<Agent> {
+    kvp.read_kvp(LAST_USED_AGENT_KEY)
+        .log_err()
+        .flatten()
+        .and_then(|json| serde_json::from_str::<LastUsedAgent>(&json).log_err())
+        .map(|entry| entry.agent)
+}
+
+async fn write_global_last_used_agent(kvp: KeyValueStore, agent: Agent) {
+    if let Some(json) = serde_json::to_string(&LastUsedAgent { agent }).log_err() {
+        kvp.write_kvp(LAST_USED_AGENT_KEY.to_string(), json)
+            .await
+            .log_err();
+    }
+}
 
 fn read_serialized_panel(
     workspace_id: workspace::WorkspaceId,
@@ -665,13 +689,18 @@ impl AgentPanel {
                 .ok()
                 .flatten();
 
-            let serialized_panel = cx
+            let (serialized_panel, global_last_used_agent) = cx
                 .background_spawn(async move {
-                    kvp.and_then(|kvp| {
-                        workspace_id
-                            .and_then(|id| read_serialized_panel(id, &kvp))
-                            .or_else(|| read_legacy_serialized_panel(&kvp))
-                    })
+                    match kvp {
+                        Some(kvp) => {
+                            let panel = workspace_id
+                                .and_then(|id| read_serialized_panel(id, &kvp))
+                                .or_else(|| read_legacy_serialized_panel(&kvp));
+                            let global_agent = read_global_last_used_agent(&kvp);
+                            (panel, global_agent)
+                        }
+                        None => (None, None),
+                    }
                 })
                 .await;
 
@@ -710,10 +739,21 @@ impl AgentPanel {
                 let panel =
                     cx.new(|cx| Self::new(workspace, prompt_store, window, cx));
 
-                if let Some(serialized_panel) = &serialized_panel {
-                    panel.update(cx, |panel, cx| {
+                panel.update(cx, |panel, cx| {
+                    let is_via_collab = panel.project.read(cx).is_via_collab();
+
+                    // Only apply a non-native global fallback to local projects.
+                    // Collab workspaces only support NativeAgent, so inheriting a
+                    // custom agent would cause set_active → new_agent_thread_inner
+                    // to bypass the collab guard in external_thread.
+                    let global_fallback = global_last_used_agent
+                        .filter(|agent| !is_via_collab || agent.is_native());
+
+                    if let Some(serialized_panel) = &serialized_panel {
                         if let Some(selected_agent) = serialized_panel.selected_agent.clone() {
                             panel.selected_agent = selected_agent;
+                        } else if let Some(agent) = global_fallback {
+                            panel.selected_agent = agent;
                         }
                         if let Some(start_thread_in) = serialized_panel.start_thread_in {
                             let is_worktree_flag_enabled =
@@ -734,9 +774,11 @@ impl AgentPanel {
                                 );
                             }
                         }
-                        cx.notify();
-                    });
-                }
+                    } else if let Some(agent) = global_fallback {
+                        panel.selected_agent = agent;
+                    }
+                    cx.notify();
+                });
 
                 if let Some(thread_info) = last_active_thread {
                     let agent = thread_info.agent_type.clone();
@@ -1069,85 +1111,30 @@ impl AgentPanel {
         let workspace = self.workspace.clone();
         let project = self.project.clone();
         let fs = self.fs.clone();
-        let is_via_collab = self.project.read(cx).is_via_collab();
-
-        const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent";
-
-        #[derive(Serialize, Deserialize)]
-        struct LastUsedExternalAgent {
-            agent: crate::Agent,
-        }
-
         let thread_store = self.thread_store.clone();
-        let kvp = KeyValueStore::global(cx);
-
-        if let Some(agent) = agent_choice {
-            cx.background_spawn({
-                let agent = agent.clone();
-                let kvp = kvp;
-                async move {
-                    if let Some(serialized) =
-                        serde_json::to_string(&LastUsedExternalAgent { agent }).log_err()
-                    {
-                        kvp.write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized)
-                            .await
-                            .log_err();
-                    }
-                }
-            })
-            .detach();
-
-            let server = agent.server(fs, thread_store);
-            self.create_agent_thread(
-                server,
-                resume_session_id,
-                work_dirs,
-                title,
-                initial_content,
-                workspace,
-                project,
-                agent,
-                focus,
-                window,
-                cx,
-            );
-        } else {
-            cx.spawn_in(window, async move |this, cx| {
-                let ext_agent = if is_via_collab {
-                    Agent::NativeAgent
-                } else {
-                    cx.background_spawn(async move { kvp.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY) })
-                        .await
-                        .log_err()
-                        .flatten()
-                        .and_then(|value| {
-                            serde_json::from_str::<LastUsedExternalAgent>(&value).log_err()
-                        })
-                        .map(|agent| agent.agent)
-                        .unwrap_or(Agent::NativeAgent)
-                };
 
-                let server = ext_agent.server(fs, thread_store);
-                this.update_in(cx, |agent_panel, window, cx| {
-                    agent_panel.create_agent_thread(
-                        server,
-                        resume_session_id,
-                        work_dirs,
-                        title,
-                        initial_content,
-                        workspace,
-                        project,
-                        ext_agent,
-                        focus,
-                        window,
-                        cx,
-                    );
-                })?;
+        let agent = agent_choice.unwrap_or_else(|| {
+            if self.project.read(cx).is_via_collab() {
+                Agent::NativeAgent
+            } else {
+                self.selected_agent.clone()
+            }
+        });
 
-                anyhow::Ok(())
-            })
-            .detach_and_log_err(cx);
-        }
+        let server = agent.server(fs, thread_store);
+        self.create_agent_thread(
+            server,
+            resume_session_id,
+            work_dirs,
+            title,
+            initial_content,
+            workspace,
+            project,
+            agent,
+            focus,
+            window,
+            cx,
+        );
     }
 
     fn deploy_rules_library(
@@ -2102,15 +2089,25 @@ impl AgentPanel {
         initial_content: Option<AgentInitialContent>,
         workspace: WeakEntity<Workspace>,
         project: Entity<Project>,
-        ext_agent: Agent,
+        agent: Agent,
         focus: bool,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        if self.selected_agent != ext_agent {
-            self.selected_agent = ext_agent.clone();
+        if self.selected_agent != agent {
+            self.selected_agent = agent.clone();
             self.serialize(cx);
         }
+
+        cx.background_spawn({
+            let kvp = KeyValueStore::global(cx);
+            let agent = agent.clone();
+            async move {
+                write_global_last_used_agent(kvp, agent).await;
+            }
+        })
+        .detach();
+
         let thread_store = server
             .clone()
             .downcast::<agent::NativeAgentServer>()
@@ -2123,7 +2120,7 @@ impl AgentPanel {
             crate::ConversationView::new(
                 server,
                 connection_store,
-                ext_agent,
+                agent,
                 resume_session_id,
                 work_dirs,
                 title,
@@ -5611,4 +5608,211 @@ mod tests {
             "Thread A work_dirs should revert to only /project_a after removing /project_b"
         );
     }
+
+    #[gpui::test]
+    async fn test_new_workspace_inherits_global_last_used_agent(cx: &mut TestAppContext) {
+        init_test(cx);
+        cx.update(|cx| {
+            cx.update_flags(true, vec!["agent-v2".to_string()]);
+            agent::ThreadStore::init_global(cx);
+            language_model::LanguageModelRegistry::test(cx);
+            // Use an isolated DB so parallel tests can't overwrite our global key.
+            cx.set_global(db::AppDatabase::test_new());
+        });
+
+        let custom_agent = Agent::Custom {
+            id: "my-preferred-agent".into(),
+        };
+
+        // Write a known agent to the global KVP to simulate a user who has
+        // previously used this agent in another workspace.
+        let kvp = cx.update(|cx| KeyValueStore::global(cx));
+        write_global_last_used_agent(kvp, custom_agent.clone()).await;
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs.clone(), [], cx).await;
+
+        let multi_workspace =
+            cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+
+        let workspace = multi_workspace
+            .read_with(cx, |multi_workspace, _cx| {
+                multi_workspace.workspace().clone()
+            })
+            .unwrap();
+
+        workspace.update(cx, |workspace, _cx| {
+            workspace.set_random_database_id();
+        });
+
+        let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx);
+
+        // Load the panel via `load()`, which reads the global fallback
+        // asynchronously when no per-workspace state exists.
+        let async_cx = cx.update(|window, cx| window.to_async(cx));
+        let panel = AgentPanel::load(workspace.downgrade(), async_cx)
+            .await
+            .expect("panel load should succeed");
+        cx.run_until_parked();
+
+        panel.read_with(cx, |panel, _cx| {
+            assert_eq!(
+                panel.selected_agent, custom_agent,
+                "new workspace should inherit the global last-used agent"
+            );
+        });
+    }
+
+    #[gpui::test]
+    async fn test_workspaces_maintain_independent_agent_selection(cx: &mut TestAppContext) {
+        init_test(cx);
+        cx.update(|cx| {
+            cx.update_flags(true, vec!["agent-v2".to_string()]);
+            agent::ThreadStore::init_global(cx);
+            language_model::LanguageModelRegistry::test(cx);
+        });
+
+        let fs = FakeFs::new(cx.executor());
+        let project_a = Project::test(fs.clone(), [], cx).await;
+        let project_b = Project::test(fs, [], cx).await;
+
+        let multi_workspace =
+            cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
+
+        let workspace_a = multi_workspace
+            .read_with(cx, |multi_workspace, _cx| {
+                multi_workspace.workspace().clone()
+            })
+            .unwrap();
+
+        let workspace_b = multi_workspace
+            .update(cx, |multi_workspace, window, cx| {
+                multi_workspace.test_add_workspace(project_b.clone(), window, cx)
+            })
+            .unwrap();
+
+        workspace_a.update(cx, |workspace, _cx| {
+            workspace.set_random_database_id();
+        });
+        workspace_b.update(cx, |workspace, _cx| {
+            workspace.set_random_database_id();
+        });
+
+        let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx);
+
+        let agent_a = Agent::Custom {
+            id: "agent-alpha".into(),
+        };
+        let agent_b = Agent::Custom {
+            id: "agent-beta".into(),
+        };
+
+        // Set up workspace A with agent_a
+        let panel_a = workspace_a.update_in(cx, |workspace, window, cx| {
+            cx.new(|cx| AgentPanel::new(workspace, None, window, cx))
+        });
+        panel_a.update(cx, |panel, _cx| {
+            panel.selected_agent = agent_a.clone();
+        });
+
+        // Set up workspace B with agent_b
+        let panel_b = workspace_b.update_in(cx, |workspace, window, cx| {
+            cx.new(|cx| AgentPanel::new(workspace, None, window, cx))
+        });
+        panel_b.update(cx, |panel, _cx| {
+            panel.selected_agent = agent_b.clone();
+        });
+
+        // Serialize both panels
+        panel_a.update(cx, |panel, cx| panel.serialize(cx));
+        panel_b.update(cx, |panel, cx| panel.serialize(cx));
+        cx.run_until_parked();
+
+        // Load fresh panels from serialized state and verify independence
+        let async_cx = cx.update(|window, cx| window.to_async(cx));
+        let loaded_a = AgentPanel::load(workspace_a.downgrade(), async_cx)
+            .await
+            .expect("panel A load should succeed");
+        cx.run_until_parked();
+
+        let async_cx = cx.update(|window, cx| window.to_async(cx));
+        let loaded_b = AgentPanel::load(workspace_b.downgrade(), async_cx)
+            .await
+            .expect("panel B load should succeed");
+        cx.run_until_parked();
+
+        loaded_a.read_with(cx, |panel, _cx| {
+            assert_eq!(
+                panel.selected_agent, agent_a,
+                "workspace A should restore agent-alpha, not agent-beta"
+            );
+        });
+
+        loaded_b.read_with(cx, |panel, _cx| {
+            assert_eq!(
+                panel.selected_agent, agent_b,
+                "workspace B should restore agent-beta, not agent-alpha"
+            );
+        });
+    }
+
+    #[gpui::test]
+    async fn test_new_thread_uses_workspace_selected_agent(cx: &mut TestAppContext) {
+        init_test(cx);
+        cx.update(|cx| {
+            cx.update_flags(true, vec!["agent-v2".to_string()]);
+            agent::ThreadStore::init_global(cx);
+            language_model::LanguageModelRegistry::test(cx);
+        });
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs.clone(), [], cx).await;
+
+        let multi_workspace =
+            cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+
+        let workspace = multi_workspace
+            .read_with(cx, |multi_workspace, _cx| {
+                multi_workspace.workspace().clone()
+            })
+            .unwrap();
+
+        workspace.update(cx, |workspace, _cx| {
+            workspace.set_random_database_id();
+        });
+
+        let cx = &mut VisualTestContext::from_window(multi_workspace.into(), cx);
+
+        let custom_agent = Agent::Custom {
+            id: "my-custom-agent".into(),
+        };
+
+        let panel = workspace.update_in(cx, |workspace, window, cx| {
+            let panel = cx.new(|cx| AgentPanel::new(workspace, None, window, cx));
+            workspace.add_panel(panel.clone(), window, cx);
+            panel
+        });
+
+        // Set selected_agent to a custom agent
+        panel.update(cx, |panel, _cx| {
+            panel.selected_agent = custom_agent.clone();
+        });
+
+        // Call new_thread, which internally calls external_thread(None, ...)
+        // This resolves the agent from self.selected_agent
+        panel.update_in(cx, |panel, window, cx| {
+            panel.new_thread(&NewThread, window, cx);
+        });
+
+        panel.read_with(cx, |panel, _cx| {
+            assert_eq!(
+                panel.selected_agent, custom_agent,
+                "selected_agent should remain the custom agent after new_thread"
+            );
+            assert!(
+                panel.active_conversation_view().is_some(),
+                "a thread should have been created"
+            );
+        });
+    }
 }