Cargo.lock 🔗
@@ -8,6 +8,7 @@ version = "0.1.0"
dependencies = [
"agentic-coding-protocol",
"anyhow",
+ "async-pipe",
"async-trait",
"base64 0.22.1",
"buffer_diff",
Nathan Sobo created
Cargo.lock | 1
crates/acp/Cargo.toml | 1
crates/acp/src/acp.rs | 155 +++++++++++++++++++++++++++++++++++++++++
crates/acp/src/server.rs | 32 ++++++++
4 files changed, 186 insertions(+), 3 deletions(-)
@@ -8,6 +8,7 @@ version = "0.1.0"
dependencies = [
"agentic-coding-protocol",
"anyhow",
+ "async-pipe",
"async-trait",
"base64 0.22.1",
"buffer_diff",
@@ -42,6 +42,7 @@ workspace-hack.workspace = true
zed_actions.workspace = true
[dev-dependencies]
+async-pipe.workspace = true
env_logger.workspace = true
gpui = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
@@ -771,13 +771,15 @@ pub struct ToolCallRequest {
#[cfg(test)]
mod tests {
use super::*;
- use futures::{FutureExt as _, channel::mpsc, select};
- use gpui::TestAppContext;
+ use async_pipe::{PipeReader, PipeWriter};
+ use async_trait::async_trait;
+ use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select};
+ use gpui::{AsyncApp, TestAppContext};
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
- use std::{env, path::Path, process::Stdio, time::Duration};
+ use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration};
use util::path;
fn init_test(cx: &mut TestAppContext) {
@@ -790,6 +792,42 @@ mod tests {
});
}
+ #[gpui::test]
+ async fn test_message_receipt(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+ let (server, fake_server) = fake_acp_server(project, cx);
+
+ server.initialize().await.unwrap();
+
+ fake_server.update(cx, |fake_server, _| {
+ fake_server.on_user_message(move |params, server, mut cx| async move {
+ server
+ .update(&mut cx, |server, cx| {
+ let future =
+ server
+ .connection
+ .request(acp::StreamAssistantMessageChunkParams {
+ thread_id: params.thread_id,
+ chunk: acp::AssistantMessageChunk::Thought {
+ chunk: "Thinking ".into(),
+ },
+ });
+
+ cx.spawn(async move |_, _| future.await)
+ })?
+ .await
+ .unwrap();
+
+ Ok(acp::SendUserMessageResponse)
+ })
+ })
+ }
+
#[gpui::test]
async fn test_gemini_basic(cx: &mut TestAppContext) {
init_test(cx);
@@ -1043,4 +1081,115 @@ mod tests {
server.initialize().await.unwrap();
server
}
+
+ pub fn fake_acp_server(
+ project: Entity<Project>,
+ cx: &mut TestAppContext,
+ ) -> (Arc<AcpServer>, Entity<FakeAcpServer>) {
+ let (stdin_tx, stdin_rx) = async_pipe::pipe();
+ let (stdout_tx, stdout_rx) = async_pipe::pipe();
+ let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
+ let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
+ (server, agent)
+ }
+
+ pub struct FakeAcpServer {
+ connection: acp::ClientConnection,
+ _handler_task: Task<()>,
+ _io_task: Task<()>,
+ on_user_message: Option<
+ Rc<
+ dyn Fn(
+ acp::SendUserMessageParams,
+ Entity<FakeAcpServer>,
+ AsyncApp,
+ )
+ -> LocalBoxFuture<'static, Result<acp::SendUserMessageResponse>>,
+ >,
+ >,
+ }
+
+ #[derive(Clone)]
+ struct FakeAgent {
+ server: Entity<FakeAcpServer>,
+ cx: AsyncApp,
+ }
+
+ #[async_trait(?Send)]
+ impl acp::Agent for FakeAgent {
+ async fn initialize(
+ &self,
+ _request: acp::InitializeParams,
+ ) -> Result<acp::InitializeResponse> {
+ Ok(acp::InitializeResponse {
+ is_authenticated: true,
+ })
+ }
+
+ async fn authenticate(
+ &self,
+ _request: acp::AuthenticateParams,
+ ) -> Result<acp::AuthenticateResponse> {
+ Ok(acp::AuthenticateResponse)
+ }
+
+ async fn create_thread(
+ &self,
+ _request: acp::CreateThreadParams,
+ ) -> Result<acp::CreateThreadResponse> {
+ Ok(acp::CreateThreadResponse {
+ thread_id: acp::ThreadId("test-thread".into()),
+ })
+ }
+
+ async fn send_user_message(
+ &self,
+ request: acp::SendUserMessageParams,
+ ) -> Result<acp::SendUserMessageResponse> {
+ let mut cx = self.cx.clone();
+ let handler = self
+ .server
+ .update(&mut cx, |server, _| server.on_user_message.clone())
+ .ok()
+ .flatten();
+ if let Some(handler) = handler {
+ handler(request, self.server.clone(), self.cx.clone()).await
+ } else {
+ anyhow::bail!("No handler for on_user_message")
+ }
+ }
+ }
+
+ impl FakeAcpServer {
+ fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
+ let agent = FakeAgent {
+ server: cx.entity(),
+ cx: cx.to_async(),
+ };
+
+ let (connection, handler_fut, io_fut) =
+ acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin);
+ FakeAcpServer {
+ connection: connection,
+ on_user_message: None,
+ _handler_task: cx.foreground_executor().spawn(handler_fut),
+ _io_task: cx.background_spawn(async move {
+ io_fut.await.log_err();
+ }),
+ }
+ }
+
+ fn on_user_message<F>(
+ &mut self,
+ handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
+ + 'static,
+ ) where
+ F: Future<Output = Result<acp::SendUserMessageResponse>> + 'static,
+ {
+ self.on_user_message
+ .replace(Rc::new(move |request, server, cx| {
+ handler(request, server, cx).boxed_local()
+ }));
+ }
+ }
}
@@ -159,6 +159,38 @@ impl AcpServer {
})
}
+ #[cfg(test)]
+ pub fn fake(
+ stdin: async_pipe::PipeWriter,
+ stdout: async_pipe::PipeReader,
+ project: Entity<Project>,
+ cx: &mut App,
+ ) -> Arc<Self> {
+ let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
+ let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
+ AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
+ stdin,
+ stdout,
+ );
+
+ let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
+ let io_task = cx.background_spawn({
+ async move {
+ io_fut.await.log_err();
+ // todo!() exit status?
+ }
+ });
+
+ Arc::new(Self {
+ project,
+ connection: Arc::new(connection),
+ threads,
+ exit_status,
+ _handler_task: cx.foreground_executor().spawn(handler_fut),
+ _io_task: io_task,
+ })
+ }
+
pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
self.connection
.request(acp::InitializeParams)