Cargo.lock 🔗
@@ -113,6 +113,9 @@ dependencies = [
"chrono",
"futures 0.3.31",
"gpui",
+ "project",
+ "serde_json",
+ "util",
"uuid",
]
Ben Brandt , Antonio Scandurra , and Agus Zubiaga created
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Cargo.lock | 3
crates/agent2/Cargo.toml | 3
crates/agent2/src/agent2.rs | 197 +++++++++++++++++++++++++++++++++-----
3 files changed, 177 insertions(+), 26 deletions(-)
@@ -113,6 +113,9 @@ dependencies = [
"chrono",
"futures 0.3.31",
"gpui",
+ "project",
+ "serde_json",
+ "util",
"uuid",
]
@@ -22,7 +22,10 @@ anyhow.workspace = true
chrono.workspace = true
futures.workspace = true
gpui.workspace = true
+project.workspace = true
uuid.workspace = true
[dev-dependencies]
gpui = { workspace = true, "features" = ["test-support"] }
+serde_json.workspace = true
+util.workspace = true
@@ -1,7 +1,8 @@
use anyhow::{Result, anyhow};
use chrono::{DateTime, Utc};
-use futures::{StreamExt, stream::BoxStream};
+use futures::{StreamExt, channel::oneshot, stream::BoxStream};
use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
+use project::Project;
use std::{ops::Range, path::PathBuf, sync::Arc};
use uuid::Uuid;
@@ -15,11 +16,36 @@ pub trait Agent: 'static {
pub trait AgentThread: 'static {
fn entries(&self) -> impl Future<Output = Result<Vec<AgentThreadEntry>>>;
- fn send(&self, message: Message) -> impl Future<Output = Result<()>>;
- fn on_message(
+ fn send(
&self,
- handler: impl AsyncFn(Role, BoxStream<'static, Result<MessageChunk>>) -> Result<()>,
- );
+ message: Message,
+ ) -> impl Future<Output = Result<BoxStream<'static, Result<ResponseEvent>>>>;
+}
+
+pub enum ResponseEvent {
+ MessageResponse(MessageResponse),
+ ReadFileRequest(ReadFileRequest),
+ // GlobSearchRequest(SearchRequest),
+ // RegexSearchRequest(RegexSearchRequest),
+ // RunCommandRequest(RunCommandRequest),
+ // WebSearchResponse(WebSearchResponse),
+}
+
+pub struct MessageResponse {
+ role: Role,
+ chunks: BoxStream<'static, Result<MessageChunk>>,
+}
+
+pub struct ReadFileRequest {
+ path: PathBuf,
+ range: Range<usize>,
+ response_tx: oneshot::Sender<Result<FileContent>>,
+}
+
+impl ReadFileRequest {
+ pub fn respond(self, content: Result<FileContent>) {
+ self.response_tx.send(content).ok();
+ }
}
pub struct ThreadId(Uuid);
@@ -97,14 +123,23 @@ pub struct ThreadEntry {
}
pub struct ThreadStore<T: Agent> {
- agent: Arc<T>,
threads: Vec<AgentThreadSummary>,
+ agent: Arc<T>,
+ project: Entity<Project>,
}
impl<T: Agent> ThreadStore<T> {
- pub async fn load(agent: Arc<T>, cx: &mut AsyncApp) -> Result<Entity<Self>> {
+ pub async fn load(
+ agent: Arc<T>,
+ project: Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<Entity<Self>> {
let threads = agent.threads().await?;
- cx.new(|cx| Self { agent, threads })
+ cx.new(|cx| Self {
+ threads,
+ agent,
+ project,
+ })
}
/// Returns the threads in reverse chronological order.
@@ -119,49 +154,49 @@ impl<T: Agent> ThreadStore<T> {
cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread<T::Thread>>>> {
let agent = self.agent.clone();
+ let project = self.project.clone();
cx.spawn(async move |_, cx| {
let agent_thread = agent.open_thread(id).await?;
- Thread::load(Arc::new(agent_thread), cx).await
+ Thread::load(Arc::new(agent_thread), project, cx).await
})
}
/// Creates a new thread.
pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread<T::Thread>>>> {
let agent = self.agent.clone();
+ let project = self.project.clone();
cx.spawn(async move |_, cx| {
let agent_thread = agent.create_thread().await?;
- Thread::load(Arc::new(agent_thread), cx).await
+ Thread::load(Arc::new(agent_thread), project, cx).await
})
}
}
pub struct Thread<T: AgentThread> {
- agent_thread: Arc<T>,
- entries: Vec<ThreadEntry>,
next_entry_id: ThreadEntryId,
+ entries: Vec<ThreadEntry>,
+ agent_thread: Arc<T>,
+ project: Entity<Project>,
}
impl<T: AgentThread> Thread<T> {
- pub async fn load(agent_thread: Arc<T>, cx: &mut AsyncApp) -> Result<Entity<Self>> {
+ pub async fn load(
+ agent_thread: Arc<T>,
+ project: Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<Entity<Self>> {
let entries = agent_thread.entries().await?;
- cx.new(|cx| Self::new(agent_thread, entries, cx))
+ cx.new(|cx| Self::new(agent_thread, entries, project, cx))
}
pub fn new(
agent_thread: Arc<T>,
entries: Vec<AgentThreadEntry>,
+ project: Entity<Project>,
cx: &mut Context<Self>,
) -> Self {
- agent_thread.on_message({
- let this = cx.weak_entity();
- let cx = cx.to_async();
- async move |role, chunks| {
- Self::handle_message(this.clone(), role, chunks, &mut cx.clone()).await
- }
- });
let mut next_entry_id = ThreadEntryId(0);
Self {
- agent_thread,
entries: entries
.into_iter()
.map(|entry| ThreadEntry {
@@ -170,6 +205,8 @@ impl<T: AgentThread> Thread<T> {
})
.collect(),
next_entry_id,
+ agent_thread,
+ project,
}
}
@@ -221,24 +258,101 @@ impl<T: AgentThread> Thread<T> {
pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
let agent_thread = self.agent_thread.clone();
- cx.spawn(async move |_, cx| agent_thread.send(message).await)
+ cx.spawn(async move |this, cx| {
+ let mut events = agent_thread.send(message).await?;
+ while let Some(event) = events.next().await {
+ match event {
+ Ok(ResponseEvent::MessageResponse(message)) => {
+ this.update(cx, |this, cx| this.handle_message_response(message, cx))?
+ .await?;
+ }
+ Ok(ResponseEvent::ReadFileRequest(request)) => {
+ this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
+ .await?;
+ }
+ Err(_) => todo!(),
+ }
+ }
+ Ok(())
+ })
+ }
+
+ fn handle_message_response(
+ &mut self,
+ mut message: MessageResponse,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let entry_id = self.next_entry_id.post_inc();
+ self.entries.push(ThreadEntry {
+ id: entry_id,
+ entry: AgentThreadEntry::Message(Message {
+ role: message.role,
+ chunks: Vec::new(),
+ }),
+ });
+ cx.notify();
+
+ cx.spawn(async move |this, cx| {
+ while let Some(chunk) = message.chunks.next().await {
+ match chunk {
+ Ok(chunk) => {
+ this.update(cx, |this, cx| {
+ let ix = this
+ .entries
+ .binary_search_by_key(&entry_id, |entry| entry.id)
+ .map_err(|_| anyhow!("message not found"))?;
+ let AgentThreadEntry::Message(message) = &mut this.entries[ix].entry
+ else {
+ unreachable!()
+ };
+ message.chunks.push(chunk);
+ cx.notify();
+ anyhow::Ok(())
+ })??;
+ }
+ Err(err) => todo!("show error"),
+ }
+ }
+
+ Ok(())
+ })
+ }
+
+ fn handle_read_file_request(
+ &mut self,
+ request: ReadFileRequest,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ todo!()
}
}
#[cfg(test)]
mod tests {
- use std::path::Path;
-
use super::*;
use gpui::{BackgroundExecutor, TestAppContext};
+ use project::FakeFs;
+ use serde_json::json;
+ use std::path::Path;
+ use util::path;
#[gpui::test]
async fn test_basic(cx: &mut TestAppContext) {
cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/test"),
+ json!({"foo": "foo", "bar": "bar", "baz": "baz"}),
+ )
+ .await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
let agent = GeminiAgent::start("~/gemini-cli/change-me.js", &cx.executor())
.await
.unwrap();
- let thread_store = ThreadStore::load(Arc::new(agent), &mut cx.to_async()).await.unwrap();
+ let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
+ .await
+ .unwrap();
}
struct GeminiAgent {}
@@ -248,4 +362,35 @@ mod tests {
executor.spawn(async move { Ok(GeminiAgent {}) })
}
}
+
+ impl Agent for GeminiAgent {
+ type Thread = GeminiAgentThread;
+
+ async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
+ todo!()
+ }
+
+ async fn create_thread(&self) -> Result<Self::Thread> {
+ todo!()
+ }
+
+ async fn open_thread(&self, id: ThreadId) -> Result<Self::Thread> {
+ todo!()
+ }
+ }
+
+ struct GeminiAgentThread {}
+
+ impl AgentThread for GeminiAgentThread {
+ async fn entries(&self) -> Result<Vec<AgentThreadEntry>> {
+ todo!()
+ }
+
+ async fn send(
+ &self,
+ message: Message,
+ ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
+ todo!()
+ }
+ }
}