agent: Simplify subagent tool and handle (#50230)

Ben Brandt created

Moves parent thread information into the environment, and also models it
as a turn on the handle rather than waiting for output.

Release Notes:

- N/A

Change summary

crates/agent/src/agent.rs                  | 203 ++++++++++-------------
crates/agent/src/tests/mod.rs              |  18 -
crates/agent/src/thread.rs                 |  14 -
crates/agent/src/tools/spawn_agent_tool.rs |  35 ---
crates/eval/src/instance.rs                |   2 
5 files changed, 98 insertions(+), 174 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -369,11 +369,13 @@ impl NativeAgent {
         let summarization_model = registry.thread_summary_model().map(|c| c.model);
 
         let weak = cx.weak_entity();
+        let weak_thread = thread_handle.downgrade();
         thread_handle.update(cx, |thread, cx| {
             thread.set_summarization_model(summarization_model, cx);
             thread.add_default_tools(
                 Rc::new(NativeThreadEnvironment {
                     acp_thread: acp_thread.downgrade(),
+                    thread: weak_thread,
                     agent: weak,
                 }) as _,
                 cx,
@@ -1576,17 +1578,19 @@ impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
 
 pub struct NativeThreadEnvironment {
     agent: WeakEntity<NativeAgent>,
+    thread: WeakEntity<Thread>,
     acp_thread: WeakEntity<AcpThread>,
 }
 
 impl NativeThreadEnvironment {
     pub(crate) fn create_subagent_thread(
-        agent: WeakEntity<NativeAgent>,
-        parent_thread_entity: Entity<Thread>,
+        &self,
         label: String,
-        initial_prompt: String,
         cx: &mut App,
     ) -> Result<Rc<dyn SubagentHandle>> {
+        let Some(parent_thread_entity) = self.thread.upgrade() else {
+            anyhow::bail!("Parent thread no longer exists".to_string());
+        };
         let parent_thread = parent_thread_entity.read(cx);
         let current_depth = parent_thread.depth();
 
@@ -1605,28 +1609,19 @@ impl NativeThreadEnvironment {
 
         let session_id = subagent_thread.read(cx).id().clone();
 
-        let acp_thread = agent.update(cx, |agent, cx| {
+        let acp_thread = self.agent.update(cx, |agent, cx| {
             agent.register_session(subagent_thread.clone(), cx)
         })?;
 
-        Self::prompt_subagent(
-            session_id,
-            subagent_thread,
-            acp_thread,
-            parent_thread_entity,
-            initial_prompt,
-            cx,
-        )
+        self.prompt_subagent(session_id, subagent_thread, acp_thread)
     }
 
     pub(crate) fn resume_subagent_thread(
-        agent: WeakEntity<NativeAgent>,
-        parent_thread_entity: Entity<Thread>,
+        &self,
         session_id: acp::SessionId,
-        follow_up_prompt: String,
         cx: &mut App,
     ) -> Result<Rc<dyn SubagentHandle>> {
-        let (subagent_thread, acp_thread) = agent.update(cx, |agent, _cx| {
+        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
             let session = agent
                 .sessions
                 .get(&session_id)
@@ -1634,31 +1629,23 @@ impl NativeThreadEnvironment {
             anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
         })??;
 
-        Self::prompt_subagent(
-            session_id,
-            subagent_thread,
-            acp_thread,
-            parent_thread_entity,
-            follow_up_prompt,
-            cx,
-        )
+        self.prompt_subagent(session_id, subagent_thread, acp_thread)
     }
 
     fn prompt_subagent(
+        &self,
         session_id: acp::SessionId,
         subagent_thread: Entity<Thread>,
         acp_thread: Entity<acp_thread::AcpThread>,
-        parent_thread_entity: Entity<Thread>,
-        prompt: String,
-        cx: &mut App,
     ) -> Result<Rc<dyn SubagentHandle>> {
+        let Some(parent_thread_entity) = self.thread.upgrade() else {
+            anyhow::bail!("Parent thread no longer exists".to_string());
+        };
         Ok(Rc::new(NativeSubagentHandle::new(
             session_id,
             subagent_thread,
             acp_thread,
             parent_thread_entity,
-            prompt,
-            cx,
         )) as _)
     }
 }
@@ -1697,36 +1684,16 @@ impl ThreadEnvironment for NativeThreadEnvironment {
         })
     }
 
-    fn create_subagent(
-        &self,
-        parent_thread_entity: Entity<Thread>,
-        label: String,
-        initial_prompt: String,
-        cx: &mut App,
-    ) -> Result<Rc<dyn SubagentHandle>> {
-        Self::create_subagent_thread(
-            self.agent.clone(),
-            parent_thread_entity,
-            label,
-            initial_prompt,
-            cx,
-        )
+    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
+        self.create_subagent_thread(label, cx)
     }
 
     fn resume_subagent(
         &self,
-        parent_thread_entity: Entity<Thread>,
         session_id: acp::SessionId,
-        follow_up_prompt: String,
         cx: &mut App,
     ) -> Result<Rc<dyn SubagentHandle>> {
-        Self::resume_subagent_thread(
-            self.agent.clone(),
-            parent_thread_entity,
-            session_id,
-            follow_up_prompt,
-            cx,
-        )
+        self.resume_subagent_thread(session_id, cx)
     }
 }
 
@@ -1742,8 +1709,7 @@ pub struct NativeSubagentHandle {
     session_id: acp::SessionId,
     parent_thread: WeakEntity<Thread>,
     subagent_thread: Entity<Thread>,
-    wait_for_prompt_to_complete: Shared<Task<SubagentPromptResult>>,
-    _subscription: Subscription,
+    acp_thread: Entity<acp_thread::AcpThread>,
 }
 
 impl NativeSubagentHandle {
@@ -1752,71 +1718,12 @@ impl NativeSubagentHandle {
         subagent_thread: Entity<Thread>,
         acp_thread: Entity<acp_thread::AcpThread>,
         parent_thread_entity: Entity<Thread>,
-        prompt: String,
-        cx: &mut App,
     ) -> Self {
-        let ratio_before_prompt = subagent_thread
-            .read(cx)
-            .latest_token_usage()
-            .map(|usage| usage.ratio());
-
-        parent_thread_entity.update(cx, |parent_thread, _cx| {
-            parent_thread.register_running_subagent(subagent_thread.downgrade())
-        });
-
-        let task = acp_thread.update(cx, |acp_thread, cx| {
-            acp_thread.send(vec![prompt.into()], cx)
-        });
-
-        let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
-        let mut token_limit_tx = Some(token_limit_tx);
-
-        let subscription = cx.subscribe(
-            &subagent_thread,
-            move |_thread, event: &TokenUsageUpdated, _cx| {
-                if let Some(usage) = &event.0 {
-                    let old_ratio = ratio_before_prompt
-                        .clone()
-                        .unwrap_or(TokenUsageRatio::Normal);
-                    let new_ratio = usage.ratio();
-                    if old_ratio == TokenUsageRatio::Normal && new_ratio == TokenUsageRatio::Warning
-                    {
-                        if let Some(tx) = token_limit_tx.take() {
-                            tx.send(()).ok();
-                        }
-                    }
-                }
-            },
-        );
-
-        let wait_for_prompt_to_complete = cx
-            .background_spawn(async move {
-                futures::select! {
-                    response = task.fuse() => match response {
-                        Ok(Some(response)) =>{
-                            match response.stop_reason {
-                                acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
-                                acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
-                                acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
-                                acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
-                                acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
-                            }
-
-                        }
-                        Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
-                        Err(error) => SubagentPromptResult::Error(error.to_string()),
-                    },
-                    _ = token_limit_rx.fuse() =>  SubagentPromptResult::ContextWindowWarning,
-                }
-            })
-            .shared();
-
         NativeSubagentHandle {
             session_id,
             subagent_thread,
             parent_thread: parent_thread_entity.downgrade(),
-            wait_for_prompt_to_complete,
-            _subscription: subscription,
+            acp_thread,
         }
     }
 }
@@ -1826,15 +1733,75 @@ impl SubagentHandle for NativeSubagentHandle {
         self.session_id.clone()
     }
 
-    fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
+    fn run_turn(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
         let thread = self.subagent_thread.clone();
-        let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
-
+        let acp_thread = self.acp_thread.clone();
         let subagent_session_id = self.session_id.clone();
         let parent_thread = self.parent_thread.clone();
 
         cx.spawn(async move |cx| {
-            let result = match wait_for_prompt.await {
+            let (task, _subscription) = cx.update(|cx| {
+                let ratio_before_prompt = thread
+                    .read(cx)
+                    .latest_token_usage()
+                    .map(|usage| usage.ratio());
+
+                parent_thread
+                    .update(cx, |parent_thread, _cx| {
+                        parent_thread.register_running_subagent(thread.downgrade())
+                    })
+                    .ok();
+
+                let task = acp_thread.update(cx, |acp_thread, cx| {
+                    acp_thread.send(vec![message.into()], cx)
+                });
+
+                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
+                let mut token_limit_tx = Some(token_limit_tx);
+
+                let subscription = cx.subscribe(
+                    &thread,
+                    move |_thread, event: &TokenUsageUpdated, _cx| {
+                        if let Some(usage) = &event.0 {
+                            let old_ratio = ratio_before_prompt
+                                .clone()
+                                .unwrap_or(TokenUsageRatio::Normal);
+                            let new_ratio = usage.ratio();
+                            if old_ratio == TokenUsageRatio::Normal
+                                && new_ratio == TokenUsageRatio::Warning
+                            {
+                                if let Some(tx) = token_limit_tx.take() {
+                                    tx.send(()).ok();
+                                }
+                            }
+                        }
+                    },
+                );
+
+                let wait_for_prompt = cx
+                    .background_spawn(async move {
+                        futures::select! {
+                            response = task.fuse() => match response {
+                                Ok(Some(response)) => {
+                                    match response.stop_reason {
+                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
+                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
+                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
+                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
+                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
+                                    }
+                                }
+                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
+                                Err(error) => SubagentPromptResult::Error(error.to_string()),
+                            },
+                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
+                        }
+                    });
+
+                (wait_for_prompt, subscription)
+            });
+
+            let result = match task.await {
                 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
                     thread
                         .last_message()

crates/agent/src/tests/mod.rs 🔗

@@ -167,7 +167,7 @@ impl SubagentHandle for FakeSubagentHandle {
         self.session_id.clone()
     }
 
-    fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
+    fn run_turn(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
         let task = self.wait_for_summary_task.clone();
         cx.background_spawn(async move { Ok(task.await) })
     }
@@ -203,13 +203,7 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment {
         Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
     }
 
-    fn create_subagent(
-        &self,
-        _parent_thread: Entity<Thread>,
-        _label: String,
-        _initial_prompt: String,
-        _cx: &mut App,
-    ) -> Result<Rc<dyn SubagentHandle>> {
+    fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
         Ok(self
             .subagent_handle
             .clone()
@@ -248,13 +242,7 @@ impl crate::ThreadEnvironment for MultiTerminalEnvironment {
         Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
     }
 
-    fn create_subagent(
-        &self,
-        _parent_thread: Entity<Thread>,
-        _label: String,
-        _initial_prompt: String,
-        _cx: &mut App,
-    ) -> Result<Rc<dyn SubagentHandle>> {
+    fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
         unimplemented!()
     }
 }

crates/agent/src/thread.rs 🔗

@@ -605,7 +605,7 @@ pub trait TerminalHandle {
 
 pub trait SubagentHandle {
     fn id(&self) -> acp::SessionId;
-    fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>>;
+    fn run_turn(&self, message: String, cx: &AsyncApp) -> Task<Result<String>>;
 }
 
 pub trait ThreadEnvironment {
@@ -617,19 +617,11 @@ pub trait ThreadEnvironment {
         cx: &mut AsyncApp,
     ) -> Task<Result<Rc<dyn TerminalHandle>>>;
 
-    fn create_subagent(
-        &self,
-        parent_thread: Entity<Thread>,
-        label: String,
-        initial_prompt: String,
-        cx: &mut App,
-    ) -> Result<Rc<dyn SubagentHandle>>;
+    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>>;
 
     fn resume_subagent(
         &self,
-        _parent_thread: Entity<Thread>,
         _session_id: acp::SessionId,
-        _follow_up_prompt: String,
         _cx: &mut App,
     ) -> Result<Rc<dyn SubagentHandle>> {
         Err(anyhow::anyhow!(
@@ -1376,7 +1368,7 @@ impl Thread {
         self.add_tool(WebSearchTool);
 
         if cx.has_flag::<SubagentsFeatureFlag>() && self.depth() < MAX_SUBAGENT_DEPTH {
-            self.add_tool(SpawnAgentTool::new(cx.weak_entity(), environment));
+            self.add_tool(SpawnAgentTool::new(environment));
         }
     }
 

crates/agent/src/tools/spawn_agent_tool.rs 🔗

@@ -1,14 +1,14 @@
 use acp_thread::SUBAGENT_SESSION_ID_META_KEY;
 use agent_client_protocol as acp;
 use anyhow::Result;
-use gpui::{App, SharedString, Task, WeakEntity};
+use gpui::{App, SharedString, Task};
 use language_model::LanguageModelToolResultContent;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use std::rc::Rc;
 use std::sync::Arc;
 
-use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream, ToolInput};
+use crate::{AgentTool, ThreadEnvironment, ToolCallEventStream, ToolInput};
 
 /// Spawns an agent to perform a delegated task.
 ///
@@ -59,16 +59,12 @@ impl From<SpawnAgentToolOutput> for LanguageModelToolResultContent {
 
 /// Tool that spawns an agent thread to work on a task.
 pub struct SpawnAgentTool {
-    parent_thread: WeakEntity<Thread>,
     environment: Rc<dyn ThreadEnvironment>,
 }
 
 impl SpawnAgentTool {
-    pub fn new(parent_thread: WeakEntity<Thread>, environment: Rc<dyn ThreadEnvironment>) -> Self {
-        Self {
-            parent_thread,
-            environment,
-        }
+    pub fn new(environment: Rc<dyn ThreadEnvironment>) -> Self {
+        Self { environment }
     }
 }
 
@@ -108,27 +104,10 @@ impl AgentTool for SpawnAgentTool {
                 })?;
 
             let (subagent, subagent_session_id) = cx.update(|cx| {
-                let Some(parent_thread_entity) = self.parent_thread.upgrade() else {
-                    return Err(SpawnAgentToolOutput::Error {
-                        session_id: None,
-                        error: "Parent thread no longer exists".to_string(),
-                    });
-                };
-
                 let subagent = if let Some(session_id) = input.session_id {
-                    self.environment.resume_subagent(
-                        parent_thread_entity,
-                        session_id,
-                        input.message,
-                        cx,
-                    )
+                    self.environment.resume_subagent(session_id, cx)
                 } else {
-                    self.environment.create_subagent(
-                        parent_thread_entity,
-                        input.label,
-                        input.message,
-                        cx,
-                    )
+                    self.environment.create_subagent(input.label, cx)
                 };
                 let subagent = subagent.map_err(|err| SpawnAgentToolOutput::Error {
                     session_id: None,
@@ -146,7 +125,7 @@ impl AgentTool for SpawnAgentTool {
                 Ok((subagent, subagent_session_id))
             })?;
 
-            match subagent.wait_for_output(cx).await {
+            match subagent.run_turn(input.message, cx).await {
                 Ok(output) => {
                     event_stream.update_fields(
                         acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]),

crates/eval/src/instance.rs 🔗

@@ -682,9 +682,7 @@ impl agent::ThreadEnvironment for EvalThreadEnvironment {
 
     fn create_subagent(
         &self,
-        _parent_thread: Entity<agent::Thread>,
         _label: String,
-        _initial_prompt: String,
         _cx: &mut App,
     ) -> Result<Rc<dyn agent::SubagentHandle>> {
         unimplemented!()