From d81b73f3d6800c945b2effb72287ccf4658284d5 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 8 Apr 2026 20:21:02 +0200 Subject: [PATCH] acp: Better handling of terminal auth on remote connections (#53396) We were incorrectly wrapping new terminal auth methods in double ssh calls. Only affected ACP beta users, but important for testing and stabilizing the feature. We moved the ssh wrapping to be only added in the acp process creation where it was needed. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - N/A --------- Co-authored-by: Bennet Bo Fenner --- Cargo.lock | 1 + crates/acp_thread/src/connection.rs | 2 +- crates/agent_servers/Cargo.toml | 1 + crates/agent_servers/src/acp.rs | 183 ++++++++++-------- crates/agent_servers/src/custom.rs | 10 +- crates/agent_ui/src/conversation_view.rs | 28 +-- crates/project/src/agent_server_store.rs | 73 ++++--- .../tests/integration/ext_agent_tests.rs | 4 +- .../integration/extension_agent_tests.rs | 4 +- .../remote_server/src/remote_editing_tests.rs | 6 +- 10 files changed, 172 insertions(+), 140 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e35fd6d1d97f103023e2400573720a066a62b4a3..f744c9bc2d66b54cdbfdea63aa56a5b3bf6d365d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -275,6 +275,7 @@ dependencies = [ "nix 0.29.0", "project", "release_channel", + "remote", "reqwest_client", "serde", "serde_json", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 58a8aa33830f12ffb713490c87c47133cc2ad96f..32bb8abde9aa5f67563780a7fe4993028f0df346 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -117,7 +117,7 @@ pub trait AgentConnection { &self, _method: &acp::AuthMethodId, _cx: &App, - ) -> Option { + ) -> Option>> { None } diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 7151f0084b1cb7d9b206f57551ce715ef67483f7..5fbf1e821cb4a41f09c433ec05fdde9fbbde1a9f 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -39,6 +39,7 @@ language_model.workspace = true log.workspace = true project.workspace = true release_channel.workspace = true +remote.workspace = true reqwest_client = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index e56db9df927ab3cdf838587f1cb4f9514eb5a758..dbcaabed1cf1971a6e281d8d31f8dad25dfb7434 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -10,20 +10,20 @@ use collections::HashMap; use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _}; use futures::AsyncBufReadExt as _; use futures::io::BufReader; -use project::agent_server_store::AgentServerCommand; +use project::agent_server_store::{AgentServerCommand, AgentServerStore}; use project::{AgentId, Project}; +use remote::remote_client::Interactive; use serde::Deserialize; use settings::Settings as _; -use task::{ShellBuilder, SpawnInTerminal}; -use util::ResultExt as _; -use util::path_list::PathList; -use util::process::Child; - use std::path::PathBuf; use std::process::Stdio; use std::rc::Rc; use std::{any::Any, cell::RefCell}; +use task::{ShellBuilder, SpawnInTerminal}; use thiserror::Error; +use util::ResultExt as _; +use util::path_list::PathList; +use util::process::Child; use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity}; @@ -46,7 +46,7 @@ pub struct AcpConnection { connection: Rc, sessions: Rc>>, auth_methods: Vec, - command: AgentServerCommand, + agent_server_store: WeakEntity, agent_capabilities: acp::AgentCapabilities, default_mode: Option, default_model: Option, @@ -167,6 +167,7 @@ pub async fn connect( agent_id: AgentId, project: Entity, command: AgentServerCommand, + agent_server_store: WeakEntity, default_mode: Option, default_model: Option, default_config_options: HashMap, @@ -176,6 +177,7 @@ pub async fn connect( agent_id, project, command.clone(), + agent_server_store, default_mode, default_model, default_config_options, @@ -192,23 +194,52 @@ impl AcpConnection { agent_id: AgentId, project: Entity, command: AgentServerCommand, + agent_server_store: WeakEntity, default_mode: Option, default_model: Option, default_config_options: HashMap, cx: &mut AsyncApp, ) -> Result { + let root_dir = project.read_with(cx, |project, cx| { + project + .default_path_list(cx) + .ordered_paths() + .next() + .cloned() + }); + let original_command = command.clone(); + let (path, args, env) = project + .read_with(cx, |project, cx| { + project.remote_client().and_then(|client| { + let template = client + .read(cx) + .build_command_with_options( + Some(command.path.display().to_string()), + &command.args, + &command.env.clone().into_iter().flatten().collect(), + root_dir.as_ref().map(|path| path.display().to_string()), + None, + Interactive::No, + ) + .log_err()?; + Some((template.program, template.args, template.env)) + }) + }) + .unwrap_or_else(|| { + ( + command.path.display().to_string(), + command.args, + command.env.unwrap_or_default(), + ) + }); + let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone()); let builder = ShellBuilder::new(&shell, cfg!(windows)).non_interactive(); - let mut child = - builder.build_std_command(Some(command.path.display().to_string()), &command.args); - child.envs(command.env.iter().flatten()); - if let Some(cwd) = project.update(cx, |project, cx| { + let mut child = builder.build_std_command(Some(path.clone()), &args); + child.envs(env.clone()); + if let Some(cwd) = project.read_with(cx, |project, _cx| { if project.is_local() { - project - .default_path_list(cx) - .ordered_paths() - .next() - .cloned() + root_dir.as_ref() } else { None } @@ -220,11 +251,7 @@ impl AcpConnection { let stdout = child.stdout.take().context("Failed to take stdout")?; let stdin = child.stdin.take().context("Failed to take stdin")?; let stderr = child.stderr.take().context("Failed to take stderr")?; - log::debug!( - "Spawning external agent server: {:?}, {:?}", - command.path, - command.args - ); + log::debug!("Spawning external agent server: {:?}, {:?}", path, args); log::trace!("Spawned (pid: {})", child.id()); let sessions = Rc::new(RefCell::new(HashMap::default())); @@ -342,13 +369,13 @@ impl AcpConnection { // TODO: Remove this override once Google team releases their official auth methods let auth_methods = if agent_id.0.as_ref() == GEMINI_ID { - let mut args = command.args.clone(); - args.retain(|a| a != "--experimental-acp" && a != "--acp"); + let mut gemini_args = original_command.args.clone(); + gemini_args.retain(|a| a != "--experimental-acp" && a != "--acp"); let value = serde_json::json!({ "label": "gemini /auth", - "command": command.path.to_string_lossy().into_owned(), - "args": args, - "env": command.env.clone().unwrap_or_default(), + "command": original_command.path.to_string_lossy(), + "args": gemini_args, + "env": original_command.env.unwrap_or_default(), }); let meta = acp::Meta::from_iter([("terminal-auth".to_string(), value)]); vec![acp::AuthMethod::Agent( @@ -362,7 +389,7 @@ impl AcpConnection { Ok(Self { id: agent_id, auth_methods, - command, + agent_server_store, connection, telemetry_id, sessions, @@ -494,18 +521,12 @@ fn terminal_auth_task( agent_id: &AgentId, method: &acp::AuthMethodTerminal, ) -> SpawnInTerminal { - let mut args = command.args.clone(); - args.extend(method.args.clone()); - - let mut env = command.env.clone().unwrap_or_default(); - env.extend(method.env.clone()); - acp_thread::build_terminal_auth_task( terminal_auth_task_id(agent_id, &method.id), method.name.clone(), command.path.to_string_lossy().into_owned(), - args, - env, + command.args.clone(), + command.env.clone().unwrap_or_default(), ) } @@ -890,7 +911,7 @@ impl AgentConnection for AcpConnection { &self, method_id: &acp::AuthMethodId, cx: &App, - ) -> Option { + ) -> Option>> { let method = self .auth_methods .iter() @@ -898,9 +919,28 @@ impl AgentConnection for AcpConnection { match method { acp::AuthMethod::Terminal(terminal) if cx.has_flag::() => { - Some(terminal_auth_task(&self.command, &self.id, terminal)) + let agent_id = self.id.clone(); + let terminal = terminal.clone(); + let store = self.agent_server_store.clone(); + Some(cx.spawn(async move |cx| { + let command = store + .update(cx, |store, cx| { + let agent = store + .get_external_agent(&agent_id) + .context("Agent server not found")?; + anyhow::Ok(agent.get_command( + terminal.args.clone(), + HashMap::from_iter(terminal.env.clone()), + &mut cx.to_async(), + )) + })? + .context("Failed to get agent command")? + .await?; + Ok(terminal_auth_task(&command, &agent_id, &terminal)) + })) } - _ => meta_terminal_auth_task(&self.id, method_id, method), + _ => meta_terminal_auth_task(&self.id, method_id, method) + .map(|task| Task::ready(Ok(task))), } } @@ -1075,39 +1115,32 @@ mod tests { use super::*; #[test] - fn terminal_auth_task_reuses_command_and_merges_args_and_env() { + fn terminal_auth_task_builds_spawn_from_prebuilt_command() { let command = AgentServerCommand { path: "/path/to/agent".into(), - args: vec!["--acp".into(), "--verbose".into()], + args: vec!["--acp".into(), "--verbose".into(), "/auth".into()], env: Some(HashMap::from_iter([ ("BASE".into(), "1".into()), - ("SHARED".into(), "base".into()), + ("SHARED".into(), "override".into()), + ("EXTRA".into(), "2".into()), ])), }; - let method = acp::AuthMethodTerminal::new("login", "Login") - .args(vec!["/auth".into()]) - .env(std::collections::HashMap::from_iter([ - ("EXTRA".into(), "2".into()), - ("SHARED".into(), "override".into()), - ])); + let method = acp::AuthMethodTerminal::new("login", "Login"); - let terminal_auth_task = terminal_auth_task(&command, &AgentId::new("test-agent"), &method); + let task = terminal_auth_task(&command, &AgentId::new("test-agent"), &method); + assert_eq!(task.command.as_deref(), Some("/path/to/agent")); + assert_eq!(task.args, vec!["--acp", "--verbose", "/auth"]); assert_eq!( - terminal_auth_task.command.as_deref(), - Some("/path/to/agent") - ); - assert_eq!(terminal_auth_task.args, vec!["--acp", "--verbose", "/auth"]); - assert_eq!( - terminal_auth_task.env, + task.env, HashMap::from_iter([ ("BASE".into(), "1".into()), ("SHARED".into(), "override".into()), ("EXTRA".into(), "2".into()), ]) ); - assert_eq!(terminal_auth_task.label, "Login"); - assert_eq!(terminal_auth_task.command_label, "Login"); + assert_eq!(task.label, "Login"); + assert_eq!(task.command_label, "Login"); } #[test] @@ -1127,21 +1160,17 @@ mod tests { )])), ); - let terminal_auth_task = - meta_terminal_auth_task(&AgentId::new("test-agent"), &method_id, &method) - .expect("expected legacy terminal auth task"); + let task = meta_terminal_auth_task(&AgentId::new("test-agent"), &method_id, &method) + .expect("expected legacy terminal auth task"); + assert_eq!(task.id.0, "external-agent-test-agent-legacy-login-login"); + assert_eq!(task.command.as_deref(), Some("legacy-agent")); + assert_eq!(task.args, vec!["auth", "--interactive"]); assert_eq!( - terminal_auth_task.id.0, - "external-agent-test-agent-legacy-login-login" - ); - assert_eq!(terminal_auth_task.command.as_deref(), Some("legacy-agent")); - assert_eq!(terminal_auth_task.args, vec!["auth", "--interactive"]); - assert_eq!( - terminal_auth_task.env, + task.env, HashMap::from_iter([("AUTH_MODE".into(), "interactive".into())]) ); - assert_eq!(terminal_auth_task.label, "legacy /auth"); + assert_eq!(task.label, "legacy /auth"); } #[test] @@ -1186,30 +1215,30 @@ mod tests { let command = AgentServerCommand { path: "/path/to/agent".into(), - args: vec!["--acp".into()], - env: Some(HashMap::from_iter([("BASE".into(), "1".into())])), + args: vec!["--acp".into(), "/auth".into()], + env: Some(HashMap::from_iter([ + ("BASE".into(), "1".into()), + ("AUTH_MODE".into(), "first-class".into()), + ])), }; - let terminal_auth_task = match &method { + let task = match &method { acp::AuthMethod::Terminal(terminal) => { terminal_auth_task(&command, &AgentId::new("test-agent"), terminal) } _ => unreachable!(), }; + assert_eq!(task.command.as_deref(), Some("/path/to/agent")); + assert_eq!(task.args, vec!["--acp", "/auth"]); assert_eq!( - terminal_auth_task.command.as_deref(), - Some("/path/to/agent") - ); - assert_eq!(terminal_auth_task.args, vec!["--acp", "/auth"]); - assert_eq!( - terminal_auth_task.env, + task.env, HashMap::from_iter([ ("BASE".into(), "1".into()), ("AUTH_MODE".into(), "first-class".into()), ]) ); - assert_eq!(terminal_auth_task.label, "Login"); + assert_eq!(task.label, "Login"); } } diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs index fb8d0a515244576d2cf02e4989cbd71beca448c7..151ddcefcfb0b839199c21d826a4c9f6836f876b 100644 --- a/crates/agent_servers/src/custom.rs +++ b/crates/agent_servers/src/custom.rs @@ -360,17 +360,17 @@ impl AgentServer for CustomAgentServer { let agent = store.get_external_agent(&agent_id).with_context(|| { format!("Custom agent server `{}` is not registered", agent_id) })?; - anyhow::Ok(agent.get_command( - extra_env, - delegate.new_version_available, - &mut cx.to_async(), - )) + if let Some(new_version_available_tx) = delegate.new_version_available { + agent.set_new_version_available_tx(new_version_available_tx); + } + anyhow::Ok(agent.get_command(vec![], extra_env, &mut cx.to_async())) })?? .await?; let connection = crate::acp::connect( agent_id, project, command, + store.clone(), default_mode, default_model, default_config_options, diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index 80190858151b2cf79500290a95ee0d0b6a4e8c97..1bad3c55646f9e912f79210db4afcde89c00e68a 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -1510,24 +1510,30 @@ impl ConversationView { let agent_telemetry_id = connection.telemetry_id(); - if let Some(login) = connection.terminal_auth_task(&method, cx) { + if let Some(login_task) = connection.terminal_auth_task(&method, cx) { configuration_view.take(); pending_auth_method.replace(method.clone()); let project = self.project.clone(); - let authenticate = Self::spawn_external_agent_login( - login, - workspace, - project, - method.clone(), - false, - window, - cx, - ); cx.notify(); self.auth_task = Some(cx.spawn_in(window, { async move |this, cx| { - let result = authenticate.await; + let result = async { + let login = login_task.await?; + this.update_in(cx, |_this, window, cx| { + Self::spawn_external_agent_login( + login, + workspace, + project, + method.clone(), + false, + window, + cx, + ) + })? + .await + } + .await; match &result { Ok(_) => telemetry::event!( diff --git a/crates/project/src/agent_server_store.rs b/crates/project/src/agent_server_store.rs index 0b6bb2b739f677ca1f4f3d5558538372ec6e86ff..5a9721d827cf3d189c7954f0698b662e5aaf4852 100644 --- a/crates/project/src/agent_server_store.rs +++ b/crates/project/src/agent_server_store.rs @@ -1,4 +1,3 @@ -use remote::Interactive; use std::{ any::Any, path::{Path, PathBuf}, @@ -116,9 +115,9 @@ pub enum ExternalAgentSource { pub trait ExternalAgentServer { fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task>; @@ -800,11 +799,10 @@ impl AgentServerStore { if no_browser { extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned()); } - anyhow::Ok(agent.get_command( - extra_env, - new_version_available_tx, - &mut cx.to_async(), - )) + if let Some(new_version_available_tx) = new_version_available_tx { + agent.set_new_version_available_tx(new_version_available_tx); + } + anyhow::Ok(agent.get_command(vec![], extra_env, &mut cx.to_async())) })? .await?; Ok(proto::AgentServerCommand { @@ -986,16 +984,15 @@ impl ExternalAgentServer for RemoteExternalAgentServer { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { let project_id = self.project_id; let name = self.name.to_string(); let upstream_client = self.upstream_client.downgrade(); let worktree_store = self.worktree_store.clone(); - self.new_version_available_tx = new_version_available_tx; cx.spawn(async move |cx| { let root_dir = worktree_store.read_with(cx, |worktree_store, cx| { crate::Project::default_visible_worktree_paths(worktree_store, cx) @@ -1015,22 +1012,13 @@ impl ExternalAgentServer for RemoteExternalAgentServer { }) })? .await?; - let root_dir = response.root_dir; + response.args.extend(extra_args); response.env.extend(extra_env); - let command = upstream_client.update(cx, |client, _| { - client.build_command_with_options( - Some(response.path), - &response.args, - &response.env.into_iter().collect(), - Some(root_dir.clone()), - None, - Interactive::No, - ) - })??; + Ok(AgentServerCommand { - path: command.program.into(), - args: command.args, - env: Some(command.env), + path: response.path.into(), + args: response.args, + env: Some(response.env.into_iter().collect()), }) }) } @@ -1162,12 +1150,11 @@ impl ExternalAgentServer for LocalExtensionArchiveAgent { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { - self.new_version_available_tx = new_version_available_tx; let fs = self.fs.clone(); let http_client = self.http_client.clone(); let node_runtime = self.node_runtime.clone(); @@ -1309,9 +1296,12 @@ impl ExternalAgentServer for LocalExtensionArchiveAgent { } }; + let mut args = target_config.args.clone(); + args.extend(extra_args); + let command = AgentServerCommand { path: cmd_path, - args: target_config.args.clone(), + args, env: Some(env), }; @@ -1354,12 +1344,11 @@ impl ExternalAgentServer for LocalRegistryArchiveAgent { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { - self.new_version_available_tx = new_version_available_tx; let fs = self.fs.clone(); let http_client = self.http_client.clone(); let node_runtime = self.node_runtime.clone(); @@ -1486,9 +1475,12 @@ impl ExternalAgentServer for LocalRegistryArchiveAgent { } }; + let mut args = target_config.args.clone(); + args.extend(extra_args); + let command = AgentServerCommand { path: cmd_path, - args: target_config.args.clone(), + args, env: Some(env), }; @@ -1530,12 +1522,11 @@ impl ExternalAgentServer for LocalRegistryNpxAgent { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { - self.new_version_available_tx = new_version_available_tx; let node_runtime = self.node_runtime.clone(); let project_environment = self.project_environment.downgrade(); let package = self.package.clone(); @@ -1566,9 +1557,12 @@ impl ExternalAgentServer for LocalRegistryNpxAgent { env.extend(extra_env); env.extend(settings_env); + let mut args = npm_command.args; + args.extend(extra_args); + let command = AgentServerCommand { path: npm_command.path, - args: npm_command.args, + args, env: Some(env), }; @@ -1592,9 +1586,9 @@ struct LocalCustomAgent { impl ExternalAgentServer for LocalCustomAgent { fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - _new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { let mut command = self.command.clone(); @@ -1609,6 +1603,7 @@ impl ExternalAgentServer for LocalCustomAgent { env.extend(command.env.unwrap_or_default()); env.extend(extra_env); command.env = Some(env); + command.args.extend(extra_args); Ok(command) }) } diff --git a/crates/project/tests/integration/ext_agent_tests.rs b/crates/project/tests/integration/ext_agent_tests.rs index bd4acf2b3e9419b62ff676331383b48f98874345..82135485d3f262e5984ddbd003b69b828839d4bc 100644 --- a/crates/project/tests/integration/ext_agent_tests.rs +++ b/crates/project/tests/integration/ext_agent_tests.rs @@ -8,9 +8,9 @@ struct NoopExternalAgent; impl ExternalAgentServer for NoopExternalAgent { fn get_command( - &mut self, + &self, + _extra_args: Vec, _extra_env: HashMap, - _new_version_available_tx: Option>>, _cx: &mut AsyncApp, ) -> Task> { Task::ready(Ok(AgentServerCommand { diff --git a/crates/project/tests/integration/extension_agent_tests.rs b/crates/project/tests/integration/extension_agent_tests.rs index 577bc3b2901c52f4f47d9d0c82ef89fc66e2c21a..5af2cd229c476a261a5d37666e00d2dff3b293b4 100644 --- a/crates/project/tests/integration/extension_agent_tests.rs +++ b/crates/project/tests/integration/extension_agent_tests.rs @@ -24,9 +24,9 @@ struct NoopExternalAgent; impl ExternalAgentServer for NoopExternalAgent { fn get_command( - &mut self, + &self, + _extra_args: Vec, _extra_env: HashMap, - _new_version_available_tx: Option>>, _cx: &mut AsyncApp, ) -> Task> { Task::ready(Ok(AgentServerCommand { diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index f0f23577d31075ab815d6dba1cdbdccd275c184a..571c5e7ea1aa5d623cf70d8fd06252bd0860de1b 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2256,8 +2256,8 @@ async fn test_remote_external_agent_server( .get_external_agent(&"foo".into()) .unwrap() .get_command( + vec![], HashMap::from_iter([("OTHER_VAR".into(), "other-val".into())]), - None, &mut cx.to_async(), ) }) @@ -2267,8 +2267,8 @@ async fn test_remote_external_agent_server( assert_eq!( command, AgentServerCommand { - path: "mock".into(), - args: vec!["foo-cli".into(), "--flag".into()], + path: "foo-cli".into(), + args: vec!["--flag".into()], env: Some(HashMap::from_iter([ ("NO_BROWSER".into(), "1".into()), ("VAR".into(), "val".into()),