mcp_server.rs

  1use std::{cell::RefCell, path::PathBuf, rc::Rc};
  2
  3use acp_thread::AcpThread;
  4use agent_client_protocol as acp;
  5use anyhow::{Context, Result};
  6use collections::HashMap;
  7use context_server::types::{
  8    CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
  9    ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
 10    ToolResponseContent, ToolsCapabilities, requests,
 11};
 12use gpui::{App, AsyncApp, Task, WeakEntity};
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15
 16use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
 17
 18pub struct ZedMcpServer {
 19    server: context_server::listener::McpServer,
 20}
 21
 22pub const SERVER_NAME: &str = "zed";
 23pub const READ_TOOL: &str = "Read";
 24pub const EDIT_TOOL: &str = "Edit";
 25pub const PERMISSION_TOOL: &str = "Confirmation";
 26
 27#[derive(Deserialize, JsonSchema, Debug)]
 28struct PermissionToolParams {
 29    tool_name: String,
 30    input: serde_json::Value,
 31    tool_use_id: Option<String>,
 32}
 33
 34#[derive(Serialize)]
 35#[serde(rename_all = "camelCase")]
 36struct PermissionToolResponse {
 37    behavior: PermissionToolBehavior,
 38    updated_input: serde_json::Value,
 39}
 40
 41#[derive(Serialize)]
 42#[serde(rename_all = "snake_case")]
 43enum PermissionToolBehavior {
 44    Allow,
 45    Deny,
 46}
 47
 48impl ZedMcpServer {
 49    pub async fn new(
 50        thread_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
 51        cx: &AsyncApp,
 52    ) -> Result<Self> {
 53        let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
 54        mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
 55        mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
 56        mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
 57            Self::handle_call_tool(request, thread_map.clone(), cx)
 58        });
 59
 60        Ok(Self { server: mcp_server })
 61    }
 62
 63    pub fn server_config(&self) -> Result<McpServerConfig> {
 64        let zed_path = std::env::current_exe()
 65            .context("finding current executable path for use in mcp_server")?;
 66
 67        Ok(McpServerConfig {
 68            command: zed_path,
 69            args: vec![
 70                "--nc".into(),
 71                self.server.socket_path().display().to_string(),
 72            ],
 73            env: None,
 74        })
 75    }
 76
 77    fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
 78        cx.foreground_executor().spawn(async move {
 79            Ok(InitializeResponse {
 80                protocol_version: ProtocolVersion("2025-06-18".into()),
 81                capabilities: ServerCapabilities {
 82                    experimental: None,
 83                    logging: None,
 84                    completions: None,
 85                    prompts: None,
 86                    resources: None,
 87                    tools: Some(ToolsCapabilities {
 88                        list_changed: Some(false),
 89                    }),
 90                },
 91                server_info: Implementation {
 92                    name: SERVER_NAME.into(),
 93                    version: "0.1.0".into(),
 94                },
 95                meta: None,
 96            })
 97        })
 98    }
 99
100    fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
101        cx.foreground_executor().spawn(async move {
102            Ok(ListToolsResponse {
103                tools: vec![
104                    Tool {
105                        name: PERMISSION_TOOL.into(),
106                        input_schema: schemars::schema_for!(PermissionToolParams).into(),
107                        description: None,
108                        annotations: None,
109                    },
110                    Tool {
111                        name: READ_TOOL.into(),
112                        input_schema: schemars::schema_for!(ReadToolParams).into(),
113                        description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
114                        annotations: Some(ToolAnnotations {
115                            title: Some("Read file".to_string()),
116                            read_only_hint: Some(true),
117                            destructive_hint: Some(false),
118                            open_world_hint: Some(false),
119                            // if time passes the contents might change, but it's not going to do anything different
120                            // true or false seem too strong, let's try a none.
121                            idempotent_hint: None,
122                        }),
123                    },
124                    Tool {
125                        name: EDIT_TOOL.into(),
126                        input_schema: schemars::schema_for!(EditToolParams).into(),
127                        description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
128                        annotations: Some(ToolAnnotations {
129                            title: Some("Edit file".to_string()),
130                            read_only_hint: Some(false),
131                            destructive_hint: Some(false),
132                            open_world_hint: Some(false),
133                            idempotent_hint: Some(false),
134                        }),
135                    },
136                ],
137                next_cursor: None,
138                meta: None,
139            })
140        })
141    }
142
143    fn handle_call_tool(
144        request: CallToolParams,
145        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
146        cx: &App,
147    ) -> Task<Result<CallToolResponse>> {
148        cx.spawn(async move |cx| {
149            if request.name.as_str() == PERMISSION_TOOL {
150                let input =
151                    serde_json::from_value(request.arguments.context("Arguments required")?)?;
152
153                let result = Self::handle_permissions_tool_call(input, threads_map, cx).await?;
154                Ok(CallToolResponse {
155                    content: vec![ToolResponseContent::Text {
156                        text: serde_json::to_string(&result)?,
157                    }],
158                    is_error: None,
159                    meta: None,
160                })
161            } else if request.name.as_str() == READ_TOOL {
162                let input =
163                    serde_json::from_value(request.arguments.context("Arguments required")?)?;
164
165                let content = Self::handle_read_tool_call(input, threads_map, cx).await?;
166                Ok(CallToolResponse {
167                    content,
168                    is_error: None,
169                    meta: None,
170                })
171            } else if request.name.as_str() == EDIT_TOOL {
172                let input =
173                    serde_json::from_value(request.arguments.context("Arguments required")?)?;
174
175                Self::handle_edit_tool_call(input, threads_map, cx).await?;
176                Ok(CallToolResponse {
177                    content: vec![],
178                    is_error: None,
179                    meta: None,
180                })
181            } else {
182                anyhow::bail!("Unsupported tool");
183            }
184        })
185    }
186
187    fn handle_read_tool_call(
188        ReadToolParams {
189            abs_path,
190            offset,
191            limit,
192        }: ReadToolParams,
193        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
194        cx: &AsyncApp,
195    ) -> Task<Result<Vec<ToolResponseContent>>> {
196        cx.spawn(async move |cx| {
197            // todo! get session id somehow
198            let thread = {
199                let threads_map = threads_map.borrow();
200                let Some((_, thread)) = threads_map.iter().next() else {
201                    anyhow::bail!("Server not available");
202                };
203                thread.clone()
204            };
205
206            let content = thread
207                .update(cx, |thread, cx| {
208                    thread.read_text_file(abs_path, offset, limit, false, cx)
209                })?
210                .await?;
211
212            Ok(vec![ToolResponseContent::Text { text: content }])
213        })
214    }
215
216    fn handle_edit_tool_call(
217        params: EditToolParams,
218        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
219        cx: &AsyncApp,
220    ) -> Task<Result<()>> {
221        cx.spawn(async move |cx| {
222            // todo! get session id somehow
223            let thread = {
224                let threads_map = threads_map.borrow();
225                let Some((_, thread)) = threads_map.iter().next() else {
226                    anyhow::bail!("Server not available");
227                };
228                thread.clone()
229            };
230
231            let content = thread
232                .update(cx, |threads, cx| {
233                    threads.read_text_file(params.abs_path.clone(), None, None, true, cx)
234                })?
235                .await?;
236
237            let new_content = content.replace(&params.old_text, &params.new_text);
238            if new_content == content {
239                return Err(anyhow::anyhow!("The old_text was not found in the content"));
240            }
241
242            thread
243                .update(cx, |threads, cx| {
244                    threads.write_text_file(params.abs_path, new_content, cx)
245                })?
246                .await?;
247
248            Ok(())
249        })
250    }
251
252    fn handle_permissions_tool_call(
253        params: PermissionToolParams,
254        threads_map: Rc<RefCell<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
255        cx: &AsyncApp,
256    ) -> Task<Result<PermissionToolResponse>> {
257        cx.spawn(async move |cx| {
258            // todo! get session id somehow
259            let thread = {
260                let threads_map = threads_map.borrow();
261                let Some((_, thread)) = threads_map.iter().next() else {
262                    anyhow::bail!("Server not available");
263                };
264                thread.clone()
265            };
266
267            let claude_tool = ClaudeTool::infer(&params.tool_name, params.input.clone());
268
269            let tool_call_id =
270                acp::ToolCallId(params.tool_use_id.context("Tool ID required")?.into());
271
272            let allow_option_id = acp::PermissionOptionId("allow".into());
273            let reject_option_id = acp::PermissionOptionId("reject".into());
274
275            let chosen_option = thread
276                .update(cx, |thread, cx| {
277                    thread.request_tool_call_permission(
278                        claude_tool.as_acp(tool_call_id),
279                        vec![
280                            acp::PermissionOption {
281                                id: allow_option_id.clone(),
282                                label: "Allow".into(),
283                                kind: acp::PermissionOptionKind::AllowOnce,
284                            },
285                            acp::PermissionOption {
286                                id: reject_option_id,
287                                label: "Reject".into(),
288                                kind: acp::PermissionOptionKind::RejectOnce,
289                            },
290                        ],
291                        cx,
292                    )
293                })?
294                .await?;
295
296            if chosen_option == allow_option_id {
297                Ok(PermissionToolResponse {
298                    behavior: PermissionToolBehavior::Allow,
299                    updated_input: params.input,
300                })
301            } else {
302                Ok(PermissionToolResponse {
303                    behavior: PermissionToolBehavior::Deny,
304                    updated_input: params.input,
305                })
306            }
307        })
308    }
309}
310
311#[derive(Serialize)]
312#[serde(rename_all = "camelCase")]
313pub struct McpConfig {
314    pub mcp_servers: HashMap<String, McpServerConfig>,
315}
316
317#[derive(Serialize, Clone)]
318#[serde(rename_all = "camelCase")]
319pub struct McpServerConfig {
320    pub command: PathBuf,
321    pub args: Vec<String>,
322    #[serde(skip_serializing_if = "Option::is_none")]
323    pub env: Option<HashMap<String, String>>,
324}