stdio_transport.rs

  1use std::path::PathBuf;
  2use std::pin::Pin;
  3
  4use anyhow::{Context as _, Result};
  5use async_trait::async_trait;
  6use futures::io::{BufReader, BufWriter};
  7use futures::{
  8    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
  9};
 10use gpui::AsyncApp;
 11use smol::channel;
 12use smol::process::Child;
 13use util::TryFutureExt as _;
 14
 15use crate::client::ModelContextServerBinary;
 16use crate::transport::Transport;
 17
 18pub struct StdioTransport {
 19    stdout_sender: channel::Sender<String>,
 20    stdin_receiver: channel::Receiver<String>,
 21    stderr_receiver: channel::Receiver<String>,
 22    server: Child,
 23}
 24
 25impl StdioTransport {
 26    pub fn new(
 27        binary: ModelContextServerBinary,
 28        working_directory: &Option<PathBuf>,
 29        cx: &AsyncApp,
 30    ) -> Result<Self> {
 31        let mut command = util::command::new_smol_command(&binary.executable);
 32        command
 33            .args(&binary.args)
 34            .envs(binary.env.unwrap_or_default())
 35            .stdin(std::process::Stdio::piped())
 36            .stdout(std::process::Stdio::piped())
 37            .stderr(std::process::Stdio::piped())
 38            .kill_on_drop(true);
 39
 40        if let Some(working_directory) = working_directory {
 41            command.current_dir(working_directory);
 42        }
 43
 44        let mut server = command.spawn().with_context(|| {
 45            format!(
 46                "failed to spawn command. (path={:?}, args={:?})",
 47                binary.executable, &binary.args
 48            )
 49        })?;
 50
 51        let stdin = server.stdin.take().unwrap();
 52        let stdout = server.stdout.take().unwrap();
 53        let stderr = server.stderr.take().unwrap();
 54
 55        let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
 56        let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
 57        let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
 58
 59        cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
 60            .detach();
 61
 62        cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
 63            .detach();
 64
 65        cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
 66            .detach();
 67
 68        Ok(Self {
 69            stdout_sender,
 70            stdin_receiver,
 71            stderr_receiver,
 72            server,
 73        })
 74    }
 75
 76    async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
 77    where
 78        Stdout: AsyncRead + Unpin + Send + 'static,
 79    {
 80        let mut stdin = BufReader::new(stdin);
 81        let mut line = String::new();
 82        while let Ok(n) = stdin.read_line(&mut line).await {
 83            if n == 0 {
 84                break;
 85            }
 86            if inbound_rx.send(line.clone()).await.is_err() {
 87                break;
 88            }
 89            line.clear();
 90        }
 91    }
 92
 93    async fn handle_output<Stdin>(
 94        stdin: Stdin,
 95        outbound_rx: channel::Receiver<String>,
 96    ) -> Result<()>
 97    where
 98        Stdin: AsyncWrite + Unpin + Send + 'static,
 99    {
100        let mut stdin = BufWriter::new(stdin);
101        let mut pinned_rx = Box::pin(outbound_rx);
102        while let Some(message) = pinned_rx.next().await {
103            log::trace!("outgoing message: {}", message);
104
105            stdin.write_all(message.as_bytes()).await?;
106            stdin.write_all(b"\n").await?;
107            stdin.flush().await?;
108        }
109        Ok(())
110    }
111
112    async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
113    where
114        Stderr: AsyncRead + Unpin + Send + 'static,
115    {
116        let mut stderr = BufReader::new(stderr);
117        let mut line = String::new();
118        while let Ok(n) = stderr.read_line(&mut line).await {
119            if n == 0 {
120                break;
121            }
122            if stderr_tx.send(line.clone()).await.is_err() {
123                break;
124            }
125            line.clear();
126        }
127    }
128}
129
130#[async_trait]
131impl Transport for StdioTransport {
132    async fn send(&self, message: String) -> Result<()> {
133        Ok(self.stdout_sender.send(message).await?)
134    }
135
136    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
137        Box::pin(self.stdin_receiver.clone())
138    }
139
140    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
141        Box::pin(self.stderr_receiver.clone())
142    }
143}
144
145impl Drop for StdioTransport {
146    fn drop(&mut self) {
147        let _ = self.server.kill();
148    }
149}