diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index 7f75a3ea046b3165aa17b6d43b5402efef12cd04..34dfd15fc978805259b7bd9368b7a11505031a02 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -10,8 +10,10 @@ use futures::channel::oneshot; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _}; use markdown::Markdown; +use parking_lot::Mutex; +use parking_lot::Mutex; use project::Project; -use std::{mem, ops::Range, path::PathBuf, sync::Arc}; +use std::{mem, ops::Range, path::PathBuf, process::ExitStatus, sync::Arc}; use ui::{App, IconName}; use util::{ResultExt, debug_panic}; @@ -377,13 +379,17 @@ pub struct ThreadEntry { } pub struct AcpThread { - id: ThreadId, next_entry_id: ThreadEntryId, entries: Vec, server: Arc, title: SharedString, project: Entity, send_task: Option>, + + connection: Arc, + exit_status: Arc>>, + _handler_task: Task<()>, + _io_task: Task<()>, } enum AcpThreadEvent { @@ -403,7 +409,6 @@ impl EventEmitter for AcpThread {} impl AcpThread { pub fn new( server: Arc, - thread_id: ThreadId, entries: Vec, project: Entity, _: &mut Context, @@ -419,7 +424,6 @@ impl AcpThread { }) .collect(), server, - id: thread_id, next_entry_id, project, send_task: None, @@ -680,7 +684,6 @@ impl AcpThread { cx: &mut Context, ) -> impl use<> + Future> { let agent = self.server.clone(); - let id = self.id.clone(); let chunk = UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx); let message = UserMessage { @@ -695,7 +698,7 @@ impl AcpThread { self.send_task = Some(cx.spawn(async move |this, cx| { cancel.await.log_err(); - let result = agent.send_message(id, acp_message, cx).await; + let result = agent.send_message(acp_message, cx).await; tx.send(result).log_err(); this.update(cx, |this, _cx| this.send_task.take()).log_err(); })); @@ -710,11 +713,10 @@ impl AcpThread { pub fn cancel(&mut self, cx: &mut Context) -> Task> { let agent = self.server.clone(); - let id = self.id.clone(); if self.send_task.take().is_some() { cx.spawn(async move |this, cx| { - agent.cancel_send_message(id, cx).await?; + agent.cancel_send_message(cx).await?; this.update(cx, |this, _cx| { for entry in this.entries.iter_mut() { @@ -851,7 +853,6 @@ mod tests { server .update(&mut cx, |server, _| { server.send_to_zed(acp::StreamAssistantMessageChunkParams { - thread_id: params.thread_id.clone(), chunk: acp::AssistantMessageChunk::Thought { chunk: "Thinking ".into(), }, @@ -862,7 +863,6 @@ mod tests { server .update(&mut cx, |server, _| { server.send_to_zed(acp::StreamAssistantMessageChunkParams { - thread_id: params.thread_id, chunk: acp::AssistantMessageChunk::Thought { chunk: "hard!".into(), }, @@ -1151,10 +1151,11 @@ mod tests { pub fn fake_acp_server( project: Entity, cx: &mut TestAppContext, - ) -> (Arc, Entity) { + ) -> (Entity, Arc, Entity) { let (stdin_tx, stdin_rx) = async_pipe::pipe(); let (stdout_tx, stdout_rx) = async_pipe::pipe(); let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx)); + let thread = server.thread.upgrade().unwrap(); let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); (server, agent) } @@ -1199,15 +1200,6 @@ mod tests { Ok(acp::AuthenticateResponse) } - async fn create_thread( - &self, - _request: acp::CreateThreadParams, - ) -> Result { - Ok(acp::CreateThreadResponse { - thread_id: acp::ThreadId("test-thread".into()), - }) - } - async fn send_user_message( &self, request: acp::SendUserMessageParams, diff --git a/crates/acp/src/server.rs b/crates/acp/src/server.rs index c826f2ff39fc99a628ed2561c8b6435c02cb1983..602323c970656b4979f8035d8c190576af689f56 100644 --- a/crates/acp/src/server.rs +++ b/crates/acp/src/server.rs @@ -1,9 +1,8 @@ -use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest}; +use crate::{AcpThread, ThreadEntryId, ToolCallId, ToolCallRequest}; use agentic_coding_protocol as acp; use anyhow::{Context as _, Result}; use async_trait::async_trait; -use collections::HashMap; -use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use parking_lot::Mutex; use project::Project; use smol::process::Child; @@ -11,37 +10,23 @@ use std::{process::ExitStatus, sync::Arc}; use util::ResultExt; pub struct AcpServer { - connection: Arc, - threads: Arc>>>, + thread: WeakEntity, project: Entity, + connection: Arc, exit_status: Arc>>, _handler_task: Task<()>, _io_task: Task<()>, } struct AcpClientDelegate { - threads: Arc>>>, + thread: WeakEntity, cx: AsyncApp, // sent_buffer_versions: HashMap, HashMap>, } impl AcpClientDelegate { - fn new(threads: Arc>>>, cx: AsyncApp) -> Self { - Self { threads, cx: cx } - } - - fn update_thread( - &self, - thread_id: &ThreadId, - cx: &mut App, - callback: impl FnOnce(&mut AcpThread, &mut Context) -> R, - ) -> Option { - let thread = self.threads.lock().get(&thread_id)?.clone(); - let Some(thread) = thread.upgrade() else { - self.threads.lock().remove(&thread_id); - return None; - }; - Some(thread.update(cx, callback)) + fn new(thread: WeakEntity, cx: AsyncApp) -> Self { + Self { thread, cx } } } @@ -54,7 +39,7 @@ impl acp::Client for AcpClientDelegate { let cx = &mut self.cx.clone(); cx.update(|cx| { - self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| { + self.thread.update(cx, |thread, cx| { thread.push_assistant_chunk(params.chunk, cx) }); })?; @@ -69,7 +54,7 @@ impl acp::Client for AcpClientDelegate { let cx = &mut self.cx.clone(); let ToolCallRequest { id, outcome } = cx .update(|cx| { - self.update_thread(&request.thread_id.into(), cx, |thread, cx| { + self.thread.update(cx, |thread, cx| { thread.request_tool_call( request.label, request.icon, @@ -94,7 +79,7 @@ impl acp::Client for AcpClientDelegate { let cx = &mut self.cx.clone(); let entry_id = cx .update(|cx| { - self.update_thread(&request.thread_id.into(), cx, |thread, cx| { + self.thread.update(cx, |thread, cx| { thread.push_tool_call(request.label, request.icon, request.content, cx) }) })? @@ -112,7 +97,7 @@ impl acp::Client for AcpClientDelegate { let cx = &mut self.cx.clone(); cx.update(|cx| { - self.update_thread(&request.thread_id.into(), cx, |thread, cx| { + self.thread.update(cx, |thread, cx| { thread.update_tool_call( request.tool_call_id.into(), request.status, @@ -132,31 +117,42 @@ impl AcpServer { let stdin = process.stdin.take().expect("process didn't have stdin"); let stdout = process.stdout.take().expect("process didn't have stdout"); - let threads: Arc>>> = Default::default(); - let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(threads.clone(), cx.to_async()), - stdin, - stdout, - ); - - let exit_status: Arc>> = Default::default(); - let io_task = cx.background_spawn({ - let exit_status = exit_status.clone(); - async move { - io_fut.await.log_err(); - let result = process.status().await.log_err(); - *exit_status.lock() = result; - } + let mut connection = None; + cx.new(|cx| { + let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( + AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + stdin, + stdout, + ); + + let exit_status: Arc>> = Default::default(); + let io_task = cx.background_spawn({ + let exit_status = exit_status.clone(); + async move { + io_fut.await.log_err(); + let result = process.status().await.log_err(); + *exit_status.lock() = result; + } + }); + + connection.replace(Arc::new(Self { + project: project.clone(), + connection: Arc::new(conn), + thread: cx.entity().downgrade(), + exit_status, + _handler_task: cx.foreground_executor().spawn(handler_fut), + _io_task: io_task, + })); + + AcpThread::new( + connection.clone().unwrap(), + Vec::default(), + project.clone(), + cx, + ) }); - Arc::new(Self { - project, - connection: Arc::new(connection), - threads, - exit_status, - _handler_task: cx.foreground_executor().spawn(handler_fut), - _io_task: io_task, - }) + connection.unwrap() } #[cfg(test)] @@ -166,29 +162,40 @@ impl AcpServer { project: Entity, cx: &mut App, ) -> Arc { - let threads: Arc>>> = Default::default(); - let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()), - stdin, - stdout, - ); - - let exit_status: Arc>> = Default::default(); - let io_task = cx.background_spawn({ - async move { - io_fut.await.log_err(); - // todo!() exit status? - } + let mut connection = None; + cx.new(|cx| { + let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( + AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + stdin, + stdout, + ); + + let exit_status: Arc>> = Default::default(); + let io_task = cx.background_spawn({ + async move { + io_fut.await.log_err(); + // todo!() exit status? + } + }); + + connection.replace(Arc::new(Self { + project: project.clone(), + connection: Arc::new(conn), + thread: cx.entity().downgrade(), + exit_status, + _handler_task: cx.foreground_executor().spawn(handler_fut), + _io_task: io_task, + })); + + AcpThread::new( + connection.clone().unwrap(), + Vec::default(), + project.clone(), + cx, + ) }); - Arc::new(Self { - project, - connection: Arc::new(connection), - threads, - exit_status, - _handler_task: cx.foreground_executor().spawn(handler_fut), - _io_task: io_task, - }) + connection.unwrap() } pub async fn initialize(&self) -> Result { @@ -207,49 +214,17 @@ impl AcpServer { Ok(()) } - pub async fn create_thread(self: Arc, cx: &mut AsyncApp) -> Result> { - let response = self - .connection - .request(acp::CreateThreadParams) - .await - .map_err(to_anyhow)?; - - let thread_id: ThreadId = response.thread_id.into(); - let server = self.clone(); - let thread = cx.new(|cx| { - AcpThread::new( - server, - thread_id.clone(), - Vec::default(), - self.project.clone(), - cx, - ) - })?; - self.threads.lock().insert(thread_id, thread.downgrade()); - Ok(thread) - } - - pub async fn send_message( - &self, - thread_id: ThreadId, - message: acp::UserMessage, - _cx: &mut AsyncApp, - ) -> Result<()> { + pub async fn send_message(&self, message: acp::UserMessage, _cx: &mut AsyncApp) -> Result<()> { self.connection - .request(acp::SendUserMessageParams { - thread_id: thread_id.clone().into(), - message, - }) + .request(acp::SendUserMessageParams { message }) .await .map_err(to_anyhow)?; Ok(()) } - pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> { + pub async fn cancel_send_message(&self, _cx: &mut AsyncApp) -> Result<()> { self.connection - .request(acp::CancelSendMessageParams { - thread_id: thread_id.clone().into(), - }) + .request(acp::CancelSendMessageParams) .await .map_err(to_anyhow)?; Ok(()) @@ -270,18 +245,6 @@ fn to_anyhow(e: acp::Error) -> anyhow::Error { anyhow::anyhow!(e.message) } -impl From for ThreadId { - fn from(thread_id: acp::ThreadId) -> Self { - Self(thread_id.0.into()) - } -} - -impl From for acp::ThreadId { - fn from(thread_id: ThreadId) -> Self { - acp::ThreadId(thread_id.0.to_string()) - } -} - impl From for ToolCallId { fn from(tool_call_id: acp::ToolCallId) -> Self { Self(ThreadEntryId(tool_call_id.0))