1use std::path::PathBuf;
2use std::pin::Pin;
3
4use anyhow::{Context as _, Result};
5use async_trait::async_trait;
6use futures::io::{BufReader, BufWriter};
7use futures::{
8 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
9};
10use gpui::AsyncApp;
11use settings::Settings as _;
12use smol::channel;
13use smol::process::Child;
14use terminal::terminal_settings::TerminalSettings;
15use util::TryFutureExt as _;
16use util::shell_builder::ShellBuilder;
17
18use crate::client::ModelContextServerBinary;
19use crate::transport::Transport;
20
21pub struct StdioTransport {
22 stdout_sender: channel::Sender<String>,
23 stdin_receiver: channel::Receiver<String>,
24 stderr_receiver: channel::Receiver<String>,
25 server: Child,
26}
27
28impl StdioTransport {
29 pub fn new(
30 binary: ModelContextServerBinary,
31 working_directory: &Option<PathBuf>,
32 cx: &AsyncApp,
33 ) -> Result<Self> {
34 let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
35 let builder = ShellBuilder::new(&shell, cfg!(windows));
36 let (command, args) =
37 builder.build(Some(binary.executable.display().to_string()), &binary.args);
38
39 let mut command = util::command::new_smol_command(command);
40 command
41 .args(args)
42 .envs(binary.env.unwrap_or_default())
43 .stdin(std::process::Stdio::piped())
44 .stdout(std::process::Stdio::piped())
45 .stderr(std::process::Stdio::piped())
46 .kill_on_drop(true);
47
48 if let Some(working_directory) = working_directory {
49 command.current_dir(working_directory);
50 }
51
52 let mut server = command
53 .spawn()
54 .with_context(|| format!("failed to spawn command {command:?})",))?;
55
56 let stdin = server.stdin.take().unwrap();
57 let stdout = server.stdout.take().unwrap();
58 let stderr = server.stderr.take().unwrap();
59
60 let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
61 let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
62 let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
63
64 cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
65 .detach();
66
67 cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
68 .detach();
69
70 cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
71 .detach();
72
73 Ok(Self {
74 stdout_sender,
75 stdin_receiver,
76 stderr_receiver,
77 server,
78 })
79 }
80
81 async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
82 where
83 Stdout: AsyncRead + Unpin + Send + 'static,
84 {
85 let mut stdin = BufReader::new(stdin);
86 let mut line = String::new();
87 while let Ok(n) = stdin.read_line(&mut line).await {
88 if n == 0 {
89 break;
90 }
91 if inbound_rx.send(line.clone()).await.is_err() {
92 break;
93 }
94 line.clear();
95 }
96 }
97
98 async fn handle_output<Stdin>(
99 stdin: Stdin,
100 outbound_rx: channel::Receiver<String>,
101 ) -> Result<()>
102 where
103 Stdin: AsyncWrite + Unpin + Send + 'static,
104 {
105 let mut stdin = BufWriter::new(stdin);
106 let mut pinned_rx = Box::pin(outbound_rx);
107 while let Some(message) = pinned_rx.next().await {
108 log::trace!("outgoing message: {}", message);
109
110 stdin.write_all(message.as_bytes()).await?;
111 stdin.write_all(b"\n").await?;
112 stdin.flush().await?;
113 }
114 Ok(())
115 }
116
117 async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
118 where
119 Stderr: AsyncRead + Unpin + Send + 'static,
120 {
121 let mut stderr = BufReader::new(stderr);
122 let mut line = String::new();
123 while let Ok(n) = stderr.read_line(&mut line).await {
124 if n == 0 {
125 break;
126 }
127 if stderr_tx.send(line.clone()).await.is_err() {
128 break;
129 }
130 line.clear();
131 }
132 }
133}
134
135#[async_trait]
136impl Transport for StdioTransport {
137 async fn send(&self, message: String) -> Result<()> {
138 Ok(self.stdout_sender.send(message).await?)
139 }
140
141 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
142 Box::pin(self.stdin_receiver.clone())
143 }
144
145 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
146 Box::pin(self.stderr_receiver.clone())
147 }
148}
149
150impl Drop for StdioTransport {
151 fn drop(&mut self) {
152 let _ = self.server.kill();
153 }
154}