diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index aa17a80b5b4348cbf752e8dbb004d8f3f060cd2e..be9952fd559ab8da075c404e3d340b023a79c547 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -6,6 +6,7 @@ use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use editor::{Bias, MultiBuffer, PathKey}; +use futures::future::{Fuse, FusedFuture}; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use itertools::Itertools; @@ -572,7 +573,7 @@ pub struct AcpThread { project: Entity, action_log: Entity, shared_buffers: HashMap, BufferSnapshot>, - send_task: Option>, + send_task: Option>>, connection: Rc, session_id: acp::SessionId, } @@ -662,7 +663,11 @@ impl AcpThread { } pub fn status(&self) -> ThreadStatus { - if self.send_task.is_some() { + if self + .send_task + .as_ref() + .map_or(false, |t| !t.is_terminated()) + { if self.waiting_for_tool_confirmation() { ThreadStatus::WaitingForToolConfirmation } else { @@ -1037,28 +1042,31 @@ impl AcpThread { let (tx, rx) = oneshot::channel(); let cancel_task = self.cancel(cx); - self.send_task = Some(cx.spawn(async move |this, cx| { - async { - cancel_task.await; + self.send_task = Some( + cx.spawn(async move |this, cx| { + async { + cancel_task.await; + + let result = this + .update(cx, |this, cx| { + this.connection.prompt( + acp::PromptRequest { + prompt: message, + session_id: this.session_id.clone(), + }, + cx, + ) + })? + .await; - let result = this - .update(cx, |this, cx| { - this.connection.prompt( - acp::PromptRequest { - prompt: message, - session_id: this.session_id.clone(), - }, - cx, - ) - })? - .await; - tx.send(result).log_err(); - this.update(cx, |this, _cx| this.send_task.take())?; - anyhow::Ok(()) - } - .await - .log_err(); - })); + tx.send(result).log_err(); + anyhow::Ok(()) + } + .await + .log_err(); + }) + .fuse(), + ); cx.spawn(async move |this, cx| match rx.await { Ok(Err(e)) => { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 59e2e87433cbd596fb5c24f0afd35441a9fce835..09d08fdcf871a89d37fb1e49c44de58f8976740d 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -24,7 +24,7 @@ use futures::{ }; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use serde::{Deserialize, Serialize}; -use util::ResultExt; +use util::{ResultExt, debug_panic}; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; @@ -153,16 +153,20 @@ impl AgentConnection for ClaudeAgentConnection { }) .detach(); + let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None)); + let end_turn_tx = Rc::new(RefCell::new(None)); let handler_task = cx.spawn({ let end_turn_tx = end_turn_tx.clone(); let mut thread_rx = thread_rx.clone(); + let cancellation_state = pending_cancellation.clone(); async move |cx| { while let Some(message) = incoming_message_rx.next().await { ClaudeAgentSession::handle_message( thread_rx.clone(), message, end_turn_tx.clone(), + cancellation_state.clone(), cx, ) .await @@ -189,6 +193,7 @@ impl AgentConnection for ClaudeAgentConnection { let session = ClaudeAgentSession { outgoing_tx, end_turn_tx, + pending_cancellation, _handler_task: handler_task, _mcp_server: Some(permission_mcp_server), }; @@ -255,7 +260,12 @@ impl AgentConnection for ClaudeAgentConnection { return Task::ready(Err(anyhow!(err))); } - cx.foreground_executor().spawn(async move { rx.await? }) + let cancellation_state = session.pending_cancellation.clone(); + cx.foreground_executor().spawn(async move { + let result = rx.await??; + cancellation_state.set(PendingCancellation::None); + Ok(result) + }) } fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { @@ -265,18 +275,19 @@ impl AgentConnection for ClaudeAgentConnection { return; }; + let request_id = new_request_id(); + + session.pending_cancellation.set(PendingCancellation::Sent { + request_id: request_id.clone(), + }); + session .outgoing_tx - .unbounded_send(SdkMessage::new_interrupt_message()) + .unbounded_send(SdkMessage::ControlRequest { + request_id, + request: ControlRequest::Interrupt, + }) .log_err(); - - if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() { - end_turn_tx - .send(Ok(acp::PromptResponse { - stop_reason: acp::StopReason::Cancelled, - })) - .ok(); - } } } @@ -339,25 +350,107 @@ fn spawn_claude( struct ClaudeAgentSession { outgoing_tx: UnboundedSender, end_turn_tx: Rc>>>>, + pending_cancellation: Rc>, _mcp_server: Option, _handler_task: Task<()>, } +#[derive(Debug, Default, PartialEq)] +enum PendingCancellation { + #[default] + None, + Sent { + request_id: String, + }, + Confirmed, +} + impl ClaudeAgentSession { async fn handle_message( mut thread_rx: watch::Receiver>, message: SdkMessage, end_turn_tx: Rc>>>>, + pending_cancellation: Rc>, cx: &mut AsyncApp, ) { match message { // we should only be sending these out, they don't need to be in the thread SdkMessage::ControlRequest { .. } => {} - SdkMessage::Assistant { + SdkMessage::User { message, session_id: _, + } => { + let Some(thread) = thread_rx + .recv() + .await + .log_err() + .and_then(|entity| entity.upgrade()) + else { + log::error!("Received an SDK message but thread is gone"); + return; + }; + + for chunk in message.content.chunks() { + match chunk { + ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { + let state = pending_cancellation.take(); + if state != PendingCancellation::Confirmed { + thread + .update(cx, |thread, cx| { + thread.push_user_content_block(text.into(), cx) + }) + .log_err(); + } + pending_cancellation.set(state); + } + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + let content = content.to_string(); + thread + .update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + content: (!content.is_empty()) + .then(|| vec![content.into()]), + ..Default::default() + }, + }, + cx, + ) + }) + .log_err(); + } + ContentChunk::Thinking { .. } + | ContentChunk::RedactedThinking + | ContentChunk::ToolUse { .. } => { + debug_panic!( + "Should not get {:?} with role: assistant. should we handle this?", + chunk + ); + } + + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::WebSearchToolResult => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_content_block( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) + }) + .log_err(); + } + } + } } - | SdkMessage::User { + SdkMessage::Assistant { message, session_id: _, } => { @@ -423,31 +516,12 @@ impl ClaudeAgentSession { }) .log_err(); } - ContentChunk::ToolResult { - content, - tool_use_id, - } => { - let content = content.to_string(); - thread - .update(cx, |thread, cx| { - thread.update_tool_call( - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_use_id.into()), - fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), - content: (!content.is_empty()) - .then(|| vec![content.into()]), - ..Default::default() - }, - }, - cx, - ) - }) - .log_err(); + ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => { + debug_panic!( + "Should not get tool results with role: assistant. should we handle this?" + ); } - ContentChunk::Image - | ContentChunk::Document - | ContentChunk::WebSearchToolResult => { + ContentChunk::Image | ContentChunk::Document => { thread .update(cx, |thread, cx| { thread.push_assistant_content_block( @@ -468,7 +542,10 @@ impl ClaudeAgentSession { .. } => { if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { - if is_error || subtype == ResultErrorType::ErrorDuringExecution { + if is_error + || (subtype == ResultErrorType::ErrorDuringExecution + && pending_cancellation.take() != PendingCancellation::Confirmed) + { end_turn_tx .send(Err(anyhow!( "Error: {}", @@ -479,7 +556,7 @@ impl ClaudeAgentSession { let stop_reason = match subtype { ResultErrorType::Success => acp::StopReason::EndTurn, ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, - ResultErrorType::ErrorDuringExecution => unreachable!(), + ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, }; end_turn_tx .send(Ok(acp::PromptResponse { stop_reason })) @@ -487,7 +564,20 @@ impl ClaudeAgentSession { } } } - SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {} + SdkMessage::ControlResponse { response } => { + if matches!(response.subtype, ResultErrorType::Success) { + let pending_cancellation_value = pending_cancellation.take(); + + if let PendingCancellation::Sent { request_id } = &pending_cancellation_value + && request_id == &response.request_id + { + pending_cancellation.set(PendingCancellation::Confirmed); + } else { + pending_cancellation.set(pending_cancellation_value); + } + } + } + SdkMessage::System { .. } => {} } } @@ -728,22 +818,15 @@ impl Display for ResultErrorType { } } -impl SdkMessage { - fn new_interrupt_message() -> Self { - use rand::Rng; - // In the Claude Code TS SDK they just generate a random 12 character string, - // `Math.random().toString(36).substring(2, 15)` - let request_id = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(12) - .map(char::from) - .collect(); - - Self::ControlRequest { - request_id, - request: ControlRequest::Interrupt, - } - } +fn new_request_id() -> String { + use rand::Rng; + // In the Claude Code TS SDK they just generate a random 12 character string, + // `Math.random().toString(36).substring(2, 15)` + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(12) + .map(char::from) + .collect() } #[derive(Debug, Clone, Serialize, Deserialize)]