1use anyhow::{Context as _, Result, anyhow};
2use collections::HashMap;
3use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
4use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
5use parking_lot::Mutex;
6use postage::barrier;
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use serde_json::{Value, value::RawValue};
9use slotmap::SlotMap;
10use smol::channel;
11use std::{
12 fmt,
13 path::PathBuf,
14 pin::pin,
15 sync::{
16 Arc,
17 atomic::{AtomicI32, Ordering::SeqCst},
18 },
19 time::{Duration, Instant},
20};
21use util::{ResultExt, TryFutureExt};
22
23use crate::{
24 transport::{StdioTransport, Transport},
25 types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
26};
27
28const JSON_RPC_VERSION: &str = "2.0";
29const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
30
31// Standard JSON-RPC error codes
32pub const PARSE_ERROR: i32 = -32700;
33pub const INVALID_REQUEST: i32 = -32600;
34pub const METHOD_NOT_FOUND: i32 = -32601;
35pub const INVALID_PARAMS: i32 = -32602;
36pub const INTERNAL_ERROR: i32 = -32603;
37
38type ResponseHandler = Box<dyn Send + FnOnce(String)>;
39type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
40type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
41
42#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
43#[serde(untagged)]
44pub enum RequestId {
45 Int(i32),
46 Str(String),
47}
48
49pub(crate) struct Client {
50 server_id: ContextServerId,
51 next_id: AtomicI32,
52 outbound_tx: channel::Sender<String>,
53 name: Arc<str>,
54 subscription_set: Arc<Mutex<NotificationSubscriptionSet>>,
55 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
56 #[allow(clippy::type_complexity)]
57 #[allow(dead_code)]
58 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
59 #[allow(dead_code)]
60 output_done_rx: Mutex<Option<barrier::Receiver>>,
61 executor: BackgroundExecutor,
62 #[allow(dead_code)]
63 transport: Arc<dyn Transport>,
64 request_timeout: Option<Duration>,
65 /// Single-slot side channel for the last transport-level error. When the
66 /// output task encounters a send failure it stashes the error here and
67 /// exits; the next request to observe cancellation `.take()`s it so it can
68 /// propagate a typed error (e.g. `TransportError::AuthRequired`) instead
69 /// of a generic "cancelled". This works because `initialize` is the sole
70 /// in-flight request at startup, but would need rethinking if concurrent
71 /// requests are ever issued during that phase.
72 last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
73}
74
75#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
76#[repr(transparent)]
77pub(crate) struct ContextServerId(pub Arc<str>);
78
79fn is_null_value<T: Serialize>(value: &T) -> bool {
80 matches!(serde_json::to_value(value), Ok(Value::Null))
81}
82
83#[derive(Serialize, Deserialize)]
84pub struct Request<'a, T> {
85 pub jsonrpc: &'static str,
86 pub id: RequestId,
87 pub method: &'a str,
88 #[serde(skip_serializing_if = "is_null_value")]
89 pub params: T,
90}
91
92#[derive(Serialize, Deserialize)]
93pub struct AnyRequest<'a> {
94 pub jsonrpc: &'a str,
95 pub id: RequestId,
96 pub method: &'a str,
97 #[serde(skip_serializing_if = "is_null_value")]
98 pub params: Option<&'a RawValue>,
99}
100
101#[derive(Serialize, Deserialize)]
102struct AnyResponse<'a> {
103 jsonrpc: &'a str,
104 id: RequestId,
105 #[serde(default)]
106 error: Option<Error>,
107 #[serde(borrow)]
108 result: Option<&'a RawValue>,
109}
110
111#[derive(Serialize, Deserialize)]
112#[allow(dead_code)]
113pub(crate) struct Response<T> {
114 pub jsonrpc: &'static str,
115 pub id: RequestId,
116 #[serde(flatten)]
117 pub value: CspResult<T>,
118}
119
120#[derive(Serialize, Deserialize)]
121#[serde(rename_all = "snake_case")]
122pub(crate) enum CspResult<T> {
123 #[serde(rename = "result")]
124 Ok(Option<T>),
125 #[allow(dead_code)]
126 Error(Option<Error>),
127}
128
129#[derive(Serialize, Deserialize)]
130struct Notification<'a, T> {
131 jsonrpc: &'static str,
132 #[serde(borrow)]
133 method: &'a str,
134 params: T,
135}
136
137#[derive(Debug, Clone, Deserialize)]
138struct AnyNotification<'a> {
139 #[expect(
140 unused,
141 reason = "Part of the JSON-RPC protocol - we expect the field to be present in a valid JSON-RPC notification"
142 )]
143 jsonrpc: &'a str,
144 method: String,
145 #[serde(default)]
146 params: Option<Value>,
147}
148
149#[derive(Debug, Serialize, Deserialize)]
150pub(crate) struct Error {
151 pub message: String,
152 pub code: i32,
153}
154
155#[derive(Debug, Clone, Deserialize)]
156pub struct ModelContextServerBinary {
157 pub executable: PathBuf,
158 pub args: Vec<String>,
159 pub env: Option<HashMap<String, String>>,
160 pub timeout: Option<u64>,
161}
162
163impl Client {
164 /// Creates a new Client instance for a context server.
165 ///
166 /// This function initializes a new Client by spawning a child process for the context server,
167 /// setting up communication channels, and initializing handlers for input/output operations.
168 /// It takes a server ID, binary information, and an async app context as input.
169 pub fn stdio(
170 server_id: ContextServerId,
171 binary: ModelContextServerBinary,
172 working_directory: &Option<PathBuf>,
173 cx: AsyncApp,
174 ) -> Result<Self> {
175 log::debug!(
176 "starting context server (executable={:?}, args={:?})",
177 binary.executable,
178 &binary.args
179 );
180
181 let server_name = binary
182 .executable
183 .file_name()
184 .map(|name| name.to_string_lossy().into_owned())
185 .unwrap_or_else(String::new);
186
187 let timeout = binary.timeout.map(Duration::from_secs);
188 let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?);
189 Self::new(server_id, server_name.into(), transport, timeout, cx)
190 }
191
192 /// Creates a new Client instance for a context server.
193 pub fn new(
194 server_id: ContextServerId,
195 server_name: Arc<str>,
196 transport: Arc<dyn Transport>,
197 request_timeout: Option<Duration>,
198 cx: AsyncApp,
199 ) -> Result<Self> {
200 let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
201 let (output_done_tx, output_done_rx) = barrier::channel();
202
203 let subscription_set = Arc::new(Mutex::new(NotificationSubscriptionSet::default()));
204 let response_handlers =
205 Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
206 let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
207
208 let receive_input_task = cx.spawn({
209 let subscription_set = subscription_set.clone();
210 let response_handlers = response_handlers.clone();
211 let request_handlers = request_handlers.clone();
212 let transport = transport.clone();
213 async move |cx| {
214 Self::handle_input(
215 transport,
216 subscription_set,
217 request_handlers,
218 response_handlers,
219 cx,
220 )
221 .log_err()
222 .await
223 }
224 });
225 let receive_err_task = cx.spawn({
226 let transport = transport.clone();
227 async move |_| Self::handle_err(transport).log_err().await
228 });
229 let input_task = cx.spawn(async move |_| {
230 let (input, err) = futures::join!(receive_input_task, receive_err_task);
231 input.or(err)
232 });
233
234 let last_transport_error: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
235 let output_task = cx.background_spawn({
236 let transport = transport.clone();
237 let last_transport_error = last_transport_error.clone();
238 Self::handle_output(
239 transport,
240 outbound_rx,
241 output_done_tx,
242 response_handlers.clone(),
243 last_transport_error,
244 )
245 .log_err()
246 });
247
248 Ok(Self {
249 server_id,
250 subscription_set,
251 response_handlers,
252 name: server_name,
253 next_id: Default::default(),
254 outbound_tx,
255 executor: cx.background_executor().clone(),
256 io_tasks: Mutex::new(Some((input_task, output_task))),
257 output_done_rx: Mutex::new(Some(output_done_rx)),
258 transport,
259 request_timeout,
260 last_transport_error,
261 })
262 }
263
264 /// Handles input from the server's stdout.
265 ///
266 /// This function continuously reads lines from the provided stdout stream,
267 /// parses them as JSON-RPC responses or notifications, and dispatches them
268 /// to the appropriate handlers. It processes both responses (which are matched
269 /// to pending requests) and notifications (which trigger registered handlers).
270 async fn handle_input(
271 transport: Arc<dyn Transport>,
272 subscription_set: Arc<Mutex<NotificationSubscriptionSet>>,
273 request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
274 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
275 cx: &mut AsyncApp,
276 ) -> anyhow::Result<()> {
277 let mut receiver = transport.receive();
278
279 while let Some(message) = receiver.next().await {
280 log::trace!("recv: {}", &message);
281 if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
282 let mut request_handlers = request_handlers.lock();
283 if let Some(handler) = request_handlers.get_mut(request.method) {
284 handler(
285 request.id,
286 request.params.unwrap_or(RawValue::NULL),
287 cx.clone(),
288 );
289 }
290 } else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
291 if let Some(handlers) = response_handlers.lock().as_mut()
292 && let Some(handler) = handlers.remove(&response.id)
293 {
294 handler(message.to_string());
295 }
296 } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
297 subscription_set.lock().notify(
298 ¬ification.method,
299 notification.params.unwrap_or(Value::Null),
300 cx,
301 )
302 } else {
303 log::error!("Unhandled JSON from context_server: {}", message);
304 }
305 }
306
307 smol::future::yield_now().await;
308
309 Ok(())
310 }
311
312 /// Handles the stderr output from the context server.
313 /// Continuously reads and logs any error messages from the server.
314 async fn handle_err(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
315 while let Some(err) = transport.receive_err().next().await {
316 log::debug!("context server stderr: {}", err.trim());
317 }
318
319 Ok(())
320 }
321
322 /// Handles the output to the context server's stdin.
323 /// This function continuously receives messages from the outbound channel,
324 /// writes them to the server's stdin, and manages the lifecycle of response handlers.
325 async fn handle_output(
326 transport: Arc<dyn Transport>,
327 outbound_rx: channel::Receiver<String>,
328 output_done_tx: barrier::Sender,
329 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
330 last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
331 ) -> anyhow::Result<()> {
332 let _clear_response_handlers = util::defer({
333 let response_handlers = response_handlers.clone();
334 move || {
335 response_handlers.lock().take();
336 }
337 });
338 while let Ok(message) = outbound_rx.recv().await {
339 log::trace!("outgoing message: {}", message);
340 if let Err(err) = transport.send(message).await {
341 log::debug!("transport send failed: {:#}", err);
342 *last_transport_error.lock() = Some(err);
343 return Ok(());
344 }
345 }
346 drop(output_done_tx);
347 Ok(())
348 }
349
350 /// Sends a JSON-RPC request to the context server and waits for a response.
351 /// This function handles serialization, deserialization, timeout, and error handling.
352 pub async fn request<T: DeserializeOwned>(
353 &self,
354 method: &str,
355 params: impl Serialize,
356 ) -> Result<T> {
357 self.request_with(
358 method,
359 params,
360 None,
361 self.request_timeout.or(Some(DEFAULT_REQUEST_TIMEOUT)),
362 )
363 .await
364 }
365
366 pub async fn request_with<T: DeserializeOwned>(
367 &self,
368 method: &str,
369 params: impl Serialize,
370 cancel_rx: Option<oneshot::Receiver<()>>,
371 timeout: Option<Duration>,
372 ) -> Result<T> {
373 let id = self.next_id.fetch_add(1, SeqCst);
374 let request = serde_json::to_string(&Request {
375 jsonrpc: JSON_RPC_VERSION,
376 id: RequestId::Int(id),
377 method,
378 params,
379 })
380 .unwrap();
381
382 let (tx, rx) = oneshot::channel();
383 let handle_response = self
384 .response_handlers
385 .lock()
386 .as_mut()
387 .context("server shut down")
388 .map(|handlers| {
389 handlers.insert(
390 RequestId::Int(id),
391 Box::new(move |result| {
392 let _ = tx.send(result);
393 }),
394 );
395 });
396
397 let send = self
398 .outbound_tx
399 .try_send(request)
400 .context("failed to write to context server's stdin");
401
402 let executor = self.executor.clone();
403 let started = Instant::now();
404 handle_response?;
405 send?;
406
407 let mut timeout_fut = pin!(
408 match timeout {
409 Some(timeout) => future::Either::Left(executor.timer(timeout)),
410 None => future::Either::Right(future::pending()),
411 }
412 .fuse()
413 );
414 let mut cancel_fut = pin!(
415 match cancel_rx {
416 Some(rx) => future::Either::Left(async {
417 rx.await.log_err();
418 }),
419 None => future::Either::Right(future::pending()),
420 }
421 .fuse()
422 );
423
424 select! {
425 response = rx.fuse() => {
426 let elapsed = started.elapsed();
427 log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
428 match response {
429 Ok(response) => {
430 let parsed: AnyResponse = serde_json::from_str(&response)?;
431 if let Some(error) = parsed.error {
432 Err(anyhow!(error.message))
433 } else if let Some(result) = parsed.result {
434 Ok(serde_json::from_str(result.get())?)
435 } else {
436 anyhow::bail!("Invalid response: no result or error");
437 }
438 }
439 Err(_canceled) => {
440 if let Some(err) = self.last_transport_error.lock().take() {
441 return Err(err);
442 }
443 anyhow::bail!("cancelled")
444 }
445 }
446 }
447 _ = cancel_fut => {
448 self.notify(
449 Cancelled::METHOD,
450 ClientNotification::Cancelled(CancelledParams {
451 request_id: RequestId::Int(id),
452 reason: None
453 })
454 ).log_err();
455 anyhow::bail!(RequestCanceled)
456 }
457 _ = timeout_fut => {
458 log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap());
459 anyhow::bail!("Context server request timeout");
460 }
461 }
462 }
463
464 /// Sends a notification to the context server without expecting a response.
465 /// This function serializes the notification and sends it through the outbound channel.
466 pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
467 let notification = serde_json::to_string(&Notification {
468 jsonrpc: JSON_RPC_VERSION,
469 method,
470 params,
471 })
472 .unwrap();
473 self.outbound_tx.try_send(notification)?;
474 Ok(())
475 }
476
477 #[must_use]
478 pub fn on_notification(
479 &self,
480 method: &'static str,
481 f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
482 ) -> NotificationSubscription {
483 let mut notification_subscriptions = self.subscription_set.lock();
484
485 NotificationSubscription {
486 id: notification_subscriptions.add_handler(method, f),
487 set: self.subscription_set.clone(),
488 }
489 }
490}
491
492#[derive(Debug)]
493pub struct RequestCanceled;
494
495impl std::error::Error for RequestCanceled {}
496
497impl std::fmt::Display for RequestCanceled {
498 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
499 f.write_str("Context server request was canceled")
500 }
501}
502
503impl fmt::Display for ContextServerId {
504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505 self.0.fmt(f)
506 }
507}
508
509impl fmt::Debug for Client {
510 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511 f.debug_struct("Context Server Client")
512 .field("id", &self.server_id.0)
513 .field("name", &self.name)
514 .finish_non_exhaustive()
515 }
516}
517
518slotmap::new_key_type! {
519 struct NotificationSubscriptionId;
520}
521
522#[derive(Default)]
523pub struct NotificationSubscriptionSet {
524 // we have very few subscriptions at the moment
525 methods: Vec<(&'static str, Vec<NotificationSubscriptionId>)>,
526 handlers: SlotMap<NotificationSubscriptionId, NotificationHandler>,
527}
528
529impl NotificationSubscriptionSet {
530 #[must_use]
531 fn add_handler(
532 &mut self,
533 method: &'static str,
534 handler: NotificationHandler,
535 ) -> NotificationSubscriptionId {
536 let id = self.handlers.insert(handler);
537 if let Some((_, handler_ids)) = self
538 .methods
539 .iter_mut()
540 .find(|(probe_method, _)| method == *probe_method)
541 {
542 debug_assert!(
543 handler_ids.len() < 20,
544 "Too many MCP handlers for {}. Consider using a different data structure.",
545 method
546 );
547
548 handler_ids.push(id);
549 } else {
550 self.methods.push((method, vec![id]));
551 };
552 id
553 }
554
555 fn notify(&mut self, method: &str, payload: Value, cx: &mut AsyncApp) {
556 let Some((_, handler_ids)) = self
557 .methods
558 .iter_mut()
559 .find(|(probe_method, _)| method == *probe_method)
560 else {
561 return;
562 };
563
564 for handler_id in handler_ids {
565 if let Some(handler) = self.handlers.get_mut(*handler_id) {
566 handler(payload.clone(), cx.clone());
567 }
568 }
569 }
570}
571
572pub struct NotificationSubscription {
573 id: NotificationSubscriptionId,
574 set: Arc<Mutex<NotificationSubscriptionSet>>,
575}
576
577impl Drop for NotificationSubscription {
578 fn drop(&mut self) {
579 let mut set = self.set.lock();
580 set.handlers.remove(self.id);
581 set.methods.retain_mut(|(_, handler_ids)| {
582 handler_ids.retain(|id| *id != self.id);
583 !handler_ids.is_empty()
584 });
585 }
586}