client.rs

  1use anyhow::{anyhow, Context as _, Result};
  2use collections::HashMap;
  3use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
  4use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
  5use parking_lot::Mutex;
  6use postage::barrier;
  7use serde::{de::DeserializeOwned, Deserialize, Serialize};
  8use serde_json::{value::RawValue, Value};
  9use smol::{
 10    channel,
 11    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
 12    process::Child,
 13};
 14use std::{
 15    fmt,
 16    path::PathBuf,
 17    sync::{
 18        atomic::{AtomicI32, Ordering::SeqCst},
 19        Arc,
 20    },
 21    time::{Duration, Instant},
 22};
 23use util::TryFutureExt;
 24
 25const JSON_RPC_VERSION: &str = "2.0";
 26const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
 27
 28// Standard JSON-RPC error codes
 29pub const PARSE_ERROR: i32 = -32700;
 30pub const INVALID_REQUEST: i32 = -32600;
 31pub const METHOD_NOT_FOUND: i32 = -32601;
 32pub const INVALID_PARAMS: i32 = -32602;
 33pub const INTERNAL_ERROR: i32 = -32603;
 34
 35type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
 36type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
 37
 38#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
 39#[serde(untagged)]
 40pub enum RequestId {
 41    Int(i32),
 42    Str(String),
 43}
 44
 45pub struct Client {
 46    server_id: ContextServerId,
 47    next_id: AtomicI32,
 48    outbound_tx: channel::Sender<String>,
 49    name: Arc<str>,
 50    notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
 51    response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
 52    #[allow(clippy::type_complexity)]
 53    #[allow(dead_code)]
 54    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
 55    #[allow(dead_code)]
 56    output_done_rx: Mutex<Option<barrier::Receiver>>,
 57    executor: BackgroundExecutor,
 58    server: Arc<Mutex<Option<Child>>>,
 59}
 60
 61#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 62#[repr(transparent)]
 63pub struct ContextServerId(pub Arc<str>);
 64
 65fn is_null_value<T: Serialize>(value: &T) -> bool {
 66    if let Ok(Value::Null) = serde_json::to_value(value) {
 67        true
 68    } else {
 69        false
 70    }
 71}
 72
 73#[derive(Serialize, Deserialize)]
 74struct Request<'a, T> {
 75    jsonrpc: &'static str,
 76    id: RequestId,
 77    method: &'a str,
 78    #[serde(skip_serializing_if = "is_null_value")]
 79    params: T,
 80}
 81
 82#[derive(Serialize, Deserialize)]
 83struct AnyResponse<'a> {
 84    jsonrpc: &'a str,
 85    id: RequestId,
 86    #[serde(default)]
 87    error: Option<Error>,
 88    #[serde(borrow)]
 89    result: Option<&'a RawValue>,
 90}
 91
 92#[derive(Deserialize)]
 93#[allow(dead_code)]
 94struct Response<T> {
 95    jsonrpc: &'static str,
 96    id: RequestId,
 97    #[serde(flatten)]
 98    value: CspResult<T>,
 99}
100
101#[derive(Deserialize)]
102#[serde(rename_all = "snake_case")]
103enum CspResult<T> {
104    #[serde(rename = "result")]
105    Ok(Option<T>),
106    #[allow(dead_code)]
107    Error(Option<Error>),
108}
109
110#[derive(Serialize, Deserialize)]
111struct Notification<'a, T> {
112    jsonrpc: &'static str,
113    #[serde(borrow)]
114    method: &'a str,
115    params: T,
116}
117
118#[derive(Debug, Clone, Deserialize)]
119struct AnyNotification<'a> {
120    jsonrpc: &'a str,
121    method: String,
122    #[serde(default)]
123    params: Option<Value>,
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127struct Error {
128    message: String,
129}
130
131#[derive(Debug, Clone, Deserialize)]
132pub struct ModelContextServerBinary {
133    pub executable: PathBuf,
134    pub args: Vec<String>,
135    pub env: Option<HashMap<String, String>>,
136}
137
138impl Client {
139    /// Creates a new Client instance for a context server.
140    ///
141    /// This function initializes a new Client by spawning a child process for the context server,
142    /// setting up communication channels, and initializing handlers for input/output operations.
143    /// It takes a server ID, binary information, and an async app context as input.
144    pub fn new(
145        server_id: ContextServerId,
146        binary: ModelContextServerBinary,
147        cx: AsyncApp,
148    ) -> Result<Self> {
149        log::info!(
150            "starting context server (executable={:?}, args={:?})",
151            binary.executable,
152            &binary.args
153        );
154
155        let mut command = util::command::new_smol_command(&binary.executable);
156        command
157            .args(&binary.args)
158            .envs(binary.env.unwrap_or_default())
159            .stdin(std::process::Stdio::piped())
160            .stdout(std::process::Stdio::piped())
161            .stderr(std::process::Stdio::piped())
162            .kill_on_drop(true);
163
164        let mut server = command.spawn().with_context(|| {
165            format!(
166                "failed to spawn command. (path={:?}, args={:?})",
167                binary.executable, &binary.args
168            )
169        })?;
170
171        let stdin = server.stdin.take().unwrap();
172        let stdout = server.stdout.take().unwrap();
173        let stderr = server.stderr.take().unwrap();
174
175        let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
176        let (output_done_tx, output_done_rx) = barrier::channel();
177
178        let notification_handlers =
179            Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
180        let response_handlers =
181            Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
182
183        let stdout_input_task = cx.spawn({
184            let notification_handlers = notification_handlers.clone();
185            let response_handlers = response_handlers.clone();
186            move |cx| {
187                Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
188            }
189        });
190        let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
191        let input_task = cx.spawn(|_| async move {
192            let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
193            stdout.or(stderr)
194        });
195        let output_task = cx.background_spawn({
196            Self::handle_output(
197                stdin,
198                outbound_rx,
199                output_done_tx,
200                response_handlers.clone(),
201            )
202            .log_err()
203        });
204
205        let mut context_server = Self {
206            server_id,
207            notification_handlers,
208            response_handlers,
209            name: "".into(),
210            next_id: Default::default(),
211            outbound_tx,
212            executor: cx.background_executor().clone(),
213            io_tasks: Mutex::new(Some((input_task, output_task))),
214            output_done_rx: Mutex::new(Some(output_done_rx)),
215            server: Arc::new(Mutex::new(Some(server))),
216        };
217
218        if let Some(name) = binary.executable.file_name() {
219            context_server.name = name.to_string_lossy().into();
220        }
221
222        Ok(context_server)
223    }
224
225    /// Handles input from the server's stdout.
226    ///
227    /// This function continuously reads lines from the provided stdout stream,
228    /// parses them as JSON-RPC responses or notifications, and dispatches them
229    /// to the appropriate handlers. It processes both responses (which are matched
230    /// to pending requests) and notifications (which trigger registered handlers).
231    async fn handle_input<Stdout>(
232        stdout: Stdout,
233        notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
234        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
235        cx: AsyncApp,
236    ) -> anyhow::Result<()>
237    where
238        Stdout: AsyncRead + Unpin + Send + 'static,
239    {
240        let mut stdout = BufReader::new(stdout);
241        let mut buffer = String::new();
242
243        loop {
244            buffer.clear();
245            if stdout.read_line(&mut buffer).await? == 0 {
246                return Ok(());
247            }
248
249            let content = buffer.trim();
250
251            if !content.is_empty() {
252                if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
253                    if let Some(handlers) = response_handlers.lock().as_mut() {
254                        if let Some(handler) = handlers.remove(&response.id) {
255                            handler(Ok(content.to_string()));
256                        }
257                    }
258                } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
259                    let mut notification_handlers = notification_handlers.lock();
260                    if let Some(handler) =
261                        notification_handlers.get_mut(notification.method.as_str())
262                    {
263                        handler(notification.params.unwrap_or(Value::Null), cx.clone());
264                    }
265                }
266            }
267
268            smol::future::yield_now().await;
269        }
270    }
271
272    /// Handles the stderr output from the context server.
273    /// Continuously reads and logs any error messages from the server.
274    async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
275    where
276        Stderr: AsyncRead + Unpin + Send + 'static,
277    {
278        let mut stderr = BufReader::new(stderr);
279        let mut buffer = String::new();
280
281        loop {
282            buffer.clear();
283            if stderr.read_line(&mut buffer).await? == 0 {
284                return Ok(());
285            }
286            log::warn!("context server stderr: {}", buffer.trim());
287            smol::future::yield_now().await;
288        }
289    }
290
291    /// Handles the output to the context server's stdin.
292    /// This function continuously receives messages from the outbound channel,
293    /// writes them to the server's stdin, and manages the lifecycle of response handlers.
294    async fn handle_output<Stdin>(
295        stdin: Stdin,
296        outbound_rx: channel::Receiver<String>,
297        output_done_tx: barrier::Sender,
298        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
299    ) -> anyhow::Result<()>
300    where
301        Stdin: AsyncWrite + Unpin + Send + 'static,
302    {
303        let mut stdin = BufWriter::new(stdin);
304        let _clear_response_handlers = util::defer({
305            let response_handlers = response_handlers.clone();
306            move || {
307                response_handlers.lock().take();
308            }
309        });
310        while let Ok(message) = outbound_rx.recv().await {
311            log::trace!("outgoing message: {}", message);
312
313            stdin.write_all(message.as_bytes()).await?;
314            stdin.write_all(b"\n").await?;
315            stdin.flush().await?;
316        }
317        drop(output_done_tx);
318        Ok(())
319    }
320
321    /// Sends a JSON-RPC request to the context server and waits for a response.
322    /// This function handles serialization, deserialization, timeout, and error handling.
323    pub async fn request<T: DeserializeOwned>(
324        &self,
325        method: &str,
326        params: impl Serialize,
327    ) -> Result<T> {
328        let id = self.next_id.fetch_add(1, SeqCst);
329        let request = serde_json::to_string(&Request {
330            jsonrpc: JSON_RPC_VERSION,
331            id: RequestId::Int(id),
332            method,
333            params,
334        })
335        .unwrap();
336
337        let (tx, rx) = oneshot::channel();
338        let handle_response = self
339            .response_handlers
340            .lock()
341            .as_mut()
342            .ok_or_else(|| anyhow!("server shut down"))
343            .map(|handlers| {
344                handlers.insert(
345                    RequestId::Int(id),
346                    Box::new(move |result| {
347                        let _ = tx.send(result);
348                    }),
349                );
350            });
351
352        let send = self
353            .outbound_tx
354            .try_send(request)
355            .context("failed to write to context server's stdin");
356
357        let executor = self.executor.clone();
358        let started = Instant::now();
359        handle_response?;
360        send?;
361
362        let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
363        select! {
364            response = rx.fuse() => {
365                let elapsed = started.elapsed();
366                log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
367                match response? {
368                    Ok(response) => {
369                        let parsed: AnyResponse = serde_json::from_str(&response)?;
370                        if let Some(error) = parsed.error {
371                            Err(anyhow!(error.message))
372                        } else if let Some(result) = parsed.result {
373                            Ok(serde_json::from_str(result.get())?)
374                        } else {
375                            Err(anyhow!("Invalid response: no result or error"))
376                        }
377                    }
378                    Err(_) => anyhow::bail!("cancelled")
379                }
380            }
381            _ = timeout => {
382                log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
383                anyhow::bail!("Context server request timeout");
384            }
385        }
386    }
387
388    /// Sends a notification to the context server without expecting a response.
389    /// This function serializes the notification and sends it through the outbound channel.
390    pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
391        let notification = serde_json::to_string(&Notification {
392            jsonrpc: JSON_RPC_VERSION,
393            method,
394            params,
395        })
396        .unwrap();
397        self.outbound_tx.try_send(notification)?;
398        Ok(())
399    }
400
401    pub fn on_notification<F>(&self, method: &'static str, f: F)
402    where
403        F: 'static + Send + FnMut(Value, AsyncApp),
404    {
405        self.notification_handlers
406            .lock()
407            .insert(method, Box::new(f));
408    }
409
410    pub fn name(&self) -> &str {
411        &self.name
412    }
413
414    pub fn server_id(&self) -> ContextServerId {
415        self.server_id.clone()
416    }
417}
418
419impl Drop for Client {
420    fn drop(&mut self) {
421        if let Some(mut server) = self.server.lock().take() {
422            let _ = server.kill();
423        }
424    }
425}
426
427impl fmt::Display for ContextServerId {
428    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
429        self.0.fmt(f)
430    }
431}
432
433impl fmt::Debug for Client {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        f.debug_struct("Context Server Client")
436            .field("id", &self.server_id.0)
437            .field("name", &self.name)
438            .finish_non_exhaustive()
439    }
440}