listener.rs

  1use ::serde::{Deserialize, Serialize};
  2use anyhow::{Context as _, Result};
  3use collections::HashMap;
  4use futures::{
  5    AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
  6    channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
  7    io::BufReader,
  8    select_biased,
  9};
 10use gpui::{App, AppContext, AsyncApp, Task};
 11use net::async_net::{UnixListener, UnixStream};
 12use serde_json::{json, value::RawValue};
 13use smol::stream::StreamExt;
 14use std::{
 15    cell::RefCell,
 16    path::{Path, PathBuf},
 17    rc::Rc,
 18};
 19use util::ResultExt;
 20
 21use crate::{
 22    client::{CspResult, RequestId, Response},
 23    types::Request,
 24};
 25
 26pub struct McpServer {
 27    socket_path: PathBuf,
 28    handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
 29    _server_task: Task<()>,
 30}
 31
 32type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
 33
 34impl McpServer {
 35    pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
 36        let task = cx.background_spawn(async move {
 37            let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
 38            let socket_path = temp_dir.path().join("mcp.sock");
 39            let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
 40
 41            anyhow::Ok((temp_dir, socket_path, listener))
 42        });
 43
 44        cx.spawn(async move |cx| {
 45            let (temp_dir, socket_path, listener) = task.await?;
 46            let handlers = Rc::new(RefCell::new(HashMap::default()));
 47            let server_task = cx.spawn({
 48                let handlers = handlers.clone();
 49                async move |cx| {
 50                    while let Ok((stream, _)) = listener.accept().await {
 51                        Self::serve_connection(stream, handlers.clone(), cx);
 52                    }
 53                    drop(temp_dir)
 54                }
 55            });
 56            Ok(Self {
 57                socket_path,
 58                _server_task: server_task,
 59                handlers: handlers.clone(),
 60            })
 61        })
 62    }
 63
 64    pub fn handle_request<R: Request>(
 65        &mut self,
 66        f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
 67    ) {
 68        let f = Box::new(f);
 69        self.handlers.borrow_mut().insert(
 70            R::METHOD,
 71            Box::new(move |req_id, opt_params, cx| {
 72                let result = match opt_params {
 73                    Some(params) => serde_json::from_str(params.get()),
 74                    None => serde_json::from_value(serde_json::Value::Null),
 75                };
 76
 77                let params: R::Params = match result {
 78                    Ok(params) => params,
 79                    Err(e) => {
 80                        return Task::ready(
 81                            serde_json::to_string(&Response::<R::Response> {
 82                                jsonrpc: "2.0",
 83                                id: req_id,
 84                                value: CspResult::Error(Some(crate::client::Error {
 85                                    message: format!("{e}"),
 86                                    code: -32700,
 87                                })),
 88                            })
 89                            .unwrap(),
 90                        );
 91                    }
 92                };
 93                let task = f(params, cx);
 94                cx.background_spawn(async move {
 95                    match task.await {
 96                        Ok(result) => serde_json::to_string(&Response {
 97                            jsonrpc: "2.0",
 98                            id: req_id,
 99                            value: CspResult::Ok(Some(result)),
100                        })
101                        .unwrap(),
102                        Err(e) => serde_json::to_string(&Response {
103                            jsonrpc: "2.0",
104                            id: req_id,
105                            value: CspResult::Error::<R::Response>(Some(crate::client::Error {
106                                message: format!("{e}"),
107                                code: -32603,
108                            })),
109                        })
110                        .unwrap(),
111                    }
112                })
113            }),
114        );
115    }
116
117    pub fn socket_path(&self) -> &Path {
118        &self.socket_path
119    }
120
121    fn serve_connection(
122        stream: UnixStream,
123        handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
124        cx: &mut AsyncApp,
125    ) {
126        let (read, write) = smol::io::split(stream);
127        let (incoming_tx, mut incoming_rx) = unbounded();
128        let (outgoing_tx, outgoing_rx) = unbounded();
129
130        cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
131            .detach();
132
133        cx.spawn(async move |cx| {
134            while let Some(request) = incoming_rx.next().await {
135                let Some(request_id) = request.id.clone() else {
136                    continue;
137                };
138                if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
139                    let outgoing_tx = outgoing_tx.clone();
140
141                    if let Some(task) = cx
142                        .update(|cx| handler(request_id, request.params, cx))
143                        .log_err()
144                    {
145                        cx.spawn(async move |_| {
146                            let response = task.await;
147                            outgoing_tx.unbounded_send(response).ok();
148                        })
149                        .detach();
150                    }
151                } else {
152                    outgoing_tx
153                        .unbounded_send(
154                            serde_json::to_string(&Response::<()> {
155                                jsonrpc: "2.0",
156                                id: request.id.unwrap(),
157                                value: CspResult::Error(Some(crate::client::Error {
158                                    message: format!("unhandled method {}", request.method),
159                                    code: -32601,
160                                })),
161                            })
162                            .unwrap(),
163                        )
164                        .ok();
165                }
166            }
167        })
168        .detach();
169    }
170
171    async fn handle_io(
172        mut outgoing_rx: UnboundedReceiver<String>,
173        incoming_tx: UnboundedSender<RawRequest>,
174        mut outgoing_bytes: impl Unpin + AsyncWrite,
175        incoming_bytes: impl Unpin + AsyncRead,
176    ) -> Result<()> {
177        let mut output_reader = BufReader::new(incoming_bytes);
178        let mut incoming_line = String::new();
179        loop {
180            select_biased! {
181                message = outgoing_rx.next().fuse() => {
182                    if let Some(message) = message {
183                        log::trace!("send: {}", &message);
184                        outgoing_bytes.write_all(message.as_bytes()).await?;
185                        outgoing_bytes.write_all(&[b'\n']).await?;
186                    } else {
187                        break;
188                    }
189                }
190                bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
191                    if bytes_read? == 0 {
192                        break
193                    }
194                    log::trace!("recv: {}", &incoming_line);
195                    match serde_json::from_str(&incoming_line) {
196                        Ok(message) => {
197                            incoming_tx.unbounded_send(message).log_err();
198                        }
199                        Err(error) => {
200                            outgoing_bytes.write_all(serde_json::to_string(&json!({
201                                "jsonrpc": "2.0",
202                                "error": json!({
203                                    "code": -32603,
204                                    "message": format!("Failed to parse: {error}"),
205                                }),
206                            }))?.as_bytes()).await?;
207                            outgoing_bytes.write_all(&[b'\n']).await?;
208                            log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
209                        }
210                    }
211                    incoming_line.clear();
212                }
213            }
214        }
215        Ok(())
216    }
217}
218
219#[derive(Serialize, Deserialize)]
220struct RawRequest {
221    #[serde(skip_serializing_if = "Option::is_none")]
222    id: Option<RequestId>,
223    method: String,
224    #[serde(skip_serializing_if = "Option::is_none")]
225    params: Option<Box<serde_json::value::RawValue>>,
226}
227
228#[derive(Serialize, Deserialize)]
229struct RawResponse {
230    jsonrpc: &'static str,
231    id: RequestId,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    error: Option<crate::client::Error>,
234    #[serde(skip_serializing_if = "Option::is_none")]
235    result: Option<Box<serde_json::value::RawValue>>,
236}