diff --git a/Cargo.lock b/Cargo.lock index fbad211c169d7fe285bfa83aeddf952f7e0f5c53..9ba2b4d862bff0f53c1c762844812a30fe7b067b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,6 +113,9 @@ dependencies = [ "chrono", "futures 0.3.31", "gpui", + "project", + "serde_json", + "util", "uuid", ] diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index e5643cb9eef39d9feea4446f304ff76bbea7dd34..7bad05021614f0a30265e0993692a02d158203c4 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -22,7 +22,10 @@ anyhow.workspace = true chrono.workspace = true futures.workspace = true gpui.workspace = true +project.workspace = true uuid.workspace = true [dev-dependencies] gpui = { workspace = true, "features" = ["test-support"] } +serde_json.workspace = true +util.workspace = true diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 1af7925bd888bba2458b20898dd376d8aa7a3b89..30377f2cfba72fff67a5e8efe0ccd1652ad657c8 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,7 +1,8 @@ use anyhow::{Result, anyhow}; use chrono::{DateTime, Utc}; -use futures::{StreamExt, stream::BoxStream}; +use futures::{StreamExt, channel::oneshot, stream::BoxStream}; use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; +use project::Project; use std::{ops::Range, path::PathBuf, sync::Arc}; use uuid::Uuid; @@ -15,11 +16,36 @@ pub trait Agent: 'static { pub trait AgentThread: 'static { fn entries(&self) -> impl Future>>; - fn send(&self, message: Message) -> impl Future>; - fn on_message( + fn send( &self, - handler: impl AsyncFn(Role, BoxStream<'static, Result>) -> Result<()>, - ); + message: Message, + ) -> impl Future>>>; +} + +pub enum ResponseEvent { + MessageResponse(MessageResponse), + ReadFileRequest(ReadFileRequest), + // GlobSearchRequest(SearchRequest), + // RegexSearchRequest(RegexSearchRequest), + // RunCommandRequest(RunCommandRequest), + // WebSearchResponse(WebSearchResponse), +} + +pub struct MessageResponse { + role: Role, + chunks: BoxStream<'static, Result>, +} + +pub struct ReadFileRequest { + path: PathBuf, + range: Range, + response_tx: oneshot::Sender>, +} + +impl ReadFileRequest { + pub fn respond(self, content: Result) { + self.response_tx.send(content).ok(); + } } pub struct ThreadId(Uuid); @@ -97,14 +123,23 @@ pub struct ThreadEntry { } pub struct ThreadStore { - agent: Arc, threads: Vec, + agent: Arc, + project: Entity, } impl ThreadStore { - pub async fn load(agent: Arc, cx: &mut AsyncApp) -> Result> { + pub async fn load( + agent: Arc, + project: Entity, + cx: &mut AsyncApp, + ) -> Result> { let threads = agent.threads().await?; - cx.new(|cx| Self { agent, threads }) + cx.new(|cx| Self { + threads, + agent, + project, + }) } /// Returns the threads in reverse chronological order. @@ -119,49 +154,49 @@ impl ThreadStore { cx: &mut Context, ) -> Task>>> { let agent = self.agent.clone(); + let project = self.project.clone(); cx.spawn(async move |_, cx| { let agent_thread = agent.open_thread(id).await?; - Thread::load(Arc::new(agent_thread), cx).await + Thread::load(Arc::new(agent_thread), project, cx).await }) } /// Creates a new thread. pub fn create_thread(&self, cx: &mut Context) -> Task>>> { let agent = self.agent.clone(); + let project = self.project.clone(); cx.spawn(async move |_, cx| { let agent_thread = agent.create_thread().await?; - Thread::load(Arc::new(agent_thread), cx).await + Thread::load(Arc::new(agent_thread), project, cx).await }) } } pub struct Thread { - agent_thread: Arc, - entries: Vec, next_entry_id: ThreadEntryId, + entries: Vec, + agent_thread: Arc, + project: Entity, } impl Thread { - pub async fn load(agent_thread: Arc, cx: &mut AsyncApp) -> Result> { + pub async fn load( + agent_thread: Arc, + project: Entity, + cx: &mut AsyncApp, + ) -> Result> { let entries = agent_thread.entries().await?; - cx.new(|cx| Self::new(agent_thread, entries, cx)) + cx.new(|cx| Self::new(agent_thread, entries, project, cx)) } pub fn new( agent_thread: Arc, entries: Vec, + project: Entity, cx: &mut Context, ) -> Self { - agent_thread.on_message({ - let this = cx.weak_entity(); - let cx = cx.to_async(); - async move |role, chunks| { - Self::handle_message(this.clone(), role, chunks, &mut cx.clone()).await - } - }); let mut next_entry_id = ThreadEntryId(0); Self { - agent_thread, entries: entries .into_iter() .map(|entry| ThreadEntry { @@ -170,6 +205,8 @@ impl Thread { }) .collect(), next_entry_id, + agent_thread, + project, } } @@ -221,24 +258,101 @@ impl Thread { pub fn send(&mut self, message: Message, cx: &mut Context) -> Task> { let agent_thread = self.agent_thread.clone(); - cx.spawn(async move |_, cx| agent_thread.send(message).await) + cx.spawn(async move |this, cx| { + let mut events = agent_thread.send(message).await?; + while let Some(event) = events.next().await { + match event { + Ok(ResponseEvent::MessageResponse(message)) => { + this.update(cx, |this, cx| this.handle_message_response(message, cx))? + .await?; + } + Ok(ResponseEvent::ReadFileRequest(request)) => { + this.update(cx, |this, cx| this.handle_read_file_request(request, cx))? + .await?; + } + Err(_) => todo!(), + } + } + Ok(()) + }) + } + + fn handle_message_response( + &mut self, + mut message: MessageResponse, + cx: &mut Context, + ) -> Task> { + let entry_id = self.next_entry_id.post_inc(); + self.entries.push(ThreadEntry { + id: entry_id, + entry: AgentThreadEntry::Message(Message { + role: message.role, + chunks: Vec::new(), + }), + }); + cx.notify(); + + cx.spawn(async move |this, cx| { + while let Some(chunk) = message.chunks.next().await { + match chunk { + Ok(chunk) => { + this.update(cx, |this, cx| { + let ix = this + .entries + .binary_search_by_key(&entry_id, |entry| entry.id) + .map_err(|_| anyhow!("message not found"))?; + let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry + else { + unreachable!() + }; + message.chunks.push(chunk); + cx.notify(); + anyhow::Ok(()) + })??; + } + Err(err) => todo!("show error"), + } + } + + Ok(()) + }) + } + + fn handle_read_file_request( + &mut self, + request: ReadFileRequest, + cx: &mut Context, + ) -> Task> { + todo!() } } #[cfg(test)] mod tests { - use std::path::Path; - use super::*; use gpui::{BackgroundExecutor, TestAppContext}; + use project::FakeFs; + use serde_json::json; + use std::path::Path; + use util::path; #[gpui::test] async fn test_basic(cx: &mut TestAppContext) { cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/test"), + json!({"foo": "foo", "bar": "bar", "baz": "baz"}), + ) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; let agent = GeminiAgent::start("~/gemini-cli/change-me.js", &cx.executor()) .await .unwrap(); - let thread_store = ThreadStore::load(Arc::new(agent), &mut cx.to_async()).await.unwrap(); + let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async()) + .await + .unwrap(); } struct GeminiAgent {} @@ -248,4 +362,35 @@ mod tests { executor.spawn(async move { Ok(GeminiAgent {}) }) } } + + impl Agent for GeminiAgent { + type Thread = GeminiAgentThread; + + async fn threads(&self) -> Result> { + todo!() + } + + async fn create_thread(&self) -> Result { + todo!() + } + + async fn open_thread(&self, id: ThreadId) -> Result { + todo!() + } + } + + struct GeminiAgentThread {} + + impl AgentThread for GeminiAgentThread { + async fn entries(&self) -> Result> { + todo!() + } + + async fn send( + &self, + message: Message, + ) -> Result>> { + todo!() + } + } }