listener.rs

  1use ::serde::{Deserialize, Serialize};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use futures::{
  5    AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
  6    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  7    io::BufReader,
  8    select_biased,
  9};
 10use gpui::{App, AppContext, AsyncApp, Task};
 11use net::async_net::{UnixListener, UnixStream};
 12use schemars::JsonSchema;
 13use serde::de::DeserializeOwned;
 14use serde_json::{json, value::RawValue};
 15use smol::stream::StreamExt;
 16use std::{
 17    cell::RefCell,
 18    path::{Path, PathBuf},
 19    rc::Rc,
 20};
 21use util::ResultExt;
 22
 23use crate::{
 24    client::{CspResult, RequestId, Response},
 25    types::{
 26        CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
 27        ToolResponseContent,
 28        requests::{CallTool, ListTools},
 29    },
 30};
 31
 32pub struct McpServer {
 33    socket_path: PathBuf,
 34    tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
 35    handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
 36    _server_task: Task<()>,
 37}
 38
 39struct RegisteredTool {
 40    tool: Tool,
 41    handler: ToolHandler,
 42}
 43
 44type ToolHandler = Box<
 45    dyn Fn(
 46        Option<serde_json::Value>,
 47        &mut AsyncApp,
 48    ) -> Task<Result<ToolResponse<serde_json::Value>>>,
 49>;
 50type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
 51
 52impl McpServer {
 53    pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
 54        let task = cx.background_spawn(async move {
 55            let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
 56            let socket_path = temp_dir.path().join("mcp.sock");
 57            let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
 58
 59            anyhow::Ok((temp_dir, socket_path, listener))
 60        });
 61
 62        cx.spawn(async move |cx| {
 63            let (temp_dir, socket_path, listener) = task.await?;
 64            let tools = Rc::new(RefCell::new(HashMap::default()));
 65            let handlers = Rc::new(RefCell::new(HashMap::default()));
 66            let server_task = cx.spawn({
 67                let tools = tools.clone();
 68                let handlers = handlers.clone();
 69                async move |cx| {
 70                    while let Ok((stream, _)) = listener.accept().await {
 71                        Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
 72                    }
 73                    drop(temp_dir)
 74                }
 75            });
 76            Ok(Self {
 77                socket_path,
 78                _server_task: server_task,
 79                tools,
 80                handlers: handlers,
 81            })
 82        })
 83    }
 84
 85    pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
 86        let output_schema = schemars::schema_for!(T::Output);
 87        let unit_schema = schemars::schema_for!(());
 88
 89        let registered_tool = RegisteredTool {
 90            tool: Tool {
 91                name: T::NAME.into(),
 92                description: Some(tool.description().into()),
 93                input_schema: schemars::schema_for!(T::Input).into(),
 94                output_schema: if output_schema == unit_schema {
 95                    None
 96                } else {
 97                    Some(output_schema.into())
 98                },
 99                annotations: Some(tool.annotations()),
100            },
101            handler: Box::new({
102                let tool = tool.clone();
103                move |input_value, cx| {
104                    let input = match input_value {
105                        Some(input) => serde_json::from_value(input),
106                        None => serde_json::from_value(serde_json::Value::Null),
107                    };
108
109                    let tool = tool.clone();
110                    match input {
111                        Ok(input) => cx.spawn(async move |cx| {
112                            let output = tool.run(input, cx).await?;
113
114                            Ok(ToolResponse {
115                                content: output.content,
116                                structured_content: serde_json::to_value(output.structured_content)
117                                    .unwrap_or_default(),
118                            })
119                        }),
120                        Err(err) => Task::ready(Err(err.into())),
121                    }
122                }
123            }),
124        };
125
126        self.tools.borrow_mut().insert(T::NAME, registered_tool);
127    }
128
129    pub fn handle_request<R: Request>(
130        &mut self,
131        f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
132    ) {
133        let f = Box::new(f);
134        self.handlers.borrow_mut().insert(
135            R::METHOD,
136            Box::new(move |req_id, opt_params, cx| {
137                let result = match opt_params {
138                    Some(params) => serde_json::from_str(params.get()),
139                    None => serde_json::from_value(serde_json::Value::Null),
140                };
141
142                let params: R::Params = match result {
143                    Ok(params) => params,
144                    Err(e) => {
145                        return Task::ready(
146                            serde_json::to_string(&Response::<R::Response> {
147                                jsonrpc: "2.0",
148                                id: req_id,
149                                value: CspResult::Error(Some(crate::client::Error {
150                                    message: format!("{e}"),
151                                    code: -32700,
152                                })),
153                            })
154                            .unwrap(),
155                        );
156                    }
157                };
158                let task = f(params, cx);
159                cx.background_spawn(async move {
160                    match task.await {
161                        Ok(result) => serde_json::to_string(&Response {
162                            jsonrpc: "2.0",
163                            id: req_id,
164                            value: CspResult::Ok(Some(result)),
165                        })
166                        .unwrap(),
167                        Err(e) => serde_json::to_string(&Response {
168                            jsonrpc: "2.0",
169                            id: req_id,
170                            value: CspResult::Error::<R::Response>(Some(crate::client::Error {
171                                message: format!("{e}"),
172                                code: -32603,
173                            })),
174                        })
175                        .unwrap(),
176                    }
177                })
178            }),
179        );
180    }
181
182    pub fn socket_path(&self) -> &Path {
183        &self.socket_path
184    }
185
186    fn serve_connection(
187        stream: UnixStream,
188        tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
189        handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
190        cx: &mut AsyncApp,
191    ) {
192        let (read, write) = smol::io::split(stream);
193        let (incoming_tx, mut incoming_rx) = unbounded();
194        let (outgoing_tx, outgoing_rx) = unbounded();
195
196        cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
197            .detach();
198
199        cx.spawn(async move |cx| {
200            while let Some(request) = incoming_rx.next().await {
201                let Some(request_id) = request.id.clone() else {
202                    continue;
203                };
204
205                if request.method == CallTool::METHOD {
206                    Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
207                        .await;
208                } else if request.method == ListTools::METHOD {
209                    Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
210                } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
211                    let outgoing_tx = outgoing_tx.clone();
212
213                    if let Some(task) = cx
214                        .update(|cx| handler(request_id, request.params, cx))
215                        .log_err()
216                    {
217                        cx.spawn(async move |_| {
218                            let response = task.await;
219                            outgoing_tx.unbounded_send(response).ok();
220                        })
221                        .detach();
222                    }
223                } else {
224                    Self::send_err(
225                        request_id,
226                        format!("unhandled method {}", request.method),
227                        &outgoing_tx,
228                    );
229                }
230            }
231        })
232        .detach();
233    }
234
235    fn handle_list_tools(
236        request_id: RequestId,
237        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
238        outgoing_tx: &UnboundedSender<String>,
239    ) {
240        let response = ListToolsResponse {
241            tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
242            next_cursor: None,
243            meta: None,
244        };
245
246        outgoing_tx
247            .unbounded_send(
248                serde_json::to_string(&Response {
249                    jsonrpc: "2.0",
250                    id: request_id,
251                    value: CspResult::Ok(Some(response)),
252                })
253                .unwrap_or_default(),
254            )
255            .ok();
256    }
257
258    async fn handle_call_tool(
259        request_id: RequestId,
260        params: Option<Box<RawValue>>,
261        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
262        outgoing_tx: &UnboundedSender<String>,
263        cx: &mut AsyncApp,
264    ) {
265        let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
266            Some(params) => serde_json::from_str(params.get()),
267            None => serde_json::from_value(serde_json::Value::Null),
268        };
269
270        match result {
271            Ok(params) => {
272                if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
273                    let outgoing_tx = outgoing_tx.clone();
274
275                    let task = (tool.handler)(params.arguments, cx);
276                    cx.spawn(async move |_| {
277                        let response = match task.await {
278                            Ok(result) => CallToolResponse {
279                                content: result.content,
280                                is_error: Some(false),
281                                meta: None,
282                                structured_content: if result.structured_content.is_null() {
283                                    None
284                                } else {
285                                    Some(result.structured_content)
286                                },
287                            },
288                            Err(err) => CallToolResponse {
289                                content: vec![ToolResponseContent::Text {
290                                    text: err.to_string(),
291                                }],
292                                is_error: Some(true),
293                                meta: None,
294                                structured_content: None,
295                            },
296                        };
297
298                        outgoing_tx
299                            .unbounded_send(
300                                serde_json::to_string(&Response {
301                                    jsonrpc: "2.0",
302                                    id: request_id,
303                                    value: CspResult::Ok(Some(response)),
304                                })
305                                .unwrap_or_default(),
306                            )
307                            .ok();
308                    })
309                    .detach();
310                } else {
311                    Self::send_err(
312                        request_id,
313                        format!("Tool not found: {}", params.name),
314                        &outgoing_tx,
315                    );
316                }
317            }
318            Err(err) => {
319                Self::send_err(request_id, err.to_string(), &outgoing_tx);
320            }
321        }
322    }
323
324    fn send_err(
325        request_id: RequestId,
326        message: impl Into<String>,
327        outgoing_tx: &UnboundedSender<String>,
328    ) {
329        outgoing_tx
330            .unbounded_send(
331                serde_json::to_string(&Response::<()> {
332                    jsonrpc: "2.0",
333                    id: request_id,
334                    value: CspResult::Error(Some(crate::client::Error {
335                        message: message.into(),
336                        code: -32601,
337                    })),
338                })
339                .unwrap(),
340            )
341            .ok();
342    }
343
344    async fn handle_io(
345        mut outgoing_rx: UnboundedReceiver<String>,
346        incoming_tx: UnboundedSender<RawRequest>,
347        mut outgoing_bytes: impl Unpin + AsyncWrite,
348        incoming_bytes: impl Unpin + AsyncRead,
349    ) -> Result<()> {
350        let mut output_reader = BufReader::new(incoming_bytes);
351        let mut incoming_line = String::new();
352        loop {
353            select_biased! {
354                message = outgoing_rx.next().fuse() => {
355                    if let Some(message) = message {
356                        log::trace!("send: {}", &message);
357                        outgoing_bytes.write_all(message.as_bytes()).await?;
358                        outgoing_bytes.write_all(&[b'\n']).await?;
359                    } else {
360                        break;
361                    }
362                }
363                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
364                    if bytes_read? == 0 {
365                        break
366                    }
367                    log::trace!("recv: {}", &incoming_line);
368                    match serde_json::from_str(&incoming_line) {
369                        Ok(message) => {
370                            incoming_tx.unbounded_send(message).log_err();
371                        }
372                        Err(error) => {
373                            outgoing_bytes.write_all(serde_json::to_string(&json!({
374                                "jsonrpc": "2.0",
375                                "error": json!({
376                                    "code": -32603,
377                                    "message": format!("Failed to parse: {error}"),
378                                }),
379                            }))?.as_bytes()).await?;
380                            outgoing_bytes.write_all(&[b'\n']).await?;
381                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
382                        }
383                    }
384                    incoming_line.clear();
385                }
386            }
387        }
388        Ok(())
389    }
390}
391
392pub trait McpServerTool {
393    type Input: DeserializeOwned + JsonSchema;
394    type Output: Serialize + JsonSchema;
395
396    const NAME: &'static str;
397
398    fn description(&self) -> &'static str;
399
400    fn annotations(&self) -> ToolAnnotations {
401        ToolAnnotations {
402            title: None,
403            read_only_hint: None,
404            destructive_hint: None,
405            idempotent_hint: None,
406            open_world_hint: None,
407        }
408    }
409
410    fn run(
411        &self,
412        input: Self::Input,
413        cx: &mut AsyncApp,
414    ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
415}
416
417pub struct ToolResponse<T> {
418    pub content: Vec<ToolResponseContent>,
419    pub structured_content: T,
420}
421
422#[derive(Serialize, Deserialize)]
423struct RawRequest {
424    #[serde(skip_serializing_if = "Option::is_none")]
425    id: Option<RequestId>,
426    method: String,
427    #[serde(skip_serializing_if = "Option::is_none")]
428    params: Option<Box<serde_json::value::RawValue>>,
429}
430
431#[derive(Serialize, Deserialize)]
432struct RawResponse {
433    jsonrpc: &'static str,
434    id: RequestId,
435    #[serde(skip_serializing_if = "Option::is_none")]
436    error: Option<crate::client::Error>,
437    #[serde(skip_serializing_if = "Option::is_none")]
438    result: Option<Box<serde_json::value::RawValue>>,
439}