1use ::serde::{Deserialize, Serialize};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use futures::{
5 AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
6 channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
7 io::BufReader,
8 select_biased,
9};
10use gpui::{App, AppContext, AsyncApp, Task};
11use net::async_net::{UnixListener, UnixStream};
12use schemars::JsonSchema;
13use serde::de::DeserializeOwned;
14use serde_json::{json, value::RawValue};
15use smol::stream::StreamExt;
16use std::{
17 cell::RefCell,
18 path::{Path, PathBuf},
19 rc::Rc,
20};
21use util::ResultExt;
22
23use crate::{
24 client::{CspResult, RequestId, Response},
25 types::{
26 CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
27 ToolResponseContent,
28 requests::{CallTool, ListTools},
29 },
30};
31
32pub struct McpServer {
33 socket_path: PathBuf,
34 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
35 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
36 _server_task: Task<()>,
37}
38
39struct RegisteredTool {
40 tool: Tool,
41 handler: ToolHandler,
42}
43
44type ToolHandler = Box<
45 dyn Fn(
46 Option<serde_json::Value>,
47 &mut AsyncApp,
48 ) -> Task<Result<ToolResponse<serde_json::Value>>>,
49>;
50type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
51
52impl McpServer {
53 pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
54 let task = cx.background_spawn(async move {
55 let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
56 let socket_path = temp_dir.path().join("mcp.sock");
57 let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
58
59 anyhow::Ok((temp_dir, socket_path, listener))
60 });
61
62 cx.spawn(async move |cx| {
63 let (temp_dir, socket_path, listener) = task.await?;
64 let tools = Rc::new(RefCell::new(HashMap::default()));
65 let handlers = Rc::new(RefCell::new(HashMap::default()));
66 let server_task = cx.spawn({
67 let tools = tools.clone();
68 let handlers = handlers.clone();
69 async move |cx| {
70 while let Ok((stream, _)) = listener.accept().await {
71 Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
72 }
73 drop(temp_dir)
74 }
75 });
76 Ok(Self {
77 socket_path,
78 _server_task: server_task,
79 tools,
80 handlers,
81 })
82 })
83 }
84
85 pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
86 let mut settings = schemars::generate::SchemaSettings::draft07();
87 settings.inline_subschemas = true;
88 let mut generator = settings.into_generator();
89
90 let output_schema = generator.root_schema_for::<T::Output>();
91 let unit_schema = generator.root_schema_for::<T::Output>();
92
93 let registered_tool = RegisteredTool {
94 tool: Tool {
95 name: T::NAME.into(),
96 description: Some(tool.description().into()),
97 input_schema: generator.root_schema_for::<T::Input>().into(),
98 output_schema: if output_schema == unit_schema {
99 None
100 } else {
101 Some(output_schema.into())
102 },
103 annotations: Some(tool.annotations()),
104 },
105 handler: Box::new({
106 let tool = tool.clone();
107 move |input_value, cx| {
108 let input = match input_value {
109 Some(input) => serde_json::from_value(input),
110 None => serde_json::from_value(serde_json::Value::Null),
111 };
112
113 let tool = tool.clone();
114 match input {
115 Ok(input) => cx.spawn(async move |cx| {
116 let output = tool.run(input, cx).await?;
117
118 Ok(ToolResponse {
119 content: output.content,
120 structured_content: serde_json::to_value(output.structured_content)
121 .unwrap_or_default(),
122 })
123 }),
124 Err(err) => Task::ready(Err(err.into())),
125 }
126 }
127 }),
128 };
129
130 self.tools.borrow_mut().insert(T::NAME, registered_tool);
131 }
132
133 pub fn handle_request<R: Request>(
134 &mut self,
135 f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
136 ) {
137 let f = Box::new(f);
138 self.handlers.borrow_mut().insert(
139 R::METHOD,
140 Box::new(move |req_id, opt_params, cx| {
141 let result = match opt_params {
142 Some(params) => serde_json::from_str(params.get()),
143 None => serde_json::from_value(serde_json::Value::Null),
144 };
145
146 let params: R::Params = match result {
147 Ok(params) => params,
148 Err(e) => {
149 return Task::ready(
150 serde_json::to_string(&Response::<R::Response> {
151 jsonrpc: "2.0",
152 id: req_id,
153 value: CspResult::Error(Some(crate::client::Error {
154 message: format!("{e}"),
155 code: -32700,
156 })),
157 })
158 .unwrap(),
159 );
160 }
161 };
162 let task = f(params, cx);
163 cx.background_spawn(async move {
164 match task.await {
165 Ok(result) => serde_json::to_string(&Response {
166 jsonrpc: "2.0",
167 id: req_id,
168 value: CspResult::Ok(Some(result)),
169 })
170 .unwrap(),
171 Err(e) => serde_json::to_string(&Response {
172 jsonrpc: "2.0",
173 id: req_id,
174 value: CspResult::Error::<R::Response>(Some(crate::client::Error {
175 message: format!("{e}"),
176 code: -32603,
177 })),
178 })
179 .unwrap(),
180 }
181 })
182 }),
183 );
184 }
185
186 pub fn socket_path(&self) -> &Path {
187 &self.socket_path
188 }
189
190 fn serve_connection(
191 stream: UnixStream,
192 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
193 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
194 cx: &mut AsyncApp,
195 ) {
196 let (read, write) = smol::io::split(stream);
197 let (incoming_tx, mut incoming_rx) = unbounded();
198 let (outgoing_tx, outgoing_rx) = unbounded();
199
200 cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
201 .detach();
202
203 cx.spawn(async move |cx| {
204 while let Some(request) = incoming_rx.next().await {
205 let Some(request_id) = request.id.clone() else {
206 continue;
207 };
208
209 if request.method == CallTool::METHOD {
210 Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
211 .await;
212 } else if request.method == ListTools::METHOD {
213 Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
214 } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
215 let outgoing_tx = outgoing_tx.clone();
216
217 if let Some(task) = cx
218 .update(|cx| handler(request_id, request.params, cx))
219 .log_err()
220 {
221 cx.spawn(async move |_| {
222 let response = task.await;
223 outgoing_tx.unbounded_send(response).ok();
224 })
225 .detach();
226 }
227 } else {
228 Self::send_err(
229 request_id,
230 format!("unhandled method {}", request.method),
231 &outgoing_tx,
232 );
233 }
234 }
235 })
236 .detach();
237 }
238
239 fn handle_list_tools(
240 request_id: RequestId,
241 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
242 outgoing_tx: &UnboundedSender<String>,
243 ) {
244 let response = ListToolsResponse {
245 tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
246 next_cursor: None,
247 meta: None,
248 };
249
250 outgoing_tx
251 .unbounded_send(
252 serde_json::to_string(&Response {
253 jsonrpc: "2.0",
254 id: request_id,
255 value: CspResult::Ok(Some(response)),
256 })
257 .unwrap_or_default(),
258 )
259 .ok();
260 }
261
262 async fn handle_call_tool(
263 request_id: RequestId,
264 params: Option<Box<RawValue>>,
265 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
266 outgoing_tx: &UnboundedSender<String>,
267 cx: &mut AsyncApp,
268 ) {
269 let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
270 Some(params) => serde_json::from_str(params.get()),
271 None => serde_json::from_value(serde_json::Value::Null),
272 };
273
274 match result {
275 Ok(params) => {
276 if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
277 let outgoing_tx = outgoing_tx.clone();
278
279 let task = (tool.handler)(params.arguments, cx);
280 cx.spawn(async move |_| {
281 let response = match task.await {
282 Ok(result) => CallToolResponse {
283 content: result.content,
284 is_error: Some(false),
285 meta: None,
286 structured_content: if result.structured_content.is_null() {
287 None
288 } else {
289 Some(result.structured_content)
290 },
291 },
292 Err(err) => CallToolResponse {
293 content: vec![ToolResponseContent::Text {
294 text: err.to_string(),
295 }],
296 is_error: Some(true),
297 meta: None,
298 structured_content: None,
299 },
300 };
301
302 outgoing_tx
303 .unbounded_send(
304 serde_json::to_string(&Response {
305 jsonrpc: "2.0",
306 id: request_id,
307 value: CspResult::Ok(Some(response)),
308 })
309 .unwrap_or_default(),
310 )
311 .ok();
312 })
313 .detach();
314 } else {
315 Self::send_err(
316 request_id,
317 format!("Tool not found: {}", params.name),
318 outgoing_tx,
319 );
320 }
321 }
322 Err(err) => {
323 Self::send_err(request_id, err.to_string(), outgoing_tx);
324 }
325 }
326 }
327
328 fn send_err(
329 request_id: RequestId,
330 message: impl Into<String>,
331 outgoing_tx: &UnboundedSender<String>,
332 ) {
333 outgoing_tx
334 .unbounded_send(
335 serde_json::to_string(&Response::<()> {
336 jsonrpc: "2.0",
337 id: request_id,
338 value: CspResult::Error(Some(crate::client::Error {
339 message: message.into(),
340 code: -32601,
341 })),
342 })
343 .unwrap(),
344 )
345 .ok();
346 }
347
348 async fn handle_io(
349 mut outgoing_rx: UnboundedReceiver<String>,
350 incoming_tx: UnboundedSender<RawRequest>,
351 mut outgoing_bytes: impl Unpin + AsyncWrite,
352 incoming_bytes: impl Unpin + AsyncRead,
353 ) -> Result<()> {
354 let mut output_reader = BufReader::new(incoming_bytes);
355 let mut incoming_line = String::new();
356 loop {
357 select_biased! {
358 message = outgoing_rx.next().fuse() => {
359 if let Some(message) = message {
360 log::trace!("send: {}", &message);
361 outgoing_bytes.write_all(message.as_bytes()).await?;
362 outgoing_bytes.write_all(&[b'\n']).await?;
363 } else {
364 break;
365 }
366 }
367 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
368 if bytes_read? == 0 {
369 break
370 }
371 log::trace!("recv: {}", &incoming_line);
372 match serde_json::from_str(&incoming_line) {
373 Ok(message) => {
374 incoming_tx.unbounded_send(message).log_err();
375 }
376 Err(error) => {
377 outgoing_bytes.write_all(serde_json::to_string(&json!({
378 "jsonrpc": "2.0",
379 "error": json!({
380 "code": -32603,
381 "message": format!("Failed to parse: {error}"),
382 }),
383 }))?.as_bytes()).await?;
384 outgoing_bytes.write_all(&[b'\n']).await?;
385 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
386 }
387 }
388 incoming_line.clear();
389 }
390 }
391 }
392 Ok(())
393 }
394}
395
396pub trait McpServerTool {
397 type Input: DeserializeOwned + JsonSchema;
398 type Output: Serialize + JsonSchema;
399
400 const NAME: &'static str;
401
402 fn description(&self) -> &'static str;
403
404 fn annotations(&self) -> ToolAnnotations {
405 ToolAnnotations {
406 title: None,
407 read_only_hint: None,
408 destructive_hint: None,
409 idempotent_hint: None,
410 open_world_hint: None,
411 }
412 }
413
414 fn run(
415 &self,
416 input: Self::Input,
417 cx: &mut AsyncApp,
418 ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
419}
420
421pub struct ToolResponse<T> {
422 pub content: Vec<ToolResponseContent>,
423 pub structured_content: T,
424}
425
426#[derive(Debug, Serialize, Deserialize)]
427struct RawRequest {
428 #[serde(skip_serializing_if = "Option::is_none")]
429 id: Option<RequestId>,
430 method: String,
431 #[serde(skip_serializing_if = "Option::is_none")]
432 params: Option<Box<serde_json::value::RawValue>>,
433}
434
435#[derive(Serialize, Deserialize)]
436struct RawResponse {
437 jsonrpc: &'static str,
438 id: RequestId,
439 #[serde(skip_serializing_if = "Option::is_none")]
440 error: Option<crate::client::Error>,
441 #[serde(skip_serializing_if = "Option::is_none")]
442 result: Option<Box<serde_json::value::RawValue>>,
443}