1use ::serde::{Deserialize, Serialize};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use futures::{
5 AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
6 channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
7 io::BufReader,
8 select_biased,
9};
10use gpui::{App, AppContext, AsyncApp, Task};
11use net::async_net::{UnixListener, UnixStream};
12use schemars::JsonSchema;
13use serde::de::DeserializeOwned;
14use serde_json::{json, value::RawValue};
15use smol::stream::StreamExt;
16use std::{
17 any::TypeId,
18 cell::RefCell,
19 path::{Path, PathBuf},
20 rc::Rc,
21};
22use util::ResultExt;
23
24use crate::{
25 client::{CspResult, RequestId, Response},
26 types::{
27 CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
28 ToolResponseContent,
29 requests::{CallTool, ListTools},
30 },
31};
32
33pub struct McpServer {
34 socket_path: PathBuf,
35 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
36 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
37 _server_task: Task<()>,
38}
39
40struct RegisteredTool {
41 tool: Tool,
42 handler: ToolHandler,
43}
44
45type ToolHandler = Box<
46 dyn Fn(
47 Option<serde_json::Value>,
48 &mut AsyncApp,
49 ) -> Task<Result<ToolResponse<serde_json::Value>>>,
50>;
51type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
52
53impl McpServer {
54 pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
55 let task = cx.background_spawn(async move {
56 let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
57 let socket_path = temp_dir.path().join("mcp.sock");
58 let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
59
60 anyhow::Ok((temp_dir, socket_path, listener))
61 });
62
63 cx.spawn(async move |cx| {
64 let (temp_dir, socket_path, listener) = task.await?;
65 let tools = Rc::new(RefCell::new(HashMap::default()));
66 let handlers = Rc::new(RefCell::new(HashMap::default()));
67 let server_task = cx.spawn({
68 let tools = tools.clone();
69 let handlers = handlers.clone();
70 async move |cx| {
71 while let Ok((stream, _)) = listener.accept().await {
72 Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
73 }
74 drop(temp_dir)
75 }
76 });
77 Ok(Self {
78 socket_path,
79 _server_task: server_task,
80 tools,
81 handlers,
82 })
83 })
84 }
85
86 pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
87 let mut settings = schemars::generate::SchemaSettings::draft07();
88 settings.inline_subschemas = true;
89 let mut generator = settings.into_generator();
90
91 let input_schema = generator.root_schema_for::<T::Input>();
92
93 let description = input_schema
94 .get("description")
95 .and_then(|desc| desc.as_str())
96 .map(|desc| desc.to_string());
97 debug_assert!(
98 description.is_some(),
99 "Input schema struct must include a doc comment for the tool description"
100 );
101
102 let registered_tool = RegisteredTool {
103 tool: Tool {
104 name: T::NAME.into(),
105 description,
106 input_schema: input_schema.into(),
107 output_schema: if TypeId::of::<T::Output>() == TypeId::of::<()>() {
108 None
109 } else {
110 Some(generator.root_schema_for::<T::Output>().into())
111 },
112 annotations: Some(tool.annotations()),
113 },
114 handler: Box::new({
115 let tool = tool.clone();
116 move |input_value, cx| {
117 let input = match input_value {
118 Some(input) => serde_json::from_value(input),
119 None => serde_json::from_value(serde_json::Value::Null),
120 };
121
122 let tool = tool.clone();
123 match input {
124 Ok(input) => cx.spawn(async move |cx| {
125 let output = tool.run(input, cx).await?;
126
127 Ok(ToolResponse {
128 content: output.content,
129 structured_content: serde_json::to_value(output.structured_content)
130 .unwrap_or_default(),
131 })
132 }),
133 Err(err) => Task::ready(Err(err.into())),
134 }
135 }
136 }),
137 };
138
139 self.tools.borrow_mut().insert(T::NAME, registered_tool);
140 }
141
142 pub fn handle_request<R: Request>(
143 &mut self,
144 f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
145 ) {
146 let f = Box::new(f);
147 self.handlers.borrow_mut().insert(
148 R::METHOD,
149 Box::new(move |req_id, opt_params, cx| {
150 let result = match opt_params {
151 Some(params) => serde_json::from_str(params.get()),
152 None => serde_json::from_value(serde_json::Value::Null),
153 };
154
155 let params: R::Params = match result {
156 Ok(params) => params,
157 Err(e) => {
158 return Task::ready(
159 serde_json::to_string(&Response::<R::Response> {
160 jsonrpc: "2.0",
161 id: req_id,
162 value: CspResult::Error(Some(crate::client::Error {
163 message: format!("{e}"),
164 code: -32700,
165 })),
166 })
167 .unwrap(),
168 );
169 }
170 };
171 let task = f(params, cx);
172 cx.background_spawn(async move {
173 match task.await {
174 Ok(result) => serde_json::to_string(&Response {
175 jsonrpc: "2.0",
176 id: req_id,
177 value: CspResult::Ok(Some(result)),
178 })
179 .unwrap(),
180 Err(e) => serde_json::to_string(&Response {
181 jsonrpc: "2.0",
182 id: req_id,
183 value: CspResult::Error::<R::Response>(Some(crate::client::Error {
184 message: format!("{e}"),
185 code: -32603,
186 })),
187 })
188 .unwrap(),
189 }
190 })
191 }),
192 );
193 }
194
195 pub fn socket_path(&self) -> &Path {
196 &self.socket_path
197 }
198
199 fn serve_connection(
200 stream: UnixStream,
201 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
202 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
203 cx: &mut AsyncApp,
204 ) {
205 let (read, write) = smol::io::split(stream);
206 let (incoming_tx, mut incoming_rx) = unbounded();
207 let (outgoing_tx, outgoing_rx) = unbounded();
208
209 cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
210 .detach();
211
212 cx.spawn(async move |cx| {
213 while let Some(request) = incoming_rx.next().await {
214 let Some(request_id) = request.id.clone() else {
215 continue;
216 };
217
218 if request.method == CallTool::METHOD {
219 Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
220 .await;
221 } else if request.method == ListTools::METHOD {
222 Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
223 } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
224 let outgoing_tx = outgoing_tx.clone();
225
226 if let Some(task) = cx
227 .update(|cx| handler(request_id, request.params, cx))
228 .log_err()
229 {
230 cx.spawn(async move |_| {
231 let response = task.await;
232 outgoing_tx.unbounded_send(response).ok();
233 })
234 .detach();
235 }
236 } else {
237 Self::send_err(
238 request_id,
239 format!("unhandled method {}", request.method),
240 &outgoing_tx,
241 );
242 }
243 }
244 })
245 .detach();
246 }
247
248 fn handle_list_tools(
249 request_id: RequestId,
250 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
251 outgoing_tx: &UnboundedSender<String>,
252 ) {
253 let response = ListToolsResponse {
254 tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
255 next_cursor: None,
256 meta: None,
257 };
258
259 outgoing_tx
260 .unbounded_send(
261 serde_json::to_string(&Response {
262 jsonrpc: "2.0",
263 id: request_id,
264 value: CspResult::Ok(Some(response)),
265 })
266 .unwrap_or_default(),
267 )
268 .ok();
269 }
270
271 async fn handle_call_tool(
272 request_id: RequestId,
273 params: Option<Box<RawValue>>,
274 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
275 outgoing_tx: &UnboundedSender<String>,
276 cx: &mut AsyncApp,
277 ) {
278 let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
279 Some(params) => serde_json::from_str(params.get()),
280 None => serde_json::from_value(serde_json::Value::Null),
281 };
282
283 match result {
284 Ok(params) => {
285 if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
286 let outgoing_tx = outgoing_tx.clone();
287
288 let task = (tool.handler)(params.arguments, cx);
289 cx.spawn(async move |_| {
290 let response = match task.await {
291 Ok(result) => CallToolResponse {
292 content: result.content,
293 is_error: Some(false),
294 meta: None,
295 structured_content: if result.structured_content.is_null() {
296 None
297 } else {
298 Some(result.structured_content)
299 },
300 },
301 Err(err) => CallToolResponse {
302 content: vec![ToolResponseContent::Text {
303 text: err.to_string(),
304 }],
305 is_error: Some(true),
306 meta: None,
307 structured_content: None,
308 },
309 };
310
311 outgoing_tx
312 .unbounded_send(
313 serde_json::to_string(&Response {
314 jsonrpc: "2.0",
315 id: request_id,
316 value: CspResult::Ok(Some(response)),
317 })
318 .unwrap_or_default(),
319 )
320 .ok();
321 })
322 .detach();
323 } else {
324 Self::send_err(
325 request_id,
326 format!("Tool not found: {}", params.name),
327 outgoing_tx,
328 );
329 }
330 }
331 Err(err) => {
332 Self::send_err(request_id, err.to_string(), outgoing_tx);
333 }
334 }
335 }
336
337 fn send_err(
338 request_id: RequestId,
339 message: impl Into<String>,
340 outgoing_tx: &UnboundedSender<String>,
341 ) {
342 outgoing_tx
343 .unbounded_send(
344 serde_json::to_string(&Response::<()> {
345 jsonrpc: "2.0",
346 id: request_id,
347 value: CspResult::Error(Some(crate::client::Error {
348 message: message.into(),
349 code: -32601,
350 })),
351 })
352 .unwrap(),
353 )
354 .ok();
355 }
356
357 async fn handle_io(
358 mut outgoing_rx: UnboundedReceiver<String>,
359 incoming_tx: UnboundedSender<RawRequest>,
360 mut outgoing_bytes: impl Unpin + AsyncWrite,
361 incoming_bytes: impl Unpin + AsyncRead,
362 ) -> Result<()> {
363 let mut output_reader = BufReader::new(incoming_bytes);
364 let mut incoming_line = String::new();
365 loop {
366 select_biased! {
367 message = outgoing_rx.next().fuse() => {
368 if let Some(message) = message {
369 log::trace!("send: {}", &message);
370 outgoing_bytes.write_all(message.as_bytes()).await?;
371 outgoing_bytes.write_all(&[b'\n']).await?;
372 } else {
373 break;
374 }
375 }
376 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
377 if bytes_read? == 0 {
378 break
379 }
380 log::trace!("recv: {}", &incoming_line);
381 match serde_json::from_str(&incoming_line) {
382 Ok(message) => {
383 incoming_tx.unbounded_send(message).log_err();
384 }
385 Err(error) => {
386 outgoing_bytes.write_all(serde_json::to_string(&json!({
387 "jsonrpc": "2.0",
388 "error": json!({
389 "code": -32603,
390 "message": format!("Failed to parse: {error}"),
391 }),
392 }))?.as_bytes()).await?;
393 outgoing_bytes.write_all(&[b'\n']).await?;
394 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
395 }
396 }
397 incoming_line.clear();
398 }
399 }
400 }
401 Ok(())
402 }
403}
404
405pub trait McpServerTool {
406 type Input: DeserializeOwned + JsonSchema;
407 type Output: Serialize + JsonSchema;
408
409 const NAME: &'static str;
410
411 fn annotations(&self) -> ToolAnnotations {
412 ToolAnnotations {
413 title: None,
414 read_only_hint: None,
415 destructive_hint: None,
416 idempotent_hint: None,
417 open_world_hint: None,
418 }
419 }
420
421 fn run(
422 &self,
423 input: Self::Input,
424 cx: &mut AsyncApp,
425 ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
426}
427
428#[derive(Debug)]
429pub struct ToolResponse<T> {
430 pub content: Vec<ToolResponseContent>,
431 pub structured_content: T,
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435struct RawRequest {
436 #[serde(skip_serializing_if = "Option::is_none")]
437 id: Option<RequestId>,
438 method: String,
439 #[serde(skip_serializing_if = "Option::is_none")]
440 params: Option<Box<serde_json::value::RawValue>>,
441}
442
443#[derive(Serialize, Deserialize)]
444struct RawResponse {
445 jsonrpc: &'static str,
446 id: RequestId,
447 #[serde(skip_serializing_if = "Option::is_none")]
448 error: Option<crate::client::Error>,
449 #[serde(skip_serializing_if = "Option::is_none")]
450 result: Option<Box<serde_json::value::RawValue>>,
451}