1use std::pin::Pin;
2
3use anyhow::{Context as _, Result};
4use async_trait::async_trait;
5use futures::io::{BufReader, BufWriter};
6use futures::{
7 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
8};
9use gpui::AsyncApp;
10use smol::channel;
11use smol::process::Child;
12use util::TryFutureExt as _;
13
14use crate::client::ModelContextServerBinary;
15use crate::transport::Transport;
16
17pub struct StdioTransport {
18 stdout_sender: channel::Sender<String>,
19 stdin_receiver: channel::Receiver<String>,
20 stderr_receiver: channel::Receiver<String>,
21 server: Child,
22}
23
24impl StdioTransport {
25 pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
26 let mut command = util::command::new_smol_command(&binary.executable);
27 command
28 .args(&binary.args)
29 .envs(binary.env.unwrap_or_default())
30 .stdin(std::process::Stdio::piped())
31 .stdout(std::process::Stdio::piped())
32 .stderr(std::process::Stdio::piped())
33 .kill_on_drop(true);
34
35 let mut server = command.spawn().with_context(|| {
36 format!(
37 "failed to spawn command. (path={:?}, args={:?})",
38 binary.executable, &binary.args
39 )
40 })?;
41
42 let stdin = server.stdin.take().unwrap();
43 let stdout = server.stdout.take().unwrap();
44 let stderr = server.stderr.take().unwrap();
45
46 let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
47 let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
48 let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
49
50 cx.spawn(async move |_| Self::handle_output(stdin, stdout_receiver).log_err().await)
51 .detach();
52
53 cx.spawn(async move |_| Self::handle_input(stdout, stdin_sender).await)
54 .detach();
55
56 cx.spawn(async move |_| Self::handle_err(stderr, stderr_sender).await)
57 .detach();
58
59 Ok(Self {
60 stdout_sender,
61 stdin_receiver,
62 stderr_receiver,
63 server,
64 })
65 }
66
67 async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
68 where
69 Stdout: AsyncRead + Unpin + Send + 'static,
70 {
71 let mut stdin = BufReader::new(stdin);
72 let mut line = String::new();
73 while let Ok(n) = stdin.read_line(&mut line).await {
74 if n == 0 {
75 break;
76 }
77 if inbound_rx.send(line.clone()).await.is_err() {
78 break;
79 }
80 line.clear();
81 }
82 }
83
84 async fn handle_output<Stdin>(
85 stdin: Stdin,
86 outbound_rx: channel::Receiver<String>,
87 ) -> Result<()>
88 where
89 Stdin: AsyncWrite + Unpin + Send + 'static,
90 {
91 let mut stdin = BufWriter::new(stdin);
92 let mut pinned_rx = Box::pin(outbound_rx);
93 while let Some(message) = pinned_rx.next().await {
94 log::trace!("outgoing message: {}", message);
95
96 stdin.write_all(message.as_bytes()).await?;
97 stdin.write_all(b"\n").await?;
98 stdin.flush().await?;
99 }
100 Ok(())
101 }
102
103 async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
104 where
105 Stderr: AsyncRead + Unpin + Send + 'static,
106 {
107 let mut stderr = BufReader::new(stderr);
108 let mut line = String::new();
109 while let Ok(n) = stderr.read_line(&mut line).await {
110 if n == 0 {
111 break;
112 }
113 if stderr_tx.send(line.clone()).await.is_err() {
114 break;
115 }
116 line.clear();
117 }
118 }
119}
120
121#[async_trait]
122impl Transport for StdioTransport {
123 async fn send(&self, message: String) -> Result<()> {
124 Ok(self.stdout_sender.send(message).await?)
125 }
126
127 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
128 Box::pin(self.stdin_receiver.clone())
129 }
130
131 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
132 Box::pin(self.stderr_receiver.clone())
133 }
134}
135
136impl Drop for StdioTransport {
137 fn drop(&mut self) {
138 let _ = self.server.kill();
139 }
140}