diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 759c6e3b9c8c228a6ae6bea5330819b97200b603..7bf0468d3a65a619a70efd1e7e67f301402ad20c 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -369,11 +369,13 @@ impl NativeAgent { let summarization_model = registry.thread_summary_model().map(|c| c.model); let weak = cx.weak_entity(); + let weak_thread = thread_handle.downgrade(); thread_handle.update(cx, |thread, cx| { thread.set_summarization_model(summarization_model, cx); thread.add_default_tools( Rc::new(NativeThreadEnvironment { acp_thread: acp_thread.downgrade(), + thread: weak_thread, agent: weak, }) as _, cx, @@ -1576,17 +1578,19 @@ impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle { pub struct NativeThreadEnvironment { agent: WeakEntity, + thread: WeakEntity, acp_thread: WeakEntity, } impl NativeThreadEnvironment { pub(crate) fn create_subagent_thread( - agent: WeakEntity, - parent_thread_entity: Entity, + &self, label: String, - initial_prompt: String, cx: &mut App, ) -> Result> { + let Some(parent_thread_entity) = self.thread.upgrade() else { + anyhow::bail!("Parent thread no longer exists".to_string()); + }; let parent_thread = parent_thread_entity.read(cx); let current_depth = parent_thread.depth(); @@ -1605,28 +1609,19 @@ impl NativeThreadEnvironment { let session_id = subagent_thread.read(cx).id().clone(); - let acp_thread = agent.update(cx, |agent, cx| { + let acp_thread = self.agent.update(cx, |agent, cx| { agent.register_session(subagent_thread.clone(), cx) })?; - Self::prompt_subagent( - session_id, - subagent_thread, - acp_thread, - parent_thread_entity, - initial_prompt, - cx, - ) + self.prompt_subagent(session_id, subagent_thread, acp_thread) } pub(crate) fn resume_subagent_thread( - agent: WeakEntity, - parent_thread_entity: Entity, + &self, session_id: acp::SessionId, - follow_up_prompt: String, cx: &mut App, ) -> Result> { - let (subagent_thread, acp_thread) = agent.update(cx, |agent, _cx| { + let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| { let session = agent .sessions .get(&session_id) @@ -1634,31 +1629,23 @@ impl NativeThreadEnvironment { anyhow::Ok((session.thread.clone(), session.acp_thread.clone())) })??; - Self::prompt_subagent( - session_id, - subagent_thread, - acp_thread, - parent_thread_entity, - follow_up_prompt, - cx, - ) + self.prompt_subagent(session_id, subagent_thread, acp_thread) } fn prompt_subagent( + &self, session_id: acp::SessionId, subagent_thread: Entity, acp_thread: Entity, - parent_thread_entity: Entity, - prompt: String, - cx: &mut App, ) -> Result> { + let Some(parent_thread_entity) = self.thread.upgrade() else { + anyhow::bail!("Parent thread no longer exists".to_string()); + }; Ok(Rc::new(NativeSubagentHandle::new( session_id, subagent_thread, acp_thread, parent_thread_entity, - prompt, - cx, )) as _) } } @@ -1697,36 +1684,16 @@ impl ThreadEnvironment for NativeThreadEnvironment { }) } - fn create_subagent( - &self, - parent_thread_entity: Entity, - label: String, - initial_prompt: String, - cx: &mut App, - ) -> Result> { - Self::create_subagent_thread( - self.agent.clone(), - parent_thread_entity, - label, - initial_prompt, - cx, - ) + fn create_subagent(&self, label: String, cx: &mut App) -> Result> { + self.create_subagent_thread(label, cx) } fn resume_subagent( &self, - parent_thread_entity: Entity, session_id: acp::SessionId, - follow_up_prompt: String, cx: &mut App, ) -> Result> { - Self::resume_subagent_thread( - self.agent.clone(), - parent_thread_entity, - session_id, - follow_up_prompt, - cx, - ) + self.resume_subagent_thread(session_id, cx) } } @@ -1742,8 +1709,7 @@ pub struct NativeSubagentHandle { session_id: acp::SessionId, parent_thread: WeakEntity, subagent_thread: Entity, - wait_for_prompt_to_complete: Shared>, - _subscription: Subscription, + acp_thread: Entity, } impl NativeSubagentHandle { @@ -1752,71 +1718,12 @@ impl NativeSubagentHandle { subagent_thread: Entity, acp_thread: Entity, parent_thread_entity: Entity, - prompt: String, - cx: &mut App, ) -> Self { - let ratio_before_prompt = subagent_thread - .read(cx) - .latest_token_usage() - .map(|usage| usage.ratio()); - - parent_thread_entity.update(cx, |parent_thread, _cx| { - parent_thread.register_running_subagent(subagent_thread.downgrade()) - }); - - let task = acp_thread.update(cx, |acp_thread, cx| { - acp_thread.send(vec![prompt.into()], cx) - }); - - let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>(); - let mut token_limit_tx = Some(token_limit_tx); - - let subscription = cx.subscribe( - &subagent_thread, - move |_thread, event: &TokenUsageUpdated, _cx| { - if let Some(usage) = &event.0 { - let old_ratio = ratio_before_prompt - .clone() - .unwrap_or(TokenUsageRatio::Normal); - let new_ratio = usage.ratio(); - if old_ratio == TokenUsageRatio::Normal && new_ratio == TokenUsageRatio::Warning - { - if let Some(tx) = token_limit_tx.take() { - tx.send(()).ok(); - } - } - } - }, - ); - - let wait_for_prompt_to_complete = cx - .background_spawn(async move { - futures::select! { - response = task.fuse() => match response { - Ok(Some(response)) =>{ - match response.stop_reason { - acp::StopReason::Cancelled => SubagentPromptResult::Cancelled, - acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()), - acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()), - acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()), - acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed, - } - - } - Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()), - Err(error) => SubagentPromptResult::Error(error.to_string()), - }, - _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning, - } - }) - .shared(); - NativeSubagentHandle { session_id, subagent_thread, parent_thread: parent_thread_entity.downgrade(), - wait_for_prompt_to_complete, - _subscription: subscription, + acp_thread, } } } @@ -1826,15 +1733,75 @@ impl SubagentHandle for NativeSubagentHandle { self.session_id.clone() } - fn wait_for_output(&self, cx: &AsyncApp) -> Task> { + fn run_turn(&self, message: String, cx: &AsyncApp) -> Task> { let thread = self.subagent_thread.clone(); - let wait_for_prompt = self.wait_for_prompt_to_complete.clone(); - + let acp_thread = self.acp_thread.clone(); let subagent_session_id = self.session_id.clone(); let parent_thread = self.parent_thread.clone(); cx.spawn(async move |cx| { - let result = match wait_for_prompt.await { + let (task, _subscription) = cx.update(|cx| { + let ratio_before_prompt = thread + .read(cx) + .latest_token_usage() + .map(|usage| usage.ratio()); + + parent_thread + .update(cx, |parent_thread, _cx| { + parent_thread.register_running_subagent(thread.downgrade()) + }) + .ok(); + + let task = acp_thread.update(cx, |acp_thread, cx| { + acp_thread.send(vec![message.into()], cx) + }); + + let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>(); + let mut token_limit_tx = Some(token_limit_tx); + + let subscription = cx.subscribe( + &thread, + move |_thread, event: &TokenUsageUpdated, _cx| { + if let Some(usage) = &event.0 { + let old_ratio = ratio_before_prompt + .clone() + .unwrap_or(TokenUsageRatio::Normal); + let new_ratio = usage.ratio(); + if old_ratio == TokenUsageRatio::Normal + && new_ratio == TokenUsageRatio::Warning + { + if let Some(tx) = token_limit_tx.take() { + tx.send(()).ok(); + } + } + } + }, + ); + + let wait_for_prompt = cx + .background_spawn(async move { + futures::select! { + response = task.fuse() => match response { + Ok(Some(response)) => { + match response.stop_reason { + acp::StopReason::Cancelled => SubagentPromptResult::Cancelled, + acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()), + acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()), + acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()), + acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed, + } + } + Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()), + Err(error) => SubagentPromptResult::Error(error.to_string()), + }, + _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning, + } + }); + + (wait_for_prompt, subscription) + }); + + let result = match task.await { SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| { thread .last_message() diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index e8c95c630b65870bfc8a78b9e965373a2604879d..3643704802d673a5b18075c7edbc684b68578219 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -167,7 +167,7 @@ impl SubagentHandle for FakeSubagentHandle { self.session_id.clone() } - fn wait_for_output(&self, cx: &AsyncApp) -> Task> { + fn run_turn(&self, _message: String, cx: &AsyncApp) -> Task> { let task = self.wait_for_summary_task.clone(); cx.background_spawn(async move { Ok(task.await) }) } @@ -203,13 +203,7 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment { Task::ready(Ok(handle as Rc)) } - fn create_subagent( - &self, - _parent_thread: Entity, - _label: String, - _initial_prompt: String, - _cx: &mut App, - ) -> Result> { + fn create_subagent(&self, _label: String, _cx: &mut App) -> Result> { Ok(self .subagent_handle .clone() @@ -248,13 +242,7 @@ impl crate::ThreadEnvironment for MultiTerminalEnvironment { Task::ready(Ok(handle as Rc)) } - fn create_subagent( - &self, - _parent_thread: Entity, - _label: String, - _initial_prompt: String, - _cx: &mut App, - ) -> Result> { + fn create_subagent(&self, _label: String, _cx: &mut App) -> Result> { unimplemented!() } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 923fbd11126f21459131b7ca194288de6af5498e..cfac50aba7daa9bf799b561bb06f14309bcf53dd 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -605,7 +605,7 @@ pub trait TerminalHandle { pub trait SubagentHandle { fn id(&self) -> acp::SessionId; - fn wait_for_output(&self, cx: &AsyncApp) -> Task>; + fn run_turn(&self, message: String, cx: &AsyncApp) -> Task>; } pub trait ThreadEnvironment { @@ -617,19 +617,11 @@ pub trait ThreadEnvironment { cx: &mut AsyncApp, ) -> Task>>; - fn create_subagent( - &self, - parent_thread: Entity, - label: String, - initial_prompt: String, - cx: &mut App, - ) -> Result>; + fn create_subagent(&self, label: String, cx: &mut App) -> Result>; fn resume_subagent( &self, - _parent_thread: Entity, _session_id: acp::SessionId, - _follow_up_prompt: String, _cx: &mut App, ) -> Result> { Err(anyhow::anyhow!( @@ -1376,7 +1368,7 @@ impl Thread { self.add_tool(WebSearchTool); if cx.has_flag::() && self.depth() < MAX_SUBAGENT_DEPTH { - self.add_tool(SpawnAgentTool::new(cx.weak_entity(), environment)); + self.add_tool(SpawnAgentTool::new(environment)); } } diff --git a/crates/agent/src/tools/spawn_agent_tool.rs b/crates/agent/src/tools/spawn_agent_tool.rs index e454377ce1a56134ca0677b37c469ff322a6ed90..8c97b222a901744d77429cba15d03686e31fbde2 100644 --- a/crates/agent/src/tools/spawn_agent_tool.rs +++ b/crates/agent/src/tools/spawn_agent_tool.rs @@ -1,14 +1,14 @@ use acp_thread::SUBAGENT_SESSION_ID_META_KEY; use agent_client_protocol as acp; use anyhow::Result; -use gpui::{App, SharedString, Task, WeakEntity}; +use gpui::{App, SharedString, Task}; use language_model::LanguageModelToolResultContent; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::rc::Rc; use std::sync::Arc; -use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream, ToolInput}; +use crate::{AgentTool, ThreadEnvironment, ToolCallEventStream, ToolInput}; /// Spawns an agent to perform a delegated task. /// @@ -59,16 +59,12 @@ impl From for LanguageModelToolResultContent { /// Tool that spawns an agent thread to work on a task. pub struct SpawnAgentTool { - parent_thread: WeakEntity, environment: Rc, } impl SpawnAgentTool { - pub fn new(parent_thread: WeakEntity, environment: Rc) -> Self { - Self { - parent_thread, - environment, - } + pub fn new(environment: Rc) -> Self { + Self { environment } } } @@ -108,27 +104,10 @@ impl AgentTool for SpawnAgentTool { })?; let (subagent, subagent_session_id) = cx.update(|cx| { - let Some(parent_thread_entity) = self.parent_thread.upgrade() else { - return Err(SpawnAgentToolOutput::Error { - session_id: None, - error: "Parent thread no longer exists".to_string(), - }); - }; - let subagent = if let Some(session_id) = input.session_id { - self.environment.resume_subagent( - parent_thread_entity, - session_id, - input.message, - cx, - ) + self.environment.resume_subagent(session_id, cx) } else { - self.environment.create_subagent( - parent_thread_entity, - input.label, - input.message, - cx, - ) + self.environment.create_subagent(input.label, cx) }; let subagent = subagent.map_err(|err| SpawnAgentToolOutput::Error { session_id: None, @@ -146,7 +125,7 @@ impl AgentTool for SpawnAgentTool { Ok((subagent, subagent_session_id)) })?; - match subagent.wait_for_output(cx).await { + match subagent.run_turn(input.message, cx).await { Ok(output) => { event_stream.update_fields( acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]), diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index cbc5cd1568cf80ad23b9da9dfcaab74730986533..59593578f1ffc512447f08fd728c6619943d6b6e 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -682,9 +682,7 @@ impl agent::ThreadEnvironment for EvalThreadEnvironment { fn create_subagent( &self, - _parent_thread: Entity, _label: String, - _initial_prompt: String, _cx: &mut App, ) -> Result> { unimplemented!()