ssh.rs

   1use crate::{
   2    RemoteClientDelegate, RemotePlatform,
   3    remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
   4    transport::{parse_platform, parse_shell},
   5};
   6use anyhow::{Context as _, Result, anyhow};
   7use async_trait::async_trait;
   8use collections::HashMap;
   9use futures::{
  10    AsyncReadExt as _, FutureExt as _,
  11    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
  12    select_biased,
  13};
  14use gpui::{App, AppContext as _, AsyncApp, Task};
  15use parking_lot::Mutex;
  16use paths::remote_server_dir_relative;
  17use release_channel::{AppVersion, ReleaseChannel};
  18use rpc::proto::Envelope;
  19use semver::Version;
  20pub use settings::SshPortForwardOption;
  21use smol::{
  22    fs,
  23    process::{self, Child, Stdio},
  24};
  25use std::{
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28    time::Instant,
  29};
  30use tempfile::TempDir;
  31use util::{
  32    paths::{PathStyle, RemotePathBuf},
  33    rel_path::RelPath,
  34    shell::ShellKind,
  35};
  36
  37pub(crate) struct SshRemoteConnection {
  38    socket: SshSocket,
  39    master_process: Mutex<Option<MasterProcess>>,
  40    remote_binary_path: Option<Arc<RelPath>>,
  41    ssh_platform: RemotePlatform,
  42    ssh_path_style: PathStyle,
  43    ssh_shell: String,
  44    ssh_shell_kind: ShellKind,
  45    ssh_default_system_shell: String,
  46    _temp_dir: TempDir,
  47}
  48
  49#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  50pub struct SshConnectionOptions {
  51    pub host: String,
  52    pub username: Option<String>,
  53    pub port: Option<u16>,
  54    pub password: Option<String>,
  55    pub args: Option<Vec<String>>,
  56    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  57
  58    pub nickname: Option<String>,
  59    pub upload_binary_over_ssh: bool,
  60}
  61
  62impl From<settings::SshConnection> for SshConnectionOptions {
  63    fn from(val: settings::SshConnection) -> Self {
  64        SshConnectionOptions {
  65            host: val.host.into(),
  66            username: val.username,
  67            port: val.port,
  68            password: None,
  69            args: Some(val.args),
  70            nickname: val.nickname,
  71            upload_binary_over_ssh: val.upload_binary_over_ssh.unwrap_or_default(),
  72            port_forwards: val.port_forwards,
  73        }
  74    }
  75}
  76
  77struct SshSocket {
  78    connection_options: SshConnectionOptions,
  79    #[cfg(not(target_os = "windows"))]
  80    socket_path: std::path::PathBuf,
  81    envs: HashMap<String, String>,
  82    #[cfg(target_os = "windows")]
  83    _proxy: askpass::PasswordProxy,
  84}
  85
  86struct MasterProcess {
  87    process: Child,
  88}
  89
  90#[cfg(not(target_os = "windows"))]
  91impl MasterProcess {
  92    pub fn new(
  93        askpass_script_path: &std::ffi::OsStr,
  94        additional_args: Vec<String>,
  95        socket_path: &std::path::Path,
  96        url: &str,
  97    ) -> Result<Self> {
  98        let args = [
  99            "-N",
 100            "-o",
 101            "ControlPersist=no",
 102            "-o",
 103            "ControlMaster=yes",
 104            "-o",
 105        ];
 106
 107        let mut master_process = util::command::new_smol_command("ssh");
 108        master_process
 109            .kill_on_drop(true)
 110            .stdin(Stdio::null())
 111            .stdout(Stdio::piped())
 112            .stderr(Stdio::piped())
 113            .env("SSH_ASKPASS_REQUIRE", "force")
 114            .env("SSH_ASKPASS", askpass_script_path)
 115            .args(additional_args)
 116            .args(args);
 117
 118        master_process.arg(format!("ControlPath='{}'", socket_path.display()));
 119
 120        let process = master_process.arg(&url).spawn()?;
 121
 122        Ok(MasterProcess { process })
 123    }
 124
 125    pub async fn wait_connected(&mut self) -> Result<()> {
 126        let Some(mut stdout) = self.process.stdout.take() else {
 127            anyhow::bail!("ssh process stdout capture failed");
 128        };
 129
 130        let mut output = Vec::new();
 131        stdout.read_to_end(&mut output).await?;
 132        Ok(())
 133    }
 134}
 135
 136#[cfg(target_os = "windows")]
 137impl MasterProcess {
 138    const CONNECTION_ESTABLISHED_MAGIC: &str = "ZED_SSH_CONNECTION_ESTABLISHED";
 139
 140    pub fn new(
 141        askpass_script_path: &std::ffi::OsStr,
 142        additional_args: Vec<String>,
 143        url: &str,
 144    ) -> Result<Self> {
 145        // On Windows, `ControlMaster` and `ControlPath` are not supported:
 146        // https://github.com/PowerShell/Win32-OpenSSH/issues/405
 147        // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
 148        //
 149        // Using an ugly workaround to detect connection establishment
 150        // -N doesn't work with JumpHosts as windows openssh never closes stdin in that case
 151        let args = [
 152            "-t",
 153            &format!("echo '{}'; exec $0", Self::CONNECTION_ESTABLISHED_MAGIC),
 154        ];
 155
 156        let mut master_process = util::command::new_smol_command("ssh");
 157        master_process
 158            .kill_on_drop(true)
 159            .stdin(Stdio::null())
 160            .stdout(Stdio::piped())
 161            .stderr(Stdio::piped())
 162            .env("SSH_ASKPASS_REQUIRE", "force")
 163            .env("SSH_ASKPASS", askpass_script_path)
 164            .args(additional_args)
 165            .arg(url)
 166            .args(args);
 167
 168        let process = master_process.spawn()?;
 169
 170        Ok(MasterProcess { process })
 171    }
 172
 173    pub async fn wait_connected(&mut self) -> Result<()> {
 174        use smol::io::AsyncBufReadExt;
 175
 176        let Some(stdout) = self.process.stdout.take() else {
 177            anyhow::bail!("ssh process stdout capture failed");
 178        };
 179
 180        let mut reader = smol::io::BufReader::new(stdout);
 181
 182        let mut line = String::new();
 183
 184        loop {
 185            let n = reader.read_line(&mut line).await?;
 186            if n == 0 {
 187                anyhow::bail!("ssh process exited before connection established");
 188            }
 189
 190            if line.contains(Self::CONNECTION_ESTABLISHED_MAGIC) {
 191                return Ok(());
 192            }
 193        }
 194    }
 195}
 196
 197impl AsRef<Child> for MasterProcess {
 198    fn as_ref(&self) -> &Child {
 199        &self.process
 200    }
 201}
 202
 203impl AsMut<Child> for MasterProcess {
 204    fn as_mut(&mut self) -> &mut Child {
 205        &mut self.process
 206    }
 207}
 208
 209#[async_trait(?Send)]
 210impl RemoteConnection for SshRemoteConnection {
 211    async fn kill(&self) -> Result<()> {
 212        let Some(mut process) = self.master_process.lock().take() else {
 213            return Ok(());
 214        };
 215        process.as_mut().kill().ok();
 216        process.as_mut().status().await?;
 217        Ok(())
 218    }
 219
 220    fn has_been_killed(&self) -> bool {
 221        self.master_process.lock().is_none()
 222    }
 223
 224    fn connection_options(&self) -> RemoteConnectionOptions {
 225        RemoteConnectionOptions::Ssh(self.socket.connection_options.clone())
 226    }
 227
 228    fn shell(&self) -> String {
 229        self.ssh_shell.clone()
 230    }
 231
 232    fn default_system_shell(&self) -> String {
 233        self.ssh_default_system_shell.clone()
 234    }
 235
 236    fn build_command(
 237        &self,
 238        input_program: Option<String>,
 239        input_args: &[String],
 240        input_env: &HashMap<String, String>,
 241        working_dir: Option<String>,
 242        port_forward: Option<(u16, String, u16)>,
 243    ) -> Result<CommandTemplate> {
 244        let Self {
 245            ssh_path_style,
 246            socket,
 247            ssh_shell_kind,
 248            ssh_shell,
 249            ..
 250        } = self;
 251        let env = socket.envs.clone();
 252        build_command(
 253            input_program,
 254            input_args,
 255            input_env,
 256            working_dir,
 257            port_forward,
 258            env,
 259            *ssh_path_style,
 260            ssh_shell,
 261            *ssh_shell_kind,
 262            socket.ssh_args(),
 263        )
 264    }
 265
 266    fn build_forward_ports_command(
 267        &self,
 268        forwards: Vec<(u16, String, u16)>,
 269    ) -> Result<CommandTemplate> {
 270        let Self { socket, .. } = self;
 271        let mut args = socket.ssh_args();
 272        args.push("-N".into());
 273        for (local_port, host, remote_port) in forwards {
 274            args.push("-L".into());
 275            args.push(format!("{local_port}:{host}:{remote_port}"));
 276        }
 277        Ok(CommandTemplate {
 278            program: "ssh".into(),
 279            args,
 280            env: Default::default(),
 281        })
 282    }
 283
 284    fn upload_directory(
 285        &self,
 286        src_path: PathBuf,
 287        dest_path: RemotePathBuf,
 288        cx: &App,
 289    ) -> Task<Result<()>> {
 290        let dest_path_str = dest_path.to_string();
 291        let src_path_display = src_path.display().to_string();
 292
 293        let mut sftp_command = self.build_sftp_command();
 294        let mut scp_command =
 295            self.build_scp_command(&src_path, &dest_path_str, Some(&["-C", "-r"]));
 296
 297        cx.background_spawn(async move {
 298            // We will try SFTP first, and if that fails, we will fall back to SCP.
 299            // If SCP fails also, we give up and return an error.
 300            // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
 301            // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
 302            // This is for example the case on Windows as evidenced by this implementation snippet:
 303            // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
 304            if Self::is_sftp_available().await {
 305                log::debug!("using SFTP for directory upload");
 306                let mut child = sftp_command.spawn()?;
 307                if let Some(mut stdin) = child.stdin.take() {
 308                    use futures::AsyncWriteExt;
 309                    let sftp_batch = format!("put -r \"{src_path_display}\" \"{dest_path_str}\"\n");
 310                    stdin.write_all(sftp_batch.as_bytes()).await?;
 311                    stdin.flush().await?;
 312                }
 313
 314                let output = child.output().await?;
 315                if output.status.success() {
 316                    return Ok(());
 317                }
 318
 319                let stderr = String::from_utf8_lossy(&output.stderr);
 320                log::debug!("failed to upload directory via SFTP {src_path_display} -> {dest_path_str}: {stderr}");
 321            }
 322
 323            log::debug!("using SCP for directory upload");
 324            let output = scp_command.output().await?;
 325
 326            if output.status.success() {
 327                return Ok(());
 328            }
 329
 330            let stderr = String::from_utf8_lossy(&output.stderr);
 331            log::debug!("failed to upload directory via SCP {src_path_display} -> {dest_path_str}: {stderr}");
 332
 333            anyhow::bail!(
 334                "failed to upload directory via SFTP/SCP {} -> {}: {}",
 335                src_path_display,
 336                dest_path_str,
 337                stderr,
 338            );
 339        })
 340    }
 341
 342    fn start_proxy(
 343        &self,
 344        unique_identifier: String,
 345        reconnect: bool,
 346        incoming_tx: UnboundedSender<Envelope>,
 347        outgoing_rx: UnboundedReceiver<Envelope>,
 348        connection_activity_tx: Sender<()>,
 349        delegate: Arc<dyn RemoteClientDelegate>,
 350        cx: &mut AsyncApp,
 351    ) -> Task<Result<i32>> {
 352        delegate.set_status(Some("Starting proxy"), cx);
 353
 354        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
 355            return Task::ready(Err(anyhow!("Remote binary path not set")));
 356        };
 357
 358        let mut proxy_args = vec![];
 359        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
 360            if let Some(value) = std::env::var(env_var).ok() {
 361                proxy_args.push(format!("{}='{}'", env_var, value));
 362            }
 363        }
 364        proxy_args.push(remote_binary_path.display(self.path_style()).into_owned());
 365        proxy_args.push("proxy".to_owned());
 366        proxy_args.push("--identifier".to_owned());
 367        proxy_args.push(unique_identifier);
 368
 369        if reconnect {
 370            proxy_args.push("--reconnect".to_owned());
 371        }
 372
 373        let ssh_proxy_process = match self
 374            .socket
 375            .ssh_command(self.ssh_shell_kind, "env", &proxy_args, false)
 376            // IMPORTANT: we kill this process when we drop the task that uses it.
 377            .kill_on_drop(true)
 378            .spawn()
 379        {
 380            Ok(process) => process,
 381            Err(error) => {
 382                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
 383            }
 384        };
 385
 386        super::handle_rpc_messages_over_child_process_stdio(
 387            ssh_proxy_process,
 388            incoming_tx,
 389            outgoing_rx,
 390            connection_activity_tx,
 391            cx,
 392        )
 393    }
 394
 395    fn path_style(&self) -> PathStyle {
 396        self.ssh_path_style
 397    }
 398
 399    fn has_wsl_interop(&self) -> bool {
 400        false
 401    }
 402}
 403
 404impl SshRemoteConnection {
 405    pub(crate) async fn new(
 406        connection_options: SshConnectionOptions,
 407        delegate: Arc<dyn RemoteClientDelegate>,
 408        cx: &mut AsyncApp,
 409    ) -> Result<Self> {
 410        use askpass::AskPassResult;
 411
 412        let url = connection_options.ssh_url();
 413
 414        let temp_dir = tempfile::Builder::new()
 415            .prefix("zed-ssh-session")
 416            .tempdir()?;
 417        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
 418            let delegate = delegate.clone();
 419            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
 420        });
 421
 422        let mut askpass =
 423            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
 424
 425        delegate.set_status(Some("Connecting"), cx);
 426
 427        // Start the master SSH process, which does not do anything except for establish
 428        // the connection and keep it open, allowing other ssh commands to reuse it
 429        // via a control socket.
 430        #[cfg(not(target_os = "windows"))]
 431        let socket_path = temp_dir.path().join("ssh.sock");
 432
 433        #[cfg(target_os = "windows")]
 434        let mut master_process = MasterProcess::new(
 435            askpass.script_path().as_ref(),
 436            connection_options.additional_args(),
 437            &url,
 438        )?;
 439        #[cfg(not(target_os = "windows"))]
 440        let mut master_process = MasterProcess::new(
 441            askpass.script_path().as_ref(),
 442            connection_options.additional_args(),
 443            &socket_path,
 444            &url,
 445        )?;
 446
 447        let result = select_biased! {
 448            result = askpass.run().fuse() => {
 449                match result {
 450                    AskPassResult::CancelledByUser => {
 451                        master_process.as_mut().kill().ok();
 452                        anyhow::bail!("SSH connection canceled")
 453                    }
 454                    AskPassResult::Timedout => {
 455                        anyhow::bail!("connecting to host timed out")
 456                    }
 457                }
 458            }
 459            _ = master_process.wait_connected().fuse() => {
 460                anyhow::Ok(())
 461            }
 462        };
 463
 464        if let Err(e) = result {
 465            return Err(e.context("Failed to connect to host"));
 466        }
 467
 468        if master_process.as_mut().try_status()?.is_some() {
 469            let mut output = Vec::new();
 470            output.clear();
 471            let mut stderr = master_process.as_mut().stderr.take().unwrap();
 472            stderr.read_to_end(&mut output).await?;
 473
 474            let error_message = format!(
 475                "failed to connect: {}",
 476                String::from_utf8_lossy(&output).trim()
 477            );
 478            anyhow::bail!(error_message);
 479        }
 480
 481        #[cfg(not(target_os = "windows"))]
 482        let socket = SshSocket::new(connection_options, socket_path).await?;
 483        #[cfg(target_os = "windows")]
 484        let socket = SshSocket::new(
 485            connection_options,
 486            askpass
 487                .get_password()
 488                .or_else(|| askpass::EncryptedPassword::try_from("").ok())
 489                .context("Failed to fetch askpass password")?,
 490            cx.background_executor().clone(),
 491        )
 492        .await?;
 493        drop(askpass);
 494
 495        let ssh_shell = socket.shell().await;
 496        log::info!("Remote shell discovered: {}", ssh_shell);
 497        let ssh_platform = socket.platform(ShellKind::new(&ssh_shell, false)).await?;
 498        log::info!("Remote platform discovered: {:?}", ssh_platform);
 499        let ssh_path_style = match ssh_platform.os {
 500            "windows" => PathStyle::Windows,
 501            _ => PathStyle::Posix,
 502        };
 503        let ssh_default_system_shell = String::from("/bin/sh");
 504        let ssh_shell_kind = ShellKind::new(
 505            &ssh_shell,
 506            match ssh_platform.os {
 507                "windows" => true,
 508                _ => false,
 509            },
 510        );
 511
 512        let mut this = Self {
 513            socket,
 514            master_process: Mutex::new(Some(master_process)),
 515            _temp_dir: temp_dir,
 516            remote_binary_path: None,
 517            ssh_path_style,
 518            ssh_platform,
 519            ssh_shell,
 520            ssh_shell_kind,
 521            ssh_default_system_shell,
 522        };
 523
 524        let (release_channel, version) =
 525            cx.update(|cx| (ReleaseChannel::global(cx), AppVersion::global(cx)))?;
 526        this.remote_binary_path = Some(
 527            this.ensure_server_binary(&delegate, release_channel, version, cx)
 528                .await?,
 529        );
 530
 531        Ok(this)
 532    }
 533
 534    async fn ensure_server_binary(
 535        &self,
 536        delegate: &Arc<dyn RemoteClientDelegate>,
 537        release_channel: ReleaseChannel,
 538        version: Version,
 539        cx: &mut AsyncApp,
 540    ) -> Result<Arc<RelPath>> {
 541        let version_str = match release_channel {
 542            ReleaseChannel::Dev => "build".to_string(),
 543            _ => version.to_string(),
 544        };
 545        let binary_name = format!(
 546            "zed-remote-server-{}-{}",
 547            release_channel.dev_name(),
 548            version_str
 549        );
 550        let dst_path =
 551            paths::remote_server_dir_relative().join(RelPath::unix(&binary_name).unwrap());
 552
 553        #[cfg(debug_assertions)]
 554        if let Some(remote_server_path) =
 555            super::build_remote_server_from_source(&self.ssh_platform, delegate.as_ref(), cx)
 556                .await?
 557        {
 558            let tmp_path = paths::remote_server_dir_relative().join(
 559                RelPath::unix(&format!(
 560                    "download-{}-{}",
 561                    std::process::id(),
 562                    remote_server_path.file_name().unwrap().to_string_lossy()
 563                ))
 564                .unwrap(),
 565            );
 566            self.upload_local_server_binary(&remote_server_path, &tmp_path, delegate, cx)
 567                .await?;
 568            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
 569                .await?;
 570            return Ok(dst_path);
 571        }
 572
 573        if self
 574            .socket
 575            .run_command(
 576                self.ssh_shell_kind,
 577                &dst_path.display(self.path_style()),
 578                &["version"],
 579                true,
 580            )
 581            .await
 582            .is_ok()
 583        {
 584            return Ok(dst_path);
 585        }
 586
 587        let wanted_version = cx.update(|cx| match release_channel {
 588            ReleaseChannel::Nightly => Ok(None),
 589            ReleaseChannel::Dev => {
 590                anyhow::bail!(
 591                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
 592                    dst_path
 593                )
 594            }
 595            _ => Ok(Some(AppVersion::global(cx))),
 596        })??;
 597
 598        let tmp_path_gz = remote_server_dir_relative().join(
 599            RelPath::unix(&format!(
 600                "{}-download-{}.gz",
 601                binary_name,
 602                std::process::id()
 603            ))
 604            .unwrap(),
 605        );
 606        if !self.socket.connection_options.upload_binary_over_ssh
 607            && let Some(url) = delegate
 608                .get_download_url(
 609                    self.ssh_platform,
 610                    release_channel,
 611                    wanted_version.clone(),
 612                    cx,
 613                )
 614                .await?
 615        {
 616            match self
 617                .download_binary_on_server(&url, &tmp_path_gz, delegate, cx)
 618                .await
 619            {
 620                Ok(_) => {
 621                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 622                        .await
 623                        .context("extracting server binary")?;
 624                    return Ok(dst_path);
 625                }
 626                Err(e) => {
 627                    log::error!(
 628                        "Failed to download binary on server, attempting to download locally and then upload it the server: {e:#}",
 629                    )
 630                }
 631            }
 632        }
 633
 634        let src_path = delegate
 635            .download_server_binary_locally(
 636                self.ssh_platform,
 637                release_channel,
 638                wanted_version.clone(),
 639                cx,
 640            )
 641            .await
 642            .context("downloading server binary locally")?;
 643        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
 644            .await
 645            .context("uploading server binary")?;
 646        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 647            .await
 648            .context("extracting server binary")?;
 649        Ok(dst_path)
 650    }
 651
 652    async fn download_binary_on_server(
 653        &self,
 654        url: &str,
 655        tmp_path_gz: &RelPath,
 656        delegate: &Arc<dyn RemoteClientDelegate>,
 657        cx: &mut AsyncApp,
 658    ) -> Result<()> {
 659        if let Some(parent) = tmp_path_gz.parent() {
 660            self.socket
 661                .run_command(
 662                    self.ssh_shell_kind,
 663                    "mkdir",
 664                    &["-p", parent.display(self.path_style()).as_ref()],
 665                    true,
 666                )
 667                .await?;
 668        }
 669
 670        delegate.set_status(Some("Downloading remote development server on host"), cx);
 671
 672        const CONNECT_TIMEOUT_SECS: &str = "10";
 673
 674        match self
 675            .socket
 676            .run_command(
 677                self.ssh_shell_kind,
 678                "curl",
 679                &[
 680                    "-f",
 681                    "-L",
 682                    "--connect-timeout",
 683                    CONNECT_TIMEOUT_SECS,
 684                    url,
 685                    "-o",
 686                    &tmp_path_gz.display(self.path_style()),
 687                ],
 688                true,
 689            )
 690            .await
 691        {
 692            Ok(_) => {}
 693            Err(e) => {
 694                if self
 695                    .socket
 696                    .run_command(self.ssh_shell_kind, "which", &["curl"], true)
 697                    .await
 698                    .is_ok()
 699                {
 700                    return Err(e);
 701                }
 702
 703                log::info!("curl is not available, trying wget");
 704                match self
 705                    .socket
 706                    .run_command(
 707                        self.ssh_shell_kind,
 708                        "wget",
 709                        &[
 710                            "--connect-timeout",
 711                            CONNECT_TIMEOUT_SECS,
 712                            "--tries",
 713                            "1",
 714                            url,
 715                            "-O",
 716                            &tmp_path_gz.display(self.path_style()),
 717                        ],
 718                        true,
 719                    )
 720                    .await
 721                {
 722                    Ok(_) => {}
 723                    Err(e) => {
 724                        if self
 725                            .socket
 726                            .run_command(self.ssh_shell_kind, "which", &["wget"], true)
 727                            .await
 728                            .is_ok()
 729                        {
 730                            return Err(e);
 731                        } else {
 732                            anyhow::bail!("Neither curl nor wget is available");
 733                        }
 734                    }
 735                }
 736            }
 737        }
 738
 739        Ok(())
 740    }
 741
 742    async fn upload_local_server_binary(
 743        &self,
 744        src_path: &Path,
 745        tmp_path_gz: &RelPath,
 746        delegate: &Arc<dyn RemoteClientDelegate>,
 747        cx: &mut AsyncApp,
 748    ) -> Result<()> {
 749        if let Some(parent) = tmp_path_gz.parent() {
 750            self.socket
 751                .run_command(
 752                    self.ssh_shell_kind,
 753                    "mkdir",
 754                    &["-p", parent.display(self.path_style()).as_ref()],
 755                    true,
 756                )
 757                .await?;
 758        }
 759
 760        let src_stat = fs::metadata(&src_path).await?;
 761        let size = src_stat.len();
 762
 763        let t0 = Instant::now();
 764        delegate.set_status(Some("Uploading remote development server"), cx);
 765        log::info!(
 766            "uploading remote development server to {:?} ({}kb)",
 767            tmp_path_gz,
 768            size / 1024
 769        );
 770        self.upload_file(src_path, tmp_path_gz)
 771            .await
 772            .context("failed to upload server binary")?;
 773        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 774        Ok(())
 775    }
 776
 777    async fn extract_server_binary(
 778        &self,
 779        dst_path: &RelPath,
 780        tmp_path: &RelPath,
 781        delegate: &Arc<dyn RemoteClientDelegate>,
 782        cx: &mut AsyncApp,
 783    ) -> Result<()> {
 784        delegate.set_status(Some("Extracting remote development server"), cx);
 785        let server_mode = 0o755;
 786
 787        let shell_kind = ShellKind::Posix;
 788        let orig_tmp_path = tmp_path.display(self.path_style());
 789        let server_mode = format!("{:o}", server_mode);
 790        let server_mode = shell_kind
 791            .try_quote(&server_mode)
 792            .context("shell quoting")?;
 793        let dst_path = dst_path.display(self.path_style());
 794        let dst_path = shell_kind.try_quote(&dst_path).context("shell quoting")?;
 795        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
 796            let orig_tmp_path = shell_kind
 797                .try_quote(&orig_tmp_path)
 798                .context("shell quoting")?;
 799            let tmp_path = shell_kind.try_quote(&tmp_path).context("shell quoting")?;
 800            format!(
 801                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
 802            )
 803        } else {
 804            let orig_tmp_path = shell_kind
 805                .try_quote(&orig_tmp_path)
 806                .context("shell quoting")?;
 807            format!("chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",)
 808        };
 809        let args = shell_kind.args_for_shell(false, script.to_string());
 810        self.socket
 811            .run_command(shell_kind, "sh", &args, true)
 812            .await?;
 813        Ok(())
 814    }
 815
 816    fn build_scp_command(
 817        &self,
 818        src_path: &Path,
 819        dest_path_str: &str,
 820        args: Option<&[&str]>,
 821    ) -> process::Command {
 822        let mut command = util::command::new_smol_command("scp");
 823        self.socket.ssh_options(&mut command, false).args(
 824            self.socket
 825                .connection_options
 826                .port
 827                .map(|port| vec!["-P".to_string(), port.to_string()])
 828                .unwrap_or_default(),
 829        );
 830        if let Some(args) = args {
 831            command.args(args);
 832        }
 833        command.arg(src_path).arg(format!(
 834            "{}:{}",
 835            self.socket.connection_options.scp_url(),
 836            dest_path_str
 837        ));
 838        command
 839    }
 840
 841    fn build_sftp_command(&self) -> process::Command {
 842        let mut command = util::command::new_smol_command("sftp");
 843        self.socket.ssh_options(&mut command, false).args(
 844            self.socket
 845                .connection_options
 846                .port
 847                .map(|port| vec!["-P".to_string(), port.to_string()])
 848                .unwrap_or_default(),
 849        );
 850        command.arg("-b").arg("-");
 851        command.arg(self.socket.connection_options.scp_url());
 852        command.stdin(Stdio::piped());
 853        command
 854    }
 855
 856    async fn upload_file(&self, src_path: &Path, dest_path: &RelPath) -> Result<()> {
 857        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
 858
 859        let src_path_display = src_path.display().to_string();
 860        let dest_path_str = dest_path.display(self.path_style());
 861
 862        // We will try SFTP first, and if that fails, we will fall back to SCP.
 863        // If SCP fails also, we give up and return an error.
 864        // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
 865        // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
 866        // This is for example the case on Windows as evidenced by this implementation snippet:
 867        // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
 868        if Self::is_sftp_available().await {
 869            log::debug!("using SFTP for file upload");
 870            let mut command = self.build_sftp_command();
 871            let sftp_batch = format!("put {src_path_display} {dest_path_str}\n");
 872
 873            let mut child = command.spawn()?;
 874            if let Some(mut stdin) = child.stdin.take() {
 875                use futures::AsyncWriteExt;
 876                stdin.write_all(sftp_batch.as_bytes()).await?;
 877                stdin.flush().await?;
 878            }
 879
 880            let output = child.output().await?;
 881            if output.status.success() {
 882                return Ok(());
 883            }
 884
 885            let stderr = String::from_utf8_lossy(&output.stderr);
 886            log::debug!(
 887                "failed to upload file via SFTP {src_path_display} -> {dest_path_str}: {stderr}"
 888            );
 889        }
 890
 891        log::debug!("using SCP for file upload");
 892        let mut command = self.build_scp_command(src_path, &dest_path_str, None);
 893        let output = command.output().await?;
 894
 895        if output.status.success() {
 896            return Ok(());
 897        }
 898
 899        let stderr = String::from_utf8_lossy(&output.stderr);
 900        log::debug!(
 901            "failed to upload file via SCP {src_path_display} -> {dest_path_str}: {stderr}",
 902        );
 903        anyhow::bail!(
 904            "failed to upload file via STFP/SCP {} -> {}: {}",
 905            src_path_display,
 906            dest_path_str,
 907            stderr,
 908        );
 909    }
 910
 911    async fn is_sftp_available() -> bool {
 912        which::which("sftp").is_ok()
 913    }
 914}
 915
 916impl SshSocket {
 917    #[cfg(not(target_os = "windows"))]
 918    async fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 919        Ok(Self {
 920            connection_options: options,
 921            envs: HashMap::default(),
 922            socket_path,
 923        })
 924    }
 925
 926    #[cfg(target_os = "windows")]
 927    async fn new(
 928        options: SshConnectionOptions,
 929        password: askpass::EncryptedPassword,
 930        executor: gpui::BackgroundExecutor,
 931    ) -> Result<Self> {
 932        let mut envs = HashMap::default();
 933        let get_password =
 934            move |_| Task::ready(std::ops::ControlFlow::Continue(Ok(password.clone())));
 935
 936        let _proxy = askpass::PasswordProxy::new(get_password, executor).await?;
 937        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 938        envs.insert(
 939            "SSH_ASKPASS".into(),
 940            _proxy.script_path().as_ref().display().to_string(),
 941        );
 942
 943        Ok(Self {
 944            connection_options: options,
 945            envs,
 946            _proxy,
 947        })
 948    }
 949
 950    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 951    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 952    // and passes -l as an argument to sh, not to ls.
 953    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 954    // into a machine. You must use `cd` to get back to $HOME.
 955    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 956    fn ssh_command(
 957        &self,
 958        shell_kind: ShellKind,
 959        program: &str,
 960        args: &[impl AsRef<str>],
 961        allow_pseudo_tty: bool,
 962    ) -> process::Command {
 963        let mut command = util::command::new_smol_command("ssh");
 964        let program = shell_kind.prepend_command_prefix(program);
 965        let mut to_run = shell_kind
 966            .try_quote_prefix_aware(&program)
 967            .expect("shell quoting")
 968            .into_owned();
 969        for arg in args {
 970            // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 971            debug_assert!(
 972                !arg.as_ref().contains('\n'),
 973                "multiline arguments do not work in all shells"
 974            );
 975            to_run.push(' ');
 976            to_run.push_str(&shell_kind.try_quote(arg.as_ref()).expect("shell quoting"));
 977        }
 978        let separator = shell_kind.sequential_commands_separator();
 979        let to_run = format!("cd{separator} {to_run}");
 980        self.ssh_options(&mut command, true)
 981            .arg(self.connection_options.ssh_url());
 982        if !allow_pseudo_tty {
 983            command.arg("-T");
 984        }
 985        command.arg(to_run);
 986        log::debug!("ssh {:?}", command);
 987        command
 988    }
 989
 990    async fn run_command(
 991        &self,
 992        shell_kind: ShellKind,
 993        program: &str,
 994        args: &[impl AsRef<str>],
 995        allow_pseudo_tty: bool,
 996    ) -> Result<String> {
 997        let mut command = self.ssh_command(shell_kind, program, args, allow_pseudo_tty);
 998        let output = command.output().await?;
 999        anyhow::ensure!(
1000            output.status.success(),
1001            "failed to run command {command:?}: {}",
1002            String::from_utf8_lossy(&output.stderr)
1003        );
1004        Ok(String::from_utf8_lossy(&output.stdout).to_string())
1005    }
1006
1007    #[cfg(not(target_os = "windows"))]
1008    fn ssh_options<'a>(
1009        &self,
1010        command: &'a mut process::Command,
1011        include_port_forwards: bool,
1012    ) -> &'a mut process::Command {
1013        let args = if include_port_forwards {
1014            self.connection_options.additional_args()
1015        } else {
1016            self.connection_options.additional_args_for_scp()
1017        };
1018
1019        command
1020            .stdin(Stdio::piped())
1021            .stdout(Stdio::piped())
1022            .stderr(Stdio::piped())
1023            .args(args)
1024            .args(["-o", "ControlMaster=no", "-o"])
1025            .arg(format!("ControlPath={}", self.socket_path.display()))
1026    }
1027
1028    #[cfg(target_os = "windows")]
1029    fn ssh_options<'a>(
1030        &self,
1031        command: &'a mut process::Command,
1032        include_port_forwards: bool,
1033    ) -> &'a mut process::Command {
1034        let args = if include_port_forwards {
1035            self.connection_options.additional_args()
1036        } else {
1037            self.connection_options.additional_args_for_scp()
1038        };
1039
1040        command
1041            .stdin(Stdio::piped())
1042            .stdout(Stdio::piped())
1043            .stderr(Stdio::piped())
1044            .args(args)
1045            .envs(self.envs.clone())
1046    }
1047
1048    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1049    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1050    #[cfg(not(target_os = "windows"))]
1051    fn ssh_args(&self) -> Vec<String> {
1052        let mut arguments = self.connection_options.additional_args();
1053        arguments.extend(vec![
1054            "-o".to_string(),
1055            "ControlMaster=no".to_string(),
1056            "-o".to_string(),
1057            format!("ControlPath={}", self.socket_path.display()),
1058            self.connection_options.ssh_url(),
1059        ]);
1060        arguments
1061    }
1062
1063    #[cfg(target_os = "windows")]
1064    fn ssh_args(&self) -> Vec<String> {
1065        let mut arguments = self.connection_options.additional_args();
1066        arguments.push(self.connection_options.ssh_url());
1067        arguments
1068    }
1069
1070    async fn platform(&self, shell: ShellKind) -> Result<RemotePlatform> {
1071        let output = self.run_command(shell, "uname", &["-sm"], false).await?;
1072        parse_platform(&output)
1073    }
1074
1075    async fn shell(&self) -> String {
1076        const DEFAULT_SHELL: &str = "sh";
1077        match self
1078            .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false)
1079            .await
1080        {
1081            Ok(output) => parse_shell(&output, DEFAULT_SHELL),
1082            Err(e) => {
1083                log::error!("Failed to detect remote shell: {e}");
1084                DEFAULT_SHELL.to_owned()
1085            }
1086        }
1087    }
1088}
1089
1090fn parse_port_number(port_str: &str) -> Result<u16> {
1091    port_str
1092        .parse()
1093        .with_context(|| format!("parsing port number: {port_str}"))
1094}
1095
1096fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
1097    let parts: Vec<&str> = spec.split(':').collect();
1098
1099    match parts.len() {
1100        4 => {
1101            let local_port = parse_port_number(parts[1])?;
1102            let remote_port = parse_port_number(parts[3])?;
1103
1104            Ok(SshPortForwardOption {
1105                local_host: Some(parts[0].to_string()),
1106                local_port,
1107                remote_host: Some(parts[2].to_string()),
1108                remote_port,
1109            })
1110        }
1111        3 => {
1112            let local_port = parse_port_number(parts[0])?;
1113            let remote_port = parse_port_number(parts[2])?;
1114
1115            Ok(SshPortForwardOption {
1116                local_host: None,
1117                local_port,
1118                remote_host: Some(parts[1].to_string()),
1119                remote_port,
1120            })
1121        }
1122        _ => anyhow::bail!("Invalid port forward format"),
1123    }
1124}
1125
1126impl SshConnectionOptions {
1127    pub fn parse_command_line(input: &str) -> Result<Self> {
1128        let input = input.trim_start_matches("ssh ");
1129        let mut hostname: Option<String> = None;
1130        let mut username: Option<String> = None;
1131        let mut port: Option<u16> = None;
1132        let mut args = Vec::new();
1133        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
1134
1135        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
1136        const ALLOWED_OPTS: &[&str] = &[
1137            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
1138        ];
1139        const ALLOWED_ARGS: &[&str] = &[
1140            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
1141            "-w",
1142        ];
1143
1144        let mut tokens = ShellKind::Posix
1145            .split(input)
1146            .context("invalid input")?
1147            .into_iter();
1148
1149        'outer: while let Some(arg) = tokens.next() {
1150            if ALLOWED_OPTS.contains(&(&arg as &str)) {
1151                args.push(arg.to_string());
1152                continue;
1153            }
1154            if arg == "-p" {
1155                port = tokens.next().and_then(|arg| arg.parse().ok());
1156                continue;
1157            } else if let Some(p) = arg.strip_prefix("-p") {
1158                port = p.parse().ok();
1159                continue;
1160            }
1161            if arg == "-l" {
1162                username = tokens.next();
1163                continue;
1164            } else if let Some(l) = arg.strip_prefix("-l") {
1165                username = Some(l.to_string());
1166                continue;
1167            }
1168            if arg == "-L" || arg.starts_with("-L") {
1169                let forward_spec = if arg == "-L" {
1170                    tokens.next()
1171                } else {
1172                    Some(arg.strip_prefix("-L").unwrap().to_string())
1173                };
1174
1175                if let Some(spec) = forward_spec {
1176                    port_forwards.push(parse_port_forward_spec(&spec)?);
1177                } else {
1178                    anyhow::bail!("Missing port forward format");
1179                }
1180            }
1181
1182            for a in ALLOWED_ARGS {
1183                if arg == *a {
1184                    args.push(arg);
1185                    if let Some(next) = tokens.next() {
1186                        args.push(next);
1187                    }
1188                    continue 'outer;
1189                } else if arg.starts_with(a) {
1190                    args.push(arg);
1191                    continue 'outer;
1192                }
1193            }
1194            if arg.starts_with("-") || hostname.is_some() {
1195                anyhow::bail!("unsupported argument: {:?}", arg);
1196            }
1197            let mut input = &arg as &str;
1198            // Destination might be: username1@username2@ip2@ip1
1199            if let Some((u, rest)) = input.rsplit_once('@') {
1200                input = rest;
1201                username = Some(u.to_string());
1202            }
1203            if let Some((rest, p)) = input.split_once(':') {
1204                input = rest;
1205                port = p.parse().ok()
1206            }
1207            hostname = Some(input.to_string())
1208        }
1209
1210        let Some(hostname) = hostname else {
1211            anyhow::bail!("missing hostname");
1212        };
1213
1214        let port_forwards = match port_forwards.len() {
1215            0 => None,
1216            _ => Some(port_forwards),
1217        };
1218
1219        Ok(Self {
1220            host: hostname,
1221            username,
1222            port,
1223            port_forwards,
1224            args: Some(args),
1225            password: None,
1226            nickname: None,
1227            upload_binary_over_ssh: false,
1228        })
1229    }
1230
1231    pub fn ssh_url(&self) -> String {
1232        let mut result = String::from("ssh://");
1233        if let Some(username) = &self.username {
1234            // Username might be: username1@username2@ip2
1235            let username = urlencoding::encode(username);
1236            result.push_str(&username);
1237            result.push('@');
1238        }
1239        result.push_str(&self.host);
1240        if let Some(port) = self.port {
1241            result.push(':');
1242            result.push_str(&port.to_string());
1243        }
1244        result
1245    }
1246
1247    pub fn additional_args_for_scp(&self) -> Vec<String> {
1248        self.args.iter().flatten().cloned().collect::<Vec<String>>()
1249    }
1250
1251    pub fn additional_args(&self) -> Vec<String> {
1252        let mut args = self.additional_args_for_scp();
1253
1254        if let Some(forwards) = &self.port_forwards {
1255            args.extend(forwards.iter().map(|pf| {
1256                let local_host = match &pf.local_host {
1257                    Some(host) => host,
1258                    None => "localhost",
1259                };
1260                let remote_host = match &pf.remote_host {
1261                    Some(host) => host,
1262                    None => "localhost",
1263                };
1264
1265                format!(
1266                    "-L{}:{}:{}:{}",
1267                    local_host, pf.local_port, remote_host, pf.remote_port
1268                )
1269            }));
1270        }
1271
1272        args
1273    }
1274
1275    fn scp_url(&self) -> String {
1276        if let Some(username) = &self.username {
1277            format!("{}@{}", username, self.host)
1278        } else {
1279            self.host.clone()
1280        }
1281    }
1282
1283    pub fn connection_string(&self) -> String {
1284        let host = if let Some(username) = &self.username {
1285            format!("{}@{}", username, self.host)
1286        } else {
1287            self.host.clone()
1288        };
1289        if let Some(port) = &self.port {
1290            format!("{}:{}", host, port)
1291        } else {
1292            host
1293        }
1294    }
1295}
1296
1297fn build_command(
1298    input_program: Option<String>,
1299    input_args: &[String],
1300    input_env: &HashMap<String, String>,
1301    working_dir: Option<String>,
1302    port_forward: Option<(u16, String, u16)>,
1303    ssh_env: HashMap<String, String>,
1304    ssh_path_style: PathStyle,
1305    ssh_shell: &str,
1306    ssh_shell_kind: ShellKind,
1307    ssh_args: Vec<String>,
1308) -> Result<CommandTemplate> {
1309    use std::fmt::Write as _;
1310
1311    let mut exec = String::new();
1312    if let Some(working_dir) = working_dir {
1313        let working_dir = RemotePathBuf::new(working_dir, ssh_path_style).to_string();
1314
1315        // shlex will wrap the command in single quotes (''), disabling ~ expansion,
1316        // replace with something that works
1317        const TILDE_PREFIX: &'static str = "~/";
1318        if working_dir.starts_with(TILDE_PREFIX) {
1319            let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
1320            write!(
1321                exec,
1322                "cd \"$HOME/{working_dir}\" {} ",
1323                ssh_shell_kind.sequential_and_commands_separator()
1324            )?;
1325        } else {
1326            write!(
1327                exec,
1328                "cd \"{working_dir}\" {} ",
1329                ssh_shell_kind.sequential_and_commands_separator()
1330            )?;
1331        }
1332    } else {
1333        write!(
1334            exec,
1335            "cd {} ",
1336            ssh_shell_kind.sequential_and_commands_separator()
1337        )?;
1338    };
1339    write!(exec, "exec env ")?;
1340
1341    for (k, v) in input_env.iter() {
1342        write!(
1343            exec,
1344            "{}={} ",
1345            k,
1346            ssh_shell_kind.try_quote(v).context("shell quoting")?
1347        )?;
1348    }
1349
1350    if let Some(input_program) = input_program {
1351        write!(
1352            exec,
1353            "{}",
1354            ssh_shell_kind
1355                .try_quote_prefix_aware(&input_program)
1356                .context("shell quoting")?
1357        )?;
1358        for arg in input_args {
1359            let arg = ssh_shell_kind.try_quote(&arg).context("shell quoting")?;
1360            write!(exec, " {}", &arg)?;
1361        }
1362    } else {
1363        write!(exec, "{ssh_shell} -l")?;
1364    };
1365
1366    let mut args = Vec::new();
1367    args.extend(ssh_args);
1368
1369    if let Some((local_port, host, remote_port)) = port_forward {
1370        args.push("-L".into());
1371        args.push(format!("{local_port}:{host}:{remote_port}"));
1372    }
1373
1374    args.push("-t".into());
1375    args.push(exec);
1376    Ok(CommandTemplate {
1377        program: "ssh".into(),
1378        args,
1379        env: ssh_env,
1380    })
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385    use super::*;
1386
1387    #[test]
1388    fn test_build_command() -> Result<()> {
1389        let mut input_env = HashMap::default();
1390        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1391        let mut env = HashMap::default();
1392        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1393
1394        let command = build_command(
1395            Some("remote_program".to_string()),
1396            &["arg1".to_string(), "arg2".to_string()],
1397            &input_env,
1398            Some("~/work".to_string()),
1399            None,
1400            env.clone(),
1401            PathStyle::Posix,
1402            "/bin/fish",
1403            ShellKind::Fish,
1404            vec!["-p".to_string(), "2222".to_string()],
1405        )?;
1406
1407        assert_eq!(command.program, "ssh");
1408        assert_eq!(
1409            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1410            [
1411                "-p",
1412                "2222",
1413                "-t",
1414                "cd \"$HOME/work\" && exec env INPUT_VA=val remote_program arg1 arg2"
1415            ]
1416        );
1417        assert_eq!(command.env, env);
1418
1419        let mut input_env = HashMap::default();
1420        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1421        let mut env = HashMap::default();
1422        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1423
1424        let command = build_command(
1425            None,
1426            &["arg1".to_string(), "arg2".to_string()],
1427            &input_env,
1428            None,
1429            Some((1, "foo".to_owned(), 2)),
1430            env.clone(),
1431            PathStyle::Posix,
1432            "/bin/fish",
1433            ShellKind::Fish,
1434            vec!["-p".to_string(), "2222".to_string()],
1435        )?;
1436
1437        assert_eq!(command.program, "ssh");
1438        assert_eq!(
1439            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1440            [
1441                "-p",
1442                "2222",
1443                "-L",
1444                "1:foo:2",
1445                "-t",
1446                "cd && exec env INPUT_VA=val /bin/fish -l"
1447            ]
1448        );
1449        assert_eq!(command.env, env);
1450
1451        Ok(())
1452    }
1453
1454    #[test]
1455    fn scp_args_exclude_port_forward_flags() {
1456        let options = SshConnectionOptions {
1457            host: "example.com".into(),
1458            args: Some(vec![
1459                "-p".to_string(),
1460                "2222".to_string(),
1461                "-o".to_string(),
1462                "StrictHostKeyChecking=no".to_string(),
1463            ]),
1464            port_forwards: Some(vec![SshPortForwardOption {
1465                local_host: Some("127.0.0.1".to_string()),
1466                local_port: 8080,
1467                remote_host: Some("127.0.0.1".to_string()),
1468                remote_port: 80,
1469            }]),
1470            ..Default::default()
1471        };
1472
1473        let ssh_args = options.additional_args();
1474        assert!(
1475            ssh_args.iter().any(|arg| arg.starts_with("-L")),
1476            "expected ssh args to include port-forward: {ssh_args:?}"
1477        );
1478
1479        let scp_args = options.additional_args_for_scp();
1480        assert_eq!(
1481            scp_args,
1482            vec![
1483                "-p".to_string(),
1484                "2222".to_string(),
1485                "-o".to_string(),
1486                "StrictHostKeyChecking=no".to_string(),
1487            ]
1488        );
1489    }
1490}