@@ -446,11 +446,13 @@ pub struct ToolCallRequest {
#[cfg(test)]
mod tests {
use super::*;
+ use futures::{FutureExt as _, channel::mpsc, select};
use gpui::{AsyncApp, TestAppContext};
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
- use std::{env, path::Path, process::Stdio};
+ use smol::stream::StreamExt as _;
+ use std::{env, path::Path, process::Stdio, time::Duration};
use util::path;
fn init_test(cx: &mut TestAppContext) {
@@ -523,9 +525,9 @@ mod tests {
.unwrap();
thread.read_with(cx, |thread, cx| {
let AgentThreadEntryContent::ToolCall(ToolCall {
- id,
display_name,
- status: ToolCallStatus::Allowed { content, .. },
+ status: ToolCallStatus::Allowed { .. },
+ ..
}) = &thread.entries()[1].content
else {
panic!();
@@ -535,16 +537,112 @@ mod tests {
assert_eq!(md.source(), "ReadFile");
});
- // todo!
- // description.read_with(cx, |md, _cx| {
- // assert!(
- // md.source().contains("foo"),
- // "Expected description to contain 'foo', but got {}",
- // md.source()
- // );
- // });
+ assert!(matches!(
+ thread.entries[2].content,
+ AgentThreadEntryContent::Message(Message {
+ role: Role::Assistant,
+ ..
+ })
+ ));
+ });
+ }
+
+ #[gpui::test]
+ async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+ let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
+ let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
+ let full_turn = thread.update(cx, |thread, cx| {
+ thread.send(r#"Run `echo "Hello, world!"`"#, cx)
+ });
+
+ run_until_tool_call(&thread, cx).await;
+
+ let tool_call_id = thread.read_with(cx, |thread, cx| {
+ let AgentThreadEntryContent::ToolCall(ToolCall {
+ id,
+ display_name,
+ status:
+ ToolCallStatus::WaitingForConfirmation {
+ confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
+ ..
+ },
+ }) = &thread.entries()[1].content
+ else {
+ panic!();
+ };
+
+ assert_eq!(root_command, "echo");
+
+ display_name.read_with(cx, |md, _cx| {
+ assert_eq!(md.source(), "Shell");
+ });
+
*id
});
+
+ thread.update(cx, |thread, cx| {
+ thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
+
+ assert!(matches!(
+ &thread.entries()[1].content,
+ AgentThreadEntryContent::ToolCall(ToolCall {
+ status: ToolCallStatus::Allowed { .. },
+ ..
+ })
+ ));
+ });
+
+ full_turn.await.unwrap();
+
+ thread.read_with(cx, |thread, cx| {
+ let AgentThreadEntryContent::ToolCall(ToolCall {
+ status: ToolCallStatus::Allowed { content, .. },
+ ..
+ }) = &thread.entries()[1].content
+ else {
+ panic!();
+ };
+
+ content.as_ref().unwrap().read_with(cx, |md, _cx| {
+ assert!(
+ md.source().contains("Hello, world!"),
+ r#"Expected '{}' to contain "Hello, world!""#,
+ md.source()
+ );
+ });
+ });
+ }
+
+ async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
+ let (mut tx, mut rx) = mpsc::channel::<()>(1);
+
+ let subscription = cx.update(|cx| {
+ cx.subscribe(thread, move |thread, _, cx| {
+ if thread
+ .read(cx)
+ .entries
+ .iter()
+ .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
+ {
+ tx.try_send(()).unwrap();
+ }
+ })
+ });
+
+ select! {
+ _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
+ panic!("Timeout waiting for tool call")
+ }
+ _ = rx.next().fuse() => {
+ drop(subscription);
+ }
+ }
}
pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
@@ -554,7 +652,6 @@ mod tests {
command
.arg(cli_path)
.arg("--acp")
- .args(["--model", "gemini-2.5-flash"])
.current_dir("/private/tmp")
.stdin(Stdio::piped())
.stdout(Stdio::piped())