1use std::path::PathBuf;
2use std::pin::Pin;
3
4use anyhow::{Context as _, Result};
5use async_trait::async_trait;
6use futures::io::{BufReader, BufWriter};
7use futures::{
8 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
9};
10use gpui::AsyncApp;
11use smol::channel;
12use smol::process::Child;
13use util::TryFutureExt as _;
14
15use crate::client::ModelContextServerBinary;
16use crate::transport::Transport;
17
18pub struct StdioTransport {
19 stdout_sender: channel::Sender<String>,
20 stdin_receiver: channel::Receiver<String>,
21 stderr_receiver: channel::Receiver<String>,
22 server: Child,
23}
24
25impl StdioTransport {
26 pub fn new(
27 binary: ModelContextServerBinary,
28 working_directory: &Option<PathBuf>,
29 cx: &AsyncApp,
30 ) -> Result<Self> {
31 let mut command = util::command::new_smol_command(&binary.executable);
32 command
33 .args(&binary.args)
34 .envs(binary.env.unwrap_or_default())
35 .stdin(std::process::Stdio::piped())
36 .stdout(std::process::Stdio::piped())
37 .stderr(std::process::Stdio::piped())
38 .kill_on_drop(true);
39
40 if let Some(working_directory) = working_directory {
41 command.current_dir(working_directory);
42 }
43
44 let mut server = command.spawn().with_context(|| {
45 format!(
46 "failed to spawn command. (path={:?}, args={:?})",
47 binary.executable, &binary.args
48 )
49 })?;
50
51 let stdin = server.stdin.take().unwrap();
52 let stdout = server.stdout.take().unwrap();
53 let stderr = server.stderr.take().unwrap();
54
55 let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
56 let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
57 let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
58
59 cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
60 .detach();
61
62 cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
63 .detach();
64
65 cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
66 .detach();
67
68 Ok(Self {
69 stdout_sender,
70 stdin_receiver,
71 stderr_receiver,
72 server,
73 })
74 }
75
76 async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
77 where
78 Stdout: AsyncRead + Unpin + Send + 'static,
79 {
80 let mut stdin = BufReader::new(stdin);
81 let mut line = String::new();
82 while let Ok(n) = stdin.read_line(&mut line).await {
83 if n == 0 {
84 break;
85 }
86 if inbound_rx.send(line.clone()).await.is_err() {
87 break;
88 }
89 line.clear();
90 }
91 }
92
93 async fn handle_output<Stdin>(
94 stdin: Stdin,
95 outbound_rx: channel::Receiver<String>,
96 ) -> Result<()>
97 where
98 Stdin: AsyncWrite + Unpin + Send + 'static,
99 {
100 let mut stdin = BufWriter::new(stdin);
101 let mut pinned_rx = Box::pin(outbound_rx);
102 while let Some(message) = pinned_rx.next().await {
103 log::trace!("outgoing message: {}", message);
104
105 stdin.write_all(message.as_bytes()).await?;
106 stdin.write_all(b"\n").await?;
107 stdin.flush().await?;
108 }
109 Ok(())
110 }
111
112 async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
113 where
114 Stderr: AsyncRead + Unpin + Send + 'static,
115 {
116 let mut stderr = BufReader::new(stderr);
117 let mut line = String::new();
118 while let Ok(n) = stderr.read_line(&mut line).await {
119 if n == 0 {
120 break;
121 }
122 if stderr_tx.send(line.clone()).await.is_err() {
123 break;
124 }
125 line.clear();
126 }
127 }
128}
129
130#[async_trait]
131impl Transport for StdioTransport {
132 async fn send(&self, message: String) -> Result<()> {
133 Ok(self.stdout_sender.send(message).await?)
134 }
135
136 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
137 Box::pin(self.stdin_receiver.clone())
138 }
139
140 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
141 Box::pin(self.stderr_receiver.clone())
142 }
143}
144
145impl Drop for StdioTransport {
146 fn drop(&mut self) {
147 let _ = self.server.kill();
148 }
149}