listener.rs

  1use ::serde::{Deserialize, Serialize};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use futures::{
  5    AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
  6    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  7    io::BufReader,
  8    select_biased,
  9};
 10use gpui::{App, AppContext, AsyncApp, Task};
 11use net::async_net::{UnixListener, UnixStream};
 12use schemars::JsonSchema;
 13use serde::de::DeserializeOwned;
 14use serde_json::{json, value::RawValue};
 15use smol::stream::StreamExt;
 16use std::{
 17    any::TypeId,
 18    cell::RefCell,
 19    path::{Path, PathBuf},
 20    rc::Rc,
 21};
 22use util::ResultExt;
 23
 24use crate::{
 25    client::{CspResult, RequestId, Response},
 26    types::{
 27        CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
 28        ToolResponseContent,
 29        requests::{CallTool, ListTools},
 30    },
 31};
 32
 33pub struct McpServer {
 34    socket_path: PathBuf,
 35    tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
 36    handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
 37    _server_task: Task<()>,
 38}
 39
 40struct RegisteredTool {
 41    tool: Tool,
 42    handler: ToolHandler,
 43}
 44
 45type ToolHandler = Box<
 46    dyn Fn(
 47        Option<serde_json::Value>,
 48        &mut AsyncApp,
 49    ) -> Task<Result<ToolResponse<serde_json::Value>>>,
 50>;
 51type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
 52
 53impl McpServer {
 54    pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
 55        let task = cx.background_spawn(async move {
 56            let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
 57            let socket_path = temp_dir.path().join("mcp.sock");
 58            let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
 59
 60            anyhow::Ok((temp_dir, socket_path, listener))
 61        });
 62
 63        cx.spawn(async move |cx| {
 64            let (temp_dir, socket_path, listener) = task.await?;
 65            let tools = Rc::new(RefCell::new(HashMap::default()));
 66            let handlers = Rc::new(RefCell::new(HashMap::default()));
 67            let server_task = cx.spawn({
 68                let tools = tools.clone();
 69                let handlers = handlers.clone();
 70                async move |cx| {
 71                    while let Ok((stream, _)) = listener.accept().await {
 72                        Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
 73                    }
 74                    drop(temp_dir)
 75                }
 76            });
 77            Ok(Self {
 78                socket_path,
 79                _server_task: server_task,
 80                tools,
 81                handlers,
 82            })
 83        })
 84    }
 85
 86    pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
 87        let mut settings = schemars::generate::SchemaSettings::draft07();
 88        settings.inline_subschemas = true;
 89        let mut generator = settings.into_generator();
 90
 91        let input_schema = generator.root_schema_for::<T::Input>();
 92
 93        let description = input_schema
 94            .get("description")
 95            .and_then(|desc| desc.as_str())
 96            .map(|desc| desc.to_string());
 97        debug_assert!(
 98            description.is_some(),
 99            "Input schema struct must include a doc comment for the tool description"
100        );
101
102        let registered_tool = RegisteredTool {
103            tool: Tool {
104                name: T::NAME.into(),
105                description,
106                input_schema: input_schema.into(),
107                output_schema: if TypeId::of::<T::Output>() == TypeId::of::<()>() {
108                    None
109                } else {
110                    Some(generator.root_schema_for::<T::Output>().into())
111                },
112                annotations: Some(tool.annotations()),
113            },
114            handler: Box::new({
115                let tool = tool.clone();
116                move |input_value, cx| {
117                    let input = match input_value {
118                        Some(input) => serde_json::from_value(input),
119                        None => serde_json::from_value(serde_json::Value::Null),
120                    };
121
122                    let tool = tool.clone();
123                    match input {
124                        Ok(input) => cx.spawn(async move |cx| {
125                            let output = tool.run(input, cx).await?;
126
127                            Ok(ToolResponse {
128                                content: output.content,
129                                structured_content: serde_json::to_value(output.structured_content)
130                                    .unwrap_or_default(),
131                            })
132                        }),
133                        Err(err) => Task::ready(Err(err.into())),
134                    }
135                }
136            }),
137        };
138
139        self.tools.borrow_mut().insert(T::NAME, registered_tool);
140    }
141
142    pub fn handle_request<R: Request>(
143        &mut self,
144        f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
145    ) {
146        let f = Box::new(f);
147        self.handlers.borrow_mut().insert(
148            R::METHOD,
149            Box::new(move |req_id, opt_params, cx| {
150                let result = match opt_params {
151                    Some(params) => serde_json::from_str(params.get()),
152                    None => serde_json::from_value(serde_json::Value::Null),
153                };
154
155                let params: R::Params = match result {
156                    Ok(params) => params,
157                    Err(e) => {
158                        return Task::ready(
159                            serde_json::to_string(&Response::<R::Response> {
160                                jsonrpc: "2.0",
161                                id: req_id,
162                                value: CspResult::Error(Some(crate::client::Error {
163                                    message: format!("{e}"),
164                                    code: -32700,
165                                })),
166                            })
167                            .unwrap(),
168                        );
169                    }
170                };
171                let task = f(params, cx);
172                cx.background_spawn(async move {
173                    match task.await {
174                        Ok(result) => serde_json::to_string(&Response {
175                            jsonrpc: "2.0",
176                            id: req_id,
177                            value: CspResult::Ok(Some(result)),
178                        })
179                        .unwrap(),
180                        Err(e) => serde_json::to_string(&Response {
181                            jsonrpc: "2.0",
182                            id: req_id,
183                            value: CspResult::Error::<R::Response>(Some(crate::client::Error {
184                                message: format!("{e}"),
185                                code: -32603,
186                            })),
187                        })
188                        .unwrap(),
189                    }
190                })
191            }),
192        );
193    }
194
195    pub fn socket_path(&self) -> &Path {
196        &self.socket_path
197    }
198
199    fn serve_connection(
200        stream: UnixStream,
201        tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
202        handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
203        cx: &mut AsyncApp,
204    ) {
205        let (read, write) = smol::io::split(stream);
206        let (incoming_tx, mut incoming_rx) = unbounded();
207        let (outgoing_tx, outgoing_rx) = unbounded();
208
209        cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
210            .detach();
211
212        cx.spawn(async move |cx| {
213            while let Some(request) = incoming_rx.next().await {
214                let Some(request_id) = request.id.clone() else {
215                    continue;
216                };
217
218                if request.method == CallTool::METHOD {
219                    Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
220                        .await;
221                } else if request.method == ListTools::METHOD {
222                    Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
223                } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
224                    let outgoing_tx = outgoing_tx.clone();
225
226                    if let Some(task) = cx
227                        .update(|cx| handler(request_id, request.params, cx))
228                        .log_err()
229                    {
230                        cx.spawn(async move |_| {
231                            let response = task.await;
232                            outgoing_tx.unbounded_send(response).ok();
233                        })
234                        .detach();
235                    }
236                } else {
237                    Self::send_err(
238                        request_id,
239                        format!("unhandled method {}", request.method),
240                        &outgoing_tx,
241                    );
242                }
243            }
244        })
245        .detach();
246    }
247
248    fn handle_list_tools(
249        request_id: RequestId,
250        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
251        outgoing_tx: &UnboundedSender<String>,
252    ) {
253        let response = ListToolsResponse {
254            tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
255            next_cursor: None,
256            meta: None,
257        };
258
259        outgoing_tx
260            .unbounded_send(
261                serde_json::to_string(&Response {
262                    jsonrpc: "2.0",
263                    id: request_id,
264                    value: CspResult::Ok(Some(response)),
265                })
266                .unwrap_or_default(),
267            )
268            .ok();
269    }
270
271    async fn handle_call_tool(
272        request_id: RequestId,
273        params: Option<Box<RawValue>>,
274        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
275        outgoing_tx: &UnboundedSender<String>,
276        cx: &mut AsyncApp,
277    ) {
278        let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
279            Some(params) => serde_json::from_str(params.get()),
280            None => serde_json::from_value(serde_json::Value::Null),
281        };
282
283        match result {
284            Ok(params) => {
285                if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
286                    let outgoing_tx = outgoing_tx.clone();
287
288                    let task = (tool.handler)(params.arguments, cx);
289                    cx.spawn(async move |_| {
290                        let response = match task.await {
291                            Ok(result) => CallToolResponse {
292                                content: result.content,
293                                is_error: Some(false),
294                                meta: None,
295                                structured_content: if result.structured_content.is_null() {
296                                    None
297                                } else {
298                                    Some(result.structured_content)
299                                },
300                            },
301                            Err(err) => CallToolResponse {
302                                content: vec![ToolResponseContent::Text {
303                                    text: err.to_string(),
304                                }],
305                                is_error: Some(true),
306                                meta: None,
307                                structured_content: None,
308                            },
309                        };
310
311                        outgoing_tx
312                            .unbounded_send(
313                                serde_json::to_string(&Response {
314                                    jsonrpc: "2.0",
315                                    id: request_id,
316                                    value: CspResult::Ok(Some(response)),
317                                })
318                                .unwrap_or_default(),
319                            )
320                            .ok();
321                    })
322                    .detach();
323                } else {
324                    Self::send_err(
325                        request_id,
326                        format!("Tool not found: {}", params.name),
327                        outgoing_tx,
328                    );
329                }
330            }
331            Err(err) => {
332                Self::send_err(request_id, err.to_string(), outgoing_tx);
333            }
334        }
335    }
336
337    fn send_err(
338        request_id: RequestId,
339        message: impl Into<String>,
340        outgoing_tx: &UnboundedSender<String>,
341    ) {
342        outgoing_tx
343            .unbounded_send(
344                serde_json::to_string(&Response::<()> {
345                    jsonrpc: "2.0",
346                    id: request_id,
347                    value: CspResult::Error(Some(crate::client::Error {
348                        message: message.into(),
349                        code: -32601,
350                    })),
351                })
352                .unwrap(),
353            )
354            .ok();
355    }
356
357    async fn handle_io(
358        mut outgoing_rx: UnboundedReceiver<String>,
359        incoming_tx: UnboundedSender<RawRequest>,
360        mut outgoing_bytes: impl Unpin + AsyncWrite,
361        incoming_bytes: impl Unpin + AsyncRead,
362    ) -> Result<()> {
363        let mut output_reader = BufReader::new(incoming_bytes);
364        let mut incoming_line = String::new();
365        loop {
366            select_biased! {
367                message = outgoing_rx.next().fuse() => {
368                    if let Some(message) = message {
369                        log::trace!("send: {}", &message);
370                        outgoing_bytes.write_all(message.as_bytes()).await?;
371                        outgoing_bytes.write_all(&[b'\n']).await?;
372                    } else {
373                        break;
374                    }
375                }
376                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
377                    if bytes_read? == 0 {
378                        break
379                    }
380                    log::trace!("recv: {}", &incoming_line);
381                    match serde_json::from_str(&incoming_line) {
382                        Ok(message) => {
383                            incoming_tx.unbounded_send(message).log_err();
384                        }
385                        Err(error) => {
386                            outgoing_bytes.write_all(serde_json::to_string(&json!({
387                                "jsonrpc": "2.0",
388                                "error": json!({
389                                    "code": -32603,
390                                    "message": format!("Failed to parse: {error}"),
391                                }),
392                            }))?.as_bytes()).await?;
393                            outgoing_bytes.write_all(&[b'\n']).await?;
394                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
395                        }
396                    }
397                    incoming_line.clear();
398                }
399            }
400        }
401        Ok(())
402    }
403}
404
405pub trait McpServerTool {
406    type Input: DeserializeOwned + JsonSchema;
407    type Output: Serialize + JsonSchema;
408
409    const NAME: &'static str;
410
411    fn annotations(&self) -> ToolAnnotations {
412        ToolAnnotations {
413            title: None,
414            read_only_hint: None,
415            destructive_hint: None,
416            idempotent_hint: None,
417            open_world_hint: None,
418        }
419    }
420
421    fn run(
422        &self,
423        input: Self::Input,
424        cx: &mut AsyncApp,
425    ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
426}
427
428#[derive(Debug)]
429pub struct ToolResponse<T> {
430    pub content: Vec<ToolResponseContent>,
431    pub structured_content: T,
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435struct RawRequest {
436    #[serde(skip_serializing_if = "Option::is_none")]
437    id: Option<RequestId>,
438    method: String,
439    #[serde(skip_serializing_if = "Option::is_none")]
440    params: Option<Box<serde_json::value::RawValue>>,
441}
442
443#[derive(Serialize, Deserialize)]
444struct RawResponse {
445    jsonrpc: &'static str,
446    id: RequestId,
447    #[serde(skip_serializing_if = "Option::is_none")]
448    error: Option<crate::client::Error>,
449    #[serde(skip_serializing_if = "Option::is_none")]
450    result: Option<Box<serde_json::value::RawValue>>,
451}