1use std::path::PathBuf;
2use std::pin::Pin;
3
4use anyhow::{Context as _, Result};
5use async_trait::async_trait;
6use futures::io::{BufReader, BufWriter};
7use futures::{
8 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
9};
10use gpui::AsyncApp;
11use settings::Settings as _;
12use smol::channel;
13use smol::process::Child;
14use terminal::terminal_settings::TerminalSettings;
15use util::TryFutureExt as _;
16use util::shell_builder::ShellBuilder;
17
18use crate::client::ModelContextServerBinary;
19use crate::transport::Transport;
20
21pub struct StdioTransport {
22 stdout_sender: channel::Sender<String>,
23 stdin_receiver: channel::Receiver<String>,
24 stderr_receiver: channel::Receiver<String>,
25 server: Child,
26}
27
28impl StdioTransport {
29 pub fn new(
30 binary: ModelContextServerBinary,
31 working_directory: &Option<PathBuf>,
32 cx: &AsyncApp,
33 ) -> Result<Self> {
34 let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
35 let builder = ShellBuilder::new(&shell, cfg!(windows));
36 let mut command =
37 builder.build_command(Some(binary.executable.display().to_string()), &binary.args);
38
39 command
40 .envs(binary.env.unwrap_or_default())
41 .stdin(std::process::Stdio::piped())
42 .stdout(std::process::Stdio::piped())
43 .stderr(std::process::Stdio::piped())
44 .kill_on_drop(true);
45
46 if let Some(working_directory) = working_directory {
47 command.current_dir(working_directory);
48 }
49
50 let mut server = command
51 .spawn()
52 .with_context(|| format!("failed to spawn command {command:?})",))?;
53
54 let stdin = server.stdin.take().unwrap();
55 let stdout = server.stdout.take().unwrap();
56 let stderr = server.stderr.take().unwrap();
57
58 let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
59 let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
60 let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
61
62 cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
63 .detach();
64
65 cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
66 .detach();
67
68 cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
69 .detach();
70
71 Ok(Self {
72 stdout_sender,
73 stdin_receiver,
74 stderr_receiver,
75 server,
76 })
77 }
78
79 async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
80 where
81 Stdout: AsyncRead + Unpin + Send + 'static,
82 {
83 let mut stdin = BufReader::new(stdin);
84 let mut line = String::new();
85 while let Ok(n) = stdin.read_line(&mut line).await {
86 if n == 0 {
87 break;
88 }
89 if inbound_rx.send(line.clone()).await.is_err() {
90 break;
91 }
92 line.clear();
93 }
94 }
95
96 async fn handle_output<Stdin>(
97 stdin: Stdin,
98 outbound_rx: channel::Receiver<String>,
99 ) -> Result<()>
100 where
101 Stdin: AsyncWrite + Unpin + Send + 'static,
102 {
103 let mut stdin = BufWriter::new(stdin);
104 let mut pinned_rx = Box::pin(outbound_rx);
105 while let Some(message) = pinned_rx.next().await {
106 log::trace!("outgoing message: {}", message);
107
108 stdin.write_all(message.as_bytes()).await?;
109 stdin.write_all(b"\n").await?;
110 stdin.flush().await?;
111 }
112 Ok(())
113 }
114
115 async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
116 where
117 Stderr: AsyncRead + Unpin + Send + 'static,
118 {
119 let mut stderr = BufReader::new(stderr);
120 let mut line = String::new();
121 while let Ok(n) = stderr.read_line(&mut line).await {
122 if n == 0 {
123 break;
124 }
125 if stderr_tx.send(line.clone()).await.is_err() {
126 break;
127 }
128 line.clear();
129 }
130 }
131}
132
133#[async_trait]
134impl Transport for StdioTransport {
135 async fn send(&self, message: String) -> Result<()> {
136 Ok(self.stdout_sender.send(message).await?)
137 }
138
139 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
140 Box::pin(self.stdin_receiver.clone())
141 }
142
143 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
144 Box::pin(self.stderr_receiver.clone())
145 }
146}
147
148impl Drop for StdioTransport {
149 fn drop(&mut self) {
150 let _ = self.server.kill();
151 }
152}