@@ -10,8 +10,10 @@ use futures::channel::oneshot;
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
use markdown::Markdown;
+use parking_lot::Mutex;
+use parking_lot::Mutex;
use project::Project;
-use std::{mem, ops::Range, path::PathBuf, sync::Arc};
+use std::{mem, ops::Range, path::PathBuf, process::ExitStatus, sync::Arc};
use ui::{App, IconName};
use util::{ResultExt, debug_panic};
@@ -377,13 +379,17 @@ pub struct ThreadEntry {
}
pub struct AcpThread {
- id: ThreadId,
next_entry_id: ThreadEntryId,
entries: Vec<ThreadEntry>,
server: Arc<AcpServer>,
title: SharedString,
project: Entity<Project>,
send_task: Option<Task<()>>,
+
+ connection: Arc<acp::AgentConnection>,
+ exit_status: Arc<Mutex<Option<ExitStatus>>>,
+ _handler_task: Task<()>,
+ _io_task: Task<()>,
}
enum AcpThreadEvent {
@@ -403,7 +409,6 @@ impl EventEmitter<AcpThreadEvent> for AcpThread {}
impl AcpThread {
pub fn new(
server: Arc<AcpServer>,
- thread_id: ThreadId,
entries: Vec<AgentThreadEntryContent>,
project: Entity<Project>,
_: &mut Context<Self>,
@@ -419,7 +424,6 @@ impl AcpThread {
})
.collect(),
server,
- id: thread_id,
next_entry_id,
project,
send_task: None,
@@ -680,7 +684,6 @@ impl AcpThread {
cx: &mut Context<Self>,
) -> impl use<> + Future<Output = Result<()>> {
let agent = self.server.clone();
- let id = self.id.clone();
let chunk =
UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
let message = UserMessage {
@@ -695,7 +698,7 @@ impl AcpThread {
self.send_task = Some(cx.spawn(async move |this, cx| {
cancel.await.log_err();
- let result = agent.send_message(id, acp_message, cx).await;
+ let result = agent.send_message(acp_message, cx).await;
tx.send(result).log_err();
this.update(cx, |this, _cx| this.send_task.take()).log_err();
}));
@@ -710,11 +713,10 @@ impl AcpThread {
pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let agent = self.server.clone();
- let id = self.id.clone();
if self.send_task.take().is_some() {
cx.spawn(async move |this, cx| {
- agent.cancel_send_message(id, cx).await?;
+ agent.cancel_send_message(cx).await?;
this.update(cx, |this, _cx| {
for entry in this.entries.iter_mut() {
@@ -851,7 +853,6 @@ mod tests {
server
.update(&mut cx, |server, _| {
server.send_to_zed(acp::StreamAssistantMessageChunkParams {
- thread_id: params.thread_id.clone(),
chunk: acp::AssistantMessageChunk::Thought {
chunk: "Thinking ".into(),
},
@@ -862,7 +863,6 @@ mod tests {
server
.update(&mut cx, |server, _| {
server.send_to_zed(acp::StreamAssistantMessageChunkParams {
- thread_id: params.thread_id,
chunk: acp::AssistantMessageChunk::Thought {
chunk: "hard!".into(),
},
@@ -1151,10 +1151,11 @@ mod tests {
pub fn fake_acp_server(
project: Entity<Project>,
cx: &mut TestAppContext,
- ) -> (Arc<AcpServer>, Entity<FakeAcpServer>) {
+ ) -> (Entity<Thread>, Arc<AcpServer>, Entity<FakeAcpServer>) {
let (stdin_tx, stdin_rx) = async_pipe::pipe();
let (stdout_tx, stdout_rx) = async_pipe::pipe();
let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
+ let thread = server.thread.upgrade().unwrap();
let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
(server, agent)
}
@@ -1199,15 +1200,6 @@ mod tests {
Ok(acp::AuthenticateResponse)
}
- async fn create_thread(
- &self,
- _request: acp::CreateThreadParams,
- ) -> Result<acp::CreateThreadResponse> {
- Ok(acp::CreateThreadResponse {
- thread_id: acp::ThreadId("test-thread".into()),
- })
- }
-
async fn send_user_message(
&self,
request: acp::SendUserMessageParams,
@@ -1,9 +1,8 @@
-use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
+use crate::{AcpThread, ThreadEntryId, ToolCallId, ToolCallRequest};
use agentic_coding_protocol as acp;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
-use collections::HashMap;
-use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
+use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use parking_lot::Mutex;
use project::Project;
use smol::process::Child;
@@ -11,37 +10,23 @@ use std::{process::ExitStatus, sync::Arc};
use util::ResultExt;
pub struct AcpServer {
- connection: Arc<acp::AgentConnection>,
- threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
+ thread: WeakEntity<AcpThread>,
project: Entity<Project>,
+ connection: Arc<acp::AgentConnection>,
exit_status: Arc<Mutex<Option<ExitStatus>>>,
_handler_task: Task<()>,
_io_task: Task<()>,
}
struct AcpClientDelegate {
- threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
+ thread: WeakEntity<AcpThread>,
cx: AsyncApp,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
impl AcpClientDelegate {
- fn new(threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>, cx: AsyncApp) -> Self {
- Self { threads, cx: cx }
- }
-
- fn update_thread<R>(
- &self,
- thread_id: &ThreadId,
- cx: &mut App,
- callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
- ) -> Option<R> {
- let thread = self.threads.lock().get(&thread_id)?.clone();
- let Some(thread) = thread.upgrade() else {
- self.threads.lock().remove(&thread_id);
- return None;
- };
- Some(thread.update(cx, callback))
+ fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
+ Self { thread, cx }
}
}
@@ -54,7 +39,7 @@ impl acp::Client for AcpClientDelegate {
let cx = &mut self.cx.clone();
cx.update(|cx| {
- self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| {
+ self.thread.update(cx, |thread, cx| {
thread.push_assistant_chunk(params.chunk, cx)
});
})?;
@@ -69,7 +54,7 @@ impl acp::Client for AcpClientDelegate {
let cx = &mut self.cx.clone();
let ToolCallRequest { id, outcome } = cx
.update(|cx| {
- self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+ self.thread.update(cx, |thread, cx| {
thread.request_tool_call(
request.label,
request.icon,
@@ -94,7 +79,7 @@ impl acp::Client for AcpClientDelegate {
let cx = &mut self.cx.clone();
let entry_id = cx
.update(|cx| {
- self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+ self.thread.update(cx, |thread, cx| {
thread.push_tool_call(request.label, request.icon, request.content, cx)
})
})?
@@ -112,7 +97,7 @@ impl acp::Client for AcpClientDelegate {
let cx = &mut self.cx.clone();
cx.update(|cx| {
- self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+ self.thread.update(cx, |thread, cx| {
thread.update_tool_call(
request.tool_call_id.into(),
request.status,
@@ -132,31 +117,42 @@ impl AcpServer {
let stdin = process.stdin.take().expect("process didn't have stdin");
let stdout = process.stdout.take().expect("process didn't have stdout");
- let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
- let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
- AcpClientDelegate::new(threads.clone(), cx.to_async()),
- stdin,
- stdout,
- );
-
- let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
- let io_task = cx.background_spawn({
- let exit_status = exit_status.clone();
- async move {
- io_fut.await.log_err();
- let result = process.status().await.log_err();
- *exit_status.lock() = result;
- }
+ let mut connection = None;
+ cx.new(|cx| {
+ let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
+ AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
+ stdin,
+ stdout,
+ );
+
+ let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
+ let io_task = cx.background_spawn({
+ let exit_status = exit_status.clone();
+ async move {
+ io_fut.await.log_err();
+ let result = process.status().await.log_err();
+ *exit_status.lock() = result;
+ }
+ });
+
+ connection.replace(Arc::new(Self {
+ project: project.clone(),
+ connection: Arc::new(conn),
+ thread: cx.entity().downgrade(),
+ exit_status,
+ _handler_task: cx.foreground_executor().spawn(handler_fut),
+ _io_task: io_task,
+ }));
+
+ AcpThread::new(
+ connection.clone().unwrap(),
+ Vec::default(),
+ project.clone(),
+ cx,
+ )
});
- Arc::new(Self {
- project,
- connection: Arc::new(connection),
- threads,
- exit_status,
- _handler_task: cx.foreground_executor().spawn(handler_fut),
- _io_task: io_task,
- })
+ connection.unwrap()
}
#[cfg(test)]
@@ -166,29 +162,40 @@ impl AcpServer {
project: Entity<Project>,
cx: &mut App,
) -> Arc<Self> {
- let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
- let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
- AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
- stdin,
- stdout,
- );
-
- let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
- let io_task = cx.background_spawn({
- async move {
- io_fut.await.log_err();
- // todo!() exit status?
- }
+ let mut connection = None;
+ cx.new(|cx| {
+ let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
+ AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
+ stdin,
+ stdout,
+ );
+
+ let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
+ let io_task = cx.background_spawn({
+ async move {
+ io_fut.await.log_err();
+ // todo!() exit status?
+ }
+ });
+
+ connection.replace(Arc::new(Self {
+ project: project.clone(),
+ connection: Arc::new(conn),
+ thread: cx.entity().downgrade(),
+ exit_status,
+ _handler_task: cx.foreground_executor().spawn(handler_fut),
+ _io_task: io_task,
+ }));
+
+ AcpThread::new(
+ connection.clone().unwrap(),
+ Vec::default(),
+ project.clone(),
+ cx,
+ )
});
- Arc::new(Self {
- project,
- connection: Arc::new(connection),
- threads,
- exit_status,
- _handler_task: cx.foreground_executor().spawn(handler_fut),
- _io_task: io_task,
- })
+ connection.unwrap()
}
pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
@@ -207,49 +214,17 @@ impl AcpServer {
Ok(())
}
- pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
- let response = self
- .connection
- .request(acp::CreateThreadParams)
- .await
- .map_err(to_anyhow)?;
-
- let thread_id: ThreadId = response.thread_id.into();
- let server = self.clone();
- let thread = cx.new(|cx| {
- AcpThread::new(
- server,
- thread_id.clone(),
- Vec::default(),
- self.project.clone(),
- cx,
- )
- })?;
- self.threads.lock().insert(thread_id, thread.downgrade());
- Ok(thread)
- }
-
- pub async fn send_message(
- &self,
- thread_id: ThreadId,
- message: acp::UserMessage,
- _cx: &mut AsyncApp,
- ) -> Result<()> {
+ pub async fn send_message(&self, message: acp::UserMessage, _cx: &mut AsyncApp) -> Result<()> {
self.connection
- .request(acp::SendUserMessageParams {
- thread_id: thread_id.clone().into(),
- message,
- })
+ .request(acp::SendUserMessageParams { message })
.await
.map_err(to_anyhow)?;
Ok(())
}
- pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> {
+ pub async fn cancel_send_message(&self, _cx: &mut AsyncApp) -> Result<()> {
self.connection
- .request(acp::CancelSendMessageParams {
- thread_id: thread_id.clone().into(),
- })
+ .request(acp::CancelSendMessageParams)
.await
.map_err(to_anyhow)?;
Ok(())
@@ -270,18 +245,6 @@ fn to_anyhow(e: acp::Error) -> anyhow::Error {
anyhow::anyhow!(e.message)
}
-impl From<acp::ThreadId> for ThreadId {
- fn from(thread_id: acp::ThreadId) -> Self {
- Self(thread_id.0.into())
- }
-}
-
-impl From<ThreadId> for acp::ThreadId {
- fn from(thread_id: ThreadId) -> Self {
- acp::ThreadId(thread_id.0.to_string())
- }
-}
-
impl From<acp::ToolCallId> for ToolCallId {
fn from(tool_call_id: acp::ToolCallId) -> Self {
Self(ThreadEntryId(tool_call_id.0))