stdio_transport.rs

  1use std::path::PathBuf;
  2use std::pin::Pin;
  3
  4use anyhow::{Context as _, Result};
  5use async_trait::async_trait;
  6use futures::io::{BufReader, BufWriter};
  7use futures::{
  8    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
  9};
 10use gpui::AsyncApp;
 11use smol::channel;
 12use smol::process::Child;
 13use util::TryFutureExt as _;
 14use util::shell::Shell;
 15use util::shell_builder::ShellBuilder;
 16
 17use crate::client::ModelContextServerBinary;
 18use crate::transport::Transport;
 19
 20pub struct StdioTransport {
 21    stdout_sender: channel::Sender<String>,
 22    stdin_receiver: channel::Receiver<String>,
 23    stderr_receiver: channel::Receiver<String>,
 24    server: Child,
 25}
 26
 27impl StdioTransport {
 28    pub fn new(
 29        binary: ModelContextServerBinary,
 30        working_directory: &Option<PathBuf>,
 31        cx: &AsyncApp,
 32    ) -> Result<Self> {
 33        let builder = ShellBuilder::new(&Shell::System, cfg!(windows)).non_interactive();
 34        let mut command =
 35            builder.build_smol_command(Some(binary.executable.display().to_string()), &binary.args);
 36
 37        command
 38            .envs(binary.env.unwrap_or_default())
 39            .stdin(std::process::Stdio::piped())
 40            .stdout(std::process::Stdio::piped())
 41            .stderr(std::process::Stdio::piped())
 42            .kill_on_drop(true);
 43
 44        if let Some(working_directory) = working_directory {
 45            command.current_dir(working_directory);
 46        }
 47
 48        let mut server = command
 49            .spawn()
 50            .with_context(|| format!("failed to spawn command {command:?})",))?;
 51
 52        let stdin = server.stdin.take().unwrap();
 53        let stdout = server.stdout.take().unwrap();
 54        let stderr = server.stderr.take().unwrap();
 55
 56        let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
 57        let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
 58        let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
 59
 60        cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
 61            .detach();
 62
 63        cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
 64            .detach();
 65
 66        cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
 67            .detach();
 68
 69        Ok(Self {
 70            stdout_sender,
 71            stdin_receiver,
 72            stderr_receiver,
 73            server,
 74        })
 75    }
 76
 77    async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
 78    where
 79        Stdout: AsyncRead + Unpin + Send + 'static,
 80    {
 81        let mut stdin = BufReader::new(stdin);
 82        let mut line = String::new();
 83        while let Ok(n) = stdin.read_line(&mut line).await {
 84            if n == 0 {
 85                break;
 86            }
 87            if inbound_rx.send(line.clone()).await.is_err() {
 88                break;
 89            }
 90            line.clear();
 91        }
 92    }
 93
 94    async fn handle_output<Stdin>(
 95        stdin: Stdin,
 96        outbound_rx: channel::Receiver<String>,
 97    ) -> Result<()>
 98    where
 99        Stdin: AsyncWrite + Unpin + Send + 'static,
100    {
101        let mut stdin = BufWriter::new(stdin);
102        let mut pinned_rx = Box::pin(outbound_rx);
103        while let Some(message) = pinned_rx.next().await {
104            log::trace!("outgoing message: {}", message);
105
106            stdin.write_all(message.as_bytes()).await?;
107            stdin.write_all(b"\n").await?;
108            stdin.flush().await?;
109        }
110        Ok(())
111    }
112
113    async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
114    where
115        Stderr: AsyncRead + Unpin + Send + 'static,
116    {
117        let mut stderr = BufReader::new(stderr);
118        let mut line = String::new();
119        while let Ok(n) = stderr.read_line(&mut line).await {
120            if n == 0 {
121                break;
122            }
123            if stderr_tx.send(line.clone()).await.is_err() {
124                break;
125            }
126            line.clear();
127        }
128    }
129}
130
131#[async_trait]
132impl Transport for StdioTransport {
133    async fn send(&self, message: String) -> Result<()> {
134        Ok(self.stdout_sender.send(message).await?)
135    }
136
137    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
138        Box::pin(self.stdin_receiver.clone())
139    }
140
141    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
142        Box::pin(self.stderr_receiver.clone())
143    }
144}
145
146impl Drop for StdioTransport {
147    fn drop(&mut self) {
148        let _ = self.server.kill();
149    }
150}