1use ::serde::{Deserialize, Serialize};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use futures::{
5 AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
6 channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
7 io::BufReader,
8 select_biased,
9};
10use gpui::{App, AppContext, AsyncApp, Task};
11use net::async_net::{UnixListener, UnixStream};
12use schemars::JsonSchema;
13use serde::de::DeserializeOwned;
14use serde_json::{json, value::RawValue};
15use smol::stream::StreamExt;
16use std::{
17 cell::RefCell,
18 path::{Path, PathBuf},
19 rc::Rc,
20};
21use util::ResultExt;
22
23use crate::{
24 client::{CspResult, RequestId, Response},
25 types::{
26 CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
27 ToolResponseContent,
28 requests::{CallTool, ListTools},
29 },
30};
31
32pub struct McpServer {
33 socket_path: PathBuf,
34 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
35 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
36 _server_task: Task<()>,
37}
38
39struct RegisteredTool {
40 tool: Tool,
41 handler: ToolHandler,
42}
43
44type ToolHandler =
45 Box<dyn Fn(Option<serde_json::Value>, &mut AsyncApp) -> Task<Result<ToolResponse>>>;
46type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
47
48impl McpServer {
49 pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
50 let task = cx.background_spawn(async move {
51 let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
52 let socket_path = temp_dir.path().join("mcp.sock");
53 let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
54
55 anyhow::Ok((temp_dir, socket_path, listener))
56 });
57
58 cx.spawn(async move |cx| {
59 let (temp_dir, socket_path, listener) = task.await?;
60 let tools = Rc::new(RefCell::new(HashMap::default()));
61 let handlers = Rc::new(RefCell::new(HashMap::default()));
62 let server_task = cx.spawn({
63 let tools = tools.clone();
64 let handlers = handlers.clone();
65 async move |cx| {
66 while let Ok((stream, _)) = listener.accept().await {
67 Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
68 }
69 drop(temp_dir)
70 }
71 });
72 Ok(Self {
73 socket_path,
74 _server_task: server_task,
75 tools,
76 handlers: handlers,
77 })
78 })
79 }
80
81 pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
82 let registered_tool = RegisteredTool {
83 tool: Tool {
84 name: T::NAME.into(),
85 description: Some(tool.description().into()),
86 input_schema: schemars::schema_for!(T::Input).into(),
87 annotations: Some(tool.annotations()),
88 },
89 handler: Box::new({
90 let tool = tool.clone();
91 move |input_value, cx| {
92 let input = match input_value {
93 Some(input) => serde_json::from_value(input),
94 None => serde_json::from_value(serde_json::Value::Null),
95 };
96
97 let tool = tool.clone();
98 match input {
99 Ok(input) => cx.spawn(async move |cx| tool.run(input, cx).await),
100 Err(err) => Task::ready(Err(err.into())),
101 }
102 }
103 }),
104 };
105
106 self.tools.borrow_mut().insert(T::NAME, registered_tool);
107 }
108
109 pub fn handle_request<R: Request>(
110 &mut self,
111 f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
112 ) {
113 let f = Box::new(f);
114 self.handlers.borrow_mut().insert(
115 R::METHOD,
116 Box::new(move |req_id, opt_params, cx| {
117 let result = match opt_params {
118 Some(params) => serde_json::from_str(params.get()),
119 None => serde_json::from_value(serde_json::Value::Null),
120 };
121
122 let params: R::Params = match result {
123 Ok(params) => params,
124 Err(e) => {
125 return Task::ready(
126 serde_json::to_string(&Response::<R::Response> {
127 jsonrpc: "2.0",
128 id: req_id,
129 value: CspResult::Error(Some(crate::client::Error {
130 message: format!("{e}"),
131 code: -32700,
132 })),
133 })
134 .unwrap(),
135 );
136 }
137 };
138 let task = f(params, cx);
139 cx.background_spawn(async move {
140 match task.await {
141 Ok(result) => serde_json::to_string(&Response {
142 jsonrpc: "2.0",
143 id: req_id,
144 value: CspResult::Ok(Some(result)),
145 })
146 .unwrap(),
147 Err(e) => serde_json::to_string(&Response {
148 jsonrpc: "2.0",
149 id: req_id,
150 value: CspResult::Error::<R::Response>(Some(crate::client::Error {
151 message: format!("{e}"),
152 code: -32603,
153 })),
154 })
155 .unwrap(),
156 }
157 })
158 }),
159 );
160 }
161
162 pub fn socket_path(&self) -> &Path {
163 &self.socket_path
164 }
165
166 fn serve_connection(
167 stream: UnixStream,
168 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
169 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
170 cx: &mut AsyncApp,
171 ) {
172 let (read, write) = smol::io::split(stream);
173 let (incoming_tx, mut incoming_rx) = unbounded();
174 let (outgoing_tx, outgoing_rx) = unbounded();
175
176 cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
177 .detach();
178
179 cx.spawn(async move |cx| {
180 while let Some(request) = incoming_rx.next().await {
181 let Some(request_id) = request.id.clone() else {
182 continue;
183 };
184
185 if request.method == CallTool::METHOD {
186 Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
187 .await;
188 } else if request.method == ListTools::METHOD {
189 Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
190 } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
191 let outgoing_tx = outgoing_tx.clone();
192
193 if let Some(task) = cx
194 .update(|cx| handler(request_id, request.params, cx))
195 .log_err()
196 {
197 cx.spawn(async move |_| {
198 let response = task.await;
199 outgoing_tx.unbounded_send(response).ok();
200 })
201 .detach();
202 }
203 } else {
204 Self::send_err(
205 request_id,
206 format!("unhandled method {}", request.method),
207 &outgoing_tx,
208 );
209 }
210 }
211 })
212 .detach();
213 }
214
215 fn handle_list_tools(
216 request_id: RequestId,
217 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
218 outgoing_tx: &UnboundedSender<String>,
219 ) {
220 let response = ListToolsResponse {
221 tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
222 next_cursor: None,
223 meta: None,
224 };
225
226 outgoing_tx
227 .unbounded_send(
228 serde_json::to_string(&Response {
229 jsonrpc: "2.0",
230 id: request_id,
231 value: CspResult::Ok(Some(response)),
232 })
233 .unwrap_or_default(),
234 )
235 .ok();
236 }
237
238 async fn handle_call_tool(
239 request_id: RequestId,
240 params: Option<Box<RawValue>>,
241 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
242 outgoing_tx: &UnboundedSender<String>,
243 cx: &mut AsyncApp,
244 ) {
245 let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
246 Some(params) => serde_json::from_str(params.get()),
247 None => serde_json::from_value(serde_json::Value::Null),
248 };
249
250 match result {
251 Ok(params) => {
252 if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
253 let outgoing_tx = outgoing_tx.clone();
254
255 let task = (tool.handler)(params.arguments, cx);
256 cx.spawn(async move |_| {
257 let response = match task.await {
258 Ok(result) => CallToolResponse {
259 content: result.content,
260 is_error: Some(false),
261 meta: None,
262 structured_content: result.structured_content,
263 },
264 Err(err) => CallToolResponse {
265 content: vec![ToolResponseContent::Text {
266 text: err.to_string(),
267 }],
268 is_error: Some(true),
269 meta: None,
270 structured_content: None,
271 },
272 };
273
274 outgoing_tx
275 .unbounded_send(
276 serde_json::to_string(&Response {
277 jsonrpc: "2.0",
278 id: request_id,
279 value: CspResult::Ok(Some(response)),
280 })
281 .unwrap_or_default(),
282 )
283 .ok();
284 })
285 .detach();
286 } else {
287 Self::send_err(
288 request_id,
289 format!("Tool not found: {}", params.name),
290 &outgoing_tx,
291 );
292 }
293 }
294 Err(err) => {
295 Self::send_err(request_id, err.to_string(), &outgoing_tx);
296 }
297 }
298 }
299
300 fn send_err(
301 request_id: RequestId,
302 message: impl Into<String>,
303 outgoing_tx: &UnboundedSender<String>,
304 ) {
305 outgoing_tx
306 .unbounded_send(
307 serde_json::to_string(&Response::<()> {
308 jsonrpc: "2.0",
309 id: request_id,
310 value: CspResult::Error(Some(crate::client::Error {
311 message: message.into(),
312 code: -32601,
313 })),
314 })
315 .unwrap(),
316 )
317 .ok();
318 }
319
320 async fn handle_io(
321 mut outgoing_rx: UnboundedReceiver<String>,
322 incoming_tx: UnboundedSender<RawRequest>,
323 mut outgoing_bytes: impl Unpin + AsyncWrite,
324 incoming_bytes: impl Unpin + AsyncRead,
325 ) -> Result<()> {
326 let mut output_reader = BufReader::new(incoming_bytes);
327 let mut incoming_line = String::new();
328 loop {
329 select_biased! {
330 message = outgoing_rx.next().fuse() => {
331 if let Some(message) = message {
332 log::trace!("send: {}", &message);
333 outgoing_bytes.write_all(message.as_bytes()).await?;
334 outgoing_bytes.write_all(&[b'\n']).await?;
335 } else {
336 break;
337 }
338 }
339 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
340 if bytes_read? == 0 {
341 break
342 }
343 log::trace!("recv: {}", &incoming_line);
344 match serde_json::from_str(&incoming_line) {
345 Ok(message) => {
346 incoming_tx.unbounded_send(message).log_err();
347 }
348 Err(error) => {
349 outgoing_bytes.write_all(serde_json::to_string(&json!({
350 "jsonrpc": "2.0",
351 "error": json!({
352 "code": -32603,
353 "message": format!("Failed to parse: {error}"),
354 }),
355 }))?.as_bytes()).await?;
356 outgoing_bytes.write_all(&[b'\n']).await?;
357 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
358 }
359 }
360 incoming_line.clear();
361 }
362 }
363 }
364 Ok(())
365 }
366}
367
368pub trait McpServerTool {
369 type Input: DeserializeOwned + JsonSchema;
370 const NAME: &'static str;
371
372 fn description(&self) -> &'static str;
373
374 fn annotations(&self) -> ToolAnnotations {
375 ToolAnnotations {
376 title: None,
377 read_only_hint: None,
378 destructive_hint: None,
379 idempotent_hint: None,
380 open_world_hint: None,
381 }
382 }
383
384 fn run(
385 &self,
386 input: Self::Input,
387 cx: &mut AsyncApp,
388 ) -> impl Future<Output = Result<ToolResponse>>;
389}
390
391pub struct ToolResponse {
392 pub content: Vec<ToolResponseContent>,
393 pub structured_content: Option<serde_json::Value>,
394}
395
396#[derive(Serialize, Deserialize)]
397struct RawRequest {
398 #[serde(skip_serializing_if = "Option::is_none")]
399 id: Option<RequestId>,
400 method: String,
401 #[serde(skip_serializing_if = "Option::is_none")]
402 params: Option<Box<serde_json::value::RawValue>>,
403}
404
405#[derive(Serialize, Deserialize)]
406struct RawResponse {
407 jsonrpc: &'static str,
408 id: RequestId,
409 #[serde(skip_serializing_if = "Option::is_none")]
410 error: Option<crate::client::Error>,
411 #[serde(skip_serializing_if = "Option::is_none")]
412 result: Option<Box<serde_json::value::RawValue>>,
413}