1use ::serde::{Deserialize, Serialize};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use futures::{
5 AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
6 channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
7 io::BufReader,
8 select_biased,
9};
10use gpui::{App, AppContext, AsyncApp, Task};
11use net::async_net::{UnixListener, UnixStream};
12use schemars::JsonSchema;
13use serde::de::DeserializeOwned;
14use serde_json::{json, value::RawValue};
15use smol::stream::StreamExt;
16use std::{
17 cell::RefCell,
18 path::{Path, PathBuf},
19 rc::Rc,
20};
21use util::ResultExt;
22
23use crate::{
24 client::{CspResult, RequestId, Response},
25 types::{
26 CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
27 ToolResponseContent,
28 requests::{CallTool, ListTools},
29 },
30};
31
32pub struct McpServer {
33 socket_path: PathBuf,
34 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
35 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
36 _server_task: Task<()>,
37}
38
39struct RegisteredTool {
40 tool: Tool,
41 handler: ToolHandler,
42}
43
44type ToolHandler = Box<
45 dyn Fn(
46 Option<serde_json::Value>,
47 &mut AsyncApp,
48 ) -> Task<Result<ToolResponse<serde_json::Value>>>,
49>;
50type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
51
52impl McpServer {
53 pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
54 let task = cx.background_spawn(async move {
55 let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
56 let socket_path = temp_dir.path().join("mcp.sock");
57 let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
58
59 anyhow::Ok((temp_dir, socket_path, listener))
60 });
61
62 cx.spawn(async move |cx| {
63 let (temp_dir, socket_path, listener) = task.await?;
64 let tools = Rc::new(RefCell::new(HashMap::default()));
65 let handlers = Rc::new(RefCell::new(HashMap::default()));
66 let server_task = cx.spawn({
67 let tools = tools.clone();
68 let handlers = handlers.clone();
69 async move |cx| {
70 while let Ok((stream, _)) = listener.accept().await {
71 Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
72 }
73 drop(temp_dir)
74 }
75 });
76 Ok(Self {
77 socket_path,
78 _server_task: server_task,
79 tools,
80 handlers: handlers,
81 })
82 })
83 }
84
85 pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
86 let output_schema = schemars::schema_for!(T::Output);
87 let unit_schema = schemars::schema_for!(());
88
89 let registered_tool = RegisteredTool {
90 tool: Tool {
91 name: T::NAME.into(),
92 description: Some(tool.description().into()),
93 input_schema: schemars::schema_for!(T::Input).into(),
94 output_schema: if output_schema == unit_schema {
95 None
96 } else {
97 Some(output_schema.into())
98 },
99 annotations: Some(tool.annotations()),
100 },
101 handler: Box::new({
102 let tool = tool.clone();
103 move |input_value, cx| {
104 let input = match input_value {
105 Some(input) => serde_json::from_value(input),
106 None => serde_json::from_value(serde_json::Value::Null),
107 };
108
109 let tool = tool.clone();
110 match input {
111 Ok(input) => cx.spawn(async move |cx| {
112 let output = tool.run(input, cx).await?;
113
114 Ok(ToolResponse {
115 content: output.content,
116 structured_content: serde_json::to_value(output.structured_content)
117 .unwrap_or_default(),
118 })
119 }),
120 Err(err) => Task::ready(Err(err.into())),
121 }
122 }
123 }),
124 };
125
126 self.tools.borrow_mut().insert(T::NAME, registered_tool);
127 }
128
129 pub fn handle_request<R: Request>(
130 &mut self,
131 f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
132 ) {
133 let f = Box::new(f);
134 self.handlers.borrow_mut().insert(
135 R::METHOD,
136 Box::new(move |req_id, opt_params, cx| {
137 let result = match opt_params {
138 Some(params) => serde_json::from_str(params.get()),
139 None => serde_json::from_value(serde_json::Value::Null),
140 };
141
142 let params: R::Params = match result {
143 Ok(params) => params,
144 Err(e) => {
145 return Task::ready(
146 serde_json::to_string(&Response::<R::Response> {
147 jsonrpc: "2.0",
148 id: req_id,
149 value: CspResult::Error(Some(crate::client::Error {
150 message: format!("{e}"),
151 code: -32700,
152 })),
153 })
154 .unwrap(),
155 );
156 }
157 };
158 let task = f(params, cx);
159 cx.background_spawn(async move {
160 match task.await {
161 Ok(result) => serde_json::to_string(&Response {
162 jsonrpc: "2.0",
163 id: req_id,
164 value: CspResult::Ok(Some(result)),
165 })
166 .unwrap(),
167 Err(e) => serde_json::to_string(&Response {
168 jsonrpc: "2.0",
169 id: req_id,
170 value: CspResult::Error::<R::Response>(Some(crate::client::Error {
171 message: format!("{e}"),
172 code: -32603,
173 })),
174 })
175 .unwrap(),
176 }
177 })
178 }),
179 );
180 }
181
182 pub fn socket_path(&self) -> &Path {
183 &self.socket_path
184 }
185
186 fn serve_connection(
187 stream: UnixStream,
188 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
189 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
190 cx: &mut AsyncApp,
191 ) {
192 let (read, write) = smol::io::split(stream);
193 let (incoming_tx, mut incoming_rx) = unbounded();
194 let (outgoing_tx, outgoing_rx) = unbounded();
195
196 cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
197 .detach();
198
199 cx.spawn(async move |cx| {
200 while let Some(request) = incoming_rx.next().await {
201 let Some(request_id) = request.id.clone() else {
202 continue;
203 };
204
205 if request.method == CallTool::METHOD {
206 Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
207 .await;
208 } else if request.method == ListTools::METHOD {
209 Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
210 } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
211 let outgoing_tx = outgoing_tx.clone();
212
213 if let Some(task) = cx
214 .update(|cx| handler(request_id, request.params, cx))
215 .log_err()
216 {
217 cx.spawn(async move |_| {
218 let response = task.await;
219 outgoing_tx.unbounded_send(response).ok();
220 })
221 .detach();
222 }
223 } else {
224 Self::send_err(
225 request_id,
226 format!("unhandled method {}", request.method),
227 &outgoing_tx,
228 );
229 }
230 }
231 })
232 .detach();
233 }
234
235 fn handle_list_tools(
236 request_id: RequestId,
237 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
238 outgoing_tx: &UnboundedSender<String>,
239 ) {
240 let response = ListToolsResponse {
241 tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
242 next_cursor: None,
243 meta: None,
244 };
245
246 outgoing_tx
247 .unbounded_send(
248 serde_json::to_string(&Response {
249 jsonrpc: "2.0",
250 id: request_id,
251 value: CspResult::Ok(Some(response)),
252 })
253 .unwrap_or_default(),
254 )
255 .ok();
256 }
257
258 async fn handle_call_tool(
259 request_id: RequestId,
260 params: Option<Box<RawValue>>,
261 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
262 outgoing_tx: &UnboundedSender<String>,
263 cx: &mut AsyncApp,
264 ) {
265 let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
266 Some(params) => serde_json::from_str(params.get()),
267 None => serde_json::from_value(serde_json::Value::Null),
268 };
269
270 match result {
271 Ok(params) => {
272 if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
273 let outgoing_tx = outgoing_tx.clone();
274
275 let task = (tool.handler)(params.arguments, cx);
276 cx.spawn(async move |_| {
277 let response = match task.await {
278 Ok(result) => CallToolResponse {
279 content: result.content,
280 is_error: Some(false),
281 meta: None,
282 structured_content: if result.structured_content.is_null() {
283 None
284 } else {
285 Some(result.structured_content)
286 },
287 },
288 Err(err) => CallToolResponse {
289 content: vec![ToolResponseContent::Text {
290 text: err.to_string(),
291 }],
292 is_error: Some(true),
293 meta: None,
294 structured_content: None,
295 },
296 };
297
298 outgoing_tx
299 .unbounded_send(
300 serde_json::to_string(&Response {
301 jsonrpc: "2.0",
302 id: request_id,
303 value: CspResult::Ok(Some(response)),
304 })
305 .unwrap_or_default(),
306 )
307 .ok();
308 })
309 .detach();
310 } else {
311 Self::send_err(
312 request_id,
313 format!("Tool not found: {}", params.name),
314 &outgoing_tx,
315 );
316 }
317 }
318 Err(err) => {
319 Self::send_err(request_id, err.to_string(), &outgoing_tx);
320 }
321 }
322 }
323
324 fn send_err(
325 request_id: RequestId,
326 message: impl Into<String>,
327 outgoing_tx: &UnboundedSender<String>,
328 ) {
329 outgoing_tx
330 .unbounded_send(
331 serde_json::to_string(&Response::<()> {
332 jsonrpc: "2.0",
333 id: request_id,
334 value: CspResult::Error(Some(crate::client::Error {
335 message: message.into(),
336 code: -32601,
337 })),
338 })
339 .unwrap(),
340 )
341 .ok();
342 }
343
344 async fn handle_io(
345 mut outgoing_rx: UnboundedReceiver<String>,
346 incoming_tx: UnboundedSender<RawRequest>,
347 mut outgoing_bytes: impl Unpin + AsyncWrite,
348 incoming_bytes: impl Unpin + AsyncRead,
349 ) -> Result<()> {
350 let mut output_reader = BufReader::new(incoming_bytes);
351 let mut incoming_line = String::new();
352 loop {
353 select_biased! {
354 message = outgoing_rx.next().fuse() => {
355 if let Some(message) = message {
356 log::trace!("send: {}", &message);
357 outgoing_bytes.write_all(message.as_bytes()).await?;
358 outgoing_bytes.write_all(&[b'\n']).await?;
359 } else {
360 break;
361 }
362 }
363 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
364 if bytes_read? == 0 {
365 break
366 }
367 log::trace!("recv: {}", &incoming_line);
368 match serde_json::from_str(&incoming_line) {
369 Ok(message) => {
370 incoming_tx.unbounded_send(message).log_err();
371 }
372 Err(error) => {
373 outgoing_bytes.write_all(serde_json::to_string(&json!({
374 "jsonrpc": "2.0",
375 "error": json!({
376 "code": -32603,
377 "message": format!("Failed to parse: {error}"),
378 }),
379 }))?.as_bytes()).await?;
380 outgoing_bytes.write_all(&[b'\n']).await?;
381 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
382 }
383 }
384 incoming_line.clear();
385 }
386 }
387 }
388 Ok(())
389 }
390}
391
392pub trait McpServerTool {
393 type Input: DeserializeOwned + JsonSchema;
394 type Output: Serialize + JsonSchema;
395
396 const NAME: &'static str;
397
398 fn description(&self) -> &'static str;
399
400 fn annotations(&self) -> ToolAnnotations {
401 ToolAnnotations {
402 title: None,
403 read_only_hint: None,
404 destructive_hint: None,
405 idempotent_hint: None,
406 open_world_hint: None,
407 }
408 }
409
410 fn run(
411 &self,
412 input: Self::Input,
413 cx: &mut AsyncApp,
414 ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
415}
416
417pub struct ToolResponse<T> {
418 pub content: Vec<ToolResponseContent>,
419 pub structured_content: T,
420}
421
422#[derive(Serialize, Deserialize)]
423struct RawRequest {
424 #[serde(skip_serializing_if = "Option::is_none")]
425 id: Option<RequestId>,
426 method: String,
427 #[serde(skip_serializing_if = "Option::is_none")]
428 params: Option<Box<serde_json::value::RawValue>>,
429}
430
431#[derive(Serialize, Deserialize)]
432struct RawResponse {
433 jsonrpc: &'static str,
434 id: RequestId,
435 #[serde(skip_serializing_if = "Option::is_none")]
436 error: Option<crate::client::Error>,
437 #[serde(skip_serializing_if = "Option::is_none")]
438 result: Option<Box<serde_json::value::RawValue>>,
439}