crates/agent_servers/Cargo.toml 🔗
@@ -7,7 +7,7 @@ license = "GPL-3.0-or-later"
[features]
test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"]
-gemini = []
+e2e = []
[lints]
workspace = true
Agus Zubiaga created
Release Notes:
- N/A
crates/agent_servers/Cargo.toml | 2
crates/agent_servers/src/agent_servers.rs | 125 ++++
crates/agent_servers/src/claude.rs | 77 ++-
crates/agent_servers/src/e2e_tests.rs | 368 +++++++++++++++++
crates/agent_servers/src/gemini.rs | 428 +------------------
crates/agent_servers/src/settings.rs | 1
crates/agent_servers/src/stdio_agent_server.rs | 54 --
7 files changed, 551 insertions(+), 504 deletions(-)
@@ -7,7 +7,7 @@ license = "GPL-3.0-or-later"
[features]
test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"]
-gemini = []
+e2e = []
[lints]
workspace = true
@@ -3,6 +3,9 @@ mod gemini;
mod settings;
mod stdio_agent_server;
+#[cfg(test)]
+mod e2e_tests;
+
pub use claude::*;
pub use gemini::*;
pub use settings::*;
@@ -11,34 +14,20 @@ pub use stdio_agent_server::*;
use acp_thread::AcpThread;
use anyhow::Result;
use collections::HashMap;
-use gpui::{App, Entity, SharedString, Task};
+use gpui::{App, AsyncApp, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use std::path::{Path, PathBuf};
+use std::{
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+use util::ResultExt as _;
pub fn init(cx: &mut App) {
settings::init(cx);
}
-#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
-pub struct AgentServerCommand {
- #[serde(rename = "command")]
- pub path: PathBuf,
- #[serde(default)]
- pub args: Vec<String>,
- pub env: Option<HashMap<String, String>>,
-}
-
-pub enum AgentServerVersion {
- Supported,
- Unsupported {
- error_message: SharedString,
- upgrade_message: SharedString,
- upgrade_command: String,
- },
-}
-
pub trait AgentServer: Send {
fn logo(&self) -> ui::IconName;
fn name(&self) -> &'static str;
@@ -78,3 +67,99 @@ impl std::fmt::Debug for AgentServerCommand {
.finish()
}
}
+
+pub enum AgentServerVersion {
+ Supported,
+ Unsupported {
+ error_message: SharedString,
+ upgrade_message: SharedString,
+ upgrade_command: String,
+ },
+}
+
+#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
+pub struct AgentServerCommand {
+ #[serde(rename = "command")]
+ pub path: PathBuf,
+ #[serde(default)]
+ pub args: Vec<String>,
+ pub env: Option<HashMap<String, String>>,
+}
+
+impl AgentServerCommand {
+ pub(crate) async fn resolve(
+ path_bin_name: &'static str,
+ extra_args: &[&'static str],
+ settings: Option<AgentServerSettings>,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Option<Self> {
+ if let Some(agent_settings) = settings {
+ return Some(Self {
+ path: agent_settings.command.path,
+ args: agent_settings
+ .command
+ .args
+ .into_iter()
+ .chain(extra_args.iter().map(|arg| arg.to_string()))
+ .collect(),
+ env: agent_settings.command.env,
+ });
+ } else {
+ find_bin_in_path(path_bin_name, project, cx)
+ .await
+ .map(|path| Self {
+ path,
+ args: extra_args.iter().map(|arg| arg.to_string()).collect(),
+ env: None,
+ })
+ }
+ }
+}
+
+async fn find_bin_in_path(
+ bin_name: &'static str,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Option<PathBuf> {
+ let (env_task, root_dir) = project
+ .update(cx, |project, cx| {
+ let worktree = project.visible_worktrees(cx).next();
+ match worktree {
+ Some(worktree) => {
+ let env_task = project.environment().update(cx, |env, cx| {
+ env.get_worktree_environment(worktree.clone(), cx)
+ });
+
+ let path = worktree.read(cx).abs_path();
+ (env_task, path)
+ }
+ None => {
+ let path: Arc<Path> = paths::home_dir().as_path().into();
+ let env_task = project.environment().update(cx, |env, cx| {
+ env.get_directory_environment(path.clone(), cx)
+ });
+ (env_task, path)
+ }
+ }
+ })
+ .log_err()?;
+
+ cx.background_executor()
+ .spawn(async move {
+ let which_result = if cfg!(windows) {
+ which::which(bin_name)
+ } else {
+ let env = env_task.await.unwrap_or_default();
+ let shell_path = env.get("PATH").cloned();
+ which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
+ };
+
+ if let Err(which::Error::CannotFindBinaryPath) = which_result {
+ return None;
+ }
+
+ which_result.log_err()
+ })
+ .await
+}
@@ -3,6 +3,7 @@ mod tools;
use collections::HashMap;
use project::Project;
+use settings::SettingsStore;
use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
@@ -12,7 +13,7 @@ use agentic_coding_protocol::{
self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
};
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use futures::channel::oneshot;
use futures::future::LocalBoxFuture;
use futures::{AsyncBufReadExt, AsyncWriteExt};
@@ -28,7 +29,7 @@ use util::ResultExt;
use crate::claude::mcp_server::ClaudeMcpServer;
use crate::claude::tools::ClaudeTool;
-use crate::{AgentServer, find_bin_in_path};
+use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
#[derive(Clone)]
@@ -87,31 +88,41 @@ impl AgentServer for ClaudeCode {
.await?;
mcp_config_file.flush().await?;
- let command = find_bin_in_path("claude", &project, cx)
- .await
- .context("Failed to find claude binary")?;
-
- let mut child = util::command::new_smol_command(&command)
- .args([
- "--input-format",
- "stream-json",
- "--output-format",
- "stream-json",
- "--print",
- "--verbose",
- "--mcp-config",
- mcp_config_path.to_string_lossy().as_ref(),
- "--permission-prompt-tool",
- &format!(
- "mcp__{}__{}",
- mcp_server::SERVER_NAME,
- mcp_server::PERMISSION_TOOL
- ),
- "--allowedTools",
- "mcp__zed__Read,mcp__zed__Edit",
- "--disallowedTools",
- "Read,Edit",
- ])
+ let settings = cx.read_global(|settings: &SettingsStore, _| {
+ settings.get::<AllAgentServersSettings>(None).claude.clone()
+ })?;
+
+ let Some(command) =
+ AgentServerCommand::resolve("claude", &[], settings, &project, cx).await
+ else {
+ anyhow::bail!("Failed to find claude binary");
+ };
+
+ let mut child = util::command::new_smol_command(&command.path)
+ .args(
+ [
+ "--input-format",
+ "stream-json",
+ "--output-format",
+ "stream-json",
+ "--print",
+ "--verbose",
+ "--mcp-config",
+ mcp_config_path.to_string_lossy().as_ref(),
+ "--permission-prompt-tool",
+ &format!(
+ "mcp__{}__{}",
+ mcp_server::SERVER_NAME,
+ mcp_server::PERMISSION_TOOL
+ ),
+ "--allowedTools",
+ "mcp__zed__Read,mcp__zed__Edit",
+ "--disallowedTools",
+ "Read,Edit",
+ ]
+ .into_iter()
+ .chain(command.args.iter().map(|arg| arg.as_str())),
+ )
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
@@ -562,10 +573,20 @@ struct McpServerConfig {
}
#[cfg(test)]
-mod tests {
+pub(crate) mod tests {
use super::*;
use serde_json::json;
+ // crate::common_e2e_tests!(ClaudeCode);
+
+ pub fn local_command() -> AgentServerCommand {
+ AgentServerCommand {
+ path: "claude".into(),
+ args: vec![],
+ env: None,
+ }
+ }
+
#[test]
fn test_deserialize_content_untagged_text() {
let json = json!("Hello, world!");
@@ -0,0 +1,368 @@
+use std::{path::Path, sync::Arc, time::Duration};
+
+use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
+use acp_thread::{
+ AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
+};
+use agentic_coding_protocol as acp;
+use futures::{FutureExt, StreamExt, channel::mpsc, select};
+use gpui::{Entity, TestAppContext};
+use indoc::indoc;
+use project::{FakeFs, Project};
+use serde_json::json;
+use settings::{Settings, SettingsStore};
+use util::path;
+
+pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+ let fs = init_test(cx).await;
+ let project = Project::test(fs, [], cx).await;
+ let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+
+ thread
+ .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
+ .await
+ .unwrap();
+
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.entries().len(), 2);
+ assert!(matches!(
+ thread.entries()[0],
+ AgentThreadEntry::UserMessage(_)
+ ));
+ assert!(matches!(
+ thread.entries()[1],
+ AgentThreadEntry::AssistantMessage(_)
+ ));
+ });
+}
+
+pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+ let _fs = init_test(cx).await;
+
+ let tempdir = tempfile::tempdir().unwrap();
+ std::fs::write(
+ tempdir.path().join("foo.rs"),
+ indoc! {"
+ fn main() {
+ println!(\"Hello, world!\");
+ }
+ "},
+ )
+ .expect("failed to write file");
+ let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
+ let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(
+ acp::SendUserMessageParams {
+ chunks: vec![
+ acp::UserMessageChunk::Text {
+ text: "Read the file ".into(),
+ },
+ acp::UserMessageChunk::Path {
+ path: Path::new("foo.rs").into(),
+ },
+ acp::UserMessageChunk::Text {
+ text: " and tell me what the content of the println! is".into(),
+ },
+ ],
+ },
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(thread.entries().len(), 3);
+ assert!(matches!(
+ thread.entries()[0],
+ AgentThreadEntry::UserMessage(_)
+ ));
+ assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
+ let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
+ panic!("Expected AssistantMessage")
+ };
+ assert!(
+ assistant_message.to_markdown(cx).contains("Hello, world!"),
+ "unexpected assistant message: {:?}",
+ assistant_message.to_markdown(cx)
+ );
+ });
+}
+
+pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+ let fs = init_test(cx).await;
+ fs.insert_tree(
+ path!("/private/tmp"),
+ json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
+ )
+ .await;
+ let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+ let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.send_raw(
+ "Read the '/private/tmp/foo' file and tell me what you see.",
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, _cx| {
+ assert!(matches!(
+ &thread.entries()[2],
+ AgentThreadEntry::ToolCall(ToolCall {
+ status: ToolCallStatus::Allowed { .. },
+ ..
+ })
+ ));
+
+ assert!(matches!(
+ thread.entries()[3],
+ AgentThreadEntry::AssistantMessage(_)
+ ));
+ });
+}
+
+pub async fn test_tool_call_with_confirmation(
+ server: impl AgentServer + 'static,
+ cx: &mut TestAppContext,
+) {
+ let fs = init_test(cx).await;
+ let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+ let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+ let full_turn = thread.update(cx, |thread, cx| {
+ thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
+ });
+
+ run_until_first_tool_call(&thread, cx).await;
+
+ let tool_call_id = thread.read_with(cx, |thread, _cx| {
+ let AgentThreadEntry::ToolCall(ToolCall {
+ id,
+ status:
+ ToolCallStatus::WaitingForConfirmation {
+ confirmation: ToolCallConfirmation::Execute { root_command, .. },
+ ..
+ },
+ ..
+ }) = &thread.entries()[2]
+ else {
+ panic!();
+ };
+
+ assert_eq!(root_command, "echo");
+
+ *id
+ });
+
+ thread.update(cx, |thread, cx| {
+ thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
+
+ assert!(matches!(
+ &thread.entries()[2],
+ AgentThreadEntry::ToolCall(ToolCall {
+ status: ToolCallStatus::Allowed { .. },
+ ..
+ })
+ ));
+ });
+
+ full_turn.await.unwrap();
+
+ thread.read_with(cx, |thread, cx| {
+ let AgentThreadEntry::ToolCall(ToolCall {
+ content: Some(ToolCallContent::Markdown { markdown }),
+ status: ToolCallStatus::Allowed { .. },
+ ..
+ }) = &thread.entries()[2]
+ else {
+ panic!();
+ };
+
+ markdown.read_with(cx, |md, _cx| {
+ assert!(
+ md.source().contains("Hello, world!"),
+ r#"Expected '{}' to contain "Hello, world!""#,
+ md.source()
+ );
+ });
+ });
+}
+
+pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+ let fs = init_test(cx).await;
+
+ let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+ let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+ let full_turn = thread.update(cx, |thread, cx| {
+ thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
+ });
+
+ let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
+
+ thread.read_with(cx, |thread, _cx| {
+ let AgentThreadEntry::ToolCall(ToolCall {
+ id,
+ status:
+ ToolCallStatus::WaitingForConfirmation {
+ confirmation: ToolCallConfirmation::Execute { root_command, .. },
+ ..
+ },
+ ..
+ }) = &thread.entries()[first_tool_call_ix]
+ else {
+ panic!("{:?}", thread.entries()[1]);
+ };
+
+ assert_eq!(root_command, "echo");
+
+ *id
+ });
+
+ thread
+ .update(cx, |thread, cx| thread.cancel(cx))
+ .await
+ .unwrap();
+ full_turn.await.unwrap();
+ thread.read_with(cx, |thread, _| {
+ let AgentThreadEntry::ToolCall(ToolCall {
+ status: ToolCallStatus::Canceled,
+ ..
+ }) = &thread.entries()[first_tool_call_ix]
+ else {
+ panic!();
+ };
+ });
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
+ })
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, _| {
+ assert!(matches!(
+ &thread.entries().last().unwrap(),
+ AgentThreadEntry::AssistantMessage(..),
+ ))
+ });
+}
+
+#[macro_export]
+macro_rules! common_e2e_tests {
+ ($server:expr) => {
+ mod common_e2e {
+ use super::*;
+
+ #[::gpui::test]
+ #[cfg_attr(not(feature = "e2e"), ignore)]
+ async fn basic(cx: &mut ::gpui::TestAppContext) {
+ $crate::e2e_tests::test_basic($server, cx).await;
+ }
+
+ #[::gpui::test]
+ #[cfg_attr(not(feature = "e2e"), ignore)]
+ async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
+ $crate::e2e_tests::test_path_mentions($server, cx).await;
+ }
+
+ #[::gpui::test]
+ #[cfg_attr(not(feature = "e2e"), ignore)]
+ async fn tool_call(cx: &mut ::gpui::TestAppContext) {
+ $crate::e2e_tests::test_tool_call($server, cx).await;
+ }
+
+ #[::gpui::test]
+ #[cfg_attr(not(feature = "e2e"), ignore)]
+ async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
+ $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
+ }
+
+ #[::gpui::test]
+ #[cfg_attr(not(feature = "e2e"), ignore)]
+ async fn cancel(cx: &mut ::gpui::TestAppContext) {
+ $crate::e2e_tests::test_cancel($server, cx).await;
+ }
+ }
+ };
+}
+
+// Helpers
+
+pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
+ env_logger::try_init().ok();
+
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ language::init(cx);
+ crate::settings::init(cx);
+
+ crate::AllAgentServersSettings::override_global(
+ AllAgentServersSettings {
+ claude: Some(AgentServerSettings {
+ command: crate::claude::tests::local_command(),
+ }),
+ gemini: Some(AgentServerSettings {
+ command: crate::gemini::tests::local_command(),
+ }),
+ },
+ cx,
+ );
+ });
+
+ cx.executor().allow_parking();
+
+ FakeFs::new(cx.executor())
+}
+
+pub async fn new_test_thread(
+ server: impl AgentServer + 'static,
+ project: Entity<Project>,
+ current_dir: impl AsRef<Path>,
+ cx: &mut TestAppContext,
+) -> Entity<AcpThread> {
+ let thread = cx
+ .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
+ .await
+ .unwrap();
+
+ thread
+ .update(cx, |thread, _| thread.initialize())
+ .await
+ .unwrap();
+ thread
+}
+
+pub async fn run_until_first_tool_call(
+ thread: &Entity<AcpThread>,
+ cx: &mut TestAppContext,
+) -> usize {
+ let (mut tx, mut rx) = mpsc::channel::<usize>(1);
+
+ let subscription = cx.update(|cx| {
+ cx.subscribe(thread, move |thread, _, cx| {
+ for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
+ if matches!(entry, AgentThreadEntry::ToolCall(_)) {
+ return tx.try_send(ix).unwrap();
+ }
+ }
+ })
+ });
+
+ select! {
+ // We have to use a smol timer here because
+ // cx.background_executor().timer isn't real in the test context
+ _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
+ panic!("Timeout waiting for tool call")
+ }
+ ix = rx.next().fuse() => {
+ drop(subscription);
+ ix.unwrap()
+ }
+ }
+}
@@ -1,4 +1,4 @@
-use crate::stdio_agent_server::{StdioAgentServer, find_bin_in_path};
+use crate::stdio_agent_server::StdioAgentServer;
use crate::{AgentServerCommand, AgentServerVersion};
use anyhow::{Context as _, Result};
use gpui::{AsyncApp, Entity};
@@ -38,35 +38,15 @@ impl StdioAgentServer for Gemini {
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<AgentServerCommand> {
- let custom_command = cx.read_global(|settings: &SettingsStore, _| {
- let settings = settings.get::<AllAgentServersSettings>(None);
- settings
- .gemini
- .as_ref()
- .map(|gemini_settings| AgentServerCommand {
- path: gemini_settings.command.path.clone(),
- args: gemini_settings
- .command
- .args
- .iter()
- .cloned()
- .chain(std::iter::once(ACP_ARG.into()))
- .collect(),
- env: gemini_settings.command.env.clone(),
- })
+ let settings = cx.read_global(|settings: &SettingsStore, _| {
+ settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
- if let Some(custom_command) = custom_command {
- return Ok(custom_command);
- }
-
- if let Some(path) = find_bin_in_path("gemini", project, cx).await {
- return Ok(AgentServerCommand {
- path,
- args: vec![ACP_ARG.into()],
- env: None,
- });
- }
+ if let Some(command) =
+ AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
+ {
+ return Ok(command);
+ };
let (fs, node_runtime) = project.update(cx, |project, _| {
(project.fs().clone(), project.node_runtime().cloned())
@@ -121,381 +101,23 @@ impl StdioAgentServer for Gemini {
}
#[cfg(test)]
-mod test {
- use std::{path::Path, time::Duration};
-
- use acp_thread::{
- AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent,
- ToolCallStatus,
- };
- use agentic_coding_protocol as acp;
- use anyhow::Result;
- use futures::{FutureExt, StreamExt, channel::mpsc, select};
- use gpui::{AsyncApp, Entity, TestAppContext};
- use indoc::indoc;
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
-
- use crate::{AgentServer, AgentServerCommand, AgentServerVersion, StdioAgentServer};
-
- pub async fn gemini_acp_thread(
- project: Entity<Project>,
- current_dir: impl AsRef<Path>,
- cx: &mut TestAppContext,
- ) -> Entity<AcpThread> {
- #[derive(Clone)]
- struct DevGemini;
-
- impl StdioAgentServer for DevGemini {
- async fn command(
- &self,
- _project: &Entity<Project>,
- _cx: &mut AsyncApp,
- ) -> Result<AgentServerCommand> {
- let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
- .join("../../../gemini-cli/packages/cli")
- .to_string_lossy()
- .to_string();
-
- Ok(AgentServerCommand {
- path: "node".into(),
- args: vec![cli_path, "--experimental-acp".into()],
- env: None,
- })
- }
-
- async fn version(&self, _command: &AgentServerCommand) -> Result<AgentServerVersion> {
- Ok(AgentServerVersion::Supported)
- }
-
- fn logo(&self) -> ui::IconName {
- ui::IconName::AiGemini
- }
-
- fn name(&self) -> &'static str {
- "test"
- }
-
- fn empty_state_headline(&self) -> &'static str {
- "test"
- }
-
- fn empty_state_message(&self) -> &'static str {
- "test"
- }
-
- fn supports_always_allow(&self) -> bool {
- true
- }
- }
-
- let thread = cx
- .update(|cx| AgentServer::new_thread(&DevGemini, current_dir.as_ref(), &project, cx))
- .await
- .unwrap();
-
- thread
- .update(cx, |thread, _| thread.initialize())
- .await
- .unwrap();
- thread
- }
-
- fn init_test(cx: &mut TestAppContext) {
- env_logger::try_init().ok();
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- Project::init_settings(cx);
- language::init(cx);
- });
- }
-
- #[gpui::test]
- #[cfg_attr(not(feature = "gemini"), ignore)]
- async fn test_gemini_basic(cx: &mut TestAppContext) {
- init_test(cx);
-
- cx.executor().allow_parking();
-
- let fs = FakeFs::new(cx.executor());
- let project = Project::test(fs, [], cx).await;
- let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
- thread
- .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
- .await
- .unwrap();
-
- thread.read_with(cx, |thread, _| {
- assert_eq!(thread.entries().len(), 2);
- assert!(matches!(
- thread.entries()[0],
- AgentThreadEntry::UserMessage(_)
- ));
- assert!(matches!(
- thread.entries()[1],
- AgentThreadEntry::AssistantMessage(_)
- ));
- });
- }
-
- #[gpui::test]
- #[cfg_attr(not(feature = "gemini"), ignore)]
- async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
- init_test(cx);
-
- cx.executor().allow_parking();
- let tempdir = tempfile::tempdir().unwrap();
- std::fs::write(
- tempdir.path().join("foo.rs"),
- indoc! {"
- fn main() {
- println!(\"Hello, world!\");
- }
- "},
- )
- .expect("failed to write file");
- let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
- let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
- thread
- .update(cx, |thread, cx| {
- thread.send(
- acp::SendUserMessageParams {
- chunks: vec![
- acp::UserMessageChunk::Text {
- text: "Read the file ".into(),
- },
- acp::UserMessageChunk::Path {
- path: Path::new("foo.rs").into(),
- },
- acp::UserMessageChunk::Text {
- text: " and tell me what the content of the println! is".into(),
- },
- ],
- },
- cx,
- )
- })
- .await
- .unwrap();
-
- thread.read_with(cx, |thread, cx| {
- assert_eq!(thread.entries().len(), 3);
- assert!(matches!(
- thread.entries()[0],
- AgentThreadEntry::UserMessage(_)
- ));
- assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
- let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
- panic!("Expected AssistantMessage")
- };
- assert!(
- assistant_message.to_markdown(cx).contains("Hello, world!"),
- "unexpected assistant message: {:?}",
- assistant_message.to_markdown(cx)
- );
- });
- }
-
- #[gpui::test]
- #[cfg_attr(not(feature = "gemini"), ignore)]
- async fn test_gemini_tool_call(cx: &mut TestAppContext) {
- init_test(cx);
-
- cx.executor().allow_parking();
-
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/private/tmp"),
- json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
- )
- .await;
- let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
- let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
- thread
- .update(cx, |thread, cx| {
- thread.send_raw(
- "Read the '/private/tmp/foo' file and tell me what you see.",
- cx,
- )
- })
- .await
- .unwrap();
- thread.read_with(cx, |thread, _cx| {
- assert!(matches!(
- &thread.entries()[2],
- AgentThreadEntry::ToolCall(ToolCall {
- status: ToolCallStatus::Allowed { .. },
- ..
- })
- ));
-
- assert!(matches!(
- thread.entries()[3],
- AgentThreadEntry::AssistantMessage(_)
- ));
- });
- }
-
- #[gpui::test]
- #[cfg_attr(not(feature = "gemini"), ignore)]
- async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
- init_test(cx);
-
- cx.executor().allow_parking();
-
- let fs = FakeFs::new(cx.executor());
- let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
- let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
- let full_turn = thread.update(cx, |thread, cx| {
- thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
- });
-
- run_until_first_tool_call(&thread, cx).await;
-
- let tool_call_id = thread.read_with(cx, |thread, _cx| {
- let AgentThreadEntry::ToolCall(ToolCall {
- id,
- status:
- ToolCallStatus::WaitingForConfirmation {
- confirmation: ToolCallConfirmation::Execute { root_command, .. },
- ..
- },
- ..
- }) = &thread.entries()[2]
- else {
- panic!();
- };
-
- assert_eq!(root_command, "echo");
-
- *id
- });
-
- thread.update(cx, |thread, cx| {
- thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
-
- assert!(matches!(
- &thread.entries()[2],
- AgentThreadEntry::ToolCall(ToolCall {
- status: ToolCallStatus::Allowed { .. },
- ..
- })
- ));
- });
-
- full_turn.await.unwrap();
-
- thread.read_with(cx, |thread, cx| {
- let AgentThreadEntry::ToolCall(ToolCall {
- content: Some(ToolCallContent::Markdown { markdown }),
- status: ToolCallStatus::Allowed { .. },
- ..
- }) = &thread.entries()[2]
- else {
- panic!();
- };
-
- markdown.read_with(cx, |md, _cx| {
- assert!(
- md.source().contains("Hello, world!"),
- r#"Expected '{}' to contain "Hello, world!""#,
- md.source()
- );
- });
- });
- }
-
- #[gpui::test]
- #[cfg_attr(not(feature = "gemini"), ignore)]
- async fn test_gemini_cancel(cx: &mut TestAppContext) {
- init_test(cx);
-
- cx.executor().allow_parking();
-
- let fs = FakeFs::new(cx.executor());
- let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
- let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
- let full_turn = thread.update(cx, |thread, cx| {
- thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
- });
-
- let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
-
- thread.read_with(cx, |thread, _cx| {
- let AgentThreadEntry::ToolCall(ToolCall {
- id,
- status:
- ToolCallStatus::WaitingForConfirmation {
- confirmation: ToolCallConfirmation::Execute { root_command, .. },
- ..
- },
- ..
- }) = &thread.entries()[first_tool_call_ix]
- else {
- panic!("{:?}", thread.entries()[1]);
- };
-
- assert_eq!(root_command, "echo");
-
- *id
- });
-
- thread
- .update(cx, |thread, cx| thread.cancel(cx))
- .await
- .unwrap();
- full_turn.await.unwrap();
- thread.read_with(cx, |thread, _| {
- let AgentThreadEntry::ToolCall(ToolCall {
- status: ToolCallStatus::Canceled,
- ..
- }) = &thread.entries()[first_tool_call_ix]
- else {
- panic!();
- };
- });
-
- thread
- .update(cx, |thread, cx| {
- thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
- })
- .await
- .unwrap();
- thread.read_with(cx, |thread, _| {
- assert!(matches!(
- &thread.entries().last().unwrap(),
- AgentThreadEntry::AssistantMessage(..),
- ))
- });
- }
-
- async fn run_until_first_tool_call(
- thread: &Entity<AcpThread>,
- cx: &mut TestAppContext,
- ) -> usize {
- let (mut tx, mut rx) = mpsc::channel::<usize>(1);
-
- let subscription = cx.update(|cx| {
- cx.subscribe(thread, move |thread, _, cx| {
- for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
- if matches!(entry, AgentThreadEntry::ToolCall(_)) {
- return tx.try_send(ix).unwrap();
- }
- }
- })
- });
-
- select! {
- _ = cx.executor().timer(Duration::from_secs(10)).fuse() => {
- panic!("Timeout waiting for tool call")
- }
- ix = rx.next().fuse() => {
- drop(subscription);
- ix.unwrap()
- }
+pub(crate) mod tests {
+ use super::*;
+ use crate::AgentServerCommand;
+ use std::path::Path;
+
+ crate::common_e2e_tests!(Gemini);
+
+ pub fn local_command() -> AgentServerCommand {
+ let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
+ .join("../../../gemini-cli/packages/cli")
+ .to_string_lossy()
+ .to_string();
+
+ AgentServerCommand {
+ path: "node".into(),
+ args: vec![cli_path, ACP_ARG.into()],
+ env: None,
}
}
}
@@ -12,6 +12,7 @@ pub fn init(cx: &mut App) {
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)]
pub struct AllAgentServersSettings {
pub gemini: Option<AgentServerSettings>,
+ pub claude: Option<AgentServerSettings>,
}
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
@@ -4,11 +4,8 @@ use agentic_coding_protocol as acp;
use anyhow::{Result, anyhow};
use gpui::{App, AsyncApp, Entity, Task, prelude::*};
use project::Project;
-use std::{
- path::{Path, PathBuf},
- sync::Arc,
-};
-use util::{ResultExt, paths};
+use std::path::Path;
+use util::ResultExt;
pub trait StdioAgentServer: Send + Clone {
fn logo(&self) -> ui::IconName;
@@ -120,50 +117,3 @@ impl<T: StdioAgentServer + 'static> AgentServer for T {
})
}
}
-
-pub async fn find_bin_in_path(
- bin_name: &'static str,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
-) -> Option<PathBuf> {
- let (env_task, root_dir) = project
- .update(cx, |project, cx| {
- let worktree = project.visible_worktrees(cx).next();
- match worktree {
- Some(worktree) => {
- let env_task = project.environment().update(cx, |env, cx| {
- env.get_worktree_environment(worktree.clone(), cx)
- });
-
- let path = worktree.read(cx).abs_path();
- (env_task, path)
- }
- None => {
- let path: Arc<Path> = paths::home_dir().as_path().into();
- let env_task = project.environment().update(cx, |env, cx| {
- env.get_directory_environment(path.clone(), cx)
- });
- (env_task, path)
- }
- }
- })
- .log_err()?;
-
- cx.background_executor()
- .spawn(async move {
- let which_result = if cfg!(windows) {
- which::which(bin_name)
- } else {
- let env = env_task.await.unwrap_or_default();
- let shell_path = env.get("PATH").cloned();
- which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
- };
-
- if let Err(which::Error::CannotFindBinaryPath) = which_result {
- return None;
- }
-
- which_result.log_err()
- })
- .await
-}