From 08fead27fb362a2047cbed72d744c18698be9e73 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Thu, 3 Jul 2025 10:46:41 -0600 Subject: [PATCH] WIP --- Cargo.lock | 1 + crates/acp/Cargo.toml | 1 + crates/acp/src/acp.rs | 155 ++++++++++++++++++++++++++++++++++++++- crates/acp/src/server.rs | 32 ++++++++ 4 files changed, 186 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b35a7ca94c4c4af9a08ba9aab7931a5234479693..3f88c7ad113311ac931289e427e3a5ec6ac8d887 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,7 @@ version = "0.1.0" dependencies = [ "agentic-coding-protocol", "anyhow", + "async-pipe", "async-trait", "base64 0.22.1", "buffer_diff", diff --git a/crates/acp/Cargo.toml b/crates/acp/Cargo.toml index 9482bcbcffd40826bb8b490a568da77ee3d26336..2c98f2469ce39c68cd7d99627c64a9965d1219a7 100644 --- a/crates/acp/Cargo.toml +++ b/crates/acp/Cargo.toml @@ -42,6 +42,7 @@ workspace-hack.workspace = true zed_actions.workspace = true [dev-dependencies] +async-pipe.workspace = true env_logger.workspace = true gpui = { workspace = true, "features" = ["test-support"] } project = { workspace = true, "features" = ["test-support"] } diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index 08e8ce6cf074f7196d8b8d734905d30994aae91f..2af60300b61af7da554805553ef95a02b0145ff1 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -771,13 +771,15 @@ pub struct ToolCallRequest { #[cfg(test)] mod tests { use super::*; - use futures::{FutureExt as _, channel::mpsc, select}; - use gpui::TestAppContext; + use async_pipe::{PipeReader, PipeWriter}; + use async_trait::async_trait; + use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select}; + use gpui::{AsyncApp, TestAppContext}; use project::FakeFs; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{env, path::Path, process::Stdio, time::Duration}; + use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration}; use util::path; fn init_test(cx: &mut TestAppContext) { @@ -790,6 +792,42 @@ mod tests { }); } + #[gpui::test] + async fn test_message_receipt(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (server, fake_server) = fake_acp_server(project, cx); + + server.initialize().await.unwrap(); + + fake_server.update(cx, |fake_server, _| { + fake_server.on_user_message(move |params, server, mut cx| async move { + server + .update(&mut cx, |server, cx| { + let future = + server + .connection + .request(acp::StreamAssistantMessageChunkParams { + thread_id: params.thread_id, + chunk: acp::AssistantMessageChunk::Thought { + chunk: "Thinking ".into(), + }, + }); + + cx.spawn(async move |_, _| future.await) + })? + .await + .unwrap(); + + Ok(acp::SendUserMessageResponse) + }) + }) + } + #[gpui::test] async fn test_gemini_basic(cx: &mut TestAppContext) { init_test(cx); @@ -1043,4 +1081,115 @@ mod tests { server.initialize().await.unwrap(); server } + + pub fn fake_acp_server( + project: Entity, + cx: &mut TestAppContext, + ) -> (Arc, Entity) { + let (stdin_tx, stdin_rx) = async_pipe::pipe(); + let (stdout_tx, stdout_rx) = async_pipe::pipe(); + let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx)); + let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); + (server, agent) + } + + pub struct FakeAcpServer { + connection: acp::ClientConnection, + _handler_task: Task<()>, + _io_task: Task<()>, + on_user_message: Option< + Rc< + dyn Fn( + acp::SendUserMessageParams, + Entity, + AsyncApp, + ) + -> LocalBoxFuture<'static, Result>, + >, + >, + } + + #[derive(Clone)] + struct FakeAgent { + server: Entity, + cx: AsyncApp, + } + + #[async_trait(?Send)] + impl acp::Agent for FakeAgent { + async fn initialize( + &self, + _request: acp::InitializeParams, + ) -> Result { + Ok(acp::InitializeResponse { + is_authenticated: true, + }) + } + + async fn authenticate( + &self, + _request: acp::AuthenticateParams, + ) -> Result { + Ok(acp::AuthenticateResponse) + } + + async fn create_thread( + &self, + _request: acp::CreateThreadParams, + ) -> Result { + Ok(acp::CreateThreadResponse { + thread_id: acp::ThreadId("test-thread".into()), + }) + } + + async fn send_user_message( + &self, + request: acp::SendUserMessageParams, + ) -> Result { + let mut cx = self.cx.clone(); + let handler = self + .server + .update(&mut cx, |server, _| server.on_user_message.clone()) + .ok() + .flatten(); + if let Some(handler) = handler { + handler(request, self.server.clone(), self.cx.clone()).await + } else { + anyhow::bail!("No handler for on_user_message") + } + } + } + + impl FakeAcpServer { + fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context) -> Self { + let agent = FakeAgent { + server: cx.entity(), + cx: cx.to_async(), + }; + + let (connection, handler_fut, io_fut) = + acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin); + FakeAcpServer { + connection: connection, + on_user_message: None, + _handler_task: cx.foreground_executor().spawn(handler_fut), + _io_task: cx.background_spawn(async move { + io_fut.await.log_err(); + }), + } + } + + fn on_user_message( + &mut self, + handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity, AsyncApp) -> F + + 'static, + ) where + F: Future> + 'static, + { + self.on_user_message + .replace(Rc::new(move |request, server, cx| { + handler(request, server, cx).boxed_local() + })); + } + } } diff --git a/crates/acp/src/server.rs b/crates/acp/src/server.rs index 29025e9f3cf13a6362ec54405f81ac69dcd1097a..c826f2ff39fc99a628ed2561c8b6435c02cb1983 100644 --- a/crates/acp/src/server.rs +++ b/crates/acp/src/server.rs @@ -159,6 +159,38 @@ impl AcpServer { }) } + #[cfg(test)] + pub fn fake( + stdin: async_pipe::PipeWriter, + stdout: async_pipe::PipeReader, + project: Entity, + cx: &mut App, + ) -> Arc { + let threads: Arc>>> = Default::default(); + let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( + AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()), + stdin, + stdout, + ); + + let exit_status: Arc>> = Default::default(); + let io_task = cx.background_spawn({ + async move { + io_fut.await.log_err(); + // todo!() exit status? + } + }); + + Arc::new(Self { + project, + connection: Arc::new(connection), + threads, + exit_status, + _handler_task: cx.foreground_executor().spawn(handler_fut), + _io_task: io_task, + }) + } + pub async fn initialize(&self) -> Result { self.connection .request(acp::InitializeParams)