@@ -1,49 +1,24 @@
use std::path::PathBuf;
+use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use anyhow::{Context, Result};
use collections::HashMap;
+use context_server::listener::{McpServerTool, ToolResponse};
use context_server::types::{
- CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
- ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
- ToolResponseContent, ToolsCapabilities, requests,
+ Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
+ ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests,
};
-use gpui::{App, AsyncApp, Entity, Task, WeakEntity};
+use gpui::{App, AsyncApp, Task, WeakEntity};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
-
pub struct ClaudeZedMcpServer {
server: context_server::listener::McpServer,
}
pub const SERVER_NAME: &str = "zed";
-pub const READ_TOOL: &str = "Read";
-pub const EDIT_TOOL: &str = "Edit";
-pub const PERMISSION_TOOL: &str = "Confirmation";
-
-#[derive(Deserialize, JsonSchema, Debug)]
-struct PermissionToolParams {
- tool_name: String,
- input: serde_json::Value,
- tool_use_id: Option<String>,
-}
-
-#[derive(Serialize)]
-#[serde(rename_all = "camelCase")]
-struct PermissionToolResponse {
- behavior: PermissionToolBehavior,
- updated_input: serde_json::Value,
-}
-
-#[derive(Serialize)]
-#[serde(rename_all = "snake_case")]
-enum PermissionToolBehavior {
- Allow,
- Deny,
-}
impl ClaudeZedMcpServer {
pub async fn new(
@@ -52,9 +27,15 @@ impl ClaudeZedMcpServer {
) -> Result<Self> {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
- mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
- mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
- Self::handle_call_tool(request, thread_rx.clone(), cx)
+
+ mcp_server.add_tool(PermissionTool {
+ thread_rx: thread_rx.clone(),
+ });
+ mcp_server.add_tool(ReadTool {
+ thread_rx: thread_rx.clone(),
+ });
+ mcp_server.add_tool(EditTool {
+ thread_rx: thread_rx.clone(),
});
Ok(Self { server: mcp_server })
@@ -96,206 +77,203 @@ impl ClaudeZedMcpServer {
})
})
}
+}
- fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
- cx.foreground_executor().spawn(async move {
- Ok(ListToolsResponse {
- tools: vec![
- Tool {
- name: PERMISSION_TOOL.into(),
- input_schema: schemars::schema_for!(PermissionToolParams).into(),
- description: None,
- annotations: None,
- },
- Tool {
- name: READ_TOOL.into(),
- input_schema: schemars::schema_for!(ReadToolParams).into(),
- description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
- annotations: Some(ToolAnnotations {
- title: Some("Read file".to_string()),
- read_only_hint: Some(true),
- destructive_hint: Some(false),
- open_world_hint: Some(false),
- // if time passes the contents might change, but it's not going to do anything different
- // true or false seem too strong, let's try a none.
- idempotent_hint: None,
- }),
- },
- Tool {
- name: EDIT_TOOL.into(),
- input_schema: schemars::schema_for!(EditToolParams).into(),
- description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
- annotations: Some(ToolAnnotations {
- title: Some("Edit file".to_string()),
- read_only_hint: Some(false),
- destructive_hint: Some(false),
- open_world_hint: Some(false),
- idempotent_hint: Some(false),
- }),
- },
- ],
- next_cursor: None,
- meta: None,
- })
- })
+#[derive(Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct McpConfig {
+ pub mcp_servers: HashMap<String, McpServerConfig>,
+}
+
+#[derive(Serialize, Clone)]
+#[serde(rename_all = "camelCase")]
+pub struct McpServerConfig {
+ pub command: PathBuf,
+ pub args: Vec<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub env: Option<HashMap<String, String>>,
+}
+
+// Tools
+
+#[derive(Clone)]
+pub struct PermissionTool {
+ thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
+}
+
+#[derive(Deserialize, JsonSchema, Debug)]
+pub struct PermissionToolParams {
+ tool_name: String,
+ input: serde_json::Value,
+ tool_use_id: Option<String>,
+}
+
+#[derive(Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PermissionToolResponse {
+ behavior: PermissionToolBehavior,
+ updated_input: serde_json::Value,
+}
+
+#[derive(Serialize)]
+#[serde(rename_all = "snake_case")]
+enum PermissionToolBehavior {
+ Allow,
+ Deny,
+}
+
+impl McpServerTool for PermissionTool {
+ type Input = PermissionToolParams;
+ const NAME: &'static str = "Confirmation";
+
+ fn description(&self) -> &'static str {
+ "Request permission for tool calls"
}
- fn handle_call_tool(
- request: CallToolParams,
- mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
- cx: &App,
- ) -> Task<Result<CallToolResponse>> {
- cx.spawn(async move |cx| {
- let Some(thread) = thread_rx.recv().await?.upgrade() else {
- anyhow::bail!("Thread closed");
- };
-
- if request.name.as_str() == PERMISSION_TOOL {
- let input =
- serde_json::from_value(request.arguments.context("Arguments required")?)?;
-
- let result = Self::handle_permissions_tool_call(input, thread, cx).await?;
- Ok(CallToolResponse {
- content: vec![ToolResponseContent::Text {
- text: serde_json::to_string(&result)?,
- }],
- is_error: None,
- meta: None,
- })
- } else if request.name.as_str() == READ_TOOL {
- let input =
- serde_json::from_value(request.arguments.context("Arguments required")?)?;
-
- let content = Self::handle_read_tool_call(input, thread, cx).await?;
- Ok(CallToolResponse {
- content,
- is_error: None,
- meta: None,
- })
- } else if request.name.as_str() == EDIT_TOOL {
- let input =
- serde_json::from_value(request.arguments.context("Arguments required")?)?;
-
- Self::handle_edit_tool_call(input, thread, cx).await?;
- Ok(CallToolResponse {
- content: vec![],
- is_error: None,
- meta: None,
- })
- } else {
- anyhow::bail!("Unsupported tool");
+ async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
+ let mut thread_rx = self.thread_rx.clone();
+ let Some(thread) = thread_rx.recv().await?.upgrade() else {
+ anyhow::bail!("Thread closed");
+ };
+
+ let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
+ let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
+ let allow_option_id = acp::PermissionOptionId("allow".into());
+ let reject_option_id = acp::PermissionOptionId("reject".into());
+
+ let chosen_option = thread
+ .update(cx, |thread, cx| {
+ thread.request_tool_call_permission(
+ claude_tool.as_acp(tool_call_id),
+ vec![
+ acp::PermissionOption {
+ id: allow_option_id.clone(),
+ label: "Allow".into(),
+ kind: acp::PermissionOptionKind::AllowOnce,
+ },
+ acp::PermissionOption {
+ id: reject_option_id.clone(),
+ label: "Reject".into(),
+ kind: acp::PermissionOptionKind::RejectOnce,
+ },
+ ],
+ cx,
+ )
+ })?
+ .await?;
+
+ let response = if chosen_option == allow_option_id {
+ PermissionToolResponse {
+ behavior: PermissionToolBehavior::Allow,
+ updated_input: input.input,
}
- })
- }
+ } else {
+ PermissionToolResponse {
+ behavior: PermissionToolBehavior::Deny,
+ updated_input: input.input,
+ }
+ };
- fn handle_read_tool_call(
- ReadToolParams {
- abs_path,
- offset,
- limit,
- }: ReadToolParams,
- thread: Entity<AcpThread>,
- cx: &AsyncApp,
- ) -> Task<Result<Vec<ToolResponseContent>>> {
- cx.spawn(async move |cx| {
- let content = thread
- .update(cx, |thread, cx| {
- thread.read_text_file(abs_path, offset, limit, false, cx)
- })?
- .await?;
-
- Ok(vec![ToolResponseContent::Text { text: content }])
+ Ok(ToolResponse {
+ content: vec![ToolResponseContent::Text {
+ text: serde_json::to_string(&response)?,
+ }],
+ structured_content: None,
})
}
+}
- fn handle_edit_tool_call(
- params: EditToolParams,
- thread: Entity<AcpThread>,
- cx: &AsyncApp,
- ) -> Task<Result<()>> {
- cx.spawn(async move |cx| {
- let content = thread
- .update(cx, |threads, cx| {
- threads.read_text_file(params.abs_path.clone(), None, None, true, cx)
- })?
- .await?;
-
- let new_content = content.replace(¶ms.old_text, ¶ms.new_text);
- if new_content == content {
- return Err(anyhow::anyhow!("The old_text was not found in the content"));
- }
+#[derive(Clone)]
+pub struct ReadTool {
+ thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
+}
- thread
- .update(cx, |threads, cx| {
- threads.write_text_file(params.abs_path, new_content, cx)
- })?
- .await?;
+impl McpServerTool for ReadTool {
+ type Input = ReadToolParams;
+ const NAME: &'static str = "Read";
- Ok(())
- })
+ fn description(&self) -> &'static str {
+ "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
}
- fn handle_permissions_tool_call(
- params: PermissionToolParams,
- thread: Entity<AcpThread>,
- cx: &AsyncApp,
- ) -> Task<Result<PermissionToolResponse>> {
- cx.spawn(async move |cx| {
- let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone());
-
- let tool_call_id =
- acp::ToolCallId(params.tool_use_id.context("Tool ID required")?.into());
-
- let allow_option_id = acp::PermissionOptionId("allow".into());
- let reject_option_id = acp::PermissionOptionId("reject".into());
-
- let chosen_option = thread
- .update(cx, |thread, cx| {
- thread.request_tool_call_permission(
- claude_tool.as_acp(tool_call_id),
- vec![
- acp::PermissionOption {
- id: allow_option_id.clone(),
- label: "Allow".into(),
- kind: acp::PermissionOptionKind::AllowOnce,
- },
- acp::PermissionOption {
- id: reject_option_id,
- label: "Reject".into(),
- kind: acp::PermissionOptionKind::RejectOnce,
- },
- ],
- cx,
- )
- })?
- .await?;
-
- if chosen_option == allow_option_id {
- Ok(PermissionToolResponse {
- behavior: PermissionToolBehavior::Allow,
- updated_input: params.input,
- })
- } else {
- Ok(PermissionToolResponse {
- behavior: PermissionToolBehavior::Deny,
- updated_input: params.input,
- })
- }
+ fn annotations(&self) -> ToolAnnotations {
+ ToolAnnotations {
+ title: Some("Read file".to_string()),
+ read_only_hint: Some(true),
+ destructive_hint: Some(false),
+ open_world_hint: Some(false),
+ idempotent_hint: None,
+ }
+ }
+
+ async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
+ let mut thread_rx = self.thread_rx.clone();
+ let Some(thread) = thread_rx.recv().await?.upgrade() else {
+ anyhow::bail!("Thread closed");
+ };
+
+ let content = thread
+ .update(cx, |thread, cx| {
+ thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
+ })?
+ .await?;
+
+ Ok(ToolResponse {
+ content: vec![ToolResponseContent::Text { text: content }],
+ structured_content: None,
})
}
}
-#[derive(Serialize)]
-#[serde(rename_all = "camelCase")]
-pub struct McpConfig {
- pub mcp_servers: HashMap<String, McpServerConfig>,
+#[derive(Clone)]
+pub struct EditTool {
+ thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
-#[derive(Serialize, Clone)]
-#[serde(rename_all = "camelCase")]
-pub struct McpServerConfig {
- pub command: PathBuf,
- pub args: Vec<String>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub env: Option<HashMap<String, String>>,
+impl McpServerTool for EditTool {
+ type Input = EditToolParams;
+ const NAME: &'static str = "Edit";
+
+ fn description(&self) -> &'static str {
+ "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
+ }
+
+ fn annotations(&self) -> ToolAnnotations {
+ ToolAnnotations {
+ title: Some("Edit file".to_string()),
+ read_only_hint: Some(false),
+ destructive_hint: Some(false),
+ open_world_hint: Some(false),
+ idempotent_hint: Some(false),
+ }
+ }
+
+ async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
+ let mut thread_rx = self.thread_rx.clone();
+ let Some(thread) = thread_rx.recv().await?.upgrade() else {
+ anyhow::bail!("Thread closed");
+ };
+
+ let content = thread
+ .update(cx, |thread, cx| {
+ thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
+ })?
+ .await?;
+
+ let new_content = content.replace(&input.old_text, &input.new_text);
+ if new_content == content {
+ return Err(anyhow::anyhow!("The old_text was not found in the content"));
+ }
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.write_text_file(input.abs_path, new_content, cx)
+ })?
+ .await?;
+
+ Ok(ToolResponse {
+ content: vec![],
+ structured_content: None,
+ })
+ }
}
@@ -9,6 +9,8 @@ use futures::{
};
use gpui::{App, AppContext, AsyncApp, Task};
use net::async_net::{UnixListener, UnixStream};
+use schemars::JsonSchema;
+use serde::de::DeserializeOwned;
use serde_json::{json, value::RawValue};
use smol::stream::StreamExt;
use std::{
@@ -20,16 +22,28 @@ use util::ResultExt;
use crate::{
client::{CspResult, RequestId, Response},
- types::Request,
+ types::{
+ CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
+ ToolResponseContent,
+ requests::{CallTool, ListTools},
+ },
};
pub struct McpServer {
socket_path: PathBuf,
- handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
+ tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
+ handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
_server_task: Task<()>,
}
-type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
+struct RegisteredTool {
+ tool: Tool,
+ handler: ToolHandler,
+}
+
+type ToolHandler =
+ Box<dyn Fn(Option<serde_json::Value>, &mut AsyncApp) -> Task<Result<ToolResponse>>>;
+type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
impl McpServer {
pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
@@ -43,12 +57,14 @@ impl McpServer {
cx.spawn(async move |cx| {
let (temp_dir, socket_path, listener) = task.await?;
+ let tools = Rc::new(RefCell::new(HashMap::default()));
let handlers = Rc::new(RefCell::new(HashMap::default()));
let server_task = cx.spawn({
+ let tools = tools.clone();
let handlers = handlers.clone();
async move |cx| {
while let Ok((stream, _)) = listener.accept().await {
- Self::serve_connection(stream, handlers.clone(), cx);
+ Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
}
drop(temp_dir)
}
@@ -56,11 +72,40 @@ impl McpServer {
Ok(Self {
socket_path,
_server_task: server_task,
- handlers: handlers.clone(),
+ tools,
+ handlers: handlers,
})
})
}
+ pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
+ let registered_tool = RegisteredTool {
+ tool: Tool {
+ name: T::NAME.into(),
+ description: Some(tool.description().into()),
+ input_schema: schemars::schema_for!(T::Input).into(),
+ annotations: Some(tool.annotations()),
+ },
+ handler: Box::new({
+ let tool = tool.clone();
+ move |input_value, cx| {
+ let input = match input_value {
+ Some(input) => serde_json::from_value(input),
+ None => serde_json::from_value(serde_json::Value::Null),
+ };
+
+ let tool = tool.clone();
+ match input {
+ Ok(input) => cx.spawn(async move |cx| tool.run(input, cx).await),
+ Err(err) => Task::ready(Err(err.into())),
+ }
+ }
+ }),
+ };
+
+ self.tools.borrow_mut().insert(T::NAME, registered_tool);
+ }
+
pub fn handle_request<R: Request>(
&mut self,
f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
@@ -120,7 +165,8 @@ impl McpServer {
fn serve_connection(
stream: UnixStream,
- handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
+ tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
+ handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
cx: &mut AsyncApp,
) {
let (read, write) = smol::io::split(stream);
@@ -135,7 +181,13 @@ impl McpServer {
let Some(request_id) = request.id.clone() else {
continue;
};
- if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
+
+ if request.method == CallTool::METHOD {
+ Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
+ .await;
+ } else if request.method == ListTools::METHOD {
+ Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
+ } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
let outgoing_tx = outgoing_tx.clone();
if let Some(task) = cx
@@ -149,25 +201,122 @@ impl McpServer {
.detach();
}
} else {
- outgoing_tx
- .unbounded_send(
- serde_json::to_string(&Response::<()> {
- jsonrpc: "2.0",
- id: request.id.unwrap(),
- value: CspResult::Error(Some(crate::client::Error {
- message: format!("unhandled method {}", request.method),
- code: -32601,
- })),
- })
- .unwrap(),
- )
- .ok();
+ Self::send_err(
+ request_id,
+ format!("unhandled method {}", request.method),
+ &outgoing_tx,
+ );
}
}
})
.detach();
}
+ fn handle_list_tools(
+ request_id: RequestId,
+ tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
+ outgoing_tx: &UnboundedSender<String>,
+ ) {
+ let response = ListToolsResponse {
+ tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
+ next_cursor: None,
+ meta: None,
+ };
+
+ outgoing_tx
+ .unbounded_send(
+ serde_json::to_string(&Response {
+ jsonrpc: "2.0",
+ id: request_id,
+ value: CspResult::Ok(Some(response)),
+ })
+ .unwrap_or_default(),
+ )
+ .ok();
+ }
+
+ async fn handle_call_tool(
+ request_id: RequestId,
+ params: Option<Box<RawValue>>,
+ tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
+ outgoing_tx: &UnboundedSender<String>,
+ cx: &mut AsyncApp,
+ ) {
+ let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
+ Some(params) => serde_json::from_str(params.get()),
+ None => serde_json::from_value(serde_json::Value::Null),
+ };
+
+ match result {
+ Ok(params) => {
+ if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
+ let outgoing_tx = outgoing_tx.clone();
+
+ let task = (tool.handler)(params.arguments, cx);
+ cx.spawn(async move |_| {
+ let response = match task.await {
+ Ok(result) => CallToolResponse {
+ content: result.content,
+ is_error: Some(false),
+ meta: None,
+ structured_content: result.structured_content,
+ },
+ Err(err) => CallToolResponse {
+ content: vec![ToolResponseContent::Text {
+ text: err.to_string(),
+ }],
+ is_error: Some(true),
+ meta: None,
+ structured_content: None,
+ },
+ };
+
+ outgoing_tx
+ .unbounded_send(
+ serde_json::to_string(&Response {
+ jsonrpc: "2.0",
+ id: request_id,
+ value: CspResult::Ok(Some(response)),
+ })
+ .unwrap_or_default(),
+ )
+ .ok();
+ })
+ .detach();
+ } else {
+ Self::send_err(
+ request_id,
+ format!("Tool not found: {}", params.name),
+ &outgoing_tx,
+ );
+ }
+ }
+ Err(err) => {
+ Self::send_err(request_id, err.to_string(), &outgoing_tx);
+ }
+ }
+ }
+
+ fn send_err(
+ request_id: RequestId,
+ message: impl Into<String>,
+ outgoing_tx: &UnboundedSender<String>,
+ ) {
+ outgoing_tx
+ .unbounded_send(
+ serde_json::to_string(&Response::<()> {
+ jsonrpc: "2.0",
+ id: request_id,
+ value: CspResult::Error(Some(crate::client::Error {
+ message: message.into(),
+ code: -32601,
+ })),
+ })
+ .unwrap(),
+ )
+ .ok();
+ }
+
async fn handle_io(
mut outgoing_rx: UnboundedReceiver<String>,
incoming_tx: UnboundedSender<RawRequest>,
@@ -216,6 +365,34 @@ impl McpServer {
}
}
+pub trait McpServerTool {
+ type Input: DeserializeOwned + JsonSchema;
+ const NAME: &'static str;
+
+ fn description(&self) -> &'static str;
+
+ fn annotations(&self) -> ToolAnnotations {
+ ToolAnnotations {
+ title: None,
+ read_only_hint: None,
+ destructive_hint: None,
+ idempotent_hint: None,
+ open_world_hint: None,
+ }
+ }
+
+ fn run(
+ &self,
+ input: Self::Input,
+ cx: &mut AsyncApp,
+ ) -> impl Future<Output = Result<ToolResponse>>;
+}
+
+pub struct ToolResponse {
+ pub content: Vec<ToolResponseContent>,
+ pub structured_content: Option<serde_json::Value>,
+}
+
#[derive(Serialize, Deserialize)]
struct RawRequest {
#[serde(skip_serializing_if = "Option::is_none")]