@@ -369,11 +369,13 @@ impl NativeAgent {
let summarization_model = registry.thread_summary_model().map(|c| c.model);
let weak = cx.weak_entity();
+ let weak_thread = thread_handle.downgrade();
thread_handle.update(cx, |thread, cx| {
thread.set_summarization_model(summarization_model, cx);
thread.add_default_tools(
Rc::new(NativeThreadEnvironment {
acp_thread: acp_thread.downgrade(),
+ thread: weak_thread,
agent: weak,
}) as _,
cx,
@@ -1576,17 +1578,19 @@ impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
pub struct NativeThreadEnvironment {
agent: WeakEntity<NativeAgent>,
+ thread: WeakEntity<Thread>,
acp_thread: WeakEntity<AcpThread>,
}
impl NativeThreadEnvironment {
pub(crate) fn create_subagent_thread(
- agent: WeakEntity<NativeAgent>,
- parent_thread_entity: Entity<Thread>,
+ &self,
label: String,
- initial_prompt: String,
cx: &mut App,
) -> Result<Rc<dyn SubagentHandle>> {
+ let Some(parent_thread_entity) = self.thread.upgrade() else {
+ anyhow::bail!("Parent thread no longer exists".to_string());
+ };
let parent_thread = parent_thread_entity.read(cx);
let current_depth = parent_thread.depth();
@@ -1605,28 +1609,19 @@ impl NativeThreadEnvironment {
let session_id = subagent_thread.read(cx).id().clone();
- let acp_thread = agent.update(cx, |agent, cx| {
+ let acp_thread = self.agent.update(cx, |agent, cx| {
agent.register_session(subagent_thread.clone(), cx)
})?;
- Self::prompt_subagent(
- session_id,
- subagent_thread,
- acp_thread,
- parent_thread_entity,
- initial_prompt,
- cx,
- )
+ self.prompt_subagent(session_id, subagent_thread, acp_thread)
}
pub(crate) fn resume_subagent_thread(
- agent: WeakEntity<NativeAgent>,
- parent_thread_entity: Entity<Thread>,
+ &self,
session_id: acp::SessionId,
- follow_up_prompt: String,
cx: &mut App,
) -> Result<Rc<dyn SubagentHandle>> {
- let (subagent_thread, acp_thread) = agent.update(cx, |agent, _cx| {
+ let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
let session = agent
.sessions
.get(&session_id)
@@ -1634,31 +1629,23 @@ impl NativeThreadEnvironment {
anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
})??;
- Self::prompt_subagent(
- session_id,
- subagent_thread,
- acp_thread,
- parent_thread_entity,
- follow_up_prompt,
- cx,
- )
+ self.prompt_subagent(session_id, subagent_thread, acp_thread)
}
fn prompt_subagent(
+ &self,
session_id: acp::SessionId,
subagent_thread: Entity<Thread>,
acp_thread: Entity<acp_thread::AcpThread>,
- parent_thread_entity: Entity<Thread>,
- prompt: String,
- cx: &mut App,
) -> Result<Rc<dyn SubagentHandle>> {
+ let Some(parent_thread_entity) = self.thread.upgrade() else {
+ anyhow::bail!("Parent thread no longer exists".to_string());
+ };
Ok(Rc::new(NativeSubagentHandle::new(
session_id,
subagent_thread,
acp_thread,
parent_thread_entity,
- prompt,
- cx,
)) as _)
}
}
@@ -1697,36 +1684,16 @@ impl ThreadEnvironment for NativeThreadEnvironment {
})
}
- fn create_subagent(
- &self,
- parent_thread_entity: Entity<Thread>,
- label: String,
- initial_prompt: String,
- cx: &mut App,
- ) -> Result<Rc<dyn SubagentHandle>> {
- Self::create_subagent_thread(
- self.agent.clone(),
- parent_thread_entity,
- label,
- initial_prompt,
- cx,
- )
+ fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
+ self.create_subagent_thread(label, cx)
}
fn resume_subagent(
&self,
- parent_thread_entity: Entity<Thread>,
session_id: acp::SessionId,
- follow_up_prompt: String,
cx: &mut App,
) -> Result<Rc<dyn SubagentHandle>> {
- Self::resume_subagent_thread(
- self.agent.clone(),
- parent_thread_entity,
- session_id,
- follow_up_prompt,
- cx,
- )
+ self.resume_subagent_thread(session_id, cx)
}
}
@@ -1742,8 +1709,7 @@ pub struct NativeSubagentHandle {
session_id: acp::SessionId,
parent_thread: WeakEntity<Thread>,
subagent_thread: Entity<Thread>,
- wait_for_prompt_to_complete: Shared<Task<SubagentPromptResult>>,
- _subscription: Subscription,
+ acp_thread: Entity<acp_thread::AcpThread>,
}
impl NativeSubagentHandle {
@@ -1752,71 +1718,12 @@ impl NativeSubagentHandle {
subagent_thread: Entity<Thread>,
acp_thread: Entity<acp_thread::AcpThread>,
parent_thread_entity: Entity<Thread>,
- prompt: String,
- cx: &mut App,
) -> Self {
- let ratio_before_prompt = subagent_thread
- .read(cx)
- .latest_token_usage()
- .map(|usage| usage.ratio());
-
- parent_thread_entity.update(cx, |parent_thread, _cx| {
- parent_thread.register_running_subagent(subagent_thread.downgrade())
- });
-
- let task = acp_thread.update(cx, |acp_thread, cx| {
- acp_thread.send(vec![prompt.into()], cx)
- });
-
- let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
- let mut token_limit_tx = Some(token_limit_tx);
-
- let subscription = cx.subscribe(
- &subagent_thread,
- move |_thread, event: &TokenUsageUpdated, _cx| {
- if let Some(usage) = &event.0 {
- let old_ratio = ratio_before_prompt
- .clone()
- .unwrap_or(TokenUsageRatio::Normal);
- let new_ratio = usage.ratio();
- if old_ratio == TokenUsageRatio::Normal && new_ratio == TokenUsageRatio::Warning
- {
- if let Some(tx) = token_limit_tx.take() {
- tx.send(()).ok();
- }
- }
- }
- },
- );
-
- let wait_for_prompt_to_complete = cx
- .background_spawn(async move {
- futures::select! {
- response = task.fuse() => match response {
- Ok(Some(response)) =>{
- match response.stop_reason {
- acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
- acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
- acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
- acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
- acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
- }
-
- }
- Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
- Err(error) => SubagentPromptResult::Error(error.to_string()),
- },
- _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
- }
- })
- .shared();
-
NativeSubagentHandle {
session_id,
subagent_thread,
parent_thread: parent_thread_entity.downgrade(),
- wait_for_prompt_to_complete,
- _subscription: subscription,
+ acp_thread,
}
}
}
@@ -1826,15 +1733,75 @@ impl SubagentHandle for NativeSubagentHandle {
self.session_id.clone()
}
- fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
+ fn run_turn(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
let thread = self.subagent_thread.clone();
- let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
-
+ let acp_thread = self.acp_thread.clone();
let subagent_session_id = self.session_id.clone();
let parent_thread = self.parent_thread.clone();
cx.spawn(async move |cx| {
- let result = match wait_for_prompt.await {
+ let (task, _subscription) = cx.update(|cx| {
+ let ratio_before_prompt = thread
+ .read(cx)
+ .latest_token_usage()
+ .map(|usage| usage.ratio());
+
+ parent_thread
+ .update(cx, |parent_thread, _cx| {
+ parent_thread.register_running_subagent(thread.downgrade())
+ })
+ .ok();
+
+ let task = acp_thread.update(cx, |acp_thread, cx| {
+ acp_thread.send(vec![message.into()], cx)
+ });
+
+ let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
+ let mut token_limit_tx = Some(token_limit_tx);
+
+ let subscription = cx.subscribe(
+ &thread,
+ move |_thread, event: &TokenUsageUpdated, _cx| {
+ if let Some(usage) = &event.0 {
+ let old_ratio = ratio_before_prompt
+ .clone()
+ .unwrap_or(TokenUsageRatio::Normal);
+ let new_ratio = usage.ratio();
+ if old_ratio == TokenUsageRatio::Normal
+ && new_ratio == TokenUsageRatio::Warning
+ {
+ if let Some(tx) = token_limit_tx.take() {
+ tx.send(()).ok();
+ }
+ }
+ }
+ },
+ );
+
+ let wait_for_prompt = cx
+ .background_spawn(async move {
+ futures::select! {
+ response = task.fuse() => match response {
+ Ok(Some(response)) => {
+ match response.stop_reason {
+ acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
+ acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
+ acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
+ acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
+ acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
+ }
+ }
+ Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
+ Err(error) => SubagentPromptResult::Error(error.to_string()),
+ },
+ _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
+ }
+ });
+
+ (wait_for_prompt, subscription)
+ });
+
+ let result = match task.await {
SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
thread
.last_message()
@@ -1,14 +1,14 @@
use acp_thread::SUBAGENT_SESSION_ID_META_KEY;
use agent_client_protocol as acp;
use anyhow::Result;
-use gpui::{App, SharedString, Task, WeakEntity};
+use gpui::{App, SharedString, Task};
use language_model::LanguageModelToolResultContent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::rc::Rc;
use std::sync::Arc;
-use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream, ToolInput};
+use crate::{AgentTool, ThreadEnvironment, ToolCallEventStream, ToolInput};
/// Spawns an agent to perform a delegated task.
///
@@ -59,16 +59,12 @@ impl From<SpawnAgentToolOutput> for LanguageModelToolResultContent {
/// Tool that spawns an agent thread to work on a task.
pub struct SpawnAgentTool {
- parent_thread: WeakEntity<Thread>,
environment: Rc<dyn ThreadEnvironment>,
}
impl SpawnAgentTool {
- pub fn new(parent_thread: WeakEntity<Thread>, environment: Rc<dyn ThreadEnvironment>) -> Self {
- Self {
- parent_thread,
- environment,
- }
+ pub fn new(environment: Rc<dyn ThreadEnvironment>) -> Self {
+ Self { environment }
}
}
@@ -108,27 +104,10 @@ impl AgentTool for SpawnAgentTool {
})?;
let (subagent, subagent_session_id) = cx.update(|cx| {
- let Some(parent_thread_entity) = self.parent_thread.upgrade() else {
- return Err(SpawnAgentToolOutput::Error {
- session_id: None,
- error: "Parent thread no longer exists".to_string(),
- });
- };
-
let subagent = if let Some(session_id) = input.session_id {
- self.environment.resume_subagent(
- parent_thread_entity,
- session_id,
- input.message,
- cx,
- )
+ self.environment.resume_subagent(session_id, cx)
} else {
- self.environment.create_subagent(
- parent_thread_entity,
- input.label,
- input.message,
- cx,
- )
+ self.environment.create_subagent(input.label, cx)
};
let subagent = subagent.map_err(|err| SpawnAgentToolOutput::Error {
session_id: None,
@@ -146,7 +125,7 @@ impl AgentTool for SpawnAgentTool {
Ok((subagent, subagent_session_id))
})?;
- match subagent.wait_for_output(cx).await {
+ match subagent.run_turn(input.message, cx).await {
Ok(output) => {
event_stream.update_fields(
acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]),