mcp_server.rs

  1use std::path::PathBuf;
  2
  3use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
  4use acp_thread::AcpThread;
  5use agent_client_protocol as acp;
  6use agent_settings::AgentSettings;
  7use anyhow::{Context, Result};
  8use collections::HashMap;
  9use context_server::listener::{McpServerTool, ToolResponse};
 10use context_server::types::{
 11    Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
 12    ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests,
 13};
 14use gpui::{App, AsyncApp, Task, WeakEntity};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::Settings;
 18
 19pub struct ClaudeZedMcpServer {
 20    server: context_server::listener::McpServer,
 21}
 22
 23pub const SERVER_NAME: &str = "zed";
 24
 25impl ClaudeZedMcpServer {
 26    pub async fn new(
 27        thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
 28        cx: &AsyncApp,
 29    ) -> Result<Self> {
 30        let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
 31        mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
 32
 33        mcp_server.add_tool(PermissionTool {
 34            thread_rx: thread_rx.clone(),
 35        });
 36        mcp_server.add_tool(ReadTool {
 37            thread_rx: thread_rx.clone(),
 38        });
 39        mcp_server.add_tool(EditTool {
 40            thread_rx: thread_rx.clone(),
 41        });
 42
 43        Ok(Self { server: mcp_server })
 44    }
 45
 46    pub fn server_config(&self) -> Result<McpServerConfig> {
 47        #[cfg(not(test))]
 48        let zed_path = std::env::current_exe()
 49            .context("finding current executable path for use in mcp_server")?;
 50
 51        #[cfg(test)]
 52        let zed_path = crate::e2e_tests::get_zed_path();
 53
 54        Ok(McpServerConfig {
 55            command: zed_path,
 56            args: vec![
 57                "--nc".into(),
 58                self.server.socket_path().display().to_string(),
 59            ],
 60            env: None,
 61        })
 62    }
 63
 64    fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
 65        cx.foreground_executor().spawn(async move {
 66            Ok(InitializeResponse {
 67                protocol_version: ProtocolVersion("2025-06-18".into()),
 68                capabilities: ServerCapabilities {
 69                    experimental: None,
 70                    logging: None,
 71                    completions: None,
 72                    prompts: None,
 73                    resources: None,
 74                    tools: Some(ToolsCapabilities {
 75                        list_changed: Some(false),
 76                    }),
 77                },
 78                server_info: Implementation {
 79                    name: SERVER_NAME.into(),
 80                    version: "0.1.0".into(),
 81                },
 82                meta: None,
 83            })
 84        })
 85    }
 86}
 87
 88#[derive(Serialize)]
 89#[serde(rename_all = "camelCase")]
 90pub struct McpConfig {
 91    pub mcp_servers: HashMap<String, McpServerConfig>,
 92}
 93
 94#[derive(Serialize, Clone)]
 95#[serde(rename_all = "camelCase")]
 96pub struct McpServerConfig {
 97    pub command: PathBuf,
 98    pub args: Vec<String>,
 99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub env: Option<HashMap<String, String>>,
101}
102
103// Tools
104
105#[derive(Clone)]
106pub struct PermissionTool {
107    thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
108}
109
110#[derive(Deserialize, JsonSchema, Debug)]
111pub struct PermissionToolParams {
112    tool_name: String,
113    input: serde_json::Value,
114    tool_use_id: Option<String>,
115}
116
117#[derive(Serialize)]
118#[serde(rename_all = "camelCase")]
119#[cfg_attr(test, derive(serde::Deserialize))]
120pub struct PermissionToolResponse {
121    behavior: PermissionToolBehavior,
122    updated_input: serde_json::Value,
123}
124
125#[derive(Serialize)]
126#[serde(rename_all = "snake_case")]
127#[cfg_attr(test, derive(serde::Deserialize))]
128pub enum PermissionToolBehavior {
129    Allow,
130    Deny,
131}
132
133impl McpServerTool for PermissionTool {
134    type Input = PermissionToolParams;
135    type Output = ();
136
137    const NAME: &'static str = "Confirmation";
138
139    fn description(&self) -> &'static str {
140        "Request permission for tool calls"
141    }
142
143    async fn run(
144        &self,
145        input: Self::Input,
146        cx: &mut AsyncApp,
147    ) -> Result<ToolResponse<Self::Output>> {
148        // Check if we should automatically allow tool actions
149        let always_allow =
150            cx.update(|cx| AgentSettings::get_global(cx).always_allow_tool_actions)?;
151
152        if always_allow {
153            // If always_allow_tool_actions is true, immediately return Allow without prompting
154            let response = PermissionToolResponse {
155                behavior: PermissionToolBehavior::Allow,
156                updated_input: input.input,
157            };
158
159            return Ok(ToolResponse {
160                content: vec![ToolResponseContent::Text {
161                    text: serde_json::to_string(&response)?,
162                }],
163                structured_content: (),
164            });
165        }
166
167        // Otherwise, proceed with the normal permission flow
168        let mut thread_rx = self.thread_rx.clone();
169        let Some(thread) = thread_rx.recv().await?.upgrade() else {
170            anyhow::bail!("Thread closed");
171        };
172
173        let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
174        let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
175        let allow_option_id = acp::PermissionOptionId("allow".into());
176        let reject_option_id = acp::PermissionOptionId("reject".into());
177
178        let chosen_option = thread
179            .update(cx, |thread, cx| {
180                thread.request_tool_call_permission(
181                    claude_tool.as_acp(tool_call_id),
182                    vec![
183                        acp::PermissionOption {
184                            id: allow_option_id.clone(),
185                            label: "Allow".into(),
186                            kind: acp::PermissionOptionKind::AllowOnce,
187                        },
188                        acp::PermissionOption {
189                            id: reject_option_id.clone(),
190                            label: "Reject".into(),
191                            kind: acp::PermissionOptionKind::RejectOnce,
192                        },
193                    ],
194                    cx,
195                )
196            })?
197            .await?;
198
199        let response = if chosen_option == allow_option_id {
200            PermissionToolResponse {
201                behavior: PermissionToolBehavior::Allow,
202                updated_input: input.input,
203            }
204        } else {
205            debug_assert_eq!(chosen_option, reject_option_id);
206            PermissionToolResponse {
207                behavior: PermissionToolBehavior::Deny,
208                updated_input: input.input,
209            }
210        };
211
212        Ok(ToolResponse {
213            content: vec![ToolResponseContent::Text {
214                text: serde_json::to_string(&response)?,
215            }],
216            structured_content: (),
217        })
218    }
219}
220
221#[derive(Clone)]
222pub struct ReadTool {
223    thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
224}
225
226impl McpServerTool for ReadTool {
227    type Input = ReadToolParams;
228    type Output = ();
229
230    const NAME: &'static str = "Read";
231
232    fn description(&self) -> &'static str {
233        "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
234    }
235
236    fn annotations(&self) -> ToolAnnotations {
237        ToolAnnotations {
238            title: Some("Read file".to_string()),
239            read_only_hint: Some(true),
240            destructive_hint: Some(false),
241            open_world_hint: Some(false),
242            idempotent_hint: None,
243        }
244    }
245
246    async fn run(
247        &self,
248        input: Self::Input,
249        cx: &mut AsyncApp,
250    ) -> Result<ToolResponse<Self::Output>> {
251        let mut thread_rx = self.thread_rx.clone();
252        let Some(thread) = thread_rx.recv().await?.upgrade() else {
253            anyhow::bail!("Thread closed");
254        };
255
256        let content = thread
257            .update(cx, |thread, cx| {
258                thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
259            })?
260            .await?;
261
262        Ok(ToolResponse {
263            content: vec![ToolResponseContent::Text { text: content }],
264            structured_content: (),
265        })
266    }
267}
268
269#[derive(Clone)]
270pub struct EditTool {
271    thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
272}
273
274impl McpServerTool for EditTool {
275    type Input = EditToolParams;
276    type Output = ();
277
278    const NAME: &'static str = "Edit";
279
280    fn description(&self) -> &'static str {
281        "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
282    }
283
284    fn annotations(&self) -> ToolAnnotations {
285        ToolAnnotations {
286            title: Some("Edit file".to_string()),
287            read_only_hint: Some(false),
288            destructive_hint: Some(false),
289            open_world_hint: Some(false),
290            idempotent_hint: Some(false),
291        }
292    }
293
294    async fn run(
295        &self,
296        input: Self::Input,
297        cx: &mut AsyncApp,
298    ) -> Result<ToolResponse<Self::Output>> {
299        let mut thread_rx = self.thread_rx.clone();
300        let Some(thread) = thread_rx.recv().await?.upgrade() else {
301            anyhow::bail!("Thread closed");
302        };
303
304        let content = thread
305            .update(cx, |thread, cx| {
306                thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
307            })?
308            .await?;
309
310        let new_content = content.replace(&input.old_text, &input.new_text);
311        if new_content == content {
312            return Err(anyhow::anyhow!("The old_text was not found in the content"));
313        }
314
315        thread
316            .update(cx, |thread, cx| {
317                thread.write_text_file(input.abs_path, new_content, cx)
318            })?
319            .await?;
320
321        Ok(ToolResponse {
322            content: vec![],
323            structured_content: (),
324        })
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use gpui::TestAppContext;
332    use project::Project;
333    use settings::{Settings, SettingsStore};
334
335    #[gpui::test]
336    async fn test_permission_tool_respects_always_allow_setting(cx: &mut TestAppContext) {
337        // Initialize settings
338        cx.update(|cx| {
339            let settings_store = SettingsStore::test(cx);
340            cx.set_global(settings_store);
341            agent_settings::init(cx);
342        });
343
344        // Create a test thread
345        let project = cx.update(|cx| gpui::Entity::new(cx, |_cx| Project::local()));
346        let thread = cx.update(|cx| {
347            gpui::Entity::new(cx, |_cx| {
348                acp_thread::AcpThread::new(
349                    acp::ConnectionId("test".into()),
350                    project,
351                    std::path::Path::new("/tmp"),
352                )
353            })
354        });
355
356        let (tx, rx) = watch::channel(thread.downgrade());
357        let tool = PermissionTool { thread_rx: rx };
358
359        // Test with always_allow_tool_actions = true
360        cx.update(|cx| {
361            AgentSettings::override_global(
362                AgentSettings {
363                    always_allow_tool_actions: true,
364                    ..Default::default()
365                },
366                cx,
367            );
368        });
369
370        let input = PermissionToolParams {
371            tool_name: "test_tool".to_string(),
372            input: serde_json::json!({"test": "data"}),
373            tool_use_id: Some("test_id".to_string()),
374        };
375
376        let result = tool.run(input.clone(), &mut cx.to_async()).await.unwrap();
377
378        // Should return Allow without prompting
379        assert_eq!(result.content.len(), 1);
380        if let ToolResponseContent::Text { text } = &result.content[0] {
381            let response: PermissionToolResponse = serde_json::from_str(text).unwrap();
382            assert!(matches!(response.behavior, PermissionToolBehavior::Allow));
383        } else {
384            panic!("Expected text response");
385        }
386
387        // Test with always_allow_tool_actions = false
388        cx.update(|cx| {
389            AgentSettings::override_global(
390                AgentSettings {
391                    always_allow_tool_actions: false,
392                    ..Default::default()
393                },
394                cx,
395            );
396        });
397
398        // This test would require mocking the permission prompt response
399        // In the real scenario, it would wait for user input
400    }
401}