agent_servers: Inherit codex api key environment vars for remote (#47850)

Lukas Wirth created

Closes https://github.com/zed-industries/zed/issues/46786

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/agent_servers/src/codex.rs  | 23 +++++--
crates/remote/src/transport/ssh.rs | 87 ++++++++++++-------------------
2 files changed, 51 insertions(+), 59 deletions(-)

Detailed changes

crates/agent_servers/src/codex.rs 🔗

@@ -16,12 +16,8 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
 #[derive(Clone)]
 pub struct Codex;
 
-#[cfg(test)]
-pub(crate) mod tests {
-    use super::*;
-
-    crate::common_e2e_tests!(async |_, _| Codex, allow_option_id = "proceed_once");
-}
+const CODEX_API_KEY_VAR_NAME: &str = "CODEX_API_KEY";
+const OPEN_AI_API_KEY_VAR_NAME: &str = "OPEN_AI_API_KEY";
 
 impl AgentServer for Codex {
     fn name(&self) -> SharedString {
@@ -217,7 +213,7 @@ impl AgentServer for Codex {
         let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
         let is_remote = delegate.project.read(cx).is_via_remote_server();
         let store = delegate.store.downgrade();
-        let extra_env = load_proxy_env(cx);
+        let mut extra_env = load_proxy_env(cx);
         let default_mode = self.default_mode(cx);
         let default_model = self.default_model(cx);
         let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
@@ -228,6 +224,12 @@ impl AgentServer for Codex {
                 .map(|s| s.default_config_options.clone())
                 .unwrap_or_default()
         });
+        if let Ok(api_key) = std::env::var(CODEX_API_KEY_VAR_NAME) {
+            extra_env.insert(CODEX_API_KEY_VAR_NAME.into(), api_key);
+        }
+        if let Ok(api_key) = std::env::var(OPEN_AI_API_KEY_VAR_NAME) {
+            extra_env.insert(OPEN_AI_API_KEY_VAR_NAME.into(), api_key);
+        }
 
         cx.spawn(async move |cx| {
             let (command, root_dir, login) = store
@@ -264,3 +266,10 @@ impl AgentServer for Codex {
         self
     }
 }
+
+#[cfg(test)]
+pub(crate) mod tests {
+    use super::*;
+
+    crate::common_e2e_tests!(async |_, _| Codex, allow_option_id = "proceed_once");
+}

crates/remote/src/transport/ssh.rs 🔗

@@ -128,10 +128,11 @@ impl From<settings::SshConnection> for SshConnectionOptions {
 
 struct SshSocket {
     connection_options: SshConnectionOptions,
-    #[cfg(not(target_os = "windows"))]
+    #[cfg(not(windows))]
     socket_path: std::path::PathBuf,
+    /// Extra environment variables needed for the ssh process
     envs: HashMap<String, String>,
-    #[cfg(target_os = "windows")]
+    #[cfg(windows)]
     _proxy: askpass::PasswordProxy,
 }
 
@@ -139,7 +140,7 @@ struct MasterProcess {
     process: Child,
 }
 
-#[cfg(not(target_os = "windows"))]
+#[cfg(not(windows))]
 impl MasterProcess {
     pub fn new(
         askpass_script_path: &std::ffi::OsStr,
@@ -185,7 +186,7 @@ impl MasterProcess {
     }
 }
 
-#[cfg(target_os = "windows")]
+#[cfg(windows)]
 impl MasterProcess {
     const CONNECTION_ESTABLISHED_MAGIC: &str = "ZED_SSH_CONNECTION_ESTABLISHED";
 
@@ -519,16 +520,16 @@ impl SshRemoteConnection {
         // Start the master SSH process, which does not do anything except for establish
         // the connection and keep it open, allowing other ssh commands to reuse it
         // via a control socket.
-        #[cfg(not(target_os = "windows"))]
+        #[cfg(not(windows))]
         let socket_path = temp_dir.path().join("ssh.sock");
 
-        #[cfg(target_os = "windows")]
+        #[cfg(windows)]
         let mut master_process = MasterProcess::new(
             askpass.script_path().as_ref(),
             connection_options.additional_args(),
             &destination,
         )?;
-        #[cfg(not(target_os = "windows"))]
+        #[cfg(not(windows))]
         let mut master_process = MasterProcess::new(
             askpass.script_path().as_ref(),
             connection_options.additional_args(),
@@ -570,9 +571,9 @@ impl SshRemoteConnection {
             anyhow::bail!(error_message);
         }
 
-        #[cfg(not(target_os = "windows"))]
+        #[cfg(not(windows))]
         let socket = SshSocket::new(connection_options, socket_path).await?;
-        #[cfg(target_os = "windows")]
+        #[cfg(windows)]
         let socket = SshSocket::new(
             connection_options,
             askpass
@@ -1084,7 +1085,7 @@ impl SshRemoteConnection {
 }
 
 impl SshSocket {
-    #[cfg(not(target_os = "windows"))]
+    #[cfg(not(windows))]
     async fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
         Ok(Self {
             connection_options: options,
@@ -1093,7 +1094,7 @@ impl SshSocket {
         })
     }
 
-    #[cfg(target_os = "windows")]
+    #[cfg(windows)]
     async fn new(
         options: SshConnectionOptions,
         password: askpass::EncryptedPassword,
@@ -1179,7 +1180,6 @@ impl SshSocket {
         Ok(String::from_utf8_lossy(&output.stdout).to_string())
     }
 
-    #[cfg(not(target_os = "windows"))]
     fn ssh_options<'a>(
         &self,
         command: &'a mut process::Command,
@@ -1191,40 +1191,28 @@ impl SshSocket {
             self.connection_options.additional_args_for_scp()
         };
 
-        command
+        let cmd = command
             .stdin(Stdio::piped())
             .stdout(Stdio::piped())
             .stderr(Stdio::piped())
-            .args(args)
-            .args(["-o", "ControlMaster=no", "-o"])
-            .arg(format!("ControlPath={}", self.socket_path.display()))
-    }
-
-    #[cfg(target_os = "windows")]
-    fn ssh_options<'a>(
-        &self,
-        command: &'a mut process::Command,
-        include_port_forwards: bool,
-    ) -> &'a mut process::Command {
-        let args = if include_port_forwards {
-            self.connection_options.additional_args()
-        } else {
-            self.connection_options.additional_args_for_scp()
-        };
+            .args(args);
 
-        command
-            .stdin(Stdio::piped())
-            .stdout(Stdio::piped())
-            .stderr(Stdio::piped())
-            .args(args)
-            .envs(self.envs.clone())
+        if cfg!(windows) {
+            cmd.envs(self.envs.clone());
+        }
+        #[cfg(not(windows))]
+        {
+            cmd.args(["-o", "ControlMaster=no", "-o"])
+                .arg(format!("ControlPath={}", self.socket_path.display()));
+        }
+        cmd
     }
 
     // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
     // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
-    #[cfg(not(target_os = "windows"))]
     fn ssh_args(&self) -> Vec<String> {
         let mut arguments = self.connection_options.additional_args();
+        #[cfg(not(windows))]
         arguments.extend(vec![
             "-o".to_string(),
             "ControlMaster=no".to_string(),
@@ -1232,12 +1220,7 @@ impl SshSocket {
             format!("ControlPath={}", self.socket_path.display()),
             self.connection_options.ssh_destination(),
         ]);
-        arguments
-    }
-
-    #[cfg(target_os = "windows")]
-    fn ssh_args(&self) -> Vec<String> {
-        let mut arguments = self.connection_options.additional_args();
+        #[cfg(windows)]
         arguments.push(self.connection_options.ssh_destination());
         arguments
     }
@@ -1358,26 +1341,26 @@ fn parse_port_number(port_str: &str) -> Result<u16> {
 fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
     let parts: Vec<&str> = spec.split(':').collect();
 
-    match parts.len() {
-        4 => {
-            let local_port = parse_port_number(parts[1])?;
-            let remote_port = parse_port_number(parts[3])?;
+    match *parts {
+        [a, b, c, d] => {
+            let local_port = parse_port_number(b)?;
+            let remote_port = parse_port_number(d)?;
 
             Ok(SshPortForwardOption {
-                local_host: Some(parts[0].to_string()),
+                local_host: Some(a.to_string()),
                 local_port,
-                remote_host: Some(parts[2].to_string()),
+                remote_host: Some(c.to_string()),
                 remote_port,
             })
         }
-        3 => {
-            let local_port = parse_port_number(parts[0])?;
-            let remote_port = parse_port_number(parts[2])?;
+        [a, b, c] => {
+            let local_port = parse_port_number(a)?;
+            let remote_port = parse_port_number(c)?;
 
             Ok(SshPortForwardOption {
                 local_host: None,
                 local_port,
-                remote_host: Some(parts[1].to_string()),
+                remote_host: Some(b.to_string()),
                 remote_port,
             })
         }