1use anyhow::{anyhow, Context as _, Result};
2use collections::HashMap;
3use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
4use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
5use parking_lot::Mutex;
6use postage::barrier;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use serde_json::{value::RawValue, Value};
9use smol::{
10 channel,
11 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
12 process::Child,
13};
14use std::{
15 fmt,
16 path::PathBuf,
17 sync::{
18 atomic::{AtomicI32, Ordering::SeqCst},
19 Arc,
20 },
21 time::{Duration, Instant},
22};
23use util::TryFutureExt;
24
25const JSON_RPC_VERSION: &str = "2.0";
26const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
27
28// Standard JSON-RPC error codes
29pub const PARSE_ERROR: i32 = -32700;
30pub const INVALID_REQUEST: i32 = -32600;
31pub const METHOD_NOT_FOUND: i32 = -32601;
32pub const INVALID_PARAMS: i32 = -32602;
33pub const INTERNAL_ERROR: i32 = -32603;
34
35type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
36type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
37
38#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
39#[serde(untagged)]
40pub enum RequestId {
41 Int(i32),
42 Str(String),
43}
44
45pub struct Client {
46 server_id: ContextServerId,
47 next_id: AtomicI32,
48 outbound_tx: channel::Sender<String>,
49 name: Arc<str>,
50 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
51 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
52 #[allow(clippy::type_complexity)]
53 #[allow(dead_code)]
54 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
55 #[allow(dead_code)]
56 output_done_rx: Mutex<Option<barrier::Receiver>>,
57 executor: BackgroundExecutor,
58 server: Arc<Mutex<Option<Child>>>,
59}
60
61#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
62#[repr(transparent)]
63pub struct ContextServerId(pub Arc<str>);
64
65fn is_null_value<T: Serialize>(value: &T) -> bool {
66 if let Ok(Value::Null) = serde_json::to_value(value) {
67 true
68 } else {
69 false
70 }
71}
72
73#[derive(Serialize, Deserialize)]
74struct Request<'a, T> {
75 jsonrpc: &'static str,
76 id: RequestId,
77 method: &'a str,
78 #[serde(skip_serializing_if = "is_null_value")]
79 params: T,
80}
81
82#[derive(Serialize, Deserialize)]
83struct AnyResponse<'a> {
84 jsonrpc: &'a str,
85 id: RequestId,
86 #[serde(default)]
87 error: Option<Error>,
88 #[serde(borrow)]
89 result: Option<&'a RawValue>,
90}
91
92#[derive(Deserialize)]
93#[allow(dead_code)]
94struct Response<T> {
95 jsonrpc: &'static str,
96 id: RequestId,
97 #[serde(flatten)]
98 value: CspResult<T>,
99}
100
101#[derive(Deserialize)]
102#[serde(rename_all = "snake_case")]
103enum CspResult<T> {
104 #[serde(rename = "result")]
105 Ok(Option<T>),
106 #[allow(dead_code)]
107 Error(Option<Error>),
108}
109
110#[derive(Serialize, Deserialize)]
111struct Notification<'a, T> {
112 jsonrpc: &'static str,
113 #[serde(borrow)]
114 method: &'a str,
115 params: T,
116}
117
118#[derive(Debug, Clone, Deserialize)]
119struct AnyNotification<'a> {
120 jsonrpc: &'a str,
121 method: String,
122 #[serde(default)]
123 params: Option<Value>,
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127struct Error {
128 message: String,
129}
130
131#[derive(Debug, Clone, Deserialize)]
132pub struct ModelContextServerBinary {
133 pub executable: PathBuf,
134 pub args: Vec<String>,
135 pub env: Option<HashMap<String, String>>,
136}
137
138impl Client {
139 /// Creates a new Client instance for a context server.
140 ///
141 /// This function initializes a new Client by spawning a child process for the context server,
142 /// setting up communication channels, and initializing handlers for input/output operations.
143 /// It takes a server ID, binary information, and an async app context as input.
144 pub fn new(
145 server_id: ContextServerId,
146 binary: ModelContextServerBinary,
147 cx: AsyncApp,
148 ) -> Result<Self> {
149 log::info!(
150 "starting context server (executable={:?}, args={:?})",
151 binary.executable,
152 &binary.args
153 );
154
155 let mut command = util::command::new_smol_command(&binary.executable);
156 command
157 .args(&binary.args)
158 .envs(binary.env.unwrap_or_default())
159 .stdin(std::process::Stdio::piped())
160 .stdout(std::process::Stdio::piped())
161 .stderr(std::process::Stdio::piped())
162 .kill_on_drop(true);
163
164 let mut server = command.spawn().with_context(|| {
165 format!(
166 "failed to spawn command. (path={:?}, args={:?})",
167 binary.executable, &binary.args
168 )
169 })?;
170
171 let stdin = server.stdin.take().unwrap();
172 let stdout = server.stdout.take().unwrap();
173 let stderr = server.stderr.take().unwrap();
174
175 let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
176 let (output_done_tx, output_done_rx) = barrier::channel();
177
178 let notification_handlers =
179 Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
180 let response_handlers =
181 Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
182
183 let stdout_input_task = cx.spawn({
184 let notification_handlers = notification_handlers.clone();
185 let response_handlers = response_handlers.clone();
186 move |cx| {
187 Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
188 }
189 });
190 let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
191 let input_task = cx.spawn(|_| async move {
192 let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
193 stdout.or(stderr)
194 });
195 let output_task = cx.background_spawn({
196 Self::handle_output(
197 stdin,
198 outbound_rx,
199 output_done_tx,
200 response_handlers.clone(),
201 )
202 .log_err()
203 });
204
205 let mut context_server = Self {
206 server_id,
207 notification_handlers,
208 response_handlers,
209 name: "".into(),
210 next_id: Default::default(),
211 outbound_tx,
212 executor: cx.background_executor().clone(),
213 io_tasks: Mutex::new(Some((input_task, output_task))),
214 output_done_rx: Mutex::new(Some(output_done_rx)),
215 server: Arc::new(Mutex::new(Some(server))),
216 };
217
218 if let Some(name) = binary.executable.file_name() {
219 context_server.name = name.to_string_lossy().into();
220 }
221
222 Ok(context_server)
223 }
224
225 /// Handles input from the server's stdout.
226 ///
227 /// This function continuously reads lines from the provided stdout stream,
228 /// parses them as JSON-RPC responses or notifications, and dispatches them
229 /// to the appropriate handlers. It processes both responses (which are matched
230 /// to pending requests) and notifications (which trigger registered handlers).
231 async fn handle_input<Stdout>(
232 stdout: Stdout,
233 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
234 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
235 cx: AsyncApp,
236 ) -> anyhow::Result<()>
237 where
238 Stdout: AsyncRead + Unpin + Send + 'static,
239 {
240 let mut stdout = BufReader::new(stdout);
241 let mut buffer = String::new();
242
243 loop {
244 buffer.clear();
245 if stdout.read_line(&mut buffer).await? == 0 {
246 return Ok(());
247 }
248
249 let content = buffer.trim();
250
251 if !content.is_empty() {
252 if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
253 if let Some(handlers) = response_handlers.lock().as_mut() {
254 if let Some(handler) = handlers.remove(&response.id) {
255 handler(Ok(content.to_string()));
256 }
257 }
258 } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
259 let mut notification_handlers = notification_handlers.lock();
260 if let Some(handler) =
261 notification_handlers.get_mut(notification.method.as_str())
262 {
263 handler(notification.params.unwrap_or(Value::Null), cx.clone());
264 }
265 }
266 }
267
268 smol::future::yield_now().await;
269 }
270 }
271
272 /// Handles the stderr output from the context server.
273 /// Continuously reads and logs any error messages from the server.
274 async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
275 where
276 Stderr: AsyncRead + Unpin + Send + 'static,
277 {
278 let mut stderr = BufReader::new(stderr);
279 let mut buffer = String::new();
280
281 loop {
282 buffer.clear();
283 if stderr.read_line(&mut buffer).await? == 0 {
284 return Ok(());
285 }
286 log::warn!("context server stderr: {}", buffer.trim());
287 smol::future::yield_now().await;
288 }
289 }
290
291 /// Handles the output to the context server's stdin.
292 /// This function continuously receives messages from the outbound channel,
293 /// writes them to the server's stdin, and manages the lifecycle of response handlers.
294 async fn handle_output<Stdin>(
295 stdin: Stdin,
296 outbound_rx: channel::Receiver<String>,
297 output_done_tx: barrier::Sender,
298 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
299 ) -> anyhow::Result<()>
300 where
301 Stdin: AsyncWrite + Unpin + Send + 'static,
302 {
303 let mut stdin = BufWriter::new(stdin);
304 let _clear_response_handlers = util::defer({
305 let response_handlers = response_handlers.clone();
306 move || {
307 response_handlers.lock().take();
308 }
309 });
310 while let Ok(message) = outbound_rx.recv().await {
311 log::trace!("outgoing message: {}", message);
312
313 stdin.write_all(message.as_bytes()).await?;
314 stdin.write_all(b"\n").await?;
315 stdin.flush().await?;
316 }
317 drop(output_done_tx);
318 Ok(())
319 }
320
321 /// Sends a JSON-RPC request to the context server and waits for a response.
322 /// This function handles serialization, deserialization, timeout, and error handling.
323 pub async fn request<T: DeserializeOwned>(
324 &self,
325 method: &str,
326 params: impl Serialize,
327 ) -> Result<T> {
328 let id = self.next_id.fetch_add(1, SeqCst);
329 let request = serde_json::to_string(&Request {
330 jsonrpc: JSON_RPC_VERSION,
331 id: RequestId::Int(id),
332 method,
333 params,
334 })
335 .unwrap();
336
337 let (tx, rx) = oneshot::channel();
338 let handle_response = self
339 .response_handlers
340 .lock()
341 .as_mut()
342 .ok_or_else(|| anyhow!("server shut down"))
343 .map(|handlers| {
344 handlers.insert(
345 RequestId::Int(id),
346 Box::new(move |result| {
347 let _ = tx.send(result);
348 }),
349 );
350 });
351
352 let send = self
353 .outbound_tx
354 .try_send(request)
355 .context("failed to write to context server's stdin");
356
357 let executor = self.executor.clone();
358 let started = Instant::now();
359 handle_response?;
360 send?;
361
362 let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
363 select! {
364 response = rx.fuse() => {
365 let elapsed = started.elapsed();
366 log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
367 match response? {
368 Ok(response) => {
369 let parsed: AnyResponse = serde_json::from_str(&response)?;
370 if let Some(error) = parsed.error {
371 Err(anyhow!(error.message))
372 } else if let Some(result) = parsed.result {
373 Ok(serde_json::from_str(result.get())?)
374 } else {
375 Err(anyhow!("Invalid response: no result or error"))
376 }
377 }
378 Err(_) => anyhow::bail!("cancelled")
379 }
380 }
381 _ = timeout => {
382 log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
383 anyhow::bail!("Context server request timeout");
384 }
385 }
386 }
387
388 /// Sends a notification to the context server without expecting a response.
389 /// This function serializes the notification and sends it through the outbound channel.
390 pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
391 let notification = serde_json::to_string(&Notification {
392 jsonrpc: JSON_RPC_VERSION,
393 method,
394 params,
395 })
396 .unwrap();
397 self.outbound_tx.try_send(notification)?;
398 Ok(())
399 }
400
401 pub fn on_notification<F>(&self, method: &'static str, f: F)
402 where
403 F: 'static + Send + FnMut(Value, AsyncApp),
404 {
405 self.notification_handlers
406 .lock()
407 .insert(method, Box::new(f));
408 }
409
410 pub fn name(&self) -> &str {
411 &self.name
412 }
413
414 pub fn server_id(&self) -> ContextServerId {
415 self.server_id.clone()
416 }
417}
418
419impl Drop for Client {
420 fn drop(&mut self) {
421 if let Some(mut server) = self.server.lock().take() {
422 let _ = server.kill();
423 }
424 }
425}
426
427impl fmt::Display for ContextServerId {
428 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
429 self.0.fmt(f)
430 }
431}
432
433impl fmt::Debug for Client {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 f.debug_struct("Context Server Client")
436 .field("id", &self.server_id.0)
437 .field("name", &self.name)
438 .finish_non_exhaustive()
439 }
440}