Detailed changes
@@ -965,7 +965,7 @@ impl AcpThread {
}
}
- self.connection.cancel(cx);
+ self.connection.cancel(&self.session_id, cx);
// Wait for the send task to complete
cx.foreground_executor().spawn(send_task)
@@ -23,7 +23,7 @@ pub trait AgentConnection {
fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>>;
- fn cancel(&self, cx: &mut App);
+ fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
}
#[derive(Debug)]
@@ -111,7 +111,7 @@ impl AgentConnection for OldAcpAgentConnection {
})
}
- fn cancel(&self, cx: &mut App) {
+ fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
let task = self
.connection
.request_any(acp_old::CancelSendMessageParams.into_any());
@@ -39,6 +39,7 @@ pub trait AgentServer: Send {
fn connect(
&self,
+ // these will go away when old_acp is fully removed
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
@@ -26,7 +26,7 @@ use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use serde::{Deserialize, Serialize};
use util::ResultExt;
-use crate::claude::mcp_server::{McpConfig, ZedMcpServer};
+use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpThread, AgentConnection};
@@ -53,15 +53,50 @@ impl AgentServer for ClaudeCode {
fn connect(
&self,
- root_dir: &Path,
- project: &Entity<Project>,
- cx: &mut App,
+ _root_dir: &Path,
+ _project: &Entity<Project>,
+ _cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
- let project = project.clone();
- let root_dir = root_dir.to_path_buf();
+ let connection = ClaudeAgentConnection {
+ sessions: Default::default(),
+ };
+
+ Task::ready(Ok(Rc::new(connection) as _))
+ }
+}
+
+#[cfg(unix)]
+fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
+ let pid = nix::unistd::Pid::from_raw(pid);
+
+ nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
+ .map_err(|e| anyhow!("Failed to interrupt process: {}", e))
+}
+
+#[cfg(windows)]
+fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
+ panic!("Cancel not implemented on Windows")
+}
+
+struct ClaudeAgentConnection {
+ sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
+}
+
+impl AgentConnection for ClaudeAgentConnection {
+ fn name(&self) -> &'static str {
+ ClaudeCode.name()
+ }
+
+ fn new_thread(
+ self: Rc<Self>,
+ project: Entity<Project>,
+ cwd: &Path,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ let cwd = cwd.to_owned();
cx.spawn(async move |cx| {
- let threads_map = Rc::new(RefCell::new(HashMap::default()));
- let permission_mcp_server = ZedMcpServer::new(threads_map.clone(), cx).await?;
+ let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
+ let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
let mut mcp_servers = HashMap::default();
mcp_servers.insert(
@@ -109,7 +144,7 @@ impl AgentServer for ClaudeCode {
mode,
session_id.clone(),
&mcp_config_path,
- &root_dir,
+ &cwd,
)
.await?;
mode = ClaudeSessionMode::Resume;
@@ -118,7 +153,7 @@ impl AgentServer for ClaudeCode {
log::trace!("Spawned (pid: {})", pid);
let mut io_fut = pin!(
- ClaudeAgentConnection::handle_io(
+ ClaudeAgentSession::handle_io(
outgoing_rx.take().unwrap(),
incoming_message_tx.clone(),
child.stdin.take().unwrap(),
@@ -155,11 +190,11 @@ impl AgentServer for ClaudeCode {
let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone();
- let threads_map = threads_map.clone();
+ let thread_rx = thread_rx.clone();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await {
- ClaudeAgentConnection::handle_message(
- threads_map.clone(),
+ ClaudeAgentSession::handle_message(
+ thread_rx.clone(),
message,
end_turn_tx.clone(),
cx,
@@ -169,56 +204,23 @@ impl AgentServer for ClaudeCode {
}
});
- let connection = ClaudeAgentConnection {
- threads_map,
+ let thread =
+ cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
+
+ thread_tx.send(thread.downgrade())?;
+
+ let session = ClaudeAgentSession {
outgoing_tx,
end_turn_tx,
cancel_tx,
- session_id,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
- Ok(Rc::new(connection) as _)
- })
- }
-}
-
-#[cfg(unix)]
-fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
- let pid = nix::unistd::Pid::from_raw(pid);
+ self.sessions.borrow_mut().insert(session_id, session);
- nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
- .map_err(|e| anyhow!("Failed to interrupt process: {}", e))
-}
-
-#[cfg(windows)]
-fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
- panic!("Cancel not implemented on Windows")
-}
-
-impl AgentConnection for ClaudeAgentConnection {
- fn name(&self) -> &'static str {
- ClaudeCode.name()
- }
-
- fn new_thread(
- self: Rc<Self>,
- project: Entity<Project>,
- _cwd: &Path,
- cx: &mut AsyncApp,
- ) -> Task<Result<Entity<AcpThread>>> {
- let session_id = self.session_id.clone();
- let thread_result =
- cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx));
-
- if let Ok(thread) = &thread_result {
- self.threads_map
- .borrow_mut()
- .insert(session_id, thread.downgrade());
- }
-
- Task::ready(thread_result)
+ Ok(thread)
+ })
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
@@ -226,8 +228,16 @@ impl AgentConnection for ClaudeAgentConnection {
}
fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>> {
+ let sessions = self.sessions.borrow();
+ let Some(session) = sessions.get(¶ms.session_id) else {
+ return Task::ready(Err(anyhow!(
+ "Attempted to send message to nonexistent session {}",
+ params.session_id
+ )));
+ };
+
let (tx, rx) = oneshot::channel();
- self.end_turn_tx.borrow_mut().replace(tx);
+ session.end_turn_tx.borrow_mut().replace(tx);
let mut content = String::new();
for chunk in params.prompt {
@@ -246,7 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
}
}
- if let Err(err) = self.outgoing_tx.unbounded_send(SdkMessage::User {
+ if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
message: Message {
role: Role::User,
content: Content::UntaggedText(content),
@@ -267,9 +277,20 @@ impl AgentConnection for ClaudeAgentConnection {
})
}
- fn cancel(&self, cx: &mut App) {
+ fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
+ let sessions = self.sessions.borrow();
+ let Some(session) = sessions.get(&session_id) else {
+ log::warn!("Attempted to cancel nonexistent session {}", session_id);
+ return;
+ };
+
let (done_tx, done_rx) = oneshot::channel();
- if self.cancel_tx.unbounded_send(done_tx).log_err().is_some() {
+ if session
+ .cancel_tx
+ .unbounded_send(done_tx)
+ .log_err()
+ .is_some()
+ {
cx.foreground_executor()
.spawn(async move { done_rx.await? })
.detach_and_log_err(cx);
@@ -326,19 +347,17 @@ async fn spawn_claude(
Ok(child)
}
-struct ClaudeAgentConnection {
- threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
- session_id: acp::SessionId,
+struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
- _mcp_server: Option<ZedMcpServer>,
+ _mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>,
}
-impl ClaudeAgentConnection {
+impl ClaudeAgentSession {
async fn handle_message(
- threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
cx: &mut AsyncApp,
@@ -346,20 +365,22 @@ impl ClaudeAgentConnection {
match message {
SdkMessage::Assistant {
message,
- session_id,
+ session_id: _,
}
| SdkMessage::User {
message,
- session_id,
+ session_id: _,
} => {
- let threads_map = threads_map.borrow();
- let Some(thread) = session_id
- .and_then(|session_id| threads_map.get(&acp::SessionId(session_id.into())))
+ let Some(thread) = thread_rx
+ .recv()
+ .await
+ .log_err()
.and_then(|entity| entity.upgrade())
else {
- log::error!("Thread not found for session");
+ log::error!("Received an SDK message but thread is gone");
return;
};
+
for chunk in message.content.chunks() {
match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
@@ -1,4 +1,4 @@
-use std::{cell::RefCell, path::PathBuf, rc::Rc};
+use std::path::PathBuf;
use acp_thread::AcpThread;
use agent_client_protocol as acp;
@@ -9,13 +9,13 @@ use context_server::types::{
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
ToolResponseContent, ToolsCapabilities, requests,
};
-use gpui::{App, AsyncApp, Task, WeakEntity};
+use gpui::{App, AsyncApp, Entity, Task, WeakEntity};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
-pub struct ZedMcpServer {
+pub struct ClaudeZedMcpServer {
server: context_server::listener::McpServer,
}
@@ -45,16 +45,16 @@ enum PermissionToolBehavior {
Deny,
}
-impl ZedMcpServer {
+impl ClaudeZedMcpServer {
pub async fn new(
- thread_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
cx: &AsyncApp,
) -> Result<Self> {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
- Self::handle_call_tool(request, thread_map.clone(), cx)
+ Self::handle_call_tool(request, thread_rx.clone(), cx)
});
Ok(Self { server: mcp_server })
@@ -142,15 +142,19 @@ impl ZedMcpServer {
fn handle_call_tool(
request: CallToolParams,
- threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
cx: &App,
) -> Task<Result<CallToolResponse>> {
cx.spawn(async move |cx| {
+ let Some(thread) = thread_rx.recv().await?.upgrade() else {
+ anyhow::bail!("Thread closed");
+ };
+
if request.name.as_str() == PERMISSION_TOOL {
let input =
serde_json::from_value(request.arguments.context("Arguments required")?)?;
- let result = Self::handle_permissions_tool_call(input, threads_map, cx).await?;
+ let result = Self::handle_permissions_tool_call(input, thread, cx).await?;
Ok(CallToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&result)?,
@@ -162,7 +166,7 @@ impl ZedMcpServer {
let input =
serde_json::from_value(request.arguments.context("Arguments required")?)?;
- let content = Self::handle_read_tool_call(input, threads_map, cx).await?;
+ let content = Self::handle_read_tool_call(input, thread, cx).await?;
Ok(CallToolResponse {
content,
is_error: None,
@@ -172,7 +176,7 @@ impl ZedMcpServer {
let input =
serde_json::from_value(request.arguments.context("Arguments required")?)?;
- Self::handle_edit_tool_call(input, threads_map, cx).await?;
+ Self::handle_edit_tool_call(input, thread, cx).await?;
Ok(CallToolResponse {
content: vec![],
is_error: None,
@@ -190,19 +194,10 @@ impl ZedMcpServer {
offset,
limit,
}: ReadToolParams,
- threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ thread: Entity<AcpThread>,
cx: &AsyncApp,
) -> Task<Result<Vec<ToolResponseContent>>> {
cx.spawn(async move |cx| {
- // todo! get session id somehow
- let thread = {
- let threads_map = threads_map.borrow();
- let Some((_, thread)) = threads_map.iter().next() else {
- anyhow::bail!("Server not available");
- };
- thread.clone()
- };
-
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(abs_path, offset, limit, false, cx)
@@ -215,19 +210,10 @@ impl ZedMcpServer {
fn handle_edit_tool_call(
params: EditToolParams,
- threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ thread: Entity<AcpThread>,
cx: &AsyncApp,
) -> Task<Result<()>> {
cx.spawn(async move |cx| {
- // todo! get session id somehow
- let thread = {
- let threads_map = threads_map.borrow();
- let Some((_, thread)) = threads_map.iter().next() else {
- anyhow::bail!("Server not available");
- };
- thread.clone()
- };
-
let content = thread
.update(cx, |threads, cx| {
threads.read_text_file(params.abs_path.clone(), None, None, true, cx)
@@ -251,19 +237,10 @@ impl ZedMcpServer {
fn handle_permissions_tool_call(
params: PermissionToolParams,
- threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ thread: Entity<AcpThread>,
cx: &AsyncApp,
) -> Task<Result<PermissionToolResponse>> {
cx.spawn(async move |cx| {
- // todo! get session id somehow
- let thread = {
- let threads_map = threads_map.borrow();
- let Some((_, thread)) = threads_map.iter().next() else {
- anyhow::bail!("Server not available");
- };
- thread.clone()
- };
-
let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone());
let tool_call_id =