Fix interrupting ACP threads and CC cancellation (#35752)

Agus Zubiaga and Cole Miller created

Fixes a bug where generation wouldn't continue after interrupting the
agent, and improves CC cancellation so we don't display "[Request
interrupted by user]"

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>

Change summary

crates/acp_thread/src/acp_thread.rs |  54 ++++---
crates/agent_servers/src/claude.rs  | 197 ++++++++++++++++++++++--------
2 files changed, 171 insertions(+), 80 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -6,6 +6,7 @@ use anyhow::{Context as _, Result};
 use assistant_tool::ActionLog;
 use buffer_diff::BufferDiff;
 use editor::{Bias, MultiBuffer, PathKey};
+use futures::future::{Fuse, FusedFuture};
 use futures::{FutureExt, channel::oneshot, future::BoxFuture};
 use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
 use itertools::Itertools;
@@ -572,7 +573,7 @@ pub struct AcpThread {
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
     shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
-    send_task: Option<Task<()>>,
+    send_task: Option<Fuse<Task<()>>>,
     connection: Rc<dyn AgentConnection>,
     session_id: acp::SessionId,
 }
@@ -662,7 +663,11 @@ impl AcpThread {
     }
 
     pub fn status(&self) -> ThreadStatus {
-        if self.send_task.is_some() {
+        if self
+            .send_task
+            .as_ref()
+            .map_or(false, |t| !t.is_terminated())
+        {
             if self.waiting_for_tool_confirmation() {
                 ThreadStatus::WaitingForToolConfirmation
             } else {
@@ -1037,28 +1042,31 @@ impl AcpThread {
         let (tx, rx) = oneshot::channel();
         let cancel_task = self.cancel(cx);
 
-        self.send_task = Some(cx.spawn(async move |this, cx| {
-            async {
-                cancel_task.await;
+        self.send_task = Some(
+            cx.spawn(async move |this, cx| {
+                async {
+                    cancel_task.await;
+
+                    let result = this
+                        .update(cx, |this, cx| {
+                            this.connection.prompt(
+                                acp::PromptRequest {
+                                    prompt: message,
+                                    session_id: this.session_id.clone(),
+                                },
+                                cx,
+                            )
+                        })?
+                        .await;
 
-                let result = this
-                    .update(cx, |this, cx| {
-                        this.connection.prompt(
-                            acp::PromptRequest {
-                                prompt: message,
-                                session_id: this.session_id.clone(),
-                            },
-                            cx,
-                        )
-                    })?
-                    .await;
-                tx.send(result).log_err();
-                this.update(cx, |this, _cx| this.send_task.take())?;
-                anyhow::Ok(())
-            }
-            .await
-            .log_err();
-        }));
+                    tx.send(result).log_err();
+                    anyhow::Ok(())
+                }
+                .await
+                .log_err();
+            })
+            .fuse(),
+        );
 
         cx.spawn(async move |this, cx| match rx.await {
             Ok(Err(e)) => {

crates/agent_servers/src/claude.rs 🔗

@@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
 use project::Project;
 use settings::SettingsStore;
 use smol::process::Child;
-use std::cell::RefCell;
+use std::cell::{Cell, RefCell};
 use std::fmt::Display;
 use std::path::Path;
 use std::rc::Rc;
@@ -24,7 +24,7 @@ use futures::{
 };
 use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
 use serde::{Deserialize, Serialize};
-use util::ResultExt;
+use util::{ResultExt, debug_panic};
 
 use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
 use crate::claude::tools::ClaudeTool;
@@ -153,16 +153,20 @@ impl AgentConnection for ClaudeAgentConnection {
             })
             .detach();
 
+            let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None));
+
             let end_turn_tx = Rc::new(RefCell::new(None));
             let handler_task = cx.spawn({
                 let end_turn_tx = end_turn_tx.clone();
                 let mut thread_rx = thread_rx.clone();
+                let cancellation_state = pending_cancellation.clone();
                 async move |cx| {
                     while let Some(message) = incoming_message_rx.next().await {
                         ClaudeAgentSession::handle_message(
                             thread_rx.clone(),
                             message,
                             end_turn_tx.clone(),
+                            cancellation_state.clone(),
                             cx,
                         )
                         .await
@@ -189,6 +193,7 @@ impl AgentConnection for ClaudeAgentConnection {
             let session = ClaudeAgentSession {
                 outgoing_tx,
                 end_turn_tx,
+                pending_cancellation,
                 _handler_task: handler_task,
                 _mcp_server: Some(permission_mcp_server),
             };
@@ -255,7 +260,12 @@ impl AgentConnection for ClaudeAgentConnection {
             return Task::ready(Err(anyhow!(err)));
         }
 
-        cx.foreground_executor().spawn(async move { rx.await? })
+        let cancellation_state = session.pending_cancellation.clone();
+        cx.foreground_executor().spawn(async move {
+            let result = rx.await??;
+            cancellation_state.set(PendingCancellation::None);
+            Ok(result)
+        })
     }
 
     fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@@ -265,18 +275,19 @@ impl AgentConnection for ClaudeAgentConnection {
             return;
         };
 
+        let request_id = new_request_id();
+
+        session.pending_cancellation.set(PendingCancellation::Sent {
+            request_id: request_id.clone(),
+        });
+
         session
             .outgoing_tx
-            .unbounded_send(SdkMessage::new_interrupt_message())
+            .unbounded_send(SdkMessage::ControlRequest {
+                request_id,
+                request: ControlRequest::Interrupt,
+            })
             .log_err();
-
-        if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() {
-            end_turn_tx
-                .send(Ok(acp::PromptResponse {
-                    stop_reason: acp::StopReason::Cancelled,
-                }))
-                .ok();
-        }
     }
 }
 
@@ -339,25 +350,107 @@ fn spawn_claude(
 struct ClaudeAgentSession {
     outgoing_tx: UnboundedSender<SdkMessage>,
     end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
+    pending_cancellation: Rc<Cell<PendingCancellation>>,
     _mcp_server: Option<ClaudeZedMcpServer>,
     _handler_task: Task<()>,
 }
 
+#[derive(Debug, Default, PartialEq)]
+enum PendingCancellation {
+    #[default]
+    None,
+    Sent {
+        request_id: String,
+    },
+    Confirmed,
+}
+
 impl ClaudeAgentSession {
     async fn handle_message(
         mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
         message: SdkMessage,
         end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
+        pending_cancellation: Rc<Cell<PendingCancellation>>,
         cx: &mut AsyncApp,
     ) {
         match message {
             // we should only be sending these out, they don't need to be in the thread
             SdkMessage::ControlRequest { .. } => {}
-            SdkMessage::Assistant {
+            SdkMessage::User {
                 message,
                 session_id: _,
+            } => {
+                let Some(thread) = thread_rx
+                    .recv()
+                    .await
+                    .log_err()
+                    .and_then(|entity| entity.upgrade())
+                else {
+                    log::error!("Received an SDK message but thread is gone");
+                    return;
+                };
+
+                for chunk in message.content.chunks() {
+                    match chunk {
+                        ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
+                            let state = pending_cancellation.take();
+                            if state != PendingCancellation::Confirmed {
+                                thread
+                                    .update(cx, |thread, cx| {
+                                        thread.push_user_content_block(text.into(), cx)
+                                    })
+                                    .log_err();
+                            }
+                            pending_cancellation.set(state);
+                        }
+                        ContentChunk::ToolResult {
+                            content,
+                            tool_use_id,
+                        } => {
+                            let content = content.to_string();
+                            thread
+                                .update(cx, |thread, cx| {
+                                    thread.update_tool_call(
+                                        acp::ToolCallUpdate {
+                                            id: acp::ToolCallId(tool_use_id.into()),
+                                            fields: acp::ToolCallUpdateFields {
+                                                status: Some(acp::ToolCallStatus::Completed),
+                                                content: (!content.is_empty())
+                                                    .then(|| vec![content.into()]),
+                                                ..Default::default()
+                                            },
+                                        },
+                                        cx,
+                                    )
+                                })
+                                .log_err();
+                        }
+                        ContentChunk::Thinking { .. }
+                        | ContentChunk::RedactedThinking
+                        | ContentChunk::ToolUse { .. } => {
+                            debug_panic!(
+                                "Should not get {:?} with role: assistant. should we handle this?",
+                                chunk
+                            );
+                        }
+
+                        ContentChunk::Image
+                        | ContentChunk::Document
+                        | ContentChunk::WebSearchToolResult => {
+                            thread
+                                .update(cx, |thread, cx| {
+                                    thread.push_assistant_content_block(
+                                        format!("Unsupported content: {:?}", chunk).into(),
+                                        false,
+                                        cx,
+                                    )
+                                })
+                                .log_err();
+                        }
+                    }
+                }
             }
-            | SdkMessage::User {
+            SdkMessage::Assistant {
                 message,
                 session_id: _,
             } => {
@@ -423,31 +516,12 @@ impl ClaudeAgentSession {
                                 })
                                 .log_err();
                         }
-                        ContentChunk::ToolResult {
-                            content,
-                            tool_use_id,
-                        } => {
-                            let content = content.to_string();
-                            thread
-                                .update(cx, |thread, cx| {
-                                    thread.update_tool_call(
-                                        acp::ToolCallUpdate {
-                                            id: acp::ToolCallId(tool_use_id.into()),
-                                            fields: acp::ToolCallUpdateFields {
-                                                status: Some(acp::ToolCallStatus::Completed),
-                                                content: (!content.is_empty())
-                                                    .then(|| vec![content.into()]),
-                                                ..Default::default()
-                                            },
-                                        },
-                                        cx,
-                                    )
-                                })
-                                .log_err();
+                        ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
+                            debug_panic!(
+                                "Should not get tool results with role: assistant. should we handle this?"
+                            );
                         }
-                        ContentChunk::Image
-                        | ContentChunk::Document
-                        | ContentChunk::WebSearchToolResult => {
+                        ContentChunk::Image | ContentChunk::Document => {
                             thread
                                 .update(cx, |thread, cx| {
                                     thread.push_assistant_content_block(
@@ -468,7 +542,10 @@ impl ClaudeAgentSession {
                 ..
             } => {
                 if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
-                    if is_error || subtype == ResultErrorType::ErrorDuringExecution {
+                    if is_error
+                        || (subtype == ResultErrorType::ErrorDuringExecution
+                            && pending_cancellation.take() != PendingCancellation::Confirmed)
+                    {
                         end_turn_tx
                             .send(Err(anyhow!(
                                 "Error: {}",
@@ -479,7 +556,7 @@ impl ClaudeAgentSession {
                         let stop_reason = match subtype {
                             ResultErrorType::Success => acp::StopReason::EndTurn,
                             ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
-                            ResultErrorType::ErrorDuringExecution => unreachable!(),
+                            ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
                         };
                         end_turn_tx
                             .send(Ok(acp::PromptResponse { stop_reason }))
@@ -487,7 +564,20 @@ impl ClaudeAgentSession {
                     }
                 }
             }
-            SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
+            SdkMessage::ControlResponse { response } => {
+                if matches!(response.subtype, ResultErrorType::Success) {
+                    let pending_cancellation_value = pending_cancellation.take();
+
+                    if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
+                        && request_id == &response.request_id
+                    {
+                        pending_cancellation.set(PendingCancellation::Confirmed);
+                    } else {
+                        pending_cancellation.set(pending_cancellation_value);
+                    }
+                }
+            }
+            SdkMessage::System { .. } => {}
         }
     }
 
@@ -728,22 +818,15 @@ impl Display for ResultErrorType {
     }
 }
 
-impl SdkMessage {
-    fn new_interrupt_message() -> Self {
-        use rand::Rng;
-        // In the Claude Code TS SDK they just generate a random 12 character string,
-        // `Math.random().toString(36).substring(2, 15)`
-        let request_id = rand::thread_rng()
-            .sample_iter(&rand::distributions::Alphanumeric)
-            .take(12)
-            .map(char::from)
-            .collect();
-
-        Self::ControlRequest {
-            request_id,
-            request: ControlRequest::Interrupt,
-        }
-    }
+fn new_request_id() -> String {
+    use rand::Rng;
+    // In the Claude Code TS SDK they just generate a random 12 character string,
+    // `Math.random().toString(36).substring(2, 15)`
+    rand::thread_rng()
+        .sample_iter(&rand::distributions::Alphanumeric)
+        .take(12)
+        .map(char::from)
+        .collect()
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]