@@ -117,7 +117,6 @@ impl MessageChunk {
#[derive(Debug)]
pub enum AgentThreadEntryContent {
Message(Message),
- ReadFile { path: PathBuf, content: String },
ToolCall(ToolCall),
}
@@ -343,15 +342,17 @@ impl AcpThread {
#[cfg(test)]
mod tests {
use super::*;
+ use futures::{FutureExt as _, channel::mpsc, select};
use gpui::{AsyncApp, TestAppContext};
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
- use std::{env, path::Path, process::Stdio};
+ use smol::stream::StreamExt;
+ use std::{env, path::Path, process::Stdio, time::Duration};
use util::path;
fn init_test(cx: &mut TestAppContext) {
- env_logger::init();
+ env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
@@ -361,7 +362,41 @@ mod tests {
}
#[gpui::test]
- async fn test_gemini(cx: &mut TestAppContext) {
+ async fn test_gemini_basic(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+ let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
+ let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
+ thread
+ .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
+ .await
+ .unwrap();
+
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.entries.len(), 2);
+ assert!(matches!(
+ thread.entries[0].content,
+ AgentThreadEntryContent::Message(Message {
+ role: Role::User,
+ ..
+ })
+ ));
+ assert!(matches!(
+ thread.entries[1].content,
+ AgentThreadEntryContent::Message(Message {
+ role: Role::Assistant,
+ ..
+ })
+ ));
+ });
+ }
+
+ #[gpui::test]
+ async fn test_gemini_tool_call(cx: &mut TestAppContext) {
init_test(cx);
cx.executor().allow_parking();
@@ -375,17 +410,52 @@ mod tests {
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
- thread
- .update(cx, |thread, cx| {
- thread.send(
- "Read the '/private/tmp/foo' file and output all of its contents.",
- cx,
- )
- })
- .await
- .unwrap();
+ let full_turn = thread.update(cx, |thread, cx| {
+ thread.send(
+ "Read the '/private/tmp/foo' file and tell me what you see.",
+ cx,
+ )
+ });
+
+ run_until_tool_call(&thread, cx).await;
+
+ let tool_call_id = thread.read_with(cx, |thread, cx| {
+ let AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
+ id,
+ tool_name,
+ description,
+ ..
+ }) = &thread.entries().last().unwrap().content
+ else {
+ panic!();
+ };
+
+ tool_name.read_with(cx, |md, _cx| {
+ assert_eq!(md.source(), "read_file");
+ });
+
+ description.read_with(cx, |md, _cx| {
+ assert!(
+ md.source().contains("foo"),
+ "Expected description to contain 'foo', but got {}",
+ md.source()
+ );
+ });
+ *id
+ });
+
+ thread.update(cx, |thread, cx| {
+ thread.authorize_tool_call(tool_call_id, true, cx);
+ assert!(matches!(
+ thread.entries().last().unwrap().content,
+ AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
+ ));
+ });
+
+ full_turn.await.unwrap();
thread.read_with(cx, |thread, _| {
+ assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
assert!(matches!(
thread.entries[0].content,
AgentThreadEntryContent::Message(Message {
@@ -393,20 +463,44 @@ mod tests {
..
})
));
- assert!(
- thread.entries().iter().any(|entry| {
- match &entry.content {
- AgentThreadEntryContent::ReadFile { path, content } => {
- path.to_string_lossy().to_string() == "/private/tmp/foo"
- && content == "Lorem ipsum dolor"
- }
- _ => false,
- }
- }),
- "Thread does not contain entry. Actual: {:?}",
- thread.entries()
- );
+ assert!(matches!(
+ thread.entries[1].content,
+ AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
+ ));
+ assert!(matches!(
+ thread.entries[2].content,
+ AgentThreadEntryContent::Message(Message {
+ role: Role::Assistant,
+ ..
+ })
+ ));
+ });
+ }
+
+ async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
+ let (mut tx, mut rx) = mpsc::channel(1);
+
+ let subscription = cx.update(|cx| {
+ cx.subscribe(thread, move |thread, _, cx| {
+ if thread
+ .read(cx)
+ .entries
+ .iter()
+ .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
+ {
+ tx.try_send(()).unwrap();
+ }
+ })
});
+
+ select! {
+ _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
+ panic!("Timeout waiting for tool call")
+ }
+ _ = rx.next().fuse() => {
+ drop(subscription);
+ }
+ }
}
pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
@@ -1,4 +1,4 @@
-use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId, ToolCallId};
+use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId};
use agentic_coding_protocol as acp;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
@@ -107,7 +107,7 @@ impl acp::Client for AcpClientDelegate {
})??
.await?;
- buffer.update(cx, |buffer, cx| {
+ buffer.update(cx, |buffer, _cx| {
let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
let end = match request.line_limit {
None => buffer.max_point(),
@@ -115,15 +115,6 @@ impl acp::Client for AcpClientDelegate {
};
let content: String = buffer.text_for_range(start..end).collect();
- self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
- thread.push_entry(
- AgentThreadEntryContent::ReadFile {
- path: request.path.clone(),
- content: content.clone(),
- },
- cx,
- );
- });
acp::ReadTextFileResponse {
content,
@@ -203,7 +194,7 @@ impl acp::Client for AcpClientDelegate {
})?
.context("Failed to update thread")?;
- if dbg!(rx.await)? {
+ if rx.await? {
Ok(acp::RequestToolCallResponse::Allowed {
id: entry_id.into(),
})