stdio_transport.rs

  1use std::path::PathBuf;
  2use std::pin::Pin;
  3
  4use anyhow::{Context as _, Result};
  5use async_trait::async_trait;
  6use futures::io::{BufReader, BufWriter};
  7use futures::{
  8    AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
  9};
 10use gpui::AsyncApp;
 11use settings::Settings as _;
 12use smol::channel;
 13use smol::process::Child;
 14use terminal::terminal_settings::TerminalSettings;
 15use util::TryFutureExt as _;
 16use util::shell_builder::ShellBuilder;
 17
 18use crate::client::ModelContextServerBinary;
 19use crate::transport::Transport;
 20
 21pub struct StdioTransport {
 22    stdout_sender: channel::Sender<String>,
 23    stdin_receiver: channel::Receiver<String>,
 24    stderr_receiver: channel::Receiver<String>,
 25    server: Child,
 26}
 27
 28impl StdioTransport {
 29    pub fn new(
 30        binary: ModelContextServerBinary,
 31        working_directory: &Option<PathBuf>,
 32        cx: &AsyncApp,
 33    ) -> Result<Self> {
 34        let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
 35        let builder = ShellBuilder::new(&shell, cfg!(windows));
 36        let (command, args) =
 37            builder.build(Some(binary.executable.display().to_string()), &binary.args);
 38
 39        let mut command = util::command::new_smol_command(command);
 40        command
 41            .args(args)
 42            .envs(binary.env.unwrap_or_default())
 43            .stdin(std::process::Stdio::piped())
 44            .stdout(std::process::Stdio::piped())
 45            .stderr(std::process::Stdio::piped())
 46            .kill_on_drop(true);
 47
 48        if let Some(working_directory) = working_directory {
 49            command.current_dir(working_directory);
 50        }
 51
 52        let mut server = command
 53            .spawn()
 54            .with_context(|| format!("failed to spawn command {command:?})",))?;
 55
 56        let stdin = server.stdin.take().unwrap();
 57        let stdout = server.stdout.take().unwrap();
 58        let stderr = server.stderr.take().unwrap();
 59
 60        let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
 61        let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
 62        let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
 63
 64        cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
 65            .detach();
 66
 67        cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
 68            .detach();
 69
 70        cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
 71            .detach();
 72
 73        Ok(Self {
 74            stdout_sender,
 75            stdin_receiver,
 76            stderr_receiver,
 77            server,
 78        })
 79    }
 80
 81    async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
 82    where
 83        Stdout: AsyncRead + Unpin + Send + 'static,
 84    {
 85        let mut stdin = BufReader::new(stdin);
 86        let mut line = String::new();
 87        while let Ok(n) = stdin.read_line(&mut line).await {
 88            if n == 0 {
 89                break;
 90            }
 91            if inbound_rx.send(line.clone()).await.is_err() {
 92                break;
 93            }
 94            line.clear();
 95        }
 96    }
 97
 98    async fn handle_output<Stdin>(
 99        stdin: Stdin,
100        outbound_rx: channel::Receiver<String>,
101    ) -> Result<()>
102    where
103        Stdin: AsyncWrite + Unpin + Send + 'static,
104    {
105        let mut stdin = BufWriter::new(stdin);
106        let mut pinned_rx = Box::pin(outbound_rx);
107        while let Some(message) = pinned_rx.next().await {
108            log::trace!("outgoing message: {}", message);
109
110            stdin.write_all(message.as_bytes()).await?;
111            stdin.write_all(b"\n").await?;
112            stdin.flush().await?;
113        }
114        Ok(())
115    }
116
117    async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
118    where
119        Stderr: AsyncRead + Unpin + Send + 'static,
120    {
121        let mut stderr = BufReader::new(stderr);
122        let mut line = String::new();
123        while let Ok(n) = stderr.read_line(&mut line).await {
124            if n == 0 {
125                break;
126            }
127            if stderr_tx.send(line.clone()).await.is_err() {
128                break;
129            }
130            line.clear();
131        }
132    }
133}
134
135#[async_trait]
136impl Transport for StdioTransport {
137    async fn send(&self, message: String) -> Result<()> {
138        Ok(self.stdout_sender.send(message).await?)
139    }
140
141    fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
142        Box::pin(self.stdin_receiver.clone())
143    }
144
145    fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
146        Box::pin(self.stderr_receiver.clone())
147    }
148}
149
150impl Drop for StdioTransport {
151    fn drop(&mut self) {
152        let _ = self.server.kill();
153    }
154}