@@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
-use std::cell::{Cell, RefCell};
+use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
@@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
})
.detach();
- let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None));
+ let turn_state = Rc::new(RefCell::new(TurnState::None));
- let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({
- let end_turn_tx = end_turn_tx.clone();
+ let turn_state = turn_state.clone();
let mut thread_rx = thread_rx.clone();
- let cancellation_state = pending_cancellation.clone();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message(
thread_rx.clone(),
message,
- end_turn_tx.clone(),
- cancellation_state.clone(),
+ turn_state.clone(),
cx,
)
.await
@@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession {
outgoing_tx,
- end_turn_tx,
- pending_cancellation,
+ turn_state,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
@@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
)));
};
- let (tx, rx) = oneshot::channel();
- session.end_turn_tx.borrow_mut().replace(tx);
+ let (end_tx, end_rx) = oneshot::channel();
+ session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new();
for chunk in params.prompt {
@@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err)));
}
- let cancellation_state = session.pending_cancellation.clone();
- cx.foreground_executor().spawn(async move {
- let result = rx.await??;
- cancellation_state.set(PendingCancellation::None);
- Ok(result)
- })
+ cx.foreground_executor().spawn(async move { end_rx.await? })
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@@ -277,7 +268,15 @@ impl AgentConnection for ClaudeAgentConnection {
let request_id = new_request_id();
- session.pending_cancellation.set(PendingCancellation::Sent {
+ let turn_state = session.turn_state.take();
+ let TurnState::InProgress { end_tx } = turn_state else {
+ // Already cancelled or idle, put it back
+ session.turn_state.replace(turn_state);
+ return;
+ };
+
+ session.turn_state.replace(TurnState::CancelRequested {
+ end_tx,
request_id: request_id.clone(),
});
@@ -349,28 +348,56 @@ fn spawn_claude(
struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
- end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
- pending_cancellation: Rc<Cell<PendingCancellation>>,
+ turn_state: Rc<RefCell<TurnState>>,
_mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>,
}
-#[derive(Debug, Default, PartialEq)]
-enum PendingCancellation {
+#[derive(Debug, Default)]
+enum TurnState {
#[default]
None,
- Sent {
+ InProgress {
+ end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
+ },
+ CancelRequested {
+ end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
request_id: String,
},
- Confirmed,
+ CancelConfirmed {
+ end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
+ },
+}
+
+impl TurnState {
+ fn is_cancelled(&self) -> bool {
+ matches!(self, TurnState::CancelConfirmed { .. })
+ }
+
+ fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
+ match self {
+ TurnState::None => None,
+ TurnState::InProgress { end_tx, .. } => Some(end_tx),
+ TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
+ TurnState::CancelConfirmed { end_tx } => Some(end_tx),
+ }
+ }
+
+ fn confirm_cancellation(self, id: &str) -> Self {
+ match self {
+ TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
+ TurnState::CancelConfirmed { end_tx }
+ }
+ _ => self,
+ }
+ }
}
impl ClaudeAgentSession {
async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage,
- end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
- pending_cancellation: Rc<Cell<PendingCancellation>>,
+ turn_state: Rc<RefCell<TurnState>>,
cx: &mut AsyncApp,
) {
match message {
@@ -393,15 +420,13 @@ impl ClaudeAgentSession {
for chunk in message.content.chunks() {
match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
- let state = pending_cancellation.take();
- if state != PendingCancellation::Confirmed {
+ if !turn_state.borrow().is_cancelled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx)
})
.log_err();
}
- pending_cancellation.set(state);
}
ContentChunk::ToolResult {
content,
@@ -414,7 +439,12 @@ impl ClaudeAgentSession {
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
- status: Some(acp::ToolCallStatus::Completed),
+ status: if turn_state.borrow().is_cancelled() {
+ // Do not set to completed if turn was cancelled
+ None
+ } else {
+ Some(acp::ToolCallStatus::Completed)
+ },
content: (!content.is_empty())
.then(|| vec![content.into()]),
..Default::default()
@@ -541,40 +571,38 @@ impl ClaudeAgentSession {
result,
..
} => {
- if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
- if is_error
- || (subtype == ResultErrorType::ErrorDuringExecution
- && pending_cancellation.take() != PendingCancellation::Confirmed)
- {
- end_turn_tx
- .send(Err(anyhow!(
- "Error: {}",
- result.unwrap_or_else(|| subtype.to_string())
- )))
- .ok();
- } else {
- let stop_reason = match subtype {
- ResultErrorType::Success => acp::StopReason::EndTurn,
- ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
- ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
- };
- end_turn_tx
- .send(Ok(acp::PromptResponse { stop_reason }))
- .ok();
- }
+ let turn_state = turn_state.take();
+ let was_cancelled = turn_state.is_cancelled();
+ let Some(end_turn_tx) = turn_state.end_tx() else {
+ debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
+ return;
+ };
+
+ if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
+ {
+ end_turn_tx
+ .send(Err(anyhow!(
+ "Error: {}",
+ result.unwrap_or_else(|| subtype.to_string())
+ )))
+ .ok();
+ } else {
+ let stop_reason = match subtype {
+ ResultErrorType::Success => acp::StopReason::EndTurn,
+ ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
+ ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
+ };
+ end_turn_tx
+ .send(Ok(acp::PromptResponse { stop_reason }))
+ .ok();
}
}
SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) {
- let pending_cancellation_value = pending_cancellation.take();
-
- if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
- && request_id == &response.request_id
- {
- pending_cancellation.set(PendingCancellation::Confirmed);
- } else {
- pending_cancellation.set(pending_cancellation_value);
- }
+ let new_state = turn_state.take().confirm_cancellation(&response.request_id);
+ turn_state.replace(new_state);
+ } else {
+ log::error!("Control response error: {:?}", response);
}
}
SdkMessage::System { .. } => {}
@@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
- let full_turn = thread.update(cx, |thread, cx| {
+ let _ = thread.update(cx, |thread, cx| {
thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
cx,
@@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
id.clone()
});
- let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
- full_turn.await.unwrap();
- thread.read_with(cx, |thread, _| {
+ thread.update(cx, |thread, cx| thread.cancel(cx)).await;
+ thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall {
status: ToolCallStatus::Canceled,
..