Fix cancellation issues with subagents (#49350)

Bennet Bo Fenner and Ben Brandt created

This PR fixes issues with subagent cancellation, prior to this we would
still show a wait indicator for subagents that were cancelled from the
parent thread. We also removed the `stop_by_user` workaround, that code
path now uses `thread.cancel` directly

Before you mark this PR as ready for review, make sure that you have:
- [x] Added a solid test coverage and/or screenshots from doing manual
testing
- [x] Done a self-review taking into account security and performance
aspects
- [x] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/acp_thread/src/acp_thread.rs                  | 316 +++++--
crates/agent/src/agent.rs                            |  61 
crates/agent/src/tests/mod.rs                        | 548 +++++++++++--
crates/agent/src/thread.rs                           |   8 
crates/agent/src/tools/subagent_tool.rs              |  17 
crates/agent_ui/src/acp/thread_view/active_thread.rs |   6 
6 files changed, 727 insertions(+), 229 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -935,6 +935,11 @@ pub struct RetryStatus {
     pub duration: Duration,
 }
 
+struct RunningTurn {
+    id: u32,
+    send_task: Task<()>,
+}
+
 pub struct AcpThread {
     parent_session_id: Option<acp::SessionId>,
     title: SharedString,
@@ -943,7 +948,8 @@ pub struct AcpThread {
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
     shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
-    send_task: Option<Task<()>>,
+    turn_id: u32,
+    running_turn: Option<RunningTurn>,
     connection: Rc<dyn AgentConnection>,
     session_id: acp::SessionId,
     token_usage: Option<TokenUsage>,
@@ -952,9 +958,6 @@ pub struct AcpThread {
     terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
     pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
     pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
-    // subagent cancellation fields
-    user_stopped: Arc<std::sync::atomic::AtomicBool>,
-    user_stop_tx: watch::Sender<bool>,
 }
 
 impl From<&AcpThread> for ActionLogTelemetry {
@@ -1172,8 +1175,6 @@ impl AcpThread {
             }
         });
 
-        let (user_stop_tx, _user_stop_rx) = watch::channel(false);
-
         Self {
             parent_session_id,
             action_log,
@@ -1182,7 +1183,8 @@ impl AcpThread {
             plan: Default::default(),
             title: title.into(),
             project,
-            send_task: None,
+            running_turn: None,
+            turn_id: 0,
             connection,
             session_id,
             token_usage: None,
@@ -1191,8 +1193,6 @@ impl AcpThread {
             terminals: HashMap::default(),
             pending_terminal_output: HashMap::default(),
             pending_terminal_exit: HashMap::default(),
-            user_stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
-            user_stop_tx,
         }
     }
 
@@ -1204,22 +1204,6 @@ impl AcpThread {
         self.prompt_capabilities.clone()
     }
 
-    /// Marks this thread as stopped by user action and signals any listeners.
-    pub fn stop_by_user(&mut self) {
-        self.user_stopped
-            .store(true, std::sync::atomic::Ordering::SeqCst);
-        self.user_stop_tx.send(true).ok();
-        self.send_task.take();
-    }
-
-    pub fn was_stopped_by_user(&self) -> bool {
-        self.user_stopped.load(std::sync::atomic::Ordering::SeqCst)
-    }
-
-    pub fn user_stop_receiver(&self) -> watch::Receiver<bool> {
-        self.user_stop_tx.receiver()
-    }
-
     pub fn connection(&self) -> &Rc<dyn AgentConnection> {
         &self.connection
     }
@@ -1245,7 +1229,7 @@ impl AcpThread {
     }
 
     pub fn status(&self) -> ThreadStatus {
-        if self.send_task.is_some() {
+        if self.running_turn.is_some() {
             ThreadStatus::Generating
         } else {
             ThreadStatus::Idle
@@ -1860,7 +1844,7 @@ impl AcpThread {
         &mut self,
         message: &str,
         cx: &mut Context<Self>,
-    ) -> BoxFuture<'static, Result<()>> {
+    ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
         self.send(vec![message.into()], cx)
     }
 
@@ -1868,7 +1852,7 @@ impl AcpThread {
         &mut self,
         message: Vec<acp::ContentBlock>,
         cx: &mut Context<Self>,
-    ) -> BoxFuture<'static, Result<()>> {
+    ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
         let block = ContentBlock::new_combined(
             message.clone(),
             self.project.read(cx).languages().clone(),
@@ -1921,7 +1905,10 @@ impl AcpThread {
         self.connection.retry(&self.session_id, cx).is_some()
     }
 
-    pub fn retry(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
+    pub fn retry(
+        &mut self,
+        cx: &mut Context<Self>,
+    ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
         self.run_turn(cx, async move |this, cx| {
             this.update(cx, |this, cx| {
                 this.connection
@@ -1937,16 +1924,21 @@ impl AcpThread {
         &mut self,
         cx: &mut Context<Self>,
         f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
-    ) -> BoxFuture<'static, Result<()>> {
+    ) -> BoxFuture<'static, Result<Option<acp::PromptResponse>>> {
         self.clear_completed_plan_entries(cx);
 
         let (tx, rx) = oneshot::channel();
         let cancel_task = self.cancel(cx);
 
-        self.send_task = Some(cx.spawn(async move |this, cx| {
-            cancel_task.await;
-            tx.send(f(this, cx).await).ok();
-        }));
+        self.turn_id += 1;
+        let turn_id = self.turn_id;
+        self.running_turn = Some(RunningTurn {
+            id: turn_id,
+            send_task: cx.spawn(async move |this, cx| {
+                cancel_task.await;
+                tx.send(f(this, cx).await).ok();
+            }),
+        });
 
         cx.spawn(async move |this, cx| {
             let response = rx.await;
@@ -1957,43 +1949,39 @@ impl AcpThread {
             this.update(cx, |this, cx| {
                 this.project
                     .update(cx, |project, cx| project.set_agent_location(None, cx));
+
+                let Ok(response) = response else {
+                    // tx dropped, just return
+                    return Ok(None);
+                };
+
+                let is_same_turn = this
+                    .running_turn
+                    .as_ref()
+                    .is_some_and(|turn| turn_id == turn.id);
+
+                // If the user submitted a follow up message, running_turn might
+                // already point to a different turn. Therefore we only want to
+                // take the task if it's the same turn.
+                if is_same_turn {
+                    this.running_turn.take();
+                }
+
                 match response {
-                    Ok(Err(e)) => {
-                        this.send_task.take();
-                        cx.emit(AcpThreadEvent::Error);
-                        log::error!("Error in run turn: {:?}", e);
-                        Err(e)
-                    }
-                    Ok(Ok(r)) if r.stop_reason == acp::StopReason::MaxTokens => {
-                        this.send_task.take();
-                        cx.emit(AcpThreadEvent::Error);
-                        log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
-                        Err(anyhow!("Max tokens reached"))
-                    }
-                    result => {
-                        let canceled = matches!(
-                            result,
-                            Ok(Ok(acp::PromptResponse {
-                                stop_reason: acp::StopReason::Cancelled,
-                                ..
-                            }))
-                        );
-
-                        // We only take the task if the current prompt wasn't canceled.
-                        //
-                        // This prompt may have been canceled because another one was sent
-                        // while it was still generating. In these cases, dropping `send_task`
-                        // would cause the next generation to be canceled.
-                        if !canceled {
-                            this.send_task.take();
+                    Ok(r) => {
+                        if r.stop_reason == acp::StopReason::MaxTokens {
+                            cx.emit(AcpThreadEvent::Error);
+                            log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
+                            return Err(anyhow!("Max tokens reached"));
+                        }
+
+                        let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
+                        if canceled {
+                            this.mark_pending_tools_as_canceled();
                         }
 
                         // Handle refusal - distinguish between user prompt and tool call refusals
-                        if let Ok(Ok(acp::PromptResponse {
-                            stop_reason: acp::StopReason::Refusal,
-                            ..
-                        })) = result
-                        {
+                        if let acp::StopReason::Refusal = r.stop_reason {
                             if let Some((user_msg_ix, _)) = this.last_user_message() {
                                 // Check if there's a completed tool call with results after the last user message
                                 // This indicates the refusal is in response to tool output, not the user's prompt
@@ -2028,7 +2016,12 @@ impl AcpThread {
                         }
 
                         cx.emit(AcpThreadEvent::Stopped);
-                        Ok(())
+                        Ok(Some(r))
+                    }
+                    Err(e) => {
+                        cx.emit(AcpThreadEvent::Error);
+                        log::error!("Error in run turn: {:?}", e);
+                        Err(e)
                     }
                 }
             })?
@@ -2037,10 +2030,18 @@ impl AcpThread {
     }
 
     pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
-        let Some(send_task) = self.send_task.take() else {
+        let Some(turn) = self.running_turn.take() else {
             return Task::ready(());
         };
+        self.connection.cancel(&self.session_id, cx);
+
+        self.mark_pending_tools_as_canceled();
+
+        // Wait for the send task to complete
+        cx.background_spawn(turn.send_task)
+    }
 
+    fn mark_pending_tools_as_canceled(&mut self) {
         for entry in self.entries.iter_mut() {
             if let AgentThreadEntry::ToolCall(call) = entry {
                 let cancel = matches!(
@@ -2055,11 +2056,6 @@ impl AcpThread {
                 }
             }
         }
-
-        self.connection.cancel(&self.session_id, cx);
-
-        // Wait for the send task to complete
-        cx.foreground_executor().spawn(send_task)
     }
 
     /// Restores the git working tree to the state at the given checkpoint (if one exists)
@@ -3957,18 +3953,7 @@ mod tests {
             }
         }
 
-        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
-            let sessions = self.sessions.lock();
-            let thread = sessions.get(session_id).unwrap().clone();
-
-            cx.spawn(async move |cx| {
-                thread
-                    .update(cx, |thread, cx| thread.cancel(cx))
-                    .unwrap()
-                    .await
-            })
-            .detach();
-        }
+        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
 
         fn truncate(
             &self,
@@ -4298,7 +4283,7 @@ mod tests {
 
         // Verify that no send_task is in progress after restore
         // (cancel() clears the send_task)
-        let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
+        let has_send_task_after = thread.read_with(cx, |thread, _| thread.running_turn.is_some());
         assert!(
             !has_send_task_after,
             "Should not have a send_task after restore (cancel should have cleared it)"
@@ -4419,4 +4404,161 @@ mod tests {
             result.err()
         );
     }
+
+    /// Tests that when a follow-up message is sent during generation,
+    /// the first turn completing does NOT clear `running_turn` because
+    /// it now belongs to the second turn.
+    #[gpui::test]
+    async fn test_follow_up_message_during_generation_does_not_clear_turn(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+
+        // First handler waits for this signal before completing
+        let (first_complete_tx, first_complete_rx) = futures::channel::oneshot::channel::<()>();
+        let first_complete_rx = RefCell::new(Some(first_complete_rx));
+
+        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+            move |params, _thread, _cx| {
+                let first_complete_rx = first_complete_rx.borrow_mut().take();
+                let is_first = params
+                    .prompt
+                    .iter()
+                    .any(|c| matches!(c, acp::ContentBlock::Text(t) if t.text.contains("first")));
+
+                async move {
+                    if is_first {
+                        // First handler waits until signaled
+                        if let Some(rx) = first_complete_rx {
+                            rx.await.ok();
+                        }
+                    }
+                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
+                }
+                .boxed_local()
+            }
+        }));
+
+        let thread = cx
+            .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+            .await
+            .unwrap();
+
+        // Send first message (turn_id=1) - handler will block
+        let first_request = thread.update(cx, |thread, cx| thread.send_raw("first", cx));
+        assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 1);
+
+        // Send second message (turn_id=2) while first is still blocked
+        // This calls cancel() which takes turn 1's running_turn and sets turn 2's
+        let second_request = thread.update(cx, |thread, cx| thread.send_raw("second", cx));
+        assert_eq!(thread.read_with(cx, |t, _| t.turn_id), 2);
+
+        let running_turn_after_second_send =
+            thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
+        assert_eq!(
+            running_turn_after_second_send,
+            Some(2),
+            "running_turn should be set to turn 2 after sending second message"
+        );
+
+        // Now signal first handler to complete
+        first_complete_tx.send(()).ok();
+
+        // First request completes - should NOT clear running_turn
+        // because running_turn now belongs to turn 2
+        first_request.await.unwrap();
+
+        let running_turn_after_first =
+            thread.read_with(cx, |thread, _| thread.running_turn.as_ref().map(|t| t.id));
+        assert_eq!(
+            running_turn_after_first,
+            Some(2),
+            "first turn completing should not clear running_turn (belongs to turn 2)"
+        );
+
+        // Second request completes - SHOULD clear running_turn
+        second_request.await.unwrap();
+
+        let running_turn_after_second =
+            thread.read_with(cx, |thread, _| thread.running_turn.is_some());
+        assert!(
+            !running_turn_after_second,
+            "second turn completing should clear running_turn"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_send_returns_cancelled_response_and_marks_tools_as_cancelled(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+
+        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
+            move |_params, thread, mut cx| {
+                async move {
+                    thread
+                        .update(&mut cx, |thread, cx| {
+                            thread.handle_session_update(
+                                acp::SessionUpdate::ToolCall(
+                                    acp::ToolCall::new(
+                                        acp::ToolCallId::new("test-tool"),
+                                        "Test Tool",
+                                    )
+                                    .kind(acp::ToolKind::Fetch)
+                                    .status(acp::ToolCallStatus::InProgress),
+                                ),
+                                cx,
+                            )
+                        })
+                        .unwrap()
+                        .unwrap();
+
+                    Ok(acp::PromptResponse::new(acp::StopReason::Cancelled))
+                }
+                .boxed_local()
+            },
+        ));
+
+        let thread = cx
+            .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+            .await
+            .unwrap();
+
+        let response = thread
+            .update(cx, |thread, cx| thread.send_raw("test message", cx))
+            .await;
+
+        let response = response
+            .expect("send should succeed")
+            .expect("should have response");
+        assert_eq!(
+            response.stop_reason,
+            acp::StopReason::Cancelled,
+            "response should have Cancelled stop_reason"
+        );
+
+        thread.read_with(cx, |thread, _| {
+            let tool_entry = thread
+                .entries
+                .iter()
+                .find_map(|e| {
+                    if let AgentThreadEntry::ToolCall(call) = e {
+                        Some(call)
+                    } else {
+                        None
+                    }
+                })
+                .expect("should have tool call entry");
+
+            assert!(
+                matches!(tool_entry.status, ToolCallStatus::Canceled),
+                "tool should be marked as Canceled when response is Cancelled, got {:?}",
+                tool_entry.status
+            );
+        });
+    }
 }

crates/agent/src/agent.rs 🔗

@@ -1369,8 +1369,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
         log::info!("Cancelling on session: {}", session_id);
         self.0.update(cx, |agent, cx| {
-            if let Some(agent) = agent.sessions.get(session_id) {
-                agent
+            if let Some(session) = agent.sessions.get(session_id) {
+                session
                     .thread
                     .update(cx, |thread, cx| thread.cancel(cx))
                     .detach();
@@ -1655,26 +1655,26 @@ impl NativeThreadEnvironment {
                 if let Some(timer) = timeout_timer {
                     futures::select! {
                         _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
-                        _ = task.fuse() => SubagentInitialPromptResult::Completed,
+                        response = task.fuse() => {
+                            let response = response.log_err().flatten();
+                            if response.is_some_and(|response| {
+                                response.stop_reason == acp::StopReason::Cancelled
+                            })
+                            {
+                                SubagentInitialPromptResult::Cancelled
+                            } else {
+                                SubagentInitialPromptResult::Completed
+                            }
+                        },
                     }
                 } else {
-                    task.await.log_err();
-                    SubagentInitialPromptResult::Completed
-                }
-            })
-            .shared();
-
-        let mut user_stop_rx: watch::Receiver<bool> =
-            acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
-
-        let user_cancelled = cx
-            .background_spawn(async move {
-                loop {
-                    if *user_stop_rx.borrow() {
-                        return;
-                    }
-                    if user_stop_rx.changed().await.is_err() {
-                        std::future::pending::<()>().await;
+                    let response = task.await.log_err().flatten();
+                    if response
+                        .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
+                    {
+                        SubagentInitialPromptResult::Cancelled
+                    } else {
+                        SubagentInitialPromptResult::Completed
                     }
                 }
             })
@@ -1686,7 +1686,6 @@ impl NativeThreadEnvironment {
             parent_thread: parent_thread_entity.downgrade(),
             acp_thread,
             wait_for_prompt_to_complete,
-            user_cancelled,
         }) as _)
     }
 }
@@ -1750,6 +1749,7 @@ impl ThreadEnvironment for NativeThreadEnvironment {
 enum SubagentInitialPromptResult {
     Completed,
     Timeout,
+    Cancelled,
 }
 
 pub struct NativeSubagentHandle {
@@ -1758,7 +1758,6 @@ pub struct NativeSubagentHandle {
     subagent_thread: Entity<Thread>,
     acp_thread: Entity<AcpThread>,
     wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
-    user_cancelled: Shared<Task<()>>,
 }
 
 impl SubagentHandle for NativeSubagentHandle {
@@ -1775,6 +1774,7 @@ impl SubagentHandle for NativeSubagentHandle {
             let timed_out = match wait_for_prompt.await {
                 SubagentInitialPromptResult::Completed => false,
                 SubagentInitialPromptResult::Timeout => true,
+                SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
             };
 
             let summary_prompt = if timed_out {
@@ -1784,10 +1784,15 @@ impl SubagentHandle for NativeSubagentHandle {
                 summary_prompt
             };
 
-            acp_thread
+            let response = acp_thread
                 .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
                 .await?;
 
+            let was_canceled = response.is_some_and(|r| r.stop_reason == acp::StopReason::Cancelled);
+            if was_canceled {
+                return Err(anyhow!("User cancelled"));
+            }
+
             thread.read_with(cx, |thread, _cx| {
                 thread
                     .last_message()
@@ -1796,18 +1801,10 @@ impl SubagentHandle for NativeSubagentHandle {
             })
         });
 
-        let user_cancelled = self.user_cancelled.clone();
-        let thread = self.subagent_thread.clone();
         let subagent_session_id = self.session_id.clone();
         let parent_thread = self.parent_thread.clone();
         cx.spawn(async move |cx| {
-            let result = futures::select! {
-                result = wait_for_summary_task.fuse() => result,
-                _ = user_cancelled.fuse() => {
-                    thread.update(cx, |thread, cx| thread.cancel(cx).detach());
-                    Err(anyhow!("User cancelled"))
-                },
-            };
+            let result = wait_for_summary_task.await;
             parent_thread
                 .update(cx, |parent_thread, cx| {
                     parent_thread.unregister_running_subagent(&subagent_session_id, cx)

crates/agent/src/tests/mod.rs 🔗

@@ -1,6 +1,7 @@
 use super::*;
 use acp_thread::{
-    AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
+    AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
+    UserMessageId,
 };
 use agent_client_protocol::{self as acp};
 use agent_settings::AgentProfileId;
@@ -160,15 +161,6 @@ struct FakeSubagentHandle {
     wait_for_summary_task: Shared<Task<String>>,
 }
 
-impl FakeSubagentHandle {
-    fn new_never_completes(cx: &App) -> Self {
-        Self {
-            session_id: acp::SessionId::new("subagent-id"),
-            wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
-        }
-    }
-}
-
 impl SubagentHandle for FakeSubagentHandle {
     fn id(&self) -> acp::SessionId {
         self.session_id.clone()
@@ -193,13 +185,6 @@ impl FakeThreadEnvironment {
             ..self
         }
     }
-
-    pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
-        Self {
-            subagent_handle: Some(subagent_handle.into()),
-            ..self
-        }
-    }
 }
 
 impl crate::ThreadEnvironment for FakeThreadEnvironment {
@@ -4190,6 +4175,457 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
     }
 }
 
+#[gpui::test]
+async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
+    init_test(cx);
+    cx.update(|cx| {
+        LanguageModelRegistry::test(cx);
+    });
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["subagents".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/",
+        json!({
+            "a": {
+                "b.md": "Lorem"
+            }
+        }),
+    )
+    .await;
+    let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+    let thread_store = cx.new(|cx| ThreadStore::new(cx));
+    let agent = NativeAgent::new(
+        project.clone(),
+        thread_store.clone(),
+        Templates::new(),
+        None,
+        fs.clone(),
+        &mut cx.to_async(),
+    )
+    .await
+    .unwrap();
+    let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+    let acp_thread = cx
+        .update(|cx| {
+            connection
+                .clone()
+                .new_session(project.clone(), Path::new(""), cx)
+        })
+        .await
+        .unwrap();
+    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+    let thread = agent.read_with(cx, |agent, _| {
+        agent.sessions.get(&session_id).unwrap().thread.clone()
+    });
+    let model = Arc::new(FakeLanguageModel::default());
+
+    // Ensure empty threads are not saved, even if they get mutated.
+    thread.update(cx, |thread, cx| {
+        thread.set_model(model.clone(), cx);
+    });
+    cx.run_until_parked();
+
+    let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("spawning subagent");
+    let subagent_tool_input = SubagentToolInput {
+        label: "label".to_string(),
+        task_prompt: "subagent task prompt".to_string(),
+        summary_prompt: "subagent summary prompt".to_string(),
+        timeout_ms: None,
+        allowed_tools: None,
+    };
+    let subagent_tool_use = LanguageModelToolUse {
+        id: "subagent_1".into(),
+        name: SubagentTool::NAME.into(),
+        raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+        input: serde_json::to_value(&subagent_tool_input).unwrap(),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        subagent_tool_use,
+    ));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    let subagent_session_id = thread.read_with(cx, |thread, cx| {
+        thread
+            .running_subagent_ids(cx)
+            .get(0)
+            .expect("subagent thread should be running")
+            .clone()
+    });
+
+    let subagent_thread = agent.read_with(cx, |agent, _cx| {
+        agent
+            .sessions
+            .get(&subagent_session_id)
+            .expect("subagent session should exist")
+            .acp_thread
+            .clone()
+    });
+
+    model.send_last_completion_stream_text_chunk("subagent task response");
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    model.send_last_completion_stream_text_chunk("subagent summary response");
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    assert_eq!(
+        subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
+        indoc! {"
+            ## User
+
+            subagent task prompt
+
+            ## Assistant
+
+            subagent task response
+
+            ## User
+
+            subagent summary prompt
+
+            ## Assistant
+
+            subagent summary response
+
+        "}
+    );
+
+    model.send_last_completion_stream_text_chunk("Response");
+    model.end_last_completion_stream();
+
+    send.await.unwrap();
+
+    assert_eq!(
+        acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
+        format!(
+            indoc! {r#"
+                ## User
+
+                Prompt
+
+                ## Assistant
+
+                spawning subagent
+
+                **Tool Call: label**
+                Status: Completed
+
+                ```json
+                {{
+                  "subagent_session_id": "{}",
+                  "summary": "subagent summary response\n"
+                }}
+                ```
+
+                ## Assistant
+
+                Response
+
+            "#},
+            subagent_session_id
+        )
+    );
+}
+
+#[gpui::test]
+async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
+    init_test(cx);
+    cx.update(|cx| {
+        LanguageModelRegistry::test(cx);
+    });
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["subagents".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/",
+        json!({
+            "a": {
+                "b.md": "Lorem"
+            }
+        }),
+    )
+    .await;
+    let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+    let thread_store = cx.new(|cx| ThreadStore::new(cx));
+    let agent = NativeAgent::new(
+        project.clone(),
+        thread_store.clone(),
+        Templates::new(),
+        None,
+        fs.clone(),
+        &mut cx.to_async(),
+    )
+    .await
+    .unwrap();
+    let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+    let acp_thread = cx
+        .update(|cx| {
+            connection
+                .clone()
+                .new_session(project.clone(), Path::new(""), cx)
+        })
+        .await
+        .unwrap();
+    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+    let thread = agent.read_with(cx, |agent, _| {
+        agent.sessions.get(&session_id).unwrap().thread.clone()
+    });
+    let model = Arc::new(FakeLanguageModel::default());
+
+    // Ensure empty threads are not saved, even if they get mutated.
+    thread.update(cx, |thread, cx| {
+        thread.set_model(model.clone(), cx);
+    });
+    cx.run_until_parked();
+
+    let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("spawning subagent");
+    let subagent_tool_input = SubagentToolInput {
+        label: "label".to_string(),
+        task_prompt: "subagent task prompt".to_string(),
+        summary_prompt: "subagent summary prompt".to_string(),
+        timeout_ms: None,
+        allowed_tools: None,
+    };
+    let subagent_tool_use = LanguageModelToolUse {
+        id: "subagent_1".into(),
+        name: SubagentTool::NAME.into(),
+        raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+        input: serde_json::to_value(&subagent_tool_input).unwrap(),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        subagent_tool_use,
+    ));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    let subagent_session_id = thread.read_with(cx, |thread, cx| {
+        thread
+            .running_subagent_ids(cx)
+            .get(0)
+            .expect("subagent thread should be running")
+            .clone()
+    });
+    let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
+        agent
+            .sessions
+            .get(&subagent_session_id)
+            .expect("subagent session should exist")
+            .acp_thread
+            .clone()
+    });
+
+    // model.send_last_completion_stream_text_chunk("subagent task response");
+    // model.end_last_completion_stream();
+
+    // cx.run_until_parked();
+
+    acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
+
+    cx.run_until_parked();
+
+    send.await.unwrap();
+
+    acp_thread.read_with(cx, |thread, cx| {
+        assert_eq!(thread.status(), ThreadStatus::Idle);
+        assert_eq!(
+            thread.to_markdown(cx),
+            indoc! {"
+                ## User
+
+                Prompt
+
+                ## Assistant
+
+                spawning subagent
+
+                **Tool Call: label**
+                Status: Canceled
+
+            "}
+        );
+    });
+    subagent_acp_thread.read_with(cx, |thread, cx| {
+        assert_eq!(thread.status(), ThreadStatus::Idle);
+        assert_eq!(
+            thread.to_markdown(cx),
+            indoc! {"
+                ## User
+
+                subagent task prompt
+
+            "}
+        );
+    });
+}
+
+#[gpui::test]
+async fn test_subagent_tool_call_cancellation_during_summary_prompt(cx: &mut TestAppContext) {
+    init_test(cx);
+    cx.update(|cx| {
+        LanguageModelRegistry::test(cx);
+    });
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["subagents".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/",
+        json!({
+            "a": {
+                "b.md": "Lorem"
+            }
+        }),
+    )
+    .await;
+    let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+    let thread_store = cx.new(|cx| ThreadStore::new(cx));
+    let agent = NativeAgent::new(
+        project.clone(),
+        thread_store.clone(),
+        Templates::new(),
+        None,
+        fs.clone(),
+        &mut cx.to_async(),
+    )
+    .await
+    .unwrap();
+    let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+    let acp_thread = cx
+        .update(|cx| {
+            connection
+                .clone()
+                .new_session(project.clone(), Path::new(""), cx)
+        })
+        .await
+        .unwrap();
+    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+    let thread = agent.read_with(cx, |agent, _| {
+        agent.sessions.get(&session_id).unwrap().thread.clone()
+    });
+    let model = Arc::new(FakeLanguageModel::default());
+
+    // Ensure empty threads are not saved, even if they get mutated.
+    thread.update(cx, |thread, cx| {
+        thread.set_model(model.clone(), cx);
+    });
+    cx.run_until_parked();
+
+    let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("spawning subagent");
+    let subagent_tool_input = SubagentToolInput {
+        label: "label".to_string(),
+        task_prompt: "subagent task prompt".to_string(),
+        summary_prompt: "subagent summary prompt".to_string(),
+        timeout_ms: None,
+        allowed_tools: None,
+    };
+    let subagent_tool_use = LanguageModelToolUse {
+        id: "subagent_1".into(),
+        name: SubagentTool::NAME.into(),
+        raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+        input: serde_json::to_value(&subagent_tool_input).unwrap(),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        subagent_tool_use,
+    ));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    let subagent_session_id = thread.read_with(cx, |thread, cx| {
+        thread
+            .running_subagent_ids(cx)
+            .get(0)
+            .expect("subagent thread should be running")
+            .clone()
+    });
+    let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
+        agent
+            .sessions
+            .get(&subagent_session_id)
+            .expect("subagent session should exist")
+            .acp_thread
+            .clone()
+    });
+
+    model.send_last_completion_stream_text_chunk("subagent task response");
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
+
+    cx.run_until_parked();
+
+    send.await.unwrap();
+
+    acp_thread.read_with(cx, |thread, cx| {
+        assert_eq!(thread.status(), ThreadStatus::Idle);
+        assert_eq!(
+            thread.to_markdown(cx),
+            indoc! {"
+                ## User
+
+                Prompt
+
+                ## Assistant
+
+                spawning subagent
+
+                **Tool Call: label**
+                Status: Canceled
+
+            "}
+        );
+    });
+    subagent_acp_thread.read_with(cx, |thread, cx| {
+        assert_eq!(thread.status(), ThreadStatus::Idle);
+        assert_eq!(
+            thread.to_markdown(cx),
+            indoc! {"
+                ## User
+
+                subagent task prompt
+
+                ## Assistant
+
+                subagent task response
+
+                ## User
+
+                subagent summary prompt
+
+            "}
+        );
+    });
+}
+
 #[gpui::test]
 async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
     init_test(cx);
@@ -4382,84 +4818,6 @@ async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
     });
 }
 
-#[gpui::test]
-async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
-    // This test verifies that the subagent tool properly handles user cancellation
-    // via `event_stream.cancelled_by_user()` and stops all running subagents.
-    init_test(cx);
-    always_allow_tools(cx);
-
-    cx.update(|cx| {
-        cx.update_flags(true, vec!["subagents".to_string()]);
-    });
-
-    let fs = FakeFs::new(cx.executor());
-    fs.insert_tree(path!("/test"), json!({})).await;
-    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
-    let project_context = cx.new(|_cx| ProjectContext::default());
-    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
-    let context_server_registry =
-        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
-    let model = Arc::new(FakeLanguageModel::default());
-    let environment = Rc::new(cx.update(|cx| {
-        FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
-    }));
-
-    let parent = cx.new(|cx| {
-        Thread::new(
-            project.clone(),
-            project_context.clone(),
-            context_server_registry.clone(),
-            Templates::new(),
-            Some(model.clone()),
-            cx,
-        )
-    });
-
-    #[allow(clippy::arc_with_non_send_sync)]
-    let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
-
-    let (event_stream, _rx, mut cancellation_tx) =
-        crate::ToolCallEventStream::test_with_cancellation();
-
-    // Start the subagent tool
-    let task = cx.update(|cx| {
-        tool.run(
-            SubagentToolInput {
-                label: "Long running task".to_string(),
-                task_prompt: "Do a very long task that takes forever".to_string(),
-                summary_prompt: "Summarize".to_string(),
-                timeout_ms: None,
-                allowed_tools: None,
-            },
-            event_stream.clone(),
-            cx,
-        )
-    });
-
-    cx.run_until_parked();
-
-    // Signal cancellation via the event stream
-    crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
-
-    // The task should complete promptly with a cancellation error
-    let timeout = cx.background_executor.timer(Duration::from_secs(5));
-    let result = futures::select! {
-        result = task.fuse() => result,
-        _ = timeout.fuse() => {
-            panic!("subagent tool did not respond to cancellation within timeout");
-        }
-    };
-
-    // Verify we got a cancellation error
-    let err = result.unwrap_err();
-    assert!(
-        err.to_string().contains("cancelled by user"),
-        "expected cancellation error, got: {}",
-        err
-    );
-}
-
 #[gpui::test]
 async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
     init_test(cx);

crates/agent/src/thread.rs 🔗

@@ -2582,6 +2582,14 @@ impl Thread {
         });
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn running_subagent_ids(&self, cx: &App) -> Vec<acp::SessionId> {
+        self.running_subagents
+            .iter()
+            .filter_map(|s| s.upgrade().map(|s| s.read(cx).id().clone()))
+            .collect()
+    }
+
     pub fn running_subagent_count(&self) -> usize {
         self.running_subagents
             .iter()

crates/agent/src/tools/subagent_tool.rs 🔗

@@ -1,7 +1,6 @@
 use acp_thread::SUBAGENT_SESSION_ID_META_KEY;
 use agent_client_protocol as acp;
 use anyhow::{Result, anyhow};
-use futures::FutureExt as _;
 use gpui::{App, Entity, SharedString, Task, WeakEntity};
 use language_model::LanguageModelToolResultContent;
 use schemars::JsonSchema;
@@ -171,17 +170,11 @@ impl AgentTool for SubagentTool {
         event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta));
 
         cx.spawn(async move |cx| {
-            let summary_task = subagent.wait_for_summary(input.summary_prompt, cx);
-
-            futures::select_biased! {
-                summary = summary_task.fuse() => summary.map(|summary| SubagentToolOutput {
-                    summary,
-                    subagent_session_id,
-                }),
-                _ = event_stream.cancelled_by_user().fuse() => {
-                    Err(anyhow!("Subagent was cancelled by user"))
-                }
-            }
+            let summary = subagent.wait_for_summary(input.summary_prompt, cx).await?;
+            Ok(SubagentToolOutput {
+                subagent_session_id,
+                summary,
+            })
         })
     }
 

crates/agent_ui/src/acp/thread_view/active_thread.rs 🔗

@@ -810,7 +810,7 @@ impl AcpThreadView {
                 status,
                 turn_time_ms,
             );
-            res
+            res.map(|_| ())
         });
 
         cx.spawn(async move |this, cx| {
@@ -6164,8 +6164,8 @@ impl AcpThreadView {
                                             |this, thread| {
                                                 this.on_click(cx.listener(
                                                     move |_this, _event, _window, cx| {
-                                                        thread.update(cx, |thread, _cx| {
-                                                            thread.stop_by_user();
+                                                        thread.update(cx, |thread, cx| {
+                                                            thread.cancel(cx).detach();
                                                         });
                                                     },
                                                 ))