Unify agent server settings and extract e2e tests out (#34642)

Agus Zubiaga created

Release Notes:

- N/A

Change summary

crates/agent_servers/Cargo.toml                |   2 
crates/agent_servers/src/agent_servers.rs      | 125 ++++
crates/agent_servers/src/claude.rs             |  77 ++-
crates/agent_servers/src/e2e_tests.rs          | 368 +++++++++++++++++
crates/agent_servers/src/gemini.rs             | 428 +------------------
crates/agent_servers/src/settings.rs           |   1 
crates/agent_servers/src/stdio_agent_server.rs |  54 --
7 files changed, 551 insertions(+), 504 deletions(-)

Detailed changes

crates/agent_servers/Cargo.toml 🔗

@@ -7,7 +7,7 @@ license = "GPL-3.0-or-later"
 
 [features]
 test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"]
-gemini = []
+e2e = []
 
 [lints]
 workspace = true

crates/agent_servers/src/agent_servers.rs 🔗

@@ -3,6 +3,9 @@ mod gemini;
 mod settings;
 mod stdio_agent_server;
 
+#[cfg(test)]
+mod e2e_tests;
+
 pub use claude::*;
 pub use gemini::*;
 pub use settings::*;
@@ -11,34 +14,20 @@ pub use stdio_agent_server::*;
 use acp_thread::AcpThread;
 use anyhow::Result;
 use collections::HashMap;
-use gpui::{App, Entity, SharedString, Task};
+use gpui::{App, AsyncApp, Entity, SharedString, Task};
 use project::Project;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use std::path::{Path, PathBuf};
+use std::{
+    path::{Path, PathBuf},
+    sync::Arc,
+};
+use util::ResultExt as _;
 
 pub fn init(cx: &mut App) {
     settings::init(cx);
 }
 
-#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
-pub struct AgentServerCommand {
-    #[serde(rename = "command")]
-    pub path: PathBuf,
-    #[serde(default)]
-    pub args: Vec<String>,
-    pub env: Option<HashMap<String, String>>,
-}
-
-pub enum AgentServerVersion {
-    Supported,
-    Unsupported {
-        error_message: SharedString,
-        upgrade_message: SharedString,
-        upgrade_command: String,
-    },
-}
-
 pub trait AgentServer: Send {
     fn logo(&self) -> ui::IconName;
     fn name(&self) -> &'static str;
@@ -78,3 +67,99 @@ impl std::fmt::Debug for AgentServerCommand {
             .finish()
     }
 }
+
+pub enum AgentServerVersion {
+    Supported,
+    Unsupported {
+        error_message: SharedString,
+        upgrade_message: SharedString,
+        upgrade_command: String,
+    },
+}
+
+#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
+pub struct AgentServerCommand {
+    #[serde(rename = "command")]
+    pub path: PathBuf,
+    #[serde(default)]
+    pub args: Vec<String>,
+    pub env: Option<HashMap<String, String>>,
+}
+
+impl AgentServerCommand {
+    pub(crate) async fn resolve(
+        path_bin_name: &'static str,
+        extra_args: &[&'static str],
+        settings: Option<AgentServerSettings>,
+        project: &Entity<Project>,
+        cx: &mut AsyncApp,
+    ) -> Option<Self> {
+        if let Some(agent_settings) = settings {
+            return Some(Self {
+                path: agent_settings.command.path,
+                args: agent_settings
+                    .command
+                    .args
+                    .into_iter()
+                    .chain(extra_args.iter().map(|arg| arg.to_string()))
+                    .collect(),
+                env: agent_settings.command.env,
+            });
+        } else {
+            find_bin_in_path(path_bin_name, project, cx)
+                .await
+                .map(|path| Self {
+                    path,
+                    args: extra_args.iter().map(|arg| arg.to_string()).collect(),
+                    env: None,
+                })
+        }
+    }
+}
+
+async fn find_bin_in_path(
+    bin_name: &'static str,
+    project: &Entity<Project>,
+    cx: &mut AsyncApp,
+) -> Option<PathBuf> {
+    let (env_task, root_dir) = project
+        .update(cx, |project, cx| {
+            let worktree = project.visible_worktrees(cx).next();
+            match worktree {
+                Some(worktree) => {
+                    let env_task = project.environment().update(cx, |env, cx| {
+                        env.get_worktree_environment(worktree.clone(), cx)
+                    });
+
+                    let path = worktree.read(cx).abs_path();
+                    (env_task, path)
+                }
+                None => {
+                    let path: Arc<Path> = paths::home_dir().as_path().into();
+                    let env_task = project.environment().update(cx, |env, cx| {
+                        env.get_directory_environment(path.clone(), cx)
+                    });
+                    (env_task, path)
+                }
+            }
+        })
+        .log_err()?;
+
+    cx.background_executor()
+        .spawn(async move {
+            let which_result = if cfg!(windows) {
+                which::which(bin_name)
+            } else {
+                let env = env_task.await.unwrap_or_default();
+                let shell_path = env.get("PATH").cloned();
+                which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
+            };
+
+            if let Err(which::Error::CannotFindBinaryPath) = which_result {
+                return None;
+            }
+
+            which_result.log_err()
+        })
+        .await
+}

crates/agent_servers/src/claude.rs 🔗

@@ -3,6 +3,7 @@ mod tools;
 
 use collections::HashMap;
 use project::Project;
+use settings::SettingsStore;
 use std::cell::RefCell;
 use std::fmt::Display;
 use std::path::Path;
@@ -12,7 +13,7 @@ use agentic_coding_protocol::{
     self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
     StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
 };
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
 use futures::channel::oneshot;
 use futures::future::LocalBoxFuture;
 use futures::{AsyncBufReadExt, AsyncWriteExt};
@@ -28,7 +29,7 @@ use util::ResultExt;
 
 use crate::claude::mcp_server::ClaudeMcpServer;
 use crate::claude::tools::ClaudeTool;
-use crate::{AgentServer, find_bin_in_path};
+use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
 use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
 
 #[derive(Clone)]
@@ -87,31 +88,41 @@ impl AgentServer for ClaudeCode {
                 .await?;
             mcp_config_file.flush().await?;
 
-            let command = find_bin_in_path("claude", &project, cx)
-                .await
-                .context("Failed to find claude binary")?;
-
-            let mut child = util::command::new_smol_command(&command)
-                .args([
-                    "--input-format",
-                    "stream-json",
-                    "--output-format",
-                    "stream-json",
-                    "--print",
-                    "--verbose",
-                    "--mcp-config",
-                    mcp_config_path.to_string_lossy().as_ref(),
-                    "--permission-prompt-tool",
-                    &format!(
-                        "mcp__{}__{}",
-                        mcp_server::SERVER_NAME,
-                        mcp_server::PERMISSION_TOOL
-                    ),
-                    "--allowedTools",
-                    "mcp__zed__Read,mcp__zed__Edit",
-                    "--disallowedTools",
-                    "Read,Edit",
-                ])
+            let settings = cx.read_global(|settings: &SettingsStore, _| {
+                settings.get::<AllAgentServersSettings>(None).claude.clone()
+            })?;
+
+            let Some(command) =
+                AgentServerCommand::resolve("claude", &[], settings, &project, cx).await
+            else {
+                anyhow::bail!("Failed to find claude binary");
+            };
+
+            let mut child = util::command::new_smol_command(&command.path)
+                .args(
+                    [
+                        "--input-format",
+                        "stream-json",
+                        "--output-format",
+                        "stream-json",
+                        "--print",
+                        "--verbose",
+                        "--mcp-config",
+                        mcp_config_path.to_string_lossy().as_ref(),
+                        "--permission-prompt-tool",
+                        &format!(
+                            "mcp__{}__{}",
+                            mcp_server::SERVER_NAME,
+                            mcp_server::PERMISSION_TOOL
+                        ),
+                        "--allowedTools",
+                        "mcp__zed__Read,mcp__zed__Edit",
+                        "--disallowedTools",
+                        "Read,Edit",
+                    ]
+                    .into_iter()
+                    .chain(command.args.iter().map(|arg| arg.as_str())),
+                )
                 .current_dir(root_dir)
                 .stdin(std::process::Stdio::piped())
                 .stdout(std::process::Stdio::piped())
@@ -562,10 +573,20 @@ struct McpServerConfig {
 }
 
 #[cfg(test)]
-mod tests {
+pub(crate) mod tests {
     use super::*;
     use serde_json::json;
 
+    // crate::common_e2e_tests!(ClaudeCode);
+
+    pub fn local_command() -> AgentServerCommand {
+        AgentServerCommand {
+            path: "claude".into(),
+            args: vec![],
+            env: None,
+        }
+    }
+
     #[test]
     fn test_deserialize_content_untagged_text() {
         let json = json!("Hello, world!");

crates/agent_servers/src/e2e_tests.rs 🔗

@@ -0,0 +1,368 @@
+use std::{path::Path, sync::Arc, time::Duration};
+
+use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
+use acp_thread::{
+    AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
+};
+use agentic_coding_protocol as acp;
+use futures::{FutureExt, StreamExt, channel::mpsc, select};
+use gpui::{Entity, TestAppContext};
+use indoc::indoc;
+use project::{FakeFs, Project};
+use serde_json::json;
+use settings::{Settings, SettingsStore};
+use util::path;
+
+pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+    let fs = init_test(cx).await;
+    let project = Project::test(fs, [], cx).await;
+    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+
+    thread
+        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
+        .await
+        .unwrap();
+
+    thread.read_with(cx, |thread, _| {
+        assert_eq!(thread.entries().len(), 2);
+        assert!(matches!(
+            thread.entries()[0],
+            AgentThreadEntry::UserMessage(_)
+        ));
+        assert!(matches!(
+            thread.entries()[1],
+            AgentThreadEntry::AssistantMessage(_)
+        ));
+    });
+}
+
+pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+    let _fs = init_test(cx).await;
+
+    let tempdir = tempfile::tempdir().unwrap();
+    std::fs::write(
+        tempdir.path().join("foo.rs"),
+        indoc! {"
+            fn main() {
+                println!(\"Hello, world!\");
+            }
+        "},
+    )
+    .expect("failed to write file");
+    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
+    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
+    thread
+        .update(cx, |thread, cx| {
+            thread.send(
+                acp::SendUserMessageParams {
+                    chunks: vec![
+                        acp::UserMessageChunk::Text {
+                            text: "Read the file ".into(),
+                        },
+                        acp::UserMessageChunk::Path {
+                            path: Path::new("foo.rs").into(),
+                        },
+                        acp::UserMessageChunk::Text {
+                            text: " and tell me what the content of the println! is".into(),
+                        },
+                    ],
+                },
+                cx,
+            )
+        })
+        .await
+        .unwrap();
+
+    thread.read_with(cx, |thread, cx| {
+        assert_eq!(thread.entries().len(), 3);
+        assert!(matches!(
+            thread.entries()[0],
+            AgentThreadEntry::UserMessage(_)
+        ));
+        assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
+        let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
+            panic!("Expected AssistantMessage")
+        };
+        assert!(
+            assistant_message.to_markdown(cx).contains("Hello, world!"),
+            "unexpected assistant message: {:?}",
+            assistant_message.to_markdown(cx)
+        );
+    });
+}
+
+pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+    let fs = init_test(cx).await;
+    fs.insert_tree(
+        path!("/private/tmp"),
+        json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
+    )
+    .await;
+    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+
+    thread
+        .update(cx, |thread, cx| {
+            thread.send_raw(
+                "Read the '/private/tmp/foo' file and tell me what you see.",
+                cx,
+            )
+        })
+        .await
+        .unwrap();
+    thread.read_with(cx, |thread, _cx| {
+        assert!(matches!(
+            &thread.entries()[2],
+            AgentThreadEntry::ToolCall(ToolCall {
+                status: ToolCallStatus::Allowed { .. },
+                ..
+            })
+        ));
+
+        assert!(matches!(
+            thread.entries()[3],
+            AgentThreadEntry::AssistantMessage(_)
+        ));
+    });
+}
+
+pub async fn test_tool_call_with_confirmation(
+    server: impl AgentServer + 'static,
+    cx: &mut TestAppContext,
+) {
+    let fs = init_test(cx).await;
+    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+    let full_turn = thread.update(cx, |thread, cx| {
+        thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
+    });
+
+    run_until_first_tool_call(&thread, cx).await;
+
+    let tool_call_id = thread.read_with(cx, |thread, _cx| {
+        let AgentThreadEntry::ToolCall(ToolCall {
+            id,
+            status:
+                ToolCallStatus::WaitingForConfirmation {
+                    confirmation: ToolCallConfirmation::Execute { root_command, .. },
+                    ..
+                },
+            ..
+        }) = &thread.entries()[2]
+        else {
+            panic!();
+        };
+
+        assert_eq!(root_command, "echo");
+
+        *id
+    });
+
+    thread.update(cx, |thread, cx| {
+        thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
+
+        assert!(matches!(
+            &thread.entries()[2],
+            AgentThreadEntry::ToolCall(ToolCall {
+                status: ToolCallStatus::Allowed { .. },
+                ..
+            })
+        ));
+    });
+
+    full_turn.await.unwrap();
+
+    thread.read_with(cx, |thread, cx| {
+        let AgentThreadEntry::ToolCall(ToolCall {
+            content: Some(ToolCallContent::Markdown { markdown }),
+            status: ToolCallStatus::Allowed { .. },
+            ..
+        }) = &thread.entries()[2]
+        else {
+            panic!();
+        };
+
+        markdown.read_with(cx, |md, _cx| {
+            assert!(
+                md.source().contains("Hello, world!"),
+                r#"Expected '{}' to contain "Hello, world!""#,
+                md.source()
+            );
+        });
+    });
+}
+
+pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
+    let fs = init_test(cx).await;
+
+    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
+    let full_turn = thread.update(cx, |thread, cx| {
+        thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
+    });
+
+    let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
+
+    thread.read_with(cx, |thread, _cx| {
+        let AgentThreadEntry::ToolCall(ToolCall {
+            id,
+            status:
+                ToolCallStatus::WaitingForConfirmation {
+                    confirmation: ToolCallConfirmation::Execute { root_command, .. },
+                    ..
+                },
+            ..
+        }) = &thread.entries()[first_tool_call_ix]
+        else {
+            panic!("{:?}", thread.entries()[1]);
+        };
+
+        assert_eq!(root_command, "echo");
+
+        *id
+    });
+
+    thread
+        .update(cx, |thread, cx| thread.cancel(cx))
+        .await
+        .unwrap();
+    full_turn.await.unwrap();
+    thread.read_with(cx, |thread, _| {
+        let AgentThreadEntry::ToolCall(ToolCall {
+            status: ToolCallStatus::Canceled,
+            ..
+        }) = &thread.entries()[first_tool_call_ix]
+        else {
+            panic!();
+        };
+    });
+
+    thread
+        .update(cx, |thread, cx| {
+            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
+        })
+        .await
+        .unwrap();
+    thread.read_with(cx, |thread, _| {
+        assert!(matches!(
+            &thread.entries().last().unwrap(),
+            AgentThreadEntry::AssistantMessage(..),
+        ))
+    });
+}
+
+#[macro_export]
+macro_rules! common_e2e_tests {
+    ($server:expr) => {
+        mod common_e2e {
+            use super::*;
+
+            #[::gpui::test]
+            #[cfg_attr(not(feature = "e2e"), ignore)]
+            async fn basic(cx: &mut ::gpui::TestAppContext) {
+                $crate::e2e_tests::test_basic($server, cx).await;
+            }
+
+            #[::gpui::test]
+            #[cfg_attr(not(feature = "e2e"), ignore)]
+            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
+                $crate::e2e_tests::test_path_mentions($server, cx).await;
+            }
+
+            #[::gpui::test]
+            #[cfg_attr(not(feature = "e2e"), ignore)]
+            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
+                $crate::e2e_tests::test_tool_call($server, cx).await;
+            }
+
+            #[::gpui::test]
+            #[cfg_attr(not(feature = "e2e"), ignore)]
+            async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
+                $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
+            }
+
+            #[::gpui::test]
+            #[cfg_attr(not(feature = "e2e"), ignore)]
+            async fn cancel(cx: &mut ::gpui::TestAppContext) {
+                $crate::e2e_tests::test_cancel($server, cx).await;
+            }
+        }
+    };
+}
+
+// Helpers
+
+pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
+    env_logger::try_init().ok();
+
+    cx.update(|cx| {
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
+        Project::init_settings(cx);
+        language::init(cx);
+        crate::settings::init(cx);
+
+        crate::AllAgentServersSettings::override_global(
+            AllAgentServersSettings {
+                claude: Some(AgentServerSettings {
+                    command: crate::claude::tests::local_command(),
+                }),
+                gemini: Some(AgentServerSettings {
+                    command: crate::gemini::tests::local_command(),
+                }),
+            },
+            cx,
+        );
+    });
+
+    cx.executor().allow_parking();
+
+    FakeFs::new(cx.executor())
+}
+
+pub async fn new_test_thread(
+    server: impl AgentServer + 'static,
+    project: Entity<Project>,
+    current_dir: impl AsRef<Path>,
+    cx: &mut TestAppContext,
+) -> Entity<AcpThread> {
+    let thread = cx
+        .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
+        .await
+        .unwrap();
+
+    thread
+        .update(cx, |thread, _| thread.initialize())
+        .await
+        .unwrap();
+    thread
+}
+
+pub async fn run_until_first_tool_call(
+    thread: &Entity<AcpThread>,
+    cx: &mut TestAppContext,
+) -> usize {
+    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
+
+    let subscription = cx.update(|cx| {
+        cx.subscribe(thread, move |thread, _, cx| {
+            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
+                if matches!(entry, AgentThreadEntry::ToolCall(_)) {
+                    return tx.try_send(ix).unwrap();
+                }
+            }
+        })
+    });
+
+    select! {
+        // We have to use a smol timer here because
+        // cx.background_executor().timer isn't real in the test context
+        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
+            panic!("Timeout waiting for tool call")
+        }
+        ix = rx.next().fuse() => {
+            drop(subscription);
+            ix.unwrap()
+        }
+    }
+}

crates/agent_servers/src/gemini.rs 🔗

@@ -1,4 +1,4 @@
-use crate::stdio_agent_server::{StdioAgentServer, find_bin_in_path};
+use crate::stdio_agent_server::StdioAgentServer;
 use crate::{AgentServerCommand, AgentServerVersion};
 use anyhow::{Context as _, Result};
 use gpui::{AsyncApp, Entity};
@@ -38,35 +38,15 @@ impl StdioAgentServer for Gemini {
         project: &Entity<Project>,
         cx: &mut AsyncApp,
     ) -> Result<AgentServerCommand> {
-        let custom_command = cx.read_global(|settings: &SettingsStore, _| {
-            let settings = settings.get::<AllAgentServersSettings>(None);
-            settings
-                .gemini
-                .as_ref()
-                .map(|gemini_settings| AgentServerCommand {
-                    path: gemini_settings.command.path.clone(),
-                    args: gemini_settings
-                        .command
-                        .args
-                        .iter()
-                        .cloned()
-                        .chain(std::iter::once(ACP_ARG.into()))
-                        .collect(),
-                    env: gemini_settings.command.env.clone(),
-                })
+        let settings = cx.read_global(|settings: &SettingsStore, _| {
+            settings.get::<AllAgentServersSettings>(None).gemini.clone()
         })?;
 
-        if let Some(custom_command) = custom_command {
-            return Ok(custom_command);
-        }
-
-        if let Some(path) = find_bin_in_path("gemini", project, cx).await {
-            return Ok(AgentServerCommand {
-                path,
-                args: vec![ACP_ARG.into()],
-                env: None,
-            });
-        }
+        if let Some(command) =
+            AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
+        {
+            return Ok(command);
+        };
 
         let (fs, node_runtime) = project.update(cx, |project, _| {
             (project.fs().clone(), project.node_runtime().cloned())
@@ -121,381 +101,23 @@ impl StdioAgentServer for Gemini {
 }
 
 #[cfg(test)]
-mod test {
-    use std::{path::Path, time::Duration};
-
-    use acp_thread::{
-        AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent,
-        ToolCallStatus,
-    };
-    use agentic_coding_protocol as acp;
-    use anyhow::Result;
-    use futures::{FutureExt, StreamExt, channel::mpsc, select};
-    use gpui::{AsyncApp, Entity, TestAppContext};
-    use indoc::indoc;
-    use project::{FakeFs, Project};
-    use serde_json::json;
-    use settings::SettingsStore;
-    use util::path;
-
-    use crate::{AgentServer, AgentServerCommand, AgentServerVersion, StdioAgentServer};
-
-    pub async fn gemini_acp_thread(
-        project: Entity<Project>,
-        current_dir: impl AsRef<Path>,
-        cx: &mut TestAppContext,
-    ) -> Entity<AcpThread> {
-        #[derive(Clone)]
-        struct DevGemini;
-
-        impl StdioAgentServer for DevGemini {
-            async fn command(
-                &self,
-                _project: &Entity<Project>,
-                _cx: &mut AsyncApp,
-            ) -> Result<AgentServerCommand> {
-                let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
-                    .join("../../../gemini-cli/packages/cli")
-                    .to_string_lossy()
-                    .to_string();
-
-                Ok(AgentServerCommand {
-                    path: "node".into(),
-                    args: vec![cli_path, "--experimental-acp".into()],
-                    env: None,
-                })
-            }
-
-            async fn version(&self, _command: &AgentServerCommand) -> Result<AgentServerVersion> {
-                Ok(AgentServerVersion::Supported)
-            }
-
-            fn logo(&self) -> ui::IconName {
-                ui::IconName::AiGemini
-            }
-
-            fn name(&self) -> &'static str {
-                "test"
-            }
-
-            fn empty_state_headline(&self) -> &'static str {
-                "test"
-            }
-
-            fn empty_state_message(&self) -> &'static str {
-                "test"
-            }
-
-            fn supports_always_allow(&self) -> bool {
-                true
-            }
-        }
-
-        let thread = cx
-            .update(|cx| AgentServer::new_thread(&DevGemini, current_dir.as_ref(), &project, cx))
-            .await
-            .unwrap();
-
-        thread
-            .update(cx, |thread, _| thread.initialize())
-            .await
-            .unwrap();
-        thread
-    }
-
-    fn init_test(cx: &mut TestAppContext) {
-        env_logger::try_init().ok();
-        cx.update(|cx| {
-            let settings_store = SettingsStore::test(cx);
-            cx.set_global(settings_store);
-            Project::init_settings(cx);
-            language::init(cx);
-        });
-    }
-
-    #[gpui::test]
-    #[cfg_attr(not(feature = "gemini"), ignore)]
-    async fn test_gemini_basic(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        cx.executor().allow_parking();
-
-        let fs = FakeFs::new(cx.executor());
-        let project = Project::test(fs, [], cx).await;
-        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
-        thread
-            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
-            .await
-            .unwrap();
-
-        thread.read_with(cx, |thread, _| {
-            assert_eq!(thread.entries().len(), 2);
-            assert!(matches!(
-                thread.entries()[0],
-                AgentThreadEntry::UserMessage(_)
-            ));
-            assert!(matches!(
-                thread.entries()[1],
-                AgentThreadEntry::AssistantMessage(_)
-            ));
-        });
-    }
-
-    #[gpui::test]
-    #[cfg_attr(not(feature = "gemini"), ignore)]
-    async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        cx.executor().allow_parking();
-        let tempdir = tempfile::tempdir().unwrap();
-        std::fs::write(
-            tempdir.path().join("foo.rs"),
-            indoc! {"
-                fn main() {
-                    println!(\"Hello, world!\");
-                }
-            "},
-        )
-        .expect("failed to write file");
-        let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
-        let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
-        thread
-            .update(cx, |thread, cx| {
-                thread.send(
-                    acp::SendUserMessageParams {
-                        chunks: vec![
-                            acp::UserMessageChunk::Text {
-                                text: "Read the file ".into(),
-                            },
-                            acp::UserMessageChunk::Path {
-                                path: Path::new("foo.rs").into(),
-                            },
-                            acp::UserMessageChunk::Text {
-                                text: " and tell me what the content of the println! is".into(),
-                            },
-                        ],
-                    },
-                    cx,
-                )
-            })
-            .await
-            .unwrap();
-
-        thread.read_with(cx, |thread, cx| {
-            assert_eq!(thread.entries().len(), 3);
-            assert!(matches!(
-                thread.entries()[0],
-                AgentThreadEntry::UserMessage(_)
-            ));
-            assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
-            let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
-                panic!("Expected AssistantMessage")
-            };
-            assert!(
-                assistant_message.to_markdown(cx).contains("Hello, world!"),
-                "unexpected assistant message: {:?}",
-                assistant_message.to_markdown(cx)
-            );
-        });
-    }
-
-    #[gpui::test]
-    #[cfg_attr(not(feature = "gemini"), ignore)]
-    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        cx.executor().allow_parking();
-
-        let fs = FakeFs::new(cx.executor());
-        fs.insert_tree(
-            path!("/private/tmp"),
-            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
-        )
-        .await;
-        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
-        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
-        thread
-            .update(cx, |thread, cx| {
-                thread.send_raw(
-                    "Read the '/private/tmp/foo' file and tell me what you see.",
-                    cx,
-                )
-            })
-            .await
-            .unwrap();
-        thread.read_with(cx, |thread, _cx| {
-            assert!(matches!(
-                &thread.entries()[2],
-                AgentThreadEntry::ToolCall(ToolCall {
-                    status: ToolCallStatus::Allowed { .. },
-                    ..
-                })
-            ));
-
-            assert!(matches!(
-                thread.entries()[3],
-                AgentThreadEntry::AssistantMessage(_)
-            ));
-        });
-    }
-
-    #[gpui::test]
-    #[cfg_attr(not(feature = "gemini"), ignore)]
-    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        cx.executor().allow_parking();
-
-        let fs = FakeFs::new(cx.executor());
-        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
-        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
-        let full_turn = thread.update(cx, |thread, cx| {
-            thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
-        });
-
-        run_until_first_tool_call(&thread, cx).await;
-
-        let tool_call_id = thread.read_with(cx, |thread, _cx| {
-            let AgentThreadEntry::ToolCall(ToolCall {
-                id,
-                status:
-                    ToolCallStatus::WaitingForConfirmation {
-                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
-                        ..
-                    },
-                ..
-            }) = &thread.entries()[2]
-            else {
-                panic!();
-            };
-
-            assert_eq!(root_command, "echo");
-
-            *id
-        });
-
-        thread.update(cx, |thread, cx| {
-            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
-
-            assert!(matches!(
-                &thread.entries()[2],
-                AgentThreadEntry::ToolCall(ToolCall {
-                    status: ToolCallStatus::Allowed { .. },
-                    ..
-                })
-            ));
-        });
-
-        full_turn.await.unwrap();
-
-        thread.read_with(cx, |thread, cx| {
-            let AgentThreadEntry::ToolCall(ToolCall {
-                content: Some(ToolCallContent::Markdown { markdown }),
-                status: ToolCallStatus::Allowed { .. },
-                ..
-            }) = &thread.entries()[2]
-            else {
-                panic!();
-            };
-
-            markdown.read_with(cx, |md, _cx| {
-                assert!(
-                    md.source().contains("Hello, world!"),
-                    r#"Expected '{}' to contain "Hello, world!""#,
-                    md.source()
-                );
-            });
-        });
-    }
-
-    #[gpui::test]
-    #[cfg_attr(not(feature = "gemini"), ignore)]
-    async fn test_gemini_cancel(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        cx.executor().allow_parking();
-
-        let fs = FakeFs::new(cx.executor());
-        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
-        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
-        let full_turn = thread.update(cx, |thread, cx| {
-            thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
-        });
-
-        let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
-
-        thread.read_with(cx, |thread, _cx| {
-            let AgentThreadEntry::ToolCall(ToolCall {
-                id,
-                status:
-                    ToolCallStatus::WaitingForConfirmation {
-                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
-                        ..
-                    },
-                ..
-            }) = &thread.entries()[first_tool_call_ix]
-            else {
-                panic!("{:?}", thread.entries()[1]);
-            };
-
-            assert_eq!(root_command, "echo");
-
-            *id
-        });
-
-        thread
-            .update(cx, |thread, cx| thread.cancel(cx))
-            .await
-            .unwrap();
-        full_turn.await.unwrap();
-        thread.read_with(cx, |thread, _| {
-            let AgentThreadEntry::ToolCall(ToolCall {
-                status: ToolCallStatus::Canceled,
-                ..
-            }) = &thread.entries()[first_tool_call_ix]
-            else {
-                panic!();
-            };
-        });
-
-        thread
-            .update(cx, |thread, cx| {
-                thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
-            })
-            .await
-            .unwrap();
-        thread.read_with(cx, |thread, _| {
-            assert!(matches!(
-                &thread.entries().last().unwrap(),
-                AgentThreadEntry::AssistantMessage(..),
-            ))
-        });
-    }
-
-    async fn run_until_first_tool_call(
-        thread: &Entity<AcpThread>,
-        cx: &mut TestAppContext,
-    ) -> usize {
-        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
-
-        let subscription = cx.update(|cx| {
-            cx.subscribe(thread, move |thread, _, cx| {
-                for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
-                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
-                        return tx.try_send(ix).unwrap();
-                    }
-                }
-            })
-        });
-
-        select! {
-            _ =  cx.executor().timer(Duration::from_secs(10)).fuse() => {
-                panic!("Timeout waiting for tool call")
-            }
-            ix = rx.next().fuse() => {
-                drop(subscription);
-                ix.unwrap()
-            }
+pub(crate) mod tests {
+    use super::*;
+    use crate::AgentServerCommand;
+    use std::path::Path;
+
+    crate::common_e2e_tests!(Gemini);
+
+    pub fn local_command() -> AgentServerCommand {
+        let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
+            .join("../../../gemini-cli/packages/cli")
+            .to_string_lossy()
+            .to_string();
+
+        AgentServerCommand {
+            path: "node".into(),
+            args: vec![cli_path, ACP_ARG.into()],
+            env: None,
         }
     }
 }

crates/agent_servers/src/settings.rs 🔗

@@ -12,6 +12,7 @@ pub fn init(cx: &mut App) {
 #[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)]
 pub struct AllAgentServersSettings {
     pub gemini: Option<AgentServerSettings>,
+    pub claude: Option<AgentServerSettings>,
 }
 
 #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]

crates/agent_servers/src/stdio_agent_server.rs 🔗

@@ -4,11 +4,8 @@ use agentic_coding_protocol as acp;
 use anyhow::{Result, anyhow};
 use gpui::{App, AsyncApp, Entity, Task, prelude::*};
 use project::Project;
-use std::{
-    path::{Path, PathBuf},
-    sync::Arc,
-};
-use util::{ResultExt, paths};
+use std::path::Path;
+use util::ResultExt;
 
 pub trait StdioAgentServer: Send + Clone {
     fn logo(&self) -> ui::IconName;
@@ -120,50 +117,3 @@ impl<T: StdioAgentServer + 'static> AgentServer for T {
         })
     }
 }
-
-pub async fn find_bin_in_path(
-    bin_name: &'static str,
-    project: &Entity<Project>,
-    cx: &mut AsyncApp,
-) -> Option<PathBuf> {
-    let (env_task, root_dir) = project
-        .update(cx, |project, cx| {
-            let worktree = project.visible_worktrees(cx).next();
-            match worktree {
-                Some(worktree) => {
-                    let env_task = project.environment().update(cx, |env, cx| {
-                        env.get_worktree_environment(worktree.clone(), cx)
-                    });
-
-                    let path = worktree.read(cx).abs_path();
-                    (env_task, path)
-                }
-                None => {
-                    let path: Arc<Path> = paths::home_dir().as_path().into();
-                    let env_task = project.environment().update(cx, |env, cx| {
-                        env.get_directory_environment(path.clone(), cx)
-                    });
-                    (env_task, path)
-                }
-            }
-        })
-        .log_err()?;
-
-    cx.background_executor()
-        .spawn(async move {
-            let which_result = if cfg!(windows) {
-                which::which(bin_name)
-            } else {
-                let env = env_task.await.unwrap_or_default();
-                let shell_path = env.get("PATH").cloned();
-                which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
-            };
-
-            if let Err(which::Error::CannotFindBinaryPath) = which_result {
-                return None;
-            }
-
-            which_result.log_err()
-        })
-        .await
-}