passing roundtrip test

Conrad Irwin , Ben Brandt , Max Brunsfeld , and Agus Zubiaga created

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/agent2/src/acp.rs    |  25 +++---
crates/agent2/src/agent2.rs | 142 ++------------------------------------
2 files changed, 22 insertions(+), 145 deletions(-)

Detailed changes

crates/agent2/src/acp.rs 🔗

@@ -1,14 +1,13 @@
 use std::{io::Write as _, path::Path, sync::Arc};
 
 use crate::{
-    Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role,
-    Thread, ThreadEntryId, ThreadId,
+    Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, Role, Thread,
+    ThreadEntryId, ThreadId,
 };
 use agentic_coding_protocol as acp;
 use anyhow::{Context as _, Result, anyhow};
 use async_trait::async_trait;
 use collections::HashMap;
-use futures::channel::mpsc::UnboundedReceiver;
 use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
 use parking_lot::Mutex;
 use project::Project;
@@ -31,10 +30,14 @@ struct AcpClientDelegate {
 }
 
 impl AcpClientDelegate {
-    fn new(project: Entity<Project>, cx: AsyncApp) -> Self {
+    fn new(
+        project: Entity<Project>,
+        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
+        cx: AsyncApp,
+    ) -> Self {
         Self {
             project,
-            threads: Default::default(),
+            threads,
             cx: cx,
         }
     }
@@ -186,8 +189,9 @@ impl AcpAgent {
         let stdin = process.stdin.take().expect("process didn't have stdin");
         let stdout = process.stdout.take().expect("process didn't have stdout");
 
+        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>> = Default::default();
         let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
-            AcpClientDelegate::new(project.clone(), cx.clone()),
+            AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
             stdin,
             stdout,
         );
@@ -200,7 +204,7 @@ impl AcpAgent {
         Self {
             project,
             connection: Arc::new(connection),
-            threads: Default::default(),
+            threads,
             _handler_task: cx.foreground_executor().spawn(handler_fut),
             _io_task: io_task,
         }
@@ -286,15 +290,14 @@ impl Agent for AcpAgent {
         thread_id: ThreadId,
         message: crate::Message,
         cx: &mut AsyncApp,
-    ) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
+    ) -> Result<()> {
         let thread = self
             .threads
             .lock()
             .get(&thread_id)
             .cloned()
             .ok_or_else(|| anyhow!("no such thread"))?;
-        let response = self
-            .connection
+        self.connection
             .request(acp::SendMessageParams {
                 thread_id: thread_id.clone().into(),
                 message: acp::Message {
@@ -317,7 +320,7 @@ impl Agent for AcpAgent {
                 },
             })
             .await?;
-        todo!()
+        Ok(())
     }
 }
 

crates/agent2/src/agent2.rs 🔗

@@ -3,15 +3,9 @@ mod acp;
 use anyhow::{Result, anyhow};
 use async_trait::async_trait;
 use chrono::{DateTime, Utc};
-use futures::{
-    FutureExt, StreamExt,
-    channel::{mpsc, oneshot},
-    select_biased,
-    stream::{BoxStream, FuturesUnordered},
-};
 use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
 use project::Project;
-use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
+use std::{ops::Range, path::PathBuf, sync::Arc};
 
 #[async_trait(?Send)]
 pub trait Agent: 'static {
@@ -28,34 +22,7 @@ pub trait Agent: 'static {
         thread_id: ThreadId,
         message: Message,
         cx: &mut AsyncApp,
-    ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
-}
-
-pub enum ResponseEvent {
-    MessageResponse(MessageResponse),
-    ReadFileRequest(ReadFileRequest),
-    // GlobSearchRequest(SearchRequest),
-    // RegexSearchRequest(RegexSearchRequest),
-    // RunCommandRequest(RunCommandRequest),
-    // WebSearchResponse(WebSearchResponse),
-}
-
-pub struct MessageResponse {
-    role: Role,
-    chunks: BoxStream<'static, Result<MessageChunk>>,
-}
-
-#[derive(Debug)]
-pub struct ReadFileRequest {
-    path: PathBuf,
-    range: Range<usize>,
-    response_tx: oneshot::Sender<Result<FileContent>>,
-}
-
-impl ReadFileRequest {
-    pub fn respond(self, content: Result<FileContent>) {
-        self.response_tx.send(content).ok();
-    }
+    ) -> Result<()>;
 }
 
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -250,104 +217,10 @@ impl Thread {
         let agent = self.agent.clone();
         let id = self.id.clone();
         cx.spawn(async move |this, cx| {
-            let mut events = agent.send_thread_message(id, message, cx).await?;
-            let mut pending_event_handlers = FuturesUnordered::new();
-
-            loop {
-                let mut next_event_handler_result = pin!(
-                    async {
-                        if pending_event_handlers.is_empty() {
-                            future::pending::<()>().await;
-                        }
-
-                        pending_event_handlers.next().await
-                    }
-                    .fuse()
-                );
-
-                select_biased! {
-                    event = events.next() => {
-                        let Some(event) = event else {
-                            while let Some(result) = pending_event_handlers.next().await {
-                                result?;
-                            }
-
-                            break;
-                        };
-
-                        let task = match event {
-                            Ok(ResponseEvent::MessageResponse(message)) => {
-                                this.update(cx, |this, cx| this.handle_message_response(message, cx))?
-                            }
-                            Ok(ResponseEvent::ReadFileRequest(request)) => {
-                                this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
-                            }
-                            Err(_) => todo!(),
-                        };
-                        pending_event_handlers.push(task);
-                    }
-                    result = next_event_handler_result => {
-                        // Event handlers should only return errors that are
-                        // unrecoverable and should therefore stop this turn of
-                        // the agentic loop.
-                        result.unwrap()?;
-                    }
-                }
-            }
-
+            agent.send_thread_message(id, message, cx).await?;
             Ok(())
         })
     }
-
-    fn handle_message_response(
-        &mut self,
-        mut message: MessageResponse,
-        cx: &mut Context<Self>,
-    ) -> Task<Result<()>> {
-        let entry_id = self.next_entry_id.post_inc();
-        self.entries.push(ThreadEntry {
-            id: entry_id,
-            content: AgentThreadEntryContent::Message(Message {
-                role: message.role,
-                chunks: Vec::new(),
-            }),
-        });
-        cx.notify();
-
-        cx.spawn(async move |this, cx| {
-            while let Some(chunk) = message.chunks.next().await {
-                match chunk {
-                    Ok(chunk) => {
-                        this.update(cx, |this, cx| {
-                            let ix = this
-                                .entries
-                                .binary_search_by_key(&entry_id, |entry| entry.id)
-                                .map_err(|_| anyhow!("message not found"))?;
-                            let AgentThreadEntryContent::Message(message) =
-                                &mut this.entries[ix].content
-                            else {
-                                unreachable!()
-                            };
-                            message.chunks.push(chunk);
-                            cx.notify();
-                            anyhow::Ok(())
-                        })??;
-                    }
-                    Err(err) => todo!("show error"),
-                }
-            }
-
-            Ok(())
-        })
-    }
-
-    fn handle_read_file_request(
-        &mut self,
-        request: ReadFileRequest,
-        cx: &mut Context<Self>,
-    ) -> Task<Result<()>> {
-        todo!()
-    }
 }
 
 #[cfg(test)]
@@ -367,6 +240,7 @@ mod tests {
             let settings_store = SettingsStore::test(cx);
             cx.set_global(settings_store);
             Project::init_settings(cx);
+            language::init(cx);
         });
     }
 
@@ -378,11 +252,11 @@ mod tests {
 
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(
-            path!("/test"),
+            path!("/tmp"),
             json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
         )
         .await;
-        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let project = Project::test(fs, [path!("/tmp").as_ref()], cx).await;
         let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
         let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
             .await
@@ -400,7 +274,7 @@ mod tests {
                     Message {
                         role: Role::User,
                         chunks: vec![
-                            "Read the 'test/foo' file and output all of its contents.".into(),
+                            "Read the '/tmp/foo' file and output all of its contents.".into(),
                         ],
                     },
                     cx,
@@ -413,7 +287,7 @@ mod tests {
                 thread.entries().iter().any(|entry| {
                     entry.content
                         == AgentThreadEntryContent::ReadFile {
-                            path: "test/foo".into(),
+                            path: "/tmp/foo".into(),
                             content: "Lorem ipsum dolor".into(),
                         }
                 }),