1use std::path::PathBuf;
2use std::pin::Pin;
3
4use anyhow::{Context as _, Result};
5use async_trait::async_trait;
6use futures::io::{BufReader, BufWriter};
7use futures::{
8 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
9};
10use gpui::AsyncApp;
11use smol::channel;
12use smol::process::Child;
13use util::TryFutureExt as _;
14use util::shell::Shell;
15use util::shell_builder::ShellBuilder;
16
17use crate::client::ModelContextServerBinary;
18use crate::transport::Transport;
19
20pub struct StdioTransport {
21 stdout_sender: channel::Sender<String>,
22 stdin_receiver: channel::Receiver<String>,
23 stderr_receiver: channel::Receiver<String>,
24 server: Child,
25}
26
27impl StdioTransport {
28 pub fn new(
29 binary: ModelContextServerBinary,
30 working_directory: &Option<PathBuf>,
31 cx: &AsyncApp,
32 ) -> Result<Self> {
33 let builder = ShellBuilder::new(&Shell::System, cfg!(windows)).non_interactive();
34 let mut command =
35 builder.build_smol_command(Some(binary.executable.display().to_string()), &binary.args);
36
37 command
38 .envs(binary.env.unwrap_or_default())
39 .stdin(std::process::Stdio::piped())
40 .stdout(std::process::Stdio::piped())
41 .stderr(std::process::Stdio::piped())
42 .kill_on_drop(true);
43
44 if let Some(working_directory) = working_directory {
45 command.current_dir(working_directory);
46 }
47
48 let mut server = command
49 .spawn()
50 .with_context(|| format!("failed to spawn command {command:?})",))?;
51
52 let stdin = server.stdin.take().unwrap();
53 let stdout = server.stdout.take().unwrap();
54 let stderr = server.stderr.take().unwrap();
55
56 let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
57 let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
58 let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
59
60 cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
61 .detach();
62
63 cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
64 .detach();
65
66 cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
67 .detach();
68
69 Ok(Self {
70 stdout_sender,
71 stdin_receiver,
72 stderr_receiver,
73 server,
74 })
75 }
76
77 async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
78 where
79 Stdout: AsyncRead + Unpin + Send + 'static,
80 {
81 let mut stdin = BufReader::new(stdin);
82 let mut line = String::new();
83 while let Ok(n) = stdin.read_line(&mut line).await {
84 if n == 0 {
85 break;
86 }
87 if inbound_rx.send(line.clone()).await.is_err() {
88 break;
89 }
90 line.clear();
91 }
92 }
93
94 async fn handle_output<Stdin>(
95 stdin: Stdin,
96 outbound_rx: channel::Receiver<String>,
97 ) -> Result<()>
98 where
99 Stdin: AsyncWrite + Unpin + Send + 'static,
100 {
101 let mut stdin = BufWriter::new(stdin);
102 let mut pinned_rx = Box::pin(outbound_rx);
103 while let Some(message) = pinned_rx.next().await {
104 log::trace!("outgoing message: {}", message);
105
106 stdin.write_all(message.as_bytes()).await?;
107 stdin.write_all(b"\n").await?;
108 stdin.flush().await?;
109 }
110 Ok(())
111 }
112
113 async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
114 where
115 Stderr: AsyncRead + Unpin + Send + 'static,
116 {
117 let mut stderr = BufReader::new(stderr);
118 let mut line = String::new();
119 while let Ok(n) = stderr.read_line(&mut line).await {
120 if n == 0 {
121 break;
122 }
123 if stderr_tx.send(line.clone()).await.is_err() {
124 break;
125 }
126 line.clear();
127 }
128 }
129}
130
131#[async_trait]
132impl Transport for StdioTransport {
133 async fn send(&self, message: String) -> Result<()> {
134 Ok(self.stdout_sender.send(message).await?)
135 }
136
137 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
138 Box::pin(self.stdin_receiver.clone())
139 }
140
141 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
142 Box::pin(self.stderr_receiver.clone())
143 }
144}
145
146impl Drop for StdioTransport {
147 fn drop(&mut self) {
148 let _ = self.server.kill();
149 }
150}