1use anyhow::{anyhow, Context, Result};
2use collections::HashMap;
3use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
4use gpui::{AsyncAppContext, BackgroundExecutor, Task};
5use parking_lot::Mutex;
6use postage::barrier;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use serde_json::{value::RawValue, Value};
9use smol::{
10 channel,
11 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
12 process::{self, Child},
13};
14use std::{
15 fmt,
16 path::PathBuf,
17 sync::{
18 atomic::{AtomicI32, Ordering::SeqCst},
19 Arc,
20 },
21 time::{Duration, Instant},
22};
23use util::TryFutureExt;
24
25const JSON_RPC_VERSION: &str = "2.0";
26const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
27
28type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
29type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
30
31#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum RequestId {
34 Int(i32),
35 Str(String),
36}
37
38pub struct Client {
39 server_id: ContextServerId,
40 next_id: AtomicI32,
41 outbound_tx: channel::Sender<String>,
42 name: Arc<str>,
43 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
44 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
45 #[allow(clippy::type_complexity)]
46 #[allow(dead_code)]
47 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
48 #[allow(dead_code)]
49 output_done_rx: Mutex<Option<barrier::Receiver>>,
50 executor: BackgroundExecutor,
51 server: Arc<Mutex<Option<Child>>>,
52}
53
54#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
55#[repr(transparent)]
56pub struct ContextServerId(pub String);
57
58fn is_null_value<T: Serialize>(value: &T) -> bool {
59 if let Ok(Value::Null) = serde_json::to_value(value) {
60 true
61 } else {
62 false
63 }
64}
65
66#[derive(Serialize, Deserialize)]
67struct Request<'a, T> {
68 jsonrpc: &'static str,
69 id: RequestId,
70 method: &'a str,
71 #[serde(skip_serializing_if = "is_null_value")]
72 params: T,
73}
74
75#[derive(Serialize, Deserialize)]
76struct AnyResponse<'a> {
77 jsonrpc: &'a str,
78 id: RequestId,
79 #[serde(default)]
80 error: Option<Error>,
81 #[serde(borrow)]
82 result: Option<&'a RawValue>,
83}
84
85#[derive(Deserialize)]
86#[allow(dead_code)]
87struct Response<T> {
88 jsonrpc: &'static str,
89 id: RequestId,
90 #[serde(flatten)]
91 value: CspResult<T>,
92}
93
94#[derive(Deserialize)]
95#[serde(rename_all = "snake_case")]
96enum CspResult<T> {
97 #[serde(rename = "result")]
98 Ok(Option<T>),
99 #[allow(dead_code)]
100 Error(Option<Error>),
101}
102
103#[derive(Serialize, Deserialize)]
104struct Notification<'a, T> {
105 jsonrpc: &'static str,
106 #[serde(borrow)]
107 method: &'a str,
108 params: T,
109}
110
111#[derive(Debug, Clone, Deserialize)]
112struct AnyNotification<'a> {
113 jsonrpc: &'a str,
114 method: String,
115 #[serde(default)]
116 params: Option<Value>,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct Error {
121 message: String,
122}
123
124#[derive(Debug, Clone, Deserialize)]
125pub struct ModelContextServerBinary {
126 pub executable: PathBuf,
127 pub args: Vec<String>,
128 pub env: Option<HashMap<String, String>>,
129}
130
131impl Client {
132 /// Creates a new Client instance for a context server.
133 ///
134 /// This function initializes a new Client by spawning a child process for the context server,
135 /// setting up communication channels, and initializing handlers for input/output operations.
136 /// It takes a server ID, binary information, and an async app context as input.
137 pub fn new(
138 server_id: ContextServerId,
139 binary: ModelContextServerBinary,
140 cx: AsyncAppContext,
141 ) -> Result<Self> {
142 log::info!(
143 "starting context server (executable={:?}, args={:?})",
144 binary.executable,
145 &binary.args
146 );
147
148 let mut command = process::Command::new(&binary.executable);
149 command
150 .args(&binary.args)
151 .envs(binary.env.unwrap_or_default())
152 .stdin(std::process::Stdio::piped())
153 .stdout(std::process::Stdio::piped())
154 .stderr(std::process::Stdio::piped())
155 .kill_on_drop(true);
156
157 let mut server = command.spawn().with_context(|| {
158 format!(
159 "failed to spawn command. (path={:?}, args={:?})",
160 binary.executable, &binary.args
161 )
162 })?;
163
164 let stdin = server.stdin.take().unwrap();
165 let stdout = server.stdout.take().unwrap();
166 let stderr = server.stderr.take().unwrap();
167
168 let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
169 let (output_done_tx, output_done_rx) = barrier::channel();
170
171 let notification_handlers =
172 Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
173 let response_handlers =
174 Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
175
176 let stdout_input_task = cx.spawn({
177 let notification_handlers = notification_handlers.clone();
178 let response_handlers = response_handlers.clone();
179 move |cx| {
180 Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
181 }
182 });
183 let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
184 let input_task = cx.spawn(|_| async move {
185 let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
186 stdout.or(stderr)
187 });
188 let output_task = cx.background_executor().spawn({
189 Self::handle_output(
190 stdin,
191 outbound_rx,
192 output_done_tx,
193 response_handlers.clone(),
194 )
195 .log_err()
196 });
197
198 let mut context_server = Self {
199 server_id,
200 notification_handlers,
201 response_handlers,
202 name: "".into(),
203 next_id: Default::default(),
204 outbound_tx,
205 executor: cx.background_executor().clone(),
206 io_tasks: Mutex::new(Some((input_task, output_task))),
207 output_done_rx: Mutex::new(Some(output_done_rx)),
208 server: Arc::new(Mutex::new(Some(server))),
209 };
210
211 if let Some(name) = binary.executable.file_name() {
212 context_server.name = name.to_string_lossy().into();
213 }
214
215 Ok(context_server)
216 }
217
218 /// Handles input from the server's stdout.
219 ///
220 /// This function continuously reads lines from the provided stdout stream,
221 /// parses them as JSON-RPC responses or notifications, and dispatches them
222 /// to the appropriate handlers. It processes both responses (which are matched
223 /// to pending requests) and notifications (which trigger registered handlers).
224 async fn handle_input<Stdout>(
225 stdout: Stdout,
226 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
227 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
228 cx: AsyncAppContext,
229 ) -> anyhow::Result<()>
230 where
231 Stdout: AsyncRead + Unpin + Send + 'static,
232 {
233 let mut stdout = BufReader::new(stdout);
234 let mut buffer = String::new();
235
236 loop {
237 buffer.clear();
238 if stdout.read_line(&mut buffer).await? == 0 {
239 return Ok(());
240 }
241
242 let content = buffer.trim();
243
244 if !content.is_empty() {
245 if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
246 if let Some(handlers) = response_handlers.lock().as_mut() {
247 if let Some(handler) = handlers.remove(&response.id) {
248 handler(Ok(content.to_string()));
249 }
250 }
251 } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
252 let mut notification_handlers = notification_handlers.lock();
253 if let Some(handler) =
254 notification_handlers.get_mut(notification.method.as_str())
255 {
256 handler(notification.params.unwrap_or(Value::Null), cx.clone());
257 }
258 }
259 }
260
261 smol::future::yield_now().await;
262 }
263 }
264
265 /// Handles the stderr output from the context server.
266 /// Continuously reads and logs any error messages from the server.
267 async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
268 where
269 Stderr: AsyncRead + Unpin + Send + 'static,
270 {
271 let mut stderr = BufReader::new(stderr);
272 let mut buffer = String::new();
273
274 loop {
275 buffer.clear();
276 if stderr.read_line(&mut buffer).await? == 0 {
277 return Ok(());
278 }
279 log::warn!("context server stderr: {}", buffer.trim());
280 smol::future::yield_now().await;
281 }
282 }
283
284 /// Handles the output to the context server's stdin.
285 /// This function continuously receives messages from the outbound channel,
286 /// writes them to the server's stdin, and manages the lifecycle of response handlers.
287 async fn handle_output<Stdin>(
288 stdin: Stdin,
289 outbound_rx: channel::Receiver<String>,
290 output_done_tx: barrier::Sender,
291 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
292 ) -> anyhow::Result<()>
293 where
294 Stdin: AsyncWrite + Unpin + Send + 'static,
295 {
296 let mut stdin = BufWriter::new(stdin);
297 let _clear_response_handlers = util::defer({
298 let response_handlers = response_handlers.clone();
299 move || {
300 response_handlers.lock().take();
301 }
302 });
303 while let Ok(message) = outbound_rx.recv().await {
304 log::trace!("outgoing message: {}", message);
305
306 stdin.write_all(message.as_bytes()).await?;
307 stdin.write_all(b"\n").await?;
308 stdin.flush().await?;
309 }
310 drop(output_done_tx);
311 Ok(())
312 }
313
314 /// Sends a JSON-RPC request to the context server and waits for a response.
315 /// This function handles serialization, deserialization, timeout, and error handling.
316 pub async fn request<T: DeserializeOwned>(
317 &self,
318 method: &str,
319 params: impl Serialize,
320 ) -> Result<T> {
321 let id = self.next_id.fetch_add(1, SeqCst);
322 let request = serde_json::to_string(&Request {
323 jsonrpc: JSON_RPC_VERSION,
324 id: RequestId::Int(id),
325 method,
326 params,
327 })
328 .unwrap();
329
330 let (tx, rx) = oneshot::channel();
331 let handle_response = self
332 .response_handlers
333 .lock()
334 .as_mut()
335 .ok_or_else(|| anyhow!("server shut down"))
336 .map(|handlers| {
337 handlers.insert(
338 RequestId::Int(id),
339 Box::new(move |result| {
340 let _ = tx.send(result);
341 }),
342 );
343 });
344
345 let send = self
346 .outbound_tx
347 .try_send(request)
348 .context("failed to write to context server's stdin");
349
350 let executor = self.executor.clone();
351 let started = Instant::now();
352 handle_response?;
353 send?;
354
355 let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
356 select! {
357 response = rx.fuse() => {
358 let elapsed = started.elapsed();
359 log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
360 match response? {
361 Ok(response) => {
362 let parsed: AnyResponse = serde_json::from_str(&response)?;
363 if let Some(error) = parsed.error {
364 Err(anyhow!(error.message))
365 } else if let Some(result) = parsed.result {
366 Ok(serde_json::from_str(result.get())?)
367 } else {
368 Err(anyhow!("Invalid response: no result or error"))
369 }
370 }
371 Err(_) => anyhow::bail!("cancelled")
372 }
373 }
374 _ = timeout => {
375 log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
376 anyhow::bail!("Context server request timeout");
377 }
378 }
379 }
380
381 /// Sends a notification to the context server without expecting a response.
382 /// This function serializes the notification and sends it through the outbound channel.
383 pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
384 let notification = serde_json::to_string(&Notification {
385 jsonrpc: JSON_RPC_VERSION,
386 method,
387 params,
388 })
389 .unwrap();
390 self.outbound_tx.try_send(notification)?;
391 Ok(())
392 }
393
394 pub fn on_notification<F>(&self, method: &'static str, f: F)
395 where
396 F: 'static + Send + FnMut(Value, AsyncAppContext),
397 {
398 self.notification_handlers
399 .lock()
400 .insert(method, Box::new(f));
401 }
402
403 pub fn name(&self) -> &str {
404 &self.name
405 }
406
407 pub fn server_id(&self) -> ContextServerId {
408 self.server_id.clone()
409 }
410}
411
412impl Drop for Client {
413 fn drop(&mut self) {
414 if let Some(mut server) = self.server.lock().take() {
415 let _ = server.kill();
416 }
417 }
418}
419
420impl fmt::Display for ContextServerId {
421 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422 self.0.fmt(f)
423 }
424}
425
426impl fmt::Debug for Client {
427 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
428 f.debug_struct("Context Server Client")
429 .field("id", &self.server_id.0)
430 .field("name", &self.name)
431 .finish_non_exhaustive()
432 }
433}