1use ::serde::{Deserialize, Serialize};
2use anyhow::{Context as _, Result};
3use collections::HashMap;
4use futures::{
5 AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
6 channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
7 io::BufReader,
8 select_biased,
9};
10use gpui::{App, AppContext, AsyncApp, Task};
11use net::async_net::{UnixListener, UnixStream};
12use schemars::JsonSchema;
13use serde::de::DeserializeOwned;
14use serde_json::{json, value::RawValue};
15use smol::stream::StreamExt;
16use std::{
17 any::TypeId,
18 cell::RefCell,
19 path::{Path, PathBuf},
20 rc::Rc,
21};
22use util::ResultExt;
23
24use crate::{
25 client::{CspResult, RequestId, Response},
26 types::{
27 CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
28 ToolResponseContent,
29 requests::{CallTool, ListTools},
30 },
31};
32
33pub struct McpServer {
34 socket_path: PathBuf,
35 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
36 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
37 _server_task: Task<()>,
38}
39
40struct RegisteredTool {
41 tool: Tool,
42 handler: ToolHandler,
43}
44
45type ToolHandler = Box<
46 dyn Fn(
47 Option<serde_json::Value>,
48 &mut AsyncApp,
49 ) -> Task<Result<ToolResponse<serde_json::Value>>>,
50>;
51type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
52
53impl McpServer {
54 pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
55 let task = cx.background_spawn(async move {
56 let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
57 let socket_path = temp_dir.path().join("mcp.sock");
58 let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
59
60 anyhow::Ok((temp_dir, socket_path, listener))
61 });
62
63 cx.spawn(async move |cx| {
64 let (temp_dir, socket_path, listener) = task.await?;
65 let tools = Rc::new(RefCell::new(HashMap::default()));
66 let handlers = Rc::new(RefCell::new(HashMap::default()));
67 let server_task = cx.spawn({
68 let tools = tools.clone();
69 let handlers = handlers.clone();
70 async move |cx| {
71 while let Ok((stream, _)) = listener.accept().await {
72 Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
73 }
74 drop(temp_dir)
75 }
76 });
77 Ok(Self {
78 socket_path,
79 _server_task: server_task,
80 tools,
81 handlers,
82 })
83 })
84 }
85
86 pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
87 let mut settings = schemars::generate::SchemaSettings::draft07();
88 settings.inline_subschemas = true;
89 let mut generator = settings.into_generator();
90
91 let input_schema = generator.root_schema_for::<T::Input>();
92
93 let description = input_schema
94 .get("description")
95 .and_then(|desc| desc.as_str())
96 .map(|desc| desc.to_string());
97 debug_assert!(
98 description.is_some(),
99 "Input schema struct must include a doc comment for the tool description"
100 );
101
102 let registered_tool = RegisteredTool {
103 tool: Tool {
104 name: T::NAME.into(),
105 description,
106 input_schema: input_schema.into(),
107 output_schema: if TypeId::of::<T::Output>() == TypeId::of::<()>() {
108 None
109 } else {
110 Some(generator.root_schema_for::<T::Output>().into())
111 },
112 annotations: Some(tool.annotations()),
113 },
114 handler: Box::new({
115 move |input_value, cx| {
116 let input = match input_value {
117 Some(input) => serde_json::from_value(input),
118 None => serde_json::from_value(serde_json::Value::Null),
119 };
120
121 let tool = tool.clone();
122 match input {
123 Ok(input) => cx.spawn(async move |cx| {
124 let output = tool.run(input, cx).await?;
125
126 Ok(ToolResponse {
127 content: output.content,
128 structured_content: serde_json::to_value(output.structured_content)
129 .unwrap_or_default(),
130 })
131 }),
132 Err(err) => Task::ready(Err(err.into())),
133 }
134 }
135 }),
136 };
137
138 self.tools.borrow_mut().insert(T::NAME, registered_tool);
139 }
140
141 pub fn handle_request<R: Request>(
142 &mut self,
143 f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
144 ) {
145 let f = Box::new(f);
146 self.handlers.borrow_mut().insert(
147 R::METHOD,
148 Box::new(move |req_id, opt_params, cx| {
149 let result = match opt_params {
150 Some(params) => serde_json::from_str(params.get()),
151 None => serde_json::from_value(serde_json::Value::Null),
152 };
153
154 let params: R::Params = match result {
155 Ok(params) => params,
156 Err(e) => {
157 return Task::ready(
158 serde_json::to_string(&Response::<R::Response> {
159 jsonrpc: "2.0",
160 id: req_id,
161 value: CspResult::Error(Some(crate::client::Error {
162 message: format!("{e}"),
163 code: -32700,
164 })),
165 })
166 .unwrap(),
167 );
168 }
169 };
170 let task = f(params, cx);
171 cx.background_spawn(async move {
172 match task.await {
173 Ok(result) => serde_json::to_string(&Response {
174 jsonrpc: "2.0",
175 id: req_id,
176 value: CspResult::Ok(Some(result)),
177 })
178 .unwrap(),
179 Err(e) => serde_json::to_string(&Response {
180 jsonrpc: "2.0",
181 id: req_id,
182 value: CspResult::Error::<R::Response>(Some(crate::client::Error {
183 message: format!("{e}"),
184 code: -32603,
185 })),
186 })
187 .unwrap(),
188 }
189 })
190 }),
191 );
192 }
193
194 pub fn socket_path(&self) -> &Path {
195 &self.socket_path
196 }
197
198 fn serve_connection(
199 stream: UnixStream,
200 tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
201 handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
202 cx: &mut AsyncApp,
203 ) {
204 let (read, write) = smol::io::split(stream);
205 let (incoming_tx, mut incoming_rx) = unbounded();
206 let (outgoing_tx, outgoing_rx) = unbounded();
207
208 cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
209 .detach();
210
211 cx.spawn(async move |cx| {
212 while let Some(request) = incoming_rx.next().await {
213 let Some(request_id) = request.id.clone() else {
214 continue;
215 };
216
217 if request.method == CallTool::METHOD {
218 Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
219 .await;
220 } else if request.method == ListTools::METHOD {
221 Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
222 } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
223 let outgoing_tx = outgoing_tx.clone();
224
225 if let Some(task) = cx
226 .update(|cx| handler(request_id, request.params, cx))
227 .log_err()
228 {
229 cx.spawn(async move |_| {
230 let response = task.await;
231 outgoing_tx.unbounded_send(response).ok();
232 })
233 .detach();
234 }
235 } else {
236 Self::send_err(
237 request_id,
238 format!("unhandled method {}", request.method),
239 &outgoing_tx,
240 );
241 }
242 }
243 })
244 .detach();
245 }
246
247 fn handle_list_tools(
248 request_id: RequestId,
249 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
250 outgoing_tx: &UnboundedSender<String>,
251 ) {
252 let response = ListToolsResponse {
253 tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
254 next_cursor: None,
255 meta: None,
256 };
257
258 outgoing_tx
259 .unbounded_send(
260 serde_json::to_string(&Response {
261 jsonrpc: "2.0",
262 id: request_id,
263 value: CspResult::Ok(Some(response)),
264 })
265 .unwrap_or_default(),
266 )
267 .ok();
268 }
269
270 async fn handle_call_tool(
271 request_id: RequestId,
272 params: Option<Box<RawValue>>,
273 tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
274 outgoing_tx: &UnboundedSender<String>,
275 cx: &mut AsyncApp,
276 ) {
277 let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
278 Some(params) => serde_json::from_str(params.get()),
279 None => serde_json::from_value(serde_json::Value::Null),
280 };
281
282 match result {
283 Ok(params) => {
284 if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
285 let outgoing_tx = outgoing_tx.clone();
286
287 let task = (tool.handler)(params.arguments, cx);
288 cx.spawn(async move |_| {
289 let response = match task.await {
290 Ok(result) => CallToolResponse {
291 content: result.content,
292 is_error: Some(false),
293 meta: None,
294 structured_content: if result.structured_content.is_null() {
295 None
296 } else {
297 Some(result.structured_content)
298 },
299 },
300 Err(err) => CallToolResponse {
301 content: vec![ToolResponseContent::Text {
302 text: err.to_string(),
303 }],
304 is_error: Some(true),
305 meta: None,
306 structured_content: None,
307 },
308 };
309
310 outgoing_tx
311 .unbounded_send(
312 serde_json::to_string(&Response {
313 jsonrpc: "2.0",
314 id: request_id,
315 value: CspResult::Ok(Some(response)),
316 })
317 .unwrap_or_default(),
318 )
319 .ok();
320 })
321 .detach();
322 } else {
323 Self::send_err(
324 request_id,
325 format!("Tool not found: {}", params.name),
326 outgoing_tx,
327 );
328 }
329 }
330 Err(err) => {
331 Self::send_err(request_id, err.to_string(), outgoing_tx);
332 }
333 }
334 }
335
336 fn send_err(
337 request_id: RequestId,
338 message: impl Into<String>,
339 outgoing_tx: &UnboundedSender<String>,
340 ) {
341 outgoing_tx
342 .unbounded_send(
343 serde_json::to_string(&Response::<()> {
344 jsonrpc: "2.0",
345 id: request_id,
346 value: CspResult::Error(Some(crate::client::Error {
347 message: message.into(),
348 code: -32601,
349 })),
350 })
351 .unwrap(),
352 )
353 .ok();
354 }
355
356 async fn handle_io(
357 mut outgoing_rx: UnboundedReceiver<String>,
358 incoming_tx: UnboundedSender<RawRequest>,
359 mut outgoing_bytes: impl Unpin + AsyncWrite,
360 incoming_bytes: impl Unpin + AsyncRead,
361 ) -> Result<()> {
362 let mut output_reader = BufReader::new(incoming_bytes);
363 let mut incoming_line = String::new();
364 loop {
365 select_biased! {
366 message = outgoing_rx.next().fuse() => {
367 if let Some(message) = message {
368 log::trace!("send: {}", &message);
369 outgoing_bytes.write_all(message.as_bytes()).await?;
370 outgoing_bytes.write_all(&[b'\n']).await?;
371 } else {
372 break;
373 }
374 }
375 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
376 if bytes_read? == 0 {
377 break
378 }
379 log::trace!("recv: {}", &incoming_line);
380 match serde_json::from_str(&incoming_line) {
381 Ok(message) => {
382 incoming_tx.unbounded_send(message).log_err();
383 }
384 Err(error) => {
385 outgoing_bytes.write_all(serde_json::to_string(&json!({
386 "jsonrpc": "2.0",
387 "error": json!({
388 "code": -32603,
389 "message": format!("Failed to parse: {error}"),
390 }),
391 }))?.as_bytes()).await?;
392 outgoing_bytes.write_all(&[b'\n']).await?;
393 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
394 }
395 }
396 incoming_line.clear();
397 }
398 }
399 }
400 Ok(())
401 }
402}
403
404pub trait McpServerTool {
405 type Input: DeserializeOwned + JsonSchema;
406 type Output: Serialize + JsonSchema;
407
408 const NAME: &'static str;
409
410 fn annotations(&self) -> ToolAnnotations {
411 ToolAnnotations {
412 title: None,
413 read_only_hint: None,
414 destructive_hint: None,
415 idempotent_hint: None,
416 open_world_hint: None,
417 }
418 }
419
420 fn run(
421 &self,
422 input: Self::Input,
423 cx: &mut AsyncApp,
424 ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
425}
426
427#[derive(Debug)]
428pub struct ToolResponse<T> {
429 pub content: Vec<ToolResponseContent>,
430 pub structured_content: T,
431}
432
433#[derive(Debug, Serialize, Deserialize)]
434struct RawRequest {
435 #[serde(skip_serializing_if = "Option::is_none")]
436 id: Option<RequestId>,
437 method: String,
438 #[serde(skip_serializing_if = "Option::is_none")]
439 params: Option<Box<serde_json::value::RawValue>>,
440}