Remove ReadFile entry and test tool call

Agus Zubiaga created

Change summary

crates/acp/src/acp.rs         | 146 ++++++++++++++++++++++++++++++------
crates/acp/src/server.rs      |  15 ---
crates/acp/src/thread_view.rs |   6 -
3 files changed, 123 insertions(+), 44 deletions(-)

Detailed changes

crates/acp/src/acp.rs 🔗

@@ -117,7 +117,6 @@ impl MessageChunk {
 #[derive(Debug)]
 pub enum AgentThreadEntryContent {
     Message(Message),
-    ReadFile { path: PathBuf, content: String },
     ToolCall(ToolCall),
 }
 
@@ -343,15 +342,17 @@ impl AcpThread {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use futures::{FutureExt as _, channel::mpsc, select};
     use gpui::{AsyncApp, TestAppContext};
     use project::FakeFs;
     use serde_json::json;
     use settings::SettingsStore;
-    use std::{env, path::Path, process::Stdio};
+    use smol::stream::StreamExt;
+    use std::{env, path::Path, process::Stdio, time::Duration};
     use util::path;
 
     fn init_test(cx: &mut TestAppContext) {
-        env_logger::init();
+        env_logger::try_init().ok();
         cx.update(|cx| {
             let settings_store = SettingsStore::test(cx);
             cx.set_global(settings_store);
@@ -361,7 +362,41 @@ mod tests {
     }
 
     #[gpui::test]
-    async fn test_gemini(cx: &mut TestAppContext) {
+    async fn test_gemini_basic(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        cx.executor().allow_parking();
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
+        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
+        thread
+            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
+            .await
+            .unwrap();
+
+        thread.read_with(cx, |thread, _| {
+            assert_eq!(thread.entries.len(), 2);
+            assert!(matches!(
+                thread.entries[0].content,
+                AgentThreadEntryContent::Message(Message {
+                    role: Role::User,
+                    ..
+                })
+            ));
+            assert!(matches!(
+                thread.entries[1].content,
+                AgentThreadEntryContent::Message(Message {
+                    role: Role::Assistant,
+                    ..
+                })
+            ));
+        });
+    }
+
+    #[gpui::test]
+    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
         init_test(cx);
 
         cx.executor().allow_parking();
@@ -375,17 +410,52 @@ mod tests {
         let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
         let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
         let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
-        thread
-            .update(cx, |thread, cx| {
-                thread.send(
-                    "Read the '/private/tmp/foo' file and output all of its contents.",
-                    cx,
-                )
-            })
-            .await
-            .unwrap();
+        let full_turn = thread.update(cx, |thread, cx| {
+            thread.send(
+                "Read the '/private/tmp/foo' file and tell me what you see.",
+                cx,
+            )
+        });
+
+        run_until_tool_call(&thread, cx).await;
+
+        let tool_call_id = thread.read_with(cx, |thread, cx| {
+            let AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
+                id,
+                tool_name,
+                description,
+                ..
+            }) = &thread.entries().last().unwrap().content
+            else {
+                panic!();
+            };
+
+            tool_name.read_with(cx, |md, _cx| {
+                assert_eq!(md.source(), "read_file");
+            });
+
+            description.read_with(cx, |md, _cx| {
+                assert!(
+                    md.source().contains("foo"),
+                    "Expected description to contain 'foo', but got {}",
+                    md.source()
+                );
+            });
+            *id
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(tool_call_id, true, cx);
+            assert!(matches!(
+                thread.entries().last().unwrap().content,
+                AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
+            ));
+        });
+
+        full_turn.await.unwrap();
 
         thread.read_with(cx, |thread, _| {
+            assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
             assert!(matches!(
                 thread.entries[0].content,
                 AgentThreadEntryContent::Message(Message {
@@ -393,20 +463,44 @@ mod tests {
                     ..
                 })
             ));
-            assert!(
-                thread.entries().iter().any(|entry| {
-                    match &entry.content {
-                        AgentThreadEntryContent::ReadFile { path, content } => {
-                            path.to_string_lossy().to_string() == "/private/tmp/foo"
-                                && content == "Lorem ipsum dolor"
-                        }
-                        _ => false,
-                    }
-                }),
-                "Thread does not contain entry. Actual: {:?}",
-                thread.entries()
-            );
+            assert!(matches!(
+                thread.entries[1].content,
+                AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
+            ));
+            assert!(matches!(
+                thread.entries[2].content,
+                AgentThreadEntryContent::Message(Message {
+                    role: Role::Assistant,
+                    ..
+                })
+            ));
+        });
+    }
+
+    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
+        let (mut tx, mut rx) = mpsc::channel(1);
+
+        let subscription = cx.update(|cx| {
+            cx.subscribe(thread, move |thread, _, cx| {
+                if thread
+                    .read(cx)
+                    .entries
+                    .iter()
+                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
+                {
+                    tx.try_send(()).unwrap();
+                }
+            })
         });
+
+        select! {
+            _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
+                panic!("Timeout waiting for tool call")
+            }
+            _ = rx.next().fuse() => {
+                drop(subscription);
+            }
+        }
     }
 
     pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {

crates/acp/src/server.rs 🔗

@@ -1,4 +1,4 @@
-use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId, ToolCallId};
+use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId};
 use agentic_coding_protocol as acp;
 use anyhow::{Context as _, Result};
 use async_trait::async_trait;
@@ -107,7 +107,7 @@ impl acp::Client for AcpClientDelegate {
             })??
             .await?;
 
-        buffer.update(cx, |buffer, cx| {
+        buffer.update(cx, |buffer, _cx| {
             let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
             let end = match request.line_limit {
                 None => buffer.max_point(),
@@ -115,15 +115,6 @@ impl acp::Client for AcpClientDelegate {
             };
 
             let content: String = buffer.text_for_range(start..end).collect();
-            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
-                thread.push_entry(
-                    AgentThreadEntryContent::ReadFile {
-                        path: request.path.clone(),
-                        content: content.clone(),
-                    },
-                    cx,
-                );
-            });
 
             acp::ReadTextFileResponse {
                 content,
@@ -203,7 +194,7 @@ impl acp::Client for AcpClientDelegate {
             })?
             .context("Failed to update thread")?;
 
-        if dbg!(rx.await)? {
+        if rx.await? {
             Ok(acp::RequestToolCallResponse::Allowed {
                 id: entry_id.into(),
             })

crates/acp/src/thread_view.rs 🔗

@@ -241,12 +241,6 @@ impl AcpThreadView {
                         .into_any(),
                 }
             }
-            AgentThreadEntryContent::ReadFile { path, content: _ } => {
-                // todo!
-                div()
-                    .child(format!("<Reading file {}>", path.display()))
-                    .into_any()
-            }
             AgentThreadEntryContent::ToolCall(tool_call) => match tool_call {
                 ToolCall::WaitingForConfirmation {
                     id,