subagent_tool.rs

  1use acp_thread::{AcpThread, AgentConnection, UserMessageId};
  2use action_log::ActionLog;
  3use agent_client_protocol as acp;
  4use anyhow::{Result, anyhow};
  5use collections::{BTreeMap, HashSet};
  6use futures::{FutureExt, channel::mpsc};
  7use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity};
  8use language_model::LanguageModelToolUseId;
  9use project::Project;
 10use prompt_store::ProjectContext;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use smol::stream::StreamExt;
 14use std::any::Any;
 15use std::path::Path;
 16use std::rc::Rc;
 17use std::sync::Arc;
 18use std::time::Duration;
 19use util::ResultExt;
 20use watch;
 21
 22use crate::{
 23    AgentTool, AnyAgentTool, ContextServerRegistry, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH,
 24    SubagentContext, Templates, Thread, ThreadEvent, ToolCallAuthorization, ToolCallEventStream,
 25};
 26
 27/// When a subagent's remaining context window falls below this fraction (25%),
 28/// the "context running out" prompt is sent to encourage the subagent to wrap up.
 29const CONTEXT_LOW_THRESHOLD: f32 = 0.25;
 30
 31/// Spawns one or more subagents with their own context windows to perform delegated tasks.
 32/// Multiple subagents run in parallel.
 33///
 34/// Use this tool when you want to do any of the following:
 35/// - Perform an investigation where all you need to know is the outcome, not the research that led to that outcome.
 36/// - Complete a self-contained task where you need to know if it succeeded or failed (and how), but none of its intermediate output.
 37/// - Run multiple tasks in parallel that would take significantly longer to run sequentially.
 38///
 39/// You control what each subagent does by providing:
 40/// 1. A task prompt describing what the subagent should do
 41/// 2. A summary prompt that tells the subagent how to summarize its work when done
 42/// 3. A "context running out" prompt for when the subagent is low on tokens
 43///
 44/// Each subagent has access to the same tools you do. You can optionally restrict
 45/// which tools each subagent can use.
 46///
 47/// Note:
 48/// - Maximum 8 subagents can run in parallel
 49/// - Subagents cannot use tools you don't have access to
 50/// - If spawning multiple subagents that might write to the filesystem, provide
 51///   guidance on how to avoid conflicts (e.g. assign each to different directories)
 52/// - Instruct subagents to be concise in their summaries to conserve your context
 53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 54pub struct SubagentToolInput {
 55    /// The list of subagents to spawn. At least one is required.
 56    /// All subagents run in parallel and their results are collected.
 57    #[schemars(length(min = 1, max = 8))]
 58    pub subagents: Vec<SubagentConfig>,
 59}
 60
 61/// Configuration for a single subagent.
 62#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 63pub struct SubagentConfig {
 64    /// Short label displayed in the UI while the subagent runs (e.g., "Researching alternatives")
 65    pub label: String,
 66
 67    /// The initial prompt that tells the subagent what task to perform.
 68    /// Be specific about what you want the subagent to accomplish.
 69    pub task_prompt: String,
 70
 71    /// The prompt sent to the subagent when it completes its task, asking it
 72    /// to summarize what it did and return results. This summary becomes the
 73    /// tool result you receive.
 74    ///
 75    /// Example: "Summarize what you found, listing the top 3 alternatives with pros/cons."
 76    pub summary_prompt: String,
 77
 78    /// The prompt sent if the subagent is running low on context (25% remaining).
 79    /// Should instruct it to stop and summarize progress so far, plus what's left undone.
 80    ///
 81    /// Example: "Context is running low. Stop and summarize your progress so far,
 82    /// and list what remains to be investigated."
 83    pub context_low_prompt: String,
 84
 85    /// Optional: Maximum runtime in milliseconds. If exceeded, the subagent is
 86    /// asked to summarize and return. No timeout by default.
 87    #[serde(default)]
 88    pub timeout_ms: Option<u64>,
 89
 90    /// Optional: List of tool names the subagent is allowed to use.
 91    /// If not provided, the subagent can use all tools available to the parent.
 92    /// Tools listed here must be a subset of the parent's available tools.
 93    #[serde(default)]
 94    pub allowed_tools: Option<Vec<String>>,
 95}
 96
 97/// Tool that spawns subagent threads to work on tasks in parallel.
 98pub struct SubagentTool {
 99    parent_thread: WeakEntity<Thread>,
100    project: Entity<Project>,
101    project_context: Entity<ProjectContext>,
102    context_server_registry: Entity<ContextServerRegistry>,
103    templates: Arc<Templates>,
104    current_depth: u8,
105    /// The tools available to the parent thread, captured before SubagentTool was added.
106    /// Subagents inherit from this set (or a subset via `allowed_tools` in the config).
107    /// This is captured early so subagents don't get the subagent tool themselves.
108    parent_tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
109}
110
111impl SubagentTool {
112    pub fn new(
113        parent_thread: WeakEntity<Thread>,
114        project: Entity<Project>,
115        project_context: Entity<ProjectContext>,
116        context_server_registry: Entity<ContextServerRegistry>,
117        templates: Arc<Templates>,
118        current_depth: u8,
119        parent_tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
120    ) -> Self {
121        Self {
122            parent_thread,
123            project,
124            project_context,
125            context_server_registry,
126            templates,
127            current_depth,
128            parent_tools,
129        }
130    }
131
132    pub fn validate_subagents(&self, subagents: &[SubagentConfig]) -> Result<()> {
133        if subagents.is_empty() {
134            return Err(anyhow!("At least one subagent configuration is required"));
135        }
136
137        if subagents.len() > MAX_PARALLEL_SUBAGENTS {
138            return Err(anyhow!(
139                "Maximum {} subagents can be spawned at once, but {} were requested",
140                MAX_PARALLEL_SUBAGENTS,
141                subagents.len()
142            ));
143        }
144
145        // Collect all invalid tools across all subagents
146        let mut all_invalid_tools: Vec<String> = Vec::new();
147        for config in subagents {
148            if let Some(ref tools) = config.allowed_tools {
149                for tool in tools {
150                    if !self.parent_tools.contains_key(tool.as_str())
151                        && !all_invalid_tools.contains(tool)
152                    {
153                        all_invalid_tools.push(tool.clone());
154                    }
155                }
156            }
157        }
158
159        if !all_invalid_tools.is_empty() {
160            return Err(anyhow!(
161                "The following tools do not exist: {}",
162                all_invalid_tools
163                    .iter()
164                    .map(|t| format!("'{}'", t))
165                    .collect::<Vec<_>>()
166                    .join(", ")
167            ));
168        }
169
170        Ok(())
171    }
172}
173
174impl AgentTool for SubagentTool {
175    type Input = SubagentToolInput;
176    type Output = String;
177
178    fn name() -> &'static str {
179        acp_thread::SUBAGENT_TOOL_NAME
180    }
181
182    fn kind() -> acp::ToolKind {
183        acp::ToolKind::Other
184    }
185
186    fn initial_title(
187        &self,
188        input: Result<Self::Input, serde_json::Value>,
189        _cx: &mut App,
190    ) -> SharedString {
191        input
192            .map(|i| {
193                if i.subagents.len() == 1 {
194                    i.subagents[0].label.clone().into()
195                } else {
196                    format!("{} subagents", i.subagents.len()).into()
197                }
198            })
199            .unwrap_or_else(|_| "Subagents".into())
200    }
201
202    fn run(
203        self: Arc<Self>,
204        input: Self::Input,
205        event_stream: ToolCallEventStream,
206        cx: &mut App,
207    ) -> Task<Result<String>> {
208        if self.current_depth >= MAX_SUBAGENT_DEPTH {
209            return Task::ready(Err(anyhow!(
210                "Maximum subagent depth ({}) reached",
211                MAX_SUBAGENT_DEPTH
212            )));
213        }
214
215        if let Err(e) = self.validate_subagents(&input.subagents) {
216            return Task::ready(Err(e));
217        }
218
219        let Some(parent_thread) = self.parent_thread.upgrade() else {
220            return Task::ready(Err(anyhow!(
221                "Parent thread no longer exists (subagent depth={})",
222                self.current_depth + 1
223            )));
224        };
225
226        let running_count = parent_thread.read(cx).running_subagent_count();
227        let available_slots = MAX_PARALLEL_SUBAGENTS.saturating_sub(running_count);
228        if available_slots == 0 {
229            return Task::ready(Err(anyhow!(
230                "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
231                MAX_PARALLEL_SUBAGENTS
232            )));
233        }
234
235        if input.subagents.len() > available_slots {
236            return Task::ready(Err(anyhow!(
237                "Cannot spawn {} subagents: only {} slots available (max {} parallel)",
238                input.subagents.len(),
239                available_slots,
240                MAX_PARALLEL_SUBAGENTS
241            )));
242        }
243
244        let parent_model = parent_thread.read(cx).model().cloned();
245        let Some(model) = parent_model else {
246            return Task::ready(Err(anyhow!("No model configured")));
247        };
248
249        let parent_thread_id = parent_thread.read(cx).id().clone();
250        let project = self.project.clone();
251        let project_context = self.project_context.clone();
252        let context_server_registry = self.context_server_registry.clone();
253        let templates = self.templates.clone();
254        let parent_tools = self.parent_tools.clone();
255        let current_depth = self.current_depth;
256        let parent_thread_weak = self.parent_thread.clone();
257
258        // Spawn all subagents in parallel
259        let subagent_configs = input.subagents;
260
261        cx.spawn(async move |cx| {
262            // Create all subagent threads upfront so we can track them for cancellation
263            let mut subagent_data: Vec<(
264                String,            // label
265                Entity<Thread>,    // subagent thread
266                Entity<AcpThread>, // acp thread for display
267                String,            // task prompt
268                Option<u64>,       // timeout
269            )> = Vec::new();
270
271            for config in subagent_configs {
272                let subagent_context = SubagentContext {
273                    parent_thread_id: parent_thread_id.clone(),
274                    tool_use_id: LanguageModelToolUseId::from(uuid::Uuid::new_v4().to_string()),
275                    depth: current_depth + 1,
276                    summary_prompt: config.summary_prompt.clone(),
277                    context_low_prompt: config.context_low_prompt.clone(),
278                };
279
280                // Determine which tools this subagent gets
281                let subagent_tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>> =
282                    if let Some(ref allowed) = config.allowed_tools {
283                        let allowed_set: HashSet<&str> =
284                            allowed.iter().map(|s| s.as_str()).collect();
285                        parent_tools
286                            .iter()
287                            .filter(|(name, _)| allowed_set.contains(name.as_ref()))
288                            .map(|(name, tool)| (name.clone(), tool.clone()))
289                            .collect()
290                    } else {
291                        parent_tools.clone()
292                    };
293
294                let label = config.label.clone();
295                let task_prompt = config.task_prompt.clone();
296                let timeout_ms = config.timeout_ms;
297
298                let subagent_thread: Entity<Thread> = cx.new(|cx| {
299                    Thread::new_subagent(
300                        project.clone(),
301                        project_context.clone(),
302                        context_server_registry.clone(),
303                        templates.clone(),
304                        model.clone(),
305                        subagent_context,
306                        subagent_tools,
307                        cx,
308                    )
309                });
310
311                let subagent_weak = subagent_thread.downgrade();
312
313                let acp_thread: Entity<AcpThread> = cx.new(|cx| {
314                    let session_id = subagent_thread.read(cx).id().clone();
315                    let action_log: Entity<ActionLog> = cx.new(|_| ActionLog::new(project.clone()));
316                    let connection: Rc<dyn AgentConnection> = Rc::new(SubagentDisplayConnection);
317                    AcpThread::new(
318                        &label,
319                        connection,
320                        project.clone(),
321                        action_log,
322                        session_id,
323                        watch::Receiver::constant(acp::PromptCapabilities::new()),
324                        cx,
325                    )
326                });
327
328                event_stream.update_subagent_thread(acp_thread.clone());
329
330                if let Some(parent) = parent_thread_weak.upgrade() {
331                    parent.update(cx, |thread, _cx| {
332                        thread.register_running_subagent(subagent_weak.clone());
333                    });
334                }
335
336                subagent_data.push((label, subagent_thread, acp_thread, task_prompt, timeout_ms));
337            }
338
339            // Collect weak refs for cancellation cleanup
340            let subagent_threads: Vec<WeakEntity<Thread>> = subagent_data
341                .iter()
342                .map(|(_, thread, _, _, _)| thread.downgrade())
343                .collect();
344
345            // Spawn tasks for each subagent
346            let tasks: Vec<_> = subagent_data
347                .into_iter()
348                .map(
349                    |(label, subagent_thread, acp_thread, task_prompt, timeout_ms)| {
350                        let parent_thread_weak = parent_thread_weak.clone();
351                        cx.spawn(async move |cx| {
352                            let subagent_weak = subagent_thread.downgrade();
353
354                            let result = run_subagent(
355                                &subagent_thread,
356                                &acp_thread,
357                                task_prompt,
358                                timeout_ms,
359                                cx,
360                            )
361                            .await;
362
363                            if let Some(parent) = parent_thread_weak.upgrade() {
364                                let _ = parent.update(cx, |thread, _cx| {
365                                    thread.unregister_running_subagent(&subagent_weak);
366                                });
367                            }
368
369                            (label, result)
370                        })
371                    },
372                )
373                .collect();
374
375            // Wait for all subagents to complete, or cancellation
376            let results: Vec<(String, Result<String>)> = futures::select! {
377                results = futures::future::join_all(tasks).fuse() => results,
378                _ = event_stream.cancelled_by_user().fuse() => {
379                    // Cancel all running subagents
380                    for subagent_weak in &subagent_threads {
381                        if let Some(subagent) = subagent_weak.upgrade() {
382                            let _ = subagent.update(cx, |thread, cx| {
383                                thread.cancel(cx).detach();
384                            });
385                        }
386                    }
387                    anyhow::bail!("Subagent tool cancelled by user");
388                }
389            };
390
391            // Format the combined results
392            let mut output = String::new();
393            for (label, result) in &results {
394                output.push_str(&format!("## {}\n\n", label));
395                match result {
396                    Ok(summary) => output.push_str(&summary),
397                    Err(e) => output.push_str(&format!("Error: {}", e)),
398                }
399                output.push_str("\n\n");
400            }
401
402            Ok(output.trim().to_string())
403        })
404    }
405}
406
407async fn run_subagent(
408    subagent_thread: &Entity<Thread>,
409    acp_thread: &Entity<AcpThread>,
410    task_prompt: String,
411    timeout_ms: Option<u64>,
412    cx: &mut AsyncApp,
413) -> Result<String> {
414    let mut events_rx =
415        subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?;
416
417    let acp_thread_weak = acp_thread.downgrade();
418
419    let timed_out = if let Some(timeout) = timeout_ms {
420        forward_events_with_timeout(
421            &mut events_rx,
422            &acp_thread_weak,
423            Duration::from_millis(timeout),
424            cx,
425        )
426        .await
427    } else {
428        forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await;
429        false
430    };
431
432    let should_interrupt =
433        timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx);
434
435    if should_interrupt {
436        let mut summary_rx =
437            subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?;
438        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
439    } else {
440        let mut summary_rx =
441            subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?;
442        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
443    }
444
445    Ok(extract_last_message(subagent_thread, cx))
446}
447
448async fn forward_events_until_stop(
449    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
450    acp_thread: &WeakEntity<AcpThread>,
451    cx: &mut AsyncApp,
452) {
453    while let Some(event) = events_rx.next().await {
454        match event {
455            Ok(ThreadEvent::Stop(_)) => break,
456            Ok(event) => {
457                forward_event_to_acp_thread(event, acp_thread, cx);
458            }
459            Err(_) => break,
460        }
461    }
462}
463
464async fn forward_events_with_timeout(
465    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
466    acp_thread: &WeakEntity<AcpThread>,
467    timeout: Duration,
468    cx: &mut AsyncApp,
469) -> bool {
470    use futures::future::{self, Either};
471
472    let deadline = std::time::Instant::now() + timeout;
473
474    loop {
475        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
476        if remaining.is_zero() {
477            return true;
478        }
479
480        let timeout_future = cx.background_executor().timer(remaining);
481        let event_future = events_rx.next();
482
483        match future::select(event_future, timeout_future).await {
484            Either::Left((event, _)) => match event {
485                Some(Ok(ThreadEvent::Stop(_))) => return false,
486                Some(Ok(event)) => {
487                    forward_event_to_acp_thread(event, acp_thread, cx);
488                }
489                Some(Err(_)) => return false,
490                None => return false,
491            },
492            Either::Right((_, _)) => return true,
493        }
494    }
495}
496
497fn forward_event_to_acp_thread(
498    event: ThreadEvent,
499    acp_thread: &WeakEntity<AcpThread>,
500    cx: &mut AsyncApp,
501) {
502    match event {
503        ThreadEvent::UserMessage(message) => {
504            acp_thread
505                .update(cx, |thread, cx| {
506                    for content in message.content {
507                        thread.push_user_content_block(
508                            Some(message.id.clone()),
509                            content.into(),
510                            cx,
511                        );
512                    }
513                })
514                .log_err();
515        }
516        ThreadEvent::AgentText(text) => {
517            acp_thread
518                .update(cx, |thread, cx| {
519                    thread.push_assistant_content_block(text.into(), false, cx)
520                })
521                .log_err();
522        }
523        ThreadEvent::AgentThinking(text) => {
524            acp_thread
525                .update(cx, |thread, cx| {
526                    thread.push_assistant_content_block(text.into(), true, cx)
527                })
528                .log_err();
529        }
530        ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
531            tool_call,
532            options,
533            response,
534            ..
535        }) => {
536            let outcome_task = acp_thread.update(cx, |thread, cx| {
537                thread.request_tool_call_authorization(tool_call, options, true, cx)
538            });
539            if let Ok(Ok(task)) = outcome_task {
540                cx.background_spawn(async move {
541                    if let acp::RequestPermissionOutcome::Selected(
542                        acp::SelectedPermissionOutcome { option_id, .. },
543                    ) = task.await
544                    {
545                        response.send(option_id).ok();
546                    }
547                })
548                .detach();
549            }
550        }
551        ThreadEvent::ToolCall(tool_call) => {
552            acp_thread
553                .update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx))
554                .log_err();
555        }
556        ThreadEvent::ToolCallUpdate(update) => {
557            acp_thread
558                .update(cx, |thread, cx| thread.update_tool_call(update, cx))
559                .log_err();
560        }
561        ThreadEvent::Retry(status) => {
562            acp_thread
563                .update(cx, |thread, cx| thread.update_retry_status(status, cx))
564                .log_err();
565        }
566        ThreadEvent::Stop(_) => {}
567    }
568}
569
570fn check_context_low(thread: &Entity<Thread>, threshold: f32, cx: &mut AsyncApp) -> bool {
571    thread.read_with(cx, |thread, _| {
572        if let Some(usage) = thread.latest_token_usage() {
573            let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
574            remaining_ratio <= threshold
575        } else {
576            false
577        }
578    })
579}
580
581fn extract_last_message(thread: &Entity<Thread>, cx: &mut AsyncApp) -> String {
582    thread.read_with(cx, |thread, _| {
583        thread
584            .last_message()
585            .map(|m| m.to_markdown())
586            .unwrap_or_else(|| "No response from subagent".to_string())
587    })
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use language_model::LanguageModelToolSchemaFormat;
594
595    #[test]
596    fn test_subagent_tool_input_json_schema_is_valid() {
597        let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema);
598        let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON");
599
600        assert!(
601            schema_json.get("properties").is_some(),
602            "schema should have properties"
603        );
604        let properties = schema_json.get("properties").unwrap();
605
606        assert!(
607            properties.get("subagents").is_some(),
608            "should have subagents field"
609        );
610
611        let subagents_schema = properties.get("subagents").unwrap();
612        assert!(
613            subagents_schema.get("items").is_some(),
614            "subagents should have items schema"
615        );
616
617        // The items use a $ref to definitions/SubagentConfig, so we need to look up
618        // the actual schema in the definitions section
619        let definitions = schema_json
620            .get("definitions")
621            .expect("schema should have definitions");
622        let subagent_config_schema = definitions
623            .get("SubagentConfig")
624            .expect("definitions should have SubagentConfig");
625        let item_properties = subagent_config_schema
626            .get("properties")
627            .expect("SubagentConfig should have properties");
628
629        assert!(
630            item_properties.get("label").is_some(),
631            "subagent item should have label field"
632        );
633        assert!(
634            item_properties.get("task_prompt").is_some(),
635            "subagent item should have task_prompt field"
636        );
637        assert!(
638            item_properties.get("summary_prompt").is_some(),
639            "subagent item should have summary_prompt field"
640        );
641        assert!(
642            item_properties.get("context_low_prompt").is_some(),
643            "subagent item should have context_low_prompt field"
644        );
645        assert!(
646            item_properties.get("timeout_ms").is_some(),
647            "subagent item should have timeout_ms field"
648        );
649        assert!(
650            item_properties.get("allowed_tools").is_some(),
651            "subagent item should have allowed_tools field"
652        );
653    }
654
655    #[test]
656    fn test_subagent_tool_name() {
657        assert_eq!(SubagentTool::name(), "subagent");
658    }
659
660    #[test]
661    fn test_subagent_tool_kind() {
662        assert_eq!(SubagentTool::kind(), acp::ToolKind::Other);
663    }
664}
665
666struct SubagentDisplayConnection;
667
668impl AgentConnection for SubagentDisplayConnection {
669    fn telemetry_id(&self) -> SharedString {
670        "subagent".into()
671    }
672
673    fn auth_methods(&self) -> &[acp::AuthMethod] {
674        &[]
675    }
676
677    fn new_thread(
678        self: Rc<Self>,
679        _project: Entity<Project>,
680        _cwd: &Path,
681        _cx: &mut App,
682    ) -> Task<Result<Entity<AcpThread>>> {
683        unimplemented!("SubagentDisplayConnection does not support new_thread")
684    }
685
686    fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
687        unimplemented!("SubagentDisplayConnection does not support authenticate")
688    }
689
690    fn prompt(
691        &self,
692        _id: Option<UserMessageId>,
693        _params: acp::PromptRequest,
694        _cx: &mut App,
695    ) -> Task<Result<acp::PromptResponse>> {
696        unimplemented!("SubagentDisplayConnection does not support prompt")
697    }
698
699    fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
700
701    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
702        self
703    }
704}