diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index f476941b2e34d648000cf599dfef3e84a094f488..6e01511850c83398061bba41ac14d457bd952cee 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -965,7 +965,7 @@ impl AcpThread { } } - self.connection.cancel(cx); + self.connection.cancel(&self.session_id, cx); // Wait for the send task to complete cx.foreground_executor().spawn(send_task) diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index a1c623906a67fb881df4bb17c1cacc3dfba87773..58b96203d1a129103418cbed36d8f13bf7b0d905 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -23,7 +23,7 @@ pub trait AgentConnection { fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task>; - fn cancel(&self, cx: &mut App); + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); } #[derive(Debug)] @@ -111,7 +111,7 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn cancel(&self, cx: &mut App) { + fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) { let task = self .connection .request_any(acp_old::CancelSendMessageParams.into_any()); diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 63d338fa198b62121b255d1f64cb8d4bb9f5259e..0a0b94b832d23c3a28875c438c8d4eeb918321e8 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -39,6 +39,7 @@ pub trait AgentServer: Send { fn connect( &self, + // these will go away when old_acp is fully removed root_dir: &Path, project: &Entity, cx: &mut App, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index a11419485cce3af57d7c7c477745a9e54a23b141..6b4561de1bae49ba2b2f272f6516f258c63b489c 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -26,7 +26,7 @@ use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use serde::{Deserialize, Serialize}; use util::ResultExt; -use crate::claude::mcp_server::{McpConfig, ZedMcpServer}; +use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; use acp_thread::{AcpThread, AgentConnection}; @@ -53,15 +53,50 @@ impl AgentServer for ClaudeCode { fn connect( &self, - root_dir: &Path, - project: &Entity, - cx: &mut App, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, ) -> Task>> { - let project = project.clone(); - let root_dir = root_dir.to_path_buf(); + let connection = ClaudeAgentConnection { + sessions: Default::default(), + }; + + Task::ready(Ok(Rc::new(connection) as _)) + } +} + +#[cfg(unix)] +fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { + let pid = nix::unistd::Pid::from_raw(pid); + + nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) + .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) +} + +#[cfg(windows)] +fn send_interrupt(_pid: i32) -> anyhow::Result<()> { + panic!("Cancel not implemented on Windows") +} + +struct ClaudeAgentConnection { + sessions: Rc>>, +} + +impl AgentConnection for ClaudeAgentConnection { + fn name(&self) -> &'static str { + ClaudeCode.name() + } + + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let cwd = cwd.to_owned(); cx.spawn(async move |cx| { - let threads_map = Rc::new(RefCell::new(HashMap::default())); - let permission_mcp_server = ZedMcpServer::new(threads_map.clone(), cx).await?; + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; let mut mcp_servers = HashMap::default(); mcp_servers.insert( @@ -109,7 +144,7 @@ impl AgentServer for ClaudeCode { mode, session_id.clone(), &mcp_config_path, - &root_dir, + &cwd, ) .await?; mode = ClaudeSessionMode::Resume; @@ -118,7 +153,7 @@ impl AgentServer for ClaudeCode { log::trace!("Spawned (pid: {})", pid); let mut io_fut = pin!( - ClaudeAgentConnection::handle_io( + ClaudeAgentSession::handle_io( outgoing_rx.take().unwrap(), incoming_message_tx.clone(), child.stdin.take().unwrap(), @@ -155,11 +190,11 @@ impl AgentServer for ClaudeCode { let end_turn_tx = Rc::new(RefCell::new(None)); let handler_task = cx.spawn({ let end_turn_tx = end_turn_tx.clone(); - let threads_map = threads_map.clone(); + let thread_rx = thread_rx.clone(); async move |cx| { while let Some(message) = incoming_message_rx.next().await { - ClaudeAgentConnection::handle_message( - threads_map.clone(), + ClaudeAgentSession::handle_message( + thread_rx.clone(), message, end_turn_tx.clone(), cx, @@ -169,56 +204,23 @@ impl AgentServer for ClaudeCode { } }); - let connection = ClaudeAgentConnection { - threads_map, + let thread = + cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; + + thread_tx.send(thread.downgrade())?; + + let session = ClaudeAgentSession { outgoing_tx, end_turn_tx, cancel_tx, - session_id, _handler_task: handler_task, _mcp_server: Some(permission_mcp_server), }; - Ok(Rc::new(connection) as _) - }) - } -} - -#[cfg(unix)] -fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { - let pid = nix::unistd::Pid::from_raw(pid); + self.sessions.borrow_mut().insert(session_id, session); - nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) - .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) -} - -#[cfg(windows)] -fn send_interrupt(_pid: i32) -> anyhow::Result<()> { - panic!("Cancel not implemented on Windows") -} - -impl AgentConnection for ClaudeAgentConnection { - fn name(&self) -> &'static str { - ClaudeCode.name() - } - - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let session_id = self.session_id.clone(); - let thread_result = - cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx)); - - if let Ok(thread) = &thread_result { - self.threads_map - .borrow_mut() - .insert(session_id, thread.downgrade()); - } - - Task::ready(thread_result) + Ok(thread) + }) } fn authenticate(&self, _cx: &mut App) -> Task> { @@ -226,8 +228,16 @@ impl AgentConnection for ClaudeAgentConnection { } fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(¶ms.session_id) else { + return Task::ready(Err(anyhow!( + "Attempted to send message to nonexistent session {}", + params.session_id + ))); + }; + let (tx, rx) = oneshot::channel(); - self.end_turn_tx.borrow_mut().replace(tx); + session.end_turn_tx.borrow_mut().replace(tx); let mut content = String::new(); for chunk in params.prompt { @@ -246,7 +256,7 @@ impl AgentConnection for ClaudeAgentConnection { } } - if let Err(err) = self.outgoing_tx.unbounded_send(SdkMessage::User { + if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { message: Message { role: Role::User, content: Content::UntaggedText(content), @@ -267,9 +277,20 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn cancel(&self, cx: &mut App) { + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(&session_id) else { + log::warn!("Attempted to cancel nonexistent session {}", session_id); + return; + }; + let (done_tx, done_rx) = oneshot::channel(); - if self.cancel_tx.unbounded_send(done_tx).log_err().is_some() { + if session + .cancel_tx + .unbounded_send(done_tx) + .log_err() + .is_some() + { cx.foreground_executor() .spawn(async move { done_rx.await? }) .detach_and_log_err(cx); @@ -326,19 +347,17 @@ async fn spawn_claude( Ok(child) } -struct ClaudeAgentConnection { - threads_map: Rc>>>, - session_id: acp::SessionId, +struct ClaudeAgentSession { outgoing_tx: UnboundedSender, end_turn_tx: Rc>>>>, cancel_tx: UnboundedSender>>, - _mcp_server: Option, + _mcp_server: Option, _handler_task: Task<()>, } -impl ClaudeAgentConnection { +impl ClaudeAgentSession { async fn handle_message( - threads_map: Rc>>>, + mut thread_rx: watch::Receiver>, message: SdkMessage, end_turn_tx: Rc>>>>, cx: &mut AsyncApp, @@ -346,20 +365,22 @@ impl ClaudeAgentConnection { match message { SdkMessage::Assistant { message, - session_id, + session_id: _, } | SdkMessage::User { message, - session_id, + session_id: _, } => { - let threads_map = threads_map.borrow(); - let Some(thread) = session_id - .and_then(|session_id| threads_map.get(&acp::SessionId(session_id.into()))) + let Some(thread) = thread_rx + .recv() + .await + .log_err() .and_then(|entity| entity.upgrade()) else { - log::error!("Thread not found for session"); + log::error!("Received an SDK message but thread is gone"); return; }; + for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index 30fde3b9938bef0cbf1234248365270c2c139c15..0a39a02931caaa4100677b41b2daec2f127137ea 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, path::PathBuf, rc::Rc}; +use std::path::PathBuf; use acp_thread::AcpThread; use agent_client_protocol as acp; @@ -9,13 +9,13 @@ use context_server::types::{ ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, }; -use gpui::{App, AsyncApp, Task, WeakEntity}; +use gpui::{App, AsyncApp, Entity, Task, WeakEntity}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; -pub struct ZedMcpServer { +pub struct ClaudeZedMcpServer { server: context_server::listener::McpServer, } @@ -45,16 +45,16 @@ enum PermissionToolBehavior { Deny, } -impl ZedMcpServer { +impl ClaudeZedMcpServer { pub async fn new( - thread_map: Rc>>>, + thread_rx: watch::Receiver>, cx: &AsyncApp, ) -> Result { let mut mcp_server = context_server::listener::McpServer::new(cx).await?; mcp_server.handle_request::(Self::handle_initialize); mcp_server.handle_request::(Self::handle_list_tools); mcp_server.handle_request::(move |request, cx| { - Self::handle_call_tool(request, thread_map.clone(), cx) + Self::handle_call_tool(request, thread_rx.clone(), cx) }); Ok(Self { server: mcp_server }) @@ -142,15 +142,19 @@ impl ZedMcpServer { fn handle_call_tool( request: CallToolParams, - threads_map: Rc>>>, + mut thread_rx: watch::Receiver>, cx: &App, ) -> Task> { cx.spawn(async move |cx| { + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + if request.name.as_str() == PERMISSION_TOOL { let input = serde_json::from_value(request.arguments.context("Arguments required")?)?; - let result = Self::handle_permissions_tool_call(input, threads_map, cx).await?; + let result = Self::handle_permissions_tool_call(input, thread, cx).await?; Ok(CallToolResponse { content: vec![ToolResponseContent::Text { text: serde_json::to_string(&result)?, @@ -162,7 +166,7 @@ impl ZedMcpServer { let input = serde_json::from_value(request.arguments.context("Arguments required")?)?; - let content = Self::handle_read_tool_call(input, threads_map, cx).await?; + let content = Self::handle_read_tool_call(input, thread, cx).await?; Ok(CallToolResponse { content, is_error: None, @@ -172,7 +176,7 @@ impl ZedMcpServer { let input = serde_json::from_value(request.arguments.context("Arguments required")?)?; - Self::handle_edit_tool_call(input, threads_map, cx).await?; + Self::handle_edit_tool_call(input, thread, cx).await?; Ok(CallToolResponse { content: vec![], is_error: None, @@ -190,19 +194,10 @@ impl ZedMcpServer { offset, limit, }: ReadToolParams, - threads_map: Rc>>>, + thread: Entity, cx: &AsyncApp, ) -> Task>> { cx.spawn(async move |cx| { - // todo! get session id somehow - let thread = { - let threads_map = threads_map.borrow(); - let Some((_, thread)) = threads_map.iter().next() else { - anyhow::bail!("Server not available"); - }; - thread.clone() - }; - let content = thread .update(cx, |thread, cx| { thread.read_text_file(abs_path, offset, limit, false, cx) @@ -215,19 +210,10 @@ impl ZedMcpServer { fn handle_edit_tool_call( params: EditToolParams, - threads_map: Rc>>>, + thread: Entity, cx: &AsyncApp, ) -> Task> { cx.spawn(async move |cx| { - // todo! get session id somehow - let thread = { - let threads_map = threads_map.borrow(); - let Some((_, thread)) = threads_map.iter().next() else { - anyhow::bail!("Server not available"); - }; - thread.clone() - }; - let content = thread .update(cx, |threads, cx| { threads.read_text_file(params.abs_path.clone(), None, None, true, cx) @@ -251,19 +237,10 @@ impl ZedMcpServer { fn handle_permissions_tool_call( params: PermissionToolParams, - threads_map: Rc>>>, + thread: Entity, cx: &AsyncApp, ) -> Task> { cx.spawn(async move |cx| { - // todo! get session id somehow - let thread = { - let threads_map = threads_map.borrow(); - let Some((_, thread)) = threads_map.iter().next() else { - anyhow::bail!("Server not available"); - }; - thread.clone() - }; - let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone()); let tool_call_id =