TEMP

Conrad Irwin created

Change summary

crates/acp/src/acp.rs    |  32 ++----
crates/acp/src/server.rs | 199 +++++++++++++++++------------------------
2 files changed, 93 insertions(+), 138 deletions(-)

Detailed changes

crates/acp/src/acp.rs 🔗

@@ -10,8 +10,10 @@ use futures::channel::oneshot;
 use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
 use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
 use markdown::Markdown;
+use parking_lot::Mutex;
+use parking_lot::Mutex;
 use project::Project;
-use std::{mem, ops::Range, path::PathBuf, sync::Arc};
+use std::{mem, ops::Range, path::PathBuf, process::ExitStatus, sync::Arc};
 use ui::{App, IconName};
 use util::{ResultExt, debug_panic};
 
@@ -377,13 +379,17 @@ pub struct ThreadEntry {
 }
 
 pub struct AcpThread {
-    id: ThreadId,
     next_entry_id: ThreadEntryId,
     entries: Vec<ThreadEntry>,
     server: Arc<AcpServer>,
     title: SharedString,
     project: Entity<Project>,
     send_task: Option<Task<()>>,
+
+    connection: Arc<acp::AgentConnection>,
+    exit_status: Arc<Mutex<Option<ExitStatus>>>,
+    _handler_task: Task<()>,
+    _io_task: Task<()>,
 }
 
 enum AcpThreadEvent {
@@ -403,7 +409,6 @@ impl EventEmitter<AcpThreadEvent> for AcpThread {}
 impl AcpThread {
     pub fn new(
         server: Arc<AcpServer>,
-        thread_id: ThreadId,
         entries: Vec<AgentThreadEntryContent>,
         project: Entity<Project>,
         _: &mut Context<Self>,
@@ -419,7 +424,6 @@ impl AcpThread {
                 })
                 .collect(),
             server,
-            id: thread_id,
             next_entry_id,
             project,
             send_task: None,
@@ -680,7 +684,6 @@ impl AcpThread {
         cx: &mut Context<Self>,
     ) -> impl use<> + Future<Output = Result<()>> {
         let agent = self.server.clone();
-        let id = self.id.clone();
         let chunk =
             UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
         let message = UserMessage {
@@ -695,7 +698,7 @@ impl AcpThread {
         self.send_task = Some(cx.spawn(async move |this, cx| {
             cancel.await.log_err();
 
-            let result = agent.send_message(id, acp_message, cx).await;
+            let result = agent.send_message(acp_message, cx).await;
             tx.send(result).log_err();
             this.update(cx, |this, _cx| this.send_task.take()).log_err();
         }));
@@ -710,11 +713,10 @@ impl AcpThread {
 
     pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
         let agent = self.server.clone();
-        let id = self.id.clone();
 
         if self.send_task.take().is_some() {
             cx.spawn(async move |this, cx| {
-                agent.cancel_send_message(id, cx).await?;
+                agent.cancel_send_message(cx).await?;
 
                 this.update(cx, |this, _cx| {
                     for entry in this.entries.iter_mut() {
@@ -851,7 +853,6 @@ mod tests {
                 server
                     .update(&mut cx, |server, _| {
                         server.send_to_zed(acp::StreamAssistantMessageChunkParams {
-                            thread_id: params.thread_id.clone(),
                             chunk: acp::AssistantMessageChunk::Thought {
                                 chunk: "Thinking ".into(),
                             },
@@ -862,7 +863,6 @@ mod tests {
                 server
                     .update(&mut cx, |server, _| {
                         server.send_to_zed(acp::StreamAssistantMessageChunkParams {
-                            thread_id: params.thread_id,
                             chunk: acp::AssistantMessageChunk::Thought {
                                 chunk: "hard!".into(),
                             },
@@ -1151,10 +1151,11 @@ mod tests {
     pub fn fake_acp_server(
         project: Entity<Project>,
         cx: &mut TestAppContext,
-    ) -> (Arc<AcpServer>, Entity<FakeAcpServer>) {
+    ) -> (Entity<Thread>, Arc<AcpServer>, Entity<FakeAcpServer>) {
         let (stdin_tx, stdin_rx) = async_pipe::pipe();
         let (stdout_tx, stdout_rx) = async_pipe::pipe();
         let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
+        let thread = server.thread.upgrade().unwrap();
         let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
         (server, agent)
     }
@@ -1199,15 +1200,6 @@ mod tests {
             Ok(acp::AuthenticateResponse)
         }
 
-        async fn create_thread(
-            &self,
-            _request: acp::CreateThreadParams,
-        ) -> Result<acp::CreateThreadResponse> {
-            Ok(acp::CreateThreadResponse {
-                thread_id: acp::ThreadId("test-thread".into()),
-            })
-        }
-
         async fn send_user_message(
             &self,
             request: acp::SendUserMessageParams,

crates/acp/src/server.rs 🔗

@@ -1,9 +1,8 @@
-use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
+use crate::{AcpThread, ThreadEntryId, ToolCallId, ToolCallRequest};
 use agentic_coding_protocol as acp;
 use anyhow::{Context as _, Result};
 use async_trait::async_trait;
-use collections::HashMap;
-use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
+use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
 use parking_lot::Mutex;
 use project::Project;
 use smol::process::Child;
@@ -11,37 +10,23 @@ use std::{process::ExitStatus, sync::Arc};
 use util::ResultExt;
 
 pub struct AcpServer {
-    connection: Arc<acp::AgentConnection>,
-    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
+    thread: WeakEntity<AcpThread>,
     project: Entity<Project>,
+    connection: Arc<acp::AgentConnection>,
     exit_status: Arc<Mutex<Option<ExitStatus>>>,
     _handler_task: Task<()>,
     _io_task: Task<()>,
 }
 
 struct AcpClientDelegate {
-    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
+    thread: WeakEntity<AcpThread>,
     cx: AsyncApp,
     // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 }
 
 impl AcpClientDelegate {
-    fn new(threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>, cx: AsyncApp) -> Self {
-        Self { threads, cx: cx }
-    }
-
-    fn update_thread<R>(
-        &self,
-        thread_id: &ThreadId,
-        cx: &mut App,
-        callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
-    ) -> Option<R> {
-        let thread = self.threads.lock().get(&thread_id)?.clone();
-        let Some(thread) = thread.upgrade() else {
-            self.threads.lock().remove(&thread_id);
-            return None;
-        };
-        Some(thread.update(cx, callback))
+    fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
+        Self { thread, cx }
     }
 }
 
@@ -54,7 +39,7 @@ impl acp::Client for AcpClientDelegate {
         let cx = &mut self.cx.clone();
 
         cx.update(|cx| {
-            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
+            self.thread.update(cx, |thread, cx| {
                 thread.push_assistant_chunk(params.chunk, cx)
             });
         })?;
@@ -69,7 +54,7 @@ impl acp::Client for AcpClientDelegate {
         let cx = &mut self.cx.clone();
         let ToolCallRequest { id, outcome } = cx
             .update(|cx| {
-                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+                self.thread.update(cx, |thread, cx| {
                     thread.request_tool_call(
                         request.label,
                         request.icon,
@@ -94,7 +79,7 @@ impl acp::Client for AcpClientDelegate {
         let cx = &mut self.cx.clone();
         let entry_id = cx
             .update(|cx| {
-                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+                self.thread.update(cx, |thread, cx| {
                     thread.push_tool_call(request.label, request.icon, request.content, cx)
                 })
             })?
@@ -112,7 +97,7 @@ impl acp::Client for AcpClientDelegate {
         let cx = &mut self.cx.clone();
 
         cx.update(|cx| {
-            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
+            self.thread.update(cx, |thread, cx| {
                 thread.update_tool_call(
                     request.tool_call_id.into(),
                     request.status,
@@ -132,31 +117,42 @@ impl AcpServer {
         let stdin = process.stdin.take().expect("process didn't have stdin");
         let stdout = process.stdout.take().expect("process didn't have stdout");
 
-        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
-        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
-            AcpClientDelegate::new(threads.clone(), cx.to_async()),
-            stdin,
-            stdout,
-        );
-
-        let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
-        let io_task = cx.background_spawn({
-            let exit_status = exit_status.clone();
-            async move {
-                io_fut.await.log_err();
-                let result = process.status().await.log_err();
-                *exit_status.lock() = result;
-            }
+        let mut connection = None;
+        cx.new(|cx| {
+            let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
+                AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
+                stdin,
+                stdout,
+            );
+
+            let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
+            let io_task = cx.background_spawn({
+                let exit_status = exit_status.clone();
+                async move {
+                    io_fut.await.log_err();
+                    let result = process.status().await.log_err();
+                    *exit_status.lock() = result;
+                }
+            });
+
+            connection.replace(Arc::new(Self {
+                project: project.clone(),
+                connection: Arc::new(conn),
+                thread: cx.entity().downgrade(),
+                exit_status,
+                _handler_task: cx.foreground_executor().spawn(handler_fut),
+                _io_task: io_task,
+            }));
+
+            AcpThread::new(
+                connection.clone().unwrap(),
+                Vec::default(),
+                project.clone(),
+                cx,
+            )
         });
 
-        Arc::new(Self {
-            project,
-            connection: Arc::new(connection),
-            threads,
-            exit_status,
-            _handler_task: cx.foreground_executor().spawn(handler_fut),
-            _io_task: io_task,
-        })
+        connection.unwrap()
     }
 
     #[cfg(test)]
@@ -166,29 +162,40 @@ impl AcpServer {
         project: Entity<Project>,
         cx: &mut App,
     ) -> Arc<Self> {
-        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
-        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
-            AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
-            stdin,
-            stdout,
-        );
-
-        let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
-        let io_task = cx.background_spawn({
-            async move {
-                io_fut.await.log_err();
-                // todo!() exit status?
-            }
+        let mut connection = None;
+        cx.new(|cx| {
+            let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
+                AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
+                stdin,
+                stdout,
+            );
+
+            let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
+            let io_task = cx.background_spawn({
+                async move {
+                    io_fut.await.log_err();
+                    // todo!() exit status?
+                }
+            });
+
+            connection.replace(Arc::new(Self {
+                project: project.clone(),
+                connection: Arc::new(conn),
+                thread: cx.entity().downgrade(),
+                exit_status,
+                _handler_task: cx.foreground_executor().spawn(handler_fut),
+                _io_task: io_task,
+            }));
+
+            AcpThread::new(
+                connection.clone().unwrap(),
+                Vec::default(),
+                project.clone(),
+                cx,
+            )
         });
 
-        Arc::new(Self {
-            project,
-            connection: Arc::new(connection),
-            threads,
-            exit_status,
-            _handler_task: cx.foreground_executor().spawn(handler_fut),
-            _io_task: io_task,
-        })
+        connection.unwrap()
     }
 
     pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
@@ -207,49 +214,17 @@ impl AcpServer {
         Ok(())
     }
 
-    pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
-        let response = self
-            .connection
-            .request(acp::CreateThreadParams)
-            .await
-            .map_err(to_anyhow)?;
-
-        let thread_id: ThreadId = response.thread_id.into();
-        let server = self.clone();
-        let thread = cx.new(|cx| {
-            AcpThread::new(
-                server,
-                thread_id.clone(),
-                Vec::default(),
-                self.project.clone(),
-                cx,
-            )
-        })?;
-        self.threads.lock().insert(thread_id, thread.downgrade());
-        Ok(thread)
-    }
-
-    pub async fn send_message(
-        &self,
-        thread_id: ThreadId,
-        message: acp::UserMessage,
-        _cx: &mut AsyncApp,
-    ) -> Result<()> {
+    pub async fn send_message(&self, message: acp::UserMessage, _cx: &mut AsyncApp) -> Result<()> {
         self.connection
-            .request(acp::SendUserMessageParams {
-                thread_id: thread_id.clone().into(),
-                message,
-            })
+            .request(acp::SendUserMessageParams { message })
             .await
             .map_err(to_anyhow)?;
         Ok(())
     }
 
-    pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> {
+    pub async fn cancel_send_message(&self, _cx: &mut AsyncApp) -> Result<()> {
         self.connection
-            .request(acp::CancelSendMessageParams {
-                thread_id: thread_id.clone().into(),
-            })
+            .request(acp::CancelSendMessageParams)
             .await
             .map_err(to_anyhow)?;
         Ok(())
@@ -270,18 +245,6 @@ fn to_anyhow(e: acp::Error) -> anyhow::Error {
     anyhow::anyhow!(e.message)
 }
 
-impl From<acp::ThreadId> for ThreadId {
-    fn from(thread_id: acp::ThreadId) -> Self {
-        Self(thread_id.0.into())
-    }
-}
-
-impl From<ThreadId> for acp::ThreadId {
-    fn from(thread_id: ThreadId) -> Self {
-        acp::ThreadId(thread_id.0.to_string())
-    }
-}
-
 impl From<acp::ToolCallId> for ToolCallId {
     fn from(tool_call_id: acp::ToolCallId) -> Self {
         Self(ThreadEntryId(tool_call_id.0))