diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index b28e204c89eb5b1ef0998e818830a1ef6bf84641..8804d0cfe2515988ae0f508d5bbce55d48869064 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -117,7 +117,6 @@ impl MessageChunk { #[derive(Debug)] pub enum AgentThreadEntryContent { Message(Message), - ReadFile { path: PathBuf, content: String }, ToolCall(ToolCall), } @@ -343,15 +342,17 @@ impl AcpThread { #[cfg(test)] mod tests { use super::*; + use futures::{FutureExt as _, channel::mpsc, select}; use gpui::{AsyncApp, TestAppContext}; use project::FakeFs; use serde_json::json; use settings::SettingsStore; - use std::{env, path::Path, process::Stdio}; + use smol::stream::StreamExt; + use std::{env, path::Path, process::Stdio, time::Duration}; use util::path; fn init_test(cx: &mut TestAppContext) { - env_logger::init(); + env_logger::try_init().ok(); cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); @@ -361,7 +362,41 @@ mod tests { } #[gpui::test] - async fn test_gemini(cx: &mut TestAppContext) { + async fn test_gemini_basic(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap(); + let thread = server.create_thread(&mut cx.to_async()).await.unwrap(); + thread + .update(cx, |thread, cx| thread.send("Hello from Zed!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries.len(), 2); + assert!(matches!( + thread.entries[0].content, + AgentThreadEntryContent::Message(Message { + role: Role::User, + .. + }) + )); + assert!(matches!( + thread.entries[1].content, + AgentThreadEntryContent::Message(Message { + role: Role::Assistant, + .. + }) + )); + }); + } + + #[gpui::test] + async fn test_gemini_tool_call(cx: &mut TestAppContext) { init_test(cx); cx.executor().allow_parking(); @@ -375,17 +410,52 @@ mod tests { let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap(); let thread = server.create_thread(&mut cx.to_async()).await.unwrap(); - thread - .update(cx, |thread, cx| { - thread.send( - "Read the '/private/tmp/foo' file and output all of its contents.", - cx, - ) - }) - .await - .unwrap(); + let full_turn = thread.update(cx, |thread, cx| { + thread.send( + "Read the '/private/tmp/foo' file and tell me what you see.", + cx, + ) + }); + + run_until_tool_call(&thread, cx).await; + + let tool_call_id = thread.read_with(cx, |thread, cx| { + let AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation { + id, + tool_name, + description, + .. + }) = &thread.entries().last().unwrap().content + else { + panic!(); + }; + + tool_name.read_with(cx, |md, _cx| { + assert_eq!(md.source(), "read_file"); + }); + + description.read_with(cx, |md, _cx| { + assert!( + md.source().contains("foo"), + "Expected description to contain 'foo', but got {}", + md.source() + ); + }); + *id + }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call(tool_call_id, true, cx); + assert!(matches!( + thread.entries().last().unwrap().content, + AgentThreadEntryContent::ToolCall(ToolCall::Allowed) + )); + }); + + full_turn.await.unwrap(); thread.read_with(cx, |thread, _| { + assert!(thread.entries.len() >= 3, "{:?}", &thread.entries); assert!(matches!( thread.entries[0].content, AgentThreadEntryContent::Message(Message { @@ -393,20 +463,44 @@ mod tests { .. }) )); - assert!( - thread.entries().iter().any(|entry| { - match &entry.content { - AgentThreadEntryContent::ReadFile { path, content } => { - path.to_string_lossy().to_string() == "/private/tmp/foo" - && content == "Lorem ipsum dolor" - } - _ => false, - } - }), - "Thread does not contain entry. Actual: {:?}", - thread.entries() - ); + assert!(matches!( + thread.entries[1].content, + AgentThreadEntryContent::ToolCall(ToolCall::Allowed) + )); + assert!(matches!( + thread.entries[2].content, + AgentThreadEntryContent::Message(Message { + role: Role::Assistant, + .. + }) + )); + }); + } + + async fn run_until_tool_call(thread: &Entity, cx: &mut TestAppContext) { + let (mut tx, mut rx) = mpsc::channel(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + if thread + .read(cx) + .entries + .iter() + .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_))) + { + tx.try_send(()).unwrap(); + } + }) }); + + select! { + _ = cx.executor().timer(Duration::from_secs(5)).fuse() => { + panic!("Timeout waiting for tool call") + } + _ = rx.next().fuse() => { + drop(subscription); + } + } } pub fn gemini_acp_server(project: Entity, mut cx: AsyncApp) -> Result> { diff --git a/crates/acp/src/server.rs b/crates/acp/src/server.rs index 44b5acc3e6569253c77080ac0bd0f546b788014e..6bb198b87aa3ce3ff826d329724e8cf989e9e9b4 100644 --- a/crates/acp/src/server.rs +++ b/crates/acp/src/server.rs @@ -1,4 +1,4 @@ -use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId, ToolCallId}; +use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId}; use agentic_coding_protocol as acp; use anyhow::{Context as _, Result}; use async_trait::async_trait; @@ -107,7 +107,7 @@ impl acp::Client for AcpClientDelegate { })?? .await?; - buffer.update(cx, |buffer, cx| { + buffer.update(cx, |buffer, _cx| { let start = language::Point::new(request.line_offset.unwrap_or(0), 0); let end = match request.line_limit { None => buffer.max_point(), @@ -115,15 +115,6 @@ impl acp::Client for AcpClientDelegate { }; let content: String = buffer.text_for_range(start..end).collect(); - self.update_thread(&request.thread_id.into(), cx, |thread, cx| { - thread.push_entry( - AgentThreadEntryContent::ReadFile { - path: request.path.clone(), - content: content.clone(), - }, - cx, - ); - }); acp::ReadTextFileResponse { content, @@ -203,7 +194,7 @@ impl acp::Client for AcpClientDelegate { })? .context("Failed to update thread")?; - if dbg!(rx.await)? { + if rx.await? { Ok(acp::RequestToolCallResponse::Allowed { id: entry_id.into(), }) diff --git a/crates/acp/src/thread_view.rs b/crates/acp/src/thread_view.rs index cddefeb9647a8223070253098db8560c8aa4f8a1..e07b2bbc7ae5db99e8805dfa4ef722015ac60598 100644 --- a/crates/acp/src/thread_view.rs +++ b/crates/acp/src/thread_view.rs @@ -241,12 +241,6 @@ impl AcpThreadView { .into_any(), } } - AgentThreadEntryContent::ReadFile { path, content: _ } => { - // todo! - div() - .child(format!("", path.display())) - .into_any() - } AgentThreadEntryContent::ToolCall(tool_call) => match tool_call { ToolCall::WaitingForConfirmation { id,