diff --git a/Cargo.lock b/Cargo.lock index e35fd6d1d97f103023e2400573720a066a62b4a3..f744c9bc2d66b54cdbfdea63aa56a5b3bf6d365d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -275,6 +275,7 @@ dependencies = [ "nix 0.29.0", "project", "release_channel", + "remote", "reqwest_client", "serde", "serde_json", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 58a8aa33830f12ffb713490c87c47133cc2ad96f..32bb8abde9aa5f67563780a7fe4993028f0df346 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -117,7 +117,7 @@ pub trait AgentConnection { &self, _method: &acp::AuthMethodId, _cx: &App, - ) -> Option { + ) -> Option>> { None } diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 7151f0084b1cb7d9b206f57551ce715ef67483f7..5fbf1e821cb4a41f09c433ec05fdde9fbbde1a9f 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -39,6 +39,7 @@ language_model.workspace = true log.workspace = true project.workspace = true release_channel.workspace = true +remote.workspace = true reqwest_client = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index e56db9df927ab3cdf838587f1cb4f9514eb5a758..dbcaabed1cf1971a6e281d8d31f8dad25dfb7434 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -10,20 +10,20 @@ use collections::HashMap; use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _}; use futures::AsyncBufReadExt as _; use futures::io::BufReader; -use project::agent_server_store::AgentServerCommand; +use project::agent_server_store::{AgentServerCommand, AgentServerStore}; use project::{AgentId, Project}; +use remote::remote_client::Interactive; use serde::Deserialize; use settings::Settings as _; -use task::{ShellBuilder, SpawnInTerminal}; -use util::ResultExt as _; -use util::path_list::PathList; -use util::process::Child; - use std::path::PathBuf; use std::process::Stdio; use std::rc::Rc; use std::{any::Any, cell::RefCell}; +use task::{ShellBuilder, SpawnInTerminal}; use thiserror::Error; +use util::ResultExt as _; +use util::path_list::PathList; +use util::process::Child; use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity}; @@ -46,7 +46,7 @@ pub struct AcpConnection { connection: Rc, sessions: Rc>>, auth_methods: Vec, - command: AgentServerCommand, + agent_server_store: WeakEntity, agent_capabilities: acp::AgentCapabilities, default_mode: Option, default_model: Option, @@ -167,6 +167,7 @@ pub async fn connect( agent_id: AgentId, project: Entity, command: AgentServerCommand, + agent_server_store: WeakEntity, default_mode: Option, default_model: Option, default_config_options: HashMap, @@ -176,6 +177,7 @@ pub async fn connect( agent_id, project, command.clone(), + agent_server_store, default_mode, default_model, default_config_options, @@ -192,23 +194,52 @@ impl AcpConnection { agent_id: AgentId, project: Entity, command: AgentServerCommand, + agent_server_store: WeakEntity, default_mode: Option, default_model: Option, default_config_options: HashMap, cx: &mut AsyncApp, ) -> Result { + let root_dir = project.read_with(cx, |project, cx| { + project + .default_path_list(cx) + .ordered_paths() + .next() + .cloned() + }); + let original_command = command.clone(); + let (path, args, env) = project + .read_with(cx, |project, cx| { + project.remote_client().and_then(|client| { + let template = client + .read(cx) + .build_command_with_options( + Some(command.path.display().to_string()), + &command.args, + &command.env.clone().into_iter().flatten().collect(), + root_dir.as_ref().map(|path| path.display().to_string()), + None, + Interactive::No, + ) + .log_err()?; + Some((template.program, template.args, template.env)) + }) + }) + .unwrap_or_else(|| { + ( + command.path.display().to_string(), + command.args, + command.env.unwrap_or_default(), + ) + }); + let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone()); let builder = ShellBuilder::new(&shell, cfg!(windows)).non_interactive(); - let mut child = - builder.build_std_command(Some(command.path.display().to_string()), &command.args); - child.envs(command.env.iter().flatten()); - if let Some(cwd) = project.update(cx, |project, cx| { + let mut child = builder.build_std_command(Some(path.clone()), &args); + child.envs(env.clone()); + if let Some(cwd) = project.read_with(cx, |project, _cx| { if project.is_local() { - project - .default_path_list(cx) - .ordered_paths() - .next() - .cloned() + root_dir.as_ref() } else { None } @@ -220,11 +251,7 @@ impl AcpConnection { let stdout = child.stdout.take().context("Failed to take stdout")?; let stdin = child.stdin.take().context("Failed to take stdin")?; let stderr = child.stderr.take().context("Failed to take stderr")?; - log::debug!( - "Spawning external agent server: {:?}, {:?}", - command.path, - command.args - ); + log::debug!("Spawning external agent server: {:?}, {:?}", path, args); log::trace!("Spawned (pid: {})", child.id()); let sessions = Rc::new(RefCell::new(HashMap::default())); @@ -342,13 +369,13 @@ impl AcpConnection { // TODO: Remove this override once Google team releases their official auth methods let auth_methods = if agent_id.0.as_ref() == GEMINI_ID { - let mut args = command.args.clone(); - args.retain(|a| a != "--experimental-acp" && a != "--acp"); + let mut gemini_args = original_command.args.clone(); + gemini_args.retain(|a| a != "--experimental-acp" && a != "--acp"); let value = serde_json::json!({ "label": "gemini /auth", - "command": command.path.to_string_lossy().into_owned(), - "args": args, - "env": command.env.clone().unwrap_or_default(), + "command": original_command.path.to_string_lossy(), + "args": gemini_args, + "env": original_command.env.unwrap_or_default(), }); let meta = acp::Meta::from_iter([("terminal-auth".to_string(), value)]); vec![acp::AuthMethod::Agent( @@ -362,7 +389,7 @@ impl AcpConnection { Ok(Self { id: agent_id, auth_methods, - command, + agent_server_store, connection, telemetry_id, sessions, @@ -494,18 +521,12 @@ fn terminal_auth_task( agent_id: &AgentId, method: &acp::AuthMethodTerminal, ) -> SpawnInTerminal { - let mut args = command.args.clone(); - args.extend(method.args.clone()); - - let mut env = command.env.clone().unwrap_or_default(); - env.extend(method.env.clone()); - acp_thread::build_terminal_auth_task( terminal_auth_task_id(agent_id, &method.id), method.name.clone(), command.path.to_string_lossy().into_owned(), - args, - env, + command.args.clone(), + command.env.clone().unwrap_or_default(), ) } @@ -890,7 +911,7 @@ impl AgentConnection for AcpConnection { &self, method_id: &acp::AuthMethodId, cx: &App, - ) -> Option { + ) -> Option>> { let method = self .auth_methods .iter() @@ -898,9 +919,28 @@ impl AgentConnection for AcpConnection { match method { acp::AuthMethod::Terminal(terminal) if cx.has_flag::() => { - Some(terminal_auth_task(&self.command, &self.id, terminal)) + let agent_id = self.id.clone(); + let terminal = terminal.clone(); + let store = self.agent_server_store.clone(); + Some(cx.spawn(async move |cx| { + let command = store + .update(cx, |store, cx| { + let agent = store + .get_external_agent(&agent_id) + .context("Agent server not found")?; + anyhow::Ok(agent.get_command( + terminal.args.clone(), + HashMap::from_iter(terminal.env.clone()), + &mut cx.to_async(), + )) + })? + .context("Failed to get agent command")? + .await?; + Ok(terminal_auth_task(&command, &agent_id, &terminal)) + })) } - _ => meta_terminal_auth_task(&self.id, method_id, method), + _ => meta_terminal_auth_task(&self.id, method_id, method) + .map(|task| Task::ready(Ok(task))), } } @@ -1075,39 +1115,32 @@ mod tests { use super::*; #[test] - fn terminal_auth_task_reuses_command_and_merges_args_and_env() { + fn terminal_auth_task_builds_spawn_from_prebuilt_command() { let command = AgentServerCommand { path: "/path/to/agent".into(), - args: vec!["--acp".into(), "--verbose".into()], + args: vec!["--acp".into(), "--verbose".into(), "/auth".into()], env: Some(HashMap::from_iter([ ("BASE".into(), "1".into()), - ("SHARED".into(), "base".into()), + ("SHARED".into(), "override".into()), + ("EXTRA".into(), "2".into()), ])), }; - let method = acp::AuthMethodTerminal::new("login", "Login") - .args(vec!["/auth".into()]) - .env(std::collections::HashMap::from_iter([ - ("EXTRA".into(), "2".into()), - ("SHARED".into(), "override".into()), - ])); + let method = acp::AuthMethodTerminal::new("login", "Login"); - let terminal_auth_task = terminal_auth_task(&command, &AgentId::new("test-agent"), &method); + let task = terminal_auth_task(&command, &AgentId::new("test-agent"), &method); + assert_eq!(task.command.as_deref(), Some("/path/to/agent")); + assert_eq!(task.args, vec!["--acp", "--verbose", "/auth"]); assert_eq!( - terminal_auth_task.command.as_deref(), - Some("/path/to/agent") - ); - assert_eq!(terminal_auth_task.args, vec!["--acp", "--verbose", "/auth"]); - assert_eq!( - terminal_auth_task.env, + task.env, HashMap::from_iter([ ("BASE".into(), "1".into()), ("SHARED".into(), "override".into()), ("EXTRA".into(), "2".into()), ]) ); - assert_eq!(terminal_auth_task.label, "Login"); - assert_eq!(terminal_auth_task.command_label, "Login"); + assert_eq!(task.label, "Login"); + assert_eq!(task.command_label, "Login"); } #[test] @@ -1127,21 +1160,17 @@ mod tests { )])), ); - let terminal_auth_task = - meta_terminal_auth_task(&AgentId::new("test-agent"), &method_id, &method) - .expect("expected legacy terminal auth task"); + let task = meta_terminal_auth_task(&AgentId::new("test-agent"), &method_id, &method) + .expect("expected legacy terminal auth task"); + assert_eq!(task.id.0, "external-agent-test-agent-legacy-login-login"); + assert_eq!(task.command.as_deref(), Some("legacy-agent")); + assert_eq!(task.args, vec!["auth", "--interactive"]); assert_eq!( - terminal_auth_task.id.0, - "external-agent-test-agent-legacy-login-login" - ); - assert_eq!(terminal_auth_task.command.as_deref(), Some("legacy-agent")); - assert_eq!(terminal_auth_task.args, vec!["auth", "--interactive"]); - assert_eq!( - terminal_auth_task.env, + task.env, HashMap::from_iter([("AUTH_MODE".into(), "interactive".into())]) ); - assert_eq!(terminal_auth_task.label, "legacy /auth"); + assert_eq!(task.label, "legacy /auth"); } #[test] @@ -1186,30 +1215,30 @@ mod tests { let command = AgentServerCommand { path: "/path/to/agent".into(), - args: vec!["--acp".into()], - env: Some(HashMap::from_iter([("BASE".into(), "1".into())])), + args: vec!["--acp".into(), "/auth".into()], + env: Some(HashMap::from_iter([ + ("BASE".into(), "1".into()), + ("AUTH_MODE".into(), "first-class".into()), + ])), }; - let terminal_auth_task = match &method { + let task = match &method { acp::AuthMethod::Terminal(terminal) => { terminal_auth_task(&command, &AgentId::new("test-agent"), terminal) } _ => unreachable!(), }; + assert_eq!(task.command.as_deref(), Some("/path/to/agent")); + assert_eq!(task.args, vec!["--acp", "/auth"]); assert_eq!( - terminal_auth_task.command.as_deref(), - Some("/path/to/agent") - ); - assert_eq!(terminal_auth_task.args, vec!["--acp", "/auth"]); - assert_eq!( - terminal_auth_task.env, + task.env, HashMap::from_iter([ ("BASE".into(), "1".into()), ("AUTH_MODE".into(), "first-class".into()), ]) ); - assert_eq!(terminal_auth_task.label, "Login"); + assert_eq!(task.label, "Login"); } } diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs index fb8d0a515244576d2cf02e4989cbd71beca448c7..151ddcefcfb0b839199c21d826a4c9f6836f876b 100644 --- a/crates/agent_servers/src/custom.rs +++ b/crates/agent_servers/src/custom.rs @@ -360,17 +360,17 @@ impl AgentServer for CustomAgentServer { let agent = store.get_external_agent(&agent_id).with_context(|| { format!("Custom agent server `{}` is not registered", agent_id) })?; - anyhow::Ok(agent.get_command( - extra_env, - delegate.new_version_available, - &mut cx.to_async(), - )) + if let Some(new_version_available_tx) = delegate.new_version_available { + agent.set_new_version_available_tx(new_version_available_tx); + } + anyhow::Ok(agent.get_command(vec![], extra_env, &mut cx.to_async())) })?? .await?; let connection = crate::acp::connect( agent_id, project, command, + store.clone(), default_mode, default_model, default_config_options, diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index 80190858151b2cf79500290a95ee0d0b6a4e8c97..1bad3c55646f9e912f79210db4afcde89c00e68a 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -1510,24 +1510,30 @@ impl ConversationView { let agent_telemetry_id = connection.telemetry_id(); - if let Some(login) = connection.terminal_auth_task(&method, cx) { + if let Some(login_task) = connection.terminal_auth_task(&method, cx) { configuration_view.take(); pending_auth_method.replace(method.clone()); let project = self.project.clone(); - let authenticate = Self::spawn_external_agent_login( - login, - workspace, - project, - method.clone(), - false, - window, - cx, - ); cx.notify(); self.auth_task = Some(cx.spawn_in(window, { async move |this, cx| { - let result = authenticate.await; + let result = async { + let login = login_task.await?; + this.update_in(cx, |_this, window, cx| { + Self::spawn_external_agent_login( + login, + workspace, + project, + method.clone(), + false, + window, + cx, + ) + })? + .await + } + .await; match &result { Ok(_) => telemetry::event!( diff --git a/crates/project/src/agent_server_store.rs b/crates/project/src/agent_server_store.rs index 0b6bb2b739f677ca1f4f3d5558538372ec6e86ff..5a9721d827cf3d189c7954f0698b662e5aaf4852 100644 --- a/crates/project/src/agent_server_store.rs +++ b/crates/project/src/agent_server_store.rs @@ -1,4 +1,3 @@ -use remote::Interactive; use std::{ any::Any, path::{Path, PathBuf}, @@ -116,9 +115,9 @@ pub enum ExternalAgentSource { pub trait ExternalAgentServer { fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task>; @@ -800,11 +799,10 @@ impl AgentServerStore { if no_browser { extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned()); } - anyhow::Ok(agent.get_command( - extra_env, - new_version_available_tx, - &mut cx.to_async(), - )) + if let Some(new_version_available_tx) = new_version_available_tx { + agent.set_new_version_available_tx(new_version_available_tx); + } + anyhow::Ok(agent.get_command(vec![], extra_env, &mut cx.to_async())) })? .await?; Ok(proto::AgentServerCommand { @@ -986,16 +984,15 @@ impl ExternalAgentServer for RemoteExternalAgentServer { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { let project_id = self.project_id; let name = self.name.to_string(); let upstream_client = self.upstream_client.downgrade(); let worktree_store = self.worktree_store.clone(); - self.new_version_available_tx = new_version_available_tx; cx.spawn(async move |cx| { let root_dir = worktree_store.read_with(cx, |worktree_store, cx| { crate::Project::default_visible_worktree_paths(worktree_store, cx) @@ -1015,22 +1012,13 @@ impl ExternalAgentServer for RemoteExternalAgentServer { }) })? .await?; - let root_dir = response.root_dir; + response.args.extend(extra_args); response.env.extend(extra_env); - let command = upstream_client.update(cx, |client, _| { - client.build_command_with_options( - Some(response.path), - &response.args, - &response.env.into_iter().collect(), - Some(root_dir.clone()), - None, - Interactive::No, - ) - })??; + Ok(AgentServerCommand { - path: command.program.into(), - args: command.args, - env: Some(command.env), + path: response.path.into(), + args: response.args, + env: Some(response.env.into_iter().collect()), }) }) } @@ -1162,12 +1150,11 @@ impl ExternalAgentServer for LocalExtensionArchiveAgent { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { - self.new_version_available_tx = new_version_available_tx; let fs = self.fs.clone(); let http_client = self.http_client.clone(); let node_runtime = self.node_runtime.clone(); @@ -1309,9 +1296,12 @@ impl ExternalAgentServer for LocalExtensionArchiveAgent { } }; + let mut args = target_config.args.clone(); + args.extend(extra_args); + let command = AgentServerCommand { path: cmd_path, - args: target_config.args.clone(), + args, env: Some(env), }; @@ -1354,12 +1344,11 @@ impl ExternalAgentServer for LocalRegistryArchiveAgent { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { - self.new_version_available_tx = new_version_available_tx; let fs = self.fs.clone(); let http_client = self.http_client.clone(); let node_runtime = self.node_runtime.clone(); @@ -1486,9 +1475,12 @@ impl ExternalAgentServer for LocalRegistryArchiveAgent { } }; + let mut args = target_config.args.clone(); + args.extend(extra_args); + let command = AgentServerCommand { path: cmd_path, - args: target_config.args.clone(), + args, env: Some(env), }; @@ -1530,12 +1522,11 @@ impl ExternalAgentServer for LocalRegistryNpxAgent { } fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { - self.new_version_available_tx = new_version_available_tx; let node_runtime = self.node_runtime.clone(); let project_environment = self.project_environment.downgrade(); let package = self.package.clone(); @@ -1566,9 +1557,12 @@ impl ExternalAgentServer for LocalRegistryNpxAgent { env.extend(extra_env); env.extend(settings_env); + let mut args = npm_command.args; + args.extend(extra_args); + let command = AgentServerCommand { path: npm_command.path, - args: npm_command.args, + args, env: Some(env), }; @@ -1592,9 +1586,9 @@ struct LocalCustomAgent { impl ExternalAgentServer for LocalCustomAgent { fn get_command( - &mut self, + &self, + extra_args: Vec, extra_env: HashMap, - _new_version_available_tx: Option>>, cx: &mut AsyncApp, ) -> Task> { let mut command = self.command.clone(); @@ -1609,6 +1603,7 @@ impl ExternalAgentServer for LocalCustomAgent { env.extend(command.env.unwrap_or_default()); env.extend(extra_env); command.env = Some(env); + command.args.extend(extra_args); Ok(command) }) } diff --git a/crates/project/tests/integration/ext_agent_tests.rs b/crates/project/tests/integration/ext_agent_tests.rs index bd4acf2b3e9419b62ff676331383b48f98874345..82135485d3f262e5984ddbd003b69b828839d4bc 100644 --- a/crates/project/tests/integration/ext_agent_tests.rs +++ b/crates/project/tests/integration/ext_agent_tests.rs @@ -8,9 +8,9 @@ struct NoopExternalAgent; impl ExternalAgentServer for NoopExternalAgent { fn get_command( - &mut self, + &self, + _extra_args: Vec, _extra_env: HashMap, - _new_version_available_tx: Option>>, _cx: &mut AsyncApp, ) -> Task> { Task::ready(Ok(AgentServerCommand { diff --git a/crates/project/tests/integration/extension_agent_tests.rs b/crates/project/tests/integration/extension_agent_tests.rs index 577bc3b2901c52f4f47d9d0c82ef89fc66e2c21a..5af2cd229c476a261a5d37666e00d2dff3b293b4 100644 --- a/crates/project/tests/integration/extension_agent_tests.rs +++ b/crates/project/tests/integration/extension_agent_tests.rs @@ -24,9 +24,9 @@ struct NoopExternalAgent; impl ExternalAgentServer for NoopExternalAgent { fn get_command( - &mut self, + &self, + _extra_args: Vec, _extra_env: HashMap, - _new_version_available_tx: Option>>, _cx: &mut AsyncApp, ) -> Task> { Task::ready(Ok(AgentServerCommand { diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index f0f23577d31075ab815d6dba1cdbdccd275c184a..571c5e7ea1aa5d623cf70d8fd06252bd0860de1b 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2256,8 +2256,8 @@ async fn test_remote_external_agent_server( .get_external_agent(&"foo".into()) .unwrap() .get_command( + vec![], HashMap::from_iter([("OTHER_VAR".into(), "other-val".into())]), - None, &mut cx.to_async(), ) }) @@ -2267,8 +2267,8 @@ async fn test_remote_external_agent_server( assert_eq!( command, AgentServerCommand { - path: "mock".into(), - args: vec!["foo-cli".into(), "--flag".into()], + path: "foo-cli".into(), + args: vec!["--flag".into()], env: Some(HashMap::from_iter([ ("NO_BROWSER".into(), "1".into()), ("VAR".into(), "val".into()),