Merge branch 'acp' of github.com:zed-industries/zed into acp

Agus Zubiaga created

Change summary

crates/acp/src/acp.rs         | 119 +++++++++---------------------------
crates/acp/src/thread_view.rs |   2 
2 files changed, 30 insertions(+), 91 deletions(-)

Detailed changes

crates/acp/src/acp.rs 🔗

@@ -123,7 +123,7 @@ pub enum AgentThreadEntryContent {
 #[derive(Debug)]
 pub struct ToolCall {
     id: ToolCallId,
-    tool_name: Entity<Markdown>,
+    display_name: Entity<Markdown>,
     status: ToolCallStatus,
 }
 
@@ -271,7 +271,7 @@ impl AcpThread {
 
     pub fn request_tool_call(
         &mut self,
-        title: String,
+        display_name: String,
         confirmation: acp::ToolCallConfirmation,
         cx: &mut Context<Self>,
     ) -> ToolCallRequest {
@@ -282,22 +282,22 @@ impl AcpThread {
             respond_tx: tx,
         };
 
-        let id = self.insert_tool_call(title, status, cx);
+        let id = self.insert_tool_call(display_name, status, cx);
         ToolCallRequest { id, outcome: rx }
     }
 
-    pub fn push_tool_call(&mut self, title: String, cx: &mut Context<Self>) -> ToolCallId {
+    pub fn push_tool_call(&mut self, display_name: String, cx: &mut Context<Self>) -> ToolCallId {
         let status = ToolCallStatus::Allowed {
             status: acp::ToolCallStatus::Running,
             content: None,
         };
 
-        self.insert_tool_call(title, status, cx)
+        self.insert_tool_call(display_name, status, cx)
     }
 
     fn insert_tool_call(
         &mut self,
-        title: String,
+        display_name: String,
         status: ToolCallStatus,
         cx: &mut Context<Self>,
     ) -> ToolCallId {
@@ -307,8 +307,13 @@ impl AcpThread {
             AgentThreadEntryContent::ToolCall(ToolCall {
                 // todo! clean up id creation
                 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
-                tool_name: cx.new(|cx| {
-                    Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
+                display_name: cx.new(|cx| {
+                    Markdown::new(
+                        display_name.into(),
+                        Some(language_registry.clone()),
+                        None,
+                        cx,
+                    )
                 }),
                 status,
             }),
@@ -441,13 +446,11 @@ pub struct ToolCallRequest {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use futures::{FutureExt as _, channel::mpsc, select};
     use gpui::{AsyncApp, TestAppContext};
     use project::FakeFs;
     use serde_json::json;
     use settings::SettingsStore;
-    use smol::stream::StreamExt;
-    use std::{env, path::Path, process::Stdio, time::Duration};
+    use std::{env, path::Path, process::Stdio};
     use util::path;
 
     fn init_test(cx: &mut TestAppContext) {
@@ -509,27 +512,27 @@ mod tests {
         let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
         let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
         let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
-        let full_turn = thread.update(cx, |thread, cx| {
-            thread.send(
-                "Read the '/private/tmp/foo' file and tell me what you see.",
-                cx,
-            )
-        });
-
-        run_until_tool_call(&thread, cx).await;
-
-        let tool_call_id = thread.read_with(cx, |thread, cx| {
+        thread
+            .update(cx, |thread, cx| {
+                thread.send(
+                    "Read the '/private/tmp/foo' file and tell me what you see.",
+                    cx,
+                )
+            })
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
             let AgentThreadEntryContent::ToolCall(ToolCall {
                 id,
-                tool_name,
-                status: ToolCallStatus::Allowed { .. },
-            }) = &thread.entries().last().unwrap().content
+                display_name,
+                status: ToolCallStatus::Allowed { content, .. },
+            }) = &thread.entries()[1].content
             else {
                 panic!();
             };
 
-            tool_name.read_with(cx, |md, _cx| {
-                assert_eq!(md.source(), "read_file");
+            display_name.read_with(cx, |md, _cx| {
+                assert_eq!(md.source(), "ReadFile");
             });
 
             // todo!
@@ -542,70 +545,6 @@ mod tests {
             // });
             *id
         });
-
-        thread.update(cx, |thread, cx| {
-            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
-            assert!(matches!(
-                thread.entries().last().unwrap().content,
-                AgentThreadEntryContent::ToolCall(ToolCall {
-                    status: ToolCallStatus::Allowed { .. },
-                    ..
-                })
-            ));
-        });
-
-        full_turn.await.unwrap();
-
-        thread.read_with(cx, |thread, _| {
-            assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
-            assert!(matches!(
-                thread.entries[0].content,
-                AgentThreadEntryContent::Message(Message {
-                    role: Role::User,
-                    ..
-                })
-            ));
-            assert!(matches!(
-                thread.entries[1].content,
-                AgentThreadEntryContent::ToolCall(ToolCall {
-                    status: ToolCallStatus::Allowed { .. },
-                    ..
-                })
-            ));
-            assert!(matches!(
-                thread.entries[2].content,
-                AgentThreadEntryContent::Message(Message {
-                    role: Role::Assistant,
-                    ..
-                })
-            ));
-        });
-    }
-
-    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
-        let (mut tx, mut rx) = mpsc::channel(1);
-
-        let subscription = cx.update(|cx| {
-            cx.subscribe(thread, move |thread, _, cx| {
-                if thread
-                    .read(cx)
-                    .entries
-                    .iter()
-                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
-                {
-                    tx.try_send(()).unwrap();
-                }
-            })
-        });
-
-        select! {
-            _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
-                panic!("Timeout waiting for tool call")
-            }
-            _ = rx.next().fuse() => {
-                drop(subscription);
-            }
-        }
     }
 
     pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {

crates/acp/src/thread_view.rs 🔗

@@ -329,7 +329,7 @@ impl AcpThreadView {
                             .color(Color::Muted),
                     )
                     .child(MarkdownElement::new(
-                        tool_call.tool_name.clone(),
+                        tool_call.display_name.clone(),
                         default_markdown_style(window, cx),
                     ))
                     .child(div().w_full())