stdio_transport.rs

  1use std::pin::Pin;
  2
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use futures::io::{BufReader, BufWriter};
  6use futures::{
  7    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
  8};
  9use gpui::AsyncApp;
 10use smol::channel;
 11use smol::process::Child;
 12use util::TryFutureExt as _;
 13
 14use crate::client::ModelContextServerBinary;
 15use crate::transport::Transport;
 16
 17pub struct StdioTransport {
 18    stdout_sender: channel::Sender<String>,
 19    stdin_receiver: channel::Receiver<String>,
 20    stderr_receiver: channel::Receiver<String>,
 21    server: Child,
 22}
 23
 24impl StdioTransport {
 25    pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
 26        let mut command = util::command::new_smol_command(&binary.executable);
 27        command
 28            .args(&binary.args)
 29            .envs(binary.env.unwrap_or_default())
 30            .stdin(std::process::Stdio::piped())
 31            .stdout(std::process::Stdio::piped())
 32            .stderr(std::process::Stdio::piped())
 33            .kill_on_drop(true);
 34
 35        let mut server = command.spawn().with_context(|| {
 36            format!(
 37                "failed to spawn command. (path={:?}, args={:?})",
 38                binary.executable, &binary.args
 39            )
 40        })?;
 41
 42        let stdin = server.stdin.take().unwrap();
 43        let stdout = server.stdout.take().unwrap();
 44        let stderr = server.stderr.take().unwrap();
 45
 46        let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
 47        let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
 48        let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
 49
 50        cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
 51            .detach();
 52
 53        cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
 54            .detach();
 55
 56        cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
 57            .detach();
 58
 59        Ok(Self {
 60            stdout_sender,
 61            stdin_receiver,
 62            stderr_receiver,
 63            server,
 64        })
 65    }
 66
 67    async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
 68    where
 69        Stdout: AsyncRead + Unpin + Send + 'static,
 70    {
 71        let mut stdin = BufReader::new(stdin);
 72        let mut line = String::new();
 73        while let Ok(n) = stdin.read_line(&mut line).await {
 74            if n == 0 {
 75                break;
 76            }
 77            if inbound_rx.send(line.clone()).await.is_err() {
 78                break;
 79            }
 80            line.clear();
 81        }
 82    }
 83
 84    async fn handle_output<Stdin>(
 85        stdin: Stdin,
 86        outbound_rx: channel::Receiver<String>,
 87    ) -> Result<()>
 88    where
 89        Stdin: AsyncWrite + Unpin + Send + 'static,
 90    {
 91        let mut stdin = BufWriter::new(stdin);
 92        let mut pinned_rx = Box::pin(outbound_rx);
 93        while let Some(message) = pinned_rx.next().await {
 94            log::trace!("outgoing message: {}", message);
 95
 96            stdin.write_all(message.as_bytes()).await?;
 97            stdin.write_all(b"\n").await?;
 98            stdin.flush().await?;
 99        }
100        Ok(())
101    }
102
103    async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
104    where
105        Stderr: AsyncRead + Unpin + Send + 'static,
106    {
107        let mut stderr = BufReader::new(stderr);
108        let mut line = String::new();
109        while let Ok(n) = stderr.read_line(&mut line).await {
110            if n == 0 {
111                break;
112            }
113            if stderr_tx.send(line.clone()).await.is_err() {
114                break;
115            }
116            line.clear();
117        }
118    }
119}
120
121#[async_trait]
122impl Transport for StdioTransport {
123    async fn send(&self, message: String) -> Result<()> {
124        Ok(self.stdout_sender.send(message).await?)
125    }
126
127    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
128        Box::pin(self.stdin_receiver.clone())
129    }
130
131    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
132        Box::pin(self.stderr_receiver.clone())
133    }
134}
135
136impl Drop for StdioTransport {
137    fn drop(&mut self) {
138        let _ = self.server.kill();
139    }
140}