Move process creation to ClaudeConnection::new_thread

Agus Zubiaga and Ben Brandt created

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/acp_thread/src/acp_thread.rs           |   2 
crates/acp_thread/src/connection.rs           |   4 
crates/agent_servers/src/agent_servers.rs     |   1 
crates/agent_servers/src/claude.rs            | 163 +++++++++++---------
crates/agent_servers/src/claude/mcp_server.rs |  57 ++-----
5 files changed, 113 insertions(+), 114 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -965,7 +965,7 @@ impl AcpThread {
             }
         }
 
-        self.connection.cancel(cx);
+        self.connection.cancel(&self.session_id, cx);
 
         // Wait for the send task to complete
         cx.foreground_executor().spawn(send_task)

crates/acp_thread/src/connection.rs 🔗

@@ -23,7 +23,7 @@ pub trait AgentConnection {
 
     fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>>;
 
-    fn cancel(&self, cx: &mut App);
+    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 }
 
 #[derive(Debug)]
@@ -111,7 +111,7 @@ impl AgentConnection for OldAcpAgentConnection {
         })
     }
 
-    fn cancel(&self, cx: &mut App) {
+    fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
         let task = self
             .connection
             .request_any(acp_old::CancelSendMessageParams.into_any());

crates/agent_servers/src/agent_servers.rs 🔗

@@ -39,6 +39,7 @@ pub trait AgentServer: Send {
 
     fn connect(
         &self,
+        // these will go away when old_acp is fully removed
         root_dir: &Path,
         project: &Entity<Project>,
         cx: &mut App,

crates/agent_servers/src/claude.rs 🔗

@@ -26,7 +26,7 @@ use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
 use serde::{Deserialize, Serialize};
 use util::ResultExt;
 
-use crate::claude::mcp_server::{McpConfig, ZedMcpServer};
+use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
 use crate::claude::tools::ClaudeTool;
 use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
 use acp_thread::{AcpThread, AgentConnection};
@@ -53,15 +53,50 @@ impl AgentServer for ClaudeCode {
 
     fn connect(
         &self,
-        root_dir: &Path,
-        project: &Entity<Project>,
-        cx: &mut App,
+        _root_dir: &Path,
+        _project: &Entity<Project>,
+        _cx: &mut App,
     ) -> Task<Result<Rc<dyn AgentConnection>>> {
-        let project = project.clone();
-        let root_dir = root_dir.to_path_buf();
+        let connection = ClaudeAgentConnection {
+            sessions: Default::default(),
+        };
+
+        Task::ready(Ok(Rc::new(connection) as _))
+    }
+}
+
+#[cfg(unix)]
+fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
+    let pid = nix::unistd::Pid::from_raw(pid);
+
+    nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
+        .map_err(|e| anyhow!("Failed to interrupt process: {}", e))
+}
+
+#[cfg(windows)]
+fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
+    panic!("Cancel not implemented on Windows")
+}
+
+struct ClaudeAgentConnection {
+    sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
+}
+
+impl AgentConnection for ClaudeAgentConnection {
+    fn name(&self) -> &'static str {
+        ClaudeCode.name()
+    }
+
+    fn new_thread(
+        self: Rc<Self>,
+        project: Entity<Project>,
+        cwd: &Path,
+        cx: &mut AsyncApp,
+    ) -> Task<Result<Entity<AcpThread>>> {
+        let cwd = cwd.to_owned();
         cx.spawn(async move |cx| {
-            let threads_map = Rc::new(RefCell::new(HashMap::default()));
-            let permission_mcp_server = ZedMcpServer::new(threads_map.clone(), cx).await?;
+            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
+            let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
 
             let mut mcp_servers = HashMap::default();
             mcp_servers.insert(
@@ -109,7 +144,7 @@ impl AgentServer for ClaudeCode {
                             mode,
                             session_id.clone(),
                             &mcp_config_path,
-                            &root_dir,
+                            &cwd,
                         )
                         .await?;
                         mode = ClaudeSessionMode::Resume;
@@ -118,7 +153,7 @@ impl AgentServer for ClaudeCode {
                         log::trace!("Spawned (pid: {})", pid);
 
                         let mut io_fut = pin!(
-                            ClaudeAgentConnection::handle_io(
+                            ClaudeAgentSession::handle_io(
                                 outgoing_rx.take().unwrap(),
                                 incoming_message_tx.clone(),
                                 child.stdin.take().unwrap(),
@@ -155,11 +190,11 @@ impl AgentServer for ClaudeCode {
             let end_turn_tx = Rc::new(RefCell::new(None));
             let handler_task = cx.spawn({
                 let end_turn_tx = end_turn_tx.clone();
-                let threads_map = threads_map.clone();
+                let thread_rx = thread_rx.clone();
                 async move |cx| {
                     while let Some(message) = incoming_message_rx.next().await {
-                        ClaudeAgentConnection::handle_message(
-                            threads_map.clone(),
+                        ClaudeAgentSession::handle_message(
+                            thread_rx.clone(),
                             message,
                             end_turn_tx.clone(),
                             cx,
@@ -169,56 +204,23 @@ impl AgentServer for ClaudeCode {
                 }
             });
 
-            let connection = ClaudeAgentConnection {
-                threads_map,
+            let thread =
+                cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
+
+            thread_tx.send(thread.downgrade())?;
+
+            let session = ClaudeAgentSession {
                 outgoing_tx,
                 end_turn_tx,
                 cancel_tx,
-                session_id,
                 _handler_task: handler_task,
                 _mcp_server: Some(permission_mcp_server),
             };
 
-            Ok(Rc::new(connection) as _)
-        })
-    }
-}
-
-#[cfg(unix)]
-fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
-    let pid = nix::unistd::Pid::from_raw(pid);
+            self.sessions.borrow_mut().insert(session_id, session);
 
-    nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
-        .map_err(|e| anyhow!("Failed to interrupt process: {}", e))
-}
-
-#[cfg(windows)]
-fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
-    panic!("Cancel not implemented on Windows")
-}
-
-impl AgentConnection for ClaudeAgentConnection {
-    fn name(&self) -> &'static str {
-        ClaudeCode.name()
-    }
-
-    fn new_thread(
-        self: Rc<Self>,
-        project: Entity<Project>,
-        _cwd: &Path,
-        cx: &mut AsyncApp,
-    ) -> Task<Result<Entity<AcpThread>>> {
-        let session_id = self.session_id.clone();
-        let thread_result =
-            cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx));
-
-        if let Ok(thread) = &thread_result {
-            self.threads_map
-                .borrow_mut()
-                .insert(session_id, thread.downgrade());
-        }
-
-        Task::ready(thread_result)
+            Ok(thread)
+        })
     }
 
     fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
@@ -226,8 +228,16 @@ impl AgentConnection for ClaudeAgentConnection {
     }
 
     fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>> {
+        let sessions = self.sessions.borrow();
+        let Some(session) = sessions.get(&params.session_id) else {
+            return Task::ready(Err(anyhow!(
+                "Attempted to send message to nonexistent session {}",
+                params.session_id
+            )));
+        };
+
         let (tx, rx) = oneshot::channel();
-        self.end_turn_tx.borrow_mut().replace(tx);
+        session.end_turn_tx.borrow_mut().replace(tx);
 
         let mut content = String::new();
         for chunk in params.prompt {
@@ -246,7 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
             }
         }
 
-        if let Err(err) = self.outgoing_tx.unbounded_send(SdkMessage::User {
+        if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
             message: Message {
                 role: Role::User,
                 content: Content::UntaggedText(content),
@@ -267,9 +277,20 @@ impl AgentConnection for ClaudeAgentConnection {
         })
     }
 
-    fn cancel(&self, cx: &mut App) {
+    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
+        let sessions = self.sessions.borrow();
+        let Some(session) = sessions.get(&session_id) else {
+            log::warn!("Attempted to cancel nonexistent session {}", session_id);
+            return;
+        };
+
         let (done_tx, done_rx) = oneshot::channel();
-        if self.cancel_tx.unbounded_send(done_tx).log_err().is_some() {
+        if session
+            .cancel_tx
+            .unbounded_send(done_tx)
+            .log_err()
+            .is_some()
+        {
             cx.foreground_executor()
                 .spawn(async move { done_rx.await? })
                 .detach_and_log_err(cx);
@@ -326,19 +347,17 @@ async fn spawn_claude(
     Ok(child)
 }
 
-struct ClaudeAgentConnection {
-    threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
-    session_id: acp::SessionId,
+struct ClaudeAgentSession {
     outgoing_tx: UnboundedSender<SdkMessage>,
     end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
     cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
-    _mcp_server: Option<ZedMcpServer>,
+    _mcp_server: Option<ClaudeZedMcpServer>,
     _handler_task: Task<()>,
 }
 
-impl ClaudeAgentConnection {
+impl ClaudeAgentSession {
     async fn handle_message(
-        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
         message: SdkMessage,
         end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
         cx: &mut AsyncApp,
@@ -346,20 +365,22 @@ impl ClaudeAgentConnection {
         match message {
             SdkMessage::Assistant {
                 message,
-                session_id,
+                session_id: _,
             }
             | SdkMessage::User {
                 message,
-                session_id,
+                session_id: _,
             } => {
-                let threads_map = threads_map.borrow();
-                let Some(thread) = session_id
-                    .and_then(|session_id| threads_map.get(&acp::SessionId(session_id.into())))
+                let Some(thread) = thread_rx
+                    .recv()
+                    .await
+                    .log_err()
                     .and_then(|entity| entity.upgrade())
                 else {
-                    log::error!("Thread not found for session");
+                    log::error!("Received an SDK message but thread is gone");
                     return;
                 };
+
                 for chunk in message.content.chunks() {
                     match chunk {
                         ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {

crates/agent_servers/src/claude/mcp_server.rs 🔗

@@ -1,4 +1,4 @@
-use std::{cell::RefCell, path::PathBuf, rc::Rc};
+use std::path::PathBuf;
 
 use acp_thread::AcpThread;
 use agent_client_protocol as acp;
@@ -9,13 +9,13 @@ use context_server::types::{
     ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
     ToolResponseContent, ToolsCapabilities, requests,
 };
-use gpui::{App, AsyncApp, Task, WeakEntity};
+use gpui::{App, AsyncApp, Entity, Task, WeakEntity};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 
 use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
 
-pub struct ZedMcpServer {
+pub struct ClaudeZedMcpServer {
     server: context_server::listener::McpServer,
 }
 
@@ -45,16 +45,16 @@ enum PermissionToolBehavior {
     Deny,
 }
 
-impl ZedMcpServer {
+impl ClaudeZedMcpServer {
     pub async fn new(
-        thread_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
         cx: &AsyncApp,
     ) -> Result<Self> {
         let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
         mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
         mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
         mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
-            Self::handle_call_tool(request, thread_map.clone(), cx)
+            Self::handle_call_tool(request, thread_rx.clone(), cx)
         });
 
         Ok(Self { server: mcp_server })
@@ -142,15 +142,19 @@ impl ZedMcpServer {
 
     fn handle_call_tool(
         request: CallToolParams,
-        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
         cx: &App,
     ) -> Task<Result<CallToolResponse>> {
         cx.spawn(async move |cx| {
+            let Some(thread) = thread_rx.recv().await?.upgrade() else {
+                anyhow::bail!("Thread closed");
+            };
+
             if request.name.as_str() == PERMISSION_TOOL {
                 let input =
                     serde_json::from_value(request.arguments.context("Arguments required")?)?;
 
-                let result = Self::handle_permissions_tool_call(input, threads_map, cx).await?;
+                let result = Self::handle_permissions_tool_call(input, thread, cx).await?;
                 Ok(CallToolResponse {
                     content: vec![ToolResponseContent::Text {
                         text: serde_json::to_string(&result)?,
@@ -162,7 +166,7 @@ impl ZedMcpServer {
                 let input =
                     serde_json::from_value(request.arguments.context("Arguments required")?)?;
 
-                let content = Self::handle_read_tool_call(input, threads_map, cx).await?;
+                let content = Self::handle_read_tool_call(input, thread, cx).await?;
                 Ok(CallToolResponse {
                     content,
                     is_error: None,
@@ -172,7 +176,7 @@ impl ZedMcpServer {
                 let input =
                     serde_json::from_value(request.arguments.context("Arguments required")?)?;
 
-                Self::handle_edit_tool_call(input, threads_map, cx).await?;
+                Self::handle_edit_tool_call(input, thread, cx).await?;
                 Ok(CallToolResponse {
                     content: vec![],
                     is_error: None,
@@ -190,19 +194,10 @@ impl ZedMcpServer {
             offset,
             limit,
         }: ReadToolParams,
-        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        thread: Entity<AcpThread>,
         cx: &AsyncApp,
     ) -> Task<Result<Vec<ToolResponseContent>>> {
         cx.spawn(async move |cx| {
-            // todo! get session id somehow
-            let thread = {
-                let threads_map = threads_map.borrow();
-                let Some((_, thread)) = threads_map.iter().next() else {
-                    anyhow::bail!("Server not available");
-                };
-                thread.clone()
-            };
-
             let content = thread
                 .update(cx, |thread, cx| {
                     thread.read_text_file(abs_path, offset, limit, false, cx)
@@ -215,19 +210,10 @@ impl ZedMcpServer {
 
     fn handle_edit_tool_call(
         params: EditToolParams,
-        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        thread: Entity<AcpThread>,
         cx: &AsyncApp,
     ) -> Task<Result<()>> {
         cx.spawn(async move |cx| {
-            // todo! get session id somehow
-            let thread = {
-                let threads_map = threads_map.borrow();
-                let Some((_, thread)) = threads_map.iter().next() else {
-                    anyhow::bail!("Server not available");
-                };
-                thread.clone()
-            };
-
             let content = thread
                 .update(cx, |threads, cx| {
                     threads.read_text_file(params.abs_path.clone(), None, None, true, cx)
@@ -251,19 +237,10 @@ impl ZedMcpServer {
 
     fn handle_permissions_tool_call(
         params: PermissionToolParams,
-        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        thread: Entity<AcpThread>,
         cx: &AsyncApp,
     ) -> Task<Result<PermissionToolResponse>> {
         cx.spawn(async move |cx| {
-            // todo! get session id somehow
-            let thread = {
-                let threads_map = threads_map.borrow();
-                let Some((_, thread)) = threads_map.iter().next() else {
-                    anyhow::bail!("Server not available");
-                };
-                thread.clone()
-            };
-
             let claude_tool = ClaudeTool::infer(&params.tool_name, params.input.clone());
 
             let tool_call_id =