listener.rs

  1use ::serde::{Deserialize, Serialize};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use futures::{
  5    AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
  6    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  7    io::BufReader,
  8    select_biased,
  9};
 10use gpui::{App, AppContext, AsyncApp, Task};
 11use net::async_net::{UnixListener, UnixStream};
 12use schemars::JsonSchema;
 13use serde::de::DeserializeOwned;
 14use serde_json::{json, value::RawValue};
 15use smol::stream::StreamExt;
 16use std::{
 17    cell::RefCell,
 18    path::{Path, PathBuf},
 19    rc::Rc,
 20};
 21use util::ResultExt;
 22
 23use crate::{
 24    client::{CspResult, RequestId, Response},
 25    types::{
 26        CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
 27        ToolResponseContent,
 28        requests::{CallTool, ListTools},
 29    },
 30};
 31
 32pub struct McpServer {
 33    socket_path: PathBuf,
 34    tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
 35    handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
 36    _server_task: Task<()>,
 37}
 38
 39struct RegisteredTool {
 40    tool: Tool,
 41    handler: ToolHandler,
 42}
 43
 44type ToolHandler =
 45    Box<dyn Fn(Option<serde_json::Value>, &mut AsyncApp) -> Task<Result<ToolResponse>>>;
 46type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
 47
 48impl McpServer {
 49    pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
 50        let task = cx.background_spawn(async move {
 51            let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
 52            let socket_path = temp_dir.path().join("mcp.sock");
 53            let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
 54
 55            anyhow::Ok((temp_dir, socket_path, listener))
 56        });
 57
 58        cx.spawn(async move |cx| {
 59            let (temp_dir, socket_path, listener) = task.await?;
 60            let tools = Rc::new(RefCell::new(HashMap::default()));
 61            let handlers = Rc::new(RefCell::new(HashMap::default()));
 62            let server_task = cx.spawn({
 63                let tools = tools.clone();
 64                let handlers = handlers.clone();
 65                async move |cx| {
 66                    while let Ok((stream, _)) = listener.accept().await {
 67                        Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
 68                    }
 69                    drop(temp_dir)
 70                }
 71            });
 72            Ok(Self {
 73                socket_path,
 74                _server_task: server_task,
 75                tools,
 76                handlers: handlers,
 77            })
 78        })
 79    }
 80
 81    pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
 82        let registered_tool = RegisteredTool {
 83            tool: Tool {
 84                name: T::NAME.into(),
 85                description: Some(tool.description().into()),
 86                input_schema: schemars::schema_for!(T::Input).into(),
 87                annotations: Some(tool.annotations()),
 88            },
 89            handler: Box::new({
 90                let tool = tool.clone();
 91                move |input_value, cx| {
 92                    let input = match input_value {
 93                        Some(input) => serde_json::from_value(input),
 94                        None => serde_json::from_value(serde_json::Value::Null),
 95                    };
 96
 97                    let tool = tool.clone();
 98                    match input {
 99                        Ok(input) => cx.spawn(async move |cx| tool.run(input, cx).await),
100                        Err(err) => Task::ready(Err(err.into())),
101                    }
102                }
103            }),
104        };
105
106        self.tools.borrow_mut().insert(T::NAME, registered_tool);
107    }
108
109    pub fn handle_request<R: Request>(
110        &mut self,
111        f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
112    ) {
113        let f = Box::new(f);
114        self.handlers.borrow_mut().insert(
115            R::METHOD,
116            Box::new(move |req_id, opt_params, cx| {
117                let result = match opt_params {
118                    Some(params) => serde_json::from_str(params.get()),
119                    None => serde_json::from_value(serde_json::Value::Null),
120                };
121
122                let params: R::Params = match result {
123                    Ok(params) => params,
124                    Err(e) => {
125                        return Task::ready(
126                            serde_json::to_string(&Response::<R::Response> {
127                                jsonrpc: "2.0",
128                                id: req_id,
129                                value: CspResult::Error(Some(crate::client::Error {
130                                    message: format!("{e}"),
131                                    code: -32700,
132                                })),
133                            })
134                            .unwrap(),
135                        );
136                    }
137                };
138                let task = f(params, cx);
139                cx.background_spawn(async move {
140                    match task.await {
141                        Ok(result) => serde_json::to_string(&Response {
142                            jsonrpc: "2.0",
143                            id: req_id,
144                            value: CspResult::Ok(Some(result)),
145                        })
146                        .unwrap(),
147                        Err(e) => serde_json::to_string(&Response {
148                            jsonrpc: "2.0",
149                            id: req_id,
150                            value: CspResult::Error::<R::Response>(Some(crate::client::Error {
151                                message: format!("{e}"),
152                                code: -32603,
153                            })),
154                        })
155                        .unwrap(),
156                    }
157                })
158            }),
159        );
160    }
161
162    pub fn socket_path(&self) -> &Path {
163        &self.socket_path
164    }
165
166    fn serve_connection(
167        stream: UnixStream,
168        tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
169        handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
170        cx: &mut AsyncApp,
171    ) {
172        let (read, write) = smol::io::split(stream);
173        let (incoming_tx, mut incoming_rx) = unbounded();
174        let (outgoing_tx, outgoing_rx) = unbounded();
175
176        cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
177            .detach();
178
179        cx.spawn(async move |cx| {
180            while let Some(request) = incoming_rx.next().await {
181                let Some(request_id) = request.id.clone() else {
182                    continue;
183                };
184
185                if request.method == CallTool::METHOD {
186                    Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
187                        .await;
188                } else if request.method == ListTools::METHOD {
189                    Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
190                } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
191                    let outgoing_tx = outgoing_tx.clone();
192
193                    if let Some(task) = cx
194                        .update(|cx| handler(request_id, request.params, cx))
195                        .log_err()
196                    {
197                        cx.spawn(async move |_| {
198                            let response = task.await;
199                            outgoing_tx.unbounded_send(response).ok();
200                        })
201                        .detach();
202                    }
203                } else {
204                    Self::send_err(
205                        request_id,
206                        format!("unhandled method {}", request.method),
207                        &outgoing_tx,
208                    );
209                }
210            }
211        })
212        .detach();
213    }
214
215    fn handle_list_tools(
216        request_id: RequestId,
217        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
218        outgoing_tx: &UnboundedSender<String>,
219    ) {
220        let response = ListToolsResponse {
221            tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
222            next_cursor: None,
223            meta: None,
224        };
225
226        outgoing_tx
227            .unbounded_send(
228                serde_json::to_string(&Response {
229                    jsonrpc: "2.0",
230                    id: request_id,
231                    value: CspResult::Ok(Some(response)),
232                })
233                .unwrap_or_default(),
234            )
235            .ok();
236    }
237
238    async fn handle_call_tool(
239        request_id: RequestId,
240        params: Option<Box<RawValue>>,
241        tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
242        outgoing_tx: &UnboundedSender<String>,
243        cx: &mut AsyncApp,
244    ) {
245        let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
246            Some(params) => serde_json::from_str(params.get()),
247            None => serde_json::from_value(serde_json::Value::Null),
248        };
249
250        match result {
251            Ok(params) => {
252                if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
253                    let outgoing_tx = outgoing_tx.clone();
254
255                    let task = (tool.handler)(params.arguments, cx);
256                    cx.spawn(async move |_| {
257                        let response = match task.await {
258                            Ok(result) => CallToolResponse {
259                                content: result.content,
260                                is_error: Some(false),
261                                meta: None,
262                                structured_content: result.structured_content,
263                            },
264                            Err(err) => CallToolResponse {
265                                content: vec![ToolResponseContent::Text {
266                                    text: err.to_string(),
267                                }],
268                                is_error: Some(true),
269                                meta: None,
270                                structured_content: None,
271                            },
272                        };
273
274                        outgoing_tx
275                            .unbounded_send(
276                                serde_json::to_string(&Response {
277                                    jsonrpc: "2.0",
278                                    id: request_id,
279                                    value: CspResult::Ok(Some(response)),
280                                })
281                                .unwrap_or_default(),
282                            )
283                            .ok();
284                    })
285                    .detach();
286                } else {
287                    Self::send_err(
288                        request_id,
289                        format!("Tool not found: {}", params.name),
290                        &outgoing_tx,
291                    );
292                }
293            }
294            Err(err) => {
295                Self::send_err(request_id, err.to_string(), &outgoing_tx);
296            }
297        }
298    }
299
300    fn send_err(
301        request_id: RequestId,
302        message: impl Into<String>,
303        outgoing_tx: &UnboundedSender<String>,
304    ) {
305        outgoing_tx
306            .unbounded_send(
307                serde_json::to_string(&Response::<()> {
308                    jsonrpc: "2.0",
309                    id: request_id,
310                    value: CspResult::Error(Some(crate::client::Error {
311                        message: message.into(),
312                        code: -32601,
313                    })),
314                })
315                .unwrap(),
316            )
317            .ok();
318    }
319
320    async fn handle_io(
321        mut outgoing_rx: UnboundedReceiver<String>,
322        incoming_tx: UnboundedSender<RawRequest>,
323        mut outgoing_bytes: impl Unpin + AsyncWrite,
324        incoming_bytes: impl Unpin + AsyncRead,
325    ) -> Result<()> {
326        let mut output_reader = BufReader::new(incoming_bytes);
327        let mut incoming_line = String::new();
328        loop {
329            select_biased! {
330                message = outgoing_rx.next().fuse() => {
331                    if let Some(message) = message {
332                        log::trace!("send: {}", &message);
333                        outgoing_bytes.write_all(message.as_bytes()).await?;
334                        outgoing_bytes.write_all(&[b'\n']).await?;
335                    } else {
336                        break;
337                    }
338                }
339                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
340                    if bytes_read? == 0 {
341                        break
342                    }
343                    log::trace!("recv: {}", &incoming_line);
344                    match serde_json::from_str(&incoming_line) {
345                        Ok(message) => {
346                            incoming_tx.unbounded_send(message).log_err();
347                        }
348                        Err(error) => {
349                            outgoing_bytes.write_all(serde_json::to_string(&json!({
350                                "jsonrpc": "2.0",
351                                "error": json!({
352                                    "code": -32603,
353                                    "message": format!("Failed to parse: {error}"),
354                                }),
355                            }))?.as_bytes()).await?;
356                            outgoing_bytes.write_all(&[b'\n']).await?;
357                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
358                        }
359                    }
360                    incoming_line.clear();
361                }
362            }
363        }
364        Ok(())
365    }
366}
367
368pub trait McpServerTool {
369    type Input: DeserializeOwned + JsonSchema;
370    const NAME: &'static str;
371
372    fn description(&self) -> &'static str;
373
374    fn annotations(&self) -> ToolAnnotations {
375        ToolAnnotations {
376            title: None,
377            read_only_hint: None,
378            destructive_hint: None,
379            idempotent_hint: None,
380            open_world_hint: None,
381        }
382    }
383
384    fn run(
385        &self,
386        input: Self::Input,
387        cx: &mut AsyncApp,
388    ) -> impl Future<Output = Result<ToolResponse>>;
389}
390
391pub struct ToolResponse {
392    pub content: Vec<ToolResponseContent>,
393    pub structured_content: Option<serde_json::Value>,
394}
395
396#[derive(Serialize, Deserialize)]
397struct RawRequest {
398    #[serde(skip_serializing_if = "Option::is_none")]
399    id: Option<RequestId>,
400    method: String,
401    #[serde(skip_serializing_if = "Option::is_none")]
402    params: Option<Box<serde_json::value::RawValue>>,
403}
404
405#[derive(Serialize, Deserialize)]
406struct RawResponse {
407    jsonrpc: &'static str,
408    id: RequestId,
409    #[serde(skip_serializing_if = "Option::is_none")]
410    error: Option<crate::client::Error>,
411    #[serde(skip_serializing_if = "Option::is_none")]
412    result: Option<Box<serde_json::value::RawValue>>,
413}