subagent_tool.rs

  1use acp_thread::{AcpThread, AgentConnection, UserMessageId};
  2use action_log::ActionLog;
  3use agent_client_protocol as acp;
  4use anyhow::{Result, anyhow};
  5use collections::{BTreeMap, HashSet};
  6use futures::{FutureExt, channel::mpsc};
  7use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity};
  8use language_model::LanguageModelToolUseId;
  9use project::Project;
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use smol::stream::StreamExt;
 13use std::any::Any;
 14use std::path::Path;
 15use std::rc::Rc;
 16use std::sync::Arc;
 17use std::time::Duration;
 18use util::ResultExt;
 19use watch;
 20
 21use crate::{
 22    AgentTool, AnyAgentTool, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH, SubagentContext, Thread,
 23    ThreadEvent, ToolCallAuthorization, ToolCallEventStream,
 24};
 25
 26/// When a subagent's remaining context window falls below this fraction (25%),
 27/// the "context running out" prompt is sent to encourage the subagent to wrap up.
 28const CONTEXT_LOW_THRESHOLD: f32 = 0.25;
 29
 30/// Spawns a subagent with its own context window to perform a delegated task.
 31///
 32/// Use this tool when you want to do any of the following:
 33/// - Perform an investigation where all you need to know is the outcome, not the research that led to that outcome.
 34/// - Complete a self-contained task where you need to know if it succeeded or failed (and how), but none of its intermediate output.
 35/// - Run multiple tasks in parallel that would take significantly longer to run sequentially.
 36///
 37/// You control what the subagent does by providing:
 38/// 1. A task prompt describing what the subagent should do
 39/// 2. A summary prompt that tells the subagent how to summarize its work when done
 40/// 3. A "context running out" prompt for when the subagent is low on tokens
 41///
 42/// Each subagent has access to the same tools you do. You can optionally restrict
 43/// which tools each subagent can use.
 44///
 45/// Note:
 46/// - Maximum 8 subagents can run in parallel
 47/// - Subagents cannot use tools you don't have access to
 48/// - If spawning multiple subagents that might write to the filesystem, provide
 49///   guidance on how to avoid conflicts (e.g. assign each to different directories)
 50/// - Instruct subagents to be concise in their summaries to conserve your context
 51#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 52pub struct SubagentToolInput {
 53    /// Short label displayed in the UI while the subagent runs (e.g., "Researching alternatives")
 54    pub label: String,
 55
 56    /// The initial prompt that tells the subagent what task to perform.
 57    /// Be specific about what you want the subagent to accomplish.
 58    pub task_prompt: String,
 59
 60    /// The prompt sent to the subagent when it completes its task, asking it
 61    /// to summarize what it did and return results. This summary becomes the
 62    /// tool result you receive.
 63    ///
 64    /// Example: "Summarize what you found, listing the top 3 alternatives with pros/cons."
 65    pub summary_prompt: String,
 66
 67    /// The prompt sent if the subagent is running low on context (25% remaining).
 68    /// Should instruct it to stop and summarize progress so far, plus what's left undone.
 69    ///
 70    /// Example: "Context is running low. Stop and summarize your progress so far,
 71    /// and list what remains to be investigated."
 72    pub context_low_prompt: String,
 73
 74    /// Optional: Maximum runtime in milliseconds. If exceeded, the subagent is
 75    /// asked to summarize and return. No timeout by default.
 76    #[serde(default)]
 77    pub timeout_ms: Option<u64>,
 78
 79    /// Optional: List of tool names the subagent is allowed to use.
 80    /// If not provided, the subagent can use all tools available to the parent.
 81    /// Tools listed here must be a subset of the parent's available tools.
 82    #[serde(default)]
 83    pub allowed_tools: Option<Vec<String>>,
 84}
 85
 86/// Tool that spawns a subagent thread to work on a task.
 87pub struct SubagentTool {
 88    parent_thread: WeakEntity<Thread>,
 89    current_depth: u8,
 90}
 91
 92impl SubagentTool {
 93    pub fn new(parent_thread: WeakEntity<Thread>, current_depth: u8) -> Self {
 94        Self {
 95            parent_thread,
 96            current_depth,
 97        }
 98    }
 99
100    pub fn validate_allowed_tools(
101        &self,
102        allowed_tools: &Option<Vec<String>>,
103        cx: &App,
104    ) -> Result<()> {
105        let Some(allowed_tools) = allowed_tools else {
106            return Ok(());
107        };
108
109        let invalid_tools: Vec<_> = self.parent_thread.read_with(cx, |thread, _cx| {
110            allowed_tools
111                .iter()
112                .filter(|tool| !thread.tools.contains_key(tool.as_str()))
113                .map(|s| format!("'{s}'"))
114                .collect()
115        })?;
116
117        if !invalid_tools.is_empty() {
118            return Err(anyhow!(
119                "The following tools do not exist: {}",
120                invalid_tools.join(", ")
121            ));
122        }
123
124        Ok(())
125    }
126}
127
128impl AgentTool for SubagentTool {
129    type Input = SubagentToolInput;
130    type Output = String;
131
132    const NAME: &'static str = acp_thread::SUBAGENT_TOOL_NAME;
133
134    fn kind() -> acp::ToolKind {
135        acp::ToolKind::Other
136    }
137
138    fn initial_title(
139        &self,
140        input: Result<Self::Input, serde_json::Value>,
141        _cx: &mut App,
142    ) -> SharedString {
143        input
144            .map(|i| i.label.into())
145            .unwrap_or_else(|_| "Subagent".into())
146    }
147
148    fn run(
149        self: Arc<Self>,
150        input: Self::Input,
151        event_stream: ToolCallEventStream,
152        cx: &mut App,
153    ) -> Task<Result<String>> {
154        if self.current_depth >= MAX_SUBAGENT_DEPTH {
155            return Task::ready(Err(anyhow!(
156                "Maximum subagent depth ({}) reached",
157                MAX_SUBAGENT_DEPTH
158            )));
159        }
160
161        if let Err(e) = self.validate_allowed_tools(&input.allowed_tools, cx) {
162            return Task::ready(Err(e));
163        }
164
165        let Some(parent_thread_entity) = self.parent_thread.upgrade() else {
166            return Task::ready(Err(anyhow!(
167                "Parent thread no longer exists (subagent depth={})",
168                self.current_depth + 1
169            )));
170        };
171        let parent_thread = parent_thread_entity.read(cx);
172
173        let running_count = parent_thread.running_subagent_count();
174        if running_count >= MAX_PARALLEL_SUBAGENTS {
175            return Task::ready(Err(anyhow!(
176                "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
177                MAX_PARALLEL_SUBAGENTS
178            )));
179        }
180
181        let parent_model = parent_thread.model().cloned();
182        let Some(model) = parent_model else {
183            return Task::ready(Err(anyhow!("No model configured")));
184        };
185
186        let parent_thread_id = parent_thread.id().clone();
187        let project = parent_thread.project.clone();
188        let project_context = parent_thread.project_context().clone();
189        let context_server_registry = parent_thread.context_server_registry.clone();
190        let templates = parent_thread.templates.clone();
191        let parent_tools = parent_thread.tools.clone();
192        let current_depth = self.current_depth;
193        let parent_thread_weak = self.parent_thread.clone();
194
195        cx.spawn(async move |cx| {
196            let subagent_context = SubagentContext {
197                parent_thread_id: parent_thread_id.clone(),
198                tool_use_id: LanguageModelToolUseId::from(uuid::Uuid::new_v4().to_string()),
199                depth: current_depth + 1,
200                summary_prompt: input.summary_prompt.clone(),
201                context_low_prompt: input.context_low_prompt.clone(),
202            };
203
204            // Determine which tools this subagent gets
205            let subagent_tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>> =
206                if let Some(ref allowed) = input.allowed_tools {
207                    let allowed_set: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect();
208                    parent_tools
209                        .iter()
210                        .filter(|(name, _)| allowed_set.contains(name.as_ref()))
211                        .map(|(name, tool)| (name.clone(), tool.clone()))
212                        .collect()
213                } else {
214                    parent_tools.clone()
215                };
216
217            let subagent_thread: Entity<Thread> = cx.new(|cx| {
218                Thread::new_subagent(
219                    project.clone(),
220                    project_context.clone(),
221                    context_server_registry.clone(),
222                    templates.clone(),
223                    model.clone(),
224                    subagent_context,
225                    subagent_tools,
226                    cx,
227                )
228            });
229
230            let subagent_weak = subagent_thread.downgrade();
231
232            let acp_thread: Entity<AcpThread> = cx.new(|cx| {
233                let session_id = subagent_thread.read(cx).id().clone();
234                let action_log: Entity<ActionLog> = cx.new(|_| ActionLog::new(project.clone()));
235                let connection: Rc<dyn AgentConnection> = Rc::new(SubagentDisplayConnection);
236                AcpThread::new(
237                    &input.label,
238                    connection,
239                    project.clone(),
240                    action_log,
241                    session_id,
242                    watch::Receiver::constant(acp::PromptCapabilities::new()),
243                    cx,
244                )
245            });
246
247            event_stream.update_subagent_thread(acp_thread.clone());
248
249            let mut user_stop_rx: watch::Receiver<bool> =
250                acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
251
252            if let Some(parent) = parent_thread_weak.upgrade() {
253                parent.update(cx, |thread, _cx| {
254                    thread.register_running_subagent(subagent_weak.clone());
255                });
256            }
257
258            // Helper to wait for user stop signal on the subagent card
259            let wait_for_user_stop = async {
260                loop {
261                    if *user_stop_rx.borrow() {
262                        return;
263                    }
264                    if user_stop_rx.changed().await.is_err() {
265                        std::future::pending::<()>().await;
266                    }
267                }
268            };
269
270            // Run the subagent, handling cancellation from both:
271            // 1. Parent turn cancellation (event_stream.cancelled_by_user)
272            // 2. Direct user stop on subagent card (user_stop_rx)
273            let result = futures::select! {
274                result = run_subagent(
275                    &subagent_thread,
276                    &acp_thread,
277                    input.task_prompt,
278                    input.timeout_ms,
279                    cx,
280                ).fuse() => result,
281                _ = event_stream.cancelled_by_user().fuse() => {
282                    let _ = subagent_thread.update(cx, |thread, cx| {
283                        thread.cancel(cx).detach();
284                    });
285                    Err(anyhow!("Subagent cancelled by user"))
286                }
287                _ = wait_for_user_stop.fuse() => {
288                    let _ = subagent_thread.update(cx, |thread, cx| {
289                        thread.cancel(cx).detach();
290                    });
291                    Err(anyhow!("Subagent stopped by user"))
292                }
293            };
294
295            if let Some(parent) = parent_thread_weak.upgrade() {
296                let _ = parent.update(cx, |thread, _cx| {
297                    thread.unregister_running_subagent(&subagent_weak);
298                });
299            }
300
301            result
302        })
303    }
304}
305
306async fn run_subagent(
307    subagent_thread: &Entity<Thread>,
308    acp_thread: &Entity<AcpThread>,
309    task_prompt: String,
310    timeout_ms: Option<u64>,
311    cx: &mut AsyncApp,
312) -> Result<String> {
313    let mut events_rx =
314        subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?;
315
316    let acp_thread_weak = acp_thread.downgrade();
317
318    let timed_out = if let Some(timeout) = timeout_ms {
319        forward_events_with_timeout(
320            &mut events_rx,
321            &acp_thread_weak,
322            Duration::from_millis(timeout),
323            cx,
324        )
325        .await
326    } else {
327        forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await;
328        false
329    };
330
331    let should_interrupt =
332        timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx);
333
334    if should_interrupt {
335        let mut summary_rx =
336            subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?;
337        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
338    } else {
339        let mut summary_rx =
340            subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?;
341        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
342    }
343
344    Ok(extract_last_message(subagent_thread, cx))
345}
346
347async fn forward_events_until_stop(
348    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
349    acp_thread: &WeakEntity<AcpThread>,
350    cx: &mut AsyncApp,
351) {
352    while let Some(event) = events_rx.next().await {
353        match event {
354            Ok(ThreadEvent::Stop(_)) => break,
355            Ok(event) => {
356                forward_event_to_acp_thread(event, acp_thread, cx);
357            }
358            Err(_) => break,
359        }
360    }
361}
362
363async fn forward_events_with_timeout(
364    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
365    acp_thread: &WeakEntity<AcpThread>,
366    timeout: Duration,
367    cx: &mut AsyncApp,
368) -> bool {
369    use futures::future::{self, Either};
370
371    let deadline = std::time::Instant::now() + timeout;
372
373    loop {
374        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
375        if remaining.is_zero() {
376            return true;
377        }
378
379        let timeout_future = cx.background_executor().timer(remaining);
380        let event_future = events_rx.next();
381
382        match future::select(event_future, timeout_future).await {
383            Either::Left((event, _)) => match event {
384                Some(Ok(ThreadEvent::Stop(_))) => return false,
385                Some(Ok(event)) => {
386                    forward_event_to_acp_thread(event, acp_thread, cx);
387                }
388                Some(Err(_)) => return false,
389                None => return false,
390            },
391            Either::Right((_, _)) => return true,
392        }
393    }
394}
395
396fn forward_event_to_acp_thread(
397    event: ThreadEvent,
398    acp_thread: &WeakEntity<AcpThread>,
399    cx: &mut AsyncApp,
400) {
401    match event {
402        ThreadEvent::UserMessage(message) => {
403            acp_thread
404                .update(cx, |thread, cx| {
405                    for content in message.content {
406                        thread.push_user_content_block(
407                            Some(message.id.clone()),
408                            content.into(),
409                            cx,
410                        );
411                    }
412                })
413                .log_err();
414        }
415        ThreadEvent::AgentText(text) => {
416            acp_thread
417                .update(cx, |thread, cx| {
418                    thread.push_assistant_content_block(text.into(), false, cx)
419                })
420                .log_err();
421        }
422        ThreadEvent::AgentThinking(text) => {
423            acp_thread
424                .update(cx, |thread, cx| {
425                    thread.push_assistant_content_block(text.into(), true, cx)
426                })
427                .log_err();
428        }
429        ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
430            tool_call,
431            options,
432            response,
433            ..
434        }) => {
435            let outcome_task = acp_thread.update(cx, |thread, cx| {
436                thread.request_tool_call_authorization(tool_call, options, true, cx)
437            });
438            if let Ok(Ok(task)) = outcome_task {
439                cx.background_spawn(async move {
440                    if let acp::RequestPermissionOutcome::Selected(
441                        acp::SelectedPermissionOutcome { option_id, .. },
442                    ) = task.await
443                    {
444                        response.send(option_id).ok();
445                    }
446                })
447                .detach();
448            }
449        }
450        ThreadEvent::ToolCall(tool_call) => {
451            acp_thread
452                .update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx))
453                .log_err();
454        }
455        ThreadEvent::ToolCallUpdate(update) => {
456            acp_thread
457                .update(cx, |thread, cx| thread.update_tool_call(update, cx))
458                .log_err();
459        }
460        ThreadEvent::Retry(status) => {
461            acp_thread
462                .update(cx, |thread, cx| thread.update_retry_status(status, cx))
463                .log_err();
464        }
465        ThreadEvent::Stop(_) => {}
466    }
467}
468
469fn check_context_low(thread: &Entity<Thread>, threshold: f32, cx: &mut AsyncApp) -> bool {
470    thread.read_with(cx, |thread, _| {
471        if let Some(usage) = thread.latest_token_usage() {
472            let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
473            remaining_ratio <= threshold
474        } else {
475            false
476        }
477    })
478}
479
480fn extract_last_message(thread: &Entity<Thread>, cx: &mut AsyncApp) -> String {
481    thread.read_with(cx, |thread, _| {
482        thread
483            .last_message()
484            .map(|m| m.to_markdown())
485            .unwrap_or_else(|| "No response from subagent".to_string())
486    })
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use language_model::LanguageModelToolSchemaFormat;
493
494    #[test]
495    fn test_subagent_tool_input_json_schema_is_valid() {
496        let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema);
497        let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON");
498
499        assert!(
500            schema_json.get("properties").is_some(),
501            "schema should have properties"
502        );
503        let properties = schema_json.get("properties").unwrap();
504
505        assert!(properties.get("label").is_some(), "should have label field");
506        assert!(
507            properties.get("task_prompt").is_some(),
508            "should have task_prompt field"
509        );
510        assert!(
511            properties.get("summary_prompt").is_some(),
512            "should have summary_prompt field"
513        );
514        assert!(
515            properties.get("context_low_prompt").is_some(),
516            "should have context_low_prompt field"
517        );
518        assert!(
519            properties.get("timeout_ms").is_some(),
520            "should have timeout_ms field"
521        );
522        assert!(
523            properties.get("allowed_tools").is_some(),
524            "should have allowed_tools field"
525        );
526    }
527
528    #[test]
529    fn test_subagent_tool_name() {
530        assert_eq!(SubagentTool::NAME, "subagent");
531    }
532
533    #[test]
534    fn test_subagent_tool_kind() {
535        assert_eq!(SubagentTool::kind(), acp::ToolKind::Other);
536    }
537}
538
539struct SubagentDisplayConnection;
540
541impl AgentConnection for SubagentDisplayConnection {
542    fn telemetry_id(&self) -> SharedString {
543        acp_thread::SUBAGENT_TOOL_NAME.into()
544    }
545
546    fn auth_methods(&self) -> &[acp::AuthMethod] {
547        &[]
548    }
549
550    fn new_thread(
551        self: Rc<Self>,
552        _project: Entity<Project>,
553        _cwd: &Path,
554        _cx: &mut App,
555    ) -> Task<Result<Entity<AcpThread>>> {
556        unimplemented!("SubagentDisplayConnection does not support new_thread")
557    }
558
559    fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
560        unimplemented!("SubagentDisplayConnection does not support authenticate")
561    }
562
563    fn prompt(
564        &self,
565        _id: Option<UserMessageId>,
566        _params: acp::PromptRequest,
567        _cx: &mut App,
568    ) -> Task<Result<acp::PromptResponse>> {
569        unimplemented!("SubagentDisplayConnection does not support prompt")
570    }
571
572    fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
573
574    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
575        self
576    }
577}