1use std::path::PathBuf;
2use std::pin::Pin;
3
4use anyhow::{Context as _, Result};
5use async_trait::async_trait;
6use futures::io::{BufReader, BufWriter};
7use futures::{
8 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
9};
10use gpui::AsyncApp;
11use smol::channel;
12use smol::process::Child;
13use util::TryFutureExt as _;
14
15use crate::client::ModelContextServerBinary;
16use crate::transport::Transport;
17
18pub struct StdioTransport {
19 stdout_sender: channel::Sender<String>,
20 stdin_receiver: channel::Receiver<String>,
21 stderr_receiver: channel::Receiver<String>,
22 server: Child,
23}
24
25impl StdioTransport {
26 pub fn new(
27 binary: ModelContextServerBinary,
28 working_directory: &Option<PathBuf>,
29 cx: &AsyncApp,
30 ) -> Result<Self> {
31 let mut command = util::command::new_smol_command(&binary.executable);
32 command
33 .args(&binary.args)
34 .envs(binary.env.unwrap_or_default())
35 .stdin(std::process::Stdio::piped())
36 .stdout(std::process::Stdio::piped())
37 .stderr(std::process::Stdio::piped())
38 .kill_on_drop(true);
39
40 if let Some(working_directory) = working_directory {
41 command.current_dir(working_directory);
42 }
43
44 let mut server = command
45 .spawn()
46 .with_context(|| format!("failed to spawn command {command:?})",))?;
47
48 let stdin = server.stdin.take().unwrap();
49 let stdout = server.stdout.take().unwrap();
50 let stderr = server.stderr.take().unwrap();
51
52 let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
53 let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
54 let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
55
56 cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
57 .detach();
58
59 cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
60 .detach();
61
62 cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
63 .detach();
64
65 Ok(Self {
66 stdout_sender,
67 stdin_receiver,
68 stderr_receiver,
69 server,
70 })
71 }
72
73 async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
74 where
75 Stdout: AsyncRead + Unpin + Send + 'static,
76 {
77 let mut stdin = BufReader::new(stdin);
78 let mut line = String::new();
79 while let Ok(n) = stdin.read_line(&mut line).await {
80 if n == 0 {
81 break;
82 }
83 if inbound_rx.send(line.clone()).await.is_err() {
84 break;
85 }
86 line.clear();
87 }
88 }
89
90 async fn handle_output<Stdin>(
91 stdin: Stdin,
92 outbound_rx: channel::Receiver<String>,
93 ) -> Result<()>
94 where
95 Stdin: AsyncWrite + Unpin + Send + 'static,
96 {
97 let mut stdin = BufWriter::new(stdin);
98 let mut pinned_rx = Box::pin(outbound_rx);
99 while let Some(message) = pinned_rx.next().await {
100 log::trace!("outgoing message: {}", message);
101
102 stdin.write_all(message.as_bytes()).await?;
103 stdin.write_all(b"\n").await?;
104 stdin.flush().await?;
105 }
106 Ok(())
107 }
108
109 async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
110 where
111 Stderr: AsyncRead + Unpin + Send + 'static,
112 {
113 let mut stderr = BufReader::new(stderr);
114 let mut line = String::new();
115 while let Ok(n) = stderr.read_line(&mut line).await {
116 if n == 0 {
117 break;
118 }
119 if stderr_tx.send(line.clone()).await.is_err() {
120 break;
121 }
122 line.clear();
123 }
124 }
125}
126
127#[async_trait]
128impl Transport for StdioTransport {
129 async fn send(&self, message: String) -> Result<()> {
130 Ok(self.stdout_sender.send(message).await?)
131 }
132
133 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
134 Box::pin(self.stdin_receiver.clone())
135 }
136
137 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
138 Box::pin(self.stderr_receiver.clone())
139 }
140}
141
142impl Drop for StdioTransport {
143 fn drop(&mut self) {
144 let _ = self.server.kill();
145 }
146}