@@ -6,6 +6,7 @@ use anyhow::{Context as _, Result};
use assistant_tool::ActionLog;
use buffer_diff::BufferDiff;
use editor::{Bias, MultiBuffer, PathKey};
+use futures::future::{Fuse, FusedFuture};
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use itertools::Itertools;
@@ -572,7 +573,7 @@ pub struct AcpThread {
project: Entity<Project>,
action_log: Entity<ActionLog>,
shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
- send_task: Option<Task<()>>,
+ send_task: Option<Fuse<Task<()>>>,
connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId,
}
@@ -662,7 +663,11 @@ impl AcpThread {
}
pub fn status(&self) -> ThreadStatus {
- if self.send_task.is_some() {
+ if self
+ .send_task
+ .as_ref()
+ .map_or(false, |t| !t.is_terminated())
+ {
if self.waiting_for_tool_confirmation() {
ThreadStatus::WaitingForToolConfirmation
} else {
@@ -1037,28 +1042,31 @@ impl AcpThread {
let (tx, rx) = oneshot::channel();
let cancel_task = self.cancel(cx);
- self.send_task = Some(cx.spawn(async move |this, cx| {
- async {
- cancel_task.await;
+ self.send_task = Some(
+ cx.spawn(async move |this, cx| {
+ async {
+ cancel_task.await;
+
+ let result = this
+ .update(cx, |this, cx| {
+ this.connection.prompt(
+ acp::PromptRequest {
+ prompt: message,
+ session_id: this.session_id.clone(),
+ },
+ cx,
+ )
+ })?
+ .await;
- let result = this
- .update(cx, |this, cx| {
- this.connection.prompt(
- acp::PromptRequest {
- prompt: message,
- session_id: this.session_id.clone(),
- },
- cx,
- )
- })?
- .await;
- tx.send(result).log_err();
- this.update(cx, |this, _cx| this.send_task.take())?;
- anyhow::Ok(())
- }
- .await
- .log_err();
- }));
+ tx.send(result).log_err();
+ anyhow::Ok(())
+ }
+ .await
+ .log_err();
+ })
+ .fuse(),
+ );
cx.spawn(async move |this, cx| match rx.await {
Ok(Err(e)) => {
@@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
-use std::cell::RefCell;
+use std::cell::{Cell, RefCell};
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
@@ -24,7 +24,7 @@ use futures::{
};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use serde::{Deserialize, Serialize};
-use util::ResultExt;
+use util::{ResultExt, debug_panic};
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool;
@@ -153,16 +153,20 @@ impl AgentConnection for ClaudeAgentConnection {
})
.detach();
+ let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None));
+
let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone();
let mut thread_rx = thread_rx.clone();
+ let cancellation_state = pending_cancellation.clone();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message(
thread_rx.clone(),
message,
end_turn_tx.clone(),
+ cancellation_state.clone(),
cx,
)
.await
@@ -189,6 +193,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession {
outgoing_tx,
end_turn_tx,
+ pending_cancellation,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
@@ -255,7 +260,12 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err)));
}
- cx.foreground_executor().spawn(async move { rx.await? })
+ let cancellation_state = session.pending_cancellation.clone();
+ cx.foreground_executor().spawn(async move {
+ let result = rx.await??;
+ cancellation_state.set(PendingCancellation::None);
+ Ok(result)
+ })
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@@ -265,18 +275,19 @@ impl AgentConnection for ClaudeAgentConnection {
return;
};
+ let request_id = new_request_id();
+
+ session.pending_cancellation.set(PendingCancellation::Sent {
+ request_id: request_id.clone(),
+ });
+
session
.outgoing_tx
- .unbounded_send(SdkMessage::new_interrupt_message())
+ .unbounded_send(SdkMessage::ControlRequest {
+ request_id,
+ request: ControlRequest::Interrupt,
+ })
.log_err();
-
- if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() {
- end_turn_tx
- .send(Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::Cancelled,
- }))
- .ok();
- }
}
}
@@ -339,25 +350,107 @@ fn spawn_claude(
struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
+ pending_cancellation: Rc<Cell<PendingCancellation>>,
_mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>,
}
+#[derive(Debug, Default, PartialEq)]
+enum PendingCancellation {
+ #[default]
+ None,
+ Sent {
+ request_id: String,
+ },
+ Confirmed,
+}
+
impl ClaudeAgentSession {
async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
+ pending_cancellation: Rc<Cell<PendingCancellation>>,
cx: &mut AsyncApp,
) {
match message {
// we should only be sending these out, they don't need to be in the thread
SdkMessage::ControlRequest { .. } => {}
- SdkMessage::Assistant {
+ SdkMessage::User {
message,
session_id: _,
+ } => {
+ let Some(thread) = thread_rx
+ .recv()
+ .await
+ .log_err()
+ .and_then(|entity| entity.upgrade())
+ else {
+ log::error!("Received an SDK message but thread is gone");
+ return;
+ };
+
+ for chunk in message.content.chunks() {
+ match chunk {
+ ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
+ let state = pending_cancellation.take();
+ if state != PendingCancellation::Confirmed {
+ thread
+ .update(cx, |thread, cx| {
+ thread.push_user_content_block(text.into(), cx)
+ })
+ .log_err();
+ }
+ pending_cancellation.set(state);
+ }
+ ContentChunk::ToolResult {
+ content,
+ tool_use_id,
+ } => {
+ let content = content.to_string();
+ thread
+ .update(cx, |thread, cx| {
+ thread.update_tool_call(
+ acp::ToolCallUpdate {
+ id: acp::ToolCallId(tool_use_id.into()),
+ fields: acp::ToolCallUpdateFields {
+ status: Some(acp::ToolCallStatus::Completed),
+ content: (!content.is_empty())
+ .then(|| vec![content.into()]),
+ ..Default::default()
+ },
+ },
+ cx,
+ )
+ })
+ .log_err();
+ }
+ ContentChunk::Thinking { .. }
+ | ContentChunk::RedactedThinking
+ | ContentChunk::ToolUse { .. } => {
+ debug_panic!(
+ "Should not get {:?} with role: assistant. should we handle this?",
+ chunk
+ );
+ }
+
+ ContentChunk::Image
+ | ContentChunk::Document
+ | ContentChunk::WebSearchToolResult => {
+ thread
+ .update(cx, |thread, cx| {
+ thread.push_assistant_content_block(
+ format!("Unsupported content: {:?}", chunk).into(),
+ false,
+ cx,
+ )
+ })
+ .log_err();
+ }
+ }
+ }
}
- | SdkMessage::User {
+ SdkMessage::Assistant {
message,
session_id: _,
} => {
@@ -423,31 +516,12 @@ impl ClaudeAgentSession {
})
.log_err();
}
- ContentChunk::ToolResult {
- content,
- tool_use_id,
- } => {
- let content = content.to_string();
- thread
- .update(cx, |thread, cx| {
- thread.update_tool_call(
- acp::ToolCallUpdate {
- id: acp::ToolCallId(tool_use_id.into()),
- fields: acp::ToolCallUpdateFields {
- status: Some(acp::ToolCallStatus::Completed),
- content: (!content.is_empty())
- .then(|| vec![content.into()]),
- ..Default::default()
- },
- },
- cx,
- )
- })
- .log_err();
+ ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
+ debug_panic!(
+ "Should not get tool results with role: assistant. should we handle this?"
+ );
}
- ContentChunk::Image
- | ContentChunk::Document
- | ContentChunk::WebSearchToolResult => {
+ ContentChunk::Image | ContentChunk::Document => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
@@ -468,7 +542,10 @@ impl ClaudeAgentSession {
..
} => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
- if is_error || subtype == ResultErrorType::ErrorDuringExecution {
+ if is_error
+ || (subtype == ResultErrorType::ErrorDuringExecution
+ && pending_cancellation.take() != PendingCancellation::Confirmed)
+ {
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
@@ -479,7 +556,7 @@ impl ClaudeAgentSession {
let stop_reason = match subtype {
ResultErrorType::Success => acp::StopReason::EndTurn,
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
- ResultErrorType::ErrorDuringExecution => unreachable!(),
+ ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
};
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
@@ -487,7 +564,20 @@ impl ClaudeAgentSession {
}
}
}
- SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
+ SdkMessage::ControlResponse { response } => {
+ if matches!(response.subtype, ResultErrorType::Success) {
+ let pending_cancellation_value = pending_cancellation.take();
+
+ if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
+ && request_id == &response.request_id
+ {
+ pending_cancellation.set(PendingCancellation::Confirmed);
+ } else {
+ pending_cancellation.set(pending_cancellation_value);
+ }
+ }
+ }
+ SdkMessage::System { .. } => {}
}
}
@@ -728,22 +818,15 @@ impl Display for ResultErrorType {
}
}
-impl SdkMessage {
- fn new_interrupt_message() -> Self {
- use rand::Rng;
- // In the Claude Code TS SDK they just generate a random 12 character string,
- // `Math.random().toString(36).substring(2, 15)`
- let request_id = rand::thread_rng()
- .sample_iter(&rand::distributions::Alphanumeric)
- .take(12)
- .map(char::from)
- .collect();
-
- Self::ControlRequest {
- request_id,
- request: ControlRequest::Interrupt,
- }
- }
+fn new_request_id() -> String {
+ use rand::Rng;
+ // In the Claude Code TS SDK they just generate a random 12 character string,
+ // `Math.random().toString(36).substring(2, 15)`
+ rand::thread_rng()
+ .sample_iter(&rand::distributions::Alphanumeric)
+ .take(12)
+ .map(char::from)
+ .collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]