From 27dffc12df16723205001162bd725ba874c044e8 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Wed, 18 Feb 2026 18:33:11 +0100 Subject: [PATCH] Fix cancellation issues with subagents (#49350) This PR fixes issues with subagent cancellation, prior to this we would still show a wait indicator for subagents that were cancelled from the parent thread. We also removed the `stop_by_user` workaround, that code path now uses `thread.cancel` directly Before you mark this PR as ready for review, make sure that you have: - [x] Added a solid test coverage and/or screenshots from doing manual testing - [x] Done a self-review taking into account security and performance aspects - [x] Aligned any UI changes with the [UI checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) Release Notes: - N/A --------- Co-authored-by: Ben Brandt --- crates/acp_thread/src/acp_thread.rs | 316 +++++++--- crates/agent/src/agent.rs | 61 +- crates/agent/src/tests/mod.rs | 548 +++++++++++++++--- crates/agent/src/thread.rs | 8 + crates/agent/src/tools/subagent_tool.rs | 17 +- .../src/acp/thread_view/active_thread.rs | 6 +- 6 files changed, 727 insertions(+), 229 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 433abfb12e206bd1d289d9c5b418955917f22744..786ed9e0aa2663d5c55ba00e49ae834448e79260 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -935,6 +935,11 @@ pub struct RetryStatus { pub duration: Duration, } +struct RunningTurn { + id: u32, + send_task: Task<()>, +} + pub struct AcpThread { parent_session_id: Option, title: SharedString, @@ -943,7 +948,8 @@ pub struct AcpThread { project: Entity, action_log: Entity, shared_buffers: HashMap, BufferSnapshot>, - send_task: Option>, + turn_id: u32, + running_turn: Option, connection: Rc, session_id: acp::SessionId, token_usage: Option, @@ -952,9 +958,6 @@ pub struct AcpThread { terminals: HashMap>, pending_terminal_output: HashMap>>, pending_terminal_exit: HashMap, - // subagent cancellation fields - user_stopped: Arc, - user_stop_tx: watch::Sender, } impl From<&AcpThread> for ActionLogTelemetry { @@ -1172,8 +1175,6 @@ impl AcpThread { } }); - let (user_stop_tx, _user_stop_rx) = watch::channel(false); - Self { parent_session_id, action_log, @@ -1182,7 +1183,8 @@ impl AcpThread { plan: Default::default(), title: title.into(), project, - send_task: None, + running_turn: None, + turn_id: 0, connection, session_id, token_usage: None, @@ -1191,8 +1193,6 @@ impl AcpThread { terminals: HashMap::default(), pending_terminal_output: HashMap::default(), pending_terminal_exit: HashMap::default(), - user_stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)), - user_stop_tx, } } @@ -1204,22 +1204,6 @@ impl AcpThread { self.prompt_capabilities.clone() } - /// Marks this thread as stopped by user action and signals any listeners. - pub fn stop_by_user(&mut self) { - self.user_stopped - .store(true, std::sync::atomic::Ordering::SeqCst); - self.user_stop_tx.send(true).ok(); - self.send_task.take(); - } - - pub fn was_stopped_by_user(&self) -> bool { - self.user_stopped.load(std::sync::atomic::Ordering::SeqCst) - } - - pub fn user_stop_receiver(&self) -> watch::Receiver { - self.user_stop_tx.receiver() - } - pub fn connection(&self) -> &Rc { &self.connection } @@ -1245,7 +1229,7 @@ impl AcpThread { } pub fn status(&self) -> ThreadStatus { - if self.send_task.is_some() { + if self.running_turn.is_some() { ThreadStatus::Generating } else { ThreadStatus::Idle @@ -1860,7 +1844,7 @@ impl AcpThread { &mut self, message: &str, cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { + ) -> BoxFuture<'static, Result>> { self.send(vec![message.into()], cx) } @@ -1868,7 +1852,7 @@ impl AcpThread { &mut self, message: Vec, cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { + ) -> BoxFuture<'static, Result>> { let block = ContentBlock::new_combined( message.clone(), self.project.read(cx).languages().clone(), @@ -1921,7 +1905,10 @@ impl AcpThread { self.connection.retry(&self.session_id, cx).is_some() } - pub fn retry(&mut self, cx: &mut Context) -> BoxFuture<'static, Result<()>> { + pub fn retry( + &mut self, + cx: &mut Context, + ) -> BoxFuture<'static, Result>> { self.run_turn(cx, async move |this, cx| { this.update(cx, |this, cx| { this.connection @@ -1937,16 +1924,21 @@ impl AcpThread { &mut self, cx: &mut Context, f: impl 'static + AsyncFnOnce(WeakEntity, &mut AsyncApp) -> Result, - ) -> BoxFuture<'static, Result<()>> { + ) -> BoxFuture<'static, Result>> { self.clear_completed_plan_entries(cx); let (tx, rx) = oneshot::channel(); let cancel_task = self.cancel(cx); - self.send_task = Some(cx.spawn(async move |this, cx| { - cancel_task.await; - tx.send(f(this, cx).await).ok(); - })); + self.turn_id += 1; + let turn_id = self.turn_id; + self.running_turn = Some(RunningTurn { + id: turn_id, + send_task: cx.spawn(async move |this, cx| { + cancel_task.await; + tx.send(f(this, cx).await).ok(); + }), + }); cx.spawn(async move |this, cx| { let response = rx.await; @@ -1957,43 +1949,39 @@ impl AcpThread { this.update(cx, |this, cx| { this.project .update(cx, |project, cx| project.set_agent_location(None, cx)); + + let Ok(response) = response else { + // tx dropped, just return + return Ok(None); + }; + + let is_same_turn = this + .running_turn + .as_ref() + .is_some_and(|turn| turn_id == turn.id); + + // If the user submitted a follow up message, running_turn might + // already point to a different turn. Therefore we only want to + // take the task if it's the same turn. + if is_same_turn { + this.running_turn.take(); + } + match response { - Ok(Err(e)) => { - this.send_task.take(); - cx.emit(AcpThreadEvent::Error); - log::error!("Error in run turn: {:?}", e); - Err(e) - } - Ok(Ok(r)) if r.stop_reason == acp::StopReason::MaxTokens => { - this.send_task.take(); - cx.emit(AcpThreadEvent::Error); - log::error!("Max tokens reached. Usage: {:?}", this.token_usage); - Err(anyhow!("Max tokens reached")) - } - result => { - let canceled = matches!( - result, - Ok(Ok(acp::PromptResponse { - stop_reason: acp::StopReason::Cancelled, - .. - })) - ); - - // We only take the task if the current prompt wasn't canceled. - // - // This prompt may have been canceled because another one was sent - // while it was still generating. In these cases, dropping `send_task` - // would cause the next generation to be canceled. - if !canceled { - this.send_task.take(); + Ok(r) => { + if r.stop_reason == acp::StopReason::MaxTokens { + cx.emit(AcpThreadEvent::Error); + log::error!("Max tokens reached. Usage: {:?}", this.token_usage); + return Err(anyhow!("Max tokens reached")); + } + + let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled); + if canceled { + this.mark_pending_tools_as_canceled(); } // Handle refusal - distinguish between user prompt and tool call refusals - if let Ok(Ok(acp::PromptResponse { - stop_reason: acp::StopReason::Refusal, - .. - })) = result - { + if let acp::StopReason::Refusal = r.stop_reason { if let Some((user_msg_ix, _)) = this.last_user_message() { // Check if there's a completed tool call with results after the last user message // This indicates the refusal is in response to tool output, not the user's prompt @@ -2028,7 +2016,12 @@ impl AcpThread { } cx.emit(AcpThreadEvent::Stopped); - Ok(()) + Ok(Some(r)) + } + Err(e) => { + cx.emit(AcpThreadEvent::Error); + log::error!("Error in run turn: {:?}", e); + Err(e) } } })? @@ -2037,10 +2030,18 @@ impl AcpThread { } pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { - let Some(send_task) = self.send_task.take() else { + let Some(turn) = self.running_turn.take() else { return Task::ready(()); }; + self.connection.cancel(&self.session_id, cx); + + self.mark_pending_tools_as_canceled(); + + // Wait for the send task to complete + cx.background_spawn(turn.send_task) + } + fn mark_pending_tools_as_canceled(&mut self) { for entry in self.entries.iter_mut() { if let AgentThreadEntry::ToolCall(call) = entry { let cancel = matches!( @@ -2055,11 +2056,6 @@ impl AcpThread { } } } - - self.connection.cancel(&self.session_id, cx); - - // Wait for the send task to complete - cx.foreground_executor().spawn(send_task) } /// Restores the git working tree to the state at the given checkpoint (if one exists) @@ -3957,18 +3953,7 @@ mod tests { } } - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { - let sessions = self.sessions.lock(); - let thread = sessions.get(session_id).unwrap().clone(); - - cx.spawn(async move |cx| { - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .unwrap() - .await - }) - .detach(); - } + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {} fn truncate( &self, @@ -4298,7 +4283,7 @@ mod tests { // Verify that no send_task is in progress after restore // (cancel() clears the send_task) - let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some()); + let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some()); assert!( !has_send_task_after, "Should not have a send_task after restore (cancel should have cleared it)" @@ -4419,4 +4404,161 @@ mod tests { result.err() ); } + + /// Tests that when a follow-up message is sent during generation, + /// the first turn completing does NOT clear `running_turn` because + /// it now belongs to the second turn. + #[gpui::test] + async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + + // First handler waits for this signal before completing + let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>(); + let first_complete_rx = RefCell::new(Some(first_complete_rx)); + + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + move |params, _thread, _cx| { + let first_complete_rx = first_complete_rx.borrow_mut().take(); + let is_first = params + .prompt + .iter() + .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first"))); + + async move { + if is_first { + // First handler waits until signaled + if let Some(rx) = first_complete_rx { + rx.await.ok(); + } + } + Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) + } + .boxed_local() + } + })); + + let thread = cx + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + // Send first message (turn_id=1) - handler will block + let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx)); + assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1); + + // Send second message (turn_id=2) while first is still blocked + // This calls cancel() which takes turn 1's running_turn and sets turn 2's + let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx)); + assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2); + + let running_turn_after_second_send = + thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id)); + assert_eq!( + running_turn_after_second_send, + Some(2), + "running_turn should be set to turn 2 after sending second message" + ); + + // Now signal first handler to complete + first_complete_tx.send(()).ok(); + + // First request completes - should NOT clear running_turn + // because running_turn now belongs to turn 2 + first_request.await.unwrap(); + + let running_turn_after_first = + thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id)); + assert_eq!( + running_turn_after_first, + Some(2), + "first turn completing should not clear running_turn (belongs to turn 2)" + ); + + // Second request completes - SHOULD clear running_turn + second_request.await.unwrap(); + + let running_turn_after_second = + thread.read_with(cx, |thread, _| thread.running_turn.is_some()); + assert!( + !running_turn_after_second, + "second turn completing should clear running_turn" + ); + } + + #[gpui::test] + async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + + let connection = Rc::new(FakeAgentConnection::new().on_user_message( + move |_params, thread, mut cx| { + async move { + thread + .update(&mut cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall( + acp::ToolCall::new( + acp::ToolCallId::new("test-tool"), + "Test Tool", + ) + .kind(acp::ToolKind::Fetch) + .status(acp::ToolCallStatus::InProgress), + ), + cx, + ) + }) + .unwrap() + .unwrap(); + + Ok(acp::PromptResponse::new(acp::StopReason::Cancelled)) + } + .boxed_local() + }, + )); + + let thread = cx + .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + let response = thread + .update(cx, |thread, cx| thread.send_raw("test message", cx)) + .await; + + let response = response + .expect("send should succeed") + .expect("should have response"); + assert_eq!( + response.stop_reason, + acp::StopReason::Cancelled, + "response should have Cancelled stop_reason" + ); + + thread.read_with(cx, |thread, _| { + let tool_entry = thread + .entries + .iter() + .find_map(|e| { + if let AgentThreadEntry::ToolCall(call) = e { + Some(call) + } else { + None + } + }) + .expect("should have tool call entry"); + + assert!( + matches!(tool_entry.status, ToolCallStatus::Canceled), + "tool should be marked as Canceled when response is Cancelled, got {:?}", + tool_entry.status + ); + }); + } } diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 80649e406595027f95c61ec844d496c32d41fd07..0418547c33c0e891a4d703d0c92e006e475bef07 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -1369,8 +1369,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { - if let Some(agent) = agent.sessions.get(session_id) { - agent + if let Some(session) = agent.sessions.get(session_id) { + session .thread .update(cx, |thread, cx| thread.cancel(cx)) .detach(); @@ -1655,26 +1655,26 @@ impl NativeThreadEnvironment { if let Some(timer) = timeout_timer { futures::select! { _ = timer.fuse() => SubagentInitialPromptResult::Timeout, - _ = task.fuse() => SubagentInitialPromptResult::Completed, + response = task.fuse() => { + let response = response.log_err().flatten(); + if response.is_some_and(|response| { + response.stop_reason == acp::StopReason::Cancelled + }) + { + SubagentInitialPromptResult::Cancelled + } else { + SubagentInitialPromptResult::Completed + } + }, } } else { - task.await.log_err(); - SubagentInitialPromptResult::Completed - } - }) - .shared(); - - let mut user_stop_rx: watch::Receiver = - acp_thread.update(cx, |thread, _| thread.user_stop_receiver()); - - let user_cancelled = cx - .background_spawn(async move { - loop { - if *user_stop_rx.borrow() { - return; - } - if user_stop_rx.changed().await.is_err() { - std::future::pending::<()>().await; + let response = task.await.log_err().flatten(); + if response + .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled) + { + SubagentInitialPromptResult::Cancelled + } else { + SubagentInitialPromptResult::Completed } } }) @@ -1686,7 +1686,6 @@ impl NativeThreadEnvironment { parent_thread: parent_thread_entity.downgrade(), acp_thread, wait_for_prompt_to_complete, - user_cancelled, }) as _) } } @@ -1750,6 +1749,7 @@ impl ThreadEnvironment for NativeThreadEnvironment { enum SubagentInitialPromptResult { Completed, Timeout, + Cancelled, } pub struct NativeSubagentHandle { @@ -1758,7 +1758,6 @@ pub struct NativeSubagentHandle { subagent_thread: Entity, acp_thread: Entity, wait_for_prompt_to_complete: Shared>, - user_cancelled: Shared>, } impl SubagentHandle for NativeSubagentHandle { @@ -1775,6 +1774,7 @@ impl SubagentHandle for NativeSubagentHandle { let timed_out = match wait_for_prompt.await { SubagentInitialPromptResult::Completed => false, SubagentInitialPromptResult::Timeout => true, + SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")), }; let summary_prompt = if timed_out { @@ -1784,10 +1784,15 @@ impl SubagentHandle for NativeSubagentHandle { summary_prompt }; - acp_thread + let response = acp_thread .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx)) .await?; + let was_canceled = response.is_some_and(|r| r.stop_reason == acp::StopReason::Cancelled); + if was_canceled { + return Err(anyhow!("User cancelled")); + } + thread.read_with(cx, |thread, _cx| { thread .last_message() @@ -1796,18 +1801,10 @@ impl SubagentHandle for NativeSubagentHandle { }) }); - let user_cancelled = self.user_cancelled.clone(); - let thread = self.subagent_thread.clone(); let subagent_session_id = self.session_id.clone(); let parent_thread = self.parent_thread.clone(); cx.spawn(async move |cx| { - let result = futures::select! { - result = wait_for_summary_task.fuse() => result, - _ = user_cancelled.fuse() => { - thread.update(cx, |thread, cx| thread.cancel(cx).detach()); - Err(anyhow!("User cancelled")) - }, - }; + let result = wait_for_summary_task.await; parent_thread .update(cx, |parent_thread, cx| { parent_thread.unregister_running_subagent(&subagent_session_id, cx) diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 25bdcea4d806a173101d80822e5807dc9bcd239b..b117a8b54c43a8157641077958c6e197caff8e53 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -1,6 +1,7 @@ use super::*; use acp_thread::{ - AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId, + AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus, + UserMessageId, }; use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; @@ -160,15 +161,6 @@ struct FakeSubagentHandle { wait_for_summary_task: Shared>, } -impl FakeSubagentHandle { - fn new_never_completes(cx: &App) -> Self { - Self { - session_id: acp::SessionId::new("subagent-id"), - wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(), - } - } -} - impl SubagentHandle for FakeSubagentHandle { fn id(&self) -> acp::SessionId { self.session_id.clone() @@ -193,13 +185,6 @@ impl FakeThreadEnvironment { ..self } } - - pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self { - Self { - subagent_handle: Some(subagent_handle.into()), - ..self - } - } } impl crate::ThreadEnvironment for FakeThreadEnvironment { @@ -4190,6 +4175,457 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { } } +#[gpui::test] +async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + LanguageModelRegistry::test(cx); + }); + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = NativeAgent::new( + project.clone(), + thread_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + let model = Arc::new(FakeLanguageModel::default()); + + // Ensure empty threads are not saved, even if they get mutated. + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + cx.run_until_parked(); + + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("spawning subagent"); + let subagent_tool_input = SubagentToolInput { + label: "label".to_string(), + task_prompt: "subagent task prompt".to_string(), + summary_prompt: "subagent summary prompt".to_string(), + timeout_ms: None, + allowed_tools: None, + }; + let subagent_tool_use = LanguageModelToolUse { + id: "subagent_1".into(), + name: SubagentTool::NAME.into(), + raw_input: serde_json::to_string(&subagent_tool_input).unwrap(), + input: serde_json::to_value(&subagent_tool_input).unwrap(), + is_input_complete: true, + thought_signature: None, + }; + model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + subagent_tool_use, + )); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + let subagent_session_id = thread.read_with(cx, |thread, cx| { + thread + .running_subagent_ids(cx) + .get(0) + .expect("subagent thread should be running") + .clone() + }); + + let subagent_thread = agent.read_with(cx, |agent, _cx| { + agent + .sessions + .get(&subagent_session_id) + .expect("subagent session should exist") + .acp_thread + .clone() + }); + + model.send_last_completion_stream_text_chunk("subagent task response"); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + model.send_last_completion_stream_text_chunk("subagent summary response"); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + assert_eq!( + subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)), + indoc! {" + ## User + + subagent task prompt + + ## Assistant + + subagent task response + + ## User + + subagent summary prompt + + ## Assistant + + subagent summary response + + "} + ); + + model.send_last_completion_stream_text_chunk("Response"); + model.end_last_completion_stream(); + + send.await.unwrap(); + + assert_eq!( + acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)), + format!( + indoc! {r#" + ## User + + Prompt + + ## Assistant + + spawning subagent + + **Tool Call: label** + Status: Completed + + ```json + {{ + "subagent_session_id": "{}", + "summary": "subagent summary response\n" + }} + ``` + + ## Assistant + + Response + + "#}, + subagent_session_id + ) + ); +} + +#[gpui::test] +async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + LanguageModelRegistry::test(cx); + }); + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = NativeAgent::new( + project.clone(), + thread_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + let model = Arc::new(FakeLanguageModel::default()); + + // Ensure empty threads are not saved, even if they get mutated. + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + cx.run_until_parked(); + + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("spawning subagent"); + let subagent_tool_input = SubagentToolInput { + label: "label".to_string(), + task_prompt: "subagent task prompt".to_string(), + summary_prompt: "subagent summary prompt".to_string(), + timeout_ms: None, + allowed_tools: None, + }; + let subagent_tool_use = LanguageModelToolUse { + id: "subagent_1".into(), + name: SubagentTool::NAME.into(), + raw_input: serde_json::to_string(&subagent_tool_input).unwrap(), + input: serde_json::to_value(&subagent_tool_input).unwrap(), + is_input_complete: true, + thought_signature: None, + }; + model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + subagent_tool_use, + )); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + let subagent_session_id = thread.read_with(cx, |thread, cx| { + thread + .running_subagent_ids(cx) + .get(0) + .expect("subagent thread should be running") + .clone() + }); + let subagent_acp_thread = agent.read_with(cx, |agent, _cx| { + agent + .sessions + .get(&subagent_session_id) + .expect("subagent session should exist") + .acp_thread + .clone() + }); + + // model.send_last_completion_stream_text_chunk("subagent task response"); + // model.end_last_completion_stream(); + + // cx.run_until_parked(); + + acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await; + + cx.run_until_parked(); + + send.await.unwrap(); + + acp_thread.read_with(cx, |thread, cx| { + assert_eq!(thread.status(), ThreadStatus::Idle); + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + Prompt + + ## Assistant + + spawning subagent + + **Tool Call: label** + Status: Canceled + + "} + ); + }); + subagent_acp_thread.read_with(cx, |thread, cx| { + assert_eq!(thread.status(), ThreadStatus::Idle); + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + subagent task prompt + + "} + ); + }); +} + +#[gpui::test] +async fn test_subagent_tool_call_cancellation_during_summary_prompt(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + LanguageModelRegistry::test(cx); + }); + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = NativeAgent::new( + project.clone(), + thread_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + let model = Arc::new(FakeLanguageModel::default()); + + // Ensure empty threads are not saved, even if they get mutated. + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + cx.run_until_parked(); + + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("spawning subagent"); + let subagent_tool_input = SubagentToolInput { + label: "label".to_string(), + task_prompt: "subagent task prompt".to_string(), + summary_prompt: "subagent summary prompt".to_string(), + timeout_ms: None, + allowed_tools: None, + }; + let subagent_tool_use = LanguageModelToolUse { + id: "subagent_1".into(), + name: SubagentTool::NAME.into(), + raw_input: serde_json::to_string(&subagent_tool_input).unwrap(), + input: serde_json::to_value(&subagent_tool_input).unwrap(), + is_input_complete: true, + thought_signature: None, + }; + model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + subagent_tool_use, + )); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + let subagent_session_id = thread.read_with(cx, |thread, cx| { + thread + .running_subagent_ids(cx) + .get(0) + .expect("subagent thread should be running") + .clone() + }); + let subagent_acp_thread = agent.read_with(cx, |agent, _cx| { + agent + .sessions + .get(&subagent_session_id) + .expect("subagent session should exist") + .acp_thread + .clone() + }); + + model.send_last_completion_stream_text_chunk("subagent task response"); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await; + + cx.run_until_parked(); + + send.await.unwrap(); + + acp_thread.read_with(cx, |thread, cx| { + assert_eq!(thread.status(), ThreadStatus::Idle); + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + Prompt + + ## Assistant + + spawning subagent + + **Tool Call: label** + Status: Canceled + + "} + ); + }); + subagent_acp_thread.read_with(cx, |thread, cx| { + assert_eq!(thread.status(), ThreadStatus::Idle); + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + subagent task prompt + + ## Assistant + + subagent task response + + ## User + + subagent summary prompt + + "} + ); + }); +} + #[gpui::test] async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) { init_test(cx); @@ -4382,84 +4818,6 @@ async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) { }); } -#[gpui::test] -async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) { - // This test verifies that the subagent tool properly handles user cancellation - // via `event_stream.cancelled_by_user()` and stops all running subagents. - init_test(cx); - always_allow_tools(cx); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let environment = Rc::new(cx.update(|cx| { - FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx)) - })); - - let parent = cx.new(|cx| { - Thread::new( - project.clone(), - project_context.clone(), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - - #[allow(clippy::arc_with_non_send_sync)] - let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment)); - - let (event_stream, _rx, mut cancellation_tx) = - crate::ToolCallEventStream::test_with_cancellation(); - - // Start the subagent tool - let task = cx.update(|cx| { - tool.run( - SubagentToolInput { - label: "Long running task".to_string(), - task_prompt: "Do a very long task that takes forever".to_string(), - summary_prompt: "Summarize".to_string(), - timeout_ms: None, - allowed_tools: None, - }, - event_stream.clone(), - cx, - ) - }); - - cx.run_until_parked(); - - // Signal cancellation via the event stream - crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx); - - // The task should complete promptly with a cancellation error - let timeout = cx.background_executor.timer(Duration::from_secs(5)); - let result = futures::select! { - result = task.fuse() => result, - _ = timeout.fuse() => { - panic!("subagent tool did not respond to cancellation within timeout"); - } - }; - - // Verify we got a cancellation error - let err = result.unwrap_err(); - assert!( - err.to_string().contains("cancelled by user"), - "expected cancellation error, got: {}", - err - ); -} - #[gpui::test] async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 965898f362f028017b6ae3d1daafe7401c64eb8b..9323eb6f70ec75fbaf17b81d42071cf3ad9c331a 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2582,6 +2582,14 @@ impl Thread { }); } + #[cfg(any(test, feature = "test-support"))] + pub fn running_subagent_ids(&self, cx: &App) -> Vec { + self.running_subagents + .iter() + .filter_map(|s| s.upgrade().map(|s| s.read(cx).id().clone())) + .collect() + } + pub fn running_subagent_count(&self) -> usize { self.running_subagents .iter() diff --git a/crates/agent/src/tools/subagent_tool.rs b/crates/agent/src/tools/subagent_tool.rs index ad63ee656d49d209481b69ce5b0dc28f31b3895e..8b38f5f655a37e18cf3114f71103c896ffad8d0b 100644 --- a/crates/agent/src/tools/subagent_tool.rs +++ b/crates/agent/src/tools/subagent_tool.rs @@ -1,7 +1,6 @@ use acp_thread::SUBAGENT_SESSION_ID_META_KEY; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; -use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task, WeakEntity}; use language_model::LanguageModelToolResultContent; use schemars::JsonSchema; @@ -171,17 +170,11 @@ impl AgentTool for SubagentTool { event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta)); cx.spawn(async move |cx| { - let summary_task = subagent.wait_for_summary(input.summary_prompt, cx); - - futures::select_biased! { - summary = summary_task.fuse() => summary.map(|summary| SubagentToolOutput { - summary, - subagent_session_id, - }), - _ = event_stream.cancelled_by_user().fuse() => { - Err(anyhow!("Subagent was cancelled by user")) - } - } + let summary = subagent.wait_for_summary(input.summary_prompt, cx).await?; + Ok(SubagentToolOutput { + subagent_session_id, + summary, + }) }) } diff --git a/crates/agent_ui/src/acp/thread_view/active_thread.rs b/crates/agent_ui/src/acp/thread_view/active_thread.rs index 839b58b07e92b7878a6ec5207840cfeb6ccb0b1f..306994e5db0da15ff29cad9e6e8a7ee24f583305 100644 --- a/crates/agent_ui/src/acp/thread_view/active_thread.rs +++ b/crates/agent_ui/src/acp/thread_view/active_thread.rs @@ -810,7 +810,7 @@ impl AcpThreadView { status, turn_time_ms, ); - res + res.map(|_| ()) }); cx.spawn(async move |this, cx| { @@ -6164,8 +6164,8 @@ impl AcpThreadView { |this, thread| { this.on_click(cx.listener( move |_this, _event, _window, cx| { - thread.update(cx, |thread, _cx| { - thread.stop_by_user(); + thread.update(cx, |thread, cx| { + thread.cancel(cx).detach(); }); }, ))