1use ::serde::{Deserialize, Serialize};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use futures::{
5 AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
6 channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
7 io::BufReader,
8 select_biased,
9};
10use gpui::{App, AppContext, AsyncApp, Task};
11use net::async_net::{UnixListener, UnixStream};
12use serde_json::{json, value::RawValue};
13use smol::stream::StreamExt;
14use std::{
15 cell::RefCell,
16 path::{Path, PathBuf},
17 rc::Rc,
18};
19use util::ResultExt;
20
21use crate::{
22 client::{CspResult, RequestId, Response},
23 types::Request,
24};
25
26pub struct McpServer {
27 socket_path: PathBuf,
28 handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
29 _server_task: Task<()>,
30}
31
32type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
33
34impl McpServer {
35 pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
36 let task = cx.background_spawn(async move {
37 let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
38 let socket_path = temp_dir.path().join("mcp.sock");
39 let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
40
41 anyhow::Ok((temp_dir, socket_path, listener))
42 });
43
44 cx.spawn(async move |cx| {
45 let (temp_dir, socket_path, listener) = task.await?;
46 let handlers = Rc::new(RefCell::new(HashMap::default()));
47 let server_task = cx.spawn({
48 let handlers = handlers.clone();
49 async move |cx| {
50 while let Ok((stream, _)) = listener.accept().await {
51 Self::serve_connection(stream, handlers.clone(), cx);
52 }
53 drop(temp_dir)
54 }
55 });
56 Ok(Self {
57 socket_path,
58 _server_task: server_task,
59 handlers: handlers.clone(),
60 })
61 })
62 }
63
64 pub fn handle_request<R: Request>(
65 &mut self,
66 f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
67 ) {
68 let f = Box::new(f);
69 self.handlers.borrow_mut().insert(
70 R::METHOD,
71 Box::new(move |req_id, opt_params, cx| {
72 let result = match opt_params {
73 Some(params) => serde_json::from_str(params.get()),
74 None => serde_json::from_value(serde_json::Value::Null),
75 };
76
77 let params: R::Params = match result {
78 Ok(params) => params,
79 Err(e) => {
80 return Task::ready(
81 serde_json::to_string(&Response::<R::Response> {
82 jsonrpc: "2.0",
83 id: req_id,
84 value: CspResult::Error(Some(crate::client::Error {
85 message: format!("{e}"),
86 code: -32700,
87 })),
88 })
89 .unwrap(),
90 );
91 }
92 };
93 let task = f(params, cx);
94 cx.background_spawn(async move {
95 match task.await {
96 Ok(result) => serde_json::to_string(&Response {
97 jsonrpc: "2.0",
98 id: req_id,
99 value: CspResult::Ok(Some(result)),
100 })
101 .unwrap(),
102 Err(e) => serde_json::to_string(&Response {
103 jsonrpc: "2.0",
104 id: req_id,
105 value: CspResult::Error::<R::Response>(Some(crate::client::Error {
106 message: format!("{e}"),
107 code: -32603,
108 })),
109 })
110 .unwrap(),
111 }
112 })
113 }),
114 );
115 }
116
117 pub fn socket_path(&self) -> &Path {
118 &self.socket_path
119 }
120
121 fn serve_connection(
122 stream: UnixStream,
123 handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
124 cx: &mut AsyncApp,
125 ) {
126 let (read, write) = smol::io::split(stream);
127 let (incoming_tx, mut incoming_rx) = unbounded();
128 let (outgoing_tx, outgoing_rx) = unbounded();
129
130 cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
131 .detach();
132
133 cx.spawn(async move |cx| {
134 while let Some(request) = incoming_rx.next().await {
135 let Some(request_id) = request.id.clone() else {
136 continue;
137 };
138 if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
139 let outgoing_tx = outgoing_tx.clone();
140
141 if let Some(task) = cx
142 .update(|cx| handler(request_id, request.params, cx))
143 .log_err()
144 {
145 cx.spawn(async move |_| {
146 let response = task.await;
147 outgoing_tx.unbounded_send(response).ok();
148 })
149 .detach();
150 }
151 } else {
152 outgoing_tx
153 .unbounded_send(
154 serde_json::to_string(&Response::<()> {
155 jsonrpc: "2.0",
156 id: request.id.unwrap(),
157 value: CspResult::Error(Some(crate::client::Error {
158 message: format!("unhandled method {}", request.method),
159 code: -32601,
160 })),
161 })
162 .unwrap(),
163 )
164 .ok();
165 }
166 }
167 })
168 .detach();
169 }
170
171 async fn handle_io(
172 mut outgoing_rx: UnboundedReceiver<String>,
173 incoming_tx: UnboundedSender<RawRequest>,
174 mut outgoing_bytes: impl Unpin + AsyncWrite,
175 incoming_bytes: impl Unpin + AsyncRead,
176 ) -> Result<()> {
177 let mut output_reader = BufReader::new(incoming_bytes);
178 let mut incoming_line = String::new();
179 loop {
180 select_biased! {
181 message = outgoing_rx.next().fuse() => {
182 if let Some(message) = message {
183 log::trace!("send: {}", &message);
184 outgoing_bytes.write_all(message.as_bytes()).await?;
185 outgoing_bytes.write_all(&[b'\n']).await?;
186 } else {
187 break;
188 }
189 }
190 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
191 if bytes_read? == 0 {
192 break
193 }
194 log::trace!("recv: {}", &incoming_line);
195 match serde_json::from_str(&incoming_line) {
196 Ok(message) => {
197 incoming_tx.unbounded_send(message).log_err();
198 }
199 Err(error) => {
200 outgoing_bytes.write_all(serde_json::to_string(&json!({
201 "jsonrpc": "2.0",
202 "error": json!({
203 "code": -32603,
204 "message": format!("Failed to parse: {error}"),
205 }),
206 }))?.as_bytes()).await?;
207 outgoing_bytes.write_all(&[b'\n']).await?;
208 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
209 }
210 }
211 incoming_line.clear();
212 }
213 }
214 }
215 Ok(())
216 }
217}
218
219#[derive(Serialize, Deserialize)]
220struct RawRequest {
221 #[serde(skip_serializing_if = "Option::is_none")]
222 id: Option<RequestId>,
223 method: String,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 params: Option<Box<serde_json::value::RawValue>>,
226}
227
228#[derive(Serialize, Deserialize)]
229struct RawResponse {
230 jsonrpc: &'static str,
231 id: RequestId,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 error: Option<crate::client::Error>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 result: Option<Box<serde_json::value::RawValue>>,
236}