Get agent2 compiling

Max Brunsfeld , Conrad Irwin , and Antonio Scandurra created

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/agent2/src/acp.rs    | 153 ++++++++++++++++++++------------------
crates/agent2/src/agent2.rs |  55 +++++++------
2 files changed, 109 insertions(+), 99 deletions(-)

Detailed changes

crates/agent2/src/acp.rs 🔗

@@ -1,19 +1,15 @@
-use std::{
-    io::{Cursor, Write as _},
-    path::Path,
-    sync::{Arc, Weak},
-};
+use std::{io::Write as _, path::Path, sync::Arc};
 
 use crate::{
-    Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk,
-    ResponseEvent, Role, Thread, ThreadEntry, ThreadId,
+    Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role,
+    Thread, ThreadEntryId, ThreadId,
 };
-use agentic_coding_protocol::{self as acp, TurnId};
-use anyhow::{Context as _, Result};
+use agentic_coding_protocol as acp;
+use anyhow::{Context as _, Result, anyhow};
 use async_trait::async_trait;
 use collections::HashMap;
 use futures::channel::mpsc::UnboundedReceiver;
-use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 use parking_lot::Mutex;
 use project::Project;
 use smol::process::Child;
@@ -21,18 +17,43 @@ use util::ResultExt;
 
 pub struct AcpAgent {
     connection: Arc<acp::AgentConnection>,
-    threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
+    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
+    project: Entity<Project>,
     _handler_task: Task<()>,
     _io_task: Task<()>,
 }
 
 struct AcpClientDelegate {
     project: Entity<Project>,
-    threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
+    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
     cx: AsyncApp,
     // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 }
 
+impl AcpClientDelegate {
+    fn new(project: Entity<Project>, cx: AsyncApp) -> Self {
+        Self {
+            project,
+            threads: Default::default(),
+            cx: cx,
+        }
+    }
+
+    fn update_thread<R>(
+        &self,
+        thread_id: &ThreadId,
+        cx: &mut App,
+        callback: impl FnMut(&mut Thread, &mut Context<Thread>) -> R,
+    ) -> Option<R> {
+        let thread = self.threads.lock().get(&thread_id)?.clone();
+        let Some(thread) = thread.upgrade() else {
+            self.threads.lock().remove(&thread_id);
+            return None;
+        };
+        Some(thread.update(cx, callback))
+    }
+}
+
 #[async_trait(?Send)]
 impl acp::Client for AcpClientDelegate {
     async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
@@ -58,7 +79,7 @@ impl acp::Client for AcpClientDelegate {
 
     async fn stream_message_chunk(
         &self,
-        request: acp::StreamMessageChunkParams,
+        chunk: acp::StreamMessageChunkParams,
     ) -> Result<acp::StreamMessageChunkResponse> {
         Ok(acp::StreamMessageChunkResponse)
     }
@@ -78,25 +99,23 @@ impl acp::Client for AcpClientDelegate {
             })??
             .await?;
 
-        buffer.update(cx, |buffer, _| {
+        buffer.update(cx, |buffer, cx| {
             let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
             let end = match request.line_limit {
                 None => buffer.max_point(),
                 Some(limit) => start + language::Point::new(limit + 1, 0),
             };
 
-            let content = buffer.text_for_range(start..end).collect();
-
-            if let Some(thread) = self.threads.lock().get(&request.thread_id) {
-                thread.update(cx, |thread, cx| {
-                    thread.push_entry(ThreadEntry {
-                        content: AgentThreadEntryContent::ReadFile {
-                            path: request.path.clone(),
-                            content: content.clone(),
-                        },
-                    });
-                })
-            }
+            let content: String = buffer.text_for_range(start..end).collect();
+            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+                thread.push_entry(
+                    AgentThreadEntryContent::ReadFile {
+                        path: request.path.clone(),
+                        content: content.clone(),
+                    },
+                    cx,
+                );
+            });
 
             acp::ReadTextFileResponse {
                 content,
@@ -135,7 +154,7 @@ impl acp::Client for AcpClientDelegate {
 
                 let mut base64_content = Vec::new();
                 let mut base64_encoder = base64::write::EncoderWriter::new(
-                    Cursor::new(&mut base64_content),
+                    std::io::Cursor::new(&mut base64_content),
                     &base64::engine::general_purpose::STANDARD,
                 );
                 base64_encoder.write_all(range_content)?;
@@ -168,10 +187,7 @@ impl AcpAgent {
         let stdout = process.stdout.take().expect("process didn't have stdout");
 
         let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
-            AcpClientDelegate {
-                project,
-                cx: cx.clone(),
-            },
+            AcpClientDelegate::new(project.clone(), cx.clone()),
             stdin,
             stdout,
         );
@@ -182,17 +198,18 @@ impl AcpAgent {
         });
 
         Self {
+            project,
             connection: Arc::new(connection),
-            threads: Mutex::default(),
+            threads: Default::default(),
             _handler_task: cx.foreground_executor().spawn(handler_fut),
             _io_task: io_task,
         }
     }
 }
 
-#[async_trait]
+#[async_trait(?Send)]
 impl Agent for AcpAgent {
-    async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
+    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>> {
         let response = self.connection.request(acp::GetThreadsParams).await?;
         response
             .threads
@@ -207,31 +224,34 @@ impl Agent for AcpAgent {
             .collect()
     }
 
-    async fn create_thread(&self) -> Result<Arc<Self::Thread>> {
+    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
         let response = self.connection.request(acp::CreateThreadParams).await?;
-        let thread = Arc::new(AcpAgentThread {
-            id: response.thread_id.clone(),
-            connection: self.connection.clone(),
-            state: Mutex::new(AcpAgentThreadState {
-                turn: None,
-                next_turn_id: TurnId::default(),
-            }),
-        });
-        self.threads
-            .lock()
-            .insert(response.thread_id, Arc::downgrade(&thread));
+        let thread_id: ThreadId = response.thread_id.into();
+        let agent = self.clone();
+        let thread = cx.new(|_| Thread {
+            id: thread_id.clone(),
+            next_entry_id: ThreadEntryId(0),
+            entries: Vec::default(),
+            project: self.project.clone(),
+            agent,
+        })?;
+        self.threads.lock().insert(thread_id, thread.downgrade());
         Ok(thread)
     }
 
-    async fn open_thread(&self, id: ThreadId) -> Result<Thread> {
+    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
         todo!()
     }
 
-    async fn thread_entries(&self, thread_id: ThreadId) -> Result<Vec<AgentThreadEntryContent>> {
+    async fn thread_entries(
+        &self,
+        thread_id: ThreadId,
+        cx: &mut AsyncApp,
+    ) -> Result<Vec<AgentThreadEntryContent>> {
         let response = self
             .connection
             .request(acp::GetThreadEntriesParams {
-                thread_id: self.id.clone(),
+                thread_id: thread_id.clone().into(),
             })
             .await?;
 
@@ -265,18 +285,18 @@ impl Agent for AcpAgent {
         &self,
         thread_id: ThreadId,
         message: crate::Message,
+        cx: &mut AsyncApp,
     ) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
-        let turn_id = {
-            let mut state = self.state.lock();
-            let turn_id = state.next_turn_id.post_inc();
-            state.turn = Some(AcpAgentThreadTurn { id: turn_id });
-            turn_id
-        };
+        let thread = self
+            .threads
+            .lock()
+            .get(&thread_id)
+            .cloned()
+            .ok_or_else(|| anyhow!("no such thread"))?;
         let response = self
             .connection
             .request(acp::SendMessageParams {
-                thread_id: self.id.clone(),
-                turn_id,
+                thread_id: thread_id.clone().into(),
                 message: acp::Message {
                     role: match message.role {
                         Role::User => acp::Role::User,
@@ -301,29 +321,14 @@ impl Agent for AcpAgent {
     }
 }
 
-pub struct AcpAgentThread {
-    id: acp::ThreadId,
-    connection: Arc<acp::AgentConnection>,
-    state: Mutex<AcpAgentThreadState>,
-}
-
-struct AcpAgentThreadState {
-    next_turn_id: acp::TurnId,
-    turn: Option<AcpAgentThreadTurn>,
-}
-
-struct AcpAgentThreadTurn {
-    id: acp::TurnId,
-}
-
 impl From<acp::ThreadId> for ThreadId {
     fn from(thread_id: acp::ThreadId) -> Self {
-        Self(thread_id.0)
+        Self(thread_id.0.into())
     }
 }
 
 impl From<ThreadId> for acp::ThreadId {
     fn from(thread_id: ThreadId) -> Self {
-        acp::ThreadId(thread_id.0)
+        acp::ThreadId(thread_id.0.to_string())
     }
 }

crates/agent2/src/agent2.rs 🔗

@@ -13,16 +13,21 @@ use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
 use project::Project;
 use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
 
-#[async_trait]
+#[async_trait(?Send)]
 pub trait Agent: 'static {
-    async fn threads(&self) -> Result<Vec<AgentThreadSummary>>;
-    async fn create_thread(&self) -> Result<Entity<Thread>>;
-    async fn open_thread(&self, id: ThreadId) -> Result<Entity<Thread>>;
-    async fn thread_entries(&self, id: ThreadId) -> Result<Vec<AgentThreadEntryContent>>;
+    async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
+    async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
+    async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
+    async fn thread_entries(
+        &self,
+        id: ThreadId,
+        cx: &mut AsyncApp,
+    ) -> Result<Vec<AgentThreadEntryContent>>;
     async fn send_thread_message(
         &self,
         thread_id: ThreadId,
         message: Message,
+        cx: &mut AsyncApp,
     ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
 }
 
@@ -53,7 +58,7 @@ impl ReadFileRequest {
     }
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub struct ThreadId(SharedString);
 
 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
@@ -145,20 +150,20 @@ pub struct ThreadEntry {
     pub content: AgentThreadEntryContent,
 }
 
-pub struct ThreadStore<T: Agent> {
+pub struct ThreadStore {
     threads: Vec<AgentThreadSummary>,
-    agent: Arc<T>,
+    agent: Arc<dyn Agent>,
     project: Entity<Project>,
 }
 
-impl<T: Agent> ThreadStore<T> {
+impl ThreadStore {
     pub async fn load(
-        agent: Arc<T>,
+        agent: Arc<dyn Agent>,
         project: Entity<Project>,
         cx: &mut AsyncApp,
     ) -> Result<Entity<Self>> {
-        let threads = agent.threads().await?;
-        cx.new(|cx| Self {
+        let threads = agent.threads(cx).await?;
+        cx.new(|_cx| Self {
             threads,
             agent,
             project,
@@ -177,21 +182,13 @@ impl<T: Agent> ThreadStore<T> {
         cx: &mut Context<Self>,
     ) -> Task<Result<Entity<Thread>>> {
         let agent = self.agent.clone();
-        let project = self.project.clone();
-        cx.spawn(async move |_, cx| {
-            let agent_thread = agent.open_thread(id).await?;
-            Thread::load(agent_thread, project, cx).await
-        })
+        cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
     }
 
     /// Creates a new thread.
     pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
         let agent = self.agent.clone();
-        let project = self.project.clone();
-        cx.spawn(async move |_, cx| {
-            let agent_thread = agent.create_thread().await?;
-            Thread::load(agent_thread, project, cx).await
-        })
+        cx.spawn(async move |_, cx| agent.create_thread(cx).await)
     }
 }
 
@@ -210,7 +207,7 @@ impl Thread {
         project: Entity<Project>,
         cx: &mut AsyncApp,
     ) -> Result<Entity<Self>> {
-        let entries = agent.thread_entries(thread_id.clone()).await?;
+        let entries = agent.thread_entries(thread_id.clone(), cx).await?;
         cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
     }
 
@@ -241,11 +238,19 @@ impl Thread {
         &self.entries
     }
 
+    pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
+        self.entries.push(ThreadEntry {
+            id: self.next_entry_id.post_inc(),
+            content: entry,
+        });
+        cx.notify();
+    }
+
     pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
         let agent = self.agent.clone();
-        let id = self.id;
+        let id = self.id.clone();
         cx.spawn(async move |this, cx| {
-            let mut events = agent.send_thread_message(id, message).await?;
+            let mut events = agent.send_thread_message(id, message, cx).await?;
             let mut pending_event_handlers = FuturesUnordered::new();
 
             loop {