Detailed changes
@@ -935,6 +935,11 @@ pub struct RetryStatus {
pub duration: Duration,
}
+struct RunningTurn {
+ id: u32,
+ send_task: Task<()>,
+}
+
pub struct AcpThread {
parent_session_id: Option<acp::SessionId>,
title: SharedString,
@@ -943,7 +948,8 @@ pub struct AcpThread {
project: Entity<Project>,
action_log: Entity<ActionLog>,
shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
- send_task: Option<Task<()>>,
+ turn_id: u32,
+ running_turn: Option<RunningTurn>,
connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId,
token_usage: Option<TokenUsage>,
@@ -952,9 +958,6 @@ pub struct AcpThread {
terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
- // subagent cancellation fields
- user_stopped: Arc<std::sync::atomic::AtomicBool>,
- user_stop_tx: watch::Sender<bool>,
}
impl From<&AcpThread> for ActionLogTelemetry {
@@ -1172,8 +1175,6 @@ impl AcpThread {
}
});
- let (user_stop_tx, _user_stop_rx) = watch::channel(false);
-
Self {
parent_session_id,
action_log,
@@ -1182,7 +1183,8 @@ impl AcpThread {
plan: Default::default(),
title: title.into(),
project,
- send_task: None,
+ running_turn: None,
+ turn_id: 0,
connection,
session_id,
token_usage: None,
@@ -1191,8 +1193,6 @@ impl AcpThread {
terminals: HashMap::default(),
pending_terminal_output: HashMap::default(),
pending_terminal_exit: HashMap::default(),
- user_stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
- user_stop_tx,
}
}
@@ -1204,22 +1204,6 @@ impl AcpThread {
self.prompt_capabilities.clone()
}
- /// Marks this thread as stopped by user action and signals any listeners.
- pub fn stop_by_user(&mut self) {
- self.user_stopped
- .store(true, std::sync::atomic::Ordering::SeqCst);
- self.user_stop_tx.send(true).ok();
- self.send_task.take();
- }
-
- pub fn was_stopped_by_user(&self) -> bool {
- self.user_stopped.load(std::sync::atomic::Ordering::SeqCst)
- }
-
- pub fn user_stop_receiver(&self) -> watch::Receiver<bool> {
- self.user_stop_tx.receiver()
- }
-
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
&self.connection
}
@@ -1245,7 +1229,7 @@ impl AcpThread {
}
pub fn status(&self) -> ThreadStatus {
- if self.send_task.is_some() {
+ if self.running_turn.is_some() {
ThreadStatus::Generating
} else {
ThreadStatus::Idle
@@ -1860,7 +1844,7 @@ impl AcpThread {
&mut self,
message: &str,
cx: &mut Context<Self>,
- ) -> BoxFuture<'static, Result<()>> {
+ ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
self.send(vec![message.into()], cx)
}
@@ -1868,7 +1852,7 @@ impl AcpThread {
&mut self,
message: Vec<acp::ContentBlock>,
cx: &mut Context<Self>,
- ) -> BoxFuture<'static, Result<()>> {
+ ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
let block = ContentBlock::new_combined(
message.clone(),
self.project.read(cx).languages().clone(),
@@ -1921,7 +1905,10 @@ impl AcpThread {
self.connection.retry(&self.session_id, cx).is_some()
}
- pub fn retry(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
+ pub fn retry(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
self.run_turn(cx, async move |this, cx| {
this.update(cx, |this, cx| {
this.connection
@@ -1937,16 +1924,21 @@ impl AcpThread {
&mut self,
cx: &mut Context<Self>,
f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
- ) -> BoxFuture<'static, Result<()>> {
+ ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
self.clear_completed_plan_entries(cx);
let (tx, rx) = oneshot::channel();
let cancel_task = self.cancel(cx);
- self.send_task = Some(cx.spawn(async move |this, cx| {
- cancel_task.await;
- tx.send(f(this, cx).await).ok();
- }));
+ self.turn_id += 1;
+ let turn_id = self.turn_id;
+ self.running_turn = Some(RunningTurn {
+ id: turn_id,
+ send_task: cx.spawn(async move |this, cx| {
+ cancel_task.await;
+ tx.send(f(this, cx).await).ok();
+ }),
+ });
cx.spawn(async move |this, cx| {
let response = rx.await;
@@ -1957,43 +1949,39 @@ impl AcpThread {
this.update(cx, |this, cx| {
this.project
.update(cx, |project, cx| project.set_agent_location(None, cx));
+
+ let Ok(response) = response else {
+ // tx dropped, just return
+ return Ok(None);
+ };
+
+ let is_same_turn = this
+ .running_turn
+ .as_ref()
+ .is_some_and(|turn| turn_id == turn.id);
+
+ // If the user submitted a follow up message, running_turn might
+ // already point to a different turn. Therefore we only want to
+ // take the task if it's the same turn.
+ if is_same_turn {
+ this.running_turn.take();
+ }
+
match response {
- Ok(Err(e)) => {
- this.send_task.take();
- cx.emit(AcpThreadEvent::Error);
- log::error!("Error in run turn: {:?}", e);
- Err(e)
- }
- Ok(Ok(r)) if r.stop_reason == acp::StopReason::MaxTokens => {
- this.send_task.take();
- cx.emit(AcpThreadEvent::Error);
- log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
- Err(anyhow!("Max tokens reached"))
- }
- result => {
- let canceled = matches!(
- result,
- Ok(Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::Cancelled,
- ..
- }))
- );
-
- // We only take the task if the current prompt wasn't canceled.
- //
- // This prompt may have been canceled because another one was sent
- // while it was still generating. In these cases, dropping `send_task`
- // would cause the next generation to be canceled.
- if !canceled {
- this.send_task.take();
+ Ok(r) => {
+ if r.stop_reason == acp::StopReason::MaxTokens {
+ cx.emit(AcpThreadEvent::Error);
+ log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
+ return Err(anyhow!("Max tokens reached"));
+ }
+
+ let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
+ if canceled {
+ this.mark_pending_tools_as_canceled();
}
// Handle refusal - distinguish between user prompt and tool call refusals
- if let Ok(Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::Refusal,
- ..
- })) = result
- {
+ if let acp::StopReason::Refusal = r.stop_reason {
if let Some((user_msg_ix, _)) = this.last_user_message() {
// Check if there's a completed tool call with results after the last user message
// This indicates the refusal is in response to tool output, not the user's prompt
@@ -2028,7 +2016,12 @@ impl AcpThread {
}
cx.emit(AcpThreadEvent::Stopped);
- Ok(())
+ Ok(Some(r))
+ }
+ Err(e) => {
+ cx.emit(AcpThreadEvent::Error);
+ log::error!("Error in run turn: {:?}", e);
+ Err(e)
}
}
})?
@@ -2037,10 +2030,18 @@ impl AcpThread {
}
pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
- let Some(send_task) = self.send_task.take() else {
+ let Some(turn) = self.running_turn.take() else {
return Task::ready(());
};
+ self.connection.cancel(&self.session_id, cx);
+
+ self.mark_pending_tools_as_canceled();
+
+ // Wait for the send task to complete
+ cx.background_spawn(turn.send_task)
+ }
+ fn mark_pending_tools_as_canceled(&mut self) {
for entry in self.entries.iter_mut() {
if let AgentThreadEntry::ToolCall(call) = entry {
let cancel = matches!(
@@ -2055,11 +2056,6 @@ impl AcpThread {
}
}
}
-
- self.connection.cancel(&self.session_id, cx);
-
- // Wait for the send task to complete
- cx.foreground_executor().spawn(send_task)
}
/// Restores the git working tree to the state at the given checkpoint (if one exists)
@@ -3957,18 +3953,7 @@ mod tests {
}
}
- fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
- let sessions = self.sessions.lock();
- let thread = sessions.get(session_id).unwrap().clone();
-
- cx.spawn(async move |cx| {
- thread
- .update(cx, |thread, cx| thread.cancel(cx))
- .unwrap()
- .await
- })
- .detach();
- }
+ fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
fn truncate(
&self,
@@ -4298,7 +4283,7 @@ mod tests {
// Verify that no send_task is in progress after restore
// (cancel() clears the send_task)
- let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
+ let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some());
assert!(
!has_send_task_after,
"Should not have a send_task after restore (cancel should have cleared it)"
@@ -4419,4 +4404,161 @@ mod tests {
result.err()
);
}
+
+ /// Tests that when a follow-up message is sent during generation,
+ /// the first turn completing does NOT clear `running_turn` because
+ /// it now belongs to the second turn.
+ #[gpui::test]
+ async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+
+ // First handler waits for this signal before completing
+ let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>();
+ let first_complete_rx = RefCell::new(Some(first_complete_rx));
+
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+ move |params, _thread, _cx| {
+ let first_complete_rx = first_complete_rx.borrow_mut().take();
+ let is_first = params
+ .prompt
+ .iter()
+ .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first")));
+
+ async move {
+ if is_first {
+ // First handler waits until signaled
+ if let Some(rx) = first_complete_rx {
+ rx.await.ok();
+ }
+ }
+ Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
+ }
+ .boxed_local()
+ }
+ }));
+
+ let thread = cx
+ .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .await
+ .unwrap();
+
+ // Send first message (turn_id=1) - handler will block
+ let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx));
+ assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1);
+
+ // Send second message (turn_id=2) while first is still blocked
+ // This calls cancel() which takes turn 1's running_turn and sets turn 2's
+ let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx));
+ assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2);
+
+ let running_turn_after_second_send =
+ thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
+ assert_eq!(
+ running_turn_after_second_send,
+ Some(2),
+ "running_turn should be set to turn 2 after sending second message"
+ );
+
+ // Now signal first handler to complete
+ first_complete_tx.send(()).ok();
+
+ // First request completes - should NOT clear running_turn
+ // because running_turn now belongs to turn 2
+ first_request.await.unwrap();
+
+ let running_turn_after_first =
+ thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
+ assert_eq!(
+ running_turn_after_first,
+ Some(2),
+ "first turn completing should not clear running_turn (belongs to turn 2)"
+ );
+
+ // Second request completes - SHOULD clear running_turn
+ second_request.await.unwrap();
+
+ let running_turn_after_second =
+ thread.read_with(cx, |thread, _| thread.running_turn.is_some());
+ assert!(
+ !running_turn_after_second,
+ "second turn completing should clear running_turn"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled(
+ cx: &mut TestAppContext,
+ ) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message(
+ move |_params, thread, mut cx| {
+ async move {
+ thread
+ .update(&mut cx, |thread, cx| {
+ thread.handle_session_update(
+ acp::SessionUpdate::ToolCall(
+ acp::ToolCall::new(
+ acp::ToolCallId::new("test-tool"),
+ "Test Tool",
+ )
+ .kind(acp::ToolKind::Fetch)
+ .status(acp::ToolCallStatus::InProgress),
+ ),
+ cx,
+ )
+ })
+ .unwrap()
+ .unwrap();
+
+ Ok(acp::PromptResponse::new(acp::StopReason::Cancelled))
+ }
+ .boxed_local()
+ },
+ ));
+
+ let thread = cx
+ .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .await
+ .unwrap();
+
+ let response = thread
+ .update(cx, |thread, cx| thread.send_raw("test message", cx))
+ .await;
+
+ let response = response
+ .expect("send should succeed")
+ .expect("should have response");
+ assert_eq!(
+ response.stop_reason,
+ acp::StopReason::Cancelled,
+ "response should have Cancelled stop_reason"
+ );
+
+ thread.read_with(cx, |thread, _| {
+ let tool_entry = thread
+ .entries
+ .iter()
+ .find_map(|e| {
+ if let AgentThreadEntry::ToolCall(call) = e {
+ Some(call)
+ } else {
+ None
+ }
+ })
+ .expect("should have tool call entry");
+
+ assert!(
+ matches!(tool_entry.status, ToolCallStatus::Canceled),
+ "tool should be marked as Canceled when response is Cancelled, got {:?}",
+ tool_entry.status
+ );
+ });
+ }
}
@@ -1369,8 +1369,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
- if let Some(agent) = agent.sessions.get(session_id) {
- agent
+ if let Some(session) = agent.sessions.get(session_id) {
+ session
.thread
.update(cx, |thread, cx| thread.cancel(cx))
.detach();
@@ -1655,26 +1655,26 @@ impl NativeThreadEnvironment {
if let Some(timer) = timeout_timer {
futures::select! {
_ = timer.fuse() => SubagentInitialPromptResult::Timeout,
- _ = task.fuse() => SubagentInitialPromptResult::Completed,
+ response = task.fuse() => {
+ let response = response.log_err().flatten();
+ if response.is_some_and(|response| {
+ response.stop_reason == acp::StopReason::Cancelled
+ })
+ {
+ SubagentInitialPromptResult::Cancelled
+ } else {
+ SubagentInitialPromptResult::Completed
+ }
+ },
}
} else {
- task.await.log_err();
- SubagentInitialPromptResult::Completed
- }
- })
- .shared();
-
- let mut user_stop_rx: watch::Receiver<bool> =
- acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
-
- let user_cancelled = cx
- .background_spawn(async move {
- loop {
- if *user_stop_rx.borrow() {
- return;
- }
- if user_stop_rx.changed().await.is_err() {
- std::future::pending::<()>().await;
+ let response = task.await.log_err().flatten();
+ if response
+ .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
+ {
+ SubagentInitialPromptResult::Cancelled
+ } else {
+ SubagentInitialPromptResult::Completed
}
}
})
@@ -1686,7 +1686,6 @@ impl NativeThreadEnvironment {
parent_thread: parent_thread_entity.downgrade(),
acp_thread,
wait_for_prompt_to_complete,
- user_cancelled,
}) as _)
}
}
@@ -1750,6 +1749,7 @@ impl ThreadEnvironment for NativeThreadEnvironment {
enum SubagentInitialPromptResult {
Completed,
Timeout,
+ Cancelled,
}
pub struct NativeSubagentHandle {
@@ -1758,7 +1758,6 @@ pub struct NativeSubagentHandle {
subagent_thread: Entity<Thread>,
acp_thread: Entity<AcpThread>,
wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
- user_cancelled: Shared<Task<()>>,
}
impl SubagentHandle for NativeSubagentHandle {
@@ -1775,6 +1774,7 @@ impl SubagentHandle for NativeSubagentHandle {
let timed_out = match wait_for_prompt.await {
SubagentInitialPromptResult::Completed => false,
SubagentInitialPromptResult::Timeout => true,
+ SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
};
let summary_prompt = if timed_out {
@@ -1784,10 +1784,15 @@ impl SubagentHandle for NativeSubagentHandle {
summary_prompt
};
- acp_thread
+ let response = acp_thread
.update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
.await?;
+ let was_canceled = response.is_some_and(|r| r.stop_reason == acp::StopReason::Cancelled);
+ if was_canceled {
+ return Err(anyhow!("User cancelled"));
+ }
+
thread.read_with(cx, |thread, _cx| {
thread
.last_message()
@@ -1796,18 +1801,10 @@ impl SubagentHandle for NativeSubagentHandle {
})
});
- let user_cancelled = self.user_cancelled.clone();
- let thread = self.subagent_thread.clone();
let subagent_session_id = self.session_id.clone();
let parent_thread = self.parent_thread.clone();
cx.spawn(async move |cx| {
- let result = futures::select! {
- result = wait_for_summary_task.fuse() => result,
- _ = user_cancelled.fuse() => {
- thread.update(cx, |thread, cx| thread.cancel(cx).detach());
- Err(anyhow!("User cancelled"))
- },
- };
+ let result = wait_for_summary_task.await;
parent_thread
.update(cx, |parent_thread, cx| {
parent_thread.unregister_running_subagent(&subagent_session_id, cx)
@@ -1,6 +1,7 @@
use super::*;
use acp_thread::{
- AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
+ AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
+ UserMessageId,
};
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
@@ -160,15 +161,6 @@ struct FakeSubagentHandle {
wait_for_summary_task: Shared<Task<String>>,
}
-impl FakeSubagentHandle {
- fn new_never_completes(cx: &App) -> Self {
- Self {
- session_id: acp::SessionId::new("subagent-id"),
- wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
- }
- }
-}
-
impl SubagentHandle for FakeSubagentHandle {
fn id(&self) -> acp::SessionId {
self.session_id.clone()
@@ -193,13 +185,6 @@ impl FakeThreadEnvironment {
..self
}
}
-
- pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
- Self {
- subagent_handle: Some(subagent_handle.into()),
- ..self
- }
- }
}
impl crate::ThreadEnvironment for FakeThreadEnvironment {
@@ -4190,6 +4175,457 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
}
}
+#[gpui::test]
+async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
+ init_test(cx);
+ cx.update(|cx| {
+ LanguageModelRegistry::test(cx);
+ });
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/",
+ json!({
+ "a": {
+ "b.md": "Lorem"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ let agent = NativeAgent::new(
+ project.clone(),
+ thread_store.clone(),
+ Templates::new(),
+ None,
+ fs.clone(),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project.clone(), Path::new(""), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+ let model = Arc::new(FakeLanguageModel::default());
+
+ // Ensure empty threads are not saved, even if they get mutated.
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
+ });
+ cx.run_until_parked();
+
+ let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+ cx.run_until_parked();
+ model.send_last_completion_stream_text_chunk("spawning subagent");
+ let subagent_tool_input = SubagentToolInput {
+ label: "label".to_string(),
+ task_prompt: "subagent task prompt".to_string(),
+ summary_prompt: "subagent summary prompt".to_string(),
+ timeout_ms: None,
+ allowed_tools: None,
+ };
+ let subagent_tool_use = LanguageModelToolUse {
+ id: "subagent_1".into(),
+ name: SubagentTool::NAME.into(),
+ raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+ input: serde_json::to_value(&subagent_tool_input).unwrap(),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ subagent_tool_use,
+ ));
+ model.end_last_completion_stream();
+
+ cx.run_until_parked();
+
+ let subagent_session_id = thread.read_with(cx, |thread, cx| {
+ thread
+ .running_subagent_ids(cx)
+ .get(0)
+ .expect("subagent thread should be running")
+ .clone()
+ });
+
+ let subagent_thread = agent.read_with(cx, |agent, _cx| {
+ agent
+ .sessions
+ .get(&subagent_session_id)
+ .expect("subagent session should exist")
+ .acp_thread
+ .clone()
+ });
+
+ model.send_last_completion_stream_text_chunk("subagent task response");
+ model.end_last_completion_stream();
+
+ cx.run_until_parked();
+
+ model.send_last_completion_stream_text_chunk("subagent summary response");
+ model.end_last_completion_stream();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
+ indoc! {"
+ ## User
+
+ subagent task prompt
+
+ ## Assistant
+
+ subagent task response
+
+ ## User
+
+ subagent summary prompt
+
+ ## Assistant
+
+ subagent summary response
+
+ "}
+ );
+
+ model.send_last_completion_stream_text_chunk("Response");
+ model.end_last_completion_stream();
+
+ send.await.unwrap();
+
+ assert_eq!(
+ acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
+ format!(
+ indoc! {r#"
+ ## User
+
+ Prompt
+
+ ## Assistant
+
+ spawning subagent
+
+ **Tool Call: label**
+ Status: Completed
+
+ ```json
+ {{
+ "subagent_session_id": "{}",
+ "summary": "subagent summary response\n"
+ }}
+ ```
+
+ ## Assistant
+
+ Response
+
+ "#},
+ subagent_session_id
+ )
+ );
+}
+
+#[gpui::test]
+async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
+ init_test(cx);
+ cx.update(|cx| {
+ LanguageModelRegistry::test(cx);
+ });
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/",
+ json!({
+ "a": {
+ "b.md": "Lorem"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ let agent = NativeAgent::new(
+ project.clone(),
+ thread_store.clone(),
+ Templates::new(),
+ None,
+ fs.clone(),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project.clone(), Path::new(""), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+ let model = Arc::new(FakeLanguageModel::default());
+
+ // Ensure empty threads are not saved, even if they get mutated.
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
+ });
+ cx.run_until_parked();
+
+ let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+ cx.run_until_parked();
+ model.send_last_completion_stream_text_chunk("spawning subagent");
+ let subagent_tool_input = SubagentToolInput {
+ label: "label".to_string(),
+ task_prompt: "subagent task prompt".to_string(),
+ summary_prompt: "subagent summary prompt".to_string(),
+ timeout_ms: None,
+ allowed_tools: None,
+ };
+ let subagent_tool_use = LanguageModelToolUse {
+ id: "subagent_1".into(),
+ name: SubagentTool::NAME.into(),
+ raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+ input: serde_json::to_value(&subagent_tool_input).unwrap(),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ subagent_tool_use,
+ ));
+ model.end_last_completion_stream();
+
+ cx.run_until_parked();
+
+ let subagent_session_id = thread.read_with(cx, |thread, cx| {
+ thread
+ .running_subagent_ids(cx)
+ .get(0)
+ .expect("subagent thread should be running")
+ .clone()
+ });
+ let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
+ agent
+ .sessions
+ .get(&subagent_session_id)
+ .expect("subagent session should exist")
+ .acp_thread
+ .clone()
+ });
+
+ // model.send_last_completion_stream_text_chunk("subagent task response");
+ // model.end_last_completion_stream();
+
+ // cx.run_until_parked();
+
+ acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
+
+ cx.run_until_parked();
+
+ send.await.unwrap();
+
+ acp_thread.read_with(cx, |thread, cx| {
+ assert_eq!(thread.status(), ThreadStatus::Idle);
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ Prompt
+
+ ## Assistant
+
+ spawning subagent
+
+ **Tool Call: label**
+ Status: Canceled
+
+ "}
+ );
+ });
+ subagent_acp_thread.read_with(cx, |thread, cx| {
+ assert_eq!(thread.status(), ThreadStatus::Idle);
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ subagent task prompt
+
+ "}
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subagent_tool_call_cancellation_during_summary_prompt(cx: &mut TestAppContext) {
+ init_test(cx);
+ cx.update(|cx| {
+ LanguageModelRegistry::test(cx);
+ });
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/",
+ json!({
+ "a": {
+ "b.md": "Lorem"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ let agent = NativeAgent::new(
+ project.clone(),
+ thread_store.clone(),
+ Templates::new(),
+ None,
+ fs.clone(),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project.clone(), Path::new(""), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+ let model = Arc::new(FakeLanguageModel::default());
+
+ // Ensure empty threads are not saved, even if they get mutated.
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
+ });
+ cx.run_until_parked();
+
+ let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+ cx.run_until_parked();
+ model.send_last_completion_stream_text_chunk("spawning subagent");
+ let subagent_tool_input = SubagentToolInput {
+ label: "label".to_string(),
+ task_prompt: "subagent task prompt".to_string(),
+ summary_prompt: "subagent summary prompt".to_string(),
+ timeout_ms: None,
+ allowed_tools: None,
+ };
+ let subagent_tool_use = LanguageModelToolUse {
+ id: "subagent_1".into(),
+ name: SubagentTool::NAME.into(),
+ raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+ input: serde_json::to_value(&subagent_tool_input).unwrap(),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ subagent_tool_use,
+ ));
+ model.end_last_completion_stream();
+
+ cx.run_until_parked();
+
+ let subagent_session_id = thread.read_with(cx, |thread, cx| {
+ thread
+ .running_subagent_ids(cx)
+ .get(0)
+ .expect("subagent thread should be running")
+ .clone()
+ });
+ let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
+ agent
+ .sessions
+ .get(&subagent_session_id)
+ .expect("subagent session should exist")
+ .acp_thread
+ .clone()
+ });
+
+ model.send_last_completion_stream_text_chunk("subagent task response");
+ model.end_last_completion_stream();
+
+ cx.run_until_parked();
+
+ acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
+
+ cx.run_until_parked();
+
+ send.await.unwrap();
+
+ acp_thread.read_with(cx, |thread, cx| {
+ assert_eq!(thread.status(), ThreadStatus::Idle);
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ Prompt
+
+ ## Assistant
+
+ spawning subagent
+
+ **Tool Call: label**
+ Status: Canceled
+
+ "}
+ );
+ });
+ subagent_acp_thread.read_with(cx, |thread, cx| {
+ assert_eq!(thread.status(), ThreadStatus::Idle);
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ subagent task prompt
+
+ ## Assistant
+
+ subagent task response
+
+ ## User
+
+ subagent summary prompt
+
+ "}
+ );
+ });
+}
+
#[gpui::test]
async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
init_test(cx);
@@ -4382,84 +4818,6 @@ async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
});
}
-#[gpui::test]
-async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
- // This test verifies that the subagent tool properly handles user cancellation
- // via `event_stream.cancelled_by_user()` and stops all running subagents.
- init_test(cx);
- always_allow_tools(cx);
-
- cx.update(|cx| {
- cx.update_flags(true, vec!["subagents".to_string()]);
- });
-
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(path!("/test"), json!({})).await;
- let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
- let project_context = cx.new(|_cx| ProjectContext::default());
- let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let environment = Rc::new(cx.update(|cx| {
- FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
- }));
-
- let parent = cx.new(|cx| {
- Thread::new(
- project.clone(),
- project_context.clone(),
- context_server_registry.clone(),
- Templates::new(),
- Some(model.clone()),
- cx,
- )
- });
-
- #[allow(clippy::arc_with_non_send_sync)]
- let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
-
- let (event_stream, _rx, mut cancellation_tx) =
- crate::ToolCallEventStream::test_with_cancellation();
-
- // Start the subagent tool
- let task = cx.update(|cx| {
- tool.run(
- SubagentToolInput {
- label: "Long running task".to_string(),
- task_prompt: "Do a very long task that takes forever".to_string(),
- summary_prompt: "Summarize".to_string(),
- timeout_ms: None,
- allowed_tools: None,
- },
- event_stream.clone(),
- cx,
- )
- });
-
- cx.run_until_parked();
-
- // Signal cancellation via the event stream
- crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
-
- // The task should complete promptly with a cancellation error
- let timeout = cx.background_executor.timer(Duration::from_secs(5));
- let result = futures::select! {
- result = task.fuse() => result,
- _ = timeout.fuse() => {
- panic!("subagent tool did not respond to cancellation within timeout");
- }
- };
-
- // Verify we got a cancellation error
- let err = result.unwrap_err();
- assert!(
- err.to_string().contains("cancelled by user"),
- "expected cancellation error, got: {}",
- err
- );
-}
-
#[gpui::test]
async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
init_test(cx);
@@ -2582,6 +2582,14 @@ impl Thread {
});
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn running_subagent_ids(&self, cx: &App) -> Vec<acp::SessionId> {
+ self.running_subagents
+ .iter()
+ .filter_map(|s| s.upgrade().map(|s| s.read(cx).id().clone()))
+ .collect()
+ }
+
pub fn running_subagent_count(&self) -> usize {
self.running_subagents
.iter()
@@ -1,7 +1,6 @@
use acp_thread::SUBAGENT_SESSION_ID_META_KEY;
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
-use futures::FutureExt as _;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use language_model::LanguageModelToolResultContent;
use schemars::JsonSchema;
@@ -171,17 +170,11 @@ impl AgentTool for SubagentTool {
event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta));
cx.spawn(async move |cx| {
- let summary_task = subagent.wait_for_summary(input.summary_prompt, cx);
-
- futures::select_biased! {
- summary = summary_task.fuse() => summary.map(|summary| SubagentToolOutput {
- summary,
- subagent_session_id,
- }),
- _ = event_stream.cancelled_by_user().fuse() => {
- Err(anyhow!("Subagent was cancelled by user"))
- }
- }
+ let summary = subagent.wait_for_summary(input.summary_prompt, cx).await?;
+ Ok(SubagentToolOutput {
+ subagent_session_id,
+ summary,
+ })
})
}
@@ -810,7 +810,7 @@ impl AcpThreadView {
status,
turn_time_ms,
);
- res
+ res.map(|_| ())
});
cx.spawn(async move |this, cx| {
@@ -6164,8 +6164,8 @@ impl AcpThreadView {
|this, thread| {
this.on_click(cx.listener(
move |_this, _event, _window, cx| {
- thread.update(cx, |thread, _cx| {
- thread.stop_by_user();
+ thread.update(cx, |thread, cx| {
+ thread.cancel(cx).detach();
});
},
))