subagent_tool.rs

  1use acp_thread::{AcpThread, AgentConnection, UserMessageId};
  2use action_log::ActionLog;
  3use agent_client_protocol as acp;
  4use anyhow::{Result, anyhow};
  5use collections::{BTreeMap, HashSet};
  6use futures::{FutureExt, channel::mpsc};
  7use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity};
  8use language_model::LanguageModelToolUseId;
  9use project::Project;
 10use prompt_store::ProjectContext;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use smol::stream::StreamExt;
 14use std::any::Any;
 15use std::path::Path;
 16use std::rc::Rc;
 17use std::sync::Arc;
 18use std::time::Duration;
 19use util::ResultExt;
 20use watch;
 21
 22use crate::{
 23    AgentTool, AnyAgentTool, ContextServerRegistry, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH,
 24    SubagentContext, Templates, Thread, ThreadEvent, ToolCallAuthorization, ToolCallEventStream,
 25};
 26
 27/// When a subagent's remaining context window falls below this fraction (25%),
 28/// the "context running out" prompt is sent to encourage the subagent to wrap up.
 29const CONTEXT_LOW_THRESHOLD: f32 = 0.25;
 30
 31/// Spawns a subagent with its own context window to perform a delegated task.
 32///
 33/// Use this tool when you want to do any of the following:
 34/// - Perform an investigation where all you need to know is the outcome, not the research that led to that outcome.
 35/// - Complete a self-contained task where you need to know if it succeeded or failed (and how), but none of its intermediate output.
 36/// - Run multiple tasks in parallel that would take significantly longer to run sequentially.
 37///
 38/// You control what the subagent does by providing:
 39/// 1. A task prompt describing what the subagent should do
 40/// 2. A summary prompt that tells the subagent how to summarize its work when done
 41/// 3. A "context running out" prompt for when the subagent is low on tokens
 42///
 43/// Each subagent has access to the same tools you do. You can optionally restrict
 44/// which tools each subagent can use.
 45///
 46/// Note:
 47/// - Maximum 8 subagents can run in parallel
 48/// - Subagents cannot use tools you don't have access to
 49/// - If spawning multiple subagents that might write to the filesystem, provide
 50///   guidance on how to avoid conflicts (e.g. assign each to different directories)
 51/// - Instruct subagents to be concise in their summaries to conserve your context
 52#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 53pub struct SubagentToolInput {
 54    /// Short label displayed in the UI while the subagent runs (e.g., "Researching alternatives")
 55    pub label: String,
 56
 57    /// The initial prompt that tells the subagent what task to perform.
 58    /// Be specific about what you want the subagent to accomplish.
 59    pub task_prompt: String,
 60
 61    /// The prompt sent to the subagent when it completes its task, asking it
 62    /// to summarize what it did and return results. This summary becomes the
 63    /// tool result you receive.
 64    ///
 65    /// Example: "Summarize what you found, listing the top 3 alternatives with pros/cons."
 66    pub summary_prompt: String,
 67
 68    /// The prompt sent if the subagent is running low on context (25% remaining).
 69    /// Should instruct it to stop and summarize progress so far, plus what's left undone.
 70    ///
 71    /// Example: "Context is running low. Stop and summarize your progress so far,
 72    /// and list what remains to be investigated."
 73    pub context_low_prompt: String,
 74
 75    /// Optional: Maximum runtime in milliseconds. If exceeded, the subagent is
 76    /// asked to summarize and return. No timeout by default.
 77    #[serde(default)]
 78    pub timeout_ms: Option<u64>,
 79
 80    /// Optional: List of tool names the subagent is allowed to use.
 81    /// If not provided, the subagent can use all tools available to the parent.
 82    /// Tools listed here must be a subset of the parent's available tools.
 83    #[serde(default)]
 84    pub allowed_tools: Option<Vec<String>>,
 85}
 86
 87/// Tool that spawns a subagent thread to work on a task.
 88pub struct SubagentTool {
 89    parent_thread: WeakEntity<Thread>,
 90    project: Entity<Project>,
 91    project_context: Entity<ProjectContext>,
 92    context_server_registry: Entity<ContextServerRegistry>,
 93    templates: Arc<Templates>,
 94    current_depth: u8,
 95    /// The tools available to the parent thread, captured before SubagentTool was added.
 96    /// Subagents inherit from this set (or a subset via `allowed_tools` in the config).
 97    /// This is captured early so subagents don't get the subagent tool themselves.
 98    parent_tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 99}
100
101impl SubagentTool {
102    pub fn new(
103        parent_thread: WeakEntity<Thread>,
104        project: Entity<Project>,
105        project_context: Entity<ProjectContext>,
106        context_server_registry: Entity<ContextServerRegistry>,
107        templates: Arc<Templates>,
108        current_depth: u8,
109        parent_tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
110    ) -> Self {
111        Self {
112            parent_thread,
113            project,
114            project_context,
115            context_server_registry,
116            templates,
117            current_depth,
118            parent_tools,
119        }
120    }
121
122    pub fn validate_allowed_tools(&self, allowed_tools: &Option<Vec<String>>) -> Result<()> {
123        let Some(tools) = allowed_tools else {
124            return Ok(());
125        };
126
127        let invalid_tools: Vec<&str> = tools
128            .iter()
129            .filter(|tool| !self.parent_tools.contains_key(tool.as_str()))
130            .map(|s| s.as_str())
131            .collect();
132
133        if !invalid_tools.is_empty() {
134            return Err(anyhow!(
135                "The following tools do not exist: {}",
136                invalid_tools
137                    .iter()
138                    .map(|t| format!("'{}'", t))
139                    .collect::<Vec<_>>()
140                    .join(", ")
141            ));
142        }
143
144        Ok(())
145    }
146}
147
148impl AgentTool for SubagentTool {
149    type Input = SubagentToolInput;
150    type Output = String;
151
152    fn name() -> &'static str {
153        acp_thread::SUBAGENT_TOOL_NAME
154    }
155
156    fn kind() -> acp::ToolKind {
157        acp::ToolKind::Other
158    }
159
160    fn initial_title(
161        &self,
162        input: Result<Self::Input, serde_json::Value>,
163        _cx: &mut App,
164    ) -> SharedString {
165        input
166            .map(|i| i.label.into())
167            .unwrap_or_else(|_| "Subagent".into())
168    }
169
170    fn run(
171        self: Arc<Self>,
172        input: Self::Input,
173        event_stream: ToolCallEventStream,
174        cx: &mut App,
175    ) -> Task<Result<String>> {
176        if self.current_depth >= MAX_SUBAGENT_DEPTH {
177            return Task::ready(Err(anyhow!(
178                "Maximum subagent depth ({}) reached",
179                MAX_SUBAGENT_DEPTH
180            )));
181        }
182
183        if let Err(e) = self.validate_allowed_tools(&input.allowed_tools) {
184            return Task::ready(Err(e));
185        }
186
187        let Some(parent_thread) = self.parent_thread.upgrade() else {
188            return Task::ready(Err(anyhow!(
189                "Parent thread no longer exists (subagent depth={})",
190                self.current_depth + 1
191            )));
192        };
193
194        let running_count = parent_thread.read(cx).running_subagent_count();
195        if running_count >= MAX_PARALLEL_SUBAGENTS {
196            return Task::ready(Err(anyhow!(
197                "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
198                MAX_PARALLEL_SUBAGENTS
199            )));
200        }
201
202        let parent_model = parent_thread.read(cx).model().cloned();
203        let Some(model) = parent_model else {
204            return Task::ready(Err(anyhow!("No model configured")));
205        };
206
207        let parent_thread_id = parent_thread.read(cx).id().clone();
208        let project = self.project.clone();
209        let project_context = self.project_context.clone();
210        let context_server_registry = self.context_server_registry.clone();
211        let templates = self.templates.clone();
212        let parent_tools = self.parent_tools.clone();
213        let current_depth = self.current_depth;
214        let parent_thread_weak = self.parent_thread.clone();
215
216        cx.spawn(async move |cx| {
217            let subagent_context = SubagentContext {
218                parent_thread_id: parent_thread_id.clone(),
219                tool_use_id: LanguageModelToolUseId::from(uuid::Uuid::new_v4().to_string()),
220                depth: current_depth + 1,
221                summary_prompt: input.summary_prompt.clone(),
222                context_low_prompt: input.context_low_prompt.clone(),
223            };
224
225            // Determine which tools this subagent gets
226            let subagent_tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>> =
227                if let Some(ref allowed) = input.allowed_tools {
228                    let allowed_set: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect();
229                    parent_tools
230                        .iter()
231                        .filter(|(name, _)| allowed_set.contains(name.as_ref()))
232                        .map(|(name, tool)| (name.clone(), tool.clone()))
233                        .collect()
234                } else {
235                    parent_tools.clone()
236                };
237
238            let subagent_thread: Entity<Thread> = cx.new(|cx| {
239                Thread::new_subagent(
240                    project.clone(),
241                    project_context.clone(),
242                    context_server_registry.clone(),
243                    templates.clone(),
244                    model.clone(),
245                    subagent_context,
246                    subagent_tools,
247                    cx,
248                )
249            });
250
251            let subagent_weak = subagent_thread.downgrade();
252
253            let acp_thread: Entity<AcpThread> = cx.new(|cx| {
254                let session_id = subagent_thread.read(cx).id().clone();
255                let action_log: Entity<ActionLog> = cx.new(|_| ActionLog::new(project.clone()));
256                let connection: Rc<dyn AgentConnection> = Rc::new(SubagentDisplayConnection);
257                AcpThread::new(
258                    &input.label,
259                    connection,
260                    project.clone(),
261                    action_log,
262                    session_id,
263                    watch::Receiver::constant(acp::PromptCapabilities::new()),
264                    cx,
265                )
266            });
267
268            event_stream.update_subagent_thread(acp_thread.clone());
269
270            let mut user_stop_rx: watch::Receiver<bool> =
271                acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
272
273            if let Some(parent) = parent_thread_weak.upgrade() {
274                parent.update(cx, |thread, _cx| {
275                    thread.register_running_subagent(subagent_weak.clone());
276                });
277            }
278
279            // Helper to wait for user stop signal on the subagent card
280            let wait_for_user_stop = async {
281                loop {
282                    if *user_stop_rx.borrow() {
283                        return;
284                    }
285                    if user_stop_rx.changed().await.is_err() {
286                        std::future::pending::<()>().await;
287                    }
288                }
289            };
290
291            // Run the subagent, handling cancellation from both:
292            // 1. Parent turn cancellation (event_stream.cancelled_by_user)
293            // 2. Direct user stop on subagent card (user_stop_rx)
294            let result = futures::select! {
295                result = run_subagent(
296                    &subagent_thread,
297                    &acp_thread,
298                    input.task_prompt,
299                    input.timeout_ms,
300                    cx,
301                ).fuse() => result,
302                _ = event_stream.cancelled_by_user().fuse() => {
303                    let _ = subagent_thread.update(cx, |thread, cx| {
304                        thread.cancel(cx).detach();
305                    });
306                    Err(anyhow!("Subagent cancelled by user"))
307                }
308                _ = wait_for_user_stop.fuse() => {
309                    let _ = subagent_thread.update(cx, |thread, cx| {
310                        thread.cancel(cx).detach();
311                    });
312                    Err(anyhow!("Subagent stopped by user"))
313                }
314            };
315
316            if let Some(parent) = parent_thread_weak.upgrade() {
317                let _ = parent.update(cx, |thread, _cx| {
318                    thread.unregister_running_subagent(&subagent_weak);
319                });
320            }
321
322            result
323        })
324    }
325}
326
327async fn run_subagent(
328    subagent_thread: &Entity<Thread>,
329    acp_thread: &Entity<AcpThread>,
330    task_prompt: String,
331    timeout_ms: Option<u64>,
332    cx: &mut AsyncApp,
333) -> Result<String> {
334    let mut events_rx =
335        subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?;
336
337    let acp_thread_weak = acp_thread.downgrade();
338
339    let timed_out = if let Some(timeout) = timeout_ms {
340        forward_events_with_timeout(
341            &mut events_rx,
342            &acp_thread_weak,
343            Duration::from_millis(timeout),
344            cx,
345        )
346        .await
347    } else {
348        forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await;
349        false
350    };
351
352    let should_interrupt =
353        timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx);
354
355    if should_interrupt {
356        let mut summary_rx =
357            subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?;
358        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
359    } else {
360        let mut summary_rx =
361            subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?;
362        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
363    }
364
365    Ok(extract_last_message(subagent_thread, cx))
366}
367
368async fn forward_events_until_stop(
369    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
370    acp_thread: &WeakEntity<AcpThread>,
371    cx: &mut AsyncApp,
372) {
373    while let Some(event) = events_rx.next().await {
374        match event {
375            Ok(ThreadEvent::Stop(_)) => break,
376            Ok(event) => {
377                forward_event_to_acp_thread(event, acp_thread, cx);
378            }
379            Err(_) => break,
380        }
381    }
382}
383
384async fn forward_events_with_timeout(
385    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
386    acp_thread: &WeakEntity<AcpThread>,
387    timeout: Duration,
388    cx: &mut AsyncApp,
389) -> bool {
390    use futures::future::{self, Either};
391
392    let deadline = std::time::Instant::now() + timeout;
393
394    loop {
395        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
396        if remaining.is_zero() {
397            return true;
398        }
399
400        let timeout_future = cx.background_executor().timer(remaining);
401        let event_future = events_rx.next();
402
403        match future::select(event_future, timeout_future).await {
404            Either::Left((event, _)) => match event {
405                Some(Ok(ThreadEvent::Stop(_))) => return false,
406                Some(Ok(event)) => {
407                    forward_event_to_acp_thread(event, acp_thread, cx);
408                }
409                Some(Err(_)) => return false,
410                None => return false,
411            },
412            Either::Right((_, _)) => return true,
413        }
414    }
415}
416
417fn forward_event_to_acp_thread(
418    event: ThreadEvent,
419    acp_thread: &WeakEntity<AcpThread>,
420    cx: &mut AsyncApp,
421) {
422    match event {
423        ThreadEvent::UserMessage(message) => {
424            acp_thread
425                .update(cx, |thread, cx| {
426                    for content in message.content {
427                        thread.push_user_content_block(
428                            Some(message.id.clone()),
429                            content.into(),
430                            cx,
431                        );
432                    }
433                })
434                .log_err();
435        }
436        ThreadEvent::AgentText(text) => {
437            acp_thread
438                .update(cx, |thread, cx| {
439                    thread.push_assistant_content_block(text.into(), false, cx)
440                })
441                .log_err();
442        }
443        ThreadEvent::AgentThinking(text) => {
444            acp_thread
445                .update(cx, |thread, cx| {
446                    thread.push_assistant_content_block(text.into(), true, cx)
447                })
448                .log_err();
449        }
450        ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
451            tool_call,
452            options,
453            response,
454            ..
455        }) => {
456            let outcome_task = acp_thread.update(cx, |thread, cx| {
457                thread.request_tool_call_authorization(tool_call, options, true, cx)
458            });
459            if let Ok(Ok(task)) = outcome_task {
460                cx.background_spawn(async move {
461                    if let acp::RequestPermissionOutcome::Selected(
462                        acp::SelectedPermissionOutcome { option_id, .. },
463                    ) = task.await
464                    {
465                        response.send(option_id).ok();
466                    }
467                })
468                .detach();
469            }
470        }
471        ThreadEvent::ToolCall(tool_call) => {
472            acp_thread
473                .update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx))
474                .log_err();
475        }
476        ThreadEvent::ToolCallUpdate(update) => {
477            acp_thread
478                .update(cx, |thread, cx| thread.update_tool_call(update, cx))
479                .log_err();
480        }
481        ThreadEvent::Retry(status) => {
482            acp_thread
483                .update(cx, |thread, cx| thread.update_retry_status(status, cx))
484                .log_err();
485        }
486        ThreadEvent::Stop(_) => {}
487    }
488}
489
490fn check_context_low(thread: &Entity<Thread>, threshold: f32, cx: &mut AsyncApp) -> bool {
491    thread.read_with(cx, |thread, _| {
492        if let Some(usage) = thread.latest_token_usage() {
493            let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
494            remaining_ratio <= threshold
495        } else {
496            false
497        }
498    })
499}
500
501fn extract_last_message(thread: &Entity<Thread>, cx: &mut AsyncApp) -> String {
502    thread.read_with(cx, |thread, _| {
503        thread
504            .last_message()
505            .map(|m| m.to_markdown())
506            .unwrap_or_else(|| "No response from subagent".to_string())
507    })
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use language_model::LanguageModelToolSchemaFormat;
514
515    #[test]
516    fn test_subagent_tool_input_json_schema_is_valid() {
517        let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema);
518        let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON");
519
520        assert!(
521            schema_json.get("properties").is_some(),
522            "schema should have properties"
523        );
524        let properties = schema_json.get("properties").unwrap();
525
526        assert!(properties.get("label").is_some(), "should have label field");
527        assert!(
528            properties.get("task_prompt").is_some(),
529            "should have task_prompt field"
530        );
531        assert!(
532            properties.get("summary_prompt").is_some(),
533            "should have summary_prompt field"
534        );
535        assert!(
536            properties.get("context_low_prompt").is_some(),
537            "should have context_low_prompt field"
538        );
539        assert!(
540            properties.get("timeout_ms").is_some(),
541            "should have timeout_ms field"
542        );
543        assert!(
544            properties.get("allowed_tools").is_some(),
545            "should have allowed_tools field"
546        );
547    }
548
549    #[test]
550    fn test_subagent_tool_name() {
551        assert_eq!(SubagentTool::name(), "subagent");
552    }
553
554    #[test]
555    fn test_subagent_tool_kind() {
556        assert_eq!(SubagentTool::kind(), acp::ToolKind::Other);
557    }
558}
559
560struct SubagentDisplayConnection;
561
562impl AgentConnection for SubagentDisplayConnection {
563    fn telemetry_id(&self) -> SharedString {
564        "subagent".into()
565    }
566
567    fn auth_methods(&self) -> &[acp::AuthMethod] {
568        &[]
569    }
570
571    fn new_thread(
572        self: Rc<Self>,
573        _project: Entity<Project>,
574        _cwd: &Path,
575        _cx: &mut App,
576    ) -> Task<Result<Entity<AcpThread>>> {
577        unimplemented!("SubagentDisplayConnection does not support new_thread")
578    }
579
580    fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
581        unimplemented!("SubagentDisplayConnection does not support authenticate")
582    }
583
584    fn prompt(
585        &self,
586        _id: Option<UserMessageId>,
587        _params: acp::PromptRequest,
588        _cx: &mut App,
589    ) -> Task<Result<acp::PromptResponse>> {
590        unimplemented!("SubagentDisplayConnection does not support prompt")
591    }
592
593    fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
594
595    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
596        self
597    }
598}