listener.rs

  1use ::serde::{Deserialize, Serialize};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use futures::{
  5    AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
  6    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  7    io::BufReader,
  8    select_biased,
  9};
 10use gpui::{App, AppContext, AsyncApp, Task};
 11use net::async_net::{UnixListener, UnixStream};
 12use schemars::JsonSchema;
 13use serde::de::DeserializeOwned;
 14use serde_json::{json, value::RawValue};
 15use smol::stream::StreamExt;
 16use std::{
 17    cell::RefCell,
 18    path::{Path, PathBuf},
 19    rc::Rc,
 20};
 21use util::ResultExt;
 22
 23use crate::{
 24    client::{CspResult, RequestId, Response},
 25    types::{
 26        CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
 27        ToolResponseContent,
 28        requests::{CallTool, ListTools},
 29    },
 30};
 31
 32pub struct McpServer {
 33    socket_path: PathBuf,
 34    tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
 35    handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
 36    _server_task: Task<()>,
 37}
 38
 39struct RegisteredTool {
 40    tool: Tool,
 41    handler: ToolHandler,
 42}
 43
 44type ToolHandler = Box<
 45    dyn Fn(
 46        Option<serde_json::Value>,
 47        &mut AsyncApp,
 48    ) -> Task<Result<ToolResponse<serde_json::Value>>>,
 49>;
 50type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
 51
 52impl McpServer {
 53    pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
 54        let task = cx.background_spawn(async move {
 55            let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
 56            let socket_path = temp_dir.path().join("mcp.sock");
 57            let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
 58
 59            anyhow::Ok((temp_dir, socket_path, listener))
 60        });
 61
 62        cx.spawn(async move |cx| {
 63            let (temp_dir, socket_path, listener) = task.await?;
 64            let tools = Rc::new(RefCell::new(HashMap::default()));
 65            let handlers = Rc::new(RefCell::new(HashMap::default()));
 66            let server_task = cx.spawn({
 67                let tools = tools.clone();
 68                let handlers = handlers.clone();
 69                async move |cx| {
 70                    while let Ok((stream, _)) = listener.accept().await {
 71                        Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
 72                    }
 73                    drop(temp_dir)
 74                }
 75            });
 76            Ok(Self {
 77                socket_path,
 78                _server_task: server_task,
 79                tools,
 80                handlers,
 81            })
 82        })
 83    }
 84
 85    pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
 86        let mut settings = schemars::generate::SchemaSettings::draft07();
 87        settings.inline_subschemas = true;
 88        let mut generator = settings.into_generator();
 89
 90        let output_schema = generator.root_schema_for::<T::Output>();
 91        let unit_schema = generator.root_schema_for::<T::Output>();
 92
 93        let registered_tool = RegisteredTool {
 94            tool: Tool {
 95                name: T::NAME.into(),
 96                description: Some(tool.description().into()),
 97                input_schema: generator.root_schema_for::<T::Input>().into(),
 98                output_schema: if output_schema == unit_schema {
 99                    None
100                } else {
101                    Some(output_schema.into())
102                },
103                annotations: Some(tool.annotations()),
104            },
105            handler: Box::new({
106                let tool = tool.clone();
107                move |input_value, cx| {
108                    let input = match input_value {
109                        Some(input) => serde_json::from_value(input),
110                        None => serde_json::from_value(serde_json::Value::Null),
111                    };
112
113                    let tool = tool.clone();
114                    match input {
115                        Ok(input) => cx.spawn(async move |cx| {
116                            let output = tool.run(input, cx).await?;
117
118                            Ok(ToolResponse {
119                                content: output.content,
120                                structured_content: serde_json::to_value(output.structured_content)
121                                    .unwrap_or_default(),
122                            })
123                        }),
124                        Err(err) => Task::ready(Err(err.into())),
125                    }
126                }
127            }),
128        };
129
130        self.tools.borrow_mut().insert(T::NAME, registered_tool);
131    }
132
133    pub fn handle_request<R: Request>(
134        &mut self,
135        f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
136    ) {
137        let f = Box::new(f);
138        self.handlers.borrow_mut().insert(
139            R::METHOD,
140            Box::new(move |req_id, opt_params, cx| {
141                let result = match opt_params {
142                    Some(params) => serde_json::from_str(params.get()),
143                    None => serde_json::from_value(serde_json::Value::Null),
144                };
145
146                let params: R::Params = match result {
147                    Ok(params) => params,
148                    Err(e) => {
149                        return Task::ready(
150                            serde_json::to_string(&Response::<R::Response> {
151                                jsonrpc: "2.0",
152                                id: req_id,
153                                value: CspResult::Error(Some(crate::client::Error {
154                                    message: format!("{e}"),
155                                    code: -32700,
156                                })),
157                            })
158                            .unwrap(),
159                        );
160                    }
161                };
162                let task = f(params, cx);
163                cx.background_spawn(async move {
164                    match task.await {
165                        Ok(result) => serde_json::to_string(&Response {
166                            jsonrpc: "2.0",
167                            id: req_id,
168                            value: CspResult::Ok(Some(result)),
169                        })
170                        .unwrap(),
171                        Err(e) => serde_json::to_string(&Response {
172                            jsonrpc: "2.0",
173                            id: req_id,
174                            value: CspResult::Error::<R::Response>(Some(crate::client::Error {
175                                message: format!("{e}"),
176                                code: -32603,
177                            })),
178                        })
179                        .unwrap(),
180                    }
181                })
182            }),
183        );
184    }
185
186    pub fn socket_path(&self) -> &Path {
187        &self.socket_path
188    }
189
190    fn serve_connection(
191        stream: UnixStream,
192        tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
193        handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
194        cx: &mut AsyncApp,
195    ) {
196        let (read, write) = smol::io::split(stream);
197        let (incoming_tx, mut incoming_rx) = unbounded();
198        let (outgoing_tx, outgoing_rx) = unbounded();
199
200        cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
201            .detach();
202
203        cx.spawn(async move |cx| {
204            while let Some(request) = incoming_rx.next().await {
205                let Some(request_id) = request.id.clone() else {
206                    continue;
207                };
208
209                if request.method == CallTool::METHOD {
210                    Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
211                        .await;
212                } else if request.method == ListTools::METHOD {
213                    Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
214                } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
215                    let outgoing_tx = outgoing_tx.clone();
216
217                    if let Some(task) = cx
218                        .update(|cx| handler(request_id, request.params, cx))
219                        .log_err()
220                    {
221                        cx.spawn(async move |_| {
222                            let response = task.await;
223                            outgoing_tx.unbounded_send(response).ok();
224                        })
225                        .detach();
226                    }
227                } else {
228                    Self::send_err(
229                        request_id,
230                        format!("unhandled method {}", request.method),
231                        &outgoing_tx,
232                    );
233                }
234            }
235        })
236        .detach();
237    }
238
239    fn handle_list_tools(
240        request_id: RequestId,
241        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
242        outgoing_tx: &UnboundedSender<String>,
243    ) {
244        let response = ListToolsResponse {
245            tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
246            next_cursor: None,
247            meta: None,
248        };
249
250        outgoing_tx
251            .unbounded_send(
252                serde_json::to_string(&Response {
253                    jsonrpc: "2.0",
254                    id: request_id,
255                    value: CspResult::Ok(Some(response)),
256                })
257                .unwrap_or_default(),
258            )
259            .ok();
260    }
261
262    async fn handle_call_tool(
263        request_id: RequestId,
264        params: Option<Box<RawValue>>,
265        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
266        outgoing_tx: &UnboundedSender<String>,
267        cx: &mut AsyncApp,
268    ) {
269        let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
270            Some(params) => serde_json::from_str(params.get()),
271            None => serde_json::from_value(serde_json::Value::Null),
272        };
273
274        match result {
275            Ok(params) => {
276                if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
277                    let outgoing_tx = outgoing_tx.clone();
278
279                    let task = (tool.handler)(params.arguments, cx);
280                    cx.spawn(async move |_| {
281                        let response = match task.await {
282                            Ok(result) => CallToolResponse {
283                                content: result.content,
284                                is_error: Some(false),
285                                meta: None,
286                                structured_content: if result.structured_content.is_null() {
287                                    None
288                                } else {
289                                    Some(result.structured_content)
290                                },
291                            },
292                            Err(err) => CallToolResponse {
293                                content: vec![ToolResponseContent::Text {
294                                    text: err.to_string(),
295                                }],
296                                is_error: Some(true),
297                                meta: None,
298                                structured_content: None,
299                            },
300                        };
301
302                        outgoing_tx
303                            .unbounded_send(
304                                serde_json::to_string(&Response {
305                                    jsonrpc: "2.0",
306                                    id: request_id,
307                                    value: CspResult::Ok(Some(response)),
308                                })
309                                .unwrap_or_default(),
310                            )
311                            .ok();
312                    })
313                    .detach();
314                } else {
315                    Self::send_err(
316                        request_id,
317                        format!("Tool not found: {}", params.name),
318                        outgoing_tx,
319                    );
320                }
321            }
322            Err(err) => {
323                Self::send_err(request_id, err.to_string(), outgoing_tx);
324            }
325        }
326    }
327
328    fn send_err(
329        request_id: RequestId,
330        message: impl Into<String>,
331        outgoing_tx: &UnboundedSender<String>,
332    ) {
333        outgoing_tx
334            .unbounded_send(
335                serde_json::to_string(&Response::<()> {
336                    jsonrpc: "2.0",
337                    id: request_id,
338                    value: CspResult::Error(Some(crate::client::Error {
339                        message: message.into(),
340                        code: -32601,
341                    })),
342                })
343                .unwrap(),
344            )
345            .ok();
346    }
347
348    async fn handle_io(
349        mut outgoing_rx: UnboundedReceiver<String>,
350        incoming_tx: UnboundedSender<RawRequest>,
351        mut outgoing_bytes: impl Unpin + AsyncWrite,
352        incoming_bytes: impl Unpin + AsyncRead,
353    ) -> Result<()> {
354        let mut output_reader = BufReader::new(incoming_bytes);
355        let mut incoming_line = String::new();
356        loop {
357            select_biased! {
358                message = outgoing_rx.next().fuse() => {
359                    if let Some(message) = message {
360                        log::trace!("send: {}", &message);
361                        outgoing_bytes.write_all(message.as_bytes()).await?;
362                        outgoing_bytes.write_all(&[b'\n']).await?;
363                    } else {
364                        break;
365                    }
366                }
367                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
368                    if bytes_read? == 0 {
369                        break
370                    }
371                    log::trace!("recv: {}", &incoming_line);
372                    match serde_json::from_str(&incoming_line) {
373                        Ok(message) => {
374                            incoming_tx.unbounded_send(message).log_err();
375                        }
376                        Err(error) => {
377                            outgoing_bytes.write_all(serde_json::to_string(&json!({
378                                "jsonrpc": "2.0",
379                                "error": json!({
380                                    "code": -32603,
381                                    "message": format!("Failed to parse: {error}"),
382                                }),
383                            }))?.as_bytes()).await?;
384                            outgoing_bytes.write_all(&[b'\n']).await?;
385                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
386                        }
387                    }
388                    incoming_line.clear();
389                }
390            }
391        }
392        Ok(())
393    }
394}
395
396pub trait McpServerTool {
397    type Input: DeserializeOwned + JsonSchema;
398    type Output: Serialize + JsonSchema;
399
400    const NAME: &'static str;
401
402    fn description(&self) -> &'static str;
403
404    fn annotations(&self) -> ToolAnnotations {
405        ToolAnnotations {
406            title: None,
407            read_only_hint: None,
408            destructive_hint: None,
409            idempotent_hint: None,
410            open_world_hint: None,
411        }
412    }
413
414    fn run(
415        &self,
416        input: Self::Input,
417        cx: &mut AsyncApp,
418    ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
419}
420
421pub struct ToolResponse<T> {
422    pub content: Vec<ToolResponseContent>,
423    pub structured_content: T,
424}
425
426#[derive(Debug, Serialize, Deserialize)]
427struct RawRequest {
428    #[serde(skip_serializing_if = "Option::is_none")]
429    id: Option<RequestId>,
430    method: String,
431    #[serde(skip_serializing_if = "Option::is_none")]
432    params: Option<Box<serde_json::value::RawValue>>,
433}
434
435#[derive(Serialize, Deserialize)]
436struct RawResponse {
437    jsonrpc: &'static str,
438    id: RequestId,
439    #[serde(skip_serializing_if = "Option::is_none")]
440    error: Option<crate::client::Error>,
441    #[serde(skip_serializing_if = "Option::is_none")]
442    result: Option<Box<serde_json::value::RawValue>>,
443}