1use anyhow::{anyhow, Context, Result};
2use collections::HashMap;
3use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
4use gpui::{AsyncAppContext, BackgroundExecutor, Task};
5use parking_lot::Mutex;
6use postage::barrier;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8use serde_json::{value::RawValue, Value};
9use smol::{
10 channel,
11 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
12 process::{self, Child},
13};
14use std::{
15 fmt,
16 path::PathBuf,
17 sync::{
18 atomic::{AtomicI32, Ordering::SeqCst},
19 Arc,
20 },
21 time::{Duration, Instant},
22};
23use util::TryFutureExt;
24
25const JSON_RPC_VERSION: &str = "2.0";
26const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
27
28type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
29type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
30
31#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum RequestId {
34 Int(i32),
35 Str(String),
36}
37
38pub struct Client {
39 server_id: ContextServerId,
40 next_id: AtomicI32,
41 outbound_tx: channel::Sender<String>,
42 name: Arc<str>,
43 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
44 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
45 #[allow(clippy::type_complexity)]
46 #[allow(dead_code)]
47 io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
48 #[allow(dead_code)]
49 output_done_rx: Mutex<Option<barrier::Receiver>>,
50 executor: BackgroundExecutor,
51 server: Arc<Mutex<Option<Child>>>,
52}
53
54#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
55#[repr(transparent)]
56pub struct ContextServerId(pub String);
57
58#[derive(Serialize, Deserialize)]
59struct Request<'a, T> {
60 jsonrpc: &'static str,
61 id: RequestId,
62 method: &'a str,
63 params: T,
64}
65
66#[derive(Serialize, Deserialize)]
67struct AnyResponse<'a> {
68 jsonrpc: &'a str,
69 id: RequestId,
70 #[serde(default)]
71 error: Option<Error>,
72 #[serde(borrow)]
73 result: Option<&'a RawValue>,
74}
75
76#[derive(Deserialize)]
77#[allow(dead_code)]
78struct Response<T> {
79 jsonrpc: &'static str,
80 id: RequestId,
81 #[serde(flatten)]
82 value: CspResult<T>,
83}
84
85#[derive(Deserialize)]
86#[serde(rename_all = "snake_case")]
87enum CspResult<T> {
88 #[serde(rename = "result")]
89 Ok(Option<T>),
90 #[allow(dead_code)]
91 Error(Option<Error>),
92}
93
94#[derive(Serialize, Deserialize)]
95struct Notification<'a, T> {
96 jsonrpc: &'static str,
97 #[serde(borrow)]
98 method: &'a str,
99 params: T,
100}
101
102#[derive(Debug, Clone, Deserialize)]
103struct AnyNotification<'a> {
104 jsonrpc: &'a str,
105 method: String,
106 #[serde(default)]
107 params: Option<Value>,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct Error {
112 message: String,
113}
114
115#[derive(Debug, Clone, Deserialize)]
116pub struct ModelContextServerBinary {
117 pub executable: PathBuf,
118 pub args: Vec<String>,
119 pub env: Option<HashMap<String, String>>,
120}
121
122impl Client {
123 /// Creates a new Client instance for a context server.
124 ///
125 /// This function initializes a new Client by spawning a child process for the context server,
126 /// setting up communication channels, and initializing handlers for input/output operations.
127 /// It takes a server ID, binary information, and an async app context as input.
128 pub fn new(
129 server_id: ContextServerId,
130 binary: ModelContextServerBinary,
131 cx: AsyncAppContext,
132 ) -> Result<Self> {
133 log::info!(
134 "starting context server (executable={:?}, args={:?})",
135 binary.executable,
136 &binary.args
137 );
138
139 let mut command = process::Command::new(&binary.executable);
140 command
141 .args(&binary.args)
142 .envs(binary.env.unwrap_or_default())
143 .stdin(std::process::Stdio::piped())
144 .stdout(std::process::Stdio::piped())
145 .stderr(std::process::Stdio::piped())
146 .kill_on_drop(true);
147
148 let mut server = command.spawn().with_context(|| {
149 format!(
150 "failed to spawn command. (path={:?}, args={:?})",
151 binary.executable, &binary.args
152 )
153 })?;
154
155 let stdin = server.stdin.take().unwrap();
156 let stdout = server.stdout.take().unwrap();
157 let stderr = server.stderr.take().unwrap();
158
159 let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
160 let (output_done_tx, output_done_rx) = barrier::channel();
161
162 let notification_handlers =
163 Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
164 let response_handlers =
165 Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
166
167 let stdout_input_task = cx.spawn({
168 let notification_handlers = notification_handlers.clone();
169 let response_handlers = response_handlers.clone();
170 move |cx| {
171 Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
172 }
173 });
174 let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
175 let input_task = cx.spawn(|_| async move {
176 let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
177 stdout.or(stderr)
178 });
179 let output_task = cx.background_executor().spawn({
180 Self::handle_output(
181 stdin,
182 outbound_rx,
183 output_done_tx,
184 response_handlers.clone(),
185 )
186 .log_err()
187 });
188
189 let mut context_server = Self {
190 server_id,
191 notification_handlers,
192 response_handlers,
193 name: "".into(),
194 next_id: Default::default(),
195 outbound_tx,
196 executor: cx.background_executor().clone(),
197 io_tasks: Mutex::new(Some((input_task, output_task))),
198 output_done_rx: Mutex::new(Some(output_done_rx)),
199 server: Arc::new(Mutex::new(Some(server))),
200 };
201
202 if let Some(name) = binary.executable.file_name() {
203 context_server.name = name.to_string_lossy().into();
204 }
205
206 Ok(context_server)
207 }
208
209 /// Handles input from the server's stdout.
210 ///
211 /// This function continuously reads lines from the provided stdout stream,
212 /// parses them as JSON-RPC responses or notifications, and dispatches them
213 /// to the appropriate handlers. It processes both responses (which are matched
214 /// to pending requests) and notifications (which trigger registered handlers).
215 async fn handle_input<Stdout>(
216 stdout: Stdout,
217 notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
218 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
219 cx: AsyncAppContext,
220 ) -> anyhow::Result<()>
221 where
222 Stdout: AsyncRead + Unpin + Send + 'static,
223 {
224 let mut stdout = BufReader::new(stdout);
225 let mut buffer = String::new();
226
227 loop {
228 buffer.clear();
229 if stdout.read_line(&mut buffer).await? == 0 {
230 return Ok(());
231 }
232
233 let content = buffer.trim();
234
235 if !content.is_empty() {
236 if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
237 if let Some(handlers) = response_handlers.lock().as_mut() {
238 if let Some(handler) = handlers.remove(&response.id) {
239 handler(Ok(content.to_string()));
240 }
241 }
242 } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
243 let mut notification_handlers = notification_handlers.lock();
244 if let Some(handler) =
245 notification_handlers.get_mut(notification.method.as_str())
246 {
247 handler(notification.params.unwrap_or(Value::Null), cx.clone());
248 }
249 }
250 }
251
252 smol::future::yield_now().await;
253 }
254 }
255
256 /// Handles the stderr output from the context server.
257 /// Continuously reads and logs any error messages from the server.
258 async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
259 where
260 Stderr: AsyncRead + Unpin + Send + 'static,
261 {
262 let mut stderr = BufReader::new(stderr);
263 let mut buffer = String::new();
264
265 loop {
266 buffer.clear();
267 if stderr.read_line(&mut buffer).await? == 0 {
268 return Ok(());
269 }
270 log::warn!("context server stderr: {}", buffer.trim());
271 smol::future::yield_now().await;
272 }
273 }
274
275 /// Handles the output to the context server's stdin.
276 /// This function continuously receives messages from the outbound channel,
277 /// writes them to the server's stdin, and manages the lifecycle of response handlers.
278 async fn handle_output<Stdin>(
279 stdin: Stdin,
280 outbound_rx: channel::Receiver<String>,
281 output_done_tx: barrier::Sender,
282 response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
283 ) -> anyhow::Result<()>
284 where
285 Stdin: AsyncWrite + Unpin + Send + 'static,
286 {
287 let mut stdin = BufWriter::new(stdin);
288 let _clear_response_handlers = util::defer({
289 let response_handlers = response_handlers.clone();
290 move || {
291 response_handlers.lock().take();
292 }
293 });
294 while let Ok(message) = outbound_rx.recv().await {
295 log::trace!("outgoing message: {}", message);
296
297 stdin.write_all(message.as_bytes()).await?;
298 stdin.write_all(b"\n").await?;
299 stdin.flush().await?;
300 }
301 drop(output_done_tx);
302 Ok(())
303 }
304
305 /// Sends a JSON-RPC request to the context server and waits for a response.
306 /// This function handles serialization, deserialization, timeout, and error handling.
307 pub async fn request<T: DeserializeOwned>(
308 &self,
309 method: &str,
310 params: impl Serialize,
311 ) -> Result<T> {
312 let id = self.next_id.fetch_add(1, SeqCst);
313 let request = serde_json::to_string(&Request {
314 jsonrpc: JSON_RPC_VERSION,
315 id: RequestId::Int(id),
316 method,
317 params,
318 })
319 .unwrap();
320
321 let (tx, rx) = oneshot::channel();
322 let handle_response = self
323 .response_handlers
324 .lock()
325 .as_mut()
326 .ok_or_else(|| anyhow!("server shut down"))
327 .map(|handlers| {
328 handlers.insert(
329 RequestId::Int(id),
330 Box::new(move |result| {
331 let _ = tx.send(result);
332 }),
333 );
334 });
335
336 let send = self
337 .outbound_tx
338 .try_send(request)
339 .context("failed to write to context server's stdin");
340
341 let executor = self.executor.clone();
342 let started = Instant::now();
343 handle_response?;
344 send?;
345
346 let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
347 select! {
348 response = rx.fuse() => {
349 let elapsed = started.elapsed();
350 log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
351 match response? {
352 Ok(response) => {
353 let parsed: AnyResponse = serde_json::from_str(&response)?;
354 if let Some(error) = parsed.error {
355 Err(anyhow!(error.message))
356 } else if let Some(result) = parsed.result {
357 Ok(serde_json::from_str(result.get())?)
358 } else {
359 Err(anyhow!("Invalid response: no result or error"))
360 }
361 }
362 Err(_) => anyhow::bail!("cancelled")
363 }
364 }
365 _ = timeout => {
366 log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
367 anyhow::bail!("Context server request timeout");
368 }
369 }
370 }
371
372 /// Sends a notification to the context server without expecting a response.
373 /// This function serializes the notification and sends it through the outbound channel.
374 pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
375 let notification = serde_json::to_string(&Notification {
376 jsonrpc: JSON_RPC_VERSION,
377 method,
378 params,
379 })
380 .unwrap();
381 self.outbound_tx.try_send(notification)?;
382 Ok(())
383 }
384
385 pub fn on_notification<F>(&self, method: &'static str, f: F)
386 where
387 F: 'static + Send + FnMut(Value, AsyncAppContext),
388 {
389 self.notification_handlers
390 .lock()
391 .insert(method, Box::new(f));
392 }
393
394 pub fn name(&self) -> &str {
395 &self.name
396 }
397
398 pub fn server_id(&self) -> ContextServerId {
399 self.server_id.clone()
400 }
401}
402
403impl Drop for Client {
404 fn drop(&mut self) {
405 if let Some(mut server) = self.server.lock().take() {
406 let _ = server.kill();
407 }
408 }
409}
410
411impl fmt::Display for ContextServerId {
412 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413 self.0.fmt(f)
414 }
415}
416
417impl fmt::Debug for Client {
418 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
419 f.debug_struct("Context Server Client")
420 .field("id", &self.server_id.0)
421 .field("name", &self.name)
422 .finish_non_exhaustive()
423 }
424}