diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index 33fa793e848d25a4bcedf3588b9652e3e64ae575..06faf048ecdacd6b8e9079cee5c511cb793f0760 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -446,11 +446,13 @@ pub struct ToolCallRequest { #[cfg(test)] mod tests { use super::*; + use futures::{FutureExt as _, channel::mpsc, select}; use gpui::{AsyncApp, TestAppContext}; use project::FakeFs; use serde_json::json; use settings::SettingsStore; - use std::{env, path::Path, process::Stdio}; + use smol::stream::StreamExt as _; + use std::{env, path::Path, process::Stdio, time::Duration}; use util::path; fn init_test(cx: &mut TestAppContext) { @@ -523,9 +525,9 @@ mod tests { .unwrap(); thread.read_with(cx, |thread, cx| { let AgentThreadEntryContent::ToolCall(ToolCall { - id, display_name, - status: ToolCallStatus::Allowed { content, .. }, + status: ToolCallStatus::Allowed { .. }, + .. }) = &thread.entries()[1].content else { panic!(); @@ -535,16 +537,112 @@ mod tests { assert_eq!(md.source(), "ReadFile"); }); - // todo! - // description.read_with(cx, |md, _cx| { - // assert!( - // md.source().contains("foo"), - // "Expected description to contain 'foo', but got {}", - // md.source() - // ); - // }); + assert!(matches!( + thread.entries[2].content, + AgentThreadEntryContent::Message(Message { + role: Role::Assistant, + .. + }) + )); + }); + } + + #[gpui::test] + async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap(); + let thread = server.create_thread(&mut cx.to_async()).await.unwrap(); + let full_turn = thread.update(cx, |thread, cx| { + thread.send(r#"Run `echo "Hello, world!"`"#, cx) + }); + + run_until_tool_call(&thread, cx).await; + + let tool_call_id = thread.read_with(cx, |thread, cx| { + let AgentThreadEntryContent::ToolCall(ToolCall { + id, + display_name, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: acp::ToolCallConfirmation::Execute { root_command, .. }, + .. + }, + }) = &thread.entries()[1].content + else { + panic!(); + }; + + assert_eq!(root_command, "echo"); + + display_name.read_with(cx, |md, _cx| { + assert_eq!(md.source(), "Shell"); + }); + *id }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); + + assert!(matches!( + &thread.entries()[1].content, + AgentThreadEntryContent::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + )); + }); + + full_turn.await.unwrap(); + + thread.read_with(cx, |thread, cx| { + let AgentThreadEntryContent::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { content, .. }, + .. + }) = &thread.entries()[1].content + else { + panic!(); + }; + + content.as_ref().unwrap().read_with(cx, |md, _cx| { + assert!( + md.source().contains("Hello, world!"), + r#"Expected '{}' to contain "Hello, world!""#, + md.source() + ); + }); + }); + } + + async fn run_until_tool_call(thread: &Entity, cx: &mut TestAppContext) { + let (mut tx, mut rx) = mpsc::channel::<()>(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + if thread + .read(cx) + .entries + .iter() + .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_))) + { + tx.try_send(()).unwrap(); + } + }) + }); + + select! { + _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => { + panic!("Timeout waiting for tool call") + } + _ = rx.next().fuse() => { + drop(subscription); + } + } } pub fn gemini_acp_server(project: Entity, mut cx: AsyncApp) -> Result> { @@ -554,7 +652,6 @@ mod tests { command .arg(cli_path) .arg("--acp") - .args(["--model", "gemini-2.5-flash"]) .current_dir("/private/tmp") .stdin(Stdio::piped()) .stdout(Stdio::piped())