diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index c26f4f218ce2ae656fdd2dec85eb4389bb2f7c8d..e174365ea3afbc521d8fb62e4b9c2df192e453b6 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -4,6 +4,26 @@ mod mention; mod terminal; use agent_settings::AgentSettings; + +/// Key used in ACP ToolCall meta to store the tool's programmatic name. +/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field. +pub const TOOL_NAME_META_KEY: &str = "tool_name"; + +/// The tool name for subagent spawning +pub const SUBAGENT_TOOL_NAME: &str = "subagent"; + +/// Helper to extract tool name from ACP meta +pub fn tool_name_from_meta(meta: &Option) -> Option { + meta.as_ref() + .and_then(|m| m.get(TOOL_NAME_META_KEY)) + .and_then(|v| v.as_str()) + .map(|s| SharedString::from(s.to_owned())) +} + +/// Helper to create meta with tool name +pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta { + acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())]) +} use collections::HashSet; pub use connection::*; pub use diff::*; @@ -195,6 +215,7 @@ pub struct ToolCall { pub raw_input: Option, pub raw_input_markdown: Option>, pub raw_output: Option, + pub tool_name: Option, } impl ToolCall { @@ -229,6 +250,8 @@ impl ToolCall { .as_ref() .and_then(|input| markdown_for_raw_output(input, &language_registry, cx)); + let tool_name = tool_name_from_meta(&tool_call.meta); + let result = Self { id: tool_call.tool_call_id, label: cx @@ -241,6 +264,7 @@ impl ToolCall { raw_input: tool_call.raw_input, raw_input_markdown, raw_output: tool_call.raw_output, + tool_name, }; Ok(result) } @@ -338,6 +362,7 @@ impl ToolCall { ToolCallContent::Diff(diff) => Some(diff), ToolCallContent::ContentBlock(_) => None, ToolCallContent::Terminal(_) => None, + ToolCallContent::SubagentThread(_) => None, }) } @@ -346,9 +371,26 @@ impl ToolCall { ToolCallContent::Terminal(terminal) => Some(terminal), ToolCallContent::ContentBlock(_) => None, ToolCallContent::Diff(_) => None, + ToolCallContent::SubagentThread(_) => None, }) } + pub fn subagent_thread(&self) -> Option<&Entity> { + self.content.iter().find_map(|content| match content { + ToolCallContent::SubagentThread(thread) => Some(thread), + _ => None, + }) + } + + pub fn is_subagent(&self) -> bool { + matches!(self.kind, acp::ToolKind::Other) + && self + .tool_name + .as_ref() + .map(|n| n.as_ref() == SUBAGENT_TOOL_NAME) + .unwrap_or(false) + } + fn to_markdown(&self, cx: &App) -> String { let mut markdown = format!( "**Tool Call: {}**\nStatus: {}\n\n", @@ -642,6 +684,7 @@ pub enum ToolCallContent { ContentBlock(ContentBlock), Diff(Entity), Terminal(Entity), + SubagentThread(Entity), } impl ToolCallContent { @@ -713,6 +756,7 @@ impl ToolCallContent { Self::ContentBlock(content) => content.to_markdown(cx).to_string(), Self::Diff(diff) => diff.read(cx).to_markdown(cx), Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx), + Self::SubagentThread(thread) => thread.read(cx).to_markdown(cx), } } @@ -722,6 +766,13 @@ impl ToolCallContent { _ => None, } } + + pub fn subagent_thread(&self) -> Option<&Entity> { + match self { + Self::SubagentThread(thread) => Some(thread), + _ => None, + } + } } #[derive(Debug, PartialEq)] @@ -729,6 +780,7 @@ pub enum ToolCallUpdate { UpdateFields(acp::ToolCallUpdate), UpdateDiff(ToolCallUpdateDiff), UpdateTerminal(ToolCallUpdateTerminal), + UpdateSubagentThread(ToolCallUpdateSubagentThread), } impl ToolCallUpdate { @@ -737,6 +789,7 @@ impl ToolCallUpdate { Self::UpdateFields(update) => &update.tool_call_id, Self::UpdateDiff(diff) => &diff.id, Self::UpdateTerminal(terminal) => &terminal.id, + Self::UpdateSubagentThread(subagent) => &subagent.id, } } } @@ -771,6 +824,18 @@ pub struct ToolCallUpdateTerminal { pub terminal: Entity, } +impl From for ToolCallUpdate { + fn from(subagent: ToolCallUpdateSubagentThread) -> Self { + Self::UpdateSubagentThread(subagent) + } +} + +#[derive(Debug, PartialEq)] +pub struct ToolCallUpdateSubagentThread { + pub id: acp::ToolCallId, + pub thread: Entity, +} + #[derive(Debug, Default)] pub struct Plan { pub entries: Vec, @@ -1425,6 +1490,7 @@ impl AcpThread { raw_input: None, raw_input_markdown: None, raw_output: None, + tool_name: None, }; self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx); return Ok(()); @@ -1451,6 +1517,11 @@ impl AcpThread { call.content .push(ToolCallContent::Terminal(update.terminal)); } + ToolCallUpdate::UpdateSubagentThread(update) => { + call.content.clear(); + call.content + .push(ToolCallContent::SubagentThread(update.thread)); + } } cx.emit(AcpThreadEvent::EntryUpdated(ix)); diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 3ed5afb38c3e23ad38a4f0bf84cb6b1b4e63ac6b..a0bea1deb7372fc51b1e718cf6e51303e9c44239 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -7,6 +7,7 @@ use client::{Client, UserStore}; use cloud_llm_client::CompletionIntent; use collections::IndexMap; use context_server::{ContextServer, ContextServerCommand, ContextServerId}; +use feature_flags::FeatureFlagAppExt as _; use fs::{FakeFs, Fs}; use futures::{ FutureExt as _, StreamExt, @@ -343,7 +344,9 @@ async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAp "expected tool call update to include terminal content" ); - smol::Timer::after(Duration::from_millis(25)).await; + cx.background_executor + .timer(Duration::from_millis(25)) + .await; assert!( !handle.was_killed(), @@ -2989,10 +2992,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { acp::ToolCall::new("1", "Thinking") .kind(acp::ToolKind::Think) .raw_input(json!({})) - .meta(acp::Meta::from_iter([( - "tool_name".into(), - "thinking".into() - )])) + .meta(acp_thread::meta_with_tool_name("thinking")) ); let update = expect_tool_call_update_fields(&mut events).await; assert_eq!( @@ -3927,6 +3927,1163 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { } } +#[gpui::test] +async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + let environment = Rc::new(FakeThreadEnvironment { handle }); + + let thread = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + Some(model), + cx, + ); + thread.add_default_tools(environment, cx); + thread + }); + + thread.read_with(cx, |thread, _| { + assert!( + thread.has_registered_tool("subagent"), + "subagent tool should be present when feature flag is enabled" + ); + }); +} + +#[gpui::test] +async fn test_subagent_tool_is_absent_when_feature_flag_disabled(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(false, vec![]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + let environment = Rc::new(FakeThreadEnvironment { handle }); + + let thread = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + Some(model), + cx, + ); + thread.add_default_tools(environment, cx); + thread + }); + + thread.read_with(cx, |thread, _| { + assert!( + !thread.has_registered_tool("subagent"), + "subagent tool should not be present when feature flag is disabled" + ); + }); +} + +#[gpui::test] +async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + subagent.read_with(cx, |thread, _| { + assert!(thread.is_subagent()); + assert_eq!(thread.depth(), 1); + assert!(thread.model().is_some()); + }); +} + +#[gpui::test] +async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: MAX_SUBAGENT_DEPTH, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + let environment = Rc::new(FakeThreadEnvironment { handle }); + + let deep_subagent = cx.new(|cx| { + let mut thread = Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ); + thread.add_default_tools(environment, cx); + thread + }); + + deep_subagent.read_with(cx, |thread, _| { + assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH); + assert!( + !thread.has_registered_tool("subagent"), + "subagent tool should not be present at max depth" + ); + }); +} + +#[gpui::test] +async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize your work".to_string(), + context_low_prompt: "Context low, wrap up".to_string(), + }; + + let project = thread.read_with(cx, |t, _| t.project.clone()); + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + let task_prompt = "Find all TODO comments in the codebase"; + subagent + .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx)) + .unwrap(); + cx.run_until_parked(); + + let pending = fake_model.pending_completions(); + assert_eq!(pending.len(), 1, "should have one pending completion"); + + let messages = &pending[0].messages; + let user_messages: Vec<_> = messages + .iter() + .filter(|m| m.role == language_model::Role::User) + .collect(); + assert_eq!(user_messages.len(), 1, "should have one user message"); + + let content = &user_messages[0].content[0]; + assert!( + content.to_str().unwrap().contains("TODO"), + "task prompt should be in user message" + ); +} + +#[gpui::test] +async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Please summarize what you found".to_string(), + context_low_prompt: "Context low, wrap up".to_string(), + }; + + let project = thread.read_with(cx, |t, _| t.project.clone()); + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Do some work", cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("I did the work"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + subagent + .update(cx, |thread, cx| thread.request_final_summary(cx)) + .unwrap(); + cx.run_until_parked(); + + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "should have pending completion for summary" + ); + + let messages = &pending.last().unwrap().messages; + let user_messages: Vec<_> = messages + .iter() + .filter(|m| m.role == language_model::Role::User) + .collect(); + + let last_user = user_messages.last().unwrap(); + assert!( + last_user.content[0].to_str().unwrap().contains("summarize"), + "summary prompt should be sent" + ); +} + +#[gpui::test] +async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let subagent = cx.new(|cx| { + let mut thread = Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ); + thread.add_tool(EchoTool); + thread.add_tool(DelayTool); + thread.add_tool(WordListTool); + thread + }); + + subagent.read_with(cx, |thread, _| { + assert!(thread.has_registered_tool("echo")); + assert!(thread.has_registered_tool("delay")); + assert!(thread.has_registered_tool("word_list")); + }); + + let allowed: collections::HashSet = + vec!["echo".into()].into_iter().collect(); + + subagent.update(cx, |thread, _cx| { + thread.restrict_tools(&allowed); + }); + + subagent.read_with(cx, |thread, _| { + assert!( + thread.has_registered_tool("echo"), + "echo should still be available" + ); + assert!( + !thread.has_registered_tool("delay"), + "delay should be removed" + ); + assert!( + !thread.has_registered_tool("word_list"), + "word_list should be removed" + ); + }); +} + +#[gpui::test] +async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let parent = cx.new(|cx| { + Thread::new( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + parent.update(cx, |thread, _cx| { + thread.register_running_subagent(subagent.downgrade()); + }); + + subagent + .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) + .unwrap(); + cx.run_until_parked(); + + subagent.read_with(cx, |thread, _| { + assert!(!thread.is_turn_complete(), "subagent should be running"); + }); + + parent.update(cx, |thread, cx| { + thread.cancel(cx).detach(); + }); + + subagent.read_with(cx, |thread, _| { + assert!( + thread.is_turn_complete(), + "subagent should be cancelled when parent cancels" + ); + }); +} + +#[gpui::test] +async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let project = thread.read_with(cx, |t, _| t.project.clone()); + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + subagent + .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) + .unwrap(); + cx.run_until_parked(); + + subagent.read_with(cx, |thread, _| { + assert!(!thread.is_turn_complete(), "turn should be in progress"); + }); + + fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey { + provider: LanguageModelProviderName::from("Fake".to_string()), + }); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + subagent.read_with(cx, |thread, _| { + assert!( + thread.is_turn_complete(), + "turn should be complete after non-retryable error" + ); + }); +} + +#[gpui::test] +async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize your work".to_string(), + context_low_prompt: "Context low, stop and summarize".to_string(), + }; + + let project = thread.read_with(cx, |t, _| t.project.clone()); + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + model.clone(), + subagent_context.clone(), + cx, + ) + }); + + subagent.update(cx, |thread, _| { + thread.add_tool(EchoTool); + }); + + subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Do some work", cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Working on it..."); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx)); + assert!( + interrupt_result.is_ok(), + "interrupt_for_summary should succeed" + ); + + cx.run_until_parked(); + + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "should have pending completion for interrupted summary" + ); + + let messages = &pending.last().unwrap().messages; + let user_messages: Vec<_> = messages + .iter() + .filter(|m| m.role == language_model::Role::User) + .collect(); + + let last_user = user_messages.last().unwrap(); + let content_str = last_user.content[0].to_str().unwrap(); + assert!( + content_str.contains("Context low") || content_str.contains("stop and summarize"), + "context_low_prompt should be sent when interrupting: got {:?}", + content_str + ); +} + +#[gpui::test] +async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let project = thread.read_with(cx, |t, _| t.project.clone()); + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + subagent + .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) + .unwrap(); + cx.run_until_parked(); + + let max_tokens = model.max_token_count(); + let high_usage = language_model::TokenUsage { + input_tokens: (max_tokens as f64 * 0.80) as u64, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }; + + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage)); + fake_model.send_last_completion_stream_text_chunk("Working..."); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage()); + assert!(usage.is_some(), "should have token usage after completion"); + + let usage = usage.unwrap(); + let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32); + assert!( + remaining_ratio <= 0.25, + "remaining ratio should be at or below 25% (got {}%), indicating context is low", + remaining_ratio * 100.0 + ); +} + +#[gpui::test] +async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let parent = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ); + thread.add_tool(EchoTool); + thread + }); + + let parent_tool_names: Vec = vec!["echo".into()]; + + let tool = Arc::new(SubagentTool::new( + parent.downgrade(), + project, + project_context, + context_server_registry, + Templates::new(), + 0, + parent_tool_names, + )); + + let result = tool.validate_allowed_tools(&Some(vec!["nonexistent_tool".to_string()])); + assert!(result.is_err(), "should reject unknown tool"); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("nonexistent_tool"), + "error should mention the invalid tool name: {}", + err_msg + ); + assert!( + err_msg.contains("not available"), + "error should explain the tool is not available: {}", + err_msg + ); +} + +#[gpui::test] +async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let project = thread.read_with(cx, |t, _| t.project.clone()); + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + subagent + .update(cx, |thread, cx| thread.submit_user_message("Do work", cx)) + .unwrap(); + cx.run_until_parked(); + + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + subagent.read_with(cx, |thread, _| { + assert!( + thread.is_turn_complete(), + "turn should complete even with empty response" + ); + }); +} + +#[gpui::test] +async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let depth_1_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("root-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let depth_1_subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + model.clone(), + depth_1_context, + cx, + ) + }); + + depth_1_subagent.read_with(cx, |thread, _| { + assert_eq!(thread.depth(), 1); + assert!(thread.is_subagent()); + }); + + let depth_2_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"), + depth: 2, + summary_prompt: "Summarize depth 2".to_string(), + context_low_prompt: "Context low depth 2".to_string(), + }; + + let depth_2_subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + model.clone(), + depth_2_context, + cx, + ) + }); + + depth_2_subagent.read_with(cx, |thread, _| { + assert_eq!(thread.depth(), 2); + assert!(thread.is_subagent()); + }); + + depth_2_subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Nested task", cx) + }) + .unwrap(); + cx.run_until_parked(); + + let pending = model.as_fake().pending_completions(); + assert!( + !pending.is_empty(), + "depth-2 subagent should be able to submit messages" + ); +} + +#[gpui::test] +async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) { + init_test(cx); + always_allow_tools(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"), + depth: 1, + summary_prompt: "Summarize what you did".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let subagent = cx.new(|cx| { + let mut thread = Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + Templates::new(), + model.clone(), + subagent_context, + cx, + ); + thread.add_tool(EchoTool); + thread + }); + + subagent.read_with(cx, |thread, _| { + assert!( + thread.has_registered_tool("echo"), + "subagent should have echo tool" + ); + }); + + subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Use the echo tool to echo 'hello world'", cx) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_call_1".into(), + name: EchoTool::name().into(), + raw_input: json!({"text": "hello world"}).to_string(), + input: json!({"text": "hello world"}), + is_input_complete: true, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "should have pending completion after tool use" + ); + + let last_completion = pending.last().unwrap(); + let has_tool_result = last_completion.messages.iter().any(|m| { + m.content + .iter() + .any(|c| matches!(c, MessageContent::ToolResult(_))) + }); + assert!( + has_tool_result, + "tool result should be in the messages sent back to the model" + ); +} + +#[gpui::test] +async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + + let parent = cx.new(|cx| { + Thread::new( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + + let mut subagents = Vec::new(); + for i in 0..MAX_PARALLEL_SUBAGENTS { + let subagent_context = SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)), + depth: 1, + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + let subagent = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + model.clone(), + subagent_context, + cx, + ) + }); + + parent.update(cx, |thread, _cx| { + thread.register_running_subagent(subagent.downgrade()); + }); + subagents.push(subagent); + } + + parent.read_with(cx, |thread, _| { + assert_eq!( + thread.running_subagent_count(), + MAX_PARALLEL_SUBAGENTS, + "should have MAX_PARALLEL_SUBAGENTS registered" + ); + }); + + let parent_tool_names: Vec = vec![]; + + let tool = Arc::new(SubagentTool::new( + parent.downgrade(), + project.clone(), + project_context, + context_server_registry, + Templates::new(), + 0, + parent_tool_names, + )); + + let (event_stream, _rx) = crate::ToolCallEventStream::test(); + + let result = cx.update(|cx| { + tool.run( + SubagentToolInput { + label: "Test".to_string(), + task_prompt: "Do something".to_string(), + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + timeout_ms: None, + allowed_tools: None, + }, + event_stream, + cx, + ) + }); + + let err = result.await.unwrap_err(); + assert!( + err.to_string().contains("Maximum parallel subagents"), + "should reject when max parallel subagents reached: {}", + err + ); + + drop(subagents); +} + +#[gpui::test] +async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) { + init_test(cx); + always_allow_tools(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + let parent = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ); + thread.add_tool(EchoTool); + thread + }); + + let parent_tool_names: Vec = vec!["echo".into()]; + + let tool = Arc::new(SubagentTool::new( + parent.downgrade(), + project.clone(), + project_context, + context_server_registry, + Templates::new(), + 0, + parent_tool_names, + )); + + let (event_stream, _rx) = crate::ToolCallEventStream::test(); + + let task = cx.update(|cx| { + tool.run( + SubagentToolInput { + label: "Research task".to_string(), + task_prompt: "Find all TODOs in the codebase".to_string(), + summary_prompt: "Summarize what you found".to_string(), + context_low_prompt: "Context low, wrap up".to_string(), + timeout_ms: None, + allowed_tools: None, + }, + event_stream, + cx, + ) + }); + + cx.run_until_parked(); + + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "subagent should have started and sent a completion request" + ); + + let first_completion = &pending[0]; + let has_task_prompt = first_completion.messages.iter().any(|m| { + m.role == language_model::Role::User + && m.content + .iter() + .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false)) + }); + assert!(has_task_prompt, "task prompt should be sent to subagent"); + + fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase."); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "should have pending completion for summary request" + ); + + let last_completion = pending.last().unwrap(); + let has_summary_prompt = last_completion.messages.iter().any(|m| { + m.role == language_model::Role::User + && m.content.iter().any(|c| { + c.to_str() + .map(|s| s.contains("Summarize") || s.contains("summarize")) + .unwrap_or(false) + }) + }); + assert!( + has_summary_prompt, + "summary prompt should be sent after task completion" + ); + + fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files."); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let result = task.await; + assert!(result.is_ok(), "subagent tool should complete successfully"); + + let summary = result.unwrap(); + assert!( + summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"), + "summary should contain subagent's response: {}", + summary + ); +} + #[gpui::test] async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 9a579b12e4ca2b7d8724cad5b7987dcccb08b2ee..3cbda35c3d7102912f418f0cec269d8b4f45aefc 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -58,6 +58,27 @@ use uuid::Uuid; const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; pub const MAX_TOOL_NAME_LENGTH: usize = 64; +pub const MAX_SUBAGENT_DEPTH: u8 = 4; +pub const MAX_PARALLEL_SUBAGENTS: usize = 8; + +/// Context passed to a subagent thread for lifecycle management +#[derive(Clone)] +pub struct SubagentContext { + /// ID of the parent thread + pub parent_thread_id: acp::SessionId, + + /// ID of the tool call that spawned this subagent + pub tool_use_id: LanguageModelToolUseId, + + /// Current depth level (0 = root agent, 1 = first-level subagent, etc.) + pub depth: u8, + + /// Prompt to send when subagent completes successfully + pub summary_prompt: String, + + /// Prompt to send when context is running low (≤25% remaining) + pub context_low_prompt: String, +} /// The ID of the user prompt that initiated a request. /// @@ -626,6 +647,10 @@ pub struct Thread { pub(crate) file_read_times: HashMap, /// True if this thread was imported from a shared thread and can be synced. imported: bool, + /// If this is a subagent thread, contains context about the parent + subagent_context: Option, + /// Weak references to running subagent threads for cancellation propagation + running_subagents: Vec>, } impl Thread { @@ -683,6 +708,56 @@ impl Thread { action_log, file_read_times: HashMap::default(), imported: false, + subagent_context: None, + running_subagents: Vec::new(), + } + } + + pub fn new_subagent( + project: Entity, + project_context: Entity, + context_server_registry: Entity, + templates: Arc, + model: Arc, + subagent_context: SubagentContext, + cx: &mut Context, + ) -> Self { + let profile_id = AgentSettings::get_global(cx).default_profile.clone(); + let action_log = cx.new(|_cx| ActionLog::new(project.clone())); + let (prompt_capabilities_tx, prompt_capabilities_rx) = + watch::channel(Self::prompt_capabilities(Some(model.as_ref()))); + Self { + id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()), + prompt_id: PromptId::new(), + updated_at: Utc::now(), + title: None, + pending_title_generation: None, + pending_summary_generation: None, + summary: None, + messages: Vec::new(), + user_store: project.read(cx).user_store(), + completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, + running_turn: None, + pending_message: None, + tools: BTreeMap::default(), + tool_use_limit_reached: false, + request_token_usage: HashMap::default(), + cumulative_token_usage: TokenUsage::default(), + initial_project_snapshot: Task::ready(None).shared(), + context_server_registry, + profile_id, + project_context, + templates, + model: Some(model), + summarization_model: None, + prompt_capabilities_tx, + prompt_capabilities_rx, + project, + action_log, + file_read_times: HashMap::default(), + imported: false, + subagent_context: Some(subagent_context), + running_subagents: Vec::new(), } } @@ -880,6 +955,8 @@ impl Thread { prompt_capabilities_rx, file_read_times: HashMap::default(), imported: db_thread.imported, + subagent_context: None, + running_subagents: Vec::new(), } } @@ -984,7 +1061,6 @@ impl Thread { cx.notify() } - #[cfg(any(test, feature = "test-support"))] pub fn last_message(&self) -> Option { if let Some(message) = self.pending_message.clone() { Some(Message::Agent(message)) @@ -1030,8 +1106,17 @@ impl Thread { self.add_tool(ThinkingTool); self.add_tool(WebSearchTool); - if cx.has_flag::() { - self.add_tool(SubagentTool::new()); + if cx.has_flag::() && self.depth() < MAX_SUBAGENT_DEPTH { + let tool_names = self.registered_tool_names(); + self.add_tool(SubagentTool::new( + cx.weak_entity(), + self.project.clone(), + self.project_context.clone(), + self.context_server_registry.clone(), + self.templates.clone(), + self.depth(), + tool_names, + )); } } @@ -1043,6 +1128,10 @@ impl Thread { self.tools.remove(name).is_some() } + pub fn restrict_tools(&mut self, allowed: &collections::HashSet) { + self.tools.retain(|name, _| allowed.contains(name)); + } + pub fn profile(&self) -> &AgentProfileId { &self.profile_id } @@ -1061,6 +1150,12 @@ impl Thread { } pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { + for subagent in self.running_subagents.drain(..) { + if let Some(subagent) = subagent.upgrade() { + subagent.update(cx, |thread, cx| thread.cancel(cx)).detach(); + } + } + let Some(running_turn) = self.running_turn.take() else { self.flush_pending_message(cx); return Task::ready(()); @@ -2138,6 +2233,82 @@ impl Thread { .is_some_and(|turn| turn.tools.contains_key(name)) } + #[cfg(any(test, feature = "test-support"))] + pub fn has_registered_tool(&self, name: &str) -> bool { + self.tools.contains_key(name) + } + + pub fn registered_tool_names(&self) -> Vec { + self.tools.keys().cloned().collect() + } + + pub fn register_running_subagent(&mut self, subagent: WeakEntity) { + self.running_subagents.push(subagent); + } + + pub fn unregister_running_subagent(&mut self, subagent: &WeakEntity) { + self.running_subagents + .retain(|s| s.entity_id() != subagent.entity_id()); + } + + pub fn running_subagent_count(&self) -> usize { + self.running_subagents + .iter() + .filter(|s| s.upgrade().is_some()) + .count() + } + + pub fn is_subagent(&self) -> bool { + self.subagent_context.is_some() + } + + pub fn depth(&self) -> u8 { + self.subagent_context.as_ref().map(|c| c.depth).unwrap_or(0) + } + + pub fn is_turn_complete(&self) -> bool { + self.running_turn.is_none() + } + + pub fn submit_user_message( + &mut self, + content: impl Into, + cx: &mut Context, + ) -> Result>> { + let content = content.into(); + self.messages.push(Message::User(UserMessage { + id: UserMessageId::new(), + content: vec![UserMessageContent::Text(content)], + })); + cx.notify(); + self.send_existing(cx) + } + + pub fn interrupt_for_summary( + &mut self, + cx: &mut Context, + ) -> Result>> { + let context = self + .subagent_context + .as_ref() + .context("Not a subagent thread")?; + let prompt = context.context_low_prompt.clone(); + self.cancel(cx).detach(); + self.submit_user_message(prompt, cx) + } + + pub fn request_final_summary( + &mut self, + cx: &mut Context, + ) -> Result>> { + let context = self + .subagent_context + .as_ref() + .context("Not a subagent thread")?; + let prompt = context.summary_prompt.clone(); + self.submit_user_message(prompt, cx) + } + fn build_request_messages( &self, available_tools: Vec, @@ -2546,10 +2717,7 @@ impl ThreadEventStream { acp::ToolCall::new(id.to_string(), title) .kind(kind) .raw_input(input) - .meta(acp::Meta::from_iter([( - "tool_name".into(), - tool_name.into(), - )])) + .meta(acp_thread::meta_with_tool_name(tool_name)) } fn update_tool_call_fields( @@ -2645,6 +2813,10 @@ impl ToolCallEventStream { *self.cancellation_rx.clone().borrow() } + pub fn tool_use_id(&self) -> &LanguageModelToolUseId { + &self.tool_use_id + } + pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) { self.stream .update_tool_call_fields(&self.tool_use_id, fields); @@ -2663,6 +2835,19 @@ impl ToolCallEventStream { .ok(); } + pub fn update_subagent_thread(&self, thread: Entity) { + self.stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( + acp_thread::ToolCallUpdateSubagentThread { + id: acp::ToolCallId::new(self.tool_use_id.to_string()), + thread, + } + .into(), + ))) + .ok(); + } + pub fn authorize(&self, title: impl Into, cx: &mut App) -> Task> { if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { return Task::ready(Ok(())); diff --git a/crates/agent/src/tools/subagent_tool.rs b/crates/agent/src/tools/subagent_tool.rs index 376bce844fbf67ab79bff4d7611cc7710defda13..e8d650f7d0d6507b976262cb8b0b8973a25658a2 100644 --- a/crates/agent/src/tools/subagent_tool.rs +++ b/crates/agent/src/tools/subagent_tool.rs @@ -1,11 +1,31 @@ +use acp_thread::{AcpThread, AgentConnection, UserMessageId}; +use action_log::ActionLog; use agent_client_protocol as acp; -use anyhow::Result; -use gpui::{App, SharedString, Task}; +use anyhow::{Result, anyhow}; +use collections::HashSet; +use futures::channel::mpsc; +use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity}; +use project::Project; +use prompt_store::ProjectContext; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use smol::stream::StreamExt; +use std::any::Any; +use std::path::Path; +use std::rc::Rc; use std::sync::Arc; +use std::time::Duration; +use util::ResultExt; +use watch; -use crate::{AgentTool, ToolCallEventStream}; +use crate::{ + AgentTool, ContextServerRegistry, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH, SubagentContext, + Templates, Thread, ThreadEvent, ToolCallAuthorization, ToolCallEventStream, +}; + +/// When a subagent's remaining context window falls below this fraction (25%), +/// the "context running out" prompt is sent to encourage the subagent to wrap up. +const CONTEXT_LOW_THRESHOLD: f32 = 0.25; /// Spawns a subagent with its own context window to perform a delegated task. /// @@ -63,11 +83,50 @@ pub struct SubagentToolInput { pub allowed_tools: Option>, } -pub struct SubagentTool; +pub struct SubagentTool { + parent_thread: WeakEntity, + project: Entity, + project_context: Entity, + context_server_registry: Entity, + templates: Arc, + current_depth: u8, + parent_tool_names: HashSet, +} impl SubagentTool { - pub fn new() -> Self { - Self + pub fn new( + parent_thread: WeakEntity, + project: Entity, + project_context: Entity, + context_server_registry: Entity, + templates: Arc, + current_depth: u8, + parent_tool_names: Vec, + ) -> Self { + Self { + parent_thread, + project, + project_context, + context_server_registry, + templates, + current_depth, + parent_tool_names: parent_tool_names.into_iter().collect(), + } + } + + pub fn validate_allowed_tools(&self, allowed_tools: &Option>) -> Result<()> { + if let Some(tools) = allowed_tools { + for tool in tools { + if !self.parent_tool_names.contains(tool.as_str()) { + return Err(anyhow!( + "Tool '{}' is not available to the parent agent. Available tools: {:?}", + tool, + self.parent_tool_names.iter().collect::>() + )); + } + } + } + Ok(()) } } @@ -76,7 +135,7 @@ impl AgentTool for SubagentTool { type Output = String; fn name() -> &'static str { - "subagent" + acp_thread::SUBAGENT_TOOL_NAME } fn kind() -> acp::ToolKind { @@ -88,22 +147,405 @@ impl AgentTool for SubagentTool { input: Result, _cx: &mut App, ) -> SharedString { - match input { - Ok(input) => format!("Subagent: {}", input.label).into(), - Err(_) => "Subagent".into(), - } + input + .map(|i| i.label.into()) + .unwrap_or_else(|_| "Subagent".into()) } fn run( self: Arc, input: Self::Input, event_stream: ToolCallEventStream, - _cx: &mut App, + cx: &mut App, ) -> Task> { - event_stream.update_fields( - acp::ToolCallUpdateFields::new() - .content(vec![format!("Starting subagent: {}", input.label).into()]), + if self.current_depth >= MAX_SUBAGENT_DEPTH { + return Task::ready(Err(anyhow!( + "Maximum subagent depth ({}) reached", + MAX_SUBAGENT_DEPTH + ))); + } + + if let Err(e) = self.validate_allowed_tools(&input.allowed_tools) { + return Task::ready(Err(e)); + } + + let Some(parent_thread) = self.parent_thread.upgrade() else { + return Task::ready(Err(anyhow!( + "Parent thread no longer exists (subagent depth={})", + self.current_depth + 1 + ))); + }; + + let running_count = parent_thread.read(cx).running_subagent_count(); + if running_count >= MAX_PARALLEL_SUBAGENTS { + return Task::ready(Err(anyhow!( + "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.", + MAX_PARALLEL_SUBAGENTS + ))); + } + + let parent_thread_id = parent_thread.read(cx).id().clone(); + let parent_model = parent_thread.read(cx).model().cloned(); + let tool_use_id = event_stream.tool_use_id().clone(); + + let Some(model) = parent_model else { + return Task::ready(Err(anyhow!("No model configured"))); + }; + + let subagent_context = SubagentContext { + parent_thread_id, + tool_use_id, + depth: self.current_depth + 1, + summary_prompt: input.summary_prompt.clone(), + context_low_prompt: input.context_low_prompt.clone(), + }; + + let project = self.project.clone(); + let project_context = self.project_context.clone(); + let context_server_registry = self.context_server_registry.clone(); + let templates = self.templates.clone(); + let task_prompt = input.task_prompt; + let timeout_ms = input.timeout_ms; + let allowed_tools: Option> = input + .allowed_tools + .map(|tools| tools.into_iter().map(SharedString::from).collect()); + + let parent_thread = self.parent_thread.clone(); + + cx.spawn(async move |cx| { + let subagent_thread: Entity = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + templates.clone(), + model, + subagent_context, + cx, + ) + }); + + let subagent_weak = subagent_thread.downgrade(); + + let acp_thread: Entity = cx.new(|cx| { + let session_id = subagent_thread.read(cx).id().clone(); + let action_log: Entity = cx.new(|_| ActionLog::new(project.clone())); + let connection: Rc = Rc::new(SubagentDisplayConnection); + AcpThread::new( + "Subagent", + connection, + project.clone(), + action_log, + session_id, + watch::Receiver::constant(acp::PromptCapabilities::new()), + cx, + ) + }); + + event_stream.update_subagent_thread(acp_thread.clone()); + + if let Some(parent) = parent_thread.upgrade() { + parent.update(cx, |thread, _cx| { + thread.register_running_subagent(subagent_weak.clone()); + }); + } + + let result = run_subagent( + &subagent_thread, + &acp_thread, + allowed_tools, + task_prompt, + timeout_ms, + cx, + ) + .await; + + if let Some(parent) = parent_thread.upgrade() { + let _ = parent.update(cx, |thread, _cx| { + thread.unregister_running_subagent(&subagent_weak); + }); + } + + result + }) + } +} + +async fn run_subagent( + subagent_thread: &Entity, + acp_thread: &Entity, + allowed_tools: Option>, + task_prompt: String, + timeout_ms: Option, + cx: &mut AsyncApp, +) -> Result { + if let Some(ref allowed) = allowed_tools { + subagent_thread.update(cx, |thread, _cx| { + thread.restrict_tools(allowed); + }); + } + + let mut events_rx = + subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?; + + let acp_thread_weak = acp_thread.downgrade(); + + let timed_out = if let Some(timeout) = timeout_ms { + forward_events_with_timeout( + &mut events_rx, + &acp_thread_weak, + Duration::from_millis(timeout), + cx, + ) + .await + } else { + forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await; + false + }; + + let should_interrupt = + timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx); + + if should_interrupt { + let mut summary_rx = + subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?; + forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await; + } else { + let mut summary_rx = + subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?; + forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await; + } + + Ok(extract_last_message(subagent_thread, cx)) +} + +async fn forward_events_until_stop( + events_rx: &mut mpsc::UnboundedReceiver>, + acp_thread: &WeakEntity, + cx: &mut AsyncApp, +) { + while let Some(event) = events_rx.next().await { + match event { + Ok(ThreadEvent::Stop(_)) => break, + Ok(event) => { + forward_event_to_acp_thread(event, acp_thread, cx); + } + Err(_) => break, + } + } +} + +async fn forward_events_with_timeout( + events_rx: &mut mpsc::UnboundedReceiver>, + acp_thread: &WeakEntity, + timeout: Duration, + cx: &mut AsyncApp, +) -> bool { + use futures::future::{self, Either}; + + let deadline = std::time::Instant::now() + timeout; + + loop { + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + if remaining.is_zero() { + return true; + } + + let timeout_future = cx.background_executor().timer(remaining); + let event_future = events_rx.next(); + + match future::select(event_future, timeout_future).await { + Either::Left((event, _)) => match event { + Some(Ok(ThreadEvent::Stop(_))) => return false, + Some(Ok(event)) => { + forward_event_to_acp_thread(event, acp_thread, cx); + } + Some(Err(_)) => return false, + None => return false, + }, + Either::Right((_, _)) => return true, + } + } +} + +fn forward_event_to_acp_thread( + event: ThreadEvent, + acp_thread: &WeakEntity, + cx: &mut AsyncApp, +) { + match event { + ThreadEvent::UserMessage(message) => { + acp_thread + .update(cx, |thread, cx| { + for content in message.content { + thread.push_user_content_block( + Some(message.id.clone()), + content.into(), + cx, + ); + } + }) + .log_err(); + } + ThreadEvent::AgentText(text) => { + acp_thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block(text.into(), false, cx) + }) + .log_err(); + } + ThreadEvent::AgentThinking(text) => { + acp_thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block(text.into(), true, cx) + }) + .log_err(); + } + ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { + tool_call, + options, + response, + }) => { + let outcome_task = acp_thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization(tool_call, options, true, cx) + }); + if let Ok(Ok(task)) = outcome_task { + cx.background_spawn(async move { + if let acp::RequestPermissionOutcome::Selected( + acp::SelectedPermissionOutcome { option_id, .. }, + ) = task.await + { + response.send(option_id).ok(); + } + }) + .detach(); + } + } + ThreadEvent::ToolCall(tool_call) => { + acp_thread + .update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx)) + .log_err(); + } + ThreadEvent::ToolCallUpdate(update) => { + acp_thread + .update(cx, |thread, cx| thread.update_tool_call(update, cx)) + .log_err(); + } + ThreadEvent::Retry(status) => { + acp_thread + .update(cx, |thread, cx| thread.update_retry_status(status, cx)) + .log_err(); + } + ThreadEvent::Stop(_) => {} + } +} + +fn check_context_low(thread: &Entity, threshold: f32, cx: &mut AsyncApp) -> bool { + thread.read_with(cx, |thread, _| { + if let Some(usage) = thread.latest_token_usage() { + let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32); + remaining_ratio <= threshold + } else { + false + } + }) +} + +fn extract_last_message(thread: &Entity, cx: &mut AsyncApp) -> String { + thread.read_with(cx, |thread, _| { + thread + .last_message() + .map(|m| m.to_markdown()) + .unwrap_or_else(|| "No response from subagent".to_string()) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use language_model::LanguageModelToolSchemaFormat; + + #[test] + fn test_subagent_tool_input_json_schema_is_valid() { + let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema); + let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON"); + + assert!( + schema_json.get("properties").is_some(), + "schema should have properties" + ); + let properties = schema_json.get("properties").unwrap(); + + assert!(properties.get("label").is_some(), "should have label field"); + assert!( + properties.get("task_prompt").is_some(), + "should have task_prompt field" + ); + assert!( + properties.get("summary_prompt").is_some(), + "should have summary_prompt field" + ); + assert!( + properties.get("context_low_prompt").is_some(), + "should have context_low_prompt field" ); - Task::ready(Ok("Subagent tool not yet implemented.".to_string())) + assert!( + properties.get("timeout_ms").is_some(), + "should have timeout_ms field" + ); + assert!( + properties.get("allowed_tools").is_some(), + "should have allowed_tools field" + ); + } + + #[test] + fn test_subagent_tool_name() { + assert_eq!(SubagentTool::name(), "subagent"); + } + + #[test] + fn test_subagent_tool_kind() { + assert_eq!(SubagentTool::kind(), acp::ToolKind::Other); + } +} + +struct SubagentDisplayConnection; + +impl AgentConnection for SubagentDisplayConnection { + fn telemetry_id(&self) -> SharedString { + "subagent".into() + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn new_thread( + self: Rc, + _project: Entity, + _cwd: &Path, + _cx: &mut App, + ) -> Task>> { + unimplemented!("SubagentDisplayConnection does not support new_thread") + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task> { + unimplemented!("SubagentDisplayConnection does not support authenticate") + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + unimplemented!("SubagentDisplayConnection does not support prompt") + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {} + + fn into_any(self: Rc) -> Rc { + self } } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 9a9c912408370c469bf6ef0362891cafdf522be5..d39c453df06783dd45cfd56c6c0bc980c0f0d605 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -3309,6 +3309,12 @@ impl AcpThreadView { ToolCallContent::Terminal(terminal) => { self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx) } + ToolCallContent::SubagentThread(_thread) => { + // The subagent's AcpThread entity stores the subagent's conversation + // (messages, tool calls, etc.) but we don't render it here. The entity + // is used for serialization (e.g., to_markdown) and data storage, not display. + Empty.into_any_element() + } } } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 32c649e98c3abdda091fdb895a087eda685c41cc..249b936e1bc14d332d19bd1a2d8f1b986068be3f 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -255,7 +255,7 @@ impl ExampleContext { ThreadEvent::ToolCall(tool_call) => { let meta = tool_call.meta.expect("Missing meta field in tool_call"); let tool_name = meta - .get("tool_name") + .get(acp_thread::TOOL_NAME_META_KEY) .expect("Missing tool_name field in meta") .as_str() .expect("Unknown tool_name content in meta");