Implement ACP threads

Ben Brandt created

The `create_thread` and `get_threads` methods are now implemented for
the ACP agent. A test is added to verify the file reading flow.

Change summary

crates/agent2/src/acp.rs    |  43 +++++++++--
crates/agent2/src/agent2.rs | 140 ++++++++++++++++++++------------------
2 files changed, 106 insertions(+), 77 deletions(-)

Detailed changes

crates/agent2/src/acp.rs 🔗

@@ -1,7 +1,9 @@
 use std::path::Path;
 
-use crate::{Agent, AgentThread, AgentThreadEntry, AgentThreadSummary, ResponseEvent, ThreadId};
-use agentic_coding_protocol as acp;
+use crate::{
+    Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, ResponseEvent, ThreadId,
+};
+use agentic_coding_protocol::{self as acp};
 use anyhow::{Context as _, Result};
 use async_trait::async_trait;
 use futures::channel::mpsc::UnboundedReceiver;
@@ -45,6 +47,10 @@ impl acp::Client for AcpClientDelegate {
     async fn glob_search(&self, request: acp::GlobSearchParams) -> Result<acp::GlobSearchResponse> {
         todo!()
     }
+
+    async fn end_turn(&self, request: acp::EndTurnParams) -> Result<acp::EndTurnResponse> {
+        todo!()
+    }
 }
 
 impl AcpAgent {
@@ -78,33 +84,38 @@ impl Agent for AcpAgent {
     type Thread = AcpAgentThread;
 
     async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
-        let threads = self.connection.request(acp::ListThreadsParams).await?;
-        threads
+        let response = self.connection.request(acp::GetThreadsParams).await?;
+        response
             .threads
             .into_iter()
             .map(|thread| {
                 Ok(AgentThreadSummary {
-                    id: ThreadId(thread.id.0),
+                    id: thread.id.into(),
                     title: thread.title,
-                    created_at: thread.created_at,
+                    created_at: thread.modified_at,
                 })
             })
             .collect()
     }
 
     async fn create_thread(&self) -> Result<Self::Thread> {
-        todo!()
+        let response = self.connection.request(acp::CreateThreadParams).await?;
+        Ok(AcpAgentThread {
+            id: response.thread_id,
+        })
     }
 
-    async fn open_thread(&self, id: crate::ThreadId) -> Result<Self::Thread> {
+    async fn open_thread(&self, id: ThreadId) -> Result<Self::Thread> {
         todo!()
     }
 }
 
-pub struct AcpAgentThread {}
+pub struct AcpAgentThread {
+    id: acp::ThreadId,
+}
 
 impl AgentThread for AcpAgentThread {
-    async fn entries(&self) -> Result<Vec<AgentThreadEntry>> {
+    async fn entries(&self) -> Result<Vec<AgentThreadEntryContent>> {
         todo!()
     }
 
@@ -115,3 +126,15 @@ impl AgentThread for AcpAgentThread {
         todo!()
     }
 }
+
+impl From<acp::ThreadId> for ThreadId {
+    fn from(thread_id: acp::ThreadId) -> Self {
+        Self(thread_id.0)
+    }
+}
+
+impl From<ThreadId> for acp::ThreadId {
+    fn from(thread_id: ThreadId) -> Self {
+        acp::ThreadId(thread_id.0)
+    }
+}

crates/agent2/src/agent2.rs 🔗

@@ -8,7 +8,7 @@ use futures::{
     select_biased,
     stream::{BoxStream, FuturesUnordered},
 };
-use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
+use gpui::{AppContext, AsyncApp, Context, Entity, Task};
 use project::Project;
 use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 
@@ -21,7 +21,7 @@ pub trait Agent: 'static {
 }
 
 pub trait AgentThread: 'static {
-    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
+    fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntryContent>>>;
     fn send(
         &self,
         message: Message,
@@ -58,36 +58,36 @@ impl ReadFileRequest {
 #[derive(Debug, Clone)]
 pub struct ThreadId(String);
 
-#[derive(Debug, Clone, Copy)]
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 pub struct FileVersion(u64);
 
-#[derive(Debug, Clone)]
+#[derive(Debug)]
 pub struct AgentThreadSummary {
     pub id: ThreadId,
     pub title: String,
     pub created_at: DateTime<Utc>,
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, PartialEq, Eq)]
 pub struct FileContent {
     pub path: PathBuf,
     pub version: FileVersion,
     pub content: String,
 }
 
-#[derive(Debug, Clone)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum Role {
     User,
     Assistant,
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Eq, PartialEq)]
 pub struct Message {
     pub role: Role,
     pub chunks: Vec<MessageChunk>,
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Eq, PartialEq)]
 pub enum MessageChunk {
     Text {
         chunk: String,
@@ -108,7 +108,7 @@ pub enum MessageChunk {
     },
     Thread {
         title: String,
-        content: Vec<AgentThreadEntry>,
+        content: Vec<AgentThreadEntryContent>,
     },
     Fetch {
         url: String,
@@ -116,9 +116,18 @@ pub enum MessageChunk {
     },
 }
 
-#[derive(Debug, Clone)]
-pub enum AgentThreadEntry {
+impl From<&str> for MessageChunk {
+    fn from(chunk: &str) -> Self {
+        MessageChunk::Text {
+            chunk: chunk.to_string(),
+        }
+    }
+}
+
+#[derive(Debug, Eq, PartialEq)]
+pub enum AgentThreadEntryContent {
     Message(Message),
+    ReadFile { path: PathBuf, content: String },
 }
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
@@ -132,10 +141,10 @@ impl ThreadEntryId {
     }
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug)]
 pub struct ThreadEntry {
     pub id: ThreadEntryId,
-    pub entry: AgentThreadEntry,
+    pub content: AgentThreadEntryContent,
 }
 
 pub struct ThreadStore<T: Agent> {
@@ -207,7 +216,7 @@ impl<T: AgentThread> Thread<T> {
 
     pub fn new(
         agent_thread: Arc<T>,
-        entries: Vec<AgentThreadEntry>,
+        entries: Vec<AgentThreadEntryContent>,
         project: Entity<Project>,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -217,7 +226,7 @@ impl<T: AgentThread> Thread<T> {
                 .into_iter()
                 .map(|entry| ThreadEntry {
                     id: next_entry_id.post_inc(),
-                    entry,
+                    content: entry,
                 })
                 .collect(),
             next_entry_id,
@@ -226,48 +235,6 @@ impl<T: AgentThread> Thread<T> {
         }
     }
 
-    async fn handle_message(
-        this: WeakEntity<Self>,
-        role: Role,
-        mut chunks: BoxStream<'static, Result<MessageChunk>>,
-        cx: &mut AsyncApp,
-    ) -> Result<()> {
-        let entry_id = this.update(cx, |this, cx| {
-            let entry_id = this.next_entry_id.post_inc();
-            this.entries.push(ThreadEntry {
-                id: entry_id,
-                entry: AgentThreadEntry::Message(Message {
-                    role,
-                    chunks: Vec::new(),
-                }),
-            });
-            cx.notify();
-            entry_id
-        })?;
-
-        while let Some(chunk) = chunks.next().await {
-            match chunk {
-                Ok(chunk) => {
-                    this.update(cx, |this, cx| {
-                        let ix = this
-                            .entries
-                            .binary_search_by_key(&entry_id, |entry| entry.id)
-                            .map_err(|_| anyhow!("message not found"))?;
-                        let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry else {
-                            unreachable!()
-                        };
-                        message.chunks.push(chunk);
-                        cx.notify();
-                        anyhow::Ok(())
-                    })??;
-                }
-                Err(err) => todo!("show error"),
-            }
-        }
-
-        Ok(())
-    }
-
     pub fn entries(&self) -> &[ThreadEntry] {
         &self.entries
     }
@@ -279,13 +246,16 @@ impl<T: AgentThread> Thread<T> {
             let mut pending_event_handlers = FuturesUnordered::new();
 
             loop {
-                let mut next_event_handler_result = pin!(async {
-                    if pending_event_handlers.is_empty() {
-                        future::pending::<()>().await;
-                    }
+                let mut next_event_handler_result = pin!(
+                    async {
+                        if pending_event_handlers.is_empty() {
+                            future::pending::<()>().await;
+                        }
 
-                    pending_event_handlers.next().await
-                }.fuse());
+                        pending_event_handlers.next().await
+                    }
+                    .fuse()
+                );
 
                 select_biased! {
                     event = events.next() => {
@@ -329,7 +299,7 @@ impl<T: AgentThread> Thread<T> {
         let entry_id = self.next_entry_id.post_inc();
         self.entries.push(ThreadEntry {
             id: entry_id,
-            entry: AgentThreadEntry::Message(Message {
+            content: AgentThreadEntryContent::Message(Message {
                 role: message.role,
                 chunks: Vec::new(),
             }),
@@ -345,7 +315,8 @@ impl<T: AgentThread> Thread<T> {
                                 .entries
                                 .binary_search_by_key(&entry_id, |entry| entry.id)
                                 .map_err(|_| anyhow!("message not found"))?;
-                            let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
+                            let AgentThreadEntryContent::Message(message) =
+                                &mut this.entries[ix].content
                             else {
                                 unreachable!()
                             };
@@ -392,7 +363,7 @@ mod tests {
     }
 
     #[gpui::test]
-    async fn test_basic(cx: &mut TestAppContext) {
+    async fn test_gemini(cx: &mut TestAppContext) {
         init_test(cx);
 
         cx.executor().allow_parking();
@@ -400,7 +371,7 @@ mod tests {
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(
             path!("/test"),
-            json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
+            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
         )
         .await;
         let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
@@ -408,12 +379,47 @@ mod tests {
         let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
             .await
             .unwrap();
+        let thread = thread_store
+            .update(cx, |thread_store, cx| {
+                assert_eq!(thread_store.threads().len(), 0);
+                thread_store.create_thread(cx)
+            })
+            .await
+            .unwrap();
+        thread
+            .update(cx, |thread, cx| {
+                thread.send(
+                    Message {
+                        role: Role::User,
+                        chunks: vec![
+                            "Read the 'test/foo' file and output all of its contents.".into(),
+                        ],
+                    },
+                    cx,
+                )
+            })
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
+            assert!(
+                thread.entries().iter().any(|entry| {
+                    entry.content
+                        == AgentThreadEntryContent::ReadFile {
+                            path: "test/foo".into(),
+                            content: "Lorem ipsum dolor".into(),
+                        }
+                }),
+                "Thread does not contain entry. Actual: {:?}",
+                thread.entries()
+            );
+        });
     }
 
     pub fn gemini_agent(project: Entity<Project>, cx: AsyncApp) -> Result<AcpAgent> {
         let child = util::command::new_smol_command("node")
             .arg("../../../gemini-cli/packages/cli")
             .arg("--acp")
+            // .args(["--model", "gemini-2.5-flash"])
             .env("GEMINI_API_KEY", env::var("GEMINI_API_KEY").unwrap())
             .stdin(Stdio::piped())
             .stdout(Stdio::piped())