wip: request / response in send loop

Ben Brandt , Antonio Scandurra , and Agus Zubiaga created

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

Cargo.lock                  |   3 
crates/agent2/Cargo.toml    |   3 
crates/agent2/src/agent2.rs | 197 +++++++++++++++++++++++++++++++++-----
3 files changed, 177 insertions(+), 26 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -113,6 +113,9 @@ dependencies = [
  "chrono",
  "futures 0.3.31",
  "gpui",
+ "project",
+ "serde_json",
+ "util",
  "uuid",
 ]
 

crates/agent2/Cargo.toml 🔗

@@ -22,7 +22,10 @@ anyhow.workspace = true
 chrono.workspace = true
 futures.workspace = true
 gpui.workspace = true
+project.workspace = true
 uuid.workspace = true
 
 [dev-dependencies]
 gpui = { workspace = true, "features" = ["test-support"] }
+serde_json.workspace = true
+util.workspace = true

crates/agent2/src/agent2.rs 🔗

@@ -1,7 +1,8 @@
 use anyhow::{Result, anyhow};
 use chrono::{DateTime, Utc};
-use futures::{StreamExt, stream::BoxStream};
+use futures::{StreamExt, channel::oneshot, stream::BoxStream};
 use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
+use project::Project;
 use std::{ops::Range, path::PathBuf, sync::Arc};
 use uuid::Uuid;
 
@@ -15,11 +16,36 @@ pub trait Agent: 'static {
 
 pub trait AgentThread: 'static {
     fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
-    fn send(&self, message: Message) -> impl Future<Output = Result<()>>;
-    fn on_message(
+    fn send(
         &self,
-        handler: impl AsyncFn(Role, BoxStream<'static, Result<MessageChunk>>) -> Result<()>,
-    );
+        message: Message,
+    ) -> impl Future<Output = Result<BoxStream<'static, Result<ResponseEvent>>>>;
+}
+
+pub enum ResponseEvent {
+    MessageResponse(MessageResponse),
+    ReadFileRequest(ReadFileRequest),
+    // GlobSearchRequest(SearchRequest),
+    // RegexSearchRequest(RegexSearchRequest),
+    // RunCommandRequest(RunCommandRequest),
+    // WebSearchResponse(WebSearchResponse),
+}
+
+pub struct MessageResponse {
+    role: Role,
+    chunks: BoxStream<'static, Result<MessageChunk>>,
+}
+
+pub struct ReadFileRequest {
+    path: PathBuf,
+    range: Range<usize>,
+    response_tx: oneshot::Sender<Result<FileContent>>,
+}
+
+impl ReadFileRequest {
+    pub fn respond(self, content: Result<FileContent>) {
+        self.response_tx.send(content).ok();
+    }
 }
 
 pub struct ThreadId(Uuid);
@@ -97,14 +123,23 @@ pub struct ThreadEntry {
 }
 
 pub struct ThreadStore<T: Agent> {
-    agent: Arc<T>,
     threads: Vec<AgentThreadSummary>,
+    agent: Arc<T>,
+    project: Entity<Project>,
 }
 
 impl<T: Agent> ThreadStore<T> {
-    pub async fn load(agent: Arc<T>, cx: &mut AsyncApp) -> Result<Entity<Self>> {
+    pub async fn load(
+        agent: Arc<T>,
+        project: Entity<Project>,
+        cx: &mut AsyncApp,
+    ) -> Result<Entity<Self>> {
         let threads = agent.threads().await?;
-        cx.new(|cx| Self { agent, threads })
+        cx.new(|cx| Self {
+            threads,
+            agent,
+            project,
+        })
     }
 
     /// Returns the threads in reverse chronological order.
@@ -119,49 +154,49 @@ impl<T: Agent> ThreadStore<T> {
         cx: &mut Context<Self>,
     ) -> Task<Result<Entity<Thread<T::Thread>>>> {
         let agent = self.agent.clone();
+        let project = self.project.clone();
         cx.spawn(async move |_, cx| {
             let agent_thread = agent.open_thread(id).await?;
-            Thread::load(Arc::new(agent_thread), cx).await
+            Thread::load(Arc::new(agent_thread), project, cx).await
         })
     }
 
     /// Creates a new thread.
     pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
         let agent = self.agent.clone();
+        let project = self.project.clone();
         cx.spawn(async move |_, cx| {
             let agent_thread = agent.create_thread().await?;
-            Thread::load(Arc::new(agent_thread), cx).await
+            Thread::load(Arc::new(agent_thread), project, cx).await
         })
     }
 }
 
 pub struct Thread<T: AgentThread> {
-    agent_thread: Arc<T>,
-    entries: Vec<ThreadEntry>,
     next_entry_id: ThreadEntryId,
+    entries: Vec<ThreadEntry>,
+    agent_thread: Arc<T>,
+    project: Entity<Project>,
 }
 
 impl<T: AgentThread> Thread<T> {
-    pub async fn load(agent_thread: Arc<T>, cx: &mut AsyncApp) -> Result<Entity<Self>> {
+    pub async fn load(
+        agent_thread: Arc<T>,
+        project: Entity<Project>,
+        cx: &mut AsyncApp,
+    ) -> Result<Entity<Self>> {
         let entries = agent_thread.entries().await?;
-        cx.new(|cx| Self::new(agent_thread, entries, cx))
+        cx.new(|cx| Self::new(agent_thread, entries, project, cx))
     }
 
     pub fn new(
         agent_thread: Arc<T>,
         entries: Vec<AgentThreadEntry>,
+        project: Entity<Project>,
         cx: &mut Context<Self>,
     ) -> Self {
-        agent_thread.on_message({
-            let this = cx.weak_entity();
-            let cx = cx.to_async();
-            async move |role, chunks| {
-                Self::handle_message(this.clone(), role, chunks, &mut cx.clone()).await
-            }
-        });
         let mut next_entry_id = ThreadEntryId(0);
         Self {
-            agent_thread,
             entries: entries
                 .into_iter()
                 .map(|entry| ThreadEntry {
@@ -170,6 +205,8 @@ impl<T: AgentThread> Thread<T> {
                 })
                 .collect(),
             next_entry_id,
+            agent_thread,
+            project,
         }
     }
 
@@ -221,24 +258,101 @@ impl<T: AgentThread> Thread<T> {
 
     pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
         let agent_thread = self.agent_thread.clone();
-        cx.spawn(async move |_, cx| agent_thread.send(message).await)
+        cx.spawn(async move |this, cx| {
+            let mut events = agent_thread.send(message).await?;
+            while let Some(event) = events.next().await {
+                match event {
+                    Ok(ResponseEvent::MessageResponse(message)) => {
+                        this.update(cx, |this, cx| this.handle_message_response(message, cx))?
+                            .await?;
+                    }
+                    Ok(ResponseEvent::ReadFileRequest(request)) => {
+                        this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
+                            .await?;
+                    }
+                    Err(_) => todo!(),
+                }
+            }
+            Ok(())
+        })
+    }
+
+    fn handle_message_response(
+        &mut self,
+        mut message: MessageResponse,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        let entry_id = self.next_entry_id.post_inc();
+        self.entries.push(ThreadEntry {
+            id: entry_id,
+            entry: AgentThreadEntry::Message(Message {
+                role: message.role,
+                chunks: Vec::new(),
+            }),
+        });
+        cx.notify();
+
+        cx.spawn(async move |this, cx| {
+            while let Some(chunk) = message.chunks.next().await {
+                match chunk {
+                    Ok(chunk) => {
+                        this.update(cx, |this, cx| {
+                            let ix = this
+                                .entries
+                                .binary_search_by_key(&entry_id, |entry| entry.id)
+                                .map_err(|_| anyhow!("message not found"))?;
+                            let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
+                            else {
+                                unreachable!()
+                            };
+                            message.chunks.push(chunk);
+                            cx.notify();
+                            anyhow::Ok(())
+                        })??;
+                    }
+                    Err(err) => todo!("show error"),
+                }
+            }
+
+            Ok(())
+        })
+    }
+
+    fn handle_read_file_request(
+        &mut self,
+        request: ReadFileRequest,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        todo!()
     }
 }
 
 #[cfg(test)]
 mod tests {
-    use std::path::Path;
-
     use super::*;
     use gpui::{BackgroundExecutor, TestAppContext};
+    use project::FakeFs;
+    use serde_json::json;
+    use std::path::Path;
+    use util::path;
 
     #[gpui::test]
     async fn test_basic(cx: &mut TestAppContext) {
         cx.executor().allow_parking();
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/test"),
+            json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
+        )
+        .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
         let agent = GeminiAgent::start("~/gemini-cli/change-me.js", &cx.executor())
             .await
             .unwrap();
-        let thread_store = ThreadStore::load(Arc::new(agent), &mut cx.to_async()).await.unwrap();
+        let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
+            .await
+            .unwrap();
     }
 
     struct GeminiAgent {}
@@ -248,4 +362,35 @@ mod tests {
             executor.spawn(async move { Ok(GeminiAgent {}) })
         }
     }
+
+    impl Agent for GeminiAgent {
+        type Thread = GeminiAgentThread;
+
+        async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
+            todo!()
+        }
+
+        async fn create_thread(&self) -> Result<Self::Thread> {
+            todo!()
+        }
+
+        async fn open_thread(&self, id: ThreadId) -> Result<Self::Thread> {
+            todo!()
+        }
+    }
+
+    struct GeminiAgentThread {}
+
+    impl AgentThread for GeminiAgentThread {
+        async fn entries(&self) -> Result<Vec<AgentThreadEntry>> {
+            todo!()
+        }
+
+        async fn send(
+            &self,
+            message: Message,
+        ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
+            todo!()
+        }
+    }
 }