diff --git a/crates/agent2/src/acp.rs b/crates/agent2/src/acp.rs index a9fbc7ac28b49e4e815eed273945580bf2fc2eba..d6a1befb3c4a502a0fe4429f2908ae619a4baa10 100644 --- a/crates/agent2/src/acp.rs +++ b/crates/agent2/src/acp.rs @@ -1,7 +1,9 @@ use std::path::Path; -use crate::{Agent, AgentThread, AgentThreadEntry, AgentThreadSummary, ResponseEvent, ThreadId}; -use agentic_coding_protocol as acp; +use crate::{ + Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, ResponseEvent, ThreadId, +}; +use agentic_coding_protocol::{self as acp}; use anyhow::{Context as _, Result}; use async_trait::async_trait; use futures::channel::mpsc::UnboundedReceiver; @@ -45,6 +47,10 @@ impl acp::Client for AcpClientDelegate { async fn glob_search(&self, request: acp::GlobSearchParams) -> Result { todo!() } + + async fn end_turn(&self, request: acp::EndTurnParams) -> Result { + todo!() + } } impl AcpAgent { @@ -78,33 +84,38 @@ impl Agent for AcpAgent { type Thread = AcpAgentThread; async fn threads(&self) -> Result> { - let threads = self.connection.request(acp::ListThreadsParams).await?; - threads + let response = self.connection.request(acp::GetThreadsParams).await?; + response .threads .into_iter() .map(|thread| { Ok(AgentThreadSummary { - id: ThreadId(thread.id.0), + id: thread.id.into(), title: thread.title, - created_at: thread.created_at, + created_at: thread.modified_at, }) }) .collect() } async fn create_thread(&self) -> Result { - todo!() + let response = self.connection.request(acp::CreateThreadParams).await?; + Ok(AcpAgentThread { + id: response.thread_id, + }) } - async fn open_thread(&self, id: crate::ThreadId) -> Result { + async fn open_thread(&self, id: ThreadId) -> Result { todo!() } } -pub struct AcpAgentThread {} +pub struct AcpAgentThread { + id: acp::ThreadId, +} impl AgentThread for AcpAgentThread { - async fn entries(&self) -> Result> { + async fn entries(&self) -> Result> { todo!() } @@ -115,3 +126,15 @@ impl AgentThread for AcpAgentThread { todo!() } } + +impl From for ThreadId { + fn from(thread_id: acp::ThreadId) -> Self { + Self(thread_id.0) + } +} + +impl From for acp::ThreadId { + fn from(thread_id: ThreadId) -> Self { + acp::ThreadId(thread_id.0) + } +} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index d98fb05d6872b73c231d173a22601c2fb0434880..cd6ce3aa92fd2bf230fc01e714d18cda2635be17 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -8,7 +8,7 @@ use futures::{ select_biased, stream::{BoxStream, FuturesUnordered}, }; -use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; +use gpui::{AppContext, AsyncApp, Context, Entity, Task}; use project::Project; use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc}; @@ -21,7 +21,7 @@ pub trait Agent: 'static { } pub trait AgentThread: 'static { - fn entries(&self) -> impl Future>>; + fn entries(&self) -> impl Future>>; fn send( &self, message: Message, @@ -58,36 +58,36 @@ impl ReadFileRequest { #[derive(Debug, Clone)] pub struct ThreadId(String); -#[derive(Debug, Clone, Copy)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct FileVersion(u64); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct AgentThreadSummary { pub id: ThreadId, pub title: String, pub created_at: DateTime, } -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Eq)] pub struct FileContent { pub path: PathBuf, pub version: FileVersion, pub content: String, } -#[derive(Debug, Clone)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum Role { User, Assistant, } -#[derive(Debug, Clone)] +#[derive(Debug, Eq, PartialEq)] pub struct Message { pub role: Role, pub chunks: Vec, } -#[derive(Debug, Clone)] +#[derive(Debug, Eq, PartialEq)] pub enum MessageChunk { Text { chunk: String, @@ -108,7 +108,7 @@ pub enum MessageChunk { }, Thread { title: String, - content: Vec, + content: Vec, }, Fetch { url: String, @@ -116,9 +116,18 @@ pub enum MessageChunk { }, } -#[derive(Debug, Clone)] -pub enum AgentThreadEntry { +impl From<&str> for MessageChunk { + fn from(chunk: &str) -> Self { + MessageChunk::Text { + chunk: chunk.to_string(), + } + } +} + +#[derive(Debug, Eq, PartialEq)] +pub enum AgentThreadEntryContent { Message(Message), + ReadFile { path: PathBuf, content: String }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -132,10 +141,10 @@ impl ThreadEntryId { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ThreadEntry { pub id: ThreadEntryId, - pub entry: AgentThreadEntry, + pub content: AgentThreadEntryContent, } pub struct ThreadStore { @@ -207,7 +216,7 @@ impl Thread { pub fn new( agent_thread: Arc, - entries: Vec, + entries: Vec, project: Entity, cx: &mut Context, ) -> Self { @@ -217,7 +226,7 @@ impl Thread { .into_iter() .map(|entry| ThreadEntry { id: next_entry_id.post_inc(), - entry, + content: entry, }) .collect(), next_entry_id, @@ -226,48 +235,6 @@ impl Thread { } } - async fn handle_message( - this: WeakEntity, - role: Role, - mut chunks: BoxStream<'static, Result>, - cx: &mut AsyncApp, - ) -> Result<()> { - let entry_id = this.update(cx, |this, cx| { - let entry_id = this.next_entry_id.post_inc(); - this.entries.push(ThreadEntry { - id: entry_id, - entry: AgentThreadEntry::Message(Message { - role, - chunks: Vec::new(), - }), - }); - cx.notify(); - entry_id - })?; - - while let Some(chunk) = chunks.next().await { - match chunk { - Ok(chunk) => { - this.update(cx, |this, cx| { - let ix = this - .entries - .binary_search_by_key(&entry_id, |entry| entry.id) - .map_err(|_| anyhow!("message not found"))?; - let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else { - unreachable!() - }; - message.chunks.push(chunk); - cx.notify(); - anyhow::Ok(()) - })??; - } - Err(err) => todo!("show error"), - } - } - - Ok(()) - } - pub fn entries(&self) -> &[ThreadEntry] { &self.entries } @@ -279,13 +246,16 @@ impl Thread { let mut pending_event_handlers = FuturesUnordered::new(); loop { - let mut next_event_handler_result = pin!(async { - if pending_event_handlers.is_empty() { - future::pending::<()>().await; - } + let mut next_event_handler_result = pin!( + async { + if pending_event_handlers.is_empty() { + future::pending::<()>().await; + } - pending_event_handlers.next().await - }.fuse()); + pending_event_handlers.next().await + } + .fuse() + ); select_biased! { event = events.next() => { @@ -329,7 +299,7 @@ impl Thread { let entry_id = self.next_entry_id.post_inc(); self.entries.push(ThreadEntry { id: entry_id, - entry: AgentThreadEntry::Message(Message { + content: AgentThreadEntryContent::Message(Message { role: message.role, chunks: Vec::new(), }), @@ -345,7 +315,8 @@ impl Thread { .entries .binary_search_by_key(&entry_id, |entry| entry.id) .map_err(|_| anyhow!("message not found"))?; - let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry + let AgentThreadEntryContent::Message(message) = + &mut this.entries[ix].content else { unreachable!() }; @@ -392,7 +363,7 @@ mod tests { } #[gpui::test] - async fn test_basic(cx: &mut TestAppContext) { + async fn test_gemini(cx: &mut TestAppContext) { init_test(cx); cx.executor().allow_parking(); @@ -400,7 +371,7 @@ mod tests { let fs = FakeFs::new(cx.executor()); fs.insert_tree( path!("/test"), - json!({"foo": "foo", "bar": "bar", "baz": "baz"}), + json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), ) .await; let project = Project::test(fs, [path!("/test").as_ref()], cx).await; @@ -408,12 +379,47 @@ mod tests { let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async()) .await .unwrap(); + let thread = thread_store + .update(cx, |thread_store, cx| { + assert_eq!(thread_store.threads().len(), 0); + thread_store.create_thread(cx) + }) + .await + .unwrap(); + thread + .update(cx, |thread, cx| { + thread.send( + Message { + role: Role::User, + chunks: vec![ + "Read the 'test/foo' file and output all of its contents.".into(), + ], + }, + cx, + ) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert!( + thread.entries().iter().any(|entry| { + entry.content + == AgentThreadEntryContent::ReadFile { + path: "test/foo".into(), + content: "Lorem ipsum dolor".into(), + } + }), + "Thread does not contain entry. Actual: {:?}", + thread.entries() + ); + }); } pub fn gemini_agent(project: Entity, cx: AsyncApp) -> Result { let child = util::command::new_smol_command("node") .arg("../../../gemini-cli/packages/cli") .arg("--acp") + // .args(["--model", "gemini-2.5-flash"]) .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap()) .stdin(Stdio::piped()) .stdout(Stdio::piped())