diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 7f0e1b89af6dd95d8b1bc219a55e17e5bceb0e8b..840c443b11a679cae349a33ffe0a2323445dceea 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -964,6 +964,9 @@ pub struct AcpThread { terminals: HashMap>, pending_terminal_output: HashMap>>, pending_terminal_exit: HashMap, + // subagent cancellation fields + user_stopped: Arc, + user_stop_tx: watch::Sender, } impl From<&AcpThread> for ActionLogTelemetry { @@ -1179,6 +1182,8 @@ impl AcpThread { } }); + let (user_stop_tx, _user_stop_rx) = watch::channel(false); + Self { action_log, shared_buffers: Default::default(), @@ -1195,6 +1200,8 @@ impl AcpThread { terminals: HashMap::default(), pending_terminal_output: HashMap::default(), pending_terminal_exit: HashMap::default(), + user_stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)), + user_stop_tx, } } @@ -1202,6 +1209,21 @@ impl AcpThread { self.prompt_capabilities.clone() } + /// Marks this thread as stopped by user action and signals any listeners. + pub fn stop_by_user(&mut self) { + self.user_stopped + .store(true, std::sync::atomic::Ordering::SeqCst); + self.user_stop_tx.send(true).ok(); + } + + pub fn was_stopped_by_user(&self) -> bool { + self.user_stopped.load(std::sync::atomic::Ordering::SeqCst) + } + + pub fn user_stop_receiver(&self) -> watch::Receiver { + self.user_stop_tx.receiver() + } + pub fn connection(&self) -> &Rc { &self.connection } diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 1fe3df375fa1bdecb906f1b963e71a3f0cecfd56..2cd807b1cdc6c1715a17e28df51f26b0b78d3f60 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -4314,14 +4314,12 @@ async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( SubagentToolInput { - subagents: vec![crate::SubagentConfig { - label: "Long running task".to_string(), - task_prompt: "Do a very long task that takes forever".to_string(), - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - timeout_ms: None, - allowed_tools: None, - }], + label: "Long running task".to_string(), + task_prompt: "Do a very long task that takes forever".to_string(), + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + timeout_ms: None, + allowed_tools: None, }, event_stream.clone(), cx, @@ -4608,15 +4606,8 @@ async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) { parent_tools, )); - let subagent_configs = vec![crate::SubagentConfig { - label: "Test".to_string(), - task_prompt: "Do something".to_string(), - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - timeout_ms: None, - allowed_tools: Some(vec!["nonexistent_tool".to_string()]), - }]; - let result = tool.validate_subagents(&subagent_configs); + let allowed_tools = Some(vec!["nonexistent_tool".to_string()]); + let result = tool.validate_allowed_tools(&allowed_tools); assert!(result.is_err(), "should reject unknown tool"); let err_msg = result.unwrap_err().to_string(); assert!( @@ -4938,14 +4929,12 @@ async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) { let result = cx.update(|cx| { tool.run( SubagentToolInput { - subagents: vec![crate::SubagentConfig { - label: "Test".to_string(), - task_prompt: "Do something".to_string(), - summary_prompt: "Summarize".to_string(), - context_low_prompt: "Context low".to_string(), - timeout_ms: None, - allowed_tools: None, - }], + label: "Test".to_string(), + task_prompt: "Do something".to_string(), + summary_prompt: "Summarize".to_string(), + context_low_prompt: "Context low".to_string(), + timeout_ms: None, + allowed_tools: None, }, event_stream, cx, @@ -5016,14 +5005,12 @@ async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( SubagentToolInput { - subagents: vec![crate::SubagentConfig { - label: "Research task".to_string(), - task_prompt: "Find all TODOs in the codebase".to_string(), - summary_prompt: "Summarize what you found".to_string(), - context_low_prompt: "Context low, wrap up".to_string(), - timeout_ms: None, - allowed_tools: None, - }], + label: "Research task".to_string(), + task_prompt: "Find all TODOs in the codebase".to_string(), + summary_prompt: "Summarize what you found".to_string(), + context_low_prompt: "Context low, wrap up".to_string(), + timeout_ms: None, + allowed_tools: None, }, event_stream, cx, diff --git a/crates/agent/src/tools/subagent_tool.rs b/crates/agent/src/tools/subagent_tool.rs index fd4174e463810071904f18fd7206bf2f2e9cff7c..8176a2bca2b2c4dd109835f5c53ef53f6b16c95d 100644 --- a/crates/agent/src/tools/subagent_tool.rs +++ b/crates/agent/src/tools/subagent_tool.rs @@ -28,15 +28,14 @@ use crate::{ /// the "context running out" prompt is sent to encourage the subagent to wrap up. const CONTEXT_LOW_THRESHOLD: f32 = 0.25; -/// Spawns one or more subagents with their own context windows to perform delegated tasks. -/// Multiple subagents run in parallel. +/// Spawns a subagent with its own context window to perform a delegated task. /// /// Use this tool when you want to do any of the following: /// - Perform an investigation where all you need to know is the outcome, not the research that led to that outcome. /// - Complete a self-contained task where you need to know if it succeeded or failed (and how), but none of its intermediate output. /// - Run multiple tasks in parallel that would take significantly longer to run sequentially. /// -/// You control what each subagent does by providing: +/// You control what the subagent does by providing: /// 1. A task prompt describing what the subagent should do /// 2. A summary prompt that tells the subagent how to summarize its work when done /// 3. A "context running out" prompt for when the subagent is low on tokens @@ -50,17 +49,8 @@ const CONTEXT_LOW_THRESHOLD: f32 = 0.25; /// - If spawning multiple subagents that might write to the filesystem, provide /// guidance on how to avoid conflicts (e.g. assign each to different directories) /// - Instruct subagents to be concise in their summaries to conserve your context -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct SubagentToolInput { - /// The list of subagents to spawn. At least one is required. - /// All subagents run in parallel and their results are collected. - #[schemars(length(min = 1, max = 8))] - pub subagents: Vec, -} - -/// Configuration for a single subagent. #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] -pub struct SubagentConfig { +pub struct SubagentToolInput { /// Short label displayed in the UI while the subagent runs (e.g., "Researching alternatives") pub label: String, @@ -94,7 +84,7 @@ pub struct SubagentConfig { pub allowed_tools: Option>, } -/// Tool that spawns subagent threads to work on tasks in parallel. +/// Tool that spawns a subagent thread to work on a task. pub struct SubagentTool { parent_thread: WeakEntity, project: Entity, @@ -129,37 +119,21 @@ impl SubagentTool { } } - pub fn validate_subagents(&self, subagents: &[SubagentConfig]) -> Result<()> { - if subagents.is_empty() { - return Err(anyhow!("At least one subagent configuration is required")); - } - - if subagents.len() > MAX_PARALLEL_SUBAGENTS { - return Err(anyhow!( - "Maximum {} subagents can be spawned at once, but {} were requested", - MAX_PARALLEL_SUBAGENTS, - subagents.len() - )); - } + pub fn validate_allowed_tools(&self, allowed_tools: &Option>) -> Result<()> { + let Some(tools) = allowed_tools else { + return Ok(()); + }; - // Collect all invalid tools across all subagents - let mut all_invalid_tools: Vec = Vec::new(); - for config in subagents { - if let Some(ref tools) = config.allowed_tools { - for tool in tools { - if !self.parent_tools.contains_key(tool.as_str()) - && !all_invalid_tools.contains(tool) - { - all_invalid_tools.push(tool.clone()); - } - } - } - } + let invalid_tools: Vec<&str> = tools + .iter() + .filter(|tool| !self.parent_tools.contains_key(tool.as_str())) + .map(|s| s.as_str()) + .collect(); - if !all_invalid_tools.is_empty() { + if !invalid_tools.is_empty() { return Err(anyhow!( "The following tools do not exist: {}", - all_invalid_tools + invalid_tools .iter() .map(|t| format!("'{}'", t)) .collect::>() @@ -189,14 +163,8 @@ impl AgentTool for SubagentTool { _cx: &mut App, ) -> SharedString { input - .map(|i| { - if i.subagents.len() == 1 { - i.subagents[0].label.clone().into() - } else { - format!("{} subagents", i.subagents.len()).into() - } - }) - .unwrap_or_else(|_| "Subagents".into()) + .map(|i| i.label.into()) + .unwrap_or_else(|_| "Subagent".into()) } fn run( @@ -212,7 +180,7 @@ impl AgentTool for SubagentTool { ))); } - if let Err(e) = self.validate_subagents(&input.subagents) { + if let Err(e) = self.validate_allowed_tools(&input.allowed_tools) { return Task::ready(Err(e)); } @@ -224,23 +192,13 @@ impl AgentTool for SubagentTool { }; let running_count = parent_thread.read(cx).running_subagent_count(); - let available_slots = MAX_PARALLEL_SUBAGENTS.saturating_sub(running_count); - if available_slots == 0 { + if running_count >= MAX_PARALLEL_SUBAGENTS { return Task::ready(Err(anyhow!( "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.", MAX_PARALLEL_SUBAGENTS ))); } - if input.subagents.len() > available_slots { - return Task::ready(Err(anyhow!( - "Cannot spawn {} subagents: only {} slots available (max {} parallel)", - input.subagents.len(), - available_slots, - MAX_PARALLEL_SUBAGENTS - ))); - } - let parent_model = parent_thread.read(cx).model().cloned(); let Some(model) = parent_model else { return Task::ready(Err(anyhow!("No model configured"))); @@ -255,151 +213,113 @@ impl AgentTool for SubagentTool { let current_depth = self.current_depth; let parent_thread_weak = self.parent_thread.clone(); - // Spawn all subagents in parallel - let subagent_configs = input.subagents; - cx.spawn(async move |cx| { - // Create all subagent threads upfront so we can track them for cancellation - let mut subagent_data: Vec<( - String, // label - Entity, // subagent thread - Entity, // acp thread for display - String, // task prompt - Option, // timeout - )> = Vec::new(); - - for config in subagent_configs { - let subagent_context = SubagentContext { - parent_thread_id: parent_thread_id.clone(), - tool_use_id: LanguageModelToolUseId::from(uuid::Uuid::new_v4().to_string()), - depth: current_depth + 1, - summary_prompt: config.summary_prompt.clone(), - context_low_prompt: config.context_low_prompt.clone(), + let subagent_context = SubagentContext { + parent_thread_id: parent_thread_id.clone(), + tool_use_id: LanguageModelToolUseId::from(uuid::Uuid::new_v4().to_string()), + depth: current_depth + 1, + summary_prompt: input.summary_prompt.clone(), + context_low_prompt: input.context_low_prompt.clone(), + }; + + // Determine which tools this subagent gets + let subagent_tools: BTreeMap> = + if let Some(ref allowed) = input.allowed_tools { + let allowed_set: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect(); + parent_tools + .iter() + .filter(|(name, _)| allowed_set.contains(name.as_ref())) + .map(|(name, tool)| (name.clone(), tool.clone())) + .collect() + } else { + parent_tools.clone() }; - // Determine which tools this subagent gets - let subagent_tools: BTreeMap> = - if let Some(ref allowed) = config.allowed_tools { - let allowed_set: HashSet<&str> = - allowed.iter().map(|s| s.as_str()).collect(); - parent_tools - .iter() - .filter(|(name, _)| allowed_set.contains(name.as_ref())) - .map(|(name, tool)| (name.clone(), tool.clone())) - .collect() - } else { - parent_tools.clone() - }; - - let label = config.label.clone(); - let task_prompt = config.task_prompt.clone(); - let timeout_ms = config.timeout_ms; - - let subagent_thread: Entity = cx.new(|cx| { - Thread::new_subagent( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - templates.clone(), - model.clone(), - subagent_context, - subagent_tools, - cx, - ) - }); + let subagent_thread: Entity = cx.new(|cx| { + Thread::new_subagent( + project.clone(), + project_context.clone(), + context_server_registry.clone(), + templates.clone(), + model.clone(), + subagent_context, + subagent_tools, + cx, + ) + }); - let subagent_weak = subagent_thread.downgrade(); - - let acp_thread: Entity = cx.new(|cx| { - let session_id = subagent_thread.read(cx).id().clone(); - let action_log: Entity = cx.new(|_| ActionLog::new(project.clone())); - let connection: Rc = Rc::new(SubagentDisplayConnection); - AcpThread::new( - &label, - connection, - project.clone(), - action_log, - session_id, - watch::Receiver::constant(acp::PromptCapabilities::new()), - cx, - ) - }); + let subagent_weak = subagent_thread.downgrade(); + + let acp_thread: Entity = cx.new(|cx| { + let session_id = subagent_thread.read(cx).id().clone(); + let action_log: Entity = cx.new(|_| ActionLog::new(project.clone())); + let connection: Rc = Rc::new(SubagentDisplayConnection); + AcpThread::new( + &input.label, + connection, + project.clone(), + action_log, + session_id, + watch::Receiver::constant(acp::PromptCapabilities::new()), + cx, + ) + }); - event_stream.update_subagent_thread(acp_thread.clone()); + event_stream.update_subagent_thread(acp_thread.clone()); - if let Some(parent) = parent_thread_weak.upgrade() { - parent.update(cx, |thread, _cx| { - thread.register_running_subagent(subagent_weak.clone()); - }); - } + let mut user_stop_rx: watch::Receiver = + acp_thread.update(cx, |thread, _| thread.user_stop_receiver()); - subagent_data.push((label, subagent_thread, acp_thread, task_prompt, timeout_ms)); + if let Some(parent) = parent_thread_weak.upgrade() { + parent.update(cx, |thread, _cx| { + thread.register_running_subagent(subagent_weak.clone()); + }); } - // Collect weak refs for cancellation cleanup - let subagent_threads: Vec> = subagent_data - .iter() - .map(|(_, thread, _, _, _)| thread.downgrade()) - .collect(); - - // Spawn tasks for each subagent - let tasks: Vec<_> = subagent_data - .into_iter() - .map( - |(label, subagent_thread, acp_thread, task_prompt, timeout_ms)| { - let parent_thread_weak = parent_thread_weak.clone(); - cx.spawn(async move |cx| { - let subagent_weak = subagent_thread.downgrade(); - - let result = run_subagent( - &subagent_thread, - &acp_thread, - task_prompt, - timeout_ms, - cx, - ) - .await; - - if let Some(parent) = parent_thread_weak.upgrade() { - let _ = parent.update(cx, |thread, _cx| { - thread.unregister_running_subagent(&subagent_weak); - }); - } - - (label, result) - }) - }, - ) - .collect(); - - // Wait for all subagents to complete, or cancellation - let results: Vec<(String, Result)> = futures::select! { - results = futures::future::join_all(tasks).fuse() => results, - _ = event_stream.cancelled_by_user().fuse() => { - // Cancel all running subagents - for subagent_weak in &subagent_threads { - if let Some(subagent) = subagent_weak.upgrade() { - let _ = subagent.update(cx, |thread, cx| { - thread.cancel(cx).detach(); - }); - } + // Helper to wait for user stop signal on the subagent card + let wait_for_user_stop = async { + loop { + if *user_stop_rx.borrow() { + return; + } + if user_stop_rx.changed().await.is_err() { + std::future::pending::<()>().await; } - anyhow::bail!("Subagent tool cancelled by user"); } }; - // Format the combined results - let mut output = String::new(); - for (label, result) in &results { - output.push_str(&format!("## {}\n\n", label)); - match result { - Ok(summary) => output.push_str(&summary), - Err(e) => output.push_str(&format!("Error: {}", e)), + // Run the subagent, handling cancellation from both: + // 1. Parent turn cancellation (event_stream.cancelled_by_user) + // 2. Direct user stop on subagent card (user_stop_rx) + let result = futures::select! { + result = run_subagent( + &subagent_thread, + &acp_thread, + input.task_prompt, + input.timeout_ms, + cx, + ).fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + let _ = subagent_thread.update(cx, |thread, cx| { + thread.cancel(cx).detach(); + }); + Err(anyhow!("Subagent cancelled by user")) } - output.push_str("\n\n"); + _ = wait_for_user_stop.fuse() => { + let _ = subagent_thread.update(cx, |thread, cx| { + thread.cancel(cx).detach(); + }); + Err(anyhow!("Subagent stopped by user")) + } + }; + + if let Some(parent) = parent_thread_weak.upgrade() { + let _ = parent.update(cx, |thread, _cx| { + thread.unregister_running_subagent(&subagent_weak); + }); } - Ok(output.trim().to_string()) + result }) } } @@ -603,52 +523,26 @@ mod tests { ); let properties = schema_json.get("properties").unwrap(); + assert!(properties.get("label").is_some(), "should have label field"); assert!( - properties.get("subagents").is_some(), - "should have subagents field" - ); - - let subagents_schema = properties.get("subagents").unwrap(); - assert!( - subagents_schema.get("items").is_some(), - "subagents should have items schema" - ); - - // The items use a $ref to definitions/SubagentConfig, so we need to look up - // the actual schema in the definitions section - let definitions = schema_json - .get("definitions") - .expect("schema should have definitions"); - let subagent_config_schema = definitions - .get("SubagentConfig") - .expect("definitions should have SubagentConfig"); - let item_properties = subagent_config_schema - .get("properties") - .expect("SubagentConfig should have properties"); - - assert!( - item_properties.get("label").is_some(), - "subagent item should have label field" - ); - assert!( - item_properties.get("task_prompt").is_some(), - "subagent item should have task_prompt field" + properties.get("task_prompt").is_some(), + "should have task_prompt field" ); assert!( - item_properties.get("summary_prompt").is_some(), - "subagent item should have summary_prompt field" + properties.get("summary_prompt").is_some(), + "should have summary_prompt field" ); assert!( - item_properties.get("context_low_prompt").is_some(), - "subagent item should have context_low_prompt field" + properties.get("context_low_prompt").is_some(), + "should have context_low_prompt field" ); assert!( - item_properties.get("timeout_ms").is_some(), - "subagent item should have timeout_ms field" + properties.get("timeout_ms").is_some(), + "should have timeout_ms field" ); assert!( - item_properties.get("allowed_tools").is_some(), - "subagent item should have allowed_tools field" + properties.get("allowed_tools").is_some(), + "should have allowed_tools field" ); } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 68c62e4ef5500ad22af2b9a285c90572afa039bf..92e2a92c16cb86cd3f640981e3d1928054866a77 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -52,9 +52,9 @@ use text::{Anchor, ToPoint as _}; use theme::{AgentFontSize, ThemeSettings}; use ui::{ Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, CopyButton, DecoratedIcon, - DiffStat, Disclosure, Divider, DividerColor, IconDecoration, IconDecorationKind, KeyBinding, - PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*, - right_click_menu, + DiffStat, Disclosure, Divider, DividerColor, IconButtonShape, IconDecoration, + IconDecorationKind, KeyBinding, PopoverMenu, PopoverMenuHandle, SpinnerLabel, TintColor, + Tooltip, WithScrollbar, prelude::*, right_click_menu, }; use util::defer; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; @@ -3827,30 +3827,70 @@ impl AcpThreadView { ) }), ) - .when(has_expandable_content, |this| { - this.child( - Disclosure::new( - SharedString::from(format!( - "subagent-disclosure-inner-{}-{}", - entry_ix, context_ix - )), - is_expanded, - ) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .visible_on_hover(card_header_id) - .on_click(cx.listener({ - move |this, _, _, cx| { - if this.expanded_subagents.contains(&session_id) { - this.expanded_subagents.remove(&session_id); + .child( + h_flex() + .gap_1p5() + .when(is_running, |buttons| { + buttons.child( + Button::new( + SharedString::from(format!( + "stop-subagent-{}-{}", + entry_ix, context_ix + )), + "Stop", + ) + .icon(IconName::Stop) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Error) + .label_size(LabelSize::Small) + .tooltip(Tooltip::text("Stop this subagent")) + .on_click({ + let thread = thread.clone(); + cx.listener(move |_this, _event, _window, cx| { + thread.update(cx, |thread, _cx| { + thread.stop_by_user(); + }); + }) + }), + ) + }) + .child( + IconButton::new( + SharedString::from(format!( + "subagent-disclosure-{}-{}", + entry_ix, context_ix + )), + if is_expanded { + IconName::ChevronUp } else { - this.expanded_subagents.insert(session_id.clone()); - } - cx.notify(); - } - })), - ) - }), + IconName::ChevronDown + }, + ) + .shape(IconButtonShape::Square) + .icon_color(Color::Muted) + .icon_size(IconSize::Small) + .disabled(!has_expandable_content) + .when(has_expandable_content, |button| { + button.on_click(cx.listener({ + move |this, _, _, cx| { + if this.expanded_subagents.contains(&session_id) { + this.expanded_subagents.remove(&session_id); + } else { + this.expanded_subagents.insert(session_id.clone()); + } + cx.notify(); + } + })) + }) + .when( + !has_expandable_content, + |button| { + button.tooltip(Tooltip::text("Waiting for content...")) + }, + ), + ), + ), ) .when(is_expanded, |this| { this.child(