Moar fixes

Agus Zubiaga created

Change summary

crates/acp_thread/src/acp_thread.rs    |  84 +++----------
crates/agent_servers/src/claude.rs     |   5 
crates/agent_servers/src/mcp_server.rs | 174 ++++++++++++++-------------
3 files changed, 111 insertions(+), 152 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -1001,7 +1001,9 @@ impl AcpThread {
 
     pub fn read_text_file(
         &self,
-        request: acp::ReadTextFileArguments,
+        path: PathBuf,
+        line: Option<u32>,
+        limit: Option<u32>,
         reuse_shared_snapshot: bool,
         cx: &mut Context<Self>,
     ) -> Task<Result<String>> {
@@ -1010,7 +1012,7 @@ impl AcpThread {
         cx.spawn(async move |this, cx| {
             let load = project.update(cx, |project, cx| {
                 let path = project
-                    .project_path_for_absolute_path(&request.path, cx)
+                    .project_path_for_absolute_path(&path, cx)
                     .context("invalid path")?;
                 anyhow::Ok(project.open_buffer(path, cx))
             });
@@ -1036,7 +1038,7 @@ impl AcpThread {
                     let position = buffer
                         .read(cx)
                         .snapshot()
-                        .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
+                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
                     project.set_agent_location(
                         Some(AgentLocation {
                             buffer: buffer.downgrade(),
@@ -1052,11 +1054,11 @@ impl AcpThread {
             this.update(cx, |this, _| {
                 let text = snapshot.text();
                 this.shared_buffers.insert(buffer.clone(), snapshot);
-                if request.line.is_none() && request.limit.is_none() {
+                if line.is_none() && limit.is_none() {
                     return Ok(text);
                 }
-                let limit = request.limit.unwrap_or(u32::MAX) as usize;
-                let Some(line) = request.line else {
+                let limit = limit.unwrap_or(u32::MAX) as usize;
+                let Some(line) = line else {
                     return Ok(text.lines().take(limit).collect::<String>());
                 };
 
@@ -1075,7 +1077,8 @@ impl AcpThread {
 
     pub fn write_text_file(
         &self,
-        request: acp::WriteTextFileToolArguments,
+        path: PathBuf,
+        content: String,
         cx: &mut Context<Self>,
     ) -> Task<Result<()>> {
         let project = self.project.clone();
@@ -1083,7 +1086,7 @@ impl AcpThread {
         cx.spawn(async move |this, cx| {
             let load = project.update(cx, |project, cx| {
                 let path = project
-                    .project_path_for_absolute_path(&request.path, cx)
+                    .project_path_for_absolute_path(&path, cx)
                     .context("invalid path")?;
                 anyhow::Ok(project.open_buffer(path, cx))
             });
@@ -1098,7 +1101,7 @@ impl AcpThread {
                 .background_executor()
                 .spawn(async move {
                     let old_text = snapshot.text();
-                    text_diff(old_text.as_str(), &request.content)
+                    text_diff(old_text.as_str(), &content)
                         .into_iter()
                         .map(|(range, replacement)| {
                             (
@@ -1165,43 +1168,8 @@ impl OldAcpClientDelegate {
             next_tool_call_id: Rc::new(RefCell::new(0)),
         }
     }
-
-    pub async fn clear_completed_plan_entries(&self) -> Result<()> {
-        let cx = &mut self.cx.clone();
-        cx.update(|cx| {
-            self.thread
-                .borrow()
-                .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx))
-        })?
-        .context("Failed to update thread")?;
-
-        Ok(())
-    }
-
-    pub async fn read_text_file_reusing_snapshot(
-        &self,
-        request: acp_old::ReadTextFileParams,
-    ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
-        let content = self
-            .cx
-            .update(|cx| {
-                self.thread.borrow().update(cx, |thread, cx| {
-                    thread.read_text_file(
-                        acp::ReadTextFileArguments {
-                            path: request.path,
-                            line: request.line,
-                            limit: request.limit,
-                        },
-                        true,
-                        cx,
-                    )
-                })
-            })?
-            .context("Failed to update thread")?
-            .await?;
-        Ok(acp_old::ReadTextFileResponse { content })
-    }
 }
+
 impl acp_old::Client for OldAcpClientDelegate {
     async fn stream_assistant_message_chunk(
         &self,
@@ -1412,21 +1380,13 @@ impl acp_old::Client for OldAcpClientDelegate {
 
     async fn read_text_file(
         &self,
-        request: acp_old::ReadTextFileParams,
+        acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams,
     ) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
         let content = self
             .cx
             .update(|cx| {
                 self.thread.borrow().update(cx, |thread, cx| {
-                    thread.read_text_file(
-                        acp::ReadTextFileArguments {
-                            path: request.path,
-                            line: request.line,
-                            limit: request.limit,
-                        },
-                        false,
-                        cx,
-                    )
+                    thread.read_text_file(path, line, limit, false, cx)
                 })
             })?
             .context("Failed to update thread")?
@@ -1436,19 +1396,13 @@ impl acp_old::Client for OldAcpClientDelegate {
 
     async fn write_text_file(
         &self,
-        request: acp_old::WriteTextFileParams,
+        acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams,
     ) -> Result<(), acp_old::Error> {
         self.cx
             .update(|cx| {
-                self.thread.borrow().update(cx, |thread, cx| {
-                    thread.write_text_file(
-                        acp::WriteTextFileToolArguments {
-                            path: request.path,
-                            content: request.content,
-                        },
-                        cx,
-                    )
-                })
+                self.thread
+                    .borrow()
+                    .update(cx, |thread, cx| thread.write_text_file(path, content, cx))
             })?
             .context("Failed to update thread")?
             .await?;

crates/agent_servers/src/claude.rs 🔗

@@ -65,10 +65,7 @@ impl AgentServer for ClaudeCode {
         let root_dir = root_dir.to_path_buf();
         cx.spawn(async move |cx| {
             let threads_map = Rc::new(RefCell::new(HashMap::default()));
-            let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
-
-            let permission_mcp_server =
-                ZedMcpServer::new(threads_map, tool_id_map.clone(), cx).await?;
+            let permission_mcp_server = ZedMcpServer::new(threads_map, cx).await?;
 
             let mut mcp_servers = HashMap::default();
             mcp_servers.insert(

crates/agent_servers/src/mcp_server.rs 🔗

@@ -1,8 +1,9 @@
+// todo! move this back to claude since, it won't share any of the tools with other agents
+
 use std::{cell::RefCell, path::PathBuf, rc::Rc};
 
-use acp_thread::{AcpThread, OldAcpClientDelegate};
-use agent_client_protocol::{self as acp};
-use agentic_coding_protocol::{self as acp_old, Client as _};
+use acp_thread::AcpThread;
+use agent_client_protocol as acp;
 use anyhow::{Context, Result};
 use collections::HashMap;
 use context_server::types::{
@@ -10,16 +11,11 @@ use context_server::types::{
     ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
     ToolResponseContent, ToolsCapabilities, requests,
 };
-use gpui::{App, AsyncApp, Task};
+use gpui::{App, AsyncApp, Task, WeakEntity};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use util::debug_panic;
 
-// todo! use shared tool inference?
-use crate::claude::{
-    McpServerConfig,
-    tools::{ClaudeTool, EditToolParams, ReadToolParams},
-};
+use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
 
 pub struct ZedMcpServer {
     server: context_server::listener::McpServer,
@@ -54,14 +50,13 @@ enum PermissionToolBehavior {
 impl ZedMcpServer {
     pub async fn new(
         thread_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
-        tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
         cx: &AsyncApp,
     ) -> Result<Self> {
         let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
         mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
         mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
         mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
-            Self::handle_call_tool(request, thread_map.clone(), tool_id_map.clone(), cx)
+            Self::handle_call_tool(request, thread_map.clone(), cx)
         });
 
         Ok(Self { server: mcp_server })
@@ -149,22 +144,15 @@ impl ZedMcpServer {
 
     fn handle_call_tool(
         request: CallToolParams,
-        mut delegate_watch: watch::Receiver<Option<OldAcpClientDelegate>>,
-        tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
+        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
         cx: &App,
     ) -> Task<Result<CallToolResponse>> {
         cx.spawn(async move |cx| {
-            let Some(delegate) = delegate_watch.recv().await? else {
-                debug_panic!("Sent None delegate");
-                anyhow::bail!("Server not available");
-            };
-
             if request.name.as_str() == PERMISSION_TOOL {
                 let input =
                     serde_json::from_value(request.arguments.context("Arguments required")?)?;
 
-                let result =
-                    Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?;
+                let result = Self::handle_permissions_tool_call(input, threads_map, cx).await?;
                 Ok(CallToolResponse {
                     content: vec![ToolResponseContent::Text {
                         text: serde_json::to_string(&result)?,
@@ -176,7 +164,7 @@ impl ZedMcpServer {
                 let input =
                     serde_json::from_value(request.arguments.context("Arguments required")?)?;
 
-                let content = Self::handle_read_tool_call(input, delegate, cx).await?;
+                let content = Self::handle_read_tool_call(input, threads_map, cx).await?;
                 Ok(CallToolResponse {
                     content,
                     is_error: None,
@@ -186,7 +174,7 @@ impl ZedMcpServer {
                 let input =
                     serde_json::from_value(request.arguments.context("Arguments required")?)?;
 
-                Self::handle_edit_tool_call(input, delegate, cx).await?;
+                Self::handle_edit_tool_call(input, threads_map, cx).await?;
                 Ok(CallToolResponse {
                     content: vec![],
                     is_error: None,
@@ -199,49 +187,58 @@ impl ZedMcpServer {
     }
 
     fn handle_read_tool_call(
-        params: ReadToolParams,
-        delegate: OldAcpClientDelegate,
+        ReadToolParams {
+            abs_path,
+            offset,
+            limit,
+        }: ReadToolParams,
+        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
         cx: &AsyncApp,
     ) -> Task<Result<Vec<ToolResponseContent>>> {
-        cx.foreground_executor().spawn(async move {
-            let response = delegate
-                .read_text_file(acp_old::ReadTextFileParams {
-                    path: params.abs_path,
-                    line: params.offset,
-                    limit: params.limit,
-                })
+        cx.spawn(async move |cx| {
+            // todo! get session id somehow
+            let threads_map = threads_map.borrow();
+            let Some((_, thread)) = threads_map.iter().next() else {
+                anyhow::bail!("Server not available");
+            };
+
+            let content = thread
+                .update(cx, |thread, cx| {
+                    thread.read_text_file(abs_path, offset, limit, false, cx)
+                })?
                 .await?;
 
-            Ok(vec![ToolResponseContent::Text {
-                text: response.content,
-            }])
+            Ok(vec![ToolResponseContent::Text { text: content }])
         })
     }
 
     fn handle_edit_tool_call(
         params: EditToolParams,
-        delegate: OldAcpClientDelegate,
+        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
         cx: &AsyncApp,
     ) -> Task<Result<()>> {
-        cx.foreground_executor().spawn(async move {
-            let response = delegate
-                .read_text_file_reusing_snapshot(acp_old::ReadTextFileParams {
-                    path: params.abs_path.clone(),
-                    line: None,
-                    limit: None,
-                })
+        cx.spawn(async move |cx| {
+            // todo! get session id somehow
+            let threads_map = threads_map.borrow();
+            let Some((_, thread)) = threads_map.iter().next() else {
+                anyhow::bail!("Server not available");
+            };
+
+            let content = thread
+                .update(cx, |threads, cx| {
+                    threads.read_text_file(params.abs_path.clone(), None, None, true, cx)
+                })?
                 .await?;
 
-            let new_content = response.content.replace(&params.old_text, &params.new_text);
-            if new_content == response.content {
+            let new_content = content.replace(&params.old_text, &params.new_text);
+            if new_content == content {
                 return Err(anyhow::anyhow!("The old_text was not found in the content"));
             }
 
-            delegate
-                .write_text_file(acp_old::WriteTextFileParams {
-                    path: params.abs_path,
-                    content: new_content,
-                })
+            thread
+                .update(cx, |threads, cx| {
+                    threads.write_text_file(params.abs_path, new_content, cx)
+                })?
                 .await?;
 
             Ok(())
@@ -250,45 +247,56 @@ impl ZedMcpServer {
 
     fn handle_permissions_tool_call(
         params: PermissionToolParams,
-        delegate: OldAcpClientDelegate,
-        tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
+        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
         cx: &AsyncApp,
     ) -> Task<Result<PermissionToolResponse>> {
-        cx.foreground_executor().spawn(async move {
+        cx.spawn(async move |cx| {
+            // todo! get session id somehow
+            let threads_map = threads_map.borrow();
+            let Some((_, thread)) = threads_map.iter().next() else {
+                anyhow::bail!("Server not available");
+            };
+
             let claude_tool = ClaudeTool::infer(&params.tool_name, params.input.clone());
 
-            let tool_call_id = match params.tool_use_id {
-                Some(tool_use_id) => tool_id_map
-                    .borrow()
-                    .get(&tool_use_id)
-                    .cloned()
-                    .context("Tool call ID not found")?,
+            let tool_call_id =
+                acp::ToolCallId(params.tool_use_id.context("Tool ID required")?.into());
 
-                None => delegate.push_tool_call(claude_tool.as_acp()).await?.id,
-            };
+            let allow_option_id = acp::PermissionOptionId("allow".into());
+            let reject_option_id = acp::PermissionOptionId("reject".into());
+
+            let chosen_option = thread
+                .update(cx, |thread, cx| {
+                    thread.request_tool_call_permission(
+                        claude_tool.as_acp(tool_call_id),
+                        vec![
+                            acp::PermissionOption {
+                                id: allow_option_id.clone(),
+                                label: "Allow".into(),
+                                kind: acp::PermissionOptionKind::AllowOnce,
+                            },
+                            acp::PermissionOption {
+                                id: reject_option_id,
+                                label: "Reject".into(),
+                                kind: acp::PermissionOptionKind::RejectOnce,
+                            },
+                        ],
+                        cx,
+                    )
+                })?
+                .await?;
 
-            todo!("use regular request_tool_call_confirmation")
-            // let outcome = delegate
-            //     .request_existing_tool_call_confirmation(
-            //         tool_call_id,
-            //         claude_tool.confirmation(None),
-            //     )
-            //     .await?;
-
-            // match outcome {
-            //     acp::ToolCallConfirmationOutcome::Allow
-            //     | acp::ToolCallConfirmationOutcome::AlwaysAllow
-            //     | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
-            //     | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse {
-            //         behavior: PermissionToolBehavior::Allow,
-            //         updated_input: params.input,
-            //     }),
-            //     acp::ToolCallConfirmationOutcome::Reject
-            //     | acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse {
-            //         behavior: PermissionToolBehavior::Deny,
-            //         updated_input: params.input,
-            //     }),
-            // }
+            if chosen_option == allow_option_id {
+                Ok(PermissionToolResponse {
+                    behavior: PermissionToolBehavior::Allow,
+                    updated_input: params.input,
+                })
+            } else {
+                Ok(PermissionToolResponse {
+                    behavior: PermissionToolBehavior::Deny,
+                    updated_input: params.input,
+                })
+            }
         })
     }
 }