@@ -1,19 +1,15 @@
-use std::{
- io::{Cursor, Write as _},
- path::Path,
- sync::{Arc, Weak},
-};
+use std::{io::Write as _, path::Path, sync::Arc};
use crate::{
- Agent, AgentThread, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk,
- ResponseEvent, Role, Thread, ThreadEntry, ThreadId,
+ Agent, AgentThreadEntryContent, AgentThreadSummary, Message, MessageChunk, ResponseEvent, Role,
+ Thread, ThreadEntryId, ThreadId,
};
-use agentic_coding_protocol::{self as acp, TurnId};
-use anyhow::{Context as _, Result};
+use agentic_coding_protocol as acp;
+use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use collections::HashMap;
use futures::channel::mpsc::UnboundedReceiver;
-use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use parking_lot::Mutex;
use project::Project;
use smol::process::Child;
@@ -21,18 +17,43 @@ use util::ResultExt;
pub struct AcpAgent {
connection: Arc<acp::AgentConnection>,
- threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
+ threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
+ project: Entity<Project>,
_handler_task: Task<()>,
_io_task: Task<()>,
}
struct AcpClientDelegate {
project: Entity<Project>,
- threads: Arc<Mutex<HashMap<acp::ThreadId, WeakEntity<Thread>>>>,
+ threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<Thread>>>>,
cx: AsyncApp,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
+impl AcpClientDelegate {
+ fn new(project: Entity<Project>, cx: AsyncApp) -> Self {
+ Self {
+ project,
+ threads: Default::default(),
+ cx: cx,
+ }
+ }
+
+ fn update_thread<R>(
+ &self,
+ thread_id: &ThreadId,
+ cx: &mut App,
+ callback: impl FnMut(&mut Thread, &mut Context<Thread>) -> R,
+ ) -> Option<R> {
+ let thread = self.threads.lock().get(&thread_id)?.clone();
+ let Some(thread) = thread.upgrade() else {
+ self.threads.lock().remove(&thread_id);
+ return None;
+ };
+ Some(thread.update(cx, callback))
+ }
+}
+
#[async_trait(?Send)]
impl acp::Client for AcpClientDelegate {
async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
@@ -58,7 +79,7 @@ impl acp::Client for AcpClientDelegate {
async fn stream_message_chunk(
&self,
- request: acp::StreamMessageChunkParams,
+ chunk: acp::StreamMessageChunkParams,
) -> Result<acp::StreamMessageChunkResponse> {
Ok(acp::StreamMessageChunkResponse)
}
@@ -78,25 +99,23 @@ impl acp::Client for AcpClientDelegate {
})??
.await?;
- buffer.update(cx, |buffer, _| {
+ buffer.update(cx, |buffer, cx| {
let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
let end = match request.line_limit {
None => buffer.max_point(),
Some(limit) => start + language::Point::new(limit + 1, 0),
};
- let content = buffer.text_for_range(start..end).collect();
-
- if let Some(thread) = self.threads.lock().get(&request.thread_id) {
- thread.update(cx, |thread, cx| {
- thread.push_entry(ThreadEntry {
- content: AgentThreadEntryContent::ReadFile {
- path: request.path.clone(),
- content: content.clone(),
- },
- });
- })
- }
+ let content: String = buffer.text_for_range(start..end).collect();
+ self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+ thread.push_entry(
+ AgentThreadEntryContent::ReadFile {
+ path: request.path.clone(),
+ content: content.clone(),
+ },
+ cx,
+ );
+ });
acp::ReadTextFileResponse {
content,
@@ -135,7 +154,7 @@ impl acp::Client for AcpClientDelegate {
let mut base64_content = Vec::new();
let mut base64_encoder = base64::write::EncoderWriter::new(
- Cursor::new(&mut base64_content),
+ std::io::Cursor::new(&mut base64_content),
&base64::engine::general_purpose::STANDARD,
);
base64_encoder.write_all(range_content)?;
@@ -168,10 +187,7 @@ impl AcpAgent {
let stdout = process.stdout.take().expect("process didn't have stdout");
let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
- AcpClientDelegate {
- project,
- cx: cx.clone(),
- },
+ AcpClientDelegate::new(project.clone(), cx.clone()),
stdin,
stdout,
);
@@ -182,17 +198,18 @@ impl AcpAgent {
});
Self {
+ project,
connection: Arc::new(connection),
- threads: Mutex::default(),
+ threads: Default::default(),
_handler_task: cx.foreground_executor().spawn(handler_fut),
_io_task: io_task,
}
}
}
-#[async_trait]
+#[async_trait(?Send)]
impl Agent for AcpAgent {
- async fn threads(&self) -> Result<Vec<AgentThreadSummary>> {
+ async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>> {
let response = self.connection.request(acp::GetThreadsParams).await?;
response
.threads
@@ -207,31 +224,34 @@ impl Agent for AcpAgent {
.collect()
}
- async fn create_thread(&self) -> Result<Arc<Self::Thread>> {
+ async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
let response = self.connection.request(acp::CreateThreadParams).await?;
- let thread = Arc::new(AcpAgentThread {
- id: response.thread_id.clone(),
- connection: self.connection.clone(),
- state: Mutex::new(AcpAgentThreadState {
- turn: None,
- next_turn_id: TurnId::default(),
- }),
- });
- self.threads
- .lock()
- .insert(response.thread_id, Arc::downgrade(&thread));
+ let thread_id: ThreadId = response.thread_id.into();
+ let agent = self.clone();
+ let thread = cx.new(|_| Thread {
+ id: thread_id.clone(),
+ next_entry_id: ThreadEntryId(0),
+ entries: Vec::default(),
+ project: self.project.clone(),
+ agent,
+ })?;
+ self.threads.lock().insert(thread_id, thread.downgrade());
Ok(thread)
}
- async fn open_thread(&self, id: ThreadId) -> Result<Thread> {
+ async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>> {
todo!()
}
- async fn thread_entries(&self, thread_id: ThreadId) -> Result<Vec<AgentThreadEntryContent>> {
+ async fn thread_entries(
+ &self,
+ thread_id: ThreadId,
+ cx: &mut AsyncApp,
+ ) -> Result<Vec<AgentThreadEntryContent>> {
let response = self
.connection
.request(acp::GetThreadEntriesParams {
- thread_id: self.id.clone(),
+ thread_id: thread_id.clone().into(),
})
.await?;
@@ -265,18 +285,18 @@ impl Agent for AcpAgent {
&self,
thread_id: ThreadId,
message: crate::Message,
+ cx: &mut AsyncApp,
) -> Result<UnboundedReceiver<Result<ResponseEvent>>> {
- let turn_id = {
- let mut state = self.state.lock();
- let turn_id = state.next_turn_id.post_inc();
- state.turn = Some(AcpAgentThreadTurn { id: turn_id });
- turn_id
- };
+ let thread = self
+ .threads
+ .lock()
+ .get(&thread_id)
+ .cloned()
+ .ok_or_else(|| anyhow!("no such thread"))?;
let response = self
.connection
.request(acp::SendMessageParams {
- thread_id: self.id.clone(),
- turn_id,
+ thread_id: thread_id.clone().into(),
message: acp::Message {
role: match message.role {
Role::User => acp::Role::User,
@@ -301,29 +321,14 @@ impl Agent for AcpAgent {
}
}
-pub struct AcpAgentThread {
- id: acp::ThreadId,
- connection: Arc<acp::AgentConnection>,
- state: Mutex<AcpAgentThreadState>,
-}
-
-struct AcpAgentThreadState {
- next_turn_id: acp::TurnId,
- turn: Option<AcpAgentThreadTurn>,
-}
-
-struct AcpAgentThreadTurn {
- id: acp::TurnId,
-}
-
impl From<acp::ThreadId> for ThreadId {
fn from(thread_id: acp::ThreadId) -> Self {
- Self(thread_id.0)
+ Self(thread_id.0.into())
}
}
impl From<ThreadId> for acp::ThreadId {
fn from(thread_id: ThreadId) -> Self {
- acp::ThreadId(thread_id.0)
+ acp::ThreadId(thread_id.0.to_string())
}
}
@@ -13,16 +13,21 @@ use gpui::{AppContext, AsyncApp, Context, Entity, SharedString, Task};
use project::Project;
use std::{future, ops::Range, path::PathBuf, pin::pin, sync::Arc};
-#[async_trait]
+#[async_trait(?Send)]
pub trait Agent: 'static {
- async fn threads(&self) -> Result<Vec<AgentThreadSummary>>;
- async fn create_thread(&self) -> Result<Entity<Thread>>;
- async fn open_thread(&self, id: ThreadId) -> Result<Entity<Thread>>;
- async fn thread_entries(&self, id: ThreadId) -> Result<Vec<AgentThreadEntryContent>>;
+ async fn threads(&self, cx: &mut AsyncApp) -> Result<Vec<AgentThreadSummary>>;
+ async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
+ async fn open_thread(&self, id: ThreadId, cx: &mut AsyncApp) -> Result<Entity<Thread>>;
+ async fn thread_entries(
+ &self,
+ id: ThreadId,
+ cx: &mut AsyncApp,
+ ) -> Result<Vec<AgentThreadEntryContent>>;
async fn send_thread_message(
&self,
thread_id: ThreadId,
message: Message,
+ cx: &mut AsyncApp,
) -> Result<mpsc::UnboundedReceiver<Result<ResponseEvent>>>;
}
@@ -53,7 +58,7 @@ impl ReadFileRequest {
}
}
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ThreadId(SharedString);
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
@@ -145,20 +150,20 @@ pub struct ThreadEntry {
pub content: AgentThreadEntryContent,
}
-pub struct ThreadStore<T: Agent> {
+pub struct ThreadStore {
threads: Vec<AgentThreadSummary>,
- agent: Arc<T>,
+ agent: Arc<dyn Agent>,
project: Entity<Project>,
}
-impl<T: Agent> ThreadStore<T> {
+impl ThreadStore {
pub async fn load(
- agent: Arc<T>,
+ agent: Arc<dyn Agent>,
project: Entity<Project>,
cx: &mut AsyncApp,
) -> Result<Entity<Self>> {
- let threads = agent.threads().await?;
- cx.new(|cx| Self {
+ let threads = agent.threads(cx).await?;
+ cx.new(|_cx| Self {
threads,
agent,
project,
@@ -177,21 +182,13 @@ impl<T: Agent> ThreadStore<T> {
cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> {
let agent = self.agent.clone();
- let project = self.project.clone();
- cx.spawn(async move |_, cx| {
- let agent_thread = agent.open_thread(id).await?;
- Thread::load(agent_thread, project, cx).await
- })
+ cx.spawn(async move |_, cx| agent.open_thread(id, cx).await)
}
/// Creates a new thread.
pub fn create_thread(&self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
let agent = self.agent.clone();
- let project = self.project.clone();
- cx.spawn(async move |_, cx| {
- let agent_thread = agent.create_thread().await?;
- Thread::load(agent_thread, project, cx).await
- })
+ cx.spawn(async move |_, cx| agent.create_thread(cx).await)
}
}
@@ -210,7 +207,7 @@ impl Thread {
project: Entity<Project>,
cx: &mut AsyncApp,
) -> Result<Entity<Self>> {
- let entries = agent.thread_entries(thread_id.clone()).await?;
+ let entries = agent.thread_entries(thread_id.clone(), cx).await?;
cx.new(|cx| Self::new(agent, thread_id, entries, project, cx))
}
@@ -241,11 +238,19 @@ impl Thread {
&self.entries
}
+ pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
+ self.entries.push(ThreadEntry {
+ id: self.next_entry_id.post_inc(),
+ content: entry,
+ });
+ cx.notify();
+ }
+
pub fn send(&mut self, message: Message, cx: &mut Context<Self>) -> Task<Result<()>> {
let agent = self.agent.clone();
- let id = self.id;
+ let id = self.id.clone();
cx.spawn(async move |this, cx| {
- let mut events = agent.send_thread_message(id, message).await?;
+ let mut events = agent.send_thread_message(id, message, cx).await?;
let mut pending_event_handlers = FuturesUnordered::new();
loop {