@@ -1,14 +1,13 @@
use std::{io::Write as _, path::Path, sync::Arc};
use crate::{
- Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role,
- Thread, ThreadEntryId, ThreadId,
+ Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, Role, Thread,
+ ThreadEntryId, ThreadId,
};
use agentic_coding_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use collections::HashMap;
-use futures::channel::mpsc::UnboundedReceiver;
use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use parking_lot::Mutex;
use project::Project;
@@ -31,10 +30,14 @@ struct AcpClientDelegate {
}
impl AcpClientDelegate {
- fn new(project: Entity<Project>, cx: AsyncApp) -> Self {
+ fn new(
+ project: Entity<Project>,
+ threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
+ cx: AsyncApp,
+ ) -> Self {
Self {
project,
- threads: Default::default(),
+ threads,
cx: cx,
}
}
@@ -186,8 +189,9 @@ impl AcpAgent {
let stdin = process.stdin.take().expect("process didn't have stdin");
let stdout = process.stdout.take().expect("process didn't have stdout");
+ let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>> = Default::default();
let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
- AcpClientDelegate::new(project.clone(), cx.clone()),
+ AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
stdin,
stdout,
);
@@ -200,7 +204,7 @@ impl AcpAgent {
Self {
project,
connection: Arc::new(connection),
- threads: Default::default(),
+ threads,
_handler_task: cx.foreground_executor().spawn(handler_fut),
_io_task: io_task,
}
@@ -286,15 +290,14 @@ impl Agent for AcpAgent {
thread_id: ThreadId,
message: crate::Message,
cx: &mut AsyncApp,
- ) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
+ ) -> Result<()> {
let thread = self
.threads
.lock()
.get(&thread_id)
.cloned()
.ok_or_else(|| anyhow!("no such thread"))?;
- let response = self
- .connection
+ self.connection
.request(acp::SendMessageParams {
thread_id: thread_id.clone().into(),
message: acp::Message {
@@ -317,7 +320,7 @@ impl Agent for AcpAgent {
},
})
.await?;
- todo!()
+ Ok(())
}
}
@@ -3,15 +3,9 @@ mod acp;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
-use futures::{
- FutureExt, StreamExt,
- channel::{mpsc, oneshot},
- select_biased,
- stream::{BoxStream, FuturesUnordered},
-};
use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
use project::Project;
-use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
+use std::{ops::Range, path::PathBuf, sync::Arc};
#[async_trait(?Send)]
pub trait Agent: 'static {
@@ -28,34 +22,7 @@ pub trait Agent: 'static {
thread_id: ThreadId,
message: Message,
cx: &mut AsyncApp,
- ) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
-}
-
-pub enum ResponseEvent {
- MessageResponse(MessageResponse),
- ReadFileRequest(ReadFileRequest),
- // GlobSearchRequest(SearchRequest),
- // RegexSearchRequest(RegexSearchRequest),
- // RunCommandRequest(RunCommandRequest),
- // WebSearchResponse(WebSearchResponse),
-}
-
-pub struct MessageResponse {
- role: Role,
- chunks: BoxStream<'static, Result<MessageChunk>>,
-}
-
-#[derive(Debug)]
-pub struct ReadFileRequest {
- path: PathBuf,
- range: Range<usize>,
- response_tx: oneshot::Sender<Result<FileContent>>,
-}
-
-impl ReadFileRequest {
- pub fn respond(self, content: Result<FileContent>) {
- self.response_tx.send(content).ok();
- }
+ ) -> Result<()>;
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -250,104 +217,10 @@ impl Thread {
let agent = self.agent.clone();
let id = self.id.clone();
cx.spawn(async move |this, cx| {
- let mut events = agent.send_thread_message(id, message, cx).await?;
- let mut pending_event_handlers = FuturesUnordered::new();
-
- loop {
- let mut next_event_handler_result = pin!(
- async {
- if pending_event_handlers.is_empty() {
- future::pending::<()>().await;
- }
-
- pending_event_handlers.next().await
- }
- .fuse()
- );
-
- select_biased! {
- event = events.next() => {
- let Some(event) = event else {
- while let Some(result) = pending_event_handlers.next().await {
- result?;
- }
-
- break;
- };
-
- let task = match event {
- Ok(ResponseEvent::MessageResponse(message)) => {
- this.update(cx, |this, cx| this.handle_message_response(message, cx))?
- }
- Ok(ResponseEvent::ReadFileRequest(request)) => {
- this.update(cx, |this, cx| this.handle_read_file_request(request, cx))?
- }
- Err(_) => todo!(),
- };
- pending_event_handlers.push(task);
- }
- result = next_event_handler_result => {
- // Event handlers should only return errors that are
- // unrecoverable and should therefore stop this turn of
- // the agentic loop.
- result.unwrap()?;
- }
- }
- }
-
+ agent.send_thread_message(id, message, cx).await?;
Ok(())
})
}
-
- fn handle_message_response(
- &mut self,
- mut message: MessageResponse,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let entry_id = self.next_entry_id.post_inc();
- self.entries.push(ThreadEntry {
- id: entry_id,
- content: AgentThreadEntryContent::Message(Message {
- role: message.role,
- chunks: Vec::new(),
- }),
- });
- cx.notify();
-
- cx.spawn(async move |this, cx| {
- while let Some(chunk) = message.chunks.next().await {
- match chunk {
- Ok(chunk) => {
- this.update(cx, |this, cx| {
- let ix = this
- .entries
- .binary_search_by_key(&entry_id, |entry| entry.id)
- .map_err(|_| anyhow!("message not found"))?;
- let AgentThreadEntryContent::Message(message) =
- &mut this.entries[ix].content
- else {
- unreachable!()
- };
- message.chunks.push(chunk);
- cx.notify();
- anyhow::Ok(())
- })??;
- }
- Err(err) => todo!("show error"),
- }
- }
-
- Ok(())
- })
- }
-
- fn handle_read_file_request(
- &mut self,
- request: ReadFileRequest,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- todo!()
- }
}
#[cfg(test)]
@@ -367,6 +240,7 @@ mod tests {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
+ language::init(cx);
});
}
@@ -378,11 +252,11 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
- path!("/test"),
+ path!("/tmp"),
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
)
.await;
- let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project = Project::test(fs, [path!("/tmp").as_ref()], cx).await;
let agent = gemini_agent(project.clone(), cx.to_async()).unwrap();
let thread_store = ThreadStore::load(Arc::new(agent), project, &mut cx.to_async())
.await
@@ -400,7 +274,7 @@ mod tests {
Message {
role: Role::User,
chunks: vec![
- "Read the 'test/foo' file and output all of its contents.".into(),
+ "Read the '/tmp/foo' file and output all of its contents.".into(),
],
},
cx,
@@ -413,7 +287,7 @@ mod tests {
thread.entries().iter().any(|entry| {
entry.content
== AgentThreadEntryContent::ReadFile {
- path: "test/foo".into(),
+ path: "/tmp/foo".into(),
content: "Lorem ipsum dolor".into(),
}
}),