listener.rs

  1use ::serde::{Deserialize, Serialize};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use futures::{
  5    AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
  6    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  7    io::BufReader,
  8    select_biased,
  9};
 10use gpui::{App, AppContext, AsyncApp, Task};
 11use net::async_net::{UnixListener, UnixStream};
 12use schemars::JsonSchema;
 13use serde::de::DeserializeOwned;
 14use serde_json::{json, value::RawValue};
 15use smol::stream::StreamExt;
 16use std::{
 17    any::TypeId,
 18    cell::RefCell,
 19    path::{Path, PathBuf},
 20    rc::Rc,
 21};
 22use util::ResultExt;
 23
 24use crate::{
 25    client::{CspResult, RequestId, Response},
 26    types::{
 27        CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
 28        ToolResponseContent,
 29        requests::{CallTool, ListTools},
 30    },
 31};
 32
 33pub struct McpServer {
 34    socket_path: PathBuf,
 35    tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
 36    handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
 37    _server_task: Task<()>,
 38}
 39
 40struct RegisteredTool {
 41    tool: Tool,
 42    handler: ToolHandler,
 43}
 44
 45type ToolHandler = Box<
 46    dyn Fn(
 47        Option<serde_json::Value>,
 48        &mut AsyncApp,
 49    ) -> Task<Result<ToolResponse<serde_json::Value>>>,
 50>;
 51type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
 52
 53impl McpServer {
 54    pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
 55        let task = cx.background_spawn(async move {
 56            let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
 57            let socket_path = temp_dir.path().join("mcp.sock");
 58            let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
 59
 60            anyhow::Ok((temp_dir, socket_path, listener))
 61        });
 62
 63        cx.spawn(async move |cx| {
 64            let (temp_dir, socket_path, listener) = task.await?;
 65            let tools = Rc::new(RefCell::new(HashMap::default()));
 66            let handlers = Rc::new(RefCell::new(HashMap::default()));
 67            let server_task = cx.spawn({
 68                let tools = tools.clone();
 69                let handlers = handlers.clone();
 70                async move |cx| {
 71                    while let Ok((stream, _)) = listener.accept().await {
 72                        Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
 73                    }
 74                    drop(temp_dir)
 75                }
 76            });
 77            Ok(Self {
 78                socket_path,
 79                _server_task: server_task,
 80                tools,
 81                handlers,
 82            })
 83        })
 84    }
 85
 86    pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
 87        let mut settings = schemars::generate::SchemaSettings::draft07();
 88        settings.inline_subschemas = true;
 89        let mut generator = settings.into_generator();
 90
 91        let input_schema = generator.root_schema_for::<T::Input>();
 92
 93        let description = input_schema
 94            .get("description")
 95            .and_then(|desc| desc.as_str())
 96            .map(|desc| desc.to_string());
 97        debug_assert!(
 98            description.is_some(),
 99            "Input schema struct must include a doc comment for the tool description"
100        );
101
102        let registered_tool = RegisteredTool {
103            tool: Tool {
104                name: T::NAME.into(),
105                description,
106                input_schema: input_schema.into(),
107                output_schema: if TypeId::of::<T::Output>() == TypeId::of::<()>() {
108                    None
109                } else {
110                    Some(generator.root_schema_for::<T::Output>().into())
111                },
112                annotations: Some(tool.annotations()),
113            },
114            handler: Box::new({
115                move |input_value, cx| {
116                    let input = match input_value {
117                        Some(input) => serde_json::from_value(input),
118                        None => serde_json::from_value(serde_json::Value::Null),
119                    };
120
121                    let tool = tool.clone();
122                    match input {
123                        Ok(input) => cx.spawn(async move |cx| {
124                            let output = tool.run(input, cx).await?;
125
126                            Ok(ToolResponse {
127                                content: output.content,
128                                structured_content: serde_json::to_value(output.structured_content)
129                                    .unwrap_or_default(),
130                            })
131                        }),
132                        Err(err) => Task::ready(Err(err.into())),
133                    }
134                }
135            }),
136        };
137
138        self.tools.borrow_mut().insert(T::NAME, registered_tool);
139    }
140
141    pub fn handle_request<R: Request>(
142        &mut self,
143        f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
144    ) {
145        let f = Box::new(f);
146        self.handlers.borrow_mut().insert(
147            R::METHOD,
148            Box::new(move |req_id, opt_params, cx| {
149                let result = match opt_params {
150                    Some(params) => serde_json::from_str(params.get()),
151                    None => serde_json::from_value(serde_json::Value::Null),
152                };
153
154                let params: R::Params = match result {
155                    Ok(params) => params,
156                    Err(e) => {
157                        return Task::ready(
158                            serde_json::to_string(&Response::<R::Response> {
159                                jsonrpc: "2.0",
160                                id: req_id,
161                                value: CspResult::Error(Some(crate::client::Error {
162                                    message: format!("{e}"),
163                                    code: -32700,
164                                })),
165                            })
166                            .unwrap(),
167                        );
168                    }
169                };
170                let task = f(params, cx);
171                cx.background_spawn(async move {
172                    match task.await {
173                        Ok(result) => serde_json::to_string(&Response {
174                            jsonrpc: "2.0",
175                            id: req_id,
176                            value: CspResult::Ok(Some(result)),
177                        })
178                        .unwrap(),
179                        Err(e) => serde_json::to_string(&Response {
180                            jsonrpc: "2.0",
181                            id: req_id,
182                            value: CspResult::Error::<R::Response>(Some(crate::client::Error {
183                                message: format!("{e}"),
184                                code: -32603,
185                            })),
186                        })
187                        .unwrap(),
188                    }
189                })
190            }),
191        );
192    }
193
194    pub fn socket_path(&self) -> &Path {
195        &self.socket_path
196    }
197
198    fn serve_connection(
199        stream: UnixStream,
200        tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
201        handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
202        cx: &mut AsyncApp,
203    ) {
204        let (read, write) = smol::io::split(stream);
205        let (incoming_tx, mut incoming_rx) = unbounded();
206        let (outgoing_tx, outgoing_rx) = unbounded();
207
208        cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
209            .detach();
210
211        cx.spawn(async move |cx| {
212            while let Some(request) = incoming_rx.next().await {
213                let Some(request_id) = request.id.clone() else {
214                    continue;
215                };
216
217                if request.method == CallTool::METHOD {
218                    Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
219                        .await;
220                } else if request.method == ListTools::METHOD {
221                    Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
222                } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
223                    let outgoing_tx = outgoing_tx.clone();
224
225                    if let Some(task) = cx
226                        .update(|cx| handler(request_id, request.params, cx))
227                        .log_err()
228                    {
229                        cx.spawn(async move |_| {
230                            let response = task.await;
231                            outgoing_tx.unbounded_send(response).ok();
232                        })
233                        .detach();
234                    }
235                } else {
236                    Self::send_err(
237                        request_id,
238                        format!("unhandled method {}", request.method),
239                        &outgoing_tx,
240                    );
241                }
242            }
243        })
244        .detach();
245    }
246
247    fn handle_list_tools(
248        request_id: RequestId,
249        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
250        outgoing_tx: &UnboundedSender<String>,
251    ) {
252        let response = ListToolsResponse {
253            tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
254            next_cursor: None,
255            meta: None,
256        };
257
258        outgoing_tx
259            .unbounded_send(
260                serde_json::to_string(&Response {
261                    jsonrpc: "2.0",
262                    id: request_id,
263                    value: CspResult::Ok(Some(response)),
264                })
265                .unwrap_or_default(),
266            )
267            .ok();
268    }
269
270    async fn handle_call_tool(
271        request_id: RequestId,
272        params: Option<Box<RawValue>>,
273        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
274        outgoing_tx: &UnboundedSender<String>,
275        cx: &mut AsyncApp,
276    ) {
277        let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
278            Some(params) => serde_json::from_str(params.get()),
279            None => serde_json::from_value(serde_json::Value::Null),
280        };
281
282        match result {
283            Ok(params) => {
284                if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
285                    let outgoing_tx = outgoing_tx.clone();
286
287                    let task = (tool.handler)(params.arguments, cx);
288                    cx.spawn(async move |_| {
289                        let response = match task.await {
290                            Ok(result) => CallToolResponse {
291                                content: result.content,
292                                is_error: Some(false),
293                                meta: None,
294                                structured_content: if result.structured_content.is_null() {
295                                    None
296                                } else {
297                                    Some(result.structured_content)
298                                },
299                            },
300                            Err(err) => CallToolResponse {
301                                content: vec![ToolResponseContent::Text {
302                                    text: err.to_string(),
303                                }],
304                                is_error: Some(true),
305                                meta: None,
306                                structured_content: None,
307                            },
308                        };
309
310                        outgoing_tx
311                            .unbounded_send(
312                                serde_json::to_string(&Response {
313                                    jsonrpc: "2.0",
314                                    id: request_id,
315                                    value: CspResult::Ok(Some(response)),
316                                })
317                                .unwrap_or_default(),
318                            )
319                            .ok();
320                    })
321                    .detach();
322                } else {
323                    Self::send_err(
324                        request_id,
325                        format!("Tool not found: {}", params.name),
326                        outgoing_tx,
327                    );
328                }
329            }
330            Err(err) => {
331                Self::send_err(request_id, err.to_string(), outgoing_tx);
332            }
333        }
334    }
335
336    fn send_err(
337        request_id: RequestId,
338        message: impl Into<String>,
339        outgoing_tx: &UnboundedSender<String>,
340    ) {
341        outgoing_tx
342            .unbounded_send(
343                serde_json::to_string(&Response::<()> {
344                    jsonrpc: "2.0",
345                    id: request_id,
346                    value: CspResult::Error(Some(crate::client::Error {
347                        message: message.into(),
348                        code: -32601,
349                    })),
350                })
351                .unwrap(),
352            )
353            .ok();
354    }
355
356    async fn handle_io(
357        mut outgoing_rx: UnboundedReceiver<String>,
358        incoming_tx: UnboundedSender<RawRequest>,
359        mut outgoing_bytes: impl Unpin + AsyncWrite,
360        incoming_bytes: impl Unpin + AsyncRead,
361    ) -> Result<()> {
362        let mut output_reader = BufReader::new(incoming_bytes);
363        let mut incoming_line = String::new();
364        loop {
365            select_biased! {
366                message = outgoing_rx.next().fuse() => {
367                    if let Some(message) = message {
368                        log::trace!("send: {}", &message);
369                        outgoing_bytes.write_all(message.as_bytes()).await?;
370                        outgoing_bytes.write_all(&[b'\n']).await?;
371                    } else {
372                        break;
373                    }
374                }
375                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
376                    if bytes_read? == 0 {
377                        break
378                    }
379                    log::trace!("recv: {}", &incoming_line);
380                    match serde_json::from_str(&incoming_line) {
381                        Ok(message) => {
382                            incoming_tx.unbounded_send(message).log_err();
383                        }
384                        Err(error) => {
385                            outgoing_bytes.write_all(serde_json::to_string(&json!({
386                                "jsonrpc": "2.0",
387                                "error": json!({
388                                    "code": -32603,
389                                    "message": format!("Failed to parse: {error}"),
390                                }),
391                            }))?.as_bytes()).await?;
392                            outgoing_bytes.write_all(&[b'\n']).await?;
393                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
394                        }
395                    }
396                    incoming_line.clear();
397                }
398            }
399        }
400        Ok(())
401    }
402}
403
404pub trait McpServerTool {
405    type Input: DeserializeOwned + JsonSchema;
406    type Output: Serialize + JsonSchema;
407
408    const NAME: &'static str;
409
410    fn annotations(&self) -> ToolAnnotations {
411        ToolAnnotations {
412            title: None,
413            read_only_hint: None,
414            destructive_hint: None,
415            idempotent_hint: None,
416            open_world_hint: None,
417        }
418    }
419
420    fn run(
421        &self,
422        input: Self::Input,
423        cx: &mut AsyncApp,
424    ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
425}
426
427#[derive(Debug)]
428pub struct ToolResponse<T> {
429    pub content: Vec<ToolResponseContent>,
430    pub structured_content: T,
431}
432
433#[derive(Debug, Serialize, Deserialize)]
434struct RawRequest {
435    #[serde(skip_serializing_if = "Option::is_none")]
436    id: Option<RequestId>,
437    method: String,
438    #[serde(skip_serializing_if = "Option::is_none")]
439    params: Option<Box<serde_json::value::RawValue>>,
440}