diff --git a/crates/agent2/src/acp.rs b/crates/agent2/src/acp.rs index 529b33e8282e2c571da5ede6b879f7296cb78c9d..3981dda23b9e834e34bd41f061444fea46ad33ce 100644 --- a/crates/agent2/src/acp.rs +++ b/crates/agent2/src/acp.rs @@ -1,14 +1,13 @@ use std::{io::Write as _, path::Path, sync::Arc}; use crate::{ - Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role, - Thread, ThreadEntryId, ThreadId, + Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, Role, Thread, + ThreadEntryId, ThreadId, }; use agentic_coding_protocol as acp; use anyhow::{Context as _, Result, anyhow}; use async_trait::async_trait; use collections::HashMap; -use futures::channel::mpsc::UnboundedReceiver; use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; use parking_lot::Mutex; use project::Project; @@ -31,10 +30,14 @@ struct AcpClientDelegate { } impl AcpClientDelegate { - fn new(project: Entity, cx: AsyncApp) -> Self { + fn new( + project: Entity, + threads: Arc>>>, + cx: AsyncApp, + ) -> Self { Self { project, - threads: Default::default(), + threads, cx: cx, } } @@ -186,8 +189,9 @@ impl AcpAgent { let stdin = process.stdin.take().expect("process didn't have stdin"); let stdout = process.stdout.take().expect("process didn't have stdout"); + let threads: Arc>>> = Default::default(); let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(project.clone(), cx.clone()), + AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()), stdin, stdout, ); @@ -200,7 +204,7 @@ impl AcpAgent { Self { project, connection: Arc::new(connection), - threads: Default::default(), + threads, _handler_task: cx.foreground_executor().spawn(handler_fut), _io_task: io_task, } @@ -286,15 +290,14 @@ impl Agent for AcpAgent { thread_id: ThreadId, message: crate::Message, cx: &mut AsyncApp, - ) -> Result>> { + ) -> Result<()> { let thread = self .threads .lock() .get(&thread_id) .cloned() .ok_or_else(|| anyhow!("no such thread"))?; - let response = self - .connection + self.connection .request(acp::SendMessageParams { thread_id: thread_id.clone().into(), message: acp::Message { @@ -317,7 +320,7 @@ impl Agent for AcpAgent { }, }) .await?; - todo!() + Ok(()) } } diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 309fcc27289672e112fbf5abf17b615f6189eb3d..519e23f35e77f887955c01b070a4336d0b6b2d65 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -3,15 +3,9 @@ mod acp; use anyhow::{Result, anyhow}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use futures::{ - FutureExt, StreamExt, - channel::{mpsc, oneshot}, - select_biased, - stream::{BoxStream, FuturesUnordered}, -}; use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task}; use project::Project; -use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc}; +use std::{ops::Range, path::PathBuf, sync::Arc}; #[async_trait(?Send)] pub trait Agent: 'static { @@ -28,34 +22,7 @@ pub trait Agent: 'static { thread_id: ThreadId, message: Message, cx: &mut AsyncApp, - ) -> Result>>; -} - -pub enum ResponseEvent { - MessageResponse(MessageResponse), - ReadFileRequest(ReadFileRequest), - // GlobSearchRequest(SearchRequest), - // RegexSearchRequest(RegexSearchRequest), - // RunCommandRequest(RunCommandRequest), - // WebSearchResponse(WebSearchResponse), -} - -pub struct MessageResponse { - role: Role, - chunks: BoxStream<'static, Result>, -} - -#[derive(Debug)] -pub struct ReadFileRequest { - path: PathBuf, - range: Range, - response_tx: oneshot::Sender>, -} - -impl ReadFileRequest { - pub fn respond(self, content: Result) { - self.response_tx.send(content).ok(); - } + ) -> Result<()>; } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -250,104 +217,10 @@ impl Thread { let agent = self.agent.clone(); let id = self.id.clone(); cx.spawn(async move |this, cx| { - let mut events = agent.send_thread_message(id, message, cx).await?; - let mut pending_event_handlers = FuturesUnordered::new(); - - loop { - let mut next_event_handler_result = pin!( - async { - if pending_event_handlers.is_empty() { - future::pending::<()>().await; - } - - pending_event_handlers.next().await - } - .fuse() - ); - - select_biased! { - event = events.next() => { - let Some(event) = event else { - while let Some(result) = pending_event_handlers.next().await { - result?; - } - - break; - }; - - let task = match event { - Ok(ResponseEvent::MessageResponse(message)) => { - this.update(cx, |this, cx| this.handle_message_response(message, cx))? - } - Ok(ResponseEvent::ReadFileRequest(request)) => { - this.update(cx, |this, cx| this.handle_read_file_request(request, cx))? - } - Err(_) => todo!(), - }; - pending_event_handlers.push(task); - } - result = next_event_handler_result => { - // Event handlers should only return errors that are - // unrecoverable and should therefore stop this turn of - // the agentic loop. - result.unwrap()?; - } - } - } - + agent.send_thread_message(id, message, cx).await?; Ok(()) }) } - - fn handle_message_response( - &mut self, - mut message: MessageResponse, - cx: &mut Context, - ) -> Task> { - let entry_id = self.next_entry_id.post_inc(); - self.entries.push(ThreadEntry { - id: entry_id, - content: AgentThreadEntryContent::Message(Message { - role: message.role, - chunks: Vec::new(), - }), - }); - cx.notify(); - - cx.spawn(async move |this, cx| { - while let Some(chunk) = message.chunks.next().await { - match chunk { - Ok(chunk) => { - this.update(cx, |this, cx| { - let ix = this - .entries - .binary_search_by_key(&entry_id, |entry| entry.id) - .map_err(|_| anyhow!("message not found"))?; - let AgentThreadEntryContent::Message(message) = - &mut this.entries[ix].content - else { - unreachable!() - }; - message.chunks.push(chunk); - cx.notify(); - anyhow::Ok(()) - })??; - } - Err(err) => todo!("show error"), - } - } - - Ok(()) - }) - } - - fn handle_read_file_request( - &mut self, - request: ReadFileRequest, - cx: &mut Context, - ) -> Task> { - todo!() - } } #[cfg(test)] @@ -367,6 +240,7 @@ mod tests { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); Project::init_settings(cx); + language::init(cx); }); } @@ -378,11 +252,11 @@ mod tests { let fs = FakeFs::new(cx.executor()); fs.insert_tree( - path!("/test"), + path!("/tmp"), json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), ) .await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs, [path!("/tmp").as_ref()], cx).await; let agent = gemini_agent(project.clone(), cx.to_async()).unwrap(); let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async()) .await @@ -400,7 +274,7 @@ mod tests { Message { role: Role::User, chunks: vec![ - "Read the 'test/foo' file and output all of its contents.".into(), + "Read the '/tmp/foo' file and output all of its contents.".into(), ], }, cx, @@ -413,7 +287,7 @@ mod tests { thread.entries().iter().any(|entry| { entry.content == AgentThreadEntryContent::ReadFile { - path: "test/foo".into(), + path: "/tmp/foo".into(), content: "Lorem ipsum dolor".into(), } }),