Add tool call with confirmation test

Agus Zubiaga created

Change summary

crates/acp/src/acp.rs | 121 ++++++++++++++++++++++++++++++++++++++++----
1 file changed, 109 insertions(+), 12 deletions(-)

Detailed changes

crates/acp/src/acp.rs 🔗

@@ -446,11 +446,13 @@ pub struct ToolCallRequest {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use futures::{FutureExt as _, channel::mpsc, select};
     use gpui::{AsyncApp, TestAppContext};
     use project::FakeFs;
     use serde_json::json;
     use settings::SettingsStore;
-    use std::{env, path::Path, process::Stdio};
+    use smol::stream::StreamExt as _;
+    use std::{env, path::Path, process::Stdio, time::Duration};
     use util::path;
 
     fn init_test(cx: &mut TestAppContext) {
@@ -523,9 +525,9 @@ mod tests {
             .unwrap();
         thread.read_with(cx, |thread, cx| {
             let AgentThreadEntryContent::ToolCall(ToolCall {
-                id,
                 display_name,
-                status: ToolCallStatus::Allowed { content, .. },
+                status: ToolCallStatus::Allowed { .. },
+                ..
             }) = &thread.entries()[1].content
             else {
                 panic!();
@@ -535,16 +537,112 @@ mod tests {
                 assert_eq!(md.source(), "ReadFile");
             });
 
-            // todo!
-            // description.read_with(cx, |md, _cx| {
-            //     assert!(
-            //         md.source().contains("foo"),
-            //         "Expected description to contain 'foo', but got {}",
-            //         md.source()
-            //     );
-            // });
+            assert!(matches!(
+                thread.entries[2].content,
+                AgentThreadEntryContent::Message(Message {
+                    role: Role::Assistant,
+                    ..
+                })
+            ));
+        });
+    }
+
+    #[gpui::test]
+    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        cx.executor().allow_parking();
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
+        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
+        let full_turn = thread.update(cx, |thread, cx| {
+            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
+        });
+
+        run_until_tool_call(&thread, cx).await;
+
+        let tool_call_id = thread.read_with(cx, |thread, cx| {
+            let AgentThreadEntryContent::ToolCall(ToolCall {
+                id,
+                display_name,
+                status:
+                    ToolCallStatus::WaitingForConfirmation {
+                        confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
+                        ..
+                    },
+            }) = &thread.entries()[1].content
+            else {
+                panic!();
+            };
+
+            assert_eq!(root_command, "echo");
+
+            display_name.read_with(cx, |md, _cx| {
+                assert_eq!(md.source(), "Shell");
+            });
+
             *id
         });
+
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
+
+            assert!(matches!(
+                &thread.entries()[1].content,
+                AgentThreadEntryContent::ToolCall(ToolCall {
+                    status: ToolCallStatus::Allowed { .. },
+                    ..
+                })
+            ));
+        });
+
+        full_turn.await.unwrap();
+
+        thread.read_with(cx, |thread, cx| {
+            let AgentThreadEntryContent::ToolCall(ToolCall {
+                status: ToolCallStatus::Allowed { content, .. },
+                ..
+            }) = &thread.entries()[1].content
+            else {
+                panic!();
+            };
+
+            content.as_ref().unwrap().read_with(cx, |md, _cx| {
+                assert!(
+                    md.source().contains("Hello, world!"),
+                    r#"Expected '{}' to contain "Hello, world!""#,
+                    md.source()
+                );
+            });
+        });
+    }
+
+    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
+        let (mut tx, mut rx) = mpsc::channel::<()>(1);
+
+        let subscription = cx.update(|cx| {
+            cx.subscribe(thread, move |thread, _, cx| {
+                if thread
+                    .read(cx)
+                    .entries
+                    .iter()
+                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
+                {
+                    tx.try_send(()).unwrap();
+                }
+            })
+        });
+
+        select! {
+            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
+                panic!("Timeout waiting for tool call")
+            }
+            _ = rx.next().fuse() => {
+                drop(subscription);
+            }
+        }
     }
 
     pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
@@ -554,7 +652,6 @@ mod tests {
         command
             .arg(cli_path)
             .arg("--acp")
-            .args(["--model", "gemini-2.5-flash"])
             .current_dir("/private/tmp")
             .stdin(Stdio::piped())
             .stdout(Stdio::piped())