ssh.rs

   1use crate::{
   2    RemoteClientDelegate, RemotePlatform,
   3    remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
   4    transport::{parse_platform, parse_shell},
   5};
   6use anyhow::{Context as _, Result, anyhow};
   7use async_trait::async_trait;
   8use collections::HashMap;
   9use futures::{
  10    AsyncReadExt as _, FutureExt as _,
  11    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
  12    select_biased,
  13};
  14use gpui::{App, AppContext as _, AsyncApp, Task};
  15use parking_lot::Mutex;
  16use paths::remote_server_dir_relative;
  17use release_channel::{AppVersion, ReleaseChannel};
  18use rpc::proto::Envelope;
  19use semver::Version;
  20pub use settings::SshPortForwardOption;
  21use smol::{
  22    fs,
  23    process::{self, Child, Stdio},
  24};
  25use std::{
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28    time::Instant,
  29};
  30use tempfile::TempDir;
  31use util::{
  32    paths::{PathStyle, RemotePathBuf},
  33    rel_path::RelPath,
  34    shell::ShellKind,
  35};
  36
  37pub(crate) struct SshRemoteConnection {
  38    socket: SshSocket,
  39    master_process: Mutex<Option<MasterProcess>>,
  40    remote_binary_path: Option<Arc<RelPath>>,
  41    ssh_platform: RemotePlatform,
  42    ssh_path_style: PathStyle,
  43    ssh_shell: String,
  44    ssh_shell_kind: ShellKind,
  45    ssh_default_system_shell: String,
  46    _temp_dir: TempDir,
  47}
  48
  49#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  50pub struct SshConnectionOptions {
  51    pub host: String,
  52    pub username: Option<String>,
  53    pub port: Option<u16>,
  54    pub password: Option<String>,
  55    pub args: Option<Vec<String>>,
  56    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  57    pub connection_timeout: Option<u16>,
  58
  59    pub nickname: Option<String>,
  60    pub upload_binary_over_ssh: bool,
  61}
  62
  63impl From<settings::SshConnection> for SshConnectionOptions {
  64    fn from(val: settings::SshConnection) -> Self {
  65        SshConnectionOptions {
  66            host: val.host.into(),
  67            username: val.username,
  68            port: val.port,
  69            password: None,
  70            args: Some(val.args),
  71            nickname: val.nickname,
  72            upload_binary_over_ssh: val.upload_binary_over_ssh.unwrap_or_default(),
  73            port_forwards: val.port_forwards,
  74            connection_timeout: val.connection_timeout,
  75        }
  76    }
  77}
  78
  79struct SshSocket {
  80    connection_options: SshConnectionOptions,
  81    #[cfg(not(target_os = "windows"))]
  82    socket_path: std::path::PathBuf,
  83    envs: HashMap<String, String>,
  84    #[cfg(target_os = "windows")]
  85    _proxy: askpass::PasswordProxy,
  86}
  87
  88struct MasterProcess {
  89    process: Child,
  90}
  91
  92#[cfg(not(target_os = "windows"))]
  93impl MasterProcess {
  94    pub fn new(
  95        askpass_script_path: &std::ffi::OsStr,
  96        additional_args: Vec<String>,
  97        socket_path: &std::path::Path,
  98        url: &str,
  99    ) -> Result<Self> {
 100        let args = [
 101            "-N",
 102            "-o",
 103            "ControlPersist=no",
 104            "-o",
 105            "ControlMaster=yes",
 106            "-o",
 107        ];
 108
 109        let mut master_process = util::command::new_smol_command("ssh");
 110        master_process
 111            .kill_on_drop(true)
 112            .stdin(Stdio::null())
 113            .stdout(Stdio::piped())
 114            .stderr(Stdio::piped())
 115            .env("SSH_ASKPASS_REQUIRE", "force")
 116            .env("SSH_ASKPASS", askpass_script_path)
 117            .args(additional_args)
 118            .args(args);
 119
 120        master_process.arg(format!("ControlPath={}", socket_path.display()));
 121
 122        let process = master_process.arg(&url).spawn()?;
 123
 124        Ok(MasterProcess { process })
 125    }
 126
 127    pub async fn wait_connected(&mut self) -> Result<()> {
 128        let Some(mut stdout) = self.process.stdout.take() else {
 129            anyhow::bail!("ssh process stdout capture failed");
 130        };
 131
 132        let mut output = Vec::new();
 133        stdout.read_to_end(&mut output).await?;
 134        Ok(())
 135    }
 136}
 137
 138#[cfg(target_os = "windows")]
 139impl MasterProcess {
 140    const CONNECTION_ESTABLISHED_MAGIC: &str = "ZED_SSH_CONNECTION_ESTABLISHED";
 141
 142    pub fn new(
 143        askpass_script_path: &std::ffi::OsStr,
 144        additional_args: Vec<String>,
 145        url: &str,
 146    ) -> Result<Self> {
 147        // On Windows, `ControlMaster` and `ControlPath` are not supported:
 148        // https://github.com/PowerShell/Win32-OpenSSH/issues/405
 149        // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
 150        //
 151        // Using an ugly workaround to detect connection establishment
 152        // -N doesn't work with JumpHosts as windows openssh never closes stdin in that case
 153        let args = [
 154            "-t",
 155            &format!("echo '{}'; exec $0", Self::CONNECTION_ESTABLISHED_MAGIC),
 156        ];
 157
 158        let mut master_process = util::command::new_smol_command("ssh");
 159        master_process
 160            .kill_on_drop(true)
 161            .stdin(Stdio::null())
 162            .stdout(Stdio::piped())
 163            .stderr(Stdio::piped())
 164            .env("SSH_ASKPASS_REQUIRE", "force")
 165            .env("SSH_ASKPASS", askpass_script_path)
 166            .args(additional_args)
 167            .arg(url)
 168            .args(args);
 169
 170        let process = master_process.spawn()?;
 171
 172        Ok(MasterProcess { process })
 173    }
 174
 175    pub async fn wait_connected(&mut self) -> Result<()> {
 176        use smol::io::AsyncBufReadExt;
 177
 178        let Some(stdout) = self.process.stdout.take() else {
 179            anyhow::bail!("ssh process stdout capture failed");
 180        };
 181
 182        let mut reader = smol::io::BufReader::new(stdout);
 183
 184        let mut line = String::new();
 185
 186        loop {
 187            let n = reader.read_line(&mut line).await?;
 188            if n == 0 {
 189                anyhow::bail!("ssh process exited before connection established");
 190            }
 191
 192            if line.contains(Self::CONNECTION_ESTABLISHED_MAGIC) {
 193                return Ok(());
 194            }
 195        }
 196    }
 197}
 198
 199impl AsRef<Child> for MasterProcess {
 200    fn as_ref(&self) -> &Child {
 201        &self.process
 202    }
 203}
 204
 205impl AsMut<Child> for MasterProcess {
 206    fn as_mut(&mut self) -> &mut Child {
 207        &mut self.process
 208    }
 209}
 210
 211#[async_trait(?Send)]
 212impl RemoteConnection for SshRemoteConnection {
 213    async fn kill(&self) -> Result<()> {
 214        let Some(mut process) = self.master_process.lock().take() else {
 215            return Ok(());
 216        };
 217        process.as_mut().kill().ok();
 218        process.as_mut().status().await?;
 219        Ok(())
 220    }
 221
 222    fn has_been_killed(&self) -> bool {
 223        self.master_process.lock().is_none()
 224    }
 225
 226    fn connection_options(&self) -> RemoteConnectionOptions {
 227        RemoteConnectionOptions::Ssh(self.socket.connection_options.clone())
 228    }
 229
 230    fn shell(&self) -> String {
 231        self.ssh_shell.clone()
 232    }
 233
 234    fn default_system_shell(&self) -> String {
 235        self.ssh_default_system_shell.clone()
 236    }
 237
 238    fn build_command(
 239        &self,
 240        input_program: Option<String>,
 241        input_args: &[String],
 242        input_env: &HashMap<String, String>,
 243        working_dir: Option<String>,
 244        port_forward: Option<(u16, String, u16)>,
 245    ) -> Result<CommandTemplate> {
 246        let Self {
 247            ssh_path_style,
 248            socket,
 249            ssh_shell_kind,
 250            ssh_shell,
 251            ..
 252        } = self;
 253        let env = socket.envs.clone();
 254        build_command(
 255            input_program,
 256            input_args,
 257            input_env,
 258            working_dir,
 259            port_forward,
 260            env,
 261            *ssh_path_style,
 262            ssh_shell,
 263            *ssh_shell_kind,
 264            socket.ssh_args(),
 265        )
 266    }
 267
 268    fn build_forward_ports_command(
 269        &self,
 270        forwards: Vec<(u16, String, u16)>,
 271    ) -> Result<CommandTemplate> {
 272        let Self { socket, .. } = self;
 273        let mut args = socket.ssh_args();
 274        args.push("-N".into());
 275        for (local_port, host, remote_port) in forwards {
 276            args.push("-L".into());
 277            args.push(format!("{local_port}:{host}:{remote_port}"));
 278        }
 279        Ok(CommandTemplate {
 280            program: "ssh".into(),
 281            args,
 282            env: Default::default(),
 283        })
 284    }
 285
 286    fn upload_directory(
 287        &self,
 288        src_path: PathBuf,
 289        dest_path: RemotePathBuf,
 290        cx: &App,
 291    ) -> Task<Result<()>> {
 292        let dest_path_str = dest_path.to_string();
 293        let src_path_display = src_path.display().to_string();
 294
 295        let mut sftp_command = self.build_sftp_command();
 296        let mut scp_command =
 297            self.build_scp_command(&src_path, &dest_path_str, Some(&["-C", "-r"]));
 298
 299        cx.background_spawn(async move {
 300            // We will try SFTP first, and if that fails, we will fall back to SCP.
 301            // If SCP fails also, we give up and return an error.
 302            // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
 303            // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
 304            // This is for example the case on Windows as evidenced by this implementation snippet:
 305            // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
 306            if Self::is_sftp_available().await {
 307                log::debug!("using SFTP for directory upload");
 308                let mut child = sftp_command.spawn()?;
 309                if let Some(mut stdin) = child.stdin.take() {
 310                    use futures::AsyncWriteExt;
 311                    let sftp_batch = format!("put -r \"{src_path_display}\" \"{dest_path_str}\"\n");
 312                    stdin.write_all(sftp_batch.as_bytes()).await?;
 313                    stdin.flush().await?;
 314                }
 315
 316                let output = child.output().await?;
 317                if output.status.success() {
 318                    return Ok(());
 319                }
 320
 321                let stderr = String::from_utf8_lossy(&output.stderr);
 322                log::debug!("failed to upload directory via SFTP {src_path_display} -> {dest_path_str}: {stderr}");
 323            }
 324
 325            log::debug!("using SCP for directory upload");
 326            let output = scp_command.output().await?;
 327
 328            if output.status.success() {
 329                return Ok(());
 330            }
 331
 332            let stderr = String::from_utf8_lossy(&output.stderr);
 333            log::debug!("failed to upload directory via SCP {src_path_display} -> {dest_path_str}: {stderr}");
 334
 335            anyhow::bail!(
 336                "failed to upload directory via SFTP/SCP {} -> {}: {}",
 337                src_path_display,
 338                dest_path_str,
 339                stderr,
 340            );
 341        })
 342    }
 343
 344    fn start_proxy(
 345        &self,
 346        unique_identifier: String,
 347        reconnect: bool,
 348        incoming_tx: UnboundedSender<Envelope>,
 349        outgoing_rx: UnboundedReceiver<Envelope>,
 350        connection_activity_tx: Sender<()>,
 351        delegate: Arc<dyn RemoteClientDelegate>,
 352        cx: &mut AsyncApp,
 353    ) -> Task<Result<i32>> {
 354        delegate.set_status(Some("Starting proxy"), cx);
 355
 356        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
 357            return Task::ready(Err(anyhow!("Remote binary path not set")));
 358        };
 359
 360        let mut proxy_args = vec![];
 361        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
 362            if let Some(value) = std::env::var(env_var).ok() {
 363                proxy_args.push(format!("{}='{}'", env_var, value));
 364            }
 365        }
 366        proxy_args.push(remote_binary_path.display(self.path_style()).into_owned());
 367        proxy_args.push("proxy".to_owned());
 368        proxy_args.push("--identifier".to_owned());
 369        proxy_args.push(unique_identifier);
 370
 371        if reconnect {
 372            proxy_args.push("--reconnect".to_owned());
 373        }
 374
 375        let ssh_proxy_process = match self
 376            .socket
 377            .ssh_command(self.ssh_shell_kind, "env", &proxy_args, false)
 378            // IMPORTANT: we kill this process when we drop the task that uses it.
 379            .kill_on_drop(true)
 380            .spawn()
 381        {
 382            Ok(process) => process,
 383            Err(error) => {
 384                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
 385            }
 386        };
 387
 388        super::handle_rpc_messages_over_child_process_stdio(
 389            ssh_proxy_process,
 390            incoming_tx,
 391            outgoing_rx,
 392            connection_activity_tx,
 393            cx,
 394        )
 395    }
 396
 397    fn path_style(&self) -> PathStyle {
 398        self.ssh_path_style
 399    }
 400
 401    fn has_wsl_interop(&self) -> bool {
 402        false
 403    }
 404}
 405
 406impl SshRemoteConnection {
 407    pub(crate) async fn new(
 408        connection_options: SshConnectionOptions,
 409        delegate: Arc<dyn RemoteClientDelegate>,
 410        cx: &mut AsyncApp,
 411    ) -> Result<Self> {
 412        use askpass::AskPassResult;
 413
 414        let url = connection_options.ssh_url();
 415
 416        let temp_dir = tempfile::Builder::new()
 417            .prefix("zed-ssh-session")
 418            .tempdir()?;
 419        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
 420            let delegate = delegate.clone();
 421            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
 422        });
 423
 424        let mut askpass =
 425            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
 426
 427        delegate.set_status(Some("Connecting"), cx);
 428
 429        // Start the master SSH process, which does not do anything except for establish
 430        // the connection and keep it open, allowing other ssh commands to reuse it
 431        // via a control socket.
 432        #[cfg(not(target_os = "windows"))]
 433        let socket_path = temp_dir.path().join("ssh.sock");
 434
 435        #[cfg(target_os = "windows")]
 436        let mut master_process = MasterProcess::new(
 437            askpass.script_path().as_ref(),
 438            connection_options.additional_args(),
 439            &url,
 440        )?;
 441        #[cfg(not(target_os = "windows"))]
 442        let mut master_process = MasterProcess::new(
 443            askpass.script_path().as_ref(),
 444            connection_options.additional_args(),
 445            &socket_path,
 446            &url,
 447        )?;
 448
 449        let result = select_biased! {
 450            result = askpass.run().fuse() => {
 451                match result {
 452                    AskPassResult::CancelledByUser => {
 453                        master_process.as_mut().kill().ok();
 454                        anyhow::bail!("SSH connection canceled")
 455                    }
 456                    AskPassResult::Timedout => {
 457                        anyhow::bail!("connecting to host timed out")
 458                    }
 459                }
 460            }
 461            _ = master_process.wait_connected().fuse() => {
 462                anyhow::Ok(())
 463            }
 464        };
 465
 466        if let Err(e) = result {
 467            return Err(e.context("Failed to connect to host"));
 468        }
 469
 470        if master_process.as_mut().try_status()?.is_some() {
 471            let mut output = Vec::new();
 472            output.clear();
 473            let mut stderr = master_process.as_mut().stderr.take().unwrap();
 474            stderr.read_to_end(&mut output).await?;
 475
 476            let error_message = format!(
 477                "failed to connect: {}",
 478                String::from_utf8_lossy(&output).trim()
 479            );
 480            anyhow::bail!(error_message);
 481        }
 482
 483        #[cfg(not(target_os = "windows"))]
 484        let socket = SshSocket::new(connection_options, socket_path).await?;
 485        #[cfg(target_os = "windows")]
 486        let socket = SshSocket::new(
 487            connection_options,
 488            askpass
 489                .get_password()
 490                .or_else(|| askpass::EncryptedPassword::try_from("").ok())
 491                .context("Failed to fetch askpass password")?,
 492            cx.background_executor().clone(),
 493        )
 494        .await?;
 495        drop(askpass);
 496
 497        let ssh_shell = socket.shell().await;
 498        log::info!("Remote shell discovered: {}", ssh_shell);
 499        let ssh_platform = socket.platform(ShellKind::new(&ssh_shell, false)).await?;
 500        log::info!("Remote platform discovered: {:?}", ssh_platform);
 501        let ssh_path_style = match ssh_platform.os {
 502            "windows" => PathStyle::Windows,
 503            _ => PathStyle::Posix,
 504        };
 505        let ssh_default_system_shell = String::from("/bin/sh");
 506        let ssh_shell_kind = ShellKind::new(
 507            &ssh_shell,
 508            match ssh_platform.os {
 509                "windows" => true,
 510                _ => false,
 511            },
 512        );
 513
 514        let mut this = Self {
 515            socket,
 516            master_process: Mutex::new(Some(master_process)),
 517            _temp_dir: temp_dir,
 518            remote_binary_path: None,
 519            ssh_path_style,
 520            ssh_platform,
 521            ssh_shell,
 522            ssh_shell_kind,
 523            ssh_default_system_shell,
 524        };
 525
 526        let (release_channel, version) =
 527            cx.update(|cx| (ReleaseChannel::global(cx), AppVersion::global(cx)))?;
 528        this.remote_binary_path = Some(
 529            this.ensure_server_binary(&delegate, release_channel, version, cx)
 530                .await?,
 531        );
 532
 533        Ok(this)
 534    }
 535
 536    async fn ensure_server_binary(
 537        &self,
 538        delegate: &Arc<dyn RemoteClientDelegate>,
 539        release_channel: ReleaseChannel,
 540        version: Version,
 541        cx: &mut AsyncApp,
 542    ) -> Result<Arc<RelPath>> {
 543        let version_str = match release_channel {
 544            ReleaseChannel::Dev => "build".to_string(),
 545            _ => version.to_string(),
 546        };
 547        let binary_name = format!(
 548            "zed-remote-server-{}-{}",
 549            release_channel.dev_name(),
 550            version_str
 551        );
 552        let dst_path =
 553            paths::remote_server_dir_relative().join(RelPath::unix(&binary_name).unwrap());
 554
 555        #[cfg(debug_assertions)]
 556        if let Some(remote_server_path) =
 557            super::build_remote_server_from_source(&self.ssh_platform, delegate.as_ref(), cx)
 558                .await?
 559        {
 560            let tmp_path = paths::remote_server_dir_relative().join(
 561                RelPath::unix(&format!(
 562                    "download-{}-{}",
 563                    std::process::id(),
 564                    remote_server_path.file_name().unwrap().to_string_lossy()
 565                ))
 566                .unwrap(),
 567            );
 568            self.upload_local_server_binary(&remote_server_path, &tmp_path, delegate, cx)
 569                .await?;
 570            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
 571                .await?;
 572            return Ok(dst_path);
 573        }
 574
 575        if self
 576            .socket
 577            .run_command(
 578                self.ssh_shell_kind,
 579                &dst_path.display(self.path_style()),
 580                &["version"],
 581                true,
 582            )
 583            .await
 584            .is_ok()
 585        {
 586            return Ok(dst_path);
 587        }
 588
 589        let wanted_version = cx.update(|cx| match release_channel {
 590            ReleaseChannel::Nightly => Ok(None),
 591            ReleaseChannel::Dev => {
 592                anyhow::bail!(
 593                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
 594                    dst_path
 595                )
 596            }
 597            _ => Ok(Some(AppVersion::global(cx))),
 598        })??;
 599
 600        let tmp_path_gz = remote_server_dir_relative().join(
 601            RelPath::unix(&format!(
 602                "{}-download-{}.gz",
 603                binary_name,
 604                std::process::id()
 605            ))
 606            .unwrap(),
 607        );
 608        if !self.socket.connection_options.upload_binary_over_ssh
 609            && let Some(url) = delegate
 610                .get_download_url(
 611                    self.ssh_platform,
 612                    release_channel,
 613                    wanted_version.clone(),
 614                    cx,
 615                )
 616                .await?
 617        {
 618            match self
 619                .download_binary_on_server(&url, &tmp_path_gz, delegate, cx)
 620                .await
 621            {
 622                Ok(_) => {
 623                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 624                        .await
 625                        .context("extracting server binary")?;
 626                    return Ok(dst_path);
 627                }
 628                Err(e) => {
 629                    log::error!(
 630                        "Failed to download binary on server, attempting to download locally and then upload it the server: {e:#}",
 631                    )
 632                }
 633            }
 634        }
 635
 636        let src_path = delegate
 637            .download_server_binary_locally(
 638                self.ssh_platform,
 639                release_channel,
 640                wanted_version.clone(),
 641                cx,
 642            )
 643            .await
 644            .context("downloading server binary locally")?;
 645        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
 646            .await
 647            .context("uploading server binary")?;
 648        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 649            .await
 650            .context("extracting server binary")?;
 651        Ok(dst_path)
 652    }
 653
 654    async fn download_binary_on_server(
 655        &self,
 656        url: &str,
 657        tmp_path_gz: &RelPath,
 658        delegate: &Arc<dyn RemoteClientDelegate>,
 659        cx: &mut AsyncApp,
 660    ) -> Result<()> {
 661        if let Some(parent) = tmp_path_gz.parent() {
 662            self.socket
 663                .run_command(
 664                    self.ssh_shell_kind,
 665                    "mkdir",
 666                    &["-p", parent.display(self.path_style()).as_ref()],
 667                    true,
 668                )
 669                .await?;
 670        }
 671
 672        delegate.set_status(Some("Downloading remote development server on host"), cx);
 673
 674        let connection_timeout = self
 675            .socket
 676            .connection_options
 677            .connection_timeout
 678            .unwrap_or(10)
 679            .to_string();
 680
 681        match self
 682            .socket
 683            .run_command(
 684                self.ssh_shell_kind,
 685                "curl",
 686                &[
 687                    "-f",
 688                    "-L",
 689                    "--connect-timeout",
 690                    &connection_timeout,
 691                    url,
 692                    "-o",
 693                    &tmp_path_gz.display(self.path_style()),
 694                ],
 695                true,
 696            )
 697            .await
 698        {
 699            Ok(_) => {}
 700            Err(e) => {
 701                if self
 702                    .socket
 703                    .run_command(self.ssh_shell_kind, "which", &["curl"], true)
 704                    .await
 705                    .is_ok()
 706                {
 707                    return Err(e);
 708                }
 709
 710                log::info!("curl is not available, trying wget");
 711                match self
 712                    .socket
 713                    .run_command(
 714                        self.ssh_shell_kind,
 715                        "wget",
 716                        &[
 717                            "--connect-timeout",
 718                            &connection_timeout,
 719                            "--tries",
 720                            "1",
 721                            url,
 722                            "-O",
 723                            &tmp_path_gz.display(self.path_style()),
 724                        ],
 725                        true,
 726                    )
 727                    .await
 728                {
 729                    Ok(_) => {}
 730                    Err(e) => {
 731                        if self
 732                            .socket
 733                            .run_command(self.ssh_shell_kind, "which", &["wget"], true)
 734                            .await
 735                            .is_ok()
 736                        {
 737                            return Err(e);
 738                        } else {
 739                            anyhow::bail!("Neither curl nor wget is available");
 740                        }
 741                    }
 742                }
 743            }
 744        }
 745
 746        Ok(())
 747    }
 748
 749    async fn upload_local_server_binary(
 750        &self,
 751        src_path: &Path,
 752        tmp_path_gz: &RelPath,
 753        delegate: &Arc<dyn RemoteClientDelegate>,
 754        cx: &mut AsyncApp,
 755    ) -> Result<()> {
 756        if let Some(parent) = tmp_path_gz.parent() {
 757            self.socket
 758                .run_command(
 759                    self.ssh_shell_kind,
 760                    "mkdir",
 761                    &["-p", parent.display(self.path_style()).as_ref()],
 762                    true,
 763                )
 764                .await?;
 765        }
 766
 767        let src_stat = fs::metadata(&src_path).await?;
 768        let size = src_stat.len();
 769
 770        let t0 = Instant::now();
 771        delegate.set_status(Some("Uploading remote development server"), cx);
 772        log::info!(
 773            "uploading remote development server to {:?} ({}kb)",
 774            tmp_path_gz,
 775            size / 1024
 776        );
 777        self.upload_file(src_path, tmp_path_gz)
 778            .await
 779            .context("failed to upload server binary")?;
 780        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 781        Ok(())
 782    }
 783
 784    async fn extract_server_binary(
 785        &self,
 786        dst_path: &RelPath,
 787        tmp_path: &RelPath,
 788        delegate: &Arc<dyn RemoteClientDelegate>,
 789        cx: &mut AsyncApp,
 790    ) -> Result<()> {
 791        delegate.set_status(Some("Extracting remote development server"), cx);
 792        let server_mode = 0o755;
 793
 794        let shell_kind = ShellKind::Posix;
 795        let orig_tmp_path = tmp_path.display(self.path_style());
 796        let server_mode = format!("{:o}", server_mode);
 797        let server_mode = shell_kind
 798            .try_quote(&server_mode)
 799            .context("shell quoting")?;
 800        let dst_path = dst_path.display(self.path_style());
 801        let dst_path = shell_kind.try_quote(&dst_path).context("shell quoting")?;
 802        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
 803            let orig_tmp_path = shell_kind
 804                .try_quote(&orig_tmp_path)
 805                .context("shell quoting")?;
 806            let tmp_path = shell_kind.try_quote(&tmp_path).context("shell quoting")?;
 807            format!(
 808                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
 809            )
 810        } else {
 811            let orig_tmp_path = shell_kind
 812                .try_quote(&orig_tmp_path)
 813                .context("shell quoting")?;
 814            format!("chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",)
 815        };
 816        let args = shell_kind.args_for_shell(false, script.to_string());
 817        self.socket
 818            .run_command(shell_kind, "sh", &args, true)
 819            .await?;
 820        Ok(())
 821    }
 822
 823    fn build_scp_command(
 824        &self,
 825        src_path: &Path,
 826        dest_path_str: &str,
 827        args: Option<&[&str]>,
 828    ) -> process::Command {
 829        let mut command = util::command::new_smol_command("scp");
 830        self.socket.ssh_options(&mut command, false).args(
 831            self.socket
 832                .connection_options
 833                .port
 834                .map(|port| vec!["-P".to_string(), port.to_string()])
 835                .unwrap_or_default(),
 836        );
 837        if let Some(args) = args {
 838            command.args(args);
 839        }
 840        command.arg(src_path).arg(format!(
 841            "{}:{}",
 842            self.socket.connection_options.scp_url(),
 843            dest_path_str
 844        ));
 845        command
 846    }
 847
 848    fn build_sftp_command(&self) -> process::Command {
 849        let mut command = util::command::new_smol_command("sftp");
 850        self.socket.ssh_options(&mut command, false).args(
 851            self.socket
 852                .connection_options
 853                .port
 854                .map(|port| vec!["-P".to_string(), port.to_string()])
 855                .unwrap_or_default(),
 856        );
 857        command.arg("-b").arg("-");
 858        command.arg(self.socket.connection_options.scp_url());
 859        command.stdin(Stdio::piped());
 860        command
 861    }
 862
 863    async fn upload_file(&self, src_path: &Path, dest_path: &RelPath) -> Result<()> {
 864        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
 865
 866        let src_path_display = src_path.display().to_string();
 867        let dest_path_str = dest_path.display(self.path_style());
 868
 869        // We will try SFTP first, and if that fails, we will fall back to SCP.
 870        // If SCP fails also, we give up and return an error.
 871        // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
 872        // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
 873        // This is for example the case on Windows as evidenced by this implementation snippet:
 874        // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
 875        if Self::is_sftp_available().await {
 876            log::debug!("using SFTP for file upload");
 877            let mut command = self.build_sftp_command();
 878            let sftp_batch = format!("put {src_path_display} {dest_path_str}\n");
 879
 880            let mut child = command.spawn()?;
 881            if let Some(mut stdin) = child.stdin.take() {
 882                use futures::AsyncWriteExt;
 883                stdin.write_all(sftp_batch.as_bytes()).await?;
 884                stdin.flush().await?;
 885            }
 886
 887            let output = child.output().await?;
 888            if output.status.success() {
 889                return Ok(());
 890            }
 891
 892            let stderr = String::from_utf8_lossy(&output.stderr);
 893            log::debug!(
 894                "failed to upload file via SFTP {src_path_display} -> {dest_path_str}: {stderr}"
 895            );
 896        }
 897
 898        log::debug!("using SCP for file upload");
 899        let mut command = self.build_scp_command(src_path, &dest_path_str, None);
 900        let output = command.output().await?;
 901
 902        if output.status.success() {
 903            return Ok(());
 904        }
 905
 906        let stderr = String::from_utf8_lossy(&output.stderr);
 907        log::debug!(
 908            "failed to upload file via SCP {src_path_display} -> {dest_path_str}: {stderr}",
 909        );
 910        anyhow::bail!(
 911            "failed to upload file via STFP/SCP {} -> {}: {}",
 912            src_path_display,
 913            dest_path_str,
 914            stderr,
 915        );
 916    }
 917
 918    async fn is_sftp_available() -> bool {
 919        which::which("sftp").is_ok()
 920    }
 921}
 922
 923impl SshSocket {
 924    #[cfg(not(target_os = "windows"))]
 925    async fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 926        Ok(Self {
 927            connection_options: options,
 928            envs: HashMap::default(),
 929            socket_path,
 930        })
 931    }
 932
 933    #[cfg(target_os = "windows")]
 934    async fn new(
 935        options: SshConnectionOptions,
 936        password: askpass::EncryptedPassword,
 937        executor: gpui::BackgroundExecutor,
 938    ) -> Result<Self> {
 939        let mut envs = HashMap::default();
 940        let get_password =
 941            move |_| Task::ready(std::ops::ControlFlow::Continue(Ok(password.clone())));
 942
 943        let _proxy = askpass::PasswordProxy::new(get_password, executor).await?;
 944        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 945        envs.insert(
 946            "SSH_ASKPASS".into(),
 947            _proxy.script_path().as_ref().display().to_string(),
 948        );
 949
 950        Ok(Self {
 951            connection_options: options,
 952            envs,
 953            _proxy,
 954        })
 955    }
 956
 957    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 958    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 959    // and passes -l as an argument to sh, not to ls.
 960    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 961    // into a machine. You must use `cd` to get back to $HOME.
 962    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 963    fn ssh_command(
 964        &self,
 965        shell_kind: ShellKind,
 966        program: &str,
 967        args: &[impl AsRef<str>],
 968        allow_pseudo_tty: bool,
 969    ) -> process::Command {
 970        let mut command = util::command::new_smol_command("ssh");
 971        let program = shell_kind.prepend_command_prefix(program);
 972        let mut to_run = shell_kind
 973            .try_quote_prefix_aware(&program)
 974            .expect("shell quoting")
 975            .into_owned();
 976        for arg in args {
 977            // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 978            debug_assert!(
 979                !arg.as_ref().contains('\n'),
 980                "multiline arguments do not work in all shells"
 981            );
 982            to_run.push(' ');
 983            to_run.push_str(&shell_kind.try_quote(arg.as_ref()).expect("shell quoting"));
 984        }
 985        let separator = shell_kind.sequential_commands_separator();
 986        let to_run = format!("cd{separator} {to_run}");
 987        self.ssh_options(&mut command, true)
 988            .arg(self.connection_options.ssh_url());
 989        if !allow_pseudo_tty {
 990            command.arg("-T");
 991        }
 992        command.arg(to_run);
 993        log::debug!("ssh {:?}", command);
 994        command
 995    }
 996
 997    async fn run_command(
 998        &self,
 999        shell_kind: ShellKind,
1000        program: &str,
1001        args: &[impl AsRef<str>],
1002        allow_pseudo_tty: bool,
1003    ) -> Result<String> {
1004        let mut command = self.ssh_command(shell_kind, program, args, allow_pseudo_tty);
1005        let output = command.output().await?;
1006        anyhow::ensure!(
1007            output.status.success(),
1008            "failed to run command {command:?}: {}",
1009            String::from_utf8_lossy(&output.stderr)
1010        );
1011        Ok(String::from_utf8_lossy(&output.stdout).to_string())
1012    }
1013
1014    #[cfg(not(target_os = "windows"))]
1015    fn ssh_options<'a>(
1016        &self,
1017        command: &'a mut process::Command,
1018        include_port_forwards: bool,
1019    ) -> &'a mut process::Command {
1020        let args = if include_port_forwards {
1021            self.connection_options.additional_args()
1022        } else {
1023            self.connection_options.additional_args_for_scp()
1024        };
1025
1026        command
1027            .stdin(Stdio::piped())
1028            .stdout(Stdio::piped())
1029            .stderr(Stdio::piped())
1030            .args(args)
1031            .args(["-o", "ControlMaster=no", "-o"])
1032            .arg(format!("ControlPath={}", self.socket_path.display()))
1033    }
1034
1035    #[cfg(target_os = "windows")]
1036    fn ssh_options<'a>(
1037        &self,
1038        command: &'a mut process::Command,
1039        include_port_forwards: bool,
1040    ) -> &'a mut process::Command {
1041        let args = if include_port_forwards {
1042            self.connection_options.additional_args()
1043        } else {
1044            self.connection_options.additional_args_for_scp()
1045        };
1046
1047        command
1048            .stdin(Stdio::piped())
1049            .stdout(Stdio::piped())
1050            .stderr(Stdio::piped())
1051            .args(args)
1052            .envs(self.envs.clone())
1053    }
1054
1055    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1056    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1057    #[cfg(not(target_os = "windows"))]
1058    fn ssh_args(&self) -> Vec<String> {
1059        let mut arguments = self.connection_options.additional_args();
1060        arguments.extend(vec![
1061            "-o".to_string(),
1062            "ControlMaster=no".to_string(),
1063            "-o".to_string(),
1064            format!("ControlPath={}", self.socket_path.display()),
1065            self.connection_options.ssh_url(),
1066        ]);
1067        arguments
1068    }
1069
1070    #[cfg(target_os = "windows")]
1071    fn ssh_args(&self) -> Vec<String> {
1072        let mut arguments = self.connection_options.additional_args();
1073        arguments.push(self.connection_options.ssh_url());
1074        arguments
1075    }
1076
1077    async fn platform(&self, shell: ShellKind) -> Result<RemotePlatform> {
1078        let output = self.run_command(shell, "uname", &["-sm"], false).await?;
1079        parse_platform(&output)
1080    }
1081
1082    async fn shell(&self) -> String {
1083        const DEFAULT_SHELL: &str = "sh";
1084        match self
1085            .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false)
1086            .await
1087        {
1088            Ok(output) => parse_shell(&output, DEFAULT_SHELL),
1089            Err(e) => {
1090                log::error!("Failed to detect remote shell: {e}");
1091                DEFAULT_SHELL.to_owned()
1092            }
1093        }
1094    }
1095}
1096
1097fn parse_port_number(port_str: &str) -> Result<u16> {
1098    port_str
1099        .parse()
1100        .with_context(|| format!("parsing port number: {port_str}"))
1101}
1102
1103fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
1104    let parts: Vec<&str> = spec.split(':').collect();
1105
1106    match parts.len() {
1107        4 => {
1108            let local_port = parse_port_number(parts[1])?;
1109            let remote_port = parse_port_number(parts[3])?;
1110
1111            Ok(SshPortForwardOption {
1112                local_host: Some(parts[0].to_string()),
1113                local_port,
1114                remote_host: Some(parts[2].to_string()),
1115                remote_port,
1116            })
1117        }
1118        3 => {
1119            let local_port = parse_port_number(parts[0])?;
1120            let remote_port = parse_port_number(parts[2])?;
1121
1122            Ok(SshPortForwardOption {
1123                local_host: None,
1124                local_port,
1125                remote_host: Some(parts[1].to_string()),
1126                remote_port,
1127            })
1128        }
1129        _ => anyhow::bail!("Invalid port forward format"),
1130    }
1131}
1132
1133impl SshConnectionOptions {
1134    pub fn parse_command_line(input: &str) -> Result<Self> {
1135        let input = input.trim_start_matches("ssh ");
1136        let mut hostname: Option<String> = None;
1137        let mut username: Option<String> = None;
1138        let mut port: Option<u16> = None;
1139        let mut args = Vec::new();
1140        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
1141
1142        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
1143        const ALLOWED_OPTS: &[&str] = &[
1144            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
1145        ];
1146        const ALLOWED_ARGS: &[&str] = &[
1147            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
1148            "-w",
1149        ];
1150
1151        let mut tokens = ShellKind::Posix
1152            .split(input)
1153            .context("invalid input")?
1154            .into_iter();
1155
1156        'outer: while let Some(arg) = tokens.next() {
1157            if ALLOWED_OPTS.contains(&(&arg as &str)) {
1158                args.push(arg.to_string());
1159                continue;
1160            }
1161            if arg == "-p" {
1162                port = tokens.next().and_then(|arg| arg.parse().ok());
1163                continue;
1164            } else if let Some(p) = arg.strip_prefix("-p") {
1165                port = p.parse().ok();
1166                continue;
1167            }
1168            if arg == "-l" {
1169                username = tokens.next();
1170                continue;
1171            } else if let Some(l) = arg.strip_prefix("-l") {
1172                username = Some(l.to_string());
1173                continue;
1174            }
1175            if arg == "-L" || arg.starts_with("-L") {
1176                let forward_spec = if arg == "-L" {
1177                    tokens.next()
1178                } else {
1179                    Some(arg.strip_prefix("-L").unwrap().to_string())
1180                };
1181
1182                if let Some(spec) = forward_spec {
1183                    port_forwards.push(parse_port_forward_spec(&spec)?);
1184                } else {
1185                    anyhow::bail!("Missing port forward format");
1186                }
1187            }
1188
1189            for a in ALLOWED_ARGS {
1190                if arg == *a {
1191                    args.push(arg);
1192                    if let Some(next) = tokens.next() {
1193                        args.push(next);
1194                    }
1195                    continue 'outer;
1196                } else if arg.starts_with(a) {
1197                    args.push(arg);
1198                    continue 'outer;
1199                }
1200            }
1201            if arg.starts_with("-") || hostname.is_some() {
1202                anyhow::bail!("unsupported argument: {:?}", arg);
1203            }
1204            let mut input = &arg as &str;
1205            // Destination might be: username1@username2@ip2@ip1
1206            if let Some((u, rest)) = input.rsplit_once('@') {
1207                input = rest;
1208                username = Some(u.to_string());
1209            }
1210            if let Some((rest, p)) = input.split_once(':') {
1211                input = rest;
1212                port = p.parse().ok()
1213            }
1214            hostname = Some(input.to_string())
1215        }
1216
1217        let Some(hostname) = hostname else {
1218            anyhow::bail!("missing hostname");
1219        };
1220
1221        let port_forwards = match port_forwards.len() {
1222            0 => None,
1223            _ => Some(port_forwards),
1224        };
1225
1226        Ok(Self {
1227            host: hostname,
1228            username,
1229            port,
1230            port_forwards,
1231            args: Some(args),
1232            password: None,
1233            nickname: None,
1234            upload_binary_over_ssh: false,
1235            connection_timeout: None,
1236        })
1237    }
1238
1239    pub fn ssh_url(&self) -> String {
1240        let mut result = String::from("ssh://");
1241        if let Some(username) = &self.username {
1242            // Username might be: username1@username2@ip2
1243            let username = urlencoding::encode(username);
1244            result.push_str(&username);
1245            result.push('@');
1246        }
1247        result.push_str(&self.host);
1248        if let Some(port) = self.port {
1249            result.push(':');
1250            result.push_str(&port.to_string());
1251        }
1252        result
1253    }
1254
1255    pub fn additional_args_for_scp(&self) -> Vec<String> {
1256        self.args.iter().flatten().cloned().collect::<Vec<String>>()
1257    }
1258
1259    pub fn additional_args(&self) -> Vec<String> {
1260        let mut args = self.additional_args_for_scp();
1261
1262        if let Some(timeout) = self.connection_timeout {
1263            args.extend(["-o".to_string(), format!("ConnectTimeout={}", timeout)]);
1264        }
1265
1266        if let Some(forwards) = &self.port_forwards {
1267            args.extend(forwards.iter().map(|pf| {
1268                let local_host = match &pf.local_host {
1269                    Some(host) => host,
1270                    None => "localhost",
1271                };
1272                let remote_host = match &pf.remote_host {
1273                    Some(host) => host,
1274                    None => "localhost",
1275                };
1276
1277                format!(
1278                    "-L{}:{}:{}:{}",
1279                    local_host, pf.local_port, remote_host, pf.remote_port
1280                )
1281            }));
1282        }
1283
1284        args
1285    }
1286
1287    fn scp_url(&self) -> String {
1288        if let Some(username) = &self.username {
1289            format!("{}@{}", username, self.host)
1290        } else {
1291            self.host.clone()
1292        }
1293    }
1294
1295    pub fn connection_string(&self) -> String {
1296        let host = if let Some(username) = &self.username {
1297            format!("{}@{}", username, self.host)
1298        } else {
1299            self.host.clone()
1300        };
1301        if let Some(port) = &self.port {
1302            format!("{}:{}", host, port)
1303        } else {
1304            host
1305        }
1306    }
1307}
1308
1309fn build_command(
1310    input_program: Option<String>,
1311    input_args: &[String],
1312    input_env: &HashMap<String, String>,
1313    working_dir: Option<String>,
1314    port_forward: Option<(u16, String, u16)>,
1315    ssh_env: HashMap<String, String>,
1316    ssh_path_style: PathStyle,
1317    ssh_shell: &str,
1318    ssh_shell_kind: ShellKind,
1319    ssh_args: Vec<String>,
1320) -> Result<CommandTemplate> {
1321    use std::fmt::Write as _;
1322
1323    let mut exec = String::new();
1324    if let Some(working_dir) = working_dir {
1325        let working_dir = RemotePathBuf::new(working_dir, ssh_path_style).to_string();
1326
1327        // shlex will wrap the command in single quotes (''), disabling ~ expansion,
1328        // replace with something that works
1329        const TILDE_PREFIX: &'static str = "~/";
1330        if working_dir.starts_with(TILDE_PREFIX) {
1331            let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
1332            write!(
1333                exec,
1334                "cd \"$HOME/{working_dir}\" {} ",
1335                ssh_shell_kind.sequential_and_commands_separator()
1336            )?;
1337        } else {
1338            write!(
1339                exec,
1340                "cd \"{working_dir}\" {} ",
1341                ssh_shell_kind.sequential_and_commands_separator()
1342            )?;
1343        }
1344    } else {
1345        write!(
1346            exec,
1347            "cd {} ",
1348            ssh_shell_kind.sequential_and_commands_separator()
1349        )?;
1350    };
1351    write!(exec, "exec env ")?;
1352
1353    for (k, v) in input_env.iter() {
1354        write!(
1355            exec,
1356            "{}={} ",
1357            k,
1358            ssh_shell_kind.try_quote(v).context("shell quoting")?
1359        )?;
1360    }
1361
1362    if let Some(input_program) = input_program {
1363        write!(
1364            exec,
1365            "{}",
1366            ssh_shell_kind
1367                .try_quote_prefix_aware(&input_program)
1368                .context("shell quoting")?
1369        )?;
1370        for arg in input_args {
1371            let arg = ssh_shell_kind.try_quote(&arg).context("shell quoting")?;
1372            write!(exec, " {}", &arg)?;
1373        }
1374    } else {
1375        write!(exec, "{ssh_shell} -l")?;
1376    };
1377
1378    let mut args = Vec::new();
1379    args.extend(ssh_args);
1380
1381    if let Some((local_port, host, remote_port)) = port_forward {
1382        args.push("-L".into());
1383        args.push(format!("{local_port}:{host}:{remote_port}"));
1384    }
1385
1386    args.push("-t".into());
1387    args.push(exec);
1388
1389    Ok(CommandTemplate {
1390        program: "ssh".into(),
1391        args,
1392        env: ssh_env,
1393    })
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398    use super::*;
1399
1400    #[test]
1401    fn test_build_command() -> Result<()> {
1402        let mut input_env = HashMap::default();
1403        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1404        let mut env = HashMap::default();
1405        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1406
1407        let command = build_command(
1408            Some("remote_program".to_string()),
1409            &["arg1".to_string(), "arg2".to_string()],
1410            &input_env,
1411            Some("~/work".to_string()),
1412            None,
1413            env.clone(),
1414            PathStyle::Posix,
1415            "/bin/fish",
1416            ShellKind::Fish,
1417            vec!["-p".to_string(), "2222".to_string()],
1418        )?;
1419
1420        assert_eq!(command.program, "ssh");
1421        assert_eq!(
1422            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1423            [
1424                "-p",
1425                "2222",
1426                "-t",
1427                "cd \"$HOME/work\" && exec env INPUT_VA=val remote_program arg1 arg2"
1428            ]
1429        );
1430        assert_eq!(command.env, env);
1431
1432        let mut input_env = HashMap::default();
1433        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1434        let mut env = HashMap::default();
1435        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1436
1437        let command = build_command(
1438            None,
1439            &["arg1".to_string(), "arg2".to_string()],
1440            &input_env,
1441            None,
1442            Some((1, "foo".to_owned(), 2)),
1443            env.clone(),
1444            PathStyle::Posix,
1445            "/bin/fish",
1446            ShellKind::Fish,
1447            vec!["-p".to_string(), "2222".to_string()],
1448        )?;
1449
1450        assert_eq!(command.program, "ssh");
1451        assert_eq!(
1452            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1453            [
1454                "-p",
1455                "2222",
1456                "-L",
1457                "1:foo:2",
1458                "-t",
1459                "cd && exec env INPUT_VA=val /bin/fish -l"
1460            ]
1461        );
1462        assert_eq!(command.env, env);
1463
1464        Ok(())
1465    }
1466
1467    #[test]
1468    fn scp_args_exclude_port_forward_flags() {
1469        let options = SshConnectionOptions {
1470            host: "example.com".into(),
1471            args: Some(vec![
1472                "-p".to_string(),
1473                "2222".to_string(),
1474                "-o".to_string(),
1475                "StrictHostKeyChecking=no".to_string(),
1476            ]),
1477            port_forwards: Some(vec![SshPortForwardOption {
1478                local_host: Some("127.0.0.1".to_string()),
1479                local_port: 8080,
1480                remote_host: Some("127.0.0.1".to_string()),
1481                remote_port: 80,
1482            }]),
1483            ..Default::default()
1484        };
1485
1486        let ssh_args = options.additional_args();
1487        assert!(
1488            ssh_args.iter().any(|arg| arg.starts_with("-L")),
1489            "expected ssh args to include port-forward: {ssh_args:?}"
1490        );
1491
1492        let scp_args = options.additional_args_for_scp();
1493        assert_eq!(
1494            scp_args,
1495            vec![
1496                "-p".to_string(),
1497                "2222".to_string(),
1498                "-o".to_string(),
1499                "StrictHostKeyChecking=no".to_string(),
1500            ]
1501        );
1502    }
1503}