From 7a6674d5dca0215dbf46a91a208fb9327efd0025 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Tue, 10 Feb 2026 12:26:01 +0100 Subject: [PATCH] agent: Move subagent spawning to `ThreadEnvironment` (#48381) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TODO - [x] Cancellation - [x] Show subagent card as soon as tool_name == "subagent" - [x] Keybinding for closing subagent full screen view - [x] Only fire subagent notifications when appropriate - [x] Fix tests Release Notes: - N/A --------- Co-authored-by: Cameron Co-authored-by: Tom Houlé Co-authored-by: cameron Co-authored-by: Danilo Leal Co-authored-by: Ben Brandt --- assets/keymaps/default-linux.json | 6 + assets/keymaps/default-macos.json | 6 + assets/keymaps/default-windows.json | 6 + crates/acp_thread/src/acp_thread.rs | 134 +- crates/acp_thread/src/connection.rs | 15 +- crates/agent/src/agent.rs | 344 +++- crates/agent/src/db.rs | 109 +- .../agent/src/tests/edit_file_thread_test.rs | 447 +----- crates/agent/src/tests/mod.rs | 1405 ++++++----------- crates/agent/src/thread.rs | 387 +++-- crates/agent/src/thread_store.rs | 8 +- crates/agent/src/tools/edit_file_tool.rs | 16 - crates/agent/src/tools/read_file_tool.rs | 15 - crates/agent/src/tools/subagent_tool.rs | 590 ++----- crates/agent_servers/src/acp.rs | 5 +- crates/agent_servers/src/e2e_tests.rs | 2 +- crates/agent_ui/src/acp/entry_view_state.rs | 5 +- crates/agent_ui/src/acp/thread_view.rs | 794 ++++++---- .../src/acp/thread_view/active_thread.rs | 497 +++--- crates/agent_ui/src/agent_diff.rs | 5 +- crates/agent_ui/src/agent_panel.rs | 10 +- crates/agent_ui_v2/src/agent_thread_pane.rs | 2 +- crates/eval/src/example.rs | 3 + crates/eval/src/instance.rs | 14 +- crates/zed/src/visual_test_runner.rs | 354 +---- 25 files changed, 2176 insertions(+), 3003 deletions(-) diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 43c58411a6c4b4140a59c55a24d37716f0ab1ad3..49eeebc91407727a60f9926f70a172caa5decce7 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -300,6 +300,12 @@ "ctrl-enter": "menu::Confirm", }, }, + { + "context": "AcpThread", + "bindings": { + "ctrl--": "pane::GoBack", + }, + }, { "context": "AcpThread > Editor", "use_key_equivalents": true, diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 813f7442b6cb3ce93a130be01e0043c3ca025d9a..751f5e6547a4c1268ac9cd1849eef5e8704b5542 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -346,6 +346,12 @@ "cmd-enter": "menu::Confirm", }, }, + { + "context": "AcpThread", + "bindings": { + "ctrl--": "pane::GoBack", + }, + }, { "context": "AcpThread > Editor", "use_key_equivalents": true, diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index ae00aff39ef3529fa906f18d3a9a28e8fa6b688c..b8b9262bcdad58e54899ae0faae1b837479e1c3a 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -302,6 +302,12 @@ "ctrl-enter": "menu::Confirm", }, }, + { + "context": "AcpThread", + "bindings": { + "ctrl--": "pane::GoBack", + }, + }, { "context": "AcpThread > Editor", "use_key_equivalents": true, diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index a5503e8b6b81d7e532c0ac076426e5498720f20c..aad35f6999a7ec024c9fcdff317deaf4b3ca539d 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -9,8 +9,8 @@ use agent_settings::AgentSettings; /// This is a workaround since ACP's ToolCall doesn't have a dedicated name field. pub const TOOL_NAME_META_KEY: &str = "tool_name"; -/// The tool name for subagent spawning -pub const SUBAGENT_TOOL_NAME: &str = "subagent"; +/// Key used in ACP ToolCall meta to store the session id when a subagent is spawned. +pub const SUBAGENT_SESSION_ID_META_KEY: &str = "subagent_session_id"; /// Helper to extract tool name from ACP meta pub fn tool_name_from_meta(meta: &Option) -> Option { @@ -20,6 +20,14 @@ pub fn tool_name_from_meta(meta: &Option) -> Option { .map(|s| SharedString::from(s.to_owned())) } +/// Helper to extract subagent session id from ACP meta +pub fn subagent_session_id_from_meta(meta: &Option) -> Option { + meta.as_ref() + .and_then(|m| m.get(SUBAGENT_SESSION_ID_META_KEY)) + .and_then(|v| v.as_str()) + .map(|s| acp::SessionId::from(s.to_string())) +} + /// Helper to create meta with tool name pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta { acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())]) @@ -216,6 +224,7 @@ pub struct ToolCall { pub raw_input_markdown: Option>, pub raw_output: Option, pub tool_name: Option, + pub subagent_session_id: Option, } impl ToolCall { @@ -254,6 +263,8 @@ impl ToolCall { let tool_name = tool_name_from_meta(&tool_call.meta); + let subagent_session = subagent_session_id_from_meta(&tool_call.meta); + let result = Self { id: tool_call.tool_call_id, label: cx @@ -267,6 +278,7 @@ impl ToolCall { raw_input_markdown, raw_output: tool_call.raw_output, tool_name, + subagent_session_id: subagent_session, }; Ok(result) } @@ -274,6 +286,7 @@ impl ToolCall { fn update_fields( &mut self, fields: acp::ToolCallUpdateFields, + meta: Option, language_registry: Arc, path_style: PathStyle, terminals: &HashMap>, @@ -298,6 +311,10 @@ impl ToolCall { self.status = status.into(); } + if let Some(subagent_session_id) = subagent_session_id_from_meta(&meta) { + self.subagent_session_id = Some(subagent_session_id); + } + if let Some(title) = title { self.label.update(cx, |label, cx| { if self.kind == acp::ToolKind::Execute { @@ -366,7 +383,6 @@ impl ToolCall { ToolCallContent::Diff(diff) => Some(diff), ToolCallContent::ContentBlock(_) => None, ToolCallContent::Terminal(_) => None, - ToolCallContent::SubagentThread(_) => None, }) } @@ -375,24 +391,12 @@ impl ToolCall { ToolCallContent::Terminal(terminal) => Some(terminal), ToolCallContent::ContentBlock(_) => None, ToolCallContent::Diff(_) => None, - ToolCallContent::SubagentThread(_) => None, - }) - } - - pub fn subagent_thread(&self) -> Option<&Entity> { - self.content.iter().find_map(|content| match content { - ToolCallContent::SubagentThread(thread) => Some(thread), - _ => None, }) } pub fn is_subagent(&self) -> bool { - matches!(self.kind, acp::ToolKind::Other) - && self - .tool_name - .as_ref() - .map(|n| n.as_ref() == SUBAGENT_TOOL_NAME) - .unwrap_or(false) + self.tool_name.as_ref().is_some_and(|s| s == "subagent") + || self.subagent_session_id.is_some() } pub fn to_markdown(&self, cx: &App) -> String { @@ -688,7 +692,6 @@ pub enum ToolCallContent { ContentBlock(ContentBlock), Diff(Entity), Terminal(Entity), - SubagentThread(Entity), } impl ToolCallContent { @@ -760,7 +763,6 @@ impl ToolCallContent { Self::ContentBlock(content) => content.to_markdown(cx).to_string(), Self::Diff(diff) => diff.read(cx).to_markdown(cx), Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx), - Self::SubagentThread(thread) => thread.read(cx).to_markdown(cx), } } @@ -770,13 +772,6 @@ impl ToolCallContent { _ => None, } } - - pub fn subagent_thread(&self) -> Option<&Entity> { - match self { - Self::SubagentThread(thread) => Some(thread), - _ => None, - } - } } #[derive(Debug, PartialEq)] @@ -784,7 +779,6 @@ pub enum ToolCallUpdate { UpdateFields(acp::ToolCallUpdate), UpdateDiff(ToolCallUpdateDiff), UpdateTerminal(ToolCallUpdateTerminal), - UpdateSubagentThread(ToolCallUpdateSubagentThread), } impl ToolCallUpdate { @@ -793,7 +787,6 @@ impl ToolCallUpdate { Self::UpdateFields(update) => &update.tool_call_id, Self::UpdateDiff(diff) => &diff.id, Self::UpdateTerminal(terminal) => &terminal.id, - Self::UpdateSubagentThread(subagent) => &subagent.id, } } } @@ -828,18 +821,6 @@ pub struct ToolCallUpdateTerminal { pub terminal: Entity, } -impl From for ToolCallUpdate { - fn from(subagent: ToolCallUpdateSubagentThread) -> Self { - Self::UpdateSubagentThread(subagent) - } -} - -#[derive(Debug, PartialEq)] -pub struct ToolCallUpdateSubagentThread { - pub id: acp::ToolCallId, - pub thread: Entity, -} - #[derive(Debug, Default)] pub struct Plan { pub entries: Vec, @@ -949,6 +930,7 @@ pub struct RetryStatus { } pub struct AcpThread { + parent_session_id: Option, title: SharedString, entries: Vec, plan: Plan, @@ -987,6 +969,7 @@ pub enum AcpThreadEvent { EntriesRemoved(Range), ToolAuthorizationRequired, Retry(RetryStatus), + SubagentSpawned(acp::SessionId), Stopped, Error, LoadError(LoadError), @@ -1163,6 +1146,7 @@ impl Error for LoadError {} impl AcpThread { pub fn new( + parent_session_id: Option, title: impl Into, connection: Rc, project: Entity, @@ -1185,6 +1169,7 @@ impl AcpThread { let (user_stop_tx, _user_stop_rx) = watch::channel(false); Self { + parent_session_id, action_log, shared_buffers: Default::default(), entries: Default::default(), @@ -1205,6 +1190,10 @@ impl AcpThread { } } + pub fn parent_session_id(&self) -> Option<&acp::SessionId> { + self.parent_session_id.as_ref() + } + pub fn prompt_capabilities(&self) -> acp::PromptCapabilities { self.prompt_capabilities.clone() } @@ -1214,6 +1203,7 @@ impl AcpThread { self.user_stopped .store(true, std::sync::atomic::Ordering::SeqCst); self.user_stop_tx.send(true).ok(); + self.send_task.take(); } pub fn was_stopped_by_user(&self) -> bool { @@ -1479,6 +1469,10 @@ impl AcpThread { Task::ready(Ok(())) } + pub fn subagent_spawned(&mut self, session_id: acp::SessionId, cx: &mut Context) { + cx.emit(AcpThreadEvent::SubagentSpawned(session_id)); + } + pub fn update_token_usage(&mut self, usage: Option, cx: &mut Context) { self.token_usage = usage; cx.emit(AcpThreadEvent::TokenUsageUpdated); @@ -1518,6 +1512,7 @@ impl AcpThread { raw_input_markdown: None, raw_output: None, tool_name: None, + subagent_session_id: None, }; self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx); return Ok(()); @@ -1530,7 +1525,14 @@ impl AcpThread { match update { ToolCallUpdate::UpdateFields(update) => { let location_updated = update.fields.locations.is_some(); - call.update_fields(update.fields, languages, path_style, &self.terminals, cx)?; + call.update_fields( + update.fields, + update.meta, + languages, + path_style, + &self.terminals, + cx, + )?; if location_updated { self.resolve_locations(update.tool_call_id, cx); } @@ -1544,16 +1546,6 @@ impl AcpThread { call.content .push(ToolCallContent::Terminal(update.terminal)); } - ToolCallUpdate::UpdateSubagentThread(update) => { - debug_assert!( - !call.content.iter().any(|c| { - matches!(c, ToolCallContent::SubagentThread(existing) if existing == &update.thread) - }), - "Duplicate SubagentThread update for the same AcpThread entity" - ); - call.content - .push(ToolCallContent::SubagentThread(update.thread)); - } } cx.emit(AcpThreadEvent::EntryUpdated(ix)); @@ -1605,6 +1597,7 @@ impl AcpThread { call.update_fields( update.fields, + update.meta, language_registry, path_style, &self.terminals, @@ -2631,7 +2624,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx)) .await .unwrap(); @@ -2695,7 +2688,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, std::path::Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx)) .await .unwrap(); @@ -2783,7 +2776,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project.clone(), Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx)) .await .unwrap(); @@ -2894,7 +2887,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -2988,7 +2981,7 @@ mod tests { )); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -3069,7 +3062,7 @@ mod tests { .unwrap(); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx)) .await .unwrap(); @@ -3110,7 +3103,7 @@ mod tests { let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx)) .await .unwrap(); @@ -3185,7 +3178,7 @@ mod tests { let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx)) .await .unwrap(); @@ -3259,7 +3252,7 @@ mod tests { let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx)) .await .unwrap(); @@ -3307,7 +3300,7 @@ mod tests { })); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -3398,7 +3391,7 @@ mod tests { })); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -3457,7 +3450,7 @@ mod tests { } })); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -3630,7 +3623,7 @@ mod tests { })); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -3706,7 +3699,7 @@ mod tests { })); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -3779,7 +3772,7 @@ mod tests { } })); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -3906,7 +3899,7 @@ mod tests { &self.auth_methods } - fn new_thread( + fn new_session( self: Rc, project: Entity, _cwd: &Path, @@ -3922,6 +3915,7 @@ mod tests { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread = cx.new(|cx| { AcpThread::new( + None, "Test", self.clone(), project, @@ -4011,7 +4005,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -4077,7 +4071,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -4390,7 +4384,7 @@ mod tests { )); let thread = cx - .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) .await .unwrap(); diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 371b97393f67ad4016055b78c54e4b7006fe375b..102b5baa5a9fc5cbb445371e5138ebe7b31d83c4 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -30,7 +30,7 @@ impl UserMessageId { pub trait AgentConnection { fn telemetry_id(&self) -> SharedString; - fn new_thread( + fn new_session( self: Rc, project: Entity, cwd: &Path, @@ -53,6 +53,16 @@ pub trait AgentConnection { Task::ready(Err(anyhow::Error::msg("Loading sessions is not supported"))) } + /// Whether this agent supports closing existing sessions. + fn supports_close_session(&self, _cx: &App) -> bool { + false + } + + /// Close an existing session. Allows the agent to free the session from memory. + fn close_session(&self, _session_id: &acp::SessionId, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow::Error::msg("Closing sessions is not supported"))) + } + /// Whether this agent supports resuming existing sessions without loading history. fn supports_resume_session(&self, _cx: &App) -> bool { false @@ -598,7 +608,7 @@ mod test_support { Some(self.model_selector_impl()) } - fn new_thread( + fn new_session( self: Rc, project: Entity, _cwd: &Path, @@ -608,6 +618,7 @@ mod test_support { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread = cx.new(|cx| { AcpThread::new( + None, "Test", self.clone(), project, diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 0a36db8c206414485b94d825b96ec42f04d1c983..c3e0c28df822eed664c666a0a8ca7a4f4d4a178c 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -33,7 +33,7 @@ use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; use futures::channel::{mpsc, oneshot}; use futures::future::Shared; -use futures::{StreamExt, future}; +use futures::{FutureExt as _, StreamExt as _, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; @@ -49,6 +49,7 @@ use std::any::Any; use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::Arc; +use std::time::Duration; use util::ResultExt; use util::rel_path::RelPath; @@ -67,7 +68,7 @@ struct Session { /// The internal thread that processes messages thread: Entity, /// The ACP thread that handles protocol communication - acp_thread: WeakEntity, + acp_thread: Entity, pending_save: Task<()>, _subscriptions: Vec, } @@ -333,24 +334,27 @@ impl NativeAgent { ) }); - self.register_session(thread, cx) + self.register_session(thread, None, cx) } fn register_session( &mut self, thread_handle: Entity, + allowed_tool_names: Option>, cx: &mut Context, ) -> Entity { let connection = Rc::new(NativeAgentConnection(cx.entity())); let thread = thread_handle.read(cx); let session_id = thread.id().clone(); + let parent_session_id = thread.parent_thread_id(); let title = thread.title(); let project = thread.project.clone(); let action_log = thread.action_log.clone(); let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone(); let acp_thread = cx.new(|cx| { acp_thread::AcpThread::new( + parent_session_id, title, connection, project.clone(), @@ -364,20 +368,20 @@ impl NativeAgent { let registry = LanguageModelRegistry::read_global(cx); let summarization_model = registry.thread_summary_model().map(|c| c.model); + let weak = cx.weak_entity(); thread_handle.update(cx, |thread, cx| { thread.set_summarization_model(summarization_model, cx); thread.add_default_tools( - Rc::new(AcpThreadEnvironment { + allowed_tool_names, + Rc::new(NativeThreadEnvironment { acp_thread: acp_thread.downgrade(), + agent: weak, }) as _, cx, ) }); let subscriptions = vec![ - cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), cx.subscribe(&thread_handle, Self::handle_thread_title_updated), cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated), cx.observe(&thread_handle, move |this, thread, cx| { @@ -389,7 +393,7 @@ impl NativeAgent { session_id, Session { thread: thread_handle, - acp_thread: acp_thread.downgrade(), + acp_thread: acp_thread.clone(), _subscriptions: subscriptions, pending_save: Task::ready(()), }, @@ -580,7 +584,7 @@ impl NativeAgent { return; }; let thread = thread.downgrade(); - let acp_thread = session.acp_thread.clone(); + let acp_thread = session.acp_thread.downgrade(); cx.spawn(async move |_, cx| { let title = thread.read_with(cx, |thread, _| thread.title())?; let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?; @@ -598,12 +602,9 @@ impl NativeAgent { let Some(session) = self.sessions.get(thread.read(cx).id()) else { return; }; - session - .acp_thread - .update(cx, |acp_thread, cx| { - acp_thread.update_token_usage(usage.0.clone(), cx); - }) - .ok(); + session.acp_thread.update(cx, |acp_thread, cx| { + acp_thread.update_token_usage(usage.0.clone(), cx); + }); } fn handle_project_event( @@ -689,18 +690,16 @@ impl NativeAgent { fn update_available_commands(&self, cx: &mut Context) { let available_commands = self.build_available_commands(cx); for session in self.sessions.values() { - if let Some(acp_thread) = session.acp_thread.upgrade() { - acp_thread.update(cx, |thread, cx| { - thread - .handle_session_update( - acp::SessionUpdate::AvailableCommandsUpdate( - acp::AvailableCommandsUpdate::new(available_commands.clone()), - ), - cx, - ) - .log_err(); - }); - } + session.acp_thread.update(cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AvailableCommandsUpdate( + acp::AvailableCommandsUpdate::new(available_commands.clone()), + ), + cx, + ) + .log_err(); + }); } } @@ -796,11 +795,16 @@ impl NativeAgent { id: acp::SessionId, cx: &mut Context, ) -> Task>> { + if let Some(session) = self.sessions.get(&id) { + return Task::ready(Ok(session.acp_thread.clone())); + } + let task = self.load_thread(id, cx); cx.spawn(async move |this, cx| { let thread = task.await?; - let acp_thread = - this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?; + let acp_thread = this.update(cx, |this, cx| { + this.register_session(thread.clone(), None, cx) + })?; let events = thread.update(cx, |thread, cx| thread.replay(cx)); cx.update(|cx| { NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx) @@ -906,7 +910,7 @@ impl NativeAgent { true, cx, ); - })?; + }); thread.update(cx, |thread, cx| { thread.push_acp_user_block(id, [block], path_style, cx); @@ -920,7 +924,7 @@ impl NativeAgent { true, cx, ); - })?; + }); thread.update(cx, |thread, cx| { thread.push_acp_agent_block(block, cx); @@ -941,7 +945,11 @@ impl NativeAgent { })?; cx.update(|cx| { - NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx) + NativeAgentConnection::handle_thread_events( + response_stream, + acp_thread.downgrade(), + cx, + ) }) .await }) @@ -986,7 +994,7 @@ impl NativeAgentConnection { Ok(stream) => stream, Err(err) => return Task::ready(Err(err)), }; - Self::handle_thread_events(response_stream, acp_thread, cx) + Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx) } fn handle_thread_events( @@ -1057,6 +1065,11 @@ impl NativeAgentConnection { thread.update_tool_call(update, cx) })??; } + ThreadEvent::SubagentSpawned(session_id) => { + acp_thread.update(cx, |thread, cx| { + thread.subagent_spawned(session_id, cx); + })?; + } ThreadEvent::Retry(status) => { acp_thread.update(cx, |thread, cx| { thread.update_retry_status(status, cx) @@ -1222,7 +1235,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { "zed".into() } - fn new_thread( + fn new_session( self: Rc, project: Entity, cwd: &Path, @@ -1249,6 +1262,17 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .update(cx, |agent, cx| agent.open_thread(session.session_id, cx)) } + fn supports_close_session(&self, _cx: &App) -> bool { + true + } + + fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task> { + self.0.update(cx, |agent, _cx| { + agent.sessions.remove(session_id); + }); + Task::ready(Ok(())) + } + fn auth_methods(&self) -> &[acp::AuthMethod] { &[] // No auth for in-process } @@ -1363,7 +1387,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { agent.sessions.get(session_id).map(|session| { Rc::new(NativeAgentSessionTruncate { thread: session.thread.clone(), - acp_thread: session.acp_thread.clone(), + acp_thread: session.acp_thread.downgrade(), }) as _ }) }) @@ -1551,11 +1575,120 @@ impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle { } } -pub struct AcpThreadEnvironment { +pub struct NativeThreadEnvironment { + agent: WeakEntity, acp_thread: WeakEntity, } -impl ThreadEnvironment for AcpThreadEnvironment { +impl NativeThreadEnvironment { + pub(crate) fn create_subagent_thread( + agent: WeakEntity, + parent_thread_entity: Entity, + label: String, + initial_prompt: String, + timeout: Option, + allowed_tools: Option>, + cx: &mut App, + ) -> Result> { + let parent_thread = parent_thread_entity.read(cx); + let current_depth = parent_thread.depth(); + + if current_depth >= MAX_SUBAGENT_DEPTH { + return Err(anyhow!( + "Maximum subagent depth ({}) reached", + MAX_SUBAGENT_DEPTH + )); + } + + let running_count = parent_thread.running_subagent_count(); + if running_count >= MAX_PARALLEL_SUBAGENTS { + return Err(anyhow!( + "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.", + MAX_PARALLEL_SUBAGENTS + )); + } + + let allowed_tools = match allowed_tools { + Some(tools) => { + let parent_tool_names: std::collections::HashSet<&str> = + parent_thread.tools.keys().map(|s| s.as_str()).collect(); + Some( + tools + .into_iter() + .filter(|t| parent_tool_names.contains(t.as_str())) + .collect::>(), + ) + } + None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()), + }; + + let subagent_thread: Entity = cx.new(|cx| { + let mut thread = Thread::new_subagent(&parent_thread_entity, cx); + thread.set_title(label.into(), cx); + thread + }); + + let session_id = subagent_thread.read(cx).id().clone(); + + let acp_thread = agent.update(cx, |agent, cx| { + agent.register_session( + subagent_thread.clone(), + allowed_tools + .as_ref() + .map(|v| v.iter().map(|s| s.as_str()).collect()), + cx, + ) + })?; + + parent_thread_entity.update(cx, |parent_thread, _cx| { + parent_thread.register_running_subagent(subagent_thread.downgrade()) + }); + + let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx)); + + let timeout_timer = timeout.map(|d| cx.background_executor().timer(d)); + let wait_for_prompt_to_complete = cx + .background_spawn(async move { + if let Some(timer) = timeout_timer { + futures::select! { + _ = timer.fuse() => SubagentInitialPromptResult::Timeout, + _ = task.fuse() => SubagentInitialPromptResult::Completed, + } + } else { + task.await.log_err(); + SubagentInitialPromptResult::Completed + } + }) + .shared(); + + let mut user_stop_rx: watch::Receiver = + acp_thread.update(cx, |thread, _| thread.user_stop_receiver()); + + let user_cancelled = cx + .background_spawn(async move { + loop { + if *user_stop_rx.borrow() { + return; + } + if user_stop_rx.changed().await.is_err() { + std::future::pending::<()>().await; + } + } + }) + .shared(); + + Ok(Rc::new(NativeSubagentHandle { + session_id, + subagent_thread, + parent_thread: parent_thread_entity.downgrade(), + acp_thread, + wait_for_prompt_to_complete, + user_cancelled, + }) as _) + } +} + +impl ThreadEnvironment for NativeThreadEnvironment { fn create_terminal( &self, command: String, @@ -1588,6 +1721,98 @@ impl ThreadEnvironment for AcpThreadEnvironment { Ok(Rc::new(handle) as _) }) } + + fn create_subagent( + &self, + parent_thread_entity: Entity, + label: String, + initial_prompt: String, + timeout: Option, + allowed_tools: Option>, + cx: &mut App, + ) -> Result> { + Self::create_subagent_thread( + self.agent.clone(), + parent_thread_entity, + label, + initial_prompt, + timeout, + allowed_tools, + cx, + ) + } +} + +#[derive(Debug, Clone, Copy)] +enum SubagentInitialPromptResult { + Completed, + Timeout, +} + +pub struct NativeSubagentHandle { + session_id: acp::SessionId, + parent_thread: WeakEntity, + subagent_thread: Entity, + acp_thread: Entity, + wait_for_prompt_to_complete: Shared>, + user_cancelled: Shared>, +} + +impl SubagentHandle for NativeSubagentHandle { + fn id(&self) -> acp::SessionId { + self.session_id.clone() + } + + fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task> { + let thread = self.subagent_thread.clone(); + let acp_thread = self.acp_thread.clone(); + let wait_for_prompt = self.wait_for_prompt_to_complete.clone(); + + let wait_for_summary_task = cx.spawn(async move |cx| { + let timed_out = match wait_for_prompt.await { + SubagentInitialPromptResult::Completed => false, + SubagentInitialPromptResult::Timeout => true, + }; + + let summary_prompt = if timed_out { + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt) + } else { + summary_prompt + }; + + acp_thread + .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx)) + .await?; + + thread.read_with(cx, |thread, _cx| { + thread + .last_message() + .map(|m| m.to_markdown()) + .context("No response from subagent") + }) + }); + + let user_cancelled = self.user_cancelled.clone(); + let thread = self.subagent_thread.clone(); + let subagent_session_id = self.session_id.clone(); + let parent_thread = self.parent_thread.clone(); + cx.spawn(async move |cx| { + let result = futures::select! { + result = wait_for_summary_task.fuse() => result, + _ = user_cancelled.fuse() => { + thread.update(cx, |thread, cx| thread.cancel(cx).detach()); + Err(anyhow!("User cancelled")) + }, + }; + parent_thread + .update(cx, |parent_thread, cx| { + parent_thread.unregister_running_subagent(&subagent_session_id, cx) + }) + .ok(); + result + }) + } } pub struct AcpTerminalHandle { @@ -1730,7 +1955,7 @@ mod internal_tests { // Create a thread/session let acp_thread = cx .update(|cx| { - Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) + Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx) }) .await .unwrap(); @@ -1808,7 +2033,7 @@ mod internal_tests { // Create a thread/session let acp_thread = cx .update(|cx| { - Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) + Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx) }) .await .unwrap(); @@ -1908,7 +2133,7 @@ mod internal_tests { let acp_thread = cx .update(|cx| { - Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) + Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx) }) .await .unwrap(); @@ -2024,7 +2249,7 @@ mod internal_tests { .update(|cx| { connection .clone() - .new_thread(project.clone(), Path::new("/a"), cx) + .new_session(project.clone(), Path::new("/a"), cx) }) .await .unwrap(); @@ -2057,11 +2282,12 @@ mod internal_tests { send.await.unwrap(); cx.run_until_parked(); - // Drop the thread so it can be reloaded from disk. - cx.update(|_| { - drop(thread); - drop(acp_thread); - }); + // Close the session so it can be reloaded from disk. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + drop(thread); + drop(acp_thread); agent.read_with(cx, |agent, _| { assert!(agent.sessions.is_empty()); }); @@ -2130,7 +2356,7 @@ mod internal_tests { .update(|cx| { connection .clone() - .new_thread(project.clone(), Path::new("/a"), cx) + .new_session(project.clone(), Path::new("/a"), cx) }) .await .unwrap(); @@ -2163,11 +2389,12 @@ mod internal_tests { send.await.unwrap(); cx.run_until_parked(); - // Drop the thread so it can be reloaded from disk. - cx.update(|_| { - drop(thread); - drop(acp_thread); - }); + // Close the session so it can be reloaded from disk. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + drop(thread); + drop(acp_thread); agent.read_with(cx, |agent, _| { assert!(agent.sessions.is_empty()); }); @@ -2225,7 +2452,7 @@ mod internal_tests { .update(|cx| { connection .clone() - .new_thread(project.clone(), Path::new(""), cx) + .new_session(project.clone(), Path::new(""), cx) }) .await .unwrap(); @@ -2294,11 +2521,12 @@ mod internal_tests { cx.run_until_parked(); - // Drop the ACP thread, which should cause the session to be dropped as well. - cx.update(|_| { - drop(thread); - drop(acp_thread); - }); + // Close the session so it can be reloaded from disk. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + drop(thread); + drop(acp_thread); agent.read_with(cx, |agent, _| { assert_eq!(agent.sessions.keys().cloned().collect::>(), []); }); diff --git a/crates/agent/src/db.rs b/crates/agent/src/db.rs index e407c2e94f0655f08fb59c1deb3ebb574bc3a758..fa4b37dba3e789b499bfe5db4f0b76ccf12e5a09 100644 --- a/crates/agent/src/db.rs +++ b/crates/agent/src/db.rs @@ -26,6 +26,7 @@ pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbThreadMetadata { pub id: acp::SessionId, + pub parent_session_id: Option, #[serde(alias = "summary")] pub title: SharedString, pub updated_at: DateTime, @@ -50,6 +51,8 @@ pub struct DbThread { pub profile: Option, #[serde(default)] pub imported: bool, + #[serde(default)] + pub subagent_context: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -87,6 +90,7 @@ impl SharedThread { model: self.model, profile: None, imported: true, + subagent_context: None, } } @@ -260,6 +264,7 @@ impl DbThread { model: thread.model, profile: thread.profile, imported: false, + subagent_context: None, }) } } @@ -357,6 +362,13 @@ impl ThreadsDatabase { "})?() .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; + if let Ok(mut s) = connection.exec(indoc! {" + ALTER TABLE threads ADD COLUMN parent_id TEXT + "}) + { + s().ok(); + } + let db = Self { executor, connection: Arc::new(Mutex::new(connection)), @@ -381,6 +393,10 @@ impl ThreadsDatabase { let title = thread.title.to_string(); let updated_at = thread.updated_at.to_rfc3339(); + let parent_id = thread + .subagent_context + .as_ref() + .map(|ctx| ctx.parent_thread_id.0.clone()); let json_data = serde_json::to_string(&SerializedThread { thread, version: DbThread::VERSION, @@ -392,11 +408,11 @@ impl ThreadsDatabase { let data_type = DataType::Zstd; let data = compressed; - let mut insert = connection.exec_bound::<(Arc, String, String, DataType, Vec)>(indoc! {" - INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) + let mut insert = connection.exec_bound::<(Arc, Option>, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, parent_id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?) "})?; - insert((id.0, title, updated_at, data_type, data))?; + insert((id.0, parent_id, title, updated_at, data_type, data))?; Ok(()) } @@ -407,17 +423,18 @@ impl ThreadsDatabase { self.executor.spawn(async move { let connection = connection.lock(); - let mut select = - connection.select_bound::<(), (Arc, String, String)>(indoc! {" - SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC + let mut select = connection + .select_bound::<(), (Arc, Option>, String, String)>(indoc! {" + SELECT id, parent_id, summary, updated_at FROM threads ORDER BY updated_at DESC "})?; let rows = select(())?; let mut threads = Vec::new(); - for (id, summary, updated_at) in rows { + for (id, parent_id, summary, updated_at) in rows { threads.push(DbThreadMetadata { id: acp::SessionId::new(id), + parent_session_id: parent_id.map(acp::SessionId::new), title: summary.into(), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), }); @@ -552,6 +569,7 @@ mod tests { model: None, profile: None, imported: false, + subagent_context: None, } } @@ -618,4 +636,81 @@ mod tests { Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap() ); } + + #[test] + fn test_subagent_context_defaults_to_none() { + let json = r#"{ + "title": "Old Thread", + "messages": [], + "updated_at": "2024-01-01T00:00:00Z" + }"#; + + let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize"); + + assert!( + db_thread.subagent_context.is_none(), + "Legacy threads without subagent_context should default to None" + ); + } + + #[gpui::test] + async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) { + let database = ThreadsDatabase::new(cx.executor()).unwrap(); + + let parent_id = session_id("parent-thread"); + let child_id = session_id("child-thread"); + + let mut child_thread = make_thread( + "Subagent Thread", + Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + ); + child_thread.subagent_context = Some(crate::SubagentContext { + parent_thread_id: parent_id.clone(), + depth: 2, + }); + + database + .save_thread(child_id.clone(), child_thread) + .await + .unwrap(); + + let loaded = database + .load_thread(child_id) + .await + .unwrap() + .expect("thread should exist"); + + let context = loaded + .subagent_context + .expect("subagent_context should be restored"); + assert_eq!(context.parent_thread_id, parent_id); + assert_eq!(context.depth, 2); + } + + #[gpui::test] + async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) { + let database = ThreadsDatabase::new(cx.executor()).unwrap(); + + let thread_id = session_id("regular-thread"); + let thread = make_thread( + "Regular Thread", + Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + ); + + database + .save_thread(thread_id.clone(), thread) + .await + .unwrap(); + + let loaded = database + .load_thread(thread_id) + .await + .unwrap() + .expect("thread should exist"); + + assert!( + loaded.subagent_context.is_none(), + "Regular threads should have no subagent_context" + ); + } } diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index e016889a97a69a265c10a022c70a66ec52aae450..068c0270cf7057790d3665f7f1fac59d1d3f1d07 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -1,15 +1,14 @@ use super::*; use crate::{AgentTool, EditFileTool, ReadFileTool}; use acp_thread::UserMessageId; -use action_log::ActionLog; use fs::FakeFs; use language_model::{ - LanguageModelCompletionEvent, LanguageModelToolUse, MessageContent, StopReason, + LanguageModelCompletionEvent, LanguageModelToolUse, StopReason, fake_provider::FakeLanguageModel, }; use prompt_store::ProjectContext; use serde_json::json; -use std::{collections::BTreeMap, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use util::path; #[gpui::test] @@ -50,17 +49,23 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { ); // Add just the tools we need for this test let language_registry = project.read(cx).languages().clone(); - thread.add_tool(crate::ReadFileTool::new( - cx.weak_entity(), - project.clone(), - thread.action_log().clone(), - )); - thread.add_tool(crate::EditFileTool::new( - project.clone(), - cx.weak_entity(), - language_registry, - crate::Templates::new(), - )); + thread.add_tool( + crate::ReadFileTool::new( + cx.weak_entity(), + project.clone(), + thread.action_log().clone(), + ), + None, + ); + thread.add_tool( + crate::EditFileTool::new( + project.clone(), + cx.weak_entity(), + language_registry, + crate::Templates::new(), + ), + None, + ); thread }); @@ -203,417 +208,3 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { ); }); } - -#[gpui::test] -async fn test_subagent_uses_read_file_tool(cx: &mut TestAppContext) { - // This test verifies that subagents can successfully use the read_file tool - // through the full thread flow, and that tools are properly rebound to use - // the subagent's thread ID instead of the parent's. - super::init_test(cx); - super::always_allow_tools(cx); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "src": { - "lib.rs": "pub fn hello() -> &'static str {\n \"Hello from lib!\"\n}\n" - } - }), - ) - .await; - - let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let fake_model = model.as_fake(); - - // Create subagent context - let subagent_context = crate::SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"), - depth: 1, - summary_prompt: "Summarize what you found".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - // Create parent tools that will be passed to the subagent - // This simulates how the subagent_tool passes tools to new_subagent - let parent_tools: BTreeMap> = { - let action_log = cx.new(|_| ActionLog::new(project.clone())); - // Create a "fake" parent thread reference - this should get rebound - let fake_parent_thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)), - crate::Templates::new(), - Some(model.clone()), - cx, - ) - }); - let mut tools: BTreeMap> = - BTreeMap::new(); - tools.insert( - ReadFileTool::NAME.into(), - crate::ReadFileTool::new(fake_parent_thread.downgrade(), project.clone(), action_log) - .erase(), - ); - tools - }; - - // Create subagent - tools should be rebound to use subagent's thread - let subagent = cx.new(|cx| { - crate::Thread::new_subagent( - project.clone(), - project_context, - context_server_registry, - crate::Templates::new(), - model.clone(), - subagent_context, - parent_tools, - cx, - ) - }); - - // Get the subagent's thread ID - let _subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string()); - - // Verify the subagent has the read_file tool - subagent.read_with(cx, |thread, _| { - assert!( - thread.has_registered_tool(ReadFileTool::NAME), - "subagent should have read_file tool" - ); - }); - - // Submit a user message to the subagent - subagent - .update(cx, |thread, cx| { - thread.submit_user_message("Read the file src/lib.rs", cx) - }) - .unwrap(); - cx.run_until_parked(); - - // Simulate the model calling the read_file tool - let read_tool_use = LanguageModelToolUse { - id: "read_tool_1".into(), - name: ReadFileTool::NAME.into(), - raw_input: json!({"path": "project/src/lib.rs"}).to_string(), - input: json!({"path": "project/src/lib.rs"}), - is_input_complete: true, - thought_signature: None, - }; - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // Wait for the tool to complete and the model to be called again with tool results - let deadline = std::time::Instant::now() + Duration::from_secs(5); - while fake_model.pending_completions().is_empty() { - if std::time::Instant::now() >= deadline { - panic!("Timed out waiting for model to be called after read_file tool completion"); - } - cx.run_until_parked(); - cx.background_executor - .timer(Duration::from_millis(10)) - .await; - } - - // Verify the tool result was sent back to the model - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "Model should have been called with tool result" - ); - - let last_request = pending.last().unwrap(); - let tool_result = last_request.messages.iter().find_map(|m| { - m.content.iter().find_map(|c| match c { - MessageContent::ToolResult(result) => Some(result), - _ => None, - }) - }); - assert!( - tool_result.is_some(), - "Tool result should be in the messages sent back to the model" - ); - - // Verify the tool result contains the file content - let result = tool_result.unwrap(); - let result_text = match &result.content { - language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), - _ => panic!("expected text content in tool result"), - }; - assert!( - result_text.contains("Hello from lib!"), - "Tool result should contain file content, got: {}", - result_text - ); - - // Verify the subagent is ready for more input (tool completed, model called again) - // This test verifies the subagent can successfully use read_file tool. - // The summary flow is tested separately in test_subagent_returns_summary_on_completion. -} - -#[gpui::test] -async fn test_subagent_uses_edit_file_tool(cx: &mut TestAppContext) { - // This test verifies that subagents can successfully use the edit_file tool - // through the full thread flow, including the edit agent's model request. - // It also verifies that the edit agent uses the subagent's thread ID, not the parent's. - super::init_test(cx); - super::always_allow_tools(cx); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "src": { - "config.rs": "pub const VERSION: &str = \"1.0.0\";\n" - } - }), - ) - .await; - - let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let fake_model = model.as_fake(); - - // Create a "parent" thread to simulate the real scenario where tools are inherited - let parent_thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)), - crate::Templates::new(), - Some(model.clone()), - cx, - ) - }); - let parent_thread_id = parent_thread.read_with(cx, |thread, _| thread.id().to_string()); - - // Create parent tools that reference the parent thread - let parent_tools: BTreeMap> = { - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let language_registry = project.read_with(cx, |p, _| p.languages().clone()); - let mut tools: BTreeMap> = - BTreeMap::new(); - tools.insert( - ReadFileTool::NAME.into(), - crate::ReadFileTool::new(parent_thread.downgrade(), project.clone(), action_log) - .erase(), - ); - tools.insert( - EditFileTool::NAME.into(), - crate::EditFileTool::new( - project.clone(), - parent_thread.downgrade(), - language_registry, - crate::Templates::new(), - ) - .erase(), - ); - tools - }; - - // Create subagent context - let subagent_context = crate::SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"), - depth: 1, - summary_prompt: "Summarize what you changed".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - // Create subagent - tools should be rebound to use subagent's thread - let subagent = cx.new(|cx| { - crate::Thread::new_subagent( - project.clone(), - project_context, - context_server_registry, - crate::Templates::new(), - model.clone(), - subagent_context, - parent_tools, - cx, - ) - }); - - // Get the subagent's thread ID - it should be different from parent - let subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string()); - assert_ne!( - parent_thread_id, subagent_thread_id, - "Subagent should have a different thread ID than parent" - ); - - // Verify the subagent has the tools - subagent.read_with(cx, |thread, _| { - assert!( - thread.has_registered_tool(ReadFileTool::NAME), - "subagent should have read_file tool" - ); - assert!( - thread.has_registered_tool(EditFileTool::NAME), - "subagent should have edit_file tool" - ); - }); - - // Submit a user message to the subagent - subagent - .update(cx, |thread, cx| { - thread.submit_user_message("Update the version in config.rs to 2.0.0", cx) - }) - .unwrap(); - cx.run_until_parked(); - - // First, model calls read_file to see the current content - let read_tool_use = LanguageModelToolUse { - id: "read_tool_1".into(), - name: ReadFileTool::NAME.into(), - raw_input: json!({"path": "project/src/config.rs"}).to_string(), - input: json!({"path": "project/src/config.rs"}), - is_input_complete: true, - thought_signature: None, - }; - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // Wait for the read tool to complete and model to be called again - let deadline = std::time::Instant::now() + Duration::from_secs(5); - while fake_model.pending_completions().is_empty() { - if std::time::Instant::now() >= deadline { - panic!("Timed out waiting for model to be called after read_file tool"); - } - cx.run_until_parked(); - cx.background_executor - .timer(Duration::from_millis(10)) - .await; - } - - // Model responds and calls edit_file - fake_model.send_last_completion_stream_text_chunk("I'll update the version now."); - let edit_tool_use = LanguageModelToolUse { - id: "edit_tool_1".into(), - name: EditFileTool::NAME.into(), - raw_input: json!({ - "display_description": "Update version to 2.0.0", - "path": "project/src/config.rs", - "mode": "edit" - }) - .to_string(), - input: json!({ - "display_description": "Update version to 2.0.0", - "path": "project/src/config.rs", - "mode": "edit" - }), - is_input_complete: true, - thought_signature: None, - }; - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // The edit_file tool creates an EditAgent which makes its own model request. - // Wait for that request. - let deadline = std::time::Instant::now() + Duration::from_secs(5); - while fake_model.pending_completions().is_empty() { - if std::time::Instant::now() >= deadline { - panic!( - "Timed out waiting for edit agent completion request in subagent. Pending: {}", - fake_model.pending_completions().len() - ); - } - cx.run_until_parked(); - cx.background_executor - .timer(Duration::from_millis(10)) - .await; - } - - // Verify the edit agent's request uses the SUBAGENT's thread ID, not the parent's - let pending = fake_model.pending_completions(); - let edit_agent_request = pending.last().unwrap(); - let edit_agent_thread_id = edit_agent_request.thread_id.as_ref().unwrap(); - std::assert_eq!( - edit_agent_thread_id, - &subagent_thread_id, - "Edit agent should use subagent's thread ID, not parent's. Got: {}, expected: {}", - edit_agent_thread_id, - subagent_thread_id - ); - std::assert_ne!( - edit_agent_thread_id, - &parent_thread_id, - "Edit agent should NOT use parent's thread ID" - ); - - // Send the edit agent's response with the XML format it expects - let edit_response = "pub const VERSION: &str = \"1.0.0\";\npub const VERSION: &str = \"2.0.0\";"; - fake_model.send_last_completion_stream_text_chunk(edit_response); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // Wait for the edit to complete and the thread to call the model again with tool results - let deadline = std::time::Instant::now() + Duration::from_secs(5); - while fake_model.pending_completions().is_empty() { - if std::time::Instant::now() >= deadline { - panic!("Timed out waiting for model to be called after edit completion in subagent"); - } - cx.run_until_parked(); - cx.background_executor - .timer(Duration::from_millis(10)) - .await; - } - - // Verify the file was edited - let file_content = fs - .load(path!("/project/src/config.rs").as_ref()) - .await - .expect("file should exist"); - assert!( - file_content.contains("2.0.0"), - "File should have been edited to contain new version. Content: {}", - file_content - ); - assert!( - !file_content.contains("1.0.0"), - "Old version should be replaced. Content: {}", - file_content - ); - - // Verify the tool result was sent back to the model - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "Model should have been called with tool result" - ); - - let last_request = pending.last().unwrap(); - let has_tool_result = last_request.messages.iter().any(|m| { - m.content - .iter() - .any(|c| matches!(c, MessageContent::ToolResult(_))) - }); - assert!( - has_tool_result, - "Tool result should be in the messages sent back to the model" - ); -} diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 8a6f13cd28ac696b938c10189ac3b74d43828628..d165b888ccb3f7350e14d95d3eed2daa269d004e 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -155,8 +155,51 @@ impl crate::TerminalHandle for FakeTerminalHandle { } } +struct FakeSubagentHandle { + session_id: acp::SessionId, + wait_for_summary_task: Shared>, +} + +impl FakeSubagentHandle { + fn new_never_completes(cx: &App) -> Self { + Self { + session_id: acp::SessionId::new("subagent-id"), + wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(), + } + } +} + +impl SubagentHandle for FakeSubagentHandle { + fn id(&self) -> acp::SessionId { + self.session_id.clone() + } + + fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task> { + let task = self.wait_for_summary_task.clone(); + cx.background_spawn(async move { Ok(task.await) }) + } +} + +#[derive(Default)] struct FakeThreadEnvironment { - handle: Rc, + terminal_handle: Option>, + subagent_handle: Option>, +} + +impl FakeThreadEnvironment { + pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self { + Self { + terminal_handle: Some(terminal_handle.into()), + ..self + } + } + + pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self { + Self { + subagent_handle: Some(subagent_handle.into()), + ..self + } + } } impl crate::ThreadEnvironment for FakeThreadEnvironment { @@ -167,7 +210,27 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment { _output_byte_limit: Option, _cx: &mut AsyncApp, ) -> Task>> { - Task::ready(Ok(self.handle.clone() as Rc)) + let handle = self + .terminal_handle + .clone() + .expect("Terminal handle not available on FakeThreadEnvironment"); + Task::ready(Ok(handle as Rc)) + } + + fn create_subagent( + &self, + _parent_thread: Entity, + _label: String, + _initial_prompt: String, + _timeout_ms: Option, + _allowed_tools: Option>, + _cx: &mut App, + ) -> Result> { + Ok(self + .subagent_handle + .clone() + .expect("Subagent handle not available on FakeThreadEnvironment") + as Rc) } } @@ -200,6 +263,18 @@ impl crate::ThreadEnvironment for MultiTerminalEnvironment { self.handles.borrow_mut().push(handle.clone()); Task::ready(Ok(handle as Rc)) } + + fn create_subagent( + &self, + _parent_thread: Entity, + _label: String, + _initial_prompt: String, + _timeout: Option, + _allowed_tools: Option>, + _cx: &mut App, + ) -> Result> { + unimplemented!() + } } fn always_allow_tools(cx: &mut TestAppContext) { @@ -228,14 +303,8 @@ async fn test_echo(cx: &mut TestAppContext) { let events = events.collect().await; thread.update(cx, |thread, _cx| { - assert_eq!( - thread.last_message().unwrap().to_markdown(), - indoc! {" - ## Assistant - - Hello - "} - ) + assert_eq!(thread.last_message().unwrap().role(), Role::Assistant); + assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n") }); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); } @@ -248,10 +317,10 @@ async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) { let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); + let handle = environment.terminal_handle.clone().unwrap(); #[allow(clippy::arc_with_non_send_sync)] let tool = Arc::new(crate::TerminalTool::new(project, environment)); @@ -315,10 +384,10 @@ async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAp let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [], cx).await; - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); + let handle = environment.terminal_handle.clone().unwrap(); #[allow(clippy::arc_with_non_send_sync)] let tool = Arc::new(crate::TerminalTool::new(project, environment)); @@ -387,11 +456,10 @@ async fn test_thinking(cx: &mut TestAppContext) { let events = events.collect().await; thread.update(cx, |thread, _cx| { + assert_eq!(thread.last_message().unwrap().role(), Role::Assistant); assert_eq!( thread.last_message().unwrap().to_markdown(), indoc! {" - ## Assistant - Think Hello "} @@ -413,7 +481,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) { project_context.update(cx, |project_context, _cx| { project_context.shell = "test-shell".into() }); - thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread.update(cx, |thread, _| thread.add_tool(EchoTool, None)); thread .update(cx, |thread, cx| { thread.send(UserMessageId::new(), ["abc"], cx) @@ -549,7 +617,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) { cx.run_until_parked(); // Simulate a tool call and verify that the latest tool result is cached - thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread.update(cx, |thread, _| thread.add_tool(EchoTool, None)); thread .update(cx, |thread, cx| { thread.send(UserMessageId::new(), ["Use the echo tool"], cx) @@ -635,7 +703,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { // Test a tool call that's likely to complete *before* streaming stops. let events = thread .update(cx, |thread, cx| { - thread.add_tool(EchoTool); + thread.add_tool(EchoTool, None); thread.send( UserMessageId::new(), ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."], @@ -651,7 +719,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.remove_tool(&EchoTool::NAME); - thread.add_tool(DelayTool); + thread.add_tool(DelayTool, None); thread.send( UserMessageId::new(), [ @@ -695,7 +763,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { // Test a tool call that's likely to complete *before* streaming stops. let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(WordListTool); + thread.add_tool(WordListTool, None); thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) }) .unwrap(); @@ -746,7 +814,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(ToolRequiringPermission); + thread.add_tool(ToolRequiringPermission, None); thread.send(UserMessageId::new(), ["abc"], cx) }) .unwrap(); @@ -1087,7 +1155,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { // Test concurrent tool calls with different delay times let events = thread .update(cx, |thread, cx| { - thread.add_tool(DelayTool); + thread.add_tool(DelayTool, None); thread.send( UserMessageId::new(), [ @@ -1132,9 +1200,9 @@ async fn test_profiles(cx: &mut TestAppContext) { let fake_model = model.as_fake(); thread.update(cx, |thread, _cx| { - thread.add_tool(DelayTool); - thread.add_tool(EchoTool); - thread.add_tool(InfiniteTool); + thread.add_tool(DelayTool, None); + thread.add_tool(EchoTool, None); + thread.add_tool(InfiniteTool, None); }); // Override profiles and wait for settings to be loaded. @@ -1300,7 +1368,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { // Send again after adding the echo tool, ensuring the name collision is resolved. let events = thread.update(cx, |thread, cx| { - thread.add_tool(EchoTool); + thread.add_tool(EchoTool, None); thread.send(UserMessageId::new(), ["Go"], cx).unwrap() }); cx.run_until_parked(); @@ -1409,11 +1477,11 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { thread.update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test".into()), cx); - thread.add_tool(EchoTool); - thread.add_tool(DelayTool); - thread.add_tool(WordListTool); - thread.add_tool(ToolRequiringPermission); - thread.add_tool(InfiniteTool); + thread.add_tool(EchoTool, None); + thread.add_tool(DelayTool, None); + thread.add_tool(WordListTool, None); + thread.add_tool(ToolRequiringPermission, None); + thread.add_tool(InfiniteTool, None); }); // Set up multiple context servers with some overlapping tool names @@ -1543,8 +1611,8 @@ async fn test_cancellation(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(InfiniteTool); - thread.add_tool(EchoTool); + thread.add_tool(InfiniteTool, None); + thread.add_tool(EchoTool, None); thread.send( UserMessageId::new(), ["Call the echo tool, then call the infinite tool, then explain their output"], @@ -1628,17 +1696,17 @@ async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext always_allow_tools(cx); let fake_model = model.as_fake(); - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); + let handle = environment.terminal_handle.clone().unwrap(); let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(crate::TerminalTool::new( - thread.project().clone(), - environment, - )); + thread.add_tool( + crate::TerminalTool::new(thread.project().clone(), environment), + None, + ); thread.send(UserMessageId::new(), ["run a command"], cx) }) .unwrap(); @@ -1732,7 +1800,7 @@ async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppC let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(tool); + thread.add_tool(tool, None); thread.send( UserMessageId::new(), ["call the cancellation aware tool"], @@ -1910,18 +1978,18 @@ async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) { always_allow_tools(cx); let fake_model = model.as_fake(); - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); + let handle = environment.terminal_handle.clone().unwrap(); let message_id = UserMessageId::new(); let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(crate::TerminalTool::new( - thread.project().clone(), - environment, - )); + thread.add_tool( + crate::TerminalTool::new(thread.project().clone(), environment), + None, + ); thread.send(message_id.clone(), ["run a command"], cx) }) .unwrap(); @@ -1982,10 +2050,10 @@ async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(crate::TerminalTool::new( - thread.project().clone(), - environment.clone(), - )); + thread.add_tool( + crate::TerminalTool::new(thread.project().clone(), environment.clone()), + None, + ); thread.send(UserMessageId::new(), ["run multiple commands"], cx) }) .unwrap(); @@ -2088,17 +2156,17 @@ async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppCon always_allow_tools(cx); let fake_model = model.as_fake(); - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); + let handle = environment.terminal_handle.clone().unwrap(); let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(crate::TerminalTool::new( - thread.project().clone(), - environment, - )); + thread.add_tool( + crate::TerminalTool::new(thread.project().clone(), environment), + None, + ); thread.send(UserMessageId::new(), ["run a command"], cx) }) .unwrap(); @@ -2182,17 +2250,17 @@ async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) { always_allow_tools(cx); let fake_model = model.as_fake(); - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); + let handle = environment.terminal_handle.clone().unwrap(); let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(crate::TerminalTool::new( - thread.project().clone(), - environment, - )); + thread.add_tool( + crate::TerminalTool::new(thread.project().clone(), environment), + None, + ); thread.send(UserMessageId::new(), ["run a command with timeout"], cx) }) .unwrap(); @@ -2673,8 +2741,8 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) { let _events = thread .update(cx, |thread, cx| { - thread.add_tool(ToolRequiringPermission); - thread.add_tool(EchoTool); + thread.add_tool(ToolRequiringPermission, None); + thread.add_tool(EchoTool, None); thread.send(UserMessageId::new(), ["Hey!"], cx) }) .unwrap(); @@ -2788,7 +2856,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Create a thread using new_thread let connection_rc = Rc::new(connection.clone()); let acp_thread = cx - .update(|cx| connection_rc.new_thread(project, cwd, cx)) + .update(|cx| connection_rc.new_session(project, cwd, cx)) .await .expect("new_thread should succeed"); @@ -2855,9 +2923,11 @@ async fn test_agent_connection(cx: &mut TestAppContext) { cx.update(|cx| connection.cancel(&session_id, cx)); request.await.expect("prompt should fail gracefully"); - // Ensure that dropping the ACP thread causes the native thread to be - // dropped as well. - cx.update(|_| drop(acp_thread)); + // Explicitly close the session and drop the ACP thread. + cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx)) + .await + .unwrap(); + drop(acp_thread); let result = cx .update(|cx| { connection.prompt( @@ -2878,7 +2948,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { #[gpui::test] async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; - thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); + thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool, None)); let fake_model = model.as_fake(); let mut events = thread @@ -3080,7 +3150,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { - thread.add_tool(EchoTool); + thread.add_tool(EchoTool, None); thread.send(UserMessageId::new(), ["Call the echo tool!"], cx) }) .unwrap(); @@ -3652,10 +3722,9 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { // Test 1: Deny rule blocks command { - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -3704,10 +3773,10 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny) { - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default() + .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0)) + })); cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -3762,10 +3831,10 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { // Test 3: always_allow_tool_actions=true overrides always_confirm patterns { - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default() + .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0)) + })); cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -3808,10 +3877,10 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { // Test 4: always_allow_tool_actions=true overrides default_mode: Deny { - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0))); - let environment = Rc::new(FakeThreadEnvironment { - handle: handle.clone(), - }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default() + .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0)) + })); cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -3868,8 +3937,9 @@ async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAp cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { handle }); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); let thread = cx.new(|cx| { let mut thread = Thread::new( @@ -3880,7 +3950,7 @@ async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAp Some(model), cx, ); - thread.add_default_tools(environment, cx); + thread.add_default_tools(None, environment, cx); thread }); @@ -3893,7 +3963,7 @@ async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAp } #[gpui::test] -async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) { +async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) { init_test(cx); cx.update(|cx| { @@ -3909,31 +3979,29 @@ async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) { cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let subagent = cx.new(|cx| { - Thread::new_subagent( + let parent_thread = cx.new(|cx| { + Thread::new( project.clone(), project_context, context_server_registry, Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), + Some(model.clone()), cx, ) }); - subagent.read_with(cx, |thread, _| { - assert!(thread.is_subagent()); - assert_eq!(thread.depth(), 1); - assert!(thread.model().is_some()); + let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx)); + subagent_thread.read_with(cx, |subagent_thread, cx| { + assert!(subagent_thread.is_subagent()); + assert_eq!(subagent_thread.depth(), 1); + assert_eq!( + subagent_thread.model().map(|model| model.id()), + Some(model.id()) + ); + assert_eq!( + subagent_thread.parent_thread_id(), + Some(parent_thread.read(cx).id().clone()) + ); }); } @@ -3953,34 +4021,32 @@ async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppCont let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); let model = Arc::new(FakeLanguageModel::default()); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx)) + })); - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: MAX_SUBAGENT_DEPTH, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); - let environment = Rc::new(FakeThreadEnvironment { handle }); - - let deep_subagent = cx.new(|cx| { - let mut thread = Thread::new_subagent( + let deep_parent_thread = cx.new(|cx| { + let mut thread = Thread::new( project.clone(), project_context, context_server_registry, Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), + Some(model.clone()), cx, ); - thread.add_default_tools(environment, cx); + thread.set_subagent_context(SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + depth: MAX_SUBAGENT_DEPTH - 1, + }); + thread + }); + let deep_subagent_thread = cx.new(|cx| { + let mut thread = Thread::new_subagent(&deep_parent_thread, cx); + thread.add_default_tools(None, environment, cx); thread }); - deep_subagent.read_with(cx, |thread, _| { + deep_subagent_thread.read_with(cx, |thread, _| { assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH); assert!( !thread.has_registered_tool(SubagentTool::NAME), @@ -3989,209 +4055,6 @@ async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppCont }); } -#[gpui::test] -async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; - let fake_model = model.as_fake(); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize your work".to_string(), - context_low_prompt: "Context low, wrap up".to_string(), - }; - - let project = thread.read_with(cx, |t, _| t.project.clone()); - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - - let subagent = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context, - context_server_registry, - Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), - cx, - ) - }); - - let task_prompt = "Find all TODO comments in the codebase"; - subagent - .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx)) - .unwrap(); - cx.run_until_parked(); - - let pending = fake_model.pending_completions(); - assert_eq!(pending.len(), 1, "should have one pending completion"); - - let messages = &pending[0].messages; - let user_messages: Vec<_> = messages - .iter() - .filter(|m| m.role == language_model::Role::User) - .collect(); - assert_eq!(user_messages.len(), 1, "should have one user message"); - - let content = &user_messages[0].content[0]; - assert!( - content.to_str().unwrap().contains("TODO"), - "task prompt should be in user message" - ); -} - -#[gpui::test] -async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; - let fake_model = model.as_fake(); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Please summarize what you found".to_string(), - context_low_prompt: "Context low, wrap up".to_string(), - }; - - let project = thread.read_with(cx, |t, _| t.project.clone()); - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - - let subagent = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context, - context_server_registry, - Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), - cx, - ) - }); - - subagent - .update(cx, |thread, cx| { - thread.submit_user_message("Do some work", cx) - }) - .unwrap(); - cx.run_until_parked(); - - fake_model.send_last_completion_stream_text_chunk("I did the work"); - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - subagent - .update(cx, |thread, cx| thread.request_final_summary(cx)) - .unwrap(); - cx.run_until_parked(); - - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "should have pending completion for summary" - ); - - let messages = &pending.last().unwrap().messages; - let user_messages: Vec<_> = messages - .iter() - .filter(|m| m.role == language_model::Role::User) - .collect(); - - let last_user = user_messages.last().unwrap(); - assert!( - last_user.content[0].to_str().unwrap().contains("summarize"), - "summary prompt should be sent" - ); -} - -#[gpui::test] -async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) { - init_test(cx); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let subagent = cx.new(|cx| { - let mut thread = Thread::new_subagent( - project.clone(), - project_context, - context_server_registry, - Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), - cx, - ); - thread.add_tool(EchoTool); - thread.add_tool(DelayTool); - thread.add_tool(WordListTool); - thread - }); - - subagent.read_with(cx, |thread, _| { - assert!(thread.has_registered_tool("echo")); - assert!(thread.has_registered_tool("delay")); - assert!(thread.has_registered_tool("word_list")); - }); - - let allowed: collections::HashSet = - vec!["echo".into()].into_iter().collect(); - - subagent.update(cx, |thread, _cx| { - thread.restrict_tools(&allowed); - }); - - subagent.read_with(cx, |thread, _| { - assert!( - thread.has_registered_tool("echo"), - "echo should still be available" - ); - assert!( - !thread.has_registered_tool("delay"), - "delay should be removed" - ); - assert!( - !thread.has_registered_tool("word_list"), - "word_list should be removed" - ); - }); -} - #[gpui::test] async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) { init_test(cx); @@ -4220,33 +4083,16 @@ async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) { ) }); - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let subagent = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), - cx, - ) - }); + let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx)); parent.update(cx, |thread, _cx| { thread.register_running_subagent(subagent.downgrade()); }); subagent - .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Do work".to_string()], cx) + }) .unwrap(); cx.run_until_parked(); @@ -4285,6 +4131,9 @@ async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) { let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); let model = Arc::new(FakeLanguageModel::default()); + let environment = Rc::new(cx.update(|cx| { + FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx)) + })); let parent = cx.new(|cx| { Thread::new( @@ -4298,7 +4147,7 @@ async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) { }); #[allow(clippy::arc_with_non_send_sync)] - let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0)); + let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment)); let (event_stream, _rx, mut cancellation_tx) = crate::ToolCallEventStream::test_with_cancellation(); @@ -4310,7 +4159,6 @@ async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) { label: "Long running task".to_string(), task_prompt: "Do a very long task that takes forever".to_string(), summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), timeout_ms: None, allowed_tools: None, }, @@ -4343,405 +4191,286 @@ async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; - let fake_model = model.as_fake(); +async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) { + init_test(cx); + always_allow_tools(cx); cx.update(|cx| { cx.update_flags(true, vec!["subagents".to_string()]); }); - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let project = thread.read_with(cx, |t, _| t.project.clone()); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let project_context = cx.new(|_cx| ProjectContext::default()); let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - - let subagent = cx.new(|cx| { - Thread::new_subagent( + cx.update(LanguageModelRegistry::test); + let model = Arc::new(FakeLanguageModel::default()); + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let native_agent = NativeAgent::new( + project.clone(), + thread_store, + Templates::new(), + None, + fs, + &mut cx.to_async(), + ) + .await + .unwrap(); + let parent_thread = cx.new(|cx| { + Thread::new( project.clone(), project_context, context_server_registry, Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), + Some(model.clone()), cx, ) }); - subagent - .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) - .unwrap(); - cx.run_until_parked(); + let mut handles = Vec::new(); + for _ in 0..MAX_PARALLEL_SUBAGENTS { + let handle = cx + .update(|cx| { + NativeThreadEnvironment::create_subagent_thread( + native_agent.downgrade(), + parent_thread.clone(), + "some title".to_string(), + "some task".to_string(), + None, + None, + cx, + ) + }) + .expect("Expected to be able to create subagent thread"); + handles.push(handle); + } - subagent.read_with(cx, |thread, _| { - assert!(!thread.is_turn_complete(), "turn should be in progress"); + let result = cx.update(|cx| { + NativeThreadEnvironment::create_subagent_thread( + native_agent.downgrade(), + parent_thread.clone(), + "some title".to_string(), + "some task".to_string(), + None, + None, + cx, + ) }); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + format!( + "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.", + MAX_PARALLEL_SUBAGENTS + ) + ); +} - fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey { - provider: LanguageModelProviderName::from("Fake".to_string()), - }); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); +#[gpui::test] +async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) { + init_test(cx); - subagent.read_with(cx, |thread, _| { - assert!( - thread.is_turn_complete(), - "turn should be complete after non-retryable error" - ); - }); -} - -#[gpui::test] -async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; - let fake_model = model.as_fake(); + always_allow_tools(cx); cx.update(|cx| { cx.update_flags(true, vec!["subagents".to_string()]); }); - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize your work".to_string(), - context_low_prompt: "Context low, stop and summarize".to_string(), - }; - - let project = thread.read_with(cx, |t, _| t.project.clone()); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let project_context = cx.new(|_cx| ProjectContext::default()); let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - - let subagent = cx.new(|cx| { - Thread::new_subagent( + cx.update(LanguageModelRegistry::test); + let model = Arc::new(FakeLanguageModel::default()); + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let native_agent = NativeAgent::new( + project.clone(), + thread_store, + Templates::new(), + None, + fs, + &mut cx.to_async(), + ) + .await + .unwrap(); + let parent_thread = cx.new(|cx| { + Thread::new( project.clone(), - project_context.clone(), - context_server_registry.clone(), + project_context, + context_server_registry, Templates::new(), - model.clone(), - subagent_context.clone(), - std::collections::BTreeMap::new(), + Some(model.clone()), cx, ) }); - subagent.update(cx, |thread, _| { - thread.add_tool(EchoTool); - }); - - subagent - .update(cx, |thread, cx| { - thread.submit_user_message("Do some work", cx) + let subagent_handle = cx + .update(|cx| { + NativeThreadEnvironment::create_subagent_thread( + native_agent.downgrade(), + parent_thread.clone(), + "some title".to_string(), + "task prompt".to_string(), + Some(Duration::from_millis(10)), + None, + cx, + ) }) - .unwrap(); - cx.run_until_parked(); - - fake_model.send_last_completion_stream_text_chunk("Working on it..."); - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); + .expect("Failed to create subagent"); - let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx)); - assert!( - interrupt_result.is_ok(), - "interrupt_for_summary should succeed" - ); + let summary_task = + subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async()); cx.run_until_parked(); - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "should have pending completion for interrupted summary" - ); - - let messages = &pending.last().unwrap().messages; - let user_messages: Vec<_> = messages - .iter() - .filter(|m| m.role == language_model::Role::User) - .collect(); - - let last_user = user_messages.last().unwrap(); - let content_str = last_user.content[0].to_str().unwrap(); - assert!( - content_str.contains("Context low") || content_str.contains("stop and summarize"), - "context_low_prompt should be sent when interrupting: got {:?}", - content_str - ); -} - -#[gpui::test] -async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; - let fake_model = model.as_fake(); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let project = thread.read_with(cx, |t, _| t.project.clone()); - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + { + let messages = model.pending_completions().last().unwrap().messages.clone(); + // Ensure that model received a system prompt + assert_eq!(messages[0].role, Role::System); + // Ensure that model received a task prompt + assert_eq!(messages[1].role, Role::User); + assert_eq!( + messages[1].content, + vec![MessageContent::Text("task prompt".to_string())] + ); + } - let subagent = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context, - context_server_registry, - Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), - cx, - ) - }); + model.send_last_completion_stream_text_chunk("Some task response..."); + model.end_last_completion_stream(); - subagent - .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) - .unwrap(); cx.run_until_parked(); - let max_tokens = model.max_token_count(); - let high_usage = language_model::TokenUsage { - input_tokens: (max_tokens as f64 * 0.80) as u64, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }; - - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage)); - fake_model.send_last_completion_stream_text_chunk("Working..."); - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); + { + let messages = model.pending_completions().last().unwrap().messages.clone(); + assert_eq!(messages[2].role, Role::Assistant); + assert_eq!( + messages[2].content, + vec![MessageContent::Text("Some task response...".to_string())] + ); + // Ensure that model received a summary prompt + assert_eq!(messages[3].role, Role::User); + assert_eq!( + messages[3].content, + vec![MessageContent::Text("summary prompt".to_string())] + ); + } - let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage()); - assert!(usage.is_some(), "should have token usage after completion"); + model.send_last_completion_stream_text_chunk("Some summary..."); + model.end_last_completion_stream(); - let usage = usage.unwrap(); - let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32); - assert!( - remaining_ratio <= 0.25, - "remaining ratio should be at or below 25% (got {}%), indicating context is low", - remaining_ratio * 100.0 - ); + let result = summary_task.await; + assert_eq!(result.unwrap(), "Some summary...\n"); } #[gpui::test] -async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) { +async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded( + cx: &mut TestAppContext, +) { init_test(cx); + always_allow_tools(cx); + cx.update(|cx| { cx.update_flags(true, vec!["subagents".to_string()]); }); let fs = FakeFs::new(cx.executor()); fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let project_context = cx.new(|_cx| ProjectContext::default()); let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + cx.update(LanguageModelRegistry::test); let model = Arc::new(FakeLanguageModel::default()); - - let parent = cx.new(|cx| { - let mut thread = Thread::new( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ); - thread.add_tool(EchoTool); - thread - }); - - #[allow(clippy::arc_with_non_send_sync)] - let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0)); - - let allowed_tools = Some(vec!["nonexistent_tool".to_string()]); - let result = cx.read(|cx| tool.validate_allowed_tools(&allowed_tools, cx)); - - assert!(result.is_err(), "should reject unknown tool"); - let err_msg = result.unwrap_err().to_string(); - assert!( - err_msg.contains("nonexistent_tool"), - "error should mention the invalid tool name: {}", - err_msg - ); - assert!( - err_msg.contains("do not exist"), - "error should explain the tool does not exist: {}", - err_msg - ); -} - -#[gpui::test] -async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; - let fake_model = model.as_fake(); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let project = thread.read_with(cx, |t, _| t.project.clone()); - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - - let subagent = cx.new(|cx| { - Thread::new_subagent( + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let native_agent = NativeAgent::new( + project.clone(), + thread_store, + Templates::new(), + None, + fs, + &mut cx.to_async(), + ) + .await + .unwrap(); + let parent_thread = cx.new(|cx| { + Thread::new( project.clone(), project_context, context_server_registry, Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), + Some(model.clone()), cx, ) }); - subagent - .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) - .unwrap(); - cx.run_until_parked(); + let subagent_handle = cx + .update(|cx| { + NativeThreadEnvironment::create_subagent_thread( + native_agent.downgrade(), + parent_thread.clone(), + "some title".to_string(), + "task prompt".to_string(), + Some(Duration::from_millis(100)), + None, + cx, + ) + }) + .expect("Failed to create subagent"); + + let summary_task = + subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async()); - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); - fake_model.end_last_completion_stream(); cx.run_until_parked(); - subagent.read_with(cx, |thread, _| { - assert!( - thread.is_turn_complete(), - "turn should complete even with empty response" + { + let messages = model.pending_completions().last().unwrap().messages.clone(); + // Ensure that model received a system prompt + assert_eq!(messages[0].role, Role::System); + // Ensure that model received a task prompt + assert_eq!( + messages[1].content, + vec![MessageContent::Text("task prompt".to_string())] ); - }); -} - -#[gpui::test] -async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) { - init_test(cx); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - - let depth_1_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("root-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let depth_1_subagent = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - Templates::new(), - model.clone(), - depth_1_context, - std::collections::BTreeMap::new(), - cx, - ) - }); - - depth_1_subagent.read_with(cx, |thread, _| { - assert_eq!(thread.depth(), 1); - assert!(thread.is_subagent()); - }); - - let depth_2_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"), - depth: 2, - summary_prompt: "Summarize depth 2".to_string(), - context_low_prompt: "Context low depth 2".to_string(), - }; + } - let depth_2_subagent = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - Templates::new(), - model.clone(), - depth_2_context, - std::collections::BTreeMap::new(), - cx, - ) - }); + // Don't complete the initial model stream — let the timeout expire instead. + cx.executor().advance_clock(Duration::from_millis(200)); + cx.run_until_parked(); - depth_2_subagent.read_with(cx, |thread, _| { - assert_eq!(thread.depth(), 2); - assert!(thread.is_subagent()); - }); + // After the timeout fires, the thread is cancelled and context_low_prompt is sent + // instead of the summary_prompt. + { + let messages = model.pending_completions().last().unwrap().messages.clone(); + let last_user_message = messages + .iter() + .rev() + .find(|m| m.role == Role::User) + .unwrap(); + assert_eq!( + last_user_message.content, + vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())] + ); + } - depth_2_subagent - .update(cx, |thread, cx| { - thread.submit_user_message("Nested task", cx) - }) - .unwrap(); - cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("Some context low response..."); + model.end_last_completion_stream(); - let pending = model.as_fake().pending_completions(); - assert!( - !pending.is_empty(), - "depth-2 subagent should be able to submit messages" - ); + let result = summary_task.await; + assert_eq!(result.unwrap(), "Some context low response...\n"); } #[gpui::test] -async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) { +async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) { init_test(cx); + always_allow_tools(cx); cx.update(|cx| { @@ -4750,179 +4479,71 @@ async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) { let fs = FakeFs::new(cx.executor()); fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let project_context = cx.new(|_cx| ProjectContext::default()); let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + cx.update(LanguageModelRegistry::test); let model = Arc::new(FakeLanguageModel::default()); - let fake_model = model.as_fake(); - - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), - depth: 1, - summary_prompt: "Summarize what you did".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let subagent = cx.new(|cx| { - let mut thread = Thread::new_subagent( + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let native_agent = NativeAgent::new( + project.clone(), + thread_store, + Templates::new(), + None, + fs, + &mut cx.to_async(), + ) + .await + .unwrap(); + let parent_thread = cx.new(|cx| { + let mut thread = Thread::new( project.clone(), project_context, context_server_registry, Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), + Some(model.clone()), cx, ); - thread.add_tool(EchoTool); + thread.add_tool(ListDirectoryTool::new(project.clone()), None); + thread.add_tool(GrepTool::new(project.clone()), None); thread }); - subagent.read_with(cx, |thread, _| { - assert!( - thread.has_registered_tool("echo"), - "subagent should have echo tool" - ); - }); - - subagent - .update(cx, |thread, cx| { - thread.submit_user_message("Use the echo tool to echo 'hello world'", cx) + let _subagent_handle = cx + .update(|cx| { + NativeThreadEnvironment::create_subagent_thread( + native_agent.downgrade(), + parent_thread.clone(), + "some title".to_string(), + "task prompt".to_string(), + Some(Duration::from_millis(10)), + None, + cx, + ) }) - .unwrap(); - cx.run_until_parked(); + .expect("Failed to create subagent"); - let tool_use = LanguageModelToolUse { - id: "tool_call_1".into(), - name: EchoTool::NAME.into(), - raw_input: json!({"text": "hello world"}).to_string(), - input: json!({"text": "hello world"}), - is_input_complete: true, - thought_signature: None, - }; - fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); - fake_model.end_last_completion_stream(); cx.run_until_parked(); - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "should have pending completion after tool use" - ); - - let last_completion = pending.last().unwrap(); - let has_tool_result = last_completion.messages.iter().any(|m| { - m.content - .iter() - .any(|c| matches!(c, MessageContent::ToolResult(_))) - }); - assert!( - has_tool_result, - "tool result should be in the messages sent back to the model" - ); + let tools = model + .pending_completions() + .last() + .unwrap() + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect::>(); + assert_eq!(tools.len(), 2); + assert!(tools.contains(&"grep".to_string())); + assert!(tools.contains(&"list_directory".to_string())); } #[gpui::test] -async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) { +async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) { init_test(cx); - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - - let parent = cx.new(|cx| { - Thread::new( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - - let mut subagents = Vec::new(); - for i in 0..MAX_PARALLEL_SUBAGENTS { - let subagent_context = SubagentContext { - parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), - tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)), - depth: 1, - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - }; - - let subagent = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - Templates::new(), - model.clone(), - subagent_context, - std::collections::BTreeMap::new(), - cx, - ) - }); - - parent.update(cx, |thread, _cx| { - thread.register_running_subagent(subagent.downgrade()); - }); - subagents.push(subagent); - } - - parent.read_with(cx, |thread, _| { - assert_eq!( - thread.running_subagent_count(), - MAX_PARALLEL_SUBAGENTS, - "should have MAX_PARALLEL_SUBAGENTS registered" - ); - }); - - #[allow(clippy::arc_with_non_send_sync)] - let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0)); - - let (event_stream, _rx) = crate::ToolCallEventStream::test(); - - let result = cx.update(|cx| { - tool.run( - SubagentToolInput { - label: "Test".to_string(), - task_prompt: "Do something".to_string(), - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - timeout_ms: None, - allowed_tools: None, - }, - event_stream, - cx, - ) - }); - - let err = result.await.unwrap_err(); - assert!( - err.to_string().contains("Maximum parallel subagents"), - "should reject when max parallel subagents reached: {}", - err - ); - - drop(subagents); -} - -#[gpui::test] -async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) { - init_test(cx); always_allow_tools(cx); cx.update(|cx| { @@ -4931,105 +4552,63 @@ async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) { let fs = FakeFs::new(cx.executor()); fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let project_context = cx.new(|_cx| ProjectContext::default()); let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + cx.update(LanguageModelRegistry::test); let model = Arc::new(FakeLanguageModel::default()); - let fake_model = model.as_fake(); - - let parent = cx.new(|cx| { + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let native_agent = NativeAgent::new( + project.clone(), + thread_store, + Templates::new(), + None, + fs, + &mut cx.to_async(), + ) + .await + .unwrap(); + let parent_thread = cx.new(|cx| { let mut thread = Thread::new( project.clone(), - project_context.clone(), - context_server_registry.clone(), + project_context, + context_server_registry, Templates::new(), Some(model.clone()), cx, ); - thread.add_tool(EchoTool); + thread.add_tool(ListDirectoryTool::new(project.clone()), None); + thread.add_tool(GrepTool::new(project.clone()), None); thread }); - #[allow(clippy::arc_with_non_send_sync)] - let tool = Arc::new(SubagentTool::new(parent.downgrade(), 0)); - - let (event_stream, _rx) = crate::ToolCallEventStream::test(); - - let task = cx.update(|cx| { - tool.run( - SubagentToolInput { - label: "Research task".to_string(), - task_prompt: "Find all TODOs in the codebase".to_string(), - summary_prompt: "Summarize what you found".to_string(), - context_low_prompt: "Context low, wrap up".to_string(), - timeout_ms: None, - allowed_tools: None, - }, - event_stream, - cx, - ) - }); - - cx.run_until_parked(); - - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "subagent should have started and sent a completion request" - ); - - let first_completion = &pending[0]; - let has_task_prompt = first_completion.messages.iter().any(|m| { - m.role == language_model::Role::User - && m.content - .iter() - .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false)) - }); - assert!(has_task_prompt, "task prompt should be sent to subagent"); - - fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase."); - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "should have pending completion for summary request" - ); - - let last_completion = pending.last().unwrap(); - let has_summary_prompt = last_completion.messages.iter().any(|m| { - m.role == language_model::Role::User - && m.content.iter().any(|c| { - c.to_str() - .map(|s| s.contains("Summarize") || s.contains("summarize")) - .unwrap_or(false) - }) - }); - assert!( - has_summary_prompt, - "summary prompt should be sent after task completion" - ); + let _subagent_handle = cx + .update(|cx| { + NativeThreadEnvironment::create_subagent_thread( + native_agent.downgrade(), + parent_thread.clone(), + "some title".to_string(), + "task prompt".to_string(), + Some(Duration::from_millis(10)), + Some(vec!["grep".to_string()]), + cx, + ) + }) + .expect("Failed to create subagent"); - fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files."); - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); - fake_model.end_last_completion_stream(); cx.run_until_parked(); - let result = task.await; - assert!(result.is_ok(), "subagent tool should complete successfully"); - - let summary = result.unwrap(); - assert!( - summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"), - "summary should contain subagent's response: {}", - summary - ); + let tools = model + .pending_completions() + .last() + .unwrap() + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect::>(); + assert_eq!(tools, vec!["grep"]); } #[gpui::test] @@ -5620,7 +5199,7 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) { // Add a tool so we can simulate tool calls thread.update(cx, |thread, _cx| { - thread.add_tool(EchoTool); + thread.add_tool(EchoTool, None); }); // Start a turn by sending a message diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 24d661dbf038a9a12ed528c7f52e72aa72387c3e..6acc71b3dcb7ec582512b71557c6bd7d077c3f62 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -63,22 +63,13 @@ pub const MAX_SUBAGENT_DEPTH: u8 = 4; pub const MAX_PARALLEL_SUBAGENTS: usize = 8; /// Context passed to a subagent thread for lifecycle management -#[derive(Clone)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct SubagentContext { /// ID of the parent thread pub parent_thread_id: acp::SessionId, - /// ID of the tool call that spawned this subagent - pub tool_use_id: LanguageModelToolUseId, - /// Current depth level (0 = root agent, 1 = first-level subagent, etc.) pub depth: u8, - - /// Prompt to send when subagent completes successfully - pub summary_prompt: String, - - /// Prompt to send when context is running low (≤25% remaining) - pub context_low_prompt: String, } /// The ID of the user prompt that initiated a request. @@ -179,7 +170,7 @@ pub enum UserMessageContent { impl UserMessage { pub fn to_markdown(&self) -> String { - let mut markdown = String::from("## User\n\n"); + let mut markdown = String::new(); for content in &self.content { match content { @@ -431,7 +422,7 @@ fn codeblock_tag(full_path: &Path, line_range: Option<&RangeInclusive>) -> impl AgentMessage { pub fn to_markdown(&self) -> String { - let mut markdown = String::from("## Assistant\n\n"); + let mut markdown = String::new(); for content in &self.content { match content { @@ -587,6 +578,11 @@ pub trait TerminalHandle { fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result; } +pub trait SubagentHandle { + fn id(&self) -> acp::SessionId; + fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task>; +} + pub trait ThreadEnvironment { fn create_terminal( &self, @@ -595,6 +591,16 @@ pub trait ThreadEnvironment { output_byte_limit: Option, cx: &mut AsyncApp, ) -> Task>>; + + fn create_subagent( + &self, + parent_thread: Entity, + label: String, + initial_prompt: String, + timeout: Option, + allowed_tools: Option>, + cx: &mut App, + ) -> Result>; } #[derive(Debug)] @@ -605,6 +611,7 @@ pub enum ThreadEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + SubagentSpawned(acp::SessionId), Retry(acp_thread::RetryStatus), Stop(acp::StopReason), } @@ -827,6 +834,27 @@ impl Thread { .embedded_context(true) } + pub fn new_subagent(parent_thread: &Entity, cx: &mut Context) -> Self { + let project = parent_thread.read(cx).project.clone(); + let project_context = parent_thread.read(cx).project_context.clone(); + let context_server_registry = parent_thread.read(cx).context_server_registry.clone(); + let templates = parent_thread.read(cx).templates.clone(); + let model = parent_thread.read(cx).model().cloned(); + let mut thread = Self::new( + project, + project_context, + context_server_registry, + templates, + model, + cx, + ); + thread.subagent_context = Some(SubagentContext { + parent_thread_id: parent_thread.read(cx).id().clone(), + depth: parent_thread.read(cx).depth() + 1, + }); + thread + } + pub fn new( project: Entity, project_context: Entity, @@ -889,78 +917,6 @@ impl Thread { } } - pub fn new_subagent( - project: Entity, - project_context: Entity, - context_server_registry: Entity, - templates: Arc, - model: Arc, - subagent_context: SubagentContext, - parent_tools: BTreeMap>, - cx: &mut Context, - ) -> Self { - let settings = AgentSettings::get_global(cx); - let profile_id = settings.default_profile.clone(); - let enable_thinking = settings - .default_model - .as_ref() - .is_some_and(|model| model.enable_thinking); - let thinking_effort = settings - .default_model - .as_ref() - .and_then(|model| model.effort.clone()); - let action_log = cx.new(|_cx| ActionLog::new(project.clone())); - let (prompt_capabilities_tx, prompt_capabilities_rx) = - watch::channel(Self::prompt_capabilities(Some(model.as_ref()))); - - // Rebind tools that hold thread references to use this subagent's thread - // instead of the parent's thread. This is critical for tools like EditFileTool - // that make model requests using the thread's ID. - let weak_self = cx.weak_entity(); - let tools: BTreeMap> = parent_tools - .into_iter() - .map(|(name, tool)| { - let rebound = tool.rebind_thread(weak_self.clone()).unwrap_or(tool); - (name, rebound) - }) - .collect(); - - Self { - id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()), - prompt_id: PromptId::new(), - updated_at: Utc::now(), - title: None, - pending_title_generation: None, - pending_summary_generation: None, - summary: None, - messages: Vec::new(), - user_store: project.read(cx).user_store(), - running_turn: None, - has_queued_message: false, - pending_message: None, - tools, - request_token_usage: HashMap::default(), - cumulative_token_usage: TokenUsage::default(), - initial_project_snapshot: Task::ready(None).shared(), - context_server_registry, - profile_id, - project_context, - templates, - model: Some(model), - summarization_model: None, - thinking_enabled: enable_thinking, - thinking_effort, - prompt_capabilities_tx, - prompt_capabilities_rx, - project, - action_log, - file_read_times: HashMap::default(), - imported: false, - subagent_context: Some(subagent_context), - running_subagents: Vec::new(), - } - } - pub fn id(&self) -> &acp::SessionId { &self.id } @@ -1077,6 +1033,7 @@ impl Thread { }), ) .raw_output(output), + None, ); } @@ -1167,7 +1124,7 @@ impl Thread { prompt_capabilities_rx, file_read_times: HashMap::default(), imported: db_thread.imported, - subagent_context: None, + subagent_context: db_thread.subagent_context, running_subagents: Vec::new(), } } @@ -1188,6 +1145,7 @@ impl Thread { }), profile: Some(self.profile_id.clone()), imported: self.imported, + subagent_context: self.subagent_context.clone(), }; cx.background_spawn(async move { @@ -1286,53 +1244,106 @@ impl Thread { pub fn add_default_tools( &mut self, + allowed_tool_names: Option>, environment: Rc, cx: &mut Context, ) { let language_registry = self.project.read(cx).languages().clone(); - self.add_tool(CopyPathTool::new(self.project.clone())); - self.add_tool(CreateDirectoryTool::new(self.project.clone())); - self.add_tool(DeletePathTool::new( - self.project.clone(), - self.action_log.clone(), - )); - self.add_tool(DiagnosticsTool::new(self.project.clone())); - self.add_tool(EditFileTool::new( - self.project.clone(), - cx.weak_entity(), - language_registry.clone(), - Templates::new(), - )); - self.add_tool(StreamingEditFileTool::new( - self.project.clone(), - cx.weak_entity(), - language_registry, - Templates::new(), - )); - self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); - self.add_tool(FindPathTool::new(self.project.clone())); - self.add_tool(GrepTool::new(self.project.clone())); - self.add_tool(ListDirectoryTool::new(self.project.clone())); - self.add_tool(MovePathTool::new(self.project.clone())); - self.add_tool(NowTool); - self.add_tool(OpenTool::new(self.project.clone())); - self.add_tool(ReadFileTool::new( - cx.weak_entity(), - self.project.clone(), - self.action_log.clone(), - )); - self.add_tool(SaveFileTool::new(self.project.clone())); - self.add_tool(RestoreFileFromDiskTool::new(self.project.clone())); - self.add_tool(TerminalTool::new(self.project.clone(), environment)); - self.add_tool(ThinkingTool); - self.add_tool(WebSearchTool); + self.add_tool( + CopyPathTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + CreateDirectoryTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + DeletePathTool::new(self.project.clone(), self.action_log.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + DiagnosticsTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + EditFileTool::new( + self.project.clone(), + cx.weak_entity(), + language_registry.clone(), + Templates::new(), + ), + allowed_tool_names.as_ref(), + ); + self.add_tool( + StreamingEditFileTool::new( + self.project.clone(), + cx.weak_entity(), + language_registry, + Templates::new(), + ), + allowed_tool_names.as_ref(), + ); + self.add_tool( + FetchTool::new(self.project.read(cx).client().http_client()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + FindPathTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + GrepTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + ListDirectoryTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + MovePathTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool(NowTool, allowed_tool_names.as_ref()); + self.add_tool( + OpenTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + ReadFileTool::new( + cx.weak_entity(), + self.project.clone(), + self.action_log.clone(), + ), + allowed_tool_names.as_ref(), + ); + self.add_tool( + SaveFileTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + RestoreFileFromDiskTool::new(self.project.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool( + TerminalTool::new(self.project.clone(), environment.clone()), + allowed_tool_names.as_ref(), + ); + self.add_tool(ThinkingTool, allowed_tool_names.as_ref()); + self.add_tool(WebSearchTool, allowed_tool_names.as_ref()); if cx.has_flag::() && self.depth() < MAX_SUBAGENT_DEPTH { - self.add_tool(SubagentTool::new(cx.weak_entity(), self.depth())); + self.add_tool( + SubagentTool::new(cx.weak_entity(), environment), + allowed_tool_names.as_ref(), + ); } } - pub fn add_tool(&mut self, tool: T) { + pub fn add_tool(&mut self, tool: T, allowed_tool_names: Option<&Vec<&str>>) { + if allowed_tool_names.is_some_and(|tool_names| !tool_names.contains(&T::NAME)) { + return; + } + debug_assert!( !self.tools.contains_key(T::NAME), "Duplicate tool name: {}", @@ -1345,10 +1356,6 @@ impl Thread { self.tools.remove(name).is_some() } - pub fn restrict_tools(&mut self, allowed: &collections::HashSet) { - self.tools.retain(|name, _| allowed.contains(name)); - } - pub fn profile(&self) -> &AgentProfileId { &self.profile_id } @@ -1778,6 +1785,7 @@ impl Thread { acp::ToolCallStatus::Completed }) .raw_output(tool_result.output.clone()), + None, ); this.update(cx, |this, _cx| { this.pending_message() @@ -2048,6 +2056,7 @@ impl Thread { .title(title.as_str()) .kind(kind) .raw_input(tool_use.input.clone()), + None, ); } @@ -2472,13 +2481,19 @@ impl Thread { self.tools.keys().cloned().collect() } - pub fn register_running_subagent(&mut self, subagent: WeakEntity) { + pub(crate) fn register_running_subagent(&mut self, subagent: WeakEntity) { self.running_subagents.push(subagent); } - pub fn unregister_running_subagent(&mut self, subagent: &WeakEntity) { - self.running_subagents - .retain(|s| s.entity_id() != subagent.entity_id()); + pub(crate) fn unregister_running_subagent( + &mut self, + subagent_session_id: &acp::SessionId, + cx: &App, + ) { + self.running_subagents.retain(|s| { + s.upgrade() + .map_or(false, |s| s.read(cx).id() != subagent_session_id) + }); } pub fn running_subagent_count(&self) -> usize { @@ -2492,51 +2507,23 @@ impl Thread { self.subagent_context.is_some() } - pub fn depth(&self) -> u8 { - self.subagent_context.as_ref().map(|c| c.depth).unwrap_or(0) - } - - pub fn is_turn_complete(&self) -> bool { - self.running_turn.is_none() + pub fn parent_thread_id(&self) -> Option { + self.subagent_context + .as_ref() + .map(|c| c.parent_thread_id.clone()) } - pub fn submit_user_message( - &mut self, - content: impl Into, - cx: &mut Context, - ) -> Result>> { - let content = content.into(); - self.messages.push(Message::User(UserMessage { - id: UserMessageId::new(), - content: vec![UserMessageContent::Text(content)], - })); - cx.notify(); - self.send_existing(cx) + pub fn depth(&self) -> u8 { + self.subagent_context.as_ref().map(|c| c.depth).unwrap_or(0) } - pub fn interrupt_for_summary( - &mut self, - cx: &mut Context, - ) -> Result>> { - let context = self - .subagent_context - .as_ref() - .context("Not a subagent thread")?; - let prompt = context.context_low_prompt.clone(); - self.cancel(cx).detach(); - self.submit_user_message(prompt, cx) + #[cfg(any(test, feature = "test-support"))] + pub fn set_subagent_context(&mut self, context: SubagentContext) { + self.subagent_context = Some(context); } - pub fn request_final_summary( - &mut self, - cx: &mut Context, - ) -> Result>> { - let context = self - .subagent_context - .as_ref() - .context("Not a subagent thread")?; - let prompt = context.summary_prompt.clone(); - self.submit_user_message(prompt, cx) + pub fn is_turn_complete(&self) -> bool { + self.running_turn.is_none() } fn build_request_messages( @@ -2584,11 +2571,16 @@ impl Thread { if ix > 0 { markdown.push('\n'); } + match message { + Message::User(_) => markdown.push_str("## User\n\n"), + Message::Agent(_) => markdown.push_str("## Assistant\n\n"), + Message::Resume => {} + } markdown.push_str(&message.to_markdown()); } if let Some(message) = self.pending_message.as_ref() { - markdown.push('\n'); + markdown.push_str("\n## Assistant\n\n"); markdown.push_str(&message.to_markdown()); } @@ -2795,15 +2787,6 @@ where fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } - - /// Create a new instance of this tool bound to a different thread. - /// This is used when creating subagents, so that tools like EditFileTool - /// that hold a thread reference will use the subagent's thread instead - /// of the parent's thread. - /// Returns None if the tool doesn't need rebinding (most tools). - fn rebind_thread(&self, _new_thread: WeakEntity) -> Option> { - None - } } pub struct Erased(T); @@ -2835,14 +2818,6 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Result<()>; - /// Create a new instance of this tool bound to a different thread. - /// This is used when creating subagents, so that tools like EditFileTool - /// that hold a thread reference will use the subagent's thread instead - /// of the parent's thread. - /// Returns None if the tool doesn't need rebinding (most tools). - fn rebind_thread(&self, _new_thread: WeakEntity) -> Option> { - None - } } impl AnyAgentTool for Erased> @@ -2906,10 +2881,6 @@ where let output = serde_json::from_value(output)?; self.0.replay(input, output, event_stream, cx) } - - fn rebind_thread(&self, new_thread: WeakEntity) -> Option> { - self.0.rebind_thread(new_thread) - } } #[derive(Clone)] @@ -2970,10 +2941,13 @@ impl ThreadEventStream { &self, tool_use_id: &LanguageModelToolUseId, fields: acp::ToolCallUpdateFields, + meta: Option, ) { self.0 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( - acp::ToolCallUpdate::new(tool_use_id.to_string(), fields).into(), + acp::ToolCallUpdate::new(tool_use_id.to_string(), fields) + .meta(meta) + .into(), ))) .ok(); } @@ -3081,7 +3055,16 @@ impl ToolCallEventStream { pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) { self.stream - .update_tool_call_fields(&self.tool_use_id, fields); + .update_tool_call_fields(&self.tool_use_id, fields, None); + } + + pub fn update_fields_with_meta( + &self, + fields: acp::ToolCallUpdateFields, + meta: Option, + ) { + self.stream + .update_tool_call_fields(&self.tool_use_id, fields, meta); } pub fn update_diff(&self, diff: Entity) { @@ -3097,16 +3080,10 @@ impl ToolCallEventStream { .ok(); } - pub fn update_subagent_thread(&self, thread: Entity) { + pub fn subagent_spawned(&self, id: acp::SessionId) { self.stream .0 - .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( - acp_thread::ToolCallUpdateSubagentThread { - id: acp::ToolCallId::new(self.tool_use_id.to_string()), - thread, - } - .into(), - ))) + .unbounded_send(Ok(ThreadEvent::SubagentSpawned(id))) .ok(); } @@ -3421,6 +3398,12 @@ impl From<&str> for UserMessageContent { } } +impl From for UserMessageContent { + fn from(text: String) -> Self { + Self::Text(text) + } +} + impl UserMessageContent { pub fn from_content_block(value: acp::ContentBlock, path_style: PathStyle) -> Self { match value { diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index cd7c03de546f2fff2e603c59739816640d24e9f5..83548b69d126462fac1766df5d0ec5bb931be493 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -114,7 +114,12 @@ impl ThreadStore { let database_connection = ThreadsDatabase::connect(cx); cx.spawn(async move |this, cx| { let database = database_connection.await.map_err(|err| anyhow!(err))?; - let threads = database.list_threads().await?; + let threads = database + .list_threads() + .await? + .into_iter() + .filter(|thread| thread.parent_session_id.is_none()) + .collect::>(); this.update(cx, |this, cx| { this.threads = threads; cx.notify(); @@ -156,6 +161,7 @@ mod tests { model: None, profile: None, imported: false, + subagent_context: None, } } diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index d5a6a51433cbfc66c1c735bb59afd4a6ca072d5e..9fdfd8c726bd2073cdf26db01da5d15b0768ba48 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -146,15 +146,6 @@ impl EditFileTool { } } - pub fn with_thread(&self, new_thread: WeakEntity) -> Self { - Self { - project: self.project.clone(), - thread: new_thread, - language_registry: self.language_registry.clone(), - templates: self.templates.clone(), - } - } - fn authorize( &self, input: &EditFileToolInput, @@ -665,13 +656,6 @@ impl AgentTool for EditFileTool { })); Ok(()) } - - fn rebind_thread( - &self, - new_thread: gpui::WeakEntity, - ) -> Option> { - Some(self.with_thread(new_thread).erase()) - } } /// Validate that the file path is valid, meaning: diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index f6395742331ab17947c6885e186aab31bb8c826c..e0f1df0ca8662712a5e6967629740efd3cc89677 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -65,14 +65,6 @@ impl ReadFileTool { action_log, } } - - pub fn with_thread(&self, new_thread: WeakEntity) -> Self { - Self { - thread: new_thread, - project: self.project.clone(), - action_log: self.action_log.clone(), - } - } } impl AgentTool for ReadFileTool { @@ -314,13 +306,6 @@ impl AgentTool for ReadFileTool { result }) } - - fn rebind_thread( - &self, - new_thread: WeakEntity, - ) -> Option> { - Some(self.with_thread(new_thread).erase()) - } } #[cfg(test)] diff --git a/crates/agent/src/tools/subagent_tool.rs b/crates/agent/src/tools/subagent_tool.rs index ec7fa937b7e9ec3168f107e3a7bb50e8cf948da4..14edb0113724520dd5057e33f909cddb6182666c 100644 --- a/crates/agent/src/tools/subagent_tool.rs +++ b/crates/agent/src/tools/subagent_tool.rs @@ -1,31 +1,15 @@ -use acp_thread::{AcpThread, AgentConnection, UserMessageId}; -use action_log::ActionLog; +use acp_thread::SUBAGENT_SESSION_ID_META_KEY; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; -use collections::{BTreeMap, HashSet}; -use futures::{FutureExt, channel::mpsc}; -use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity}; -use language_model::LanguageModelToolUseId; -use project::Project; +use futures::FutureExt as _; +use gpui::{App, Entity, SharedString, Task, WeakEntity}; +use language_model::LanguageModelToolResultContent; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use smol::stream::StreamExt; -use std::any::Any; -use std::path::Path; -use std::rc::Rc; use std::sync::Arc; -use std::time::Duration; -use util::ResultExt; -use watch; +use std::{rc::Rc, time::Duration}; -use crate::{ - AgentTool, AnyAgentTool, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH, SubagentContext, Thread, - ThreadEvent, ToolCallAuthorization, ToolCallEventStream, -}; - -/// When a subagent's remaining context window falls below this fraction (25%), -/// the "context running out" prompt is sent to encourage the subagent to wrap up. -const CONTEXT_LOW_THRESHOLD: f32 = 0.25; +use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream}; /// Spawns a subagent with its own context window to perform a delegated task. /// @@ -64,13 +48,6 @@ pub struct SubagentToolInput { /// Example: "Summarize what you found, listing the top 3 alternatives with pros/cons." pub summary_prompt: String, - /// The prompt sent if the subagent is running low on context (25% remaining). - /// Should instruct it to stop and summarize progress so far, plus what's left undone. - /// - /// Example: "Context is running low. Stop and summarize your progress so far, - /// and list what remains to be investigated." - pub context_low_prompt: String, - /// Optional: Maximum runtime in milliseconds. If exceeded, the subagent is /// asked to summarize and return. No timeout by default. #[serde(default)] @@ -83,36 +60,47 @@ pub struct SubagentToolInput { pub allowed_tools: Option>, } +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct SubagentToolOutput { + pub subagent_session_id: acp::SessionId, + pub summary: String, +} + +impl From for LanguageModelToolResultContent { + fn from(output: SubagentToolOutput) -> Self { + output.summary.into() + } +} + /// Tool that spawns a subagent thread to work on a task. pub struct SubagentTool { parent_thread: WeakEntity, - current_depth: u8, + environment: Rc, } impl SubagentTool { - pub fn new(parent_thread: WeakEntity, current_depth: u8) -> Self { + pub fn new(parent_thread: WeakEntity, environment: Rc) -> Self { Self { parent_thread, - current_depth, + environment, } } - pub fn validate_allowed_tools( - &self, + fn validate_allowed_tools( allowed_tools: &Option>, + parent_thread: &Entity, cx: &App, ) -> Result<()> { let Some(allowed_tools) = allowed_tools else { return Ok(()); }; - let invalid_tools: Vec<_> = self.parent_thread.read_with(cx, |thread, _cx| { - allowed_tools - .iter() - .filter(|tool| !thread.tools.contains_key(tool.as_str())) - .map(|s| format!("'{s}'")) - .collect() - })?; + let thread = parent_thread.read(cx); + let invalid_tools: Vec<_> = allowed_tools + .iter() + .filter(|tool| !thread.tools.contains_key(tool.as_str())) + .map(|s| format!("'{s}'")) + .collect::>(); if !invalid_tools.is_empty() { return Err(anyhow!( @@ -127,9 +115,9 @@ impl SubagentTool { impl AgentTool for SubagentTool { type Input = SubagentToolInput; - type Output = String; + type Output = SubagentToolOutput; - const NAME: &'static str = acp_thread::SUBAGENT_TOOL_NAME; + const NAME: &'static str = "subagent"; fn kind() -> acp::ToolKind { acp::ToolKind::Other @@ -150,428 +138,156 @@ impl AgentTool for SubagentTool { input: Self::Input, event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task> { - if self.current_depth >= MAX_SUBAGENT_DEPTH { - return Task::ready(Err(anyhow!( - "Maximum subagent depth ({}) reached", - MAX_SUBAGENT_DEPTH - ))); - } + ) -> Task> { + let Some(parent_thread_entity) = self.parent_thread.upgrade() else { + return Task::ready(Err(anyhow!("Parent thread no longer exists"))); + }; - if let Err(e) = self.validate_allowed_tools(&input.allowed_tools, cx) { + if let Err(e) = + Self::validate_allowed_tools(&input.allowed_tools, &parent_thread_entity, cx) + { return Task::ready(Err(e)); } - let Some(parent_thread_entity) = self.parent_thread.upgrade() else { - return Task::ready(Err(anyhow!( - "Parent thread no longer exists (subagent depth={})", - self.current_depth + 1 - ))); + let subagent = match self.environment.create_subagent( + parent_thread_entity, + input.label, + input.task_prompt, + input.timeout_ms.map(|ms| Duration::from_millis(ms)), + input.allowed_tools, + cx, + ) { + Ok(subagent) => subagent, + Err(err) => return Task::ready(Err(err)), }; - let parent_thread = parent_thread_entity.read(cx); - - let running_count = parent_thread.running_subagent_count(); - if running_count >= MAX_PARALLEL_SUBAGENTS { - return Task::ready(Err(anyhow!( - "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.", - MAX_PARALLEL_SUBAGENTS - ))); - } - let parent_model = parent_thread.model().cloned(); - let Some(model) = parent_model else { - return Task::ready(Err(anyhow!("No model configured"))); - }; + let subagent_session_id = subagent.id(); - let parent_thread_id = parent_thread.id().clone(); - let project = parent_thread.project.clone(); - let project_context = parent_thread.project_context().clone(); - let context_server_registry = parent_thread.context_server_registry.clone(); - let templates = parent_thread.templates.clone(); - let parent_tools = parent_thread.tools.clone(); - let current_depth = self.current_depth; - let parent_thread_weak = self.parent_thread.clone(); + event_stream.subagent_spawned(subagent_session_id.clone()); + let meta = acp::Meta::from_iter([( + SUBAGENT_SESSION_ID_META_KEY.into(), + subagent_session_id.to_string().into(), + )]); + event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta)); cx.spawn(async move |cx| { - let subagent_context = SubagentContext { - parent_thread_id: parent_thread_id.clone(), - tool_use_id: LanguageModelToolUseId::from(uuid::Uuid::new_v4().to_string()), - depth: current_depth + 1, - summary_prompt: input.summary_prompt.clone(), - context_low_prompt: input.context_low_prompt.clone(), - }; - - // Determine which tools this subagent gets - let subagent_tools: BTreeMap> = - if let Some(ref allowed) = input.allowed_tools { - let allowed_set: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect(); - parent_tools - .iter() - .filter(|(name, _)| allowed_set.contains(name.as_ref())) - .map(|(name, tool)| (name.clone(), tool.clone())) - .collect() - } else { - parent_tools.clone() - }; - - let subagent_thread: Entity = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - templates.clone(), - model.clone(), - subagent_context, - subagent_tools, - cx, - ) - }); - - let subagent_weak = subagent_thread.downgrade(); - - let acp_thread: Entity = cx.new(|cx| { - let session_id = subagent_thread.read(cx).id().clone(); - let action_log: Entity = cx.new(|_| ActionLog::new(project.clone())); - let connection: Rc = Rc::new(SubagentDisplayConnection); - AcpThread::new( - &input.label, - connection, - project.clone(), - action_log, - session_id, - watch::Receiver::constant(acp::PromptCapabilities::new()), - cx, - ) - }); - - event_stream.update_subagent_thread(acp_thread.clone()); - - let mut user_stop_rx: watch::Receiver = - acp_thread.update(cx, |thread, _| thread.user_stop_receiver()); - - if let Some(parent) = parent_thread_weak.upgrade() { - parent.update(cx, |thread, _cx| { - thread.register_running_subagent(subagent_weak.clone()); - }); - } + let summary_task = subagent.wait_for_summary(input.summary_prompt, cx); - // Helper to wait for user stop signal on the subagent card - let wait_for_user_stop = async { - loop { - if *user_stop_rx.borrow() { - return; - } - if user_stop_rx.changed().await.is_err() { - std::future::pending::<()>().await; - } - } - }; - - // Run the subagent, handling cancellation from both: - // 1. Parent turn cancellation (event_stream.cancelled_by_user) - // 2. Direct user stop on subagent card (user_stop_rx) - let result = futures::select! { - result = run_subagent( - &subagent_thread, - &acp_thread, - input.task_prompt, - input.timeout_ms, - cx, - ).fuse() => result, + futures::select_biased! { + summary = summary_task.fuse() => summary.map(|summary| SubagentToolOutput { + summary, + subagent_session_id, + }), _ = event_stream.cancelled_by_user().fuse() => { - let _ = subagent_thread.update(cx, |thread, cx| { - thread.cancel(cx).detach(); - }); - Err(anyhow!("Subagent cancelled by user")) - } - _ = wait_for_user_stop.fuse() => { - let _ = subagent_thread.update(cx, |thread, cx| { - thread.cancel(cx).detach(); - }); - Err(anyhow!("Subagent stopped by user")) + Err(anyhow!("Subagent was cancelled by user")) } - }; - - if let Some(parent) = parent_thread_weak.upgrade() { - let _ = parent.update(cx, |thread, _cx| { - thread.unregister_running_subagent(&subagent_weak); - }); } - - result }) } -} - -async fn run_subagent( - subagent_thread: &Entity, - acp_thread: &Entity, - task_prompt: String, - timeout_ms: Option, - cx: &mut AsyncApp, -) -> Result { - let mut events_rx = - subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?; - - let acp_thread_weak = acp_thread.downgrade(); - - let timed_out = if let Some(timeout) = timeout_ms { - forward_events_with_timeout( - &mut events_rx, - &acp_thread_weak, - Duration::from_millis(timeout), - cx, - ) - .await - } else { - forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await; - false - }; - - let should_interrupt = - timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx); - - if should_interrupt { - let mut summary_rx = - subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?; - forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await; - } else { - let mut summary_rx = - subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?; - forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await; - } - - Ok(extract_last_message(subagent_thread, cx)) -} -async fn forward_events_until_stop( - events_rx: &mut mpsc::UnboundedReceiver>, - acp_thread: &WeakEntity, - cx: &mut AsyncApp, -) { - while let Some(event) = events_rx.next().await { - match event { - Ok(ThreadEvent::Stop(_)) => break, - Ok(event) => { - forward_event_to_acp_thread(event, acp_thread, cx); - } - Err(_) => break, - } - } -} - -async fn forward_events_with_timeout( - events_rx: &mut mpsc::UnboundedReceiver>, - acp_thread: &WeakEntity, - timeout: Duration, - cx: &mut AsyncApp, -) -> bool { - use futures::future::{self, Either}; - - let deadline = std::time::Instant::now() + timeout; - - loop { - let remaining = deadline.saturating_duration_since(std::time::Instant::now()); - if remaining.is_zero() { - return true; - } - - let timeout_future = cx.background_executor().timer(remaining); - let event_future = events_rx.next(); - - match future::select(event_future, timeout_future).await { - Either::Left((event, _)) => match event { - Some(Ok(ThreadEvent::Stop(_))) => return false, - Some(Ok(event)) => { - forward_event_to_acp_thread(event, acp_thread, cx); - } - Some(Err(_)) => return false, - None => return false, - }, - Either::Right((_, _)) => return true, - } - } -} - -fn forward_event_to_acp_thread( - event: ThreadEvent, - acp_thread: &WeakEntity, - cx: &mut AsyncApp, -) { - match event { - ThreadEvent::UserMessage(message) => { - acp_thread - .update(cx, |thread, cx| { - for content in message.content { - thread.push_user_content_block( - Some(message.id.clone()), - content.into(), - cx, - ); - } - }) - .log_err(); - } - ThreadEvent::AgentText(text) => { - acp_thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block(text.into(), false, cx) - }) - .log_err(); - } - ThreadEvent::AgentThinking(text) => { - acp_thread - .update(cx, |thread, cx| { - thread.push_assistant_content_block(text.into(), true, cx) - }) - .log_err(); - } - ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { - tool_call, - options, - response, - .. - }) => { - let outcome_task = acp_thread.update(cx, |thread, cx| { - thread.request_tool_call_authorization(tool_call, options, true, cx) - }); - if let Ok(Ok(task)) = outcome_task { - cx.background_spawn(async move { - if let acp::RequestPermissionOutcome::Selected( - acp::SelectedPermissionOutcome { option_id, .. }, - ) = task.await - { - response.send(option_id).ok(); - } - }) - .detach(); - } - } - ThreadEvent::ToolCall(tool_call) => { - acp_thread - .update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx)) - .log_err(); - } - ThreadEvent::ToolCallUpdate(update) => { - acp_thread - .update(cx, |thread, cx| thread.update_tool_call(update, cx)) - .log_err(); - } - ThreadEvent::Retry(status) => { - acp_thread - .update(cx, |thread, cx| thread.update_retry_status(status, cx)) - .log_err(); - } - ThreadEvent::Stop(_) => {} + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + event_stream.subagent_spawned(output.subagent_session_id.clone()); + let meta = acp::Meta::from_iter([( + SUBAGENT_SESSION_ID_META_KEY.into(), + output.subagent_session_id.to_string().into(), + )]); + event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta)); + Ok(()) } } -fn check_context_low(thread: &Entity, threshold: f32, cx: &mut AsyncApp) -> bool { - thread.read_with(cx, |thread, _| { - if let Some(usage) = thread.latest_token_usage() { - let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32); - remaining_ratio <= threshold - } else { - false - } - }) -} - -fn extract_last_message(thread: &Entity, cx: &mut AsyncApp) -> String { - thread.read_with(cx, |thread, _| { - thread - .last_message() - .map(|m| m.to_markdown()) - .unwrap_or_else(|| "No response from subagent".to_string()) - }) -} - #[cfg(test)] mod tests { use super::*; - use language_model::LanguageModelToolSchemaFormat; - - #[test] - fn test_subagent_tool_input_json_schema_is_valid() { - let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema); - let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON"); - - assert!( - schema_json.get("properties").is_some(), - "schema should have properties" - ); - let properties = schema_json.get("properties").unwrap(); - - assert!(properties.get("label").is_some(), "should have label field"); - assert!( - properties.get("task_prompt").is_some(), - "should have task_prompt field" - ); - assert!( - properties.get("summary_prompt").is_some(), - "should have summary_prompt field" - ); - assert!( - properties.get("context_low_prompt").is_some(), - "should have context_low_prompt field" - ); - assert!( - properties.get("timeout_ms").is_some(), - "should have timeout_ms field" - ); - assert!( - properties.get("allowed_tools").is_some(), - "should have allowed_tools field" - ); - } - - #[test] - fn test_subagent_tool_name() { - assert_eq!(SubagentTool::NAME, "subagent"); - } - - #[test] - fn test_subagent_tool_kind() { - assert_eq!(SubagentTool::kind(), acp::ToolKind::Other); - } -} - -struct SubagentDisplayConnection; - -impl AgentConnection for SubagentDisplayConnection { - fn telemetry_id(&self) -> SharedString { - acp_thread::SUBAGENT_TOOL_NAME.into() + use crate::{ContextServerRegistry, Templates, Thread}; + use fs::FakeFs; + use gpui::{AppContext as _, TestAppContext}; + use project::Project; + use prompt_store::ProjectContext; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + async fn create_thread_with_tools(cx: &mut TestAppContext) -> Entity { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = + project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + + cx.new(|cx| { + let mut thread = Thread::new( + project, + project_context, + context_server_registry, + Templates::new(), + None, + cx, + ); + thread.add_tool(crate::NowTool, None); + thread.add_tool(crate::ThinkingTool, None); + thread + }) } - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } + #[gpui::test] + async fn test_validate_allowed_tools_succeeds_for_valid_tools(cx: &mut TestAppContext) { + let thread = create_thread_with_tools(cx).await; - fn new_thread( - self: Rc, - _project: Entity, - _cwd: &Path, - _cx: &mut App, - ) -> Task>> { - unimplemented!("SubagentDisplayConnection does not support new_thread") - } + cx.update(|cx| { + assert!(SubagentTool::validate_allowed_tools(&None, &thread, cx).is_ok()); - fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task> { - unimplemented!("SubagentDisplayConnection does not support authenticate") - } + let valid_tools = Some(vec!["now".to_string()]); + assert!(SubagentTool::validate_allowed_tools(&valid_tools, &thread, cx).is_ok()); - fn prompt( - &self, - _id: Option, - _params: acp::PromptRequest, - _cx: &mut App, - ) -> Task> { - unimplemented!("SubagentDisplayConnection does not support prompt") + let both_tools = Some(vec!["now".to_string(), "thinking".to_string()]); + assert!(SubagentTool::validate_allowed_tools(&both_tools, &thread, cx).is_ok()); + }); } - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {} - - fn into_any(self: Rc) -> Rc { - self + #[gpui::test] + async fn test_validate_allowed_tools_fails_for_unknown_tools(cx: &mut TestAppContext) { + let thread = create_thread_with_tools(cx).await; + + cx.update(|cx| { + let unknown_tools = Some(vec!["nonexistent_tool".to_string()]); + let result = SubagentTool::validate_allowed_tools(&unknown_tools, &thread, cx); + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!( + error_message.contains("'nonexistent_tool'"), + "Expected error to mention the invalid tool name, got: {error_message}" + ); + + let mixed_tools = Some(vec![ + "now".to_string(), + "fake_tool_a".to_string(), + "fake_tool_b".to_string(), + ]); + let result = SubagentTool::validate_allowed_tools(&mixed_tools, &thread, cx); + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!( + error_message.contains("'fake_tool_a'") && error_message.contains("'fake_tool_b'"), + "Expected error to mention both invalid tool names, got: {error_message}" + ); + assert!( + !error_message.contains("'now'"), + "Expected error to not mention valid tool 'now', got: {error_message}" + ); + }); } } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 971467cc23ee0b4d629827da559a6082644cc0e5..873a80f84d26d7cfd6f60defdf8be387b64db0c3 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -365,7 +365,7 @@ impl AgentConnection for AcpConnection { self.telemetry_id.clone() } - fn new_thread( + fn new_session( self: Rc, project: Entity, cwd: &Path, @@ -558,6 +558,7 @@ impl AgentConnection for AcpConnection { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread: Entity = cx.new(|cx| { AcpThread::new( + None, self.server_name.clone(), self.clone(), project, @@ -615,6 +616,7 @@ impl AgentConnection for AcpConnection { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread: Entity = cx.new(|cx| { AcpThread::new( + None, self.server_name.clone(), self.clone(), project, @@ -688,6 +690,7 @@ impl AgentConnection for AcpConnection { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread: Entity = cx.new(|cx| { AcpThread::new( + None, self.server_name.clone(), self.clone(), project, diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 50740231b5cbe3352d6a0fa9ed6cf87da9f04c5f..b5e4a40dfb4360bd3d43df177d1e760b07f47f9f 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -449,7 +449,7 @@ pub async fn new_test_thread( .await .unwrap(); - cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx)) + cx.update(|cx| connection.new_session(project.clone(), current_dir.as_ref(), cx)) .await .unwrap() } diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index 2d6b1e7148891020d77654f97ccc2e281557f384..7db45461d0db7ec994b7a63810d25f79c2f98560 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -75,6 +75,7 @@ impl EntryViewState { match thread_entry { AgentThreadEntry::UserMessage(message) => { let has_id = message.id.is_some(); + let is_subagent = thread.read(cx).parent_session_id().is_some(); let chunks = message.chunks.clone(); if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) { if !editor.focus_handle(cx).is_focused(window) { @@ -103,7 +104,7 @@ impl EntryViewState { window, cx, ); - if !has_id { + if !has_id || is_subagent { editor.set_read_only(true, cx); } editor.set_message(chunks, window, cx); @@ -446,7 +447,7 @@ mod tests { .update(|_, cx| { connection .clone() - .new_thread(project.clone(), Path::new(path!("/project")), cx) + .new_session(project.clone(), Path::new(path!("/project")), cx) }) .await .unwrap(); diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index e294a08d14c2c993e4fb73e05e0e3eb001860c0e..bc0b14a2dddf9701ad0b48e6c8f7745e4036cea3 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -51,9 +51,9 @@ use text::{Anchor, ToPoint as _}; use theme::AgentFontSize; use ui::{ Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, CopyButton, DecoratedIcon, - DiffStat, Disclosure, Divider, DividerColor, IconButtonShape, IconDecoration, - IconDecorationKind, KeyBinding, PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, - Tooltip, WithScrollbar, prelude::*, right_click_menu, + DiffStat, Disclosure, Divider, DividerColor, IconDecoration, IconDecorationKind, KeyBinding, + PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*, + right_click_menu, }; use util::defer; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; @@ -178,13 +178,35 @@ pub struct AcpServerView { } impl AcpServerView { - pub fn as_active_thread(&self) -> Option> { + pub fn active_thread(&self) -> Option> { match &self.server_state { ServerState::Connected(connected) => Some(connected.current.clone()), _ => None, } } + pub fn parent_thread(&self, cx: &App) -> Option> { + match &self.server_state { + ServerState::Connected(connected) => { + let mut current = connected.current.clone(); + while let Some(parent_id) = current.read(cx).parent_id.clone() { + if let Some(parent) = connected.threads.get(&parent_id) { + current = parent.clone(); + } else { + break; + } + } + Some(current) + } + _ => None, + } + } + + pub fn thread_view(&self, session_id: &acp::SessionId) -> Option> { + let connected = self.as_connected()?; + connected.threads.get(session_id).cloned() + } + pub fn as_connected(&self) -> Option<&ConnectedServerState> { match &self.server_state { ServerState::Connected(connected) => Some(connected), @@ -198,6 +220,23 @@ impl AcpServerView { _ => None, } } + + pub fn navigate_to_session( + &mut self, + session_id: acp::SessionId, + window: &mut Window, + cx: &mut Context, + ) { + let Some(connected) = self.as_connected_mut() else { + return; + }; + + connected.navigate_to_session(session_id); + if let Some(view) = self.active_thread() { + view.focus_handle(cx).focus(window, cx); + } + cx.notify(); + } } enum ServerState { @@ -211,6 +250,7 @@ enum ServerState { pub struct ConnectedServerState { auth_state: AuthState, current: Entity, + threads: HashMap>, connection: Rc, } @@ -240,6 +280,23 @@ impl ConnectedServerState { pub fn has_thread_error(&self, cx: &App) -> bool { self.current.read(cx).thread_error.is_some() } + + pub fn navigate_to_session(&mut self, session_id: acp::SessionId) { + if let Some(session) = self.threads.get(&session_id) { + self.current = session.clone(); + } + } + + pub fn close_all_sessions(&self, cx: &mut App) -> Task<()> { + let tasks = self + .threads + .keys() + .map(|id| self.connection.close_session(id, cx)); + let task = futures::future::join_all(tasks); + cx.background_spawn(async move { + task.await; + }) + } } impl AcpServerView { @@ -255,9 +312,6 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) -> Self { - let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default())); - let available_commands = Rc::new(RefCell::new(vec![])); - let agent_server_store = project.read(cx).agent_server_store().clone(); let subscriptions = vec![ cx.observe_global_in::(window, Self::agent_ui_font_size_changed), @@ -270,6 +324,9 @@ impl AcpServerView { ]; cx.on_release(|this, cx| { + if let Some(connected) = this.as_connected() { + connected.close_all_sessions(cx).detach(); + } for window in this.notifications.drain(..) { window .update(cx, |_, window, _| { @@ -280,23 +337,17 @@ impl AcpServerView { }) .detach(); - let workspace_for_state = workspace.clone(); - let project_for_state = project.clone(); - Self { agent: agent.clone(), agent_server_store, workspace, - project, + project: project.clone(), thread_store, prompt_store, server_state: Self::initial_state( agent.clone(), resume_thread, - workspace_for_state, - project_for_state, - prompt_capabilities, - available_commands, + project, initial_content, window, cx, @@ -311,30 +362,38 @@ impl AcpServerView { } } - fn reset(&mut self, window: &mut Window, cx: &mut Context) { - let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default())); - let available_commands = Rc::new(RefCell::new(vec![])); + fn set_server_state(&mut self, state: ServerState, cx: &mut Context) { + if let Some(connected) = self.as_connected() { + connected.close_all_sessions(cx).detach(); + } + + self.server_state = state; + cx.notify(); + } + fn reset(&mut self, window: &mut Window, cx: &mut Context) { let resume_thread_metadata = self - .as_active_thread() + .active_thread() .and_then(|thread| thread.read(cx).resume_thread_metadata.clone()); - self.server_state = Self::initial_state( + let state = Self::initial_state( self.agent.clone(), resume_thread_metadata, - self.workspace.clone(), self.project.clone(), - prompt_capabilities.clone(), - available_commands.clone(), None, window, cx, ); + self.set_server_state(state, cx); if let Some(connected) = self.as_connected() { connected.current.update(cx, |this, cx| { this.message_editor.update(cx, |editor, cx| { - editor.set_command_state(prompt_capabilities, available_commands, cx); + editor.set_command_state( + this.prompt_capabilities.clone(), + this.available_commands.clone(), + cx, + ); }); }); } @@ -344,10 +403,7 @@ impl AcpServerView { fn initial_state( agent: Rc, resume_thread: Option, - workspace: WeakEntity, project: Entity, - prompt_capabilities: Rc>, - available_commands: Rc>>, initial_content: Option, window: &mut Window, cx: &mut Context, @@ -400,7 +456,7 @@ impl AcpServerView { this.update_in(cx, |this, window, cx| { if err.downcast_ref::().is_some() { this.handle_load_error(err, window, cx); - } else if let Some(active) = this.as_active_thread() { + } else if let Some(active) = this.active_thread() { active.update(cx, |active, cx| active.handle_any_thread_error(err, cx)); } cx.notify(); @@ -445,7 +501,7 @@ impl AcpServerView { cx.update(|_, cx| { connection .clone() - .new_thread(project.clone(), fallback_cwd.as_ref(), cx) + .new_session(project.clone(), fallback_cwd.as_ref(), cx) }) .log_err() }; @@ -471,181 +527,15 @@ impl AcpServerView { this.update_in(cx, |this, window, cx| { match result { Ok(thread) => { - let action_log = thread.read(cx).action_log().clone(); - - prompt_capabilities.replace(thread.read(cx).prompt_capabilities()); - - let entry_view_state = cx.new(|_| { - EntryViewState::new( - this.workspace.clone(), - this.project.downgrade(), - this.thread_store.clone(), - this.history.downgrade(), - this.prompt_store.clone(), - prompt_capabilities.clone(), - available_commands.clone(), - this.agent.name(), - ) - }); - - let count = thread.read(cx).entries().len(); - let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); - entry_view_state.update(cx, |view_state, cx| { - for ix in 0..count { - view_state.sync_entry(ix, &thread, window, cx); - } - list_state.splice_focusable( - 0..0, - (0..count).map(|ix| view_state.entry(ix)?.focus_handle(cx)), - ); - }); - - AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); - - let connection = thread.read(cx).connection().clone(); - let session_id = thread.read(cx).session_id().clone(); - let session_list = if connection.supports_session_history(cx) { - connection.session_list(cx) - } else { - None - }; - this.history.update(cx, |history, cx| { - history.set_session_list(session_list, cx); - }); - - // Check for config options first - // Config options take precedence over legacy mode/model selectors - // (feature flag gating happens at the data layer) - let config_options_provider = - connection.session_config_options(&session_id, cx); - - let config_options_view; - let mode_selector; - let model_selector; - if let Some(config_options) = config_options_provider { - // Use config options - don't create mode_selector or model_selector - let agent_server = this.agent.clone(); - let fs = this.project.read(cx).fs().clone(); - config_options_view = Some(cx.new(|cx| { - ConfigOptionsView::new(config_options, agent_server, fs, window, cx) - })); - model_selector = None; - mode_selector = None; - } else { - // Fall back to legacy mode/model selectors - config_options_view = None; - model_selector = - connection.model_selector(&session_id).map(|selector| { - let agent_server = this.agent.clone(); - let fs = this.project.read(cx).fs().clone(); - cx.new(|cx| { - AcpModelSelectorPopover::new( - selector, - agent_server, - fs, - PopoverMenuHandle::default(), - this.focus_handle(cx), - window, - cx, - ) - }) - }); - - mode_selector = - connection - .session_modes(&session_id, cx) - .map(|session_modes| { - let fs = this.project.read(cx).fs().clone(); - let focus_handle = this.focus_handle(cx); - cx.new(|_cx| { - ModeSelector::new( - session_modes, - this.agent.clone(), - fs, - focus_handle, - ) - }) - }); - } - - let mut subscriptions = vec![ - cx.subscribe_in(&thread, window, Self::handle_thread_event), - cx.observe(&action_log, |_, _, cx| cx.notify()), - // cx.subscribe_in( - // &entry_view_state, - // window, - // Self::handle_entry_view_event, - // ), - ]; - - let title_editor = - if thread.update(cx, |thread, cx| thread.can_set_title(cx)) { - let editor = cx.new(|cx| { - let mut editor = Editor::single_line(window, cx); - editor.set_text(thread.read(cx).title(), window, cx); - editor - }); - subscriptions.push(cx.subscribe_in( - &editor, - window, - Self::handle_title_editor_event, - )); - Some(editor) - } else { - None - }; - - let profile_selector: Option> = - connection.clone().downcast(); - let profile_selector = profile_selector - .and_then(|native_connection| native_connection.thread(&session_id, cx)) - .map(|native_thread| { - cx.new(|cx| { - ProfileSelector::new( - ::global(cx), - Arc::new(native_thread), - this.focus_handle(cx), - cx, - ) - }) - }); - - let agent_display_name = this - .agent_server_store - .read(cx) - .agent_display_name(&ExternalAgentServerName(agent.name())) - .unwrap_or_else(|| agent.name()); - - let weak = cx.weak_entity(); - let current = cx.new(|cx| { - AcpThreadView::new( - thread, - this.login.clone(), - weak, - agent.name(), - agent_display_name, - workspace.clone(), - entry_view_state, - title_editor, - config_options_view, - mode_selector, - model_selector, - profile_selector, - list_state, - prompt_capabilities, - available_commands, - resumed_without_history, - resume_thread.clone(), - project.downgrade(), - this.thread_store.clone(), - this.history.clone(), - this.prompt_store.clone(), - initial_content, - subscriptions, - window, - cx, - ) - }); + let current = this.new_thread_view( + None, + thread, + resumed_without_history, + resume_thread, + initial_content, + window, + cx, + ); if this.focus_handle.contains_focused(window, cx) { current @@ -655,13 +545,18 @@ impl AcpServerView { .focus(window, cx); } - this.server_state = ServerState::Connected(ConnectedServerState { - connection, - auth_state: AuthState::Ok, - current, - }); - - cx.notify(); + this.set_server_state( + ServerState::Connected(ConnectedServerState { + connection, + auth_state: AuthState::Ok, + current: current.clone(), + threads: HashMap::from_iter([( + current.read(cx).thread.read(cx).session_id().clone(), + current, + )]), + }), + cx, + ); } Err(err) => { this.handle_load_error(err, window, cx); @@ -675,7 +570,7 @@ impl AcpServerView { while let Ok(new_version) = new_version_available_rx.recv().await { if let Some(new_version) = new_version { this.update(cx, |this, cx| { - if let Some(thread) = this.as_active_thread() { + if let Some(thread) = this.active_thread() { thread.update(cx, |thread, _cx| { thread.new_server_version_available = Some(new_version.into()); }); @@ -709,6 +604,211 @@ impl AcpServerView { ServerState::Loading(loading_view) } + fn new_thread_view( + &self, + parent_id: Option, + thread: Entity, + resumed_without_history: bool, + resume_thread: Option, + initial_content: Option, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + let agent_name = self.agent.name(); + let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default())); + let available_commands = Rc::new(RefCell::new(vec![])); + + let action_log = thread.read(cx).action_log().clone(); + + prompt_capabilities.replace(thread.read(cx).prompt_capabilities()); + + let entry_view_state = cx.new(|_| { + EntryViewState::new( + self.workspace.clone(), + self.project.downgrade(), + self.thread_store.clone(), + self.history.downgrade(), + self.prompt_store.clone(), + prompt_capabilities.clone(), + available_commands.clone(), + self.agent.name(), + ) + }); + + let count = thread.read(cx).entries().len(); + let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); + entry_view_state.update(cx, |view_state, cx| { + for ix in 0..count { + view_state.sync_entry(ix, &thread, window, cx); + } + list_state.splice_focusable( + 0..0, + (0..count).map(|ix| view_state.entry(ix)?.focus_handle(cx)), + ); + }); + + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); + + let connection = thread.read(cx).connection().clone(); + let session_id = thread.read(cx).session_id().clone(); + let session_list = if connection.supports_session_history(cx) { + connection.session_list(cx) + } else { + None + }; + self.history.update(cx, |history, cx| { + history.set_session_list(session_list, cx); + }); + + // Check for config options first + // Config options take precedence over legacy mode/model selectors + // (feature flag gating happens at the data layer) + let config_options_provider = connection.session_config_options(&session_id, cx); + + let config_options_view; + let mode_selector; + let model_selector; + if let Some(config_options) = config_options_provider { + // Use config options - don't create mode_selector or model_selector + let agent_server = self.agent.clone(); + let fs = self.project.read(cx).fs().clone(); + config_options_view = + Some(cx.new(|cx| { + ConfigOptionsView::new(config_options, agent_server, fs, window, cx) + })); + model_selector = None; + mode_selector = None; + } else { + // Fall back to legacy mode/model selectors + config_options_view = None; + model_selector = connection.model_selector(&session_id).map(|selector| { + let agent_server = self.agent.clone(); + let fs = self.project.read(cx).fs().clone(); + cx.new(|cx| { + AcpModelSelectorPopover::new( + selector, + agent_server, + fs, + PopoverMenuHandle::default(), + self.focus_handle(cx), + window, + cx, + ) + }) + }); + + mode_selector = connection + .session_modes(&session_id, cx) + .map(|session_modes| { + let fs = self.project.read(cx).fs().clone(); + let focus_handle = self.focus_handle(cx); + cx.new(|_cx| { + ModeSelector::new(session_modes, self.agent.clone(), fs, focus_handle) + }) + }); + } + + let mut subscriptions = vec![ + cx.subscribe_in(&thread, window, Self::handle_thread_event), + cx.observe(&action_log, |_, _, cx| cx.notify()), + ]; + + let parent_session_id = thread.read(cx).session_id().clone(); + let subagent_sessions = thread + .read(cx) + .entries() + .iter() + .filter_map(|entry| match entry { + AgentThreadEntry::ToolCall(call) => call.subagent_session_id.clone(), + _ => None, + }) + .collect::>(); + + if !subagent_sessions.is_empty() { + cx.spawn_in(window, async move |this, cx| { + this.update_in(cx, |this, window, cx| { + for subagent_id in subagent_sessions { + this.load_subagent_session( + subagent_id, + parent_session_id.clone(), + window, + cx, + ); + } + }) + }) + .detach(); + } + + let title_editor = if thread.update(cx, |thread, cx| thread.can_set_title(cx)) { + let editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_text(thread.read(cx).title(), window, cx); + editor + }); + subscriptions.push(cx.subscribe_in(&editor, window, Self::handle_title_editor_event)); + Some(editor) + } else { + None + }; + + let profile_selector: Option> = + connection.clone().downcast(); + let profile_selector = profile_selector + .and_then(|native_connection| native_connection.thread(&session_id, cx)) + .map(|native_thread| { + cx.new(|cx| { + ProfileSelector::new( + ::global(cx), + Arc::new(native_thread), + self.focus_handle(cx), + cx, + ) + }) + }); + + let agent_display_name = self + .agent_server_store + .read(cx) + .agent_display_name(&ExternalAgentServerName(agent_name.clone())) + .unwrap_or_else(|| agent_name.clone()); + + let agent_icon = self.agent.logo(); + + let weak = cx.weak_entity(); + cx.new(|cx| { + AcpThreadView::new( + parent_id, + thread, + self.login.clone(), + weak, + agent_icon, + agent_name, + agent_display_name, + self.workspace.clone(), + entry_view_state, + title_editor, + config_options_view, + mode_selector, + model_selector, + profile_selector, + list_state, + prompt_capabilities, + available_commands, + resumed_without_history, + resume_thread, + self.project.downgrade(), + self.thread_store.clone(), + self.history.clone(), + self.prompt_store.clone(), + initial_content, + subscriptions, + window, + cx, + ) + }) + } + fn handle_auth_required( this: WeakEntity, err: AuthRequired, @@ -804,8 +904,7 @@ impl AcpServerView { LoadError::Other(format!("{:#}", err).into()) }; self.emit_load_error_telemetry(&load_error); - self.server_state = ServerState::LoadError(load_error); - cx.notify(); + self.set_server_state(ServerState::LoadError(load_error), cx); } fn handle_agent_servers_updated( @@ -827,7 +926,7 @@ impl AcpServerView { }; if should_retry { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, cx| { active.clear_thread_error(cx); }); @@ -856,7 +955,7 @@ impl AcpServerView { } pub fn cancel_generation(&mut self, cx: &mut Context) { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, cx| { active.cancel_generation(cx); }); @@ -870,7 +969,7 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, cx| { active.handle_title_editor_event(title_editor, event, window, cx); }); @@ -882,7 +981,7 @@ impl AcpServerView { } fn update_turn_tokens(&mut self, cx: &mut Context) { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, cx| { active.update_turn_tokens(cx); }); @@ -896,7 +995,7 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, cx| { active.send_queued_message_at_index(index, is_send_now, window, cx); }); @@ -910,11 +1009,13 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) { + let thread_id = thread.read(cx).session_id().clone(); + let is_subagent = thread.read(cx).parent_session_id().is_some(); match event { AcpThreadEvent::NewEntry => { let len = thread.read(cx).entries().len(); let index = len - 1; - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.thread_view(&thread_id) { let entry_view_state = active.read(cx).entry_view_state.clone(); let list_state = active.read(cx).list_state.clone(); entry_view_state.update(cx, |view_state, cx| { @@ -930,7 +1031,7 @@ impl AcpServerView { } AcpThreadEvent::EntryUpdated(index) => { if let Some(entry_view_state) = self - .as_active_thread() + .thread_view(&thread_id) .map(|active| active.read(cx).entry_view_state.clone()) { entry_view_state.update(cx, |view_state, cx| { @@ -939,29 +1040,39 @@ impl AcpServerView { } } AcpThreadEvent::EntriesRemoved(range) => { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.thread_view(&thread_id) { let entry_view_state = active.read(cx).entry_view_state.clone(); let list_state = active.read(cx).list_state.clone(); entry_view_state.update(cx, |view_state, _cx| view_state.remove(range.clone())); list_state.splice(range.clone(), 0); } } + AcpThreadEvent::SubagentSpawned(session_id) => self.load_subagent_session( + session_id.clone(), + thread.read(cx).session_id().clone(), + window, + cx, + ), AcpThreadEvent::ToolAuthorizationRequired => { self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); } AcpThreadEvent::Retry(retry) => { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.thread_view(&thread_id) { active.update(cx, |active, _cx| { active.thread_retry_status = Some(retry.clone()); }); } } AcpThreadEvent::Stopped => { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.thread_view(&thread_id) { active.update(cx, |active, _cx| { active.thread_retry_status.take(); }); } + if is_subagent { + return; + } + let used_tools = thread.read(cx).used_tools_since_last_user_message(); self.notify_with_sound( if used_tools { @@ -974,7 +1085,7 @@ impl AcpServerView { cx, ); - let should_send_queued = if let Some(active) = self.as_active_thread() { + let should_send_queued = if let Some(active) = self.active_thread() { active.update(cx, |active, cx| { if active.skip_queue_processing_count > 0 { active.skip_queue_processing_count -= 1; @@ -1005,29 +1116,33 @@ impl AcpServerView { } AcpThreadEvent::Refusal => { let error = ThreadError::Refusal; - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.thread_view(&thread_id) { active.update(cx, |active, cx| { active.handle_thread_error(error, cx); active.thread_retry_status.take(); }); } - let model_or_agent_name = self.current_model_name(cx); - let notification_message = - format!("{} refused to respond to this request", model_or_agent_name); - self.notify_with_sound(¬ification_message, IconName::Warning, window, cx); + if !is_subagent { + let model_or_agent_name = self.current_model_name(cx); + let notification_message = + format!("{} refused to respond to this request", model_or_agent_name); + self.notify_with_sound(¬ification_message, IconName::Warning, window, cx); + } } AcpThreadEvent::Error => { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.thread_view(&thread_id) { active.update(cx, |active, _cx| { active.thread_retry_status.take(); }); } - self.notify_with_sound( - "Agent stopped due to an error", - IconName::Warning, - window, - cx, - ); + if !is_subagent { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + } } AcpThreadEvent::LoadError(error) => { match &self.server_state { @@ -1044,12 +1159,12 @@ impl AcpServerView { } _ => {} } - self.server_state = ServerState::LoadError(error.clone()); + self.set_server_state(ServerState::LoadError(error.clone()), cx); } AcpThreadEvent::TitleUpdated => { let title = thread.read(cx).title(); if let Some(title_editor) = self - .as_active_thread() + .thread_view(&thread_id) .and_then(|active| active.read(cx).title_editor.clone()) { title_editor.update(cx, |editor, cx| { @@ -1061,7 +1176,7 @@ impl AcpServerView { self.history.update(cx, |history, cx| history.refresh(cx)); } AcpThreadEvent::PromptCapabilitiesUpdated => { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.thread_view(&thread_id) { active.update(cx, |active, _cx| { active .prompt_capabilities @@ -1088,7 +1203,7 @@ impl AcpServerView { } let has_commands = !available_commands.is_empty(); - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, _cx| { active.available_commands.replace(available_commands); }); @@ -1100,7 +1215,7 @@ impl AcpServerView { .agent_display_name(&ExternalAgentServerName(self.agent.name())) .unwrap_or_else(|| self.agent.name()); - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { let new_placeholder = placeholder_text(agent_display_name.as_ref(), has_commands); active.update(cx, |active, cx| { @@ -1244,7 +1359,7 @@ impl AcpServerView { { pending_auth_method.take(); } - if let Some(active) = this.as_active_thread() { + if let Some(active) = this.active_thread() { active.update(cx, |active, cx| { active.handle_any_thread_error(err, cx); }) @@ -1359,7 +1474,7 @@ impl AcpServerView { { pending_auth_method.take(); } - if let Some(active) = this.as_active_thread() { + if let Some(active) = this.active_thread() { active.update(cx, |active, cx| active.handle_any_thread_error(err, cx)); } } else { @@ -1372,6 +1487,63 @@ impl AcpServerView { })); } + fn load_subagent_session( + &mut self, + subagent_id: acp::SessionId, + parent_id: acp::SessionId, + window: &mut Window, + cx: &mut Context, + ) { + let Some(connected) = self.as_connected() else { + return; + }; + if connected.threads.contains_key(&subagent_id) + || !connected.connection.supports_load_session(cx) + { + return; + } + let root_dir = self + .project + .read(cx) + .worktrees(cx) + .filter_map(|worktree| { + if worktree.read(cx).is_single_file() { + Some(worktree.read(cx).abs_path().parent()?.into()) + } else { + Some(worktree.read(cx).abs_path()) + } + }) + .next(); + let cwd = root_dir.unwrap_or_else(|| paths::home_dir().as_path().into()); + + let subagent_thread_task = connected.connection.clone().load_session( + AgentSessionInfo::new(subagent_id.clone()), + self.project.clone(), + &cwd, + cx, + ); + + cx.spawn_in(window, async move |this, cx| { + let subagent_thread = subagent_thread_task.await?; + this.update_in(cx, |this, window, cx| { + let view = this.new_thread_view( + Some(parent_id), + subagent_thread, + false, + None, + None, + window, + cx, + ); + let Some(connected) = this.as_connected_mut() else { + return; + }; + connected.threads.insert(subagent_id, view); + }) + }) + .detach(); + } + fn spawn_external_agent_login( login: task::SpawnInTerminal, workspace: Entity, @@ -1492,7 +1664,7 @@ impl AcpServerView { } pub fn has_user_submitted_prompt(&self, cx: &App) -> bool { - self.as_active_thread().is_some_and(|active| { + self.active_thread().is_some_and(|active| { active .read(cx) .thread @@ -1636,7 +1808,7 @@ impl AcpServerView { thread: &Entity, cx: &mut Context, ) { - let Some(active_thread) = self.as_active_thread() else { + let Some(active_thread) = self.active_thread() else { return; }; @@ -1790,18 +1962,18 @@ impl AcpServerView { &self, cx: &App, ) -> Option> { - let acp_thread = self.as_active_thread()?.read(cx).thread.read(cx); + let acp_thread = self.active_thread()?.read(cx).thread.read(cx); acp_thread.connection().clone().downcast() } pub(crate) fn as_native_thread(&self, cx: &App) -> Option> { - let acp_thread = self.as_active_thread()?.read(cx).thread.read(cx); + let acp_thread = self.active_thread()?.read(cx).thread.read(cx); self.as_native_connection(cx)? .thread(acp_thread.session_id(), cx) } fn queued_messages_len(&self, cx: &App) -> usize { - self.as_active_thread() + self.active_thread() .map(|thread| thread.read(cx).local_queued_messages.len()) .unwrap_or_default() } @@ -1813,7 +1985,7 @@ impl AcpServerView { tracked_buffers: Vec>, cx: &mut Context, ) -> bool { - match self.as_active_thread() { + match self.active_thread() { Some(thread) => thread.update(cx, |thread, _cx| { if index < thread.local_queued_messages.len() { thread.local_queued_messages[index] = QueuedMessage { @@ -1830,7 +2002,7 @@ impl AcpServerView { } fn queued_message_contents(&self, cx: &App) -> Vec> { - match self.as_active_thread() { + match self.active_thread() { None => Vec::new(), Some(thread) => thread .read(cx) @@ -1842,7 +2014,7 @@ impl AcpServerView { } fn save_queued_message_at_index(&mut self, index: usize, cx: &mut Context) { - let editor = match self.as_active_thread() { + let editor = match self.active_thread() { Some(thread) => thread.read(cx).queued_message_editors.get(index).cloned(), None => None, }; @@ -1876,7 +2048,7 @@ impl AcpServerView { let project = self.project.downgrade(); let history = self.history.downgrade(); - let Some(thread) = self.as_active_thread() else { + let Some(thread) = self.active_thread() else { return; }; let prompt_capabilities = thread.read(cx).prompt_capabilities.clone(); @@ -1961,7 +2133,7 @@ impl AcpServerView { }); } - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, _cx| { active.last_synced_queue_length = needed_count; }); @@ -2140,7 +2312,7 @@ impl AcpServerView { fn agent_ui_font_size_changed(&mut self, _window: &mut Window, cx: &mut Context) { if let Some(entry_view_state) = self - .as_active_thread() + .active_thread() .map(|active| active.read(cx).entry_view_state.clone()) { entry_view_state.update(cx, |entry_view_state, cx| { @@ -2156,7 +2328,7 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) { - if let Some(active_thread) = self.as_active_thread() { + if let Some(active_thread) = self.active_thread() { active_thread.update(cx, |thread, cx| { thread.message_editor.update(cx, |editor, cx| { editor.insert_dragged_files(paths, added_worktrees, window, cx); @@ -2168,7 +2340,7 @@ impl AcpServerView { /// Inserts the selected text into the message editor or the message being /// edited, if any. pub(crate) fn insert_selections(&self, window: &mut Window, cx: &mut Context) { - if let Some(active_thread) = self.as_active_thread() { + if let Some(active_thread) = self.active_thread() { active_thread.update(cx, |thread, cx| { thread.active_editor(cx).update(cx, |editor, cx| { editor.insert_selections(window, cx); @@ -2184,7 +2356,7 @@ impl AcpServerView { window: &mut Window, cx: &mut Context, ) { - if let Some(active_thread) = self.as_active_thread() { + if let Some(active_thread) = self.active_thread() { active_thread.update(cx, |thread, cx| { thread.message_editor.update(cx, |editor, cx| { editor.insert_terminal_crease(text, window, cx); @@ -2198,7 +2370,7 @@ impl AcpServerView { // For ACP agents, use the agent name (e.g., "Claude Code", "Gemini CLI") // This provides better clarity about what refused the request if self.as_native_connection(cx).is_some() { - self.as_active_thread() + self.active_thread() .and_then(|active| active.read(cx).model_selector.clone()) .and_then(|selector| selector.read(cx).active_model(cx)) .map(|model| model.name.clone()) @@ -2217,7 +2389,7 @@ impl AcpServerView { pub(crate) fn reauthenticate(&mut self, window: &mut Window, cx: &mut Context) { let agent_name = self.agent.name(); - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, cx| active.clear_thread_error(cx)); } let this = cx.weak_entity(); @@ -2257,7 +2429,7 @@ fn placeholder_text(agent_name: &str, has_commands: bool) -> String { impl Focusable for AcpServerView { fn focus_handle(&self, cx: &App) -> FocusHandle { - match self.as_active_thread() { + match self.active_thread() { Some(thread) => thread.read(cx).focus_handle(cx), None => self.focus_handle.clone(), } @@ -2269,7 +2441,7 @@ impl AcpServerView { /// Expands a tool call so its content is visible. /// This is primarily useful for visual testing. pub fn expand_tool_call(&mut self, tool_call_id: acp::ToolCallId, cx: &mut Context) { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, _cx| { active.expanded_tool_calls.insert(tool_call_id); }); @@ -2280,7 +2452,7 @@ impl AcpServerView { /// Expands a subagent card so its content is visible. /// This is primarily useful for visual testing. pub fn expand_subagent(&mut self, session_id: acp::SessionId, cx: &mut Context) { - if let Some(active) = self.as_active_thread() { + if let Some(active) = self.active_thread() { active.update(cx, |active, _cx| { active.expanded_subagents.insert(session_id); }); @@ -2294,8 +2466,8 @@ impl Render for AcpServerView { self.sync_queued_message_editors(window, cx); v_flex() + .track_focus(&self.focus_handle(cx)) .size_full() - .track_focus(&self.focus_handle) .bg(cx.theme().colors().panel_background) .child(match &self.server_state { ServerState::Loading { .. } => v_flex() @@ -2548,7 +2720,7 @@ pub(crate) mod tests { cx.run_until_parked(); thread_view.read_with(cx, |view, cx| { - let state = view.as_active_thread().unwrap(); + let state = view.active_thread().unwrap(); assert!(state.read(cx).resumed_without_history); assert_eq!(state.read(cx).list_state.item_count(), 0); }); @@ -2572,7 +2744,7 @@ pub(crate) mod tests { // Check that the refusal error is set thread_view.read_with(cx, |thread_view, cx| { - let state = thread_view.as_active_thread().unwrap(); + let state = thread_view.active_thread().unwrap(); assert!( matches!(state.read(cx).thread_error, Some(ThreadError::Refusal)), "Expected refusal error to be set" @@ -2918,7 +3090,7 @@ pub(crate) mod tests { "resume-only".into() } - fn new_thread( + fn new_session( self: Rc, project: Entity, _cwd: &Path, @@ -2927,6 +3099,7 @@ pub(crate) mod tests { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread = cx.new(|cx| { AcpThread::new( + None, "ResumeOnlyAgentConnection", self.clone(), project, @@ -2958,6 +3131,7 @@ pub(crate) mod tests { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread = cx.new(|cx| { AcpThread::new( + None, "ResumeOnlyAgentConnection", self.clone(), project, @@ -3011,7 +3185,7 @@ pub(crate) mod tests { "saboteur".into() } - fn new_thread( + fn new_session( self: Rc, project: Entity, _cwd: &Path, @@ -3020,6 +3194,7 @@ pub(crate) mod tests { Task::ready(Ok(cx.new(|cx| { let action_log = cx.new(|_| ActionLog::new(project.clone())); AcpThread::new( + None, "SaboteurAgentConnection", self, project, @@ -3075,7 +3250,7 @@ pub(crate) mod tests { "refusal".into() } - fn new_thread( + fn new_session( self: Rc, project: Entity, _cwd: &Path, @@ -3084,6 +3259,7 @@ pub(crate) mod tests { Task::ready(Ok(cx.new(|cx| { let action_log = cx.new(|_| ActionLog::new(project.clone())); AcpThread::new( + None, "RefusalAgentConnection", self, project, @@ -3198,7 +3374,7 @@ pub(crate) mod tests { let thread = thread_view .read_with(cx, |view, cx| { - view.as_active_thread().map(|r| r.read(cx).thread.clone()) + view.active_thread().map(|r| r.read(cx).thread.clone()) }) .unwrap(); @@ -3224,7 +3400,7 @@ pub(crate) mod tests { thread_view.read_with(cx, |view, cx| { let entry_view_state = view - .as_active_thread() + .active_thread() .map(|active| active.read(cx).entry_view_state.clone()) .unwrap(); entry_view_state.read_with(cx, |entry_view_state, _| { @@ -3265,7 +3441,7 @@ pub(crate) mod tests { thread_view.read_with(cx, |view, cx| { let entry_view_state = view - .as_active_thread() + .active_thread() .unwrap() .read(cx) .entry_view_state @@ -3303,7 +3479,7 @@ pub(crate) mod tests { }); thread_view.read_with(cx, |view, cx| { - let active = view.as_active_thread().unwrap(); + let active = view.active_thread().unwrap(); active .read(cx) .entry_view_state @@ -3340,7 +3516,7 @@ pub(crate) mod tests { let thread = thread_view .read_with(cx, |view, cx| { - view.as_active_thread().map(|r| r.read(cx).thread.clone()) + view.active_thread().map(|r| r.read(cx).thread.clone()) }) .unwrap(); @@ -3413,12 +3589,12 @@ pub(crate) mod tests { let user_message_editor = thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), None ); - view.as_active_thread() + view.active_thread() .map(|active| &active.read(cx).entry_view_state) .as_ref() .unwrap() @@ -3434,7 +3610,7 @@ pub(crate) mod tests { cx.focus(&user_message_editor); thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), Some(0) ); @@ -3452,7 +3628,7 @@ pub(crate) mod tests { thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), None ); @@ -3480,7 +3656,7 @@ pub(crate) mod tests { let thread = cx.read(|cx| { thread_view .read(cx) - .as_active_thread() + .active_thread() .unwrap() .read(cx) .thread @@ -3524,12 +3700,12 @@ pub(crate) mod tests { let user_message_editor = thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), None ); assert_eq!( - view.as_active_thread() + view.active_thread() .unwrap() .read(cx) .thread @@ -3539,7 +3715,7 @@ pub(crate) mod tests { 2 ); - view.as_active_thread() + view.active_thread() .map(|active| &active.read(cx).entry_view_state) .as_ref() .unwrap() @@ -3572,13 +3748,13 @@ pub(crate) mod tests { thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), None ); let entries = view - .as_active_thread() + .active_thread() .unwrap() .read(cx) .thread @@ -3595,7 +3771,7 @@ pub(crate) mod tests { ); let entry_view_state = view - .as_active_thread() + .active_thread() .map(|active| &active.read(cx).entry_view_state) .unwrap(); let new_editor = entry_view_state.read_with(cx, |state, _cx| { @@ -3626,11 +3802,11 @@ pub(crate) mod tests { cx.run_until_parked(); let (user_message_editor, session_id) = thread_view.read_with(cx, |view, cx| { - let thread = view.as_active_thread().unwrap().read(cx).thread.read(cx); + let thread = view.active_thread().unwrap().read(cx).thread.read(cx); assert_eq!(thread.entries().len(), 1); let editor = view - .as_active_thread() + .active_thread() .map(|active| &active.read(cx).entry_view_state) .as_ref() .unwrap() @@ -3649,7 +3825,7 @@ pub(crate) mod tests { thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), Some(0) ); @@ -3662,7 +3838,7 @@ pub(crate) mod tests { thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), Some(0) ); @@ -3680,7 +3856,7 @@ pub(crate) mod tests { thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), Some(0) ); @@ -3694,7 +3870,7 @@ pub(crate) mod tests { assert_eq!( thread_view .read(cx) - .as_active_thread() + .active_thread() .and_then(|active| active.read(cx).editing_message), Some(0) ); @@ -3728,7 +3904,7 @@ pub(crate) mod tests { let (thread, session_id) = thread_view.read_with(cx, |view, cx| { let thread = view - .as_active_thread() + .active_thread() .as_ref() .unwrap() .read(cx) @@ -3822,7 +3998,7 @@ pub(crate) mod tests { add_to_workspace(thread_view.clone(), cx); let thread = thread_view.read_with(cx, |view, cx| { - view.as_active_thread().unwrap().read(cx).thread.clone() + view.active_thread().unwrap().read(cx).thread.clone() }); thread.read_with(cx, |thread, _cx| { @@ -3862,7 +4038,7 @@ pub(crate) mod tests { active_thread(&thread_view, cx).update_in(cx, |view, window, cx| view.send(window, cx)); let (thread, session_id) = thread_view.read_with(cx, |view, cx| { - let thread = view.as_active_thread().unwrap().read(cx).thread.clone(); + let thread = view.active_thread().unwrap().read(cx).thread.clone(); (thread.clone(), thread.read(cx).session_id().clone()) }); @@ -3994,7 +4170,7 @@ pub(crate) mod tests { let user_message_editor = thread_view.read_with(cx, |thread_view, cx| { thread_view - .as_active_thread() + .active_thread() .map(|active| &active.read(cx).entry_view_state) .as_ref() .unwrap() @@ -4009,7 +4185,7 @@ pub(crate) mod tests { cx.focus(&user_message_editor); thread_view.read_with(cx, |view, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), Some(0) ); @@ -4048,7 +4224,7 @@ pub(crate) mod tests { thread_view.update_in(cx, |view, window, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), Some(0) ); @@ -4107,7 +4283,7 @@ pub(crate) mod tests { thread_view.update_in(cx, |view, window, cx| { assert_eq!( - view.as_active_thread() + view.active_thread() .and_then(|active| active.read(cx).editing_message), None ); @@ -4167,7 +4343,7 @@ pub(crate) mod tests { // Verify the tool call is in WaitingForConfirmation state with the expected options thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread @@ -4274,7 +4450,7 @@ pub(crate) mod tests { // Verify the options thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread @@ -4362,7 +4538,7 @@ pub(crate) mod tests { // Verify the options thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread @@ -4451,7 +4627,7 @@ pub(crate) mod tests { // Verify only 2 options (no pattern button when command doesn't match pattern) thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread @@ -4549,7 +4725,7 @@ pub(crate) mod tests { // Verify tool call is waiting for confirmation thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread @@ -4579,7 +4755,7 @@ pub(crate) mod tests { // Verify tool call is no longer waiting for confirmation (was authorized) thread_view.read_with(cx, |thread_view, cx| { - let thread = thread_view.as_active_thread().expect("Thread should exist").read(cx).thread.clone(); + let thread = thread_view.active_thread().expect("Thread should exist").read(cx).thread.clone(); let thread = thread.read(cx); let tool_call = thread.first_tool_awaiting_confirmation(); assert!( @@ -4664,7 +4840,7 @@ pub(crate) mod tests { // Verify tool call was authorized thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread @@ -4721,7 +4897,7 @@ pub(crate) mod tests { // Verify default granularity is the last option (index 2 = "Only this time") thread_view.read_with(cx, |thread_view, cx| { - let state = thread_view.as_active_thread().unwrap(); + let state = thread_view.active_thread().unwrap(); let selected = state .read(cx) .selected_permission_granularity @@ -4748,7 +4924,7 @@ pub(crate) mod tests { // Verify the selection was updated thread_view.read_with(cx, |thread_view, cx| { - let state = thread_view.as_active_thread().unwrap(); + let state = thread_view.active_thread().unwrap(); let selected = state .read(cx) .selected_permission_granularity @@ -4845,7 +5021,7 @@ pub(crate) mod tests { // Verify tool call was authorized thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread @@ -4911,7 +5087,7 @@ pub(crate) mod tests { // Verify tool call was rejected (no longer waiting for confirmation) thread_view.read_with(cx, |thread_view, cx| { let thread = thread_view - .as_active_thread() + .active_thread() .expect("Thread should exist") .read(cx) .thread diff --git a/crates/agent_ui/src/acp/thread_view/active_thread.rs b/crates/agent_ui/src/acp/thread_view/active_thread.rs index b329cae4ed4821ebfbe6678b121bcba4e83feabe..d8817bc2bf24d1a87905487165ff006dd3340862 100644 --- a/crates/agent_ui/src/acp/thread_view/active_thread.rs +++ b/crates/agent_ui/src/acp/thread_view/active_thread.rs @@ -1,7 +1,7 @@ use gpui::{Corner, List}; use language_model::LanguageModelEffortLevel; use settings::update_settings_file; -use ui::{ButtonLike, SplitButton, SplitButtonStyle}; +use ui::{ButtonLike, SplitButton, SplitButtonStyle, Tab}; use super::*; @@ -167,10 +167,13 @@ impl DiffStats { pub struct AcpThreadView { pub id: acp::SessionId, + pub parent_id: Option, pub login: Option, // is some <=> Active | Unauthenticated pub thread: Entity, pub server_view: WeakEntity, + pub agent_icon: IconName, pub agent_name: SharedString, + pub focus_handle: FocusHandle, pub workspace: WeakEntity, pub entry_view_state: Entity, pub title_editor: Option>, @@ -234,7 +237,11 @@ pub struct AcpThreadView { } impl Focusable for AcpThreadView { fn focus_handle(&self, cx: &App) -> FocusHandle { - self.active_editor(cx).focus_handle(cx) + if self.parent_id.is_some() { + self.focus_handle.clone() + } else { + self.active_editor(cx).focus_handle(cx) + } } } @@ -250,9 +257,11 @@ pub struct TurnFields { impl AcpThreadView { pub fn new( + parent_id: Option, thread: Entity, login: Option, server_view: WeakEntity, + agent_icon: IconName, agent_name: SharedString, agent_display_name: SharedString, workspace: WeakEntity, @@ -339,9 +348,12 @@ impl AcpThreadView { Self { id, + parent_id, + focus_handle: cx.focus_handle(), thread, login, server_view, + agent_icon, agent_name, workspace, entry_view_state, @@ -448,6 +460,10 @@ impl AcpThreadView { } } + fn is_subagent(&self) -> bool { + self.parent_id.is_some() + } + /// Returns the currently active editor, either for a message that is being /// edited or the editor for a new message. pub(crate) fn active_editor(&self, cx: &App) -> Entity { @@ -1456,7 +1472,6 @@ impl AcpThreadView { let client = project.read(cx).client(); let session_id = self.thread.read(cx).session_id().clone(); - cx.spawn_in(window, async move |this, cx| { let response = client .request(proto::GetSharedAgentThread { @@ -2281,11 +2296,51 @@ impl AcpThreadView { ) } + pub(crate) fn render_subagent_titlebar(&mut self, cx: &mut Context) -> Option
{ + let Some(parent_session_id) = self.parent_id.clone() else { + return None; + }; + + let title = self.thread.read(cx).title(); + let server_view = self.server_view.clone(); + + Some( + h_flex() + .h(Tab::container_height(cx)) + .pl_2() + .pr_1p5() + .w_full() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .bg(cx.theme().colors().editor_background.opacity(0.2)) + .child(Label::new(title).color(Color::Muted)) + .child( + IconButton::new("minimize_subagent", IconName::Minimize) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Minimize Subagent")) + .on_click(move |_, window, cx| { + let _ = server_view.update(cx, |server_view, cx| { + server_view.navigate_to_session( + parent_session_id.clone(), + window, + cx, + ); + }); + }), + ), + ) + } + pub(crate) fn render_message_editor( &mut self, window: &mut Window, cx: &mut Context, ) -> AnyElement { + if self.is_subagent() { + return div().into_any_element(); + } + let focus_handle = self.message_editor.focus_handle(cx); let editor_bg_color = cx.theme().colors().editor_background; let editor_expanded = self.editor_expanded; @@ -3234,6 +3289,14 @@ impl AcpThreadView { .is_some_and(|checkpoint| checkpoint.show); let agent_name = self.agent_name.clone(); + let is_subagent = self.is_subagent(); + + let non_editable_icon = || { + IconButton::new("non_editable", IconName::PencilUnavailable) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .style(ButtonStyle::Transparent) + }; v_flex() .id(("user_message", entry_ix)) @@ -3283,22 +3346,28 @@ impl AcpThreadView { .py_3() .px_2() .rounded_md() - .shadow_md() .bg(cx.theme().colors().editor_background) .border_1() .when(is_indented, |this| { this.py_2().px_2().shadow_sm() }) - .when(editing && !editor_focus, |this| this.border_dashed()) .border_color(cx.theme().colors().border) - .map(|this|{ + .map(|this| { + if is_subagent { + return this.border_dashed(); + } if editing && editor_focus { - this.border_color(focus_border) - } else if message.id.is_some() { - this.hover(|s| s.border_color(focus_border.opacity(0.8))) - } else { - this + return this.border_color(focus_border); + } + if editing && !editor_focus { + return this.border_dashed() + } + if message.id.is_some() { + return this.shadow_md().hover(|s| { + s.border_color(focus_border.opacity(0.8)) + }); } + this }) .text_xs() .child(editor.clone().into_any_element()) @@ -3316,7 +3385,20 @@ impl AcpThreadView { .overflow_hidden(); let is_loading_contents = self.is_loading_contents; - if message.id.is_some() { + if is_subagent { + this.child( + base_container.border_dashed().child( + non_editable_icon().tooltip(move |_, cx| { + Tooltip::with_meta( + "Unavailable Editing", + None, + "Editing subagent messages is currently not supported.", + cx, + ) + }), + ), + ) + } else if message.id.is_some() { this.child( base_container .child( @@ -3356,10 +3438,7 @@ impl AcpThreadView { base_container .border_dashed() .child( - IconButton::new("editing_unavailable", IconName::PencilUnavailable) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .style(ButtonStyle::Transparent) + non_editable_icon() .tooltip(Tooltip::element({ move |_, _| { v_flex() @@ -4555,11 +4634,16 @@ impl AcpThreadView { let is_edit = matches!(tool_call.kind, acp::ToolKind::Edit) || tool_call.diffs().next().is_some(); - let is_subagent = tool_call.is_subagent(); // For subagent tool calls, render the subagent cards directly without wrapper - if is_subagent { - return self.render_subagent_tool_call(entry_ix, tool_call, window, cx); + if tool_call.is_subagent() { + return self.render_subagent_tool_call( + entry_ix, + tool_call, + tool_call.subagent_session_id.clone(), + window, + cx, + ); } let is_cancelled_edit = is_edit && matches!(tool_call.status, ToolCallStatus::Canceled); @@ -5243,6 +5327,7 @@ impl AcpThreadView { ) -> Div { let has_location = tool_call.locations.len() == 1; let is_file = tool_call.kind == acp::ToolKind::Edit && has_location; + let is_subagent_tool_call = tool_call.is_subagent(); let file_icon = if has_location { FileIcons::get_icon(&tool_call.locations[0].path, cx) @@ -5274,25 +5359,27 @@ impl AcpThreadView { .into_any_element() } else if is_file { div().child(file_icon).into_any_element() - } else { - div() - .child( - Icon::new(match tool_call.kind { - acp::ToolKind::Read => IconName::ToolSearch, - acp::ToolKind::Edit => IconName::ToolPencil, - acp::ToolKind::Delete => IconName::ToolDeleteFile, - acp::ToolKind::Move => IconName::ArrowRightLeft, - acp::ToolKind::Search => IconName::ToolSearch, - acp::ToolKind::Execute => IconName::ToolTerminal, - acp::ToolKind::Think => IconName::ToolThink, - acp::ToolKind::Fetch => IconName::ToolWeb, - acp::ToolKind::SwitchMode => IconName::ArrowRightLeft, - acp::ToolKind::Other | _ => IconName::ToolHammer, - }) - .size(IconSize::Small) - .color(Color::Muted), - ) + } else if is_subagent_tool_call { + Icon::new(self.agent_icon) + .size(IconSize::Small) + .color(Color::Muted) .into_any_element() + } else { + Icon::new(match tool_call.kind { + acp::ToolKind::Read => IconName::ToolSearch, + acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Delete => IconName::ToolDeleteFile, + acp::ToolKind::Move => IconName::ArrowRightLeft, + acp::ToolKind::Search => IconName::ToolSearch, + acp::ToolKind::Execute => IconName::ToolTerminal, + acp::ToolKind::Think => IconName::ToolThink, + acp::ToolKind::Fetch => IconName::ToolWeb, + acp::ToolKind::SwitchMode => IconName::ArrowRightLeft, + acp::ToolKind::Other | _ => IconName::ToolHammer, + }) + .size(IconSize::Small) + .color(Color::Muted) + .into_any_element() }; let gradient_overlay = { @@ -5478,10 +5565,6 @@ impl AcpThreadView { ToolCallContent::Terminal(terminal) => { self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx) } - ToolCallContent::SubagentThread(_thread) => { - // Subagent threads are rendered by render_subagent_tool_call, not here - Empty.into_any_element() - } } } @@ -5715,54 +5798,56 @@ impl AcpThreadView { &self, entry_ix: usize, tool_call: &ToolCall, + subagent_session_id: Option, window: &Window, cx: &Context, ) -> Div { - let subagent_threads: Vec<_> = tool_call - .content - .iter() - .filter_map(|c| c.subagent_thread().cloned()) - .collect(); - let tool_call_status = &tool_call.status; - v_flex() - .mx_5() - .my_1p5() - .gap_3() - .children( - subagent_threads - .into_iter() - .enumerate() - .map(|(context_ix, thread)| { - self.render_subagent_card( - entry_ix, - context_ix, - &thread, - tool_call_status, - window, - cx, - ) - }), - ) + let subagent_thread_view = subagent_session_id.and_then(|id| { + self.server_view + .upgrade() + .and_then(|server_view| server_view.read(cx).as_connected()) + .and_then(|connected| connected.threads.get(&id)) + }); + + let content = self.render_subagent_card( + entry_ix, + 0, + subagent_thread_view, + tool_call_status, + window, + cx, + ); + + v_flex().mx_5().my_1p5().gap_3().child(content) } fn render_subagent_card( &self, entry_ix: usize, context_ix: usize, - thread: &Entity, + thread_view: Option<&Entity>, tool_call_status: &ToolCallStatus, window: &Window, cx: &Context, ) -> AnyElement { - let thread_read = thread.read(cx); - let session_id = thread_read.session_id().clone(); - let title = thread_read.title(); - let action_log = thread_read.action_log(); - let changed_buffers = action_log.read(cx).changed_buffers(cx); - - let is_expanded = self.expanded_subagents.contains(&session_id); + let thread = thread_view + .as_ref() + .map(|view| view.read(cx).thread.clone()); + let session_id = thread + .as_ref() + .map(|thread| thread.read(cx).session_id().clone()); + let action_log = thread.as_ref().map(|thread| thread.read(cx).action_log()); + let changed_buffers = action_log + .map(|log| log.read(cx).changed_buffers(cx)) + .unwrap_or_default(); + + let is_expanded = if let Some(session_id) = &session_id { + self.expanded_subagents.contains(session_id) + } else { + false + }; let files_changed = changed_buffers.len(); let diff_stats = DiffStats::all_files(&changed_buffers, cx); @@ -5775,9 +5860,20 @@ impl AcpThreadView { ToolCallStatus::Canceled | ToolCallStatus::Failed | ToolCallStatus::Rejected ); - let card_header_id = - SharedString::from(format!("subagent-header-{}-{}", entry_ix, context_ix)); - let diff_stat_id = SharedString::from(format!("subagent-diff-{}-{}", entry_ix, context_ix)); + let title = thread + .as_ref() + .map(|t| t.read(cx).title()) + .unwrap_or_else(|| { + if is_canceled_or_failed { + "Subagent Canceled" + } else { + "Spawning Subagent…" + } + .into() + }); + + let card_header_id = format!("subagent-header-{}-{}", entry_ix, context_ix); + let diff_stat_id = format!("subagent-diff-{}-{}", entry_ix, context_ix); let icon = h_flex().w_4().justify_center().child(if is_running { SpinnerLabel::new() @@ -5795,15 +5891,17 @@ impl AcpThreadView { .into_any_element() }); - let has_expandable_content = thread_read.entries().iter().rev().any(|entry| { - if let AgentThreadEntry::AssistantMessage(msg) = entry { - msg.chunks.iter().any(|chunk| match chunk { - AssistantMessageChunk::Message { block } => block.markdown().is_some(), - AssistantMessageChunk::Thought { block } => block.markdown().is_some(), - }) - } else { - false - } + let has_expandable_content = thread.as_ref().map_or(false, |thread| { + thread.read(cx).entries().iter().rev().any(|entry| { + if let AgentThreadEntry::AssistantMessage(msg) = entry { + msg.chunks.iter().any(|chunk| match chunk { + AssistantMessageChunk::Message { block } => block.markdown().is_some(), + AssistantMessageChunk::Thought { block } => block.markdown().is_some(), + }) + } else { + false + } + }) }); v_flex() @@ -5815,8 +5913,8 @@ impl AcpThreadView { .child( h_flex() .group(&card_header_id) - .py_1() - .px_1p5() + .p_1() + .pl_1p5() .w_full() .gap_1() .justify_between() @@ -5825,11 +5923,7 @@ impl AcpThreadView { h_flex() .gap_1p5() .child(icon) - .child( - Label::new(title.to_string()) - .size(LabelSize::Small) - .color(Color::Default), - ) + .child(Label::new(title.to_string()).size(LabelSize::Small)) .when(files_changed > 0, |this| { this.child( h_flex() @@ -5851,95 +5945,126 @@ impl AcpThreadView { ) }), ) - .child( - h_flex() - .gap_1p5() - .when(is_running, |buttons| { - buttons.child( - Button::new( - SharedString::from(format!( - "stop-subagent-{}-{}", - entry_ix, context_ix - )), - "Stop", + .when_some(session_id, |this, session_id| { + this.child( + h_flex() + .when(has_expandable_content, |this| { + this.child( + IconButton::new( + format!( + "subagent-disclosure-{}-{}", + entry_ix, context_ix + ), + if is_expanded { + IconName::ChevronUp + } else { + IconName::ChevronDown + }, + ) + .icon_color(Color::Muted) + .icon_size(IconSize::Small) + .disabled(!has_expandable_content) + .visible_on_hover(card_header_id.clone()) + .on_click( + cx.listener({ + let session_id = session_id.clone(); + move |this, _, _, cx| { + if this.expanded_subagents.contains(&session_id) + { + this.expanded_subagents.remove(&session_id); + } else { + this.expanded_subagents + .insert(session_id.clone()); + } + cx.notify(); + } + }), + ), + ) + }) + .child( + IconButton::new( + format!("expand-subagent-{}-{}", entry_ix, context_ix), + IconName::Maximize, ) - .icon(IconName::Stop) - .icon_position(IconPosition::Start) + .icon_color(Color::Muted) .icon_size(IconSize::Small) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .tooltip(Tooltip::text("Stop this subagent")) - .on_click({ - let thread = thread.clone(); - cx.listener(move |_this, _event, _window, cx| { - thread.update(cx, |thread, _cx| { - thread.stop_by_user(); - }); - }) - }), - ) - }) - .child( - IconButton::new( - SharedString::from(format!( - "subagent-disclosure-{}-{}", - entry_ix, context_ix + .tooltip(Tooltip::text("Expand Subagent")) + .visible_on_hover(card_header_id) + .on_click(cx.listener( + move |this, _event, window, cx| { + this.server_view + .update(cx, |this, cx| { + this.navigate_to_session( + session_id.clone(), + window, + cx, + ); + }) + .ok(); + }, )), - if is_expanded { - IconName::ChevronUp - } else { - IconName::ChevronDown - }, ) - .shape(IconButtonShape::Square) - .icon_color(Color::Muted) - .icon_size(IconSize::Small) - .disabled(!has_expandable_content) - .when(has_expandable_content, |button| { - button.on_click(cx.listener({ - move |this, _, _, cx| { - if this.expanded_subagents.contains(&session_id) { - this.expanded_subagents.remove(&session_id); - } else { - this.expanded_subagents.insert(session_id.clone()); - } - cx.notify(); - } - })) - }) - .when( - !has_expandable_content, - |button| { - button.tooltip(Tooltip::text("Waiting for content...")) - }, - ), - ), - ), + .when(is_running, |buttons| { + buttons.child( + IconButton::new( + format!("stop-subagent-{}-{}", entry_ix, context_ix), + IconName::Stop, + ) + .icon_size(IconSize::Small) + .icon_color(Color::Error) + .tooltip(Tooltip::text("Stop Subagent")) + .when_some( + thread_view + .as_ref() + .map(|view| view.read(cx).thread.clone()), + |this, thread| { + this.on_click(cx.listener( + move |_this, _event, _window, cx| { + thread.update(cx, |thread, _cx| { + thread.stop_by_user(); + }); + }, + )) + }, + ), + ) + }), + ) + }), ) - .when(is_expanded, |this| { - this.child( - self.render_subagent_expanded_content(entry_ix, context_ix, thread, window, cx), + .when_some(thread_view, |this, thread_view| { + let thread = &thread_view.read(cx).thread; + this.when(is_expanded, |this| { + this.child( + self.render_subagent_expanded_content( + entry_ix, context_ix, thread, window, cx, + ), + ) + }) + .children( + thread + .read(cx) + .first_tool_awaiting_confirmation() + .and_then(|tc| { + if let ToolCallStatus::WaitingForConfirmation { options, .. } = + &tc.status + { + Some(self.render_subagent_pending_tool_call( + entry_ix, + context_ix, + thread.clone(), + tc, + options, + window, + cx, + )) + } else { + None + } + }), ) }) - .children( - thread_read - .first_tool_awaiting_confirmation() - .and_then(|tc| { - if let ToolCallStatus::WaitingForConfirmation { options, .. } = &tc.status { - Some(self.render_subagent_pending_tool_call( - entry_ix, - context_ix, - thread.clone(), - tc, - options, - window, - cx, - )) - } else { - None - } - }), - ) .into_any_element() } @@ -6841,6 +6966,7 @@ impl AcpThreadView { } fn render_new_version_callout(&self, version: &SharedString, cx: &mut Context) -> Div { + let server_view = self.server_view.clone(); v_flex().w_full().justify_end().child( h_flex() .p_2() @@ -6865,11 +6991,11 @@ impl AcpThreadView { Button::new("update-button", format!("Update to v{}", version)) .label_size(LabelSize::Small) .style(ButtonStyle::Tinted(TintColor::Accent)) - .on_click(cx.listener(|this, _, window, cx| { - this.server_view + .on_click(move |_, window, cx| { + server_view .update(cx, |view, cx| view.reset(window, cx)) .ok(); - })), + }), ), ) } @@ -7028,8 +7154,20 @@ impl Render for AcpThreadView { v_flex() .key_context("AcpThread") + .track_focus(&self.focus_handle(cx)) .on_action(cx.listener(|this, _: &menu::Cancel, _, cx| { - this.cancel_generation(cx); + if this.parent_id.is_none() { + this.cancel_generation(cx); + } + })) + .on_action(cx.listener(|this, _: &workspace::GoBack, window, cx| { + if let Some(parent_session_id) = this.parent_id.clone() { + this.server_view + .update(cx, |view, cx| { + view.navigate_to_session(parent_session_id, window, cx); + }) + .ok(); + } })) .on_action(cx.listener(Self::keep_all)) .on_action(cx.listener(Self::reject_all)) @@ -7153,6 +7291,7 @@ impl Render for AcpThreadView { } })) .size_full() + .children(self.render_subagent_titlebar(cx)) .child(conversation) .children(self.render_activity_bar(window, cx)) .when(self.show_codex_windows_warning, |this| { diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 850822679d2828b96ba6218c4d48e570764d6de6..c5bdaaf91bc3cfc633e5ed9812ae9a1154b5e659 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1360,6 +1360,7 @@ impl AgentDiff { } AcpThreadEvent::TitleUpdated | AcpThreadEvent::TokenUsageUpdated + | AcpThreadEvent::SubagentSpawned(_) | AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::PromptCapabilitiesUpdated @@ -1761,7 +1762,7 @@ mod tests { .update(|cx| { connection .clone() - .new_thread(project.clone(), Path::new(path!("/test")), cx) + .new_session(project.clone(), Path::new(path!("/test")), cx) }) .await .unwrap(); @@ -1942,7 +1943,7 @@ mod tests { .update(|_, cx| { connection .clone() - .new_thread(project.clone(), Path::new(path!("/test")), cx) + .new_session(project.clone(), Path::new(path!("/test")), cx) }) .await .unwrap(); diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index bd9c31a983b723c222987544561cea82a97bad2b..ccfc0cd7073b08249a9bdc07cf3525f92e689e9a 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -157,7 +157,7 @@ pub fn init(cx: &mut App) { .and_then(|thread_view| { thread_view .read(cx) - .as_active_thread() + .active_thread() .map(|r| r.read(cx).thread.clone()) }); @@ -922,7 +922,7 @@ impl AgentPanel { return; }; - let Some(active_thread) = thread_view.read(cx).as_active_thread() else { + let Some(active_thread) = thread_view.read(cx).active_thread() else { return; }; @@ -1195,7 +1195,7 @@ impl AgentPanel { ) { if let Some(workspace) = self.workspace.upgrade() && let Some(thread_view) = self.active_thread_view() - && let Some(active_thread) = thread_view.read(cx).as_active_thread() + && let Some(active_thread) = thread_view.read(cx).active_thread() { active_thread.update(cx, |thread, cx| { thread @@ -1423,7 +1423,7 @@ impl AgentPanel { match &self.active_view { ActiveView::AgentThread { thread_view, .. } => thread_view .read(cx) - .as_active_thread() + .active_thread() .map(|r| r.read(cx).thread.clone()), _ => None, } @@ -1851,7 +1851,7 @@ impl AgentPanel { if let Some(title_editor) = thread_view .read(cx) - .as_active_thread() + .parent_thread(cx) .and_then(|r| r.read(cx).title_editor.clone()) { let container = div() diff --git a/crates/agent_ui_v2/src/agent_thread_pane.rs b/crates/agent_ui_v2/src/agent_thread_pane.rs index 8959d45721981aa9955cf79d9330ce38e9255ba4..c6ae3f0ca525b2df5810a8b11c65438428d05a3f 100644 --- a/crates/agent_ui_v2/src/agent_thread_pane.rs +++ b/crates/agent_ui_v2/src/agent_thread_pane.rs @@ -142,7 +142,7 @@ impl AgentThreadPane { fn title(&self, cx: &App) -> SharedString { if let Some(active_thread_view) = &self.thread_view { let thread_view = active_thread_view.view.read(cx); - if let Some(ready) = thread_view.as_active_thread() { + if let Some(ready) = thread_view.active_thread() { let title = ready.read(cx).thread.read(cx).title(); if !title.is_empty() { return title; diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 249b936e1bc14d332d19bd1a2d8f1b986068be3f..f8171177e9ba141451390aa65a583d5094c884d2 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -328,6 +328,9 @@ impl ExampleContext { "{}Bug: Tool confirmation should not be required in eval", log_prefix ), + ThreadEvent::SubagentSpawned(session) => { + println!("{log_prefix} Got subagent spawn: {session:?}"); + } ThreadEvent::Retry(status) => { println!("{log_prefix} Got retry: {status:?}"); } diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 41234bc02c6cd911b31637e6ab67346e4b9677c0..77f28f1d67b9a7c1029633776fa1a18c0270920f 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -323,7 +323,7 @@ impl ExampleInstance { }; thread.update(cx, |thread, cx| { - thread.add_default_tools(Rc::new(EvalThreadEnvironment { + thread.add_default_tools(None, Rc::new(EvalThreadEnvironment { project: project.clone(), }), cx); thread.set_profile(meta.profile_id.clone(), cx); @@ -679,6 +679,18 @@ impl agent::ThreadEnvironment for EvalThreadEnvironment { Ok(Rc::new(EvalTerminalHandle { terminal }) as Rc) }) } + + fn create_subagent( + &self, + _parent_thread: Entity, + _label: String, + _initial_prompt: String, + _timeout_ms: Option, + _allowed_tools: Option>, + _cx: &mut App, + ) -> Result> { + unimplemented!() + } } struct LanguageModelInterceptor { diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index 2e4567b9bfdc794eb394f78e193b4345d93c6ac5..f167454e9ab4e2b6e329925c0f2ee4c9c29951eb 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -69,7 +69,6 @@ use { time::Duration, }, util::ResultExt as _, - watch, workspace::{AppState, Workspace}, zed_actions::OpenSettingsAt, }; @@ -465,26 +464,6 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()> } } - // Run Test 4: Subagent Cards visual tests - #[cfg(feature = "visual-tests")] - { - println!("\n--- Test 4: subagent_cards (running, completed, expanded) ---"); - match run_subagent_visual_tests(app_state.clone(), &mut cx, update_baseline) { - Ok(TestResult::Passed) => { - println!("✓ subagent_cards: PASSED"); - passed += 1; - } - Ok(TestResult::BaselineUpdated(_)) => { - println!("✓ subagent_cards: Baselines updated"); - updated += 1; - } - Err(e) => { - eprintln!("✗ subagent_cards: FAILED - {}", e); - failed += 1; - } - } - } - // Run Test 5: Breakpoint Hover visual tests println!("\n--- Test 5: breakpoint_hover (3 variants) ---"); match run_breakpoint_hover_visual_tests(app_state.clone(), &mut cx, update_baseline) { @@ -1927,337 +1906,6 @@ impl AgentServer for StubAgentServer { } } -#[cfg(all(target_os = "macos", feature = "visual-tests"))] -fn run_subagent_visual_tests( - app_state: Arc, - cx: &mut VisualTestAppContext, - update_baseline: bool, -) -> Result { - use acp_thread::{ - AcpThread, SUBAGENT_TOOL_NAME, ToolCallUpdateSubagentThread, meta_with_tool_name, - }; - use agent_ui::AgentPanel; - - // Create a temporary project directory - let temp_dir = tempfile::tempdir()?; - let temp_path = temp_dir.keep(); - let canonical_temp = temp_path.canonicalize()?; - let project_path = canonical_temp.join("project"); - std::fs::create_dir_all(&project_path)?; - - // Create a project - let project = cx.update(|cx| { - project::Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - project::LocalProjectFlags { - init_worktree_trust: false, - ..Default::default() - }, - cx, - ) - }); - - // Add the test directory as a worktree - let add_worktree_task = project.update(cx, |project, cx| { - project.find_or_create_worktree(&project_path, true, cx) - }); - - cx.foreground_executor - .block_test(add_worktree_task) - .log_err(); - - cx.run_until_parked(); - - // Create stub connection - we'll manually inject the subagent content - let connection = StubAgentConnection::new(); - - // Create a subagent tool call (in progress state) - let tool_call = acp::ToolCall::new("subagent-tool-1", "2 subagents") - .kind(acp::ToolKind::Other) - .meta(meta_with_tool_name(SUBAGENT_TOOL_NAME)) - .status(acp::ToolCallStatus::InProgress); - - connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(tool_call)]); - - let stub_agent: Rc = Rc::new(StubAgentServer::new(connection.clone())); - - // Create a window sized for the agent panel - let window_size = size(px(600.0), px(700.0)); - let bounds = Bounds { - origin: point(px(0.0), px(0.0)), - size: window_size, - }; - - let workspace_window: WindowHandle = cx - .update(|cx| { - cx.open_window( - WindowOptions { - window_bounds: Some(WindowBounds::Windowed(bounds)), - focus: false, - show: false, - ..Default::default() - }, - |window, cx| { - cx.new(|cx| { - Workspace::new(None, project.clone(), app_state.clone(), window, cx) - }) - }, - ) - }) - .context("Failed to open agent window")?; - - cx.run_until_parked(); - - // Load the AgentPanel - let (weak_workspace, async_window_cx) = workspace_window - .update(cx, |workspace, window, cx| { - (workspace.weak_handle(), window.to_async(cx)) - }) - .context("Failed to get workspace handle")?; - - let prompt_builder = - cx.update(|cx| prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx)); - let panel = cx - .foreground_executor - .block_test(AgentPanel::load( - weak_workspace, - prompt_builder, - async_window_cx, - )) - .context("Failed to load AgentPanel")?; - - cx.update_window(workspace_window.into(), |_, _window, cx| { - workspace_window - .update(cx, |workspace, window, cx| { - workspace.add_panel(panel.clone(), window, cx); - workspace.open_panel::(window, cx); - }) - .log_err(); - })?; - - cx.run_until_parked(); - - // Open the stub thread - cx.update_window(workspace_window.into(), |_, window, cx| { - panel.update(cx, |panel: &mut agent_ui::AgentPanel, cx| { - panel.open_external_thread_with_server(stub_agent.clone(), window, cx); - }); - })?; - - cx.run_until_parked(); - - // Get the thread view and send a message to trigger the subagent tool call - let thread_view = cx - .read(|cx| panel.read(cx).active_thread_view_for_tests().cloned()) - .ok_or_else(|| anyhow::anyhow!("No active thread view"))?; - - let thread = cx - .read(|cx| { - thread_view - .read(cx) - .as_active_thread() - .map(|active| active.read(cx).thread.clone()) - }) - .ok_or_else(|| anyhow::anyhow!("Thread not available"))?; - - // Send the message to trigger the subagent response - let send_future = thread.update(cx, |thread: &mut acp_thread::AcpThread, cx| { - thread.send(vec!["Run two subagents".into()], cx) - }); - - cx.foreground_executor.block_test(send_future).log_err(); - - cx.run_until_parked(); - - // Get the tool call ID - let tool_call_id = cx - .read(|cx| { - thread.read(cx).entries().iter().find_map(|entry| { - if let acp_thread::AgentThreadEntry::ToolCall(tool_call) = entry { - Some(tool_call.id.clone()) - } else { - None - } - }) - }) - .ok_or_else(|| anyhow::anyhow!("Expected a ToolCall entry in thread"))?; - - // Create two subagent AcpThreads and inject them - let subagent1 = cx.update(|cx| { - let action_log = cx.new(|_| action_log::ActionLog::new(project.clone())); - let session_id = acp::SessionId::new("subagent-1"); - cx.new(|cx| { - let mut thread = AcpThread::new( - "Exploring test-repo", - Rc::new(connection.clone()), - project.clone(), - action_log, - session_id, - watch::Receiver::constant(acp::PromptCapabilities::new()), - cx, - ); - // Add some content to this subagent - thread.push_assistant_content_block( - "## Summary of test-repo\n\nThis is a test repository with:\n\n- **Files:** test.txt\n- **Purpose:** Testing".into(), - false, - cx, - ); - thread - }) - }); - - let subagent2 = cx.update(|cx| { - let action_log = cx.new(|_| action_log::ActionLog::new(project.clone())); - let session_id = acp::SessionId::new("subagent-2"); - cx.new(|cx| { - let mut thread = AcpThread::new( - "Exploring test-worktree", - Rc::new(connection.clone()), - project.clone(), - action_log, - session_id, - watch::Receiver::constant(acp::PromptCapabilities::new()), - cx, - ); - // Add some content to this subagent - thread.push_assistant_content_block( - "## Summary of test-worktree\n\nThis directory contains:\n\n- A single `config.json` file\n- Basic project setup".into(), - false, - cx, - ); - thread - }) - }); - - // Inject subagent threads into the tool call - thread.update(cx, |thread: &mut acp_thread::AcpThread, cx| { - thread - .update_tool_call( - ToolCallUpdateSubagentThread { - id: tool_call_id.clone(), - thread: subagent1, - }, - cx, - ) - .log_err(); - thread - .update_tool_call( - ToolCallUpdateSubagentThread { - id: tool_call_id.clone(), - thread: subagent2, - }, - cx, - ) - .log_err(); - }); - - cx.run_until_parked(); - - cx.update_window(workspace_window.into(), |_, window, _cx| { - window.refresh(); - })?; - - cx.run_until_parked(); - - // Capture subagents in RUNNING state (tool call still in progress) - let running_result = run_visual_test( - "subagent_cards_running", - workspace_window.into(), - cx, - update_baseline, - )?; - - // Now mark the tool call as completed by updating it through the thread - thread.update(cx, |thread: &mut acp_thread::AcpThread, cx| { - thread - .handle_session_update( - acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate::new( - tool_call_id.clone(), - acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::Completed), - )), - cx, - ) - .log_err(); - }); - - cx.run_until_parked(); - - cx.update_window(workspace_window.into(), |_, window, _cx| { - window.refresh(); - })?; - - cx.run_until_parked(); - - // Capture subagents in COMPLETED state - let completed_result = run_visual_test( - "subagent_cards_completed", - workspace_window.into(), - cx, - update_baseline, - )?; - - // Expand the first subagent - thread_view.update(cx, |view: &mut agent_ui::acp::AcpServerView, cx| { - view.expand_subagent(acp::SessionId::new("subagent-1"), cx); - }); - - cx.run_until_parked(); - - cx.update_window(workspace_window.into(), |_, window, _cx| { - window.refresh(); - })?; - - cx.run_until_parked(); - - // Capture subagent in EXPANDED state - let expanded_result = run_visual_test( - "subagent_cards_expanded", - workspace_window.into(), - cx, - update_baseline, - )?; - - // Cleanup - workspace_window - .update(cx, |workspace, _window, cx| { - let project = workspace.project().clone(); - project.update(cx, |project, cx| { - let worktree_ids: Vec<_> = - project.worktrees(cx).map(|wt| wt.read(cx).id()).collect(); - for id in worktree_ids { - project.remove_worktree(id, cx); - } - }); - }) - .log_err(); - - cx.run_until_parked(); - - cx.update_window(workspace_window.into(), |_, window, _cx| { - window.remove_window(); - }) - .log_err(); - - cx.run_until_parked(); - - for _ in 0..15 { - cx.advance_clock(Duration::from_millis(100)); - cx.run_until_parked(); - } - - match (&running_result, &completed_result, &expanded_result) { - (TestResult::Passed, TestResult::Passed, TestResult::Passed) => Ok(TestResult::Passed), - (TestResult::BaselineUpdated(p), _, _) - | (_, TestResult::BaselineUpdated(p), _) - | (_, _, TestResult::BaselineUpdated(p)) => Ok(TestResult::BaselineUpdated(p.clone())), - } -} - #[cfg(all(target_os = "macos", feature = "visual-tests"))] fn run_agent_thread_view_test( app_state: Arc, @@ -2471,7 +2119,7 @@ fn run_agent_thread_view_test( .read(|cx| { thread_view .read(cx) - .as_active_thread() + .active_thread() .map(|active| active.read(cx).thread.clone()) }) .ok_or_else(|| anyhow::anyhow!("Thread not available"))?;