subagent_tool.rs

  1use acp_thread::{AcpThread, AgentConnection, UserMessageId};
  2use action_log::ActionLog;
  3use agent_client_protocol as acp;
  4use anyhow::{Result, anyhow};
  5use collections::HashSet;
  6use futures::channel::mpsc;
  7use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity};
  8use project::Project;
  9use prompt_store::ProjectContext;
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use smol::stream::StreamExt;
 13use std::any::Any;
 14use std::path::Path;
 15use std::rc::Rc;
 16use std::sync::Arc;
 17use std::time::Duration;
 18use util::ResultExt;
 19use watch;
 20
 21use crate::{
 22    AgentTool, ContextServerRegistry, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH, SubagentContext,
 23    Templates, Thread, ThreadEvent, ToolCallAuthorization, ToolCallEventStream,
 24};
 25
 26/// When a subagent's remaining context window falls below this fraction (25%),
 27/// the "context running out" prompt is sent to encourage the subagent to wrap up.
 28const CONTEXT_LOW_THRESHOLD: f32 = 0.25;
 29
 30/// Spawns a subagent with its own context window to perform a delegated task.
 31///
 32/// Use this tool when you need to:
 33/// - Perform research that would consume too many tokens in the main context
 34/// - Execute a complex subtask independently
 35/// - Run multiple parallel investigations
 36///
 37/// You control what the subagent does by providing:
 38/// 1. A task prompt describing what the subagent should do
 39/// 2. A summary prompt that tells the subagent how to summarize its work when done
 40/// 3. A "context running out" prompt for when the subagent is low on tokens
 41///
 42/// The subagent has access to the same tools you do. You can optionally restrict
 43/// which tools the subagent can use.
 44///
 45/// IMPORTANT:
 46/// - Maximum 8 subagents can be spawned per turn
 47/// - Subagents cannot use tools you don't have access to
 48/// - If spawning multiple subagents that might write to the filesystem, provide
 49///   guidance on how to avoid conflicts (e.g., assign each to different directories)
 50/// - Instruct subagents to be concise in their summaries to conserve your context
 51#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 52pub struct SubagentToolInput {
 53    /// Short label displayed in the UI while the subagent runs (e.g., "Researching alternatives")
 54    pub label: String,
 55
 56    /// The initial prompt that tells the subagent what task to perform.
 57    /// Be specific about what you want the subagent to accomplish.
 58    pub task_prompt: String,
 59
 60    /// The prompt sent to the subagent when it completes its task, asking it
 61    /// to summarize what it did and return results. This summary becomes the
 62    /// tool result you receive.
 63    ///
 64    /// Example: "Summarize what you found, listing the top 3 alternatives with pros/cons."
 65    pub summary_prompt: String,
 66
 67    /// The prompt sent if the subagent is running low on context (25% remaining).
 68    /// Should instruct it to stop and summarize progress so far, plus what's left undone.
 69    ///
 70    /// Example: "Context is running low. Stop and summarize your progress so far,
 71    /// and list what remains to be investigated."
 72    pub context_low_prompt: String,
 73
 74    /// Optional: Maximum runtime in milliseconds. If exceeded, the subagent is
 75    /// asked to summarize and return. No timeout by default.
 76    #[serde(default)]
 77    pub timeout_ms: Option<u64>,
 78
 79    /// Optional: List of tool names the subagent is allowed to use.
 80    /// If not provided, the subagent can use all tools available to the parent.
 81    /// Tools listed here must be a subset of the parent's available tools.
 82    #[serde(default)]
 83    pub allowed_tools: Option<Vec<String>>,
 84}
 85
 86pub struct SubagentTool {
 87    parent_thread: WeakEntity<Thread>,
 88    project: Entity<Project>,
 89    project_context: Entity<ProjectContext>,
 90    context_server_registry: Entity<ContextServerRegistry>,
 91    templates: Arc<Templates>,
 92    current_depth: u8,
 93    parent_tool_names: HashSet<SharedString>,
 94}
 95
 96impl SubagentTool {
 97    pub fn new(
 98        parent_thread: WeakEntity<Thread>,
 99        project: Entity<Project>,
100        project_context: Entity<ProjectContext>,
101        context_server_registry: Entity<ContextServerRegistry>,
102        templates: Arc<Templates>,
103        current_depth: u8,
104        parent_tool_names: Vec<SharedString>,
105    ) -> Self {
106        Self {
107            parent_thread,
108            project,
109            project_context,
110            context_server_registry,
111            templates,
112            current_depth,
113            parent_tool_names: parent_tool_names.into_iter().collect(),
114        }
115    }
116
117    pub fn validate_allowed_tools(&self, allowed_tools: &Option<Vec<String>>) -> Result<()> {
118        if let Some(tools) = allowed_tools {
119            for tool in tools {
120                if !self.parent_tool_names.contains(tool.as_str()) {
121                    return Err(anyhow!(
122                        "Tool '{}' is not available to the parent agent. Available tools: {:?}",
123                        tool,
124                        self.parent_tool_names.iter().collect::<Vec<_>>()
125                    ));
126                }
127            }
128        }
129        Ok(())
130    }
131}
132
133impl AgentTool for SubagentTool {
134    type Input = SubagentToolInput;
135    type Output = String;
136
137    fn name() -> &'static str {
138        acp_thread::SUBAGENT_TOOL_NAME
139    }
140
141    fn kind() -> acp::ToolKind {
142        acp::ToolKind::Other
143    }
144
145    fn initial_title(
146        &self,
147        input: Result<Self::Input, serde_json::Value>,
148        _cx: &mut App,
149    ) -> SharedString {
150        input
151            .map(|i| i.label.into())
152            .unwrap_or_else(|_| "Subagent".into())
153    }
154
155    fn run(
156        self: Arc<Self>,
157        input: Self::Input,
158        event_stream: ToolCallEventStream,
159        cx: &mut App,
160    ) -> Task<Result<String>> {
161        if self.current_depth >= MAX_SUBAGENT_DEPTH {
162            return Task::ready(Err(anyhow!(
163                "Maximum subagent depth ({}) reached",
164                MAX_SUBAGENT_DEPTH
165            )));
166        }
167
168        if let Err(e) = self.validate_allowed_tools(&input.allowed_tools) {
169            return Task::ready(Err(e));
170        }
171
172        let Some(parent_thread) = self.parent_thread.upgrade() else {
173            return Task::ready(Err(anyhow!(
174                "Parent thread no longer exists (subagent depth={})",
175                self.current_depth + 1
176            )));
177        };
178
179        let running_count = parent_thread.read(cx).running_subagent_count();
180        if running_count >= MAX_PARALLEL_SUBAGENTS {
181            return Task::ready(Err(anyhow!(
182                "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
183                MAX_PARALLEL_SUBAGENTS
184            )));
185        }
186
187        let parent_thread_id = parent_thread.read(cx).id().clone();
188        let parent_model = parent_thread.read(cx).model().cloned();
189        let tool_use_id = event_stream.tool_use_id().clone();
190
191        let Some(model) = parent_model else {
192            return Task::ready(Err(anyhow!("No model configured")));
193        };
194
195        let subagent_context = SubagentContext {
196            parent_thread_id,
197            tool_use_id,
198            depth: self.current_depth + 1,
199            summary_prompt: input.summary_prompt.clone(),
200            context_low_prompt: input.context_low_prompt.clone(),
201        };
202
203        let project = self.project.clone();
204        let project_context = self.project_context.clone();
205        let context_server_registry = self.context_server_registry.clone();
206        let templates = self.templates.clone();
207        let task_prompt = input.task_prompt;
208        let timeout_ms = input.timeout_ms;
209        let allowed_tools: Option<HashSet<SharedString>> = input
210            .allowed_tools
211            .map(|tools| tools.into_iter().map(SharedString::from).collect());
212
213        let parent_thread = self.parent_thread.clone();
214
215        cx.spawn(async move |cx| {
216            let subagent_thread: Entity<Thread> = cx.new(|cx| {
217                Thread::new_subagent(
218                    project.clone(),
219                    project_context.clone(),
220                    context_server_registry.clone(),
221                    templates.clone(),
222                    model,
223                    subagent_context,
224                    cx,
225                )
226            });
227
228            let subagent_weak = subagent_thread.downgrade();
229
230            let acp_thread: Entity<AcpThread> = cx.new(|cx| {
231                let session_id = subagent_thread.read(cx).id().clone();
232                let action_log: Entity<ActionLog> = cx.new(|_| ActionLog::new(project.clone()));
233                let connection: Rc<dyn AgentConnection> = Rc::new(SubagentDisplayConnection);
234                AcpThread::new(
235                    "Subagent",
236                    connection,
237                    project.clone(),
238                    action_log,
239                    session_id,
240                    watch::Receiver::constant(acp::PromptCapabilities::new()),
241                    cx,
242                )
243            });
244
245            event_stream.update_subagent_thread(acp_thread.clone());
246
247            if let Some(parent) = parent_thread.upgrade() {
248                parent.update(cx, |thread, _cx| {
249                    thread.register_running_subagent(subagent_weak.clone());
250                });
251            }
252
253            let result = run_subagent(
254                &subagent_thread,
255                &acp_thread,
256                allowed_tools,
257                task_prompt,
258                timeout_ms,
259                cx,
260            )
261            .await;
262
263            if let Some(parent) = parent_thread.upgrade() {
264                let _ = parent.update(cx, |thread, _cx| {
265                    thread.unregister_running_subagent(&subagent_weak);
266                });
267            }
268
269            result
270        })
271    }
272}
273
274async fn run_subagent(
275    subagent_thread: &Entity<Thread>,
276    acp_thread: &Entity<AcpThread>,
277    allowed_tools: Option<HashSet<SharedString>>,
278    task_prompt: String,
279    timeout_ms: Option<u64>,
280    cx: &mut AsyncApp,
281) -> Result<String> {
282    if let Some(ref allowed) = allowed_tools {
283        subagent_thread.update(cx, |thread, _cx| {
284            thread.restrict_tools(allowed);
285        });
286    }
287
288    let mut events_rx =
289        subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?;
290
291    let acp_thread_weak = acp_thread.downgrade();
292
293    let timed_out = if let Some(timeout) = timeout_ms {
294        forward_events_with_timeout(
295            &mut events_rx,
296            &acp_thread_weak,
297            Duration::from_millis(timeout),
298            cx,
299        )
300        .await
301    } else {
302        forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await;
303        false
304    };
305
306    let should_interrupt =
307        timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx);
308
309    if should_interrupt {
310        let mut summary_rx =
311            subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?;
312        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
313    } else {
314        let mut summary_rx =
315            subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?;
316        forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
317    }
318
319    Ok(extract_last_message(subagent_thread, cx))
320}
321
322async fn forward_events_until_stop(
323    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
324    acp_thread: &WeakEntity<AcpThread>,
325    cx: &mut AsyncApp,
326) {
327    while let Some(event) = events_rx.next().await {
328        match event {
329            Ok(ThreadEvent::Stop(_)) => break,
330            Ok(event) => {
331                forward_event_to_acp_thread(event, acp_thread, cx);
332            }
333            Err(_) => break,
334        }
335    }
336}
337
338async fn forward_events_with_timeout(
339    events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
340    acp_thread: &WeakEntity<AcpThread>,
341    timeout: Duration,
342    cx: &mut AsyncApp,
343) -> bool {
344    use futures::future::{self, Either};
345
346    let deadline = std::time::Instant::now() + timeout;
347
348    loop {
349        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
350        if remaining.is_zero() {
351            return true;
352        }
353
354        let timeout_future = cx.background_executor().timer(remaining);
355        let event_future = events_rx.next();
356
357        match future::select(event_future, timeout_future).await {
358            Either::Left((event, _)) => match event {
359                Some(Ok(ThreadEvent::Stop(_))) => return false,
360                Some(Ok(event)) => {
361                    forward_event_to_acp_thread(event, acp_thread, cx);
362                }
363                Some(Err(_)) => return false,
364                None => return false,
365            },
366            Either::Right((_, _)) => return true,
367        }
368    }
369}
370
371fn forward_event_to_acp_thread(
372    event: ThreadEvent,
373    acp_thread: &WeakEntity<AcpThread>,
374    cx: &mut AsyncApp,
375) {
376    match event {
377        ThreadEvent::UserMessage(message) => {
378            acp_thread
379                .update(cx, |thread, cx| {
380                    for content in message.content {
381                        thread.push_user_content_block(
382                            Some(message.id.clone()),
383                            content.into(),
384                            cx,
385                        );
386                    }
387                })
388                .log_err();
389        }
390        ThreadEvent::AgentText(text) => {
391            acp_thread
392                .update(cx, |thread, cx| {
393                    thread.push_assistant_content_block(text.into(), false, cx)
394                })
395                .log_err();
396        }
397        ThreadEvent::AgentThinking(text) => {
398            acp_thread
399                .update(cx, |thread, cx| {
400                    thread.push_assistant_content_block(text.into(), true, cx)
401                })
402                .log_err();
403        }
404        ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
405            tool_call,
406            options,
407            response,
408        }) => {
409            let outcome_task = acp_thread.update(cx, |thread, cx| {
410                thread.request_tool_call_authorization(tool_call, options, true, cx)
411            });
412            if let Ok(Ok(task)) = outcome_task {
413                cx.background_spawn(async move {
414                    if let acp::RequestPermissionOutcome::Selected(
415                        acp::SelectedPermissionOutcome { option_id, .. },
416                    ) = task.await
417                    {
418                        response.send(option_id).ok();
419                    }
420                })
421                .detach();
422            }
423        }
424        ThreadEvent::ToolCall(tool_call) => {
425            acp_thread
426                .update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx))
427                .log_err();
428        }
429        ThreadEvent::ToolCallUpdate(update) => {
430            acp_thread
431                .update(cx, |thread, cx| thread.update_tool_call(update, cx))
432                .log_err();
433        }
434        ThreadEvent::Retry(status) => {
435            acp_thread
436                .update(cx, |thread, cx| thread.update_retry_status(status, cx))
437                .log_err();
438        }
439        ThreadEvent::Stop(_) => {}
440    }
441}
442
443fn check_context_low(thread: &Entity<Thread>, threshold: f32, cx: &mut AsyncApp) -> bool {
444    thread.read_with(cx, |thread, _| {
445        if let Some(usage) = thread.latest_token_usage() {
446            let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
447            remaining_ratio <= threshold
448        } else {
449            false
450        }
451    })
452}
453
454fn extract_last_message(thread: &Entity<Thread>, cx: &mut AsyncApp) -> String {
455    thread.read_with(cx, |thread, _| {
456        thread
457            .last_message()
458            .map(|m| m.to_markdown())
459            .unwrap_or_else(|| "No response from subagent".to_string())
460    })
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use language_model::LanguageModelToolSchemaFormat;
467
468    #[test]
469    fn test_subagent_tool_input_json_schema_is_valid() {
470        let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema);
471        let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON");
472
473        assert!(
474            schema_json.get("properties").is_some(),
475            "schema should have properties"
476        );
477        let properties = schema_json.get("properties").unwrap();
478
479        assert!(properties.get("label").is_some(), "should have label field");
480        assert!(
481            properties.get("task_prompt").is_some(),
482            "should have task_prompt field"
483        );
484        assert!(
485            properties.get("summary_prompt").is_some(),
486            "should have summary_prompt field"
487        );
488        assert!(
489            properties.get("context_low_prompt").is_some(),
490            "should have context_low_prompt field"
491        );
492        assert!(
493            properties.get("timeout_ms").is_some(),
494            "should have timeout_ms field"
495        );
496        assert!(
497            properties.get("allowed_tools").is_some(),
498            "should have allowed_tools field"
499        );
500    }
501
502    #[test]
503    fn test_subagent_tool_name() {
504        assert_eq!(SubagentTool::name(), "subagent");
505    }
506
507    #[test]
508    fn test_subagent_tool_kind() {
509        assert_eq!(SubagentTool::kind(), acp::ToolKind::Other);
510    }
511}
512
513struct SubagentDisplayConnection;
514
515impl AgentConnection for SubagentDisplayConnection {
516    fn telemetry_id(&self) -> SharedString {
517        "subagent".into()
518    }
519
520    fn auth_methods(&self) -> &[acp::AuthMethod] {
521        &[]
522    }
523
524    fn new_thread(
525        self: Rc<Self>,
526        _project: Entity<Project>,
527        _cwd: &Path,
528        _cx: &mut App,
529    ) -> Task<Result<Entity<AcpThread>>> {
530        unimplemented!("SubagentDisplayConnection does not support new_thread")
531    }
532
533    fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
534        unimplemented!("SubagentDisplayConnection does not support authenticate")
535    }
536
537    fn prompt(
538        &self,
539        _id: Option<UserMessageId>,
540        _params: acp::PromptRequest,
541        _cx: &mut App,
542    ) -> Task<Result<acp::PromptResponse>> {
543        unimplemented!("SubagentDisplayConnection does not support prompt")
544    }
545
546    fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
547
548    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
549        self
550    }
551}