Cargo.lock 🔗
@@ -3101,6 +3101,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
+ "async-trait",
"collections",
"command_palette_hooks",
"context_server_settings",
Federico Dionisi and Marshall Bowers created
This PR abstracts the communication layer for context servers, laying
the groundwork for supporting multiple transport mechanisms and taking
one step towards enabling remote servers.
Key changes centre around creating a new `Transport` trait with methods
for sending and receiving messages. I've implemented this trait for the
existing stdio-based communication, which is now encapsulated in a
`StdioTransport` struct. The `Client` struct has been refactored to use
this new `Transport` trait instead of directly managing stdin and
stdout.
The next steps will involve implementing an SSE + HTTP transport and
defining alternative context server settings for remote servers.
Release Notes:
- N/A
---------
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
Cargo.lock | 1
crates/context_server/Cargo.toml | 1
crates/context_server/src/client.rs | 156 ++++--------
crates/context_server/src/context_server.rs | 1
crates/context_server/src/transport.rs | 16 +
crates/context_server/src/transport/stdio_transport.rs | 140 ++++++++++
6 files changed, 211 insertions(+), 104 deletions(-)
@@ -3101,6 +3101,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
+ "async-trait",
"collections",
"command_palette_hooks",
"context_server_settings",
@@ -14,6 +14,7 @@ path = "src/context_server.rs"
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
+async-trait.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_server_settings.workspace = true
@@ -1,16 +1,12 @@
-use anyhow::{anyhow, Context as _, Result};
+use anyhow::{anyhow, Context, Result};
use collections::HashMap;
-use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
+use futures::{channel::oneshot, select, FutureExt, StreamExt};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
use parking_lot::Mutex;
use postage::barrier;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
-use smol::{
- channel,
- io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
- process::Child,
-};
+use smol::channel;
use std::{
fmt,
path::PathBuf,
@@ -22,6 +18,8 @@ use std::{
};
use util::TryFutureExt;
+use crate::transport::{StdioTransport, Transport};
+
const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@@ -55,7 +53,8 @@ pub struct Client {
#[allow(dead_code)]
output_done_rx: Mutex<Option<barrier::Receiver>>,
executor: BackgroundExecutor,
- server: Arc<Mutex<Option<Child>>>,
+ #[allow(dead_code)]
+ transport: Arc<dyn Transport>,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -152,25 +151,13 @@ impl Client {
&binary.args
);
- let mut command = util::command::new_smol_command(&binary.executable);
- command
- .args(&binary.args)
- .envs(binary.env.unwrap_or_default())
- .stdin(std::process::Stdio::piped())
- .stdout(std::process::Stdio::piped())
- .stderr(std::process::Stdio::piped())
- .kill_on_drop(true);
-
- let mut server = command.spawn().with_context(|| {
- format!(
- "failed to spawn command. (path={:?}, args={:?})",
- binary.executable, &binary.args
- )
- })?;
+ let server_name = binary
+ .executable
+ .file_name()
+ .map(|name| name.to_string_lossy().to_string())
+ .unwrap_or_else(String::new);
- let stdin = server.stdin.take().unwrap();
- let stdout = server.stdout.take().unwrap();
- let stderr = server.stderr.take().unwrap();
+ let transport = Arc::new(StdioTransport::new(binary, &cx)?);
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
@@ -183,18 +170,22 @@ impl Client {
let stdout_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
+ let transport = transport.clone();
move |cx| {
- Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
+ Self::handle_input(transport, notification_handlers, response_handlers, cx)
+ .log_err()
}
});
- let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
+ let stderr_input_task = cx.spawn(|_| Self::handle_stderr(transport.clone()).log_err());
let input_task = cx.spawn(|_| async move {
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
stdout.or(stderr)
});
+
let output_task = cx.background_spawn({
+ let transport = transport.clone();
Self::handle_output(
- stdin,
+ transport,
outbound_rx,
output_done_tx,
response_handlers.clone(),
@@ -202,24 +193,18 @@ impl Client {
.log_err()
});
- let mut context_server = Self {
+ Ok(Self {
server_id,
notification_handlers,
response_handlers,
- name: "".into(),
+ name: server_name.into(),
next_id: Default::default(),
outbound_tx,
executor: cx.background_executor().clone(),
io_tasks: Mutex::new(Some((input_task, output_task))),
output_done_rx: Mutex::new(Some(output_done_rx)),
- server: Arc::new(Mutex::new(Some(server))),
- };
-
- if let Some(name) = binary.executable.file_name() {
- context_server.name = name.to_string_lossy().into();
- }
-
- Ok(context_server)
+ transport,
+ })
}
/// Handles input from the server's stdout.
@@ -228,79 +213,53 @@ impl Client {
/// parses them as JSON-RPC responses or notifications, and dispatches them
/// to the appropriate handlers. It processes both responses (which are matched
/// to pending requests) and notifications (which trigger registered handlers).
- async fn handle_input<Stdout>(
- stdout: Stdout,
+ async fn handle_input(
+ transport: Arc<dyn Transport>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: AsyncApp,
- ) -> anyhow::Result<()>
- where
- Stdout: AsyncRead + Unpin + Send + 'static,
- {
- let mut stdout = BufReader::new(stdout);
- let mut buffer = String::new();
-
- loop {
- buffer.clear();
- if stdout.read_line(&mut buffer).await? == 0 {
- return Ok(());
- }
-
- let content = buffer.trim();
-
- if !content.is_empty() {
- if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
- if let Some(handlers) = response_handlers.lock().as_mut() {
- if let Some(handler) = handlers.remove(&response.id) {
- handler(Ok(content.to_string()));
- }
- }
- } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
- let mut notification_handlers = notification_handlers.lock();
- if let Some(handler) =
- notification_handlers.get_mut(notification.method.as_str())
- {
- handler(notification.params.unwrap_or(Value::Null), cx.clone());
+ ) -> anyhow::Result<()> {
+ let mut receiver = transport.receive();
+
+ while let Some(message) = receiver.next().await {
+ if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
+ if let Some(handlers) = response_handlers.lock().as_mut() {
+ if let Some(handler) = handlers.remove(&response.id) {
+ handler(Ok(message.to_string()));
}
}
+ } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
+ let mut notification_handlers = notification_handlers.lock();
+ if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
+ handler(notification.params.unwrap_or(Value::Null), cx.clone());
+ }
}
-
- smol::future::yield_now().await;
}
+
+ smol::future::yield_now().await;
+
+ Ok(())
}
/// Handles the stderr output from the context server.
/// Continuously reads and logs any error messages from the server.
- async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
- where
- Stderr: AsyncRead + Unpin + Send + 'static,
- {
- let mut stderr = BufReader::new(stderr);
- let mut buffer = String::new();
-
- loop {
- buffer.clear();
- if stderr.read_line(&mut buffer).await? == 0 {
- return Ok(());
- }
- log::warn!("context server stderr: {}", buffer.trim());
- smol::future::yield_now().await;
+ async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
+ while let Some(err) = transport.receive_err().next().await {
+ log::warn!("context server stderr: {}", err.trim());
}
+
+ Ok(())
}
/// Handles the output to the context server's stdin.
/// This function continuously receives messages from the outbound channel,
/// writes them to the server's stdin, and manages the lifecycle of response handlers.
- async fn handle_output<Stdin>(
- stdin: Stdin,
+ async fn handle_output(
+ transport: Arc<dyn Transport>,
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
- ) -> anyhow::Result<()>
- where
- Stdin: AsyncWrite + Unpin + Send + 'static,
- {
- let mut stdin = BufWriter::new(stdin);
+ ) -> anyhow::Result<()> {
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
move || {
@@ -309,10 +268,7 @@ impl Client {
});
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message: {}", message);
-
- stdin.write_all(message.as_bytes()).await?;
- stdin.write_all(b"\n").await?;
- stdin.flush().await?;
+ transport.send(message).await?;
}
drop(output_done_tx);
Ok(())
@@ -416,14 +372,6 @@ impl Client {
}
}
-impl Drop for Client {
- fn drop(&mut self) {
- if let Some(mut server) = self.server.lock().take() {
- let _ = server.kill();
- }
- }
-}
-
impl fmt::Display for ContextServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
@@ -4,6 +4,7 @@ mod extension_context_server;
pub mod manager;
pub mod protocol;
mod registry;
+mod transport;
pub mod types;
use command_palette_hooks::CommandPaletteFilter;
@@ -0,0 +1,16 @@
+mod stdio_transport;
+
+use std::pin::Pin;
+
+use anyhow::Result;
+use async_trait::async_trait;
+use futures::Stream;
+
+pub use stdio_transport::*;
+
+#[async_trait]
+pub trait Transport: Send + Sync {
+ async fn send(&self, message: String) -> Result<()>;
+ fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
+ fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
+}
@@ -0,0 +1,140 @@
+use std::pin::Pin;
+
+use anyhow::{Context as _, Result};
+use async_trait::async_trait;
+use futures::io::{BufReader, BufWriter};
+use futures::{
+ AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
+};
+use gpui::AsyncApp;
+use smol::channel;
+use smol::process::Child;
+use util::TryFutureExt as _;
+
+use crate::client::ModelContextServerBinary;
+use crate::transport::Transport;
+
+pub struct StdioTransport {
+ stdout_sender: channel::Sender<String>,
+ stdin_receiver: channel::Receiver<String>,
+ stderr_receiver: channel::Receiver<String>,
+ server: Child,
+}
+
+impl StdioTransport {
+ pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
+ let mut command = util::command::new_smol_command(&binary.executable);
+ command
+ .args(&binary.args)
+ .envs(binary.env.unwrap_or_default())
+ .stdin(std::process::Stdio::piped())
+ .stdout(std::process::Stdio::piped())
+ .stderr(std::process::Stdio::piped())
+ .kill_on_drop(true);
+
+ let mut server = command.spawn().with_context(|| {
+ format!(
+ "failed to spawn command. (path={:?}, args={:?})",
+ binary.executable, &binary.args
+ )
+ })?;
+
+ let stdin = server.stdin.take().unwrap();
+ let stdout = server.stdout.take().unwrap();
+ let stderr = server.stderr.take().unwrap();
+
+ let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
+ let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
+ let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
+
+ cx.spawn(|_| Self::handle_output(stdin, stdout_receiver).log_err())
+ .detach();
+
+ cx.spawn(|_| async move { Self::handle_input(stdout, stdin_sender).await })
+ .detach();
+
+ cx.spawn(|_| async move { Self::handle_err(stderr, stderr_sender).await })
+ .detach();
+
+ Ok(Self {
+ stdout_sender,
+ stdin_receiver,
+ stderr_receiver,
+ server,
+ })
+ }
+
+ async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
+ where
+ Stdout: AsyncRead + Unpin + Send + 'static,
+ {
+ let mut stdin = BufReader::new(stdin);
+ let mut line = String::new();
+ while let Ok(n) = stdin.read_line(&mut line).await {
+ if n == 0 {
+ break;
+ }
+ if inbound_rx.send(line.clone()).await.is_err() {
+ break;
+ }
+ line.clear();
+ }
+ }
+
+ async fn handle_output<Stdin>(
+ stdin: Stdin,
+ outbound_rx: channel::Receiver<String>,
+ ) -> Result<()>
+ where
+ Stdin: AsyncWrite + Unpin + Send + 'static,
+ {
+ let mut stdin = BufWriter::new(stdin);
+ let mut pinned_rx = Box::pin(outbound_rx);
+ while let Some(message) = pinned_rx.next().await {
+ log::trace!("outgoing message: {}", message);
+
+ stdin.write_all(message.as_bytes()).await?;
+ stdin.write_all(b"\n").await?;
+ stdin.flush().await?;
+ }
+ Ok(())
+ }
+
+ async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
+ where
+ Stderr: AsyncRead + Unpin + Send + 'static,
+ {
+ let mut stderr = BufReader::new(stderr);
+ let mut line = String::new();
+ while let Ok(n) = stderr.read_line(&mut line).await {
+ if n == 0 {
+ break;
+ }
+ if stderr_tx.send(line.clone()).await.is_err() {
+ break;
+ }
+ line.clear();
+ }
+ }
+}
+
+#[async_trait]
+impl Transport for StdioTransport {
+ async fn send(&self, message: String) -> Result<()> {
+ Ok(self.stdout_sender.send(message).await?)
+ }
+
+ fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
+ Box::pin(self.stdin_receiver.clone())
+ }
+
+ fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
+ Box::pin(self.stderr_receiver.clone())
+ }
+}
+
+impl Drop for StdioTransport {
+ fn drop(&mut self) {
+ let _ = self.server.kill();
+ }
+}