ssh.rs

   1use crate::{
   2    RemoteClientDelegate, RemotePlatform,
   3    remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
   4    transport::{parse_platform, parse_shell},
   5};
   6use anyhow::{Context as _, Result, anyhow};
   7use async_trait::async_trait;
   8use collections::HashMap;
   9use futures::{
  10    AsyncReadExt as _, FutureExt as _,
  11    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
  12    select_biased,
  13};
  14use gpui::{App, AppContext as _, AsyncApp, Task};
  15use parking_lot::Mutex;
  16use paths::remote_server_dir_relative;
  17use release_channel::{AppVersion, ReleaseChannel};
  18use rpc::proto::Envelope;
  19use semver::Version;
  20pub use settings::SshPortForwardOption;
  21use smol::{
  22    fs,
  23    process::{self, Child, Stdio},
  24};
  25use std::{
  26    net::IpAddr,
  27    path::{Path, PathBuf},
  28    sync::Arc,
  29    time::Instant,
  30};
  31use tempfile::TempDir;
  32use util::{
  33    paths::{PathStyle, RemotePathBuf},
  34    rel_path::RelPath,
  35    shell::{Shell, ShellKind},
  36    shell_builder::ShellBuilder,
  37};
  38
  39pub(crate) struct SshRemoteConnection {
  40    socket: SshSocket,
  41    master_process: Mutex<Option<MasterProcess>>,
  42    remote_binary_path: Option<Arc<RelPath>>,
  43    ssh_platform: RemotePlatform,
  44    ssh_path_style: PathStyle,
  45    ssh_shell: String,
  46    ssh_shell_kind: ShellKind,
  47    ssh_default_system_shell: String,
  48    _temp_dir: TempDir,
  49}
  50
  51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  52pub enum SshConnectionHost {
  53    IpAddr(IpAddr),
  54    Hostname(String),
  55}
  56
  57impl SshConnectionHost {
  58    pub fn to_bracketed_string(&self) -> String {
  59        match self {
  60            Self::IpAddr(IpAddr::V4(ip)) => ip.to_string(),
  61            Self::IpAddr(IpAddr::V6(ip)) => format!("[{}]", ip),
  62            Self::Hostname(hostname) => hostname.clone(),
  63        }
  64    }
  65
  66    pub fn to_string(&self) -> String {
  67        match self {
  68            Self::IpAddr(ip) => ip.to_string(),
  69            Self::Hostname(hostname) => hostname.clone(),
  70        }
  71    }
  72}
  73
  74impl From<&str> for SshConnectionHost {
  75    fn from(value: &str) -> Self {
  76        if let Ok(address) = value.parse() {
  77            Self::IpAddr(address)
  78        } else {
  79            Self::Hostname(value.to_string())
  80        }
  81    }
  82}
  83
  84impl From<String> for SshConnectionHost {
  85    fn from(value: String) -> Self {
  86        if let Ok(address) = value.parse() {
  87            Self::IpAddr(address)
  88        } else {
  89            Self::Hostname(value)
  90        }
  91    }
  92}
  93
  94impl Default for SshConnectionHost {
  95    fn default() -> Self {
  96        Self::Hostname(Default::default())
  97    }
  98}
  99
 100#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
 101pub struct SshConnectionOptions {
 102    pub host: SshConnectionHost,
 103    pub username: Option<String>,
 104    pub port: Option<u16>,
 105    pub password: Option<String>,
 106    pub args: Option<Vec<String>>,
 107    pub port_forwards: Option<Vec<SshPortForwardOption>>,
 108    pub connection_timeout: Option<u16>,
 109
 110    pub nickname: Option<String>,
 111    pub upload_binary_over_ssh: bool,
 112}
 113
 114impl From<settings::SshConnection> for SshConnectionOptions {
 115    fn from(val: settings::SshConnection) -> Self {
 116        SshConnectionOptions {
 117            host: val.host.to_string().into(),
 118            username: val.username,
 119            port: val.port,
 120            password: None,
 121            args: Some(val.args),
 122            nickname: val.nickname,
 123            upload_binary_over_ssh: val.upload_binary_over_ssh.unwrap_or_default(),
 124            port_forwards: val.port_forwards,
 125            connection_timeout: val.connection_timeout,
 126        }
 127    }
 128}
 129
 130struct SshSocket {
 131    connection_options: SshConnectionOptions,
 132    #[cfg(not(target_os = "windows"))]
 133    socket_path: std::path::PathBuf,
 134    envs: HashMap<String, String>,
 135    #[cfg(target_os = "windows")]
 136    _proxy: askpass::PasswordProxy,
 137}
 138
 139struct MasterProcess {
 140    process: Child,
 141}
 142
 143#[cfg(not(target_os = "windows"))]
 144impl MasterProcess {
 145    pub fn new(
 146        askpass_script_path: &std::ffi::OsStr,
 147        additional_args: Vec<String>,
 148        socket_path: &std::path::Path,
 149        destination: &str,
 150    ) -> Result<Self> {
 151        let args = [
 152            "-N",
 153            "-o",
 154            "ControlPersist=no",
 155            "-o",
 156            "ControlMaster=yes",
 157            "-o",
 158        ];
 159
 160        let mut master_process = util::command::new_smol_command("ssh");
 161        master_process
 162            .kill_on_drop(true)
 163            .stdin(Stdio::null())
 164            .stdout(Stdio::piped())
 165            .stderr(Stdio::piped())
 166            .env("SSH_ASKPASS_REQUIRE", "force")
 167            .env("SSH_ASKPASS", askpass_script_path)
 168            .args(additional_args)
 169            .args(args);
 170
 171        master_process.arg(format!("ControlPath={}", socket_path.display()));
 172
 173        let process = master_process.arg(&destination).spawn()?;
 174
 175        Ok(MasterProcess { process })
 176    }
 177
 178    pub async fn wait_connected(&mut self) -> Result<()> {
 179        let Some(mut stdout) = self.process.stdout.take() else {
 180            anyhow::bail!("ssh process stdout capture failed");
 181        };
 182
 183        let mut output = Vec::new();
 184        stdout.read_to_end(&mut output).await?;
 185        Ok(())
 186    }
 187}
 188
 189#[cfg(target_os = "windows")]
 190impl MasterProcess {
 191    const CONNECTION_ESTABLISHED_MAGIC: &str = "ZED_SSH_CONNECTION_ESTABLISHED";
 192
 193    pub fn new(
 194        askpass_script_path: &std::ffi::OsStr,
 195        additional_args: Vec<String>,
 196        destination: &str,
 197    ) -> Result<Self> {
 198        // On Windows, `ControlMaster` and `ControlPath` are not supported:
 199        // https://github.com/PowerShell/Win32-OpenSSH/issues/405
 200        // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
 201        //
 202        // Using an ugly workaround to detect connection establishment
 203        // -N doesn't work with JumpHosts as windows openssh never closes stdin in that case
 204        let args = [
 205            "-t",
 206            &format!("echo '{}'; exec $0", Self::CONNECTION_ESTABLISHED_MAGIC),
 207        ];
 208
 209        let mut master_process = util::command::new_smol_command("ssh");
 210        master_process
 211            .kill_on_drop(true)
 212            .stdin(Stdio::null())
 213            .stdout(Stdio::piped())
 214            .stderr(Stdio::piped())
 215            .env("SSH_ASKPASS_REQUIRE", "force")
 216            .env("SSH_ASKPASS", askpass_script_path)
 217            .args(additional_args)
 218            .arg(destination)
 219            .args(args);
 220
 221        let process = master_process.spawn()?;
 222
 223        Ok(MasterProcess { process })
 224    }
 225
 226    pub async fn wait_connected(&mut self) -> Result<()> {
 227        use smol::io::AsyncBufReadExt;
 228
 229        let Some(stdout) = self.process.stdout.take() else {
 230            anyhow::bail!("ssh process stdout capture failed");
 231        };
 232
 233        let mut reader = smol::io::BufReader::new(stdout);
 234
 235        let mut line = String::new();
 236
 237        loop {
 238            let n = reader.read_line(&mut line).await?;
 239            if n == 0 {
 240                anyhow::bail!("ssh process exited before connection established");
 241            }
 242
 243            if line.contains(Self::CONNECTION_ESTABLISHED_MAGIC) {
 244                return Ok(());
 245            }
 246        }
 247    }
 248}
 249
 250impl AsRef<Child> for MasterProcess {
 251    fn as_ref(&self) -> &Child {
 252        &self.process
 253    }
 254}
 255
 256impl AsMut<Child> for MasterProcess {
 257    fn as_mut(&mut self) -> &mut Child {
 258        &mut self.process
 259    }
 260}
 261
 262#[async_trait(?Send)]
 263impl RemoteConnection for SshRemoteConnection {
 264    async fn kill(&self) -> Result<()> {
 265        let Some(mut process) = self.master_process.lock().take() else {
 266            return Ok(());
 267        };
 268        process.as_mut().kill().ok();
 269        process.as_mut().status().await?;
 270        Ok(())
 271    }
 272
 273    fn has_been_killed(&self) -> bool {
 274        self.master_process.lock().is_none()
 275    }
 276
 277    fn connection_options(&self) -> RemoteConnectionOptions {
 278        RemoteConnectionOptions::Ssh(self.socket.connection_options.clone())
 279    }
 280
 281    fn shell(&self) -> String {
 282        self.ssh_shell.clone()
 283    }
 284
 285    fn default_system_shell(&self) -> String {
 286        self.ssh_default_system_shell.clone()
 287    }
 288
 289    fn build_command(
 290        &self,
 291        input_program: Option<String>,
 292        input_args: &[String],
 293        input_env: &HashMap<String, String>,
 294        working_dir: Option<String>,
 295        port_forward: Option<(u16, String, u16)>,
 296    ) -> Result<CommandTemplate> {
 297        let Self {
 298            ssh_path_style,
 299            socket,
 300            ssh_shell_kind,
 301            ssh_shell,
 302            ..
 303        } = self;
 304        let env = socket.envs.clone();
 305        build_command(
 306            input_program,
 307            input_args,
 308            input_env,
 309            working_dir,
 310            port_forward,
 311            env,
 312            *ssh_path_style,
 313            ssh_shell,
 314            *ssh_shell_kind,
 315            socket.ssh_args(),
 316        )
 317    }
 318
 319    fn build_forward_ports_command(
 320        &self,
 321        forwards: Vec<(u16, String, u16)>,
 322    ) -> Result<CommandTemplate> {
 323        let Self { socket, .. } = self;
 324        let mut args = socket.ssh_args();
 325        args.push("-N".into());
 326        for (local_port, host, remote_port) in forwards {
 327            args.push("-L".into());
 328            args.push(format!("{local_port}:{host}:{remote_port}"));
 329        }
 330        Ok(CommandTemplate {
 331            program: "ssh".into(),
 332            args,
 333            env: Default::default(),
 334        })
 335    }
 336
 337    fn upload_directory(
 338        &self,
 339        src_path: PathBuf,
 340        dest_path: RemotePathBuf,
 341        cx: &App,
 342    ) -> Task<Result<()>> {
 343        let dest_path_str = dest_path.to_string();
 344        let src_path_display = src_path.display().to_string();
 345
 346        let mut sftp_command = self.build_sftp_command();
 347        let mut scp_command =
 348            self.build_scp_command(&src_path, &dest_path_str, Some(&["-C", "-r"]));
 349
 350        cx.background_spawn(async move {
 351            // We will try SFTP first, and if that fails, we will fall back to SCP.
 352            // If SCP fails also, we give up and return an error.
 353            // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
 354            // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
 355            // This is for example the case on Windows as evidenced by this implementation snippet:
 356            // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
 357            if Self::is_sftp_available().await {
 358                log::debug!("using SFTP for directory upload");
 359                let mut child = sftp_command.spawn()?;
 360                if let Some(mut stdin) = child.stdin.take() {
 361                    use futures::AsyncWriteExt;
 362                    let sftp_batch = format!("put -r \"{src_path_display}\" \"{dest_path_str}\"\n");
 363                    stdin.write_all(sftp_batch.as_bytes()).await?;
 364                    stdin.flush().await?;
 365                }
 366
 367                let output = child.output().await?;
 368                if output.status.success() {
 369                    return Ok(());
 370                }
 371
 372                let stderr = String::from_utf8_lossy(&output.stderr);
 373                log::debug!("failed to upload directory via SFTP {src_path_display} -> {dest_path_str}: {stderr}");
 374            }
 375
 376            log::debug!("using SCP for directory upload");
 377            let output = scp_command.output().await?;
 378
 379            if output.status.success() {
 380                return Ok(());
 381            }
 382
 383            let stderr = String::from_utf8_lossy(&output.stderr);
 384            log::debug!("failed to upload directory via SCP {src_path_display} -> {dest_path_str}: {stderr}");
 385
 386            anyhow::bail!(
 387                "failed to upload directory via SFTP/SCP {} -> {}: {}",
 388                src_path_display,
 389                dest_path_str,
 390                stderr,
 391            );
 392        })
 393    }
 394
 395    fn start_proxy(
 396        &self,
 397        unique_identifier: String,
 398        reconnect: bool,
 399        incoming_tx: UnboundedSender<Envelope>,
 400        outgoing_rx: UnboundedReceiver<Envelope>,
 401        connection_activity_tx: Sender<()>,
 402        delegate: Arc<dyn RemoteClientDelegate>,
 403        cx: &mut AsyncApp,
 404    ) -> Task<Result<i32>> {
 405        delegate.set_status(Some("Starting proxy"), cx);
 406
 407        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
 408            return Task::ready(Err(anyhow!("Remote binary path not set")));
 409        };
 410
 411        let mut proxy_args = vec![];
 412        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
 413            if let Some(value) = std::env::var(env_var).ok() {
 414                proxy_args.push(format!("{}='{}'", env_var, value));
 415            }
 416        }
 417        proxy_args.push(remote_binary_path.display(self.path_style()).into_owned());
 418        proxy_args.push("proxy".to_owned());
 419        proxy_args.push("--identifier".to_owned());
 420        proxy_args.push(unique_identifier);
 421
 422        if reconnect {
 423            proxy_args.push("--reconnect".to_owned());
 424        }
 425
 426        let ssh_proxy_process = match self
 427            .socket
 428            .ssh_command(self.ssh_shell_kind, "env", &proxy_args, false)
 429            // IMPORTANT: we kill this process when we drop the task that uses it.
 430            .kill_on_drop(true)
 431            .spawn()
 432        {
 433            Ok(process) => process,
 434            Err(error) => {
 435                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
 436            }
 437        };
 438
 439        super::handle_rpc_messages_over_child_process_stdio(
 440            ssh_proxy_process,
 441            incoming_tx,
 442            outgoing_rx,
 443            connection_activity_tx,
 444            cx,
 445        )
 446    }
 447
 448    fn path_style(&self) -> PathStyle {
 449        self.ssh_path_style
 450    }
 451
 452    fn has_wsl_interop(&self) -> bool {
 453        false
 454    }
 455}
 456
 457impl SshRemoteConnection {
 458    pub(crate) async fn new(
 459        connection_options: SshConnectionOptions,
 460        delegate: Arc<dyn RemoteClientDelegate>,
 461        cx: &mut AsyncApp,
 462    ) -> Result<Self> {
 463        use askpass::AskPassResult;
 464
 465        let destination = connection_options.ssh_destination();
 466
 467        let temp_dir = tempfile::Builder::new()
 468            .prefix("zed-ssh-session")
 469            .tempdir()?;
 470        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
 471            let delegate = delegate.clone();
 472            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
 473        });
 474
 475        let mut askpass =
 476            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
 477
 478        delegate.set_status(Some("Connecting"), cx);
 479
 480        // Start the master SSH process, which does not do anything except for establish
 481        // the connection and keep it open, allowing other ssh commands to reuse it
 482        // via a control socket.
 483        #[cfg(not(target_os = "windows"))]
 484        let socket_path = temp_dir.path().join("ssh.sock");
 485
 486        #[cfg(target_os = "windows")]
 487        let mut master_process = MasterProcess::new(
 488            askpass.script_path().as_ref(),
 489            connection_options.additional_args(),
 490            &destination,
 491        )?;
 492        #[cfg(not(target_os = "windows"))]
 493        let mut master_process = MasterProcess::new(
 494            askpass.script_path().as_ref(),
 495            connection_options.additional_args(),
 496            &socket_path,
 497            &destination,
 498        )?;
 499
 500        let result = select_biased! {
 501            result = askpass.run().fuse() => {
 502                match result {
 503                    AskPassResult::CancelledByUser => {
 504                        master_process.as_mut().kill().ok();
 505                        anyhow::bail!("SSH connection canceled")
 506                    }
 507                    AskPassResult::Timedout => {
 508                        anyhow::bail!("connecting to host timed out")
 509                    }
 510                }
 511            }
 512            _ = master_process.wait_connected().fuse() => {
 513                anyhow::Ok(())
 514            }
 515        };
 516
 517        if let Err(e) = result {
 518            return Err(e.context("Failed to connect to host"));
 519        }
 520
 521        if master_process.as_mut().try_status()?.is_some() {
 522            let mut output = Vec::new();
 523            output.clear();
 524            let mut stderr = master_process.as_mut().stderr.take().unwrap();
 525            stderr.read_to_end(&mut output).await?;
 526
 527            let error_message = format!(
 528                "failed to connect: {}",
 529                String::from_utf8_lossy(&output).trim()
 530            );
 531            anyhow::bail!(error_message);
 532        }
 533
 534        #[cfg(not(target_os = "windows"))]
 535        let socket = SshSocket::new(connection_options, socket_path).await?;
 536        #[cfg(target_os = "windows")]
 537        let socket = SshSocket::new(
 538            connection_options,
 539            askpass
 540                .get_password()
 541                .or_else(|| askpass::EncryptedPassword::try_from("").ok())
 542                .context("Failed to fetch askpass password")?,
 543            cx.background_executor().clone(),
 544        )
 545        .await?;
 546        drop(askpass);
 547
 548        let ssh_shell = socket.shell().await;
 549        log::info!("Remote shell discovered: {}", ssh_shell);
 550        let ssh_platform = socket.platform(ShellKind::new(&ssh_shell, false)).await?;
 551        log::info!("Remote platform discovered: {:?}", ssh_platform);
 552        let ssh_path_style = match ssh_platform.os {
 553            "windows" => PathStyle::Windows,
 554            _ => PathStyle::Posix,
 555        };
 556        let ssh_default_system_shell = String::from("/bin/sh");
 557        let ssh_shell_kind = ShellKind::new(
 558            &ssh_shell,
 559            match ssh_platform.os {
 560                "windows" => true,
 561                _ => false,
 562            },
 563        );
 564
 565        let mut this = Self {
 566            socket,
 567            master_process: Mutex::new(Some(master_process)),
 568            _temp_dir: temp_dir,
 569            remote_binary_path: None,
 570            ssh_path_style,
 571            ssh_platform,
 572            ssh_shell,
 573            ssh_shell_kind,
 574            ssh_default_system_shell,
 575        };
 576
 577        let (release_channel, version) =
 578            cx.update(|cx| (ReleaseChannel::global(cx), AppVersion::global(cx)))?;
 579        this.remote_binary_path = Some(
 580            this.ensure_server_binary(&delegate, release_channel, version, cx)
 581                .await?,
 582        );
 583
 584        Ok(this)
 585    }
 586
 587    async fn ensure_server_binary(
 588        &self,
 589        delegate: &Arc<dyn RemoteClientDelegate>,
 590        release_channel: ReleaseChannel,
 591        version: Version,
 592        cx: &mut AsyncApp,
 593    ) -> Result<Arc<RelPath>> {
 594        let version_str = match release_channel {
 595            ReleaseChannel::Dev => "build".to_string(),
 596            _ => version.to_string(),
 597        };
 598        let binary_name = format!(
 599            "zed-remote-server-{}-{}",
 600            release_channel.dev_name(),
 601            version_str
 602        );
 603        let dst_path =
 604            paths::remote_server_dir_relative().join(RelPath::unix(&binary_name).unwrap());
 605
 606        #[cfg(debug_assertions)]
 607        if let Some(remote_server_path) =
 608            super::build_remote_server_from_source(&self.ssh_platform, delegate.as_ref(), cx)
 609                .await?
 610        {
 611            let tmp_path = paths::remote_server_dir_relative().join(
 612                RelPath::unix(&format!(
 613                    "download-{}-{}",
 614                    std::process::id(),
 615                    remote_server_path.file_name().unwrap().to_string_lossy()
 616                ))
 617                .unwrap(),
 618            );
 619            self.upload_local_server_binary(&remote_server_path, &tmp_path, delegate, cx)
 620                .await?;
 621            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
 622                .await?;
 623            return Ok(dst_path);
 624        }
 625
 626        if self
 627            .socket
 628            .run_command(
 629                self.ssh_shell_kind,
 630                &dst_path.display(self.path_style()),
 631                &["version"],
 632                true,
 633            )
 634            .await
 635            .is_ok()
 636        {
 637            return Ok(dst_path);
 638        }
 639
 640        let wanted_version = cx.update(|cx| match release_channel {
 641            ReleaseChannel::Nightly => Ok(None),
 642            ReleaseChannel::Dev => {
 643                anyhow::bail!(
 644                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
 645                    dst_path
 646                )
 647            }
 648            _ => Ok(Some(AppVersion::global(cx))),
 649        })??;
 650
 651        let tmp_path_gz = remote_server_dir_relative().join(
 652            RelPath::unix(&format!(
 653                "{}-download-{}.gz",
 654                binary_name,
 655                std::process::id()
 656            ))
 657            .unwrap(),
 658        );
 659        if !self.socket.connection_options.upload_binary_over_ssh
 660            && let Some(url) = delegate
 661                .get_download_url(
 662                    self.ssh_platform,
 663                    release_channel,
 664                    wanted_version.clone(),
 665                    cx,
 666                )
 667                .await?
 668        {
 669            match self
 670                .download_binary_on_server(&url, &tmp_path_gz, delegate, cx)
 671                .await
 672            {
 673                Ok(_) => {
 674                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 675                        .await
 676                        .context("extracting server binary")?;
 677                    return Ok(dst_path);
 678                }
 679                Err(e) => {
 680                    log::error!(
 681                        "Failed to download binary on server, attempting to download locally and then upload it the server: {e:#}",
 682                    )
 683                }
 684            }
 685        }
 686
 687        let src_path = delegate
 688            .download_server_binary_locally(
 689                self.ssh_platform,
 690                release_channel,
 691                wanted_version.clone(),
 692                cx,
 693            )
 694            .await
 695            .context("downloading server binary locally")?;
 696        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
 697            .await
 698            .context("uploading server binary")?;
 699        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 700            .await
 701            .context("extracting server binary")?;
 702        Ok(dst_path)
 703    }
 704
 705    async fn download_binary_on_server(
 706        &self,
 707        url: &str,
 708        tmp_path_gz: &RelPath,
 709        delegate: &Arc<dyn RemoteClientDelegate>,
 710        cx: &mut AsyncApp,
 711    ) -> Result<()> {
 712        if let Some(parent) = tmp_path_gz.parent() {
 713            self.socket
 714                .run_command(
 715                    self.ssh_shell_kind,
 716                    "mkdir",
 717                    &["-p", parent.display(self.path_style()).as_ref()],
 718                    true,
 719                )
 720                .await?;
 721        }
 722
 723        delegate.set_status(Some("Downloading remote development server on host"), cx);
 724
 725        let connection_timeout = self
 726            .socket
 727            .connection_options
 728            .connection_timeout
 729            .unwrap_or(10)
 730            .to_string();
 731
 732        match self
 733            .socket
 734            .run_command(
 735                self.ssh_shell_kind,
 736                "curl",
 737                &[
 738                    "-f",
 739                    "-L",
 740                    "--connect-timeout",
 741                    &connection_timeout,
 742                    url,
 743                    "-o",
 744                    &tmp_path_gz.display(self.path_style()),
 745                ],
 746                true,
 747            )
 748            .await
 749        {
 750            Ok(_) => {}
 751            Err(e) => {
 752                if self
 753                    .socket
 754                    .run_command(self.ssh_shell_kind, "which", &["curl"], true)
 755                    .await
 756                    .is_ok()
 757                {
 758                    return Err(e);
 759                }
 760
 761                log::info!("curl is not available, trying wget");
 762                match self
 763                    .socket
 764                    .run_command(
 765                        self.ssh_shell_kind,
 766                        "wget",
 767                        &[
 768                            "--connect-timeout",
 769                            &connection_timeout,
 770                            "--tries",
 771                            "1",
 772                            url,
 773                            "-O",
 774                            &tmp_path_gz.display(self.path_style()),
 775                        ],
 776                        true,
 777                    )
 778                    .await
 779                {
 780                    Ok(_) => {}
 781                    Err(e) => {
 782                        if self
 783                            .socket
 784                            .run_command(self.ssh_shell_kind, "which", &["wget"], true)
 785                            .await
 786                            .is_ok()
 787                        {
 788                            return Err(e);
 789                        } else {
 790                            anyhow::bail!("Neither curl nor wget is available");
 791                        }
 792                    }
 793                }
 794            }
 795        }
 796
 797        Ok(())
 798    }
 799
 800    async fn upload_local_server_binary(
 801        &self,
 802        src_path: &Path,
 803        tmp_path_gz: &RelPath,
 804        delegate: &Arc<dyn RemoteClientDelegate>,
 805        cx: &mut AsyncApp,
 806    ) -> Result<()> {
 807        if let Some(parent) = tmp_path_gz.parent() {
 808            self.socket
 809                .run_command(
 810                    self.ssh_shell_kind,
 811                    "mkdir",
 812                    &["-p", parent.display(self.path_style()).as_ref()],
 813                    true,
 814                )
 815                .await?;
 816        }
 817
 818        let src_stat = fs::metadata(&src_path).await?;
 819        let size = src_stat.len();
 820
 821        let t0 = Instant::now();
 822        delegate.set_status(Some("Uploading remote development server"), cx);
 823        log::info!(
 824            "uploading remote development server to {:?} ({}kb)",
 825            tmp_path_gz,
 826            size / 1024
 827        );
 828        self.upload_file(src_path, tmp_path_gz)
 829            .await
 830            .context("failed to upload server binary")?;
 831        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 832        Ok(())
 833    }
 834
 835    async fn extract_server_binary(
 836        &self,
 837        dst_path: &RelPath,
 838        tmp_path: &RelPath,
 839        delegate: &Arc<dyn RemoteClientDelegate>,
 840        cx: &mut AsyncApp,
 841    ) -> Result<()> {
 842        delegate.set_status(Some("Extracting remote development server"), cx);
 843        let server_mode = 0o755;
 844
 845        let shell_kind = ShellKind::Posix;
 846        let orig_tmp_path = tmp_path.display(self.path_style());
 847        let server_mode = format!("{:o}", server_mode);
 848        let server_mode = shell_kind
 849            .try_quote(&server_mode)
 850            .context("shell quoting")?;
 851        let dst_path = dst_path.display(self.path_style());
 852        let dst_path = shell_kind.try_quote(&dst_path).context("shell quoting")?;
 853        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
 854            let orig_tmp_path = shell_kind
 855                .try_quote(&orig_tmp_path)
 856                .context("shell quoting")?;
 857            let tmp_path = shell_kind.try_quote(&tmp_path).context("shell quoting")?;
 858            format!(
 859                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
 860            )
 861        } else {
 862            let orig_tmp_path = shell_kind
 863                .try_quote(&orig_tmp_path)
 864                .context("shell quoting")?;
 865            format!("chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",)
 866        };
 867        let args = shell_kind.args_for_shell(false, script.to_string());
 868        self.socket
 869            .run_command(shell_kind, "sh", &args, true)
 870            .await?;
 871        Ok(())
 872    }
 873
 874    fn build_scp_command(
 875        &self,
 876        src_path: &Path,
 877        dest_path_str: &str,
 878        args: Option<&[&str]>,
 879    ) -> process::Command {
 880        let mut command = util::command::new_smol_command("scp");
 881        self.socket.ssh_options(&mut command, false).args(
 882            self.socket
 883                .connection_options
 884                .port
 885                .map(|port| vec!["-P".to_string(), port.to_string()])
 886                .unwrap_or_default(),
 887        );
 888        if let Some(args) = args {
 889            command.args(args);
 890        }
 891        command.arg(src_path).arg(format!(
 892            "{}:{}",
 893            self.socket.connection_options.scp_destination(),
 894            dest_path_str
 895        ));
 896        command
 897    }
 898
 899    fn build_sftp_command(&self) -> process::Command {
 900        let mut command = util::command::new_smol_command("sftp");
 901        self.socket.ssh_options(&mut command, false).args(
 902            self.socket
 903                .connection_options
 904                .port
 905                .map(|port| vec!["-P".to_string(), port.to_string()])
 906                .unwrap_or_default(),
 907        );
 908        command.arg("-b").arg("-");
 909        command.arg(self.socket.connection_options.scp_destination());
 910        command.stdin(Stdio::piped());
 911        command
 912    }
 913
 914    async fn upload_file(&self, src_path: &Path, dest_path: &RelPath) -> Result<()> {
 915        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
 916
 917        let src_path_display = src_path.display().to_string();
 918        let dest_path_str = dest_path.display(self.path_style());
 919
 920        // We will try SFTP first, and if that fails, we will fall back to SCP.
 921        // If SCP fails also, we give up and return an error.
 922        // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
 923        // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
 924        // This is for example the case on Windows as evidenced by this implementation snippet:
 925        // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
 926        if Self::is_sftp_available().await {
 927            log::debug!("using SFTP for file upload");
 928            let mut command = self.build_sftp_command();
 929            let sftp_batch = format!("put {src_path_display} {dest_path_str}\n");
 930
 931            let mut child = command.spawn()?;
 932            if let Some(mut stdin) = child.stdin.take() {
 933                use futures::AsyncWriteExt;
 934                stdin.write_all(sftp_batch.as_bytes()).await?;
 935                stdin.flush().await?;
 936            }
 937
 938            let output = child.output().await?;
 939            if output.status.success() {
 940                return Ok(());
 941            }
 942
 943            let stderr = String::from_utf8_lossy(&output.stderr);
 944            log::debug!(
 945                "failed to upload file via SFTP {src_path_display} -> {dest_path_str}: {stderr}"
 946            );
 947        }
 948
 949        log::debug!("using SCP for file upload");
 950        let mut command = self.build_scp_command(src_path, &dest_path_str, None);
 951        let output = command.output().await?;
 952
 953        if output.status.success() {
 954            return Ok(());
 955        }
 956
 957        let stderr = String::from_utf8_lossy(&output.stderr);
 958        log::debug!(
 959            "failed to upload file via SCP {src_path_display} -> {dest_path_str}: {stderr}",
 960        );
 961        anyhow::bail!(
 962            "failed to upload file via STFP/SCP {} -> {}: {}",
 963            src_path_display,
 964            dest_path_str,
 965            stderr,
 966        );
 967    }
 968
 969    async fn is_sftp_available() -> bool {
 970        which::which("sftp").is_ok()
 971    }
 972}
 973
 974impl SshSocket {
 975    #[cfg(not(target_os = "windows"))]
 976    async fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 977        Ok(Self {
 978            connection_options: options,
 979            envs: HashMap::default(),
 980            socket_path,
 981        })
 982    }
 983
 984    #[cfg(target_os = "windows")]
 985    async fn new(
 986        options: SshConnectionOptions,
 987        password: askpass::EncryptedPassword,
 988        executor: gpui::BackgroundExecutor,
 989    ) -> Result<Self> {
 990        let mut envs = HashMap::default();
 991        let get_password =
 992            move |_| Task::ready(std::ops::ControlFlow::Continue(Ok(password.clone())));
 993
 994        let _proxy = askpass::PasswordProxy::new(get_password, executor).await?;
 995        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 996        envs.insert(
 997            "SSH_ASKPASS".into(),
 998            _proxy.script_path().as_ref().display().to_string(),
 999        );
1000
1001        Ok(Self {
1002            connection_options: options,
1003            envs,
1004            _proxy,
1005        })
1006    }
1007
1008    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
1009    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
1010    // and passes -l as an argument to sh, not to ls.
1011    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
1012    // into a machine. You must use `cd` to get back to $HOME.
1013    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
1014    fn ssh_command(
1015        &self,
1016        shell_kind: ShellKind,
1017        program: &str,
1018        args: &[impl AsRef<str>],
1019        allow_pseudo_tty: bool,
1020    ) -> process::Command {
1021        let mut command = util::command::new_smol_command("ssh");
1022        let program = shell_kind.prepend_command_prefix(program);
1023        let mut to_run = shell_kind
1024            .try_quote_prefix_aware(&program)
1025            .expect("shell quoting")
1026            .into_owned();
1027        for arg in args {
1028            // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
1029            debug_assert!(
1030                !arg.as_ref().contains('\n'),
1031                "multiline arguments do not work in all shells"
1032            );
1033            to_run.push(' ');
1034            to_run.push_str(&shell_kind.try_quote(arg.as_ref()).expect("shell quoting"));
1035        }
1036        let separator = shell_kind.sequential_commands_separator();
1037        let to_run = format!("cd{separator} {to_run}");
1038        self.ssh_options(&mut command, true)
1039            .arg(self.connection_options.ssh_destination());
1040        if !allow_pseudo_tty {
1041            command.arg("-T");
1042        }
1043        command.arg(to_run);
1044        log::debug!("ssh {:?}", command);
1045        command
1046    }
1047
1048    async fn run_command(
1049        &self,
1050        shell_kind: ShellKind,
1051        program: &str,
1052        args: &[impl AsRef<str>],
1053        allow_pseudo_tty: bool,
1054    ) -> Result<String> {
1055        let mut command = self.ssh_command(shell_kind, program, args, allow_pseudo_tty);
1056        let output = command.output().await?;
1057        anyhow::ensure!(
1058            output.status.success(),
1059            "failed to run command {command:?}: {}",
1060            String::from_utf8_lossy(&output.stderr)
1061        );
1062        Ok(String::from_utf8_lossy(&output.stdout).to_string())
1063    }
1064
1065    #[cfg(not(target_os = "windows"))]
1066    fn ssh_options<'a>(
1067        &self,
1068        command: &'a mut process::Command,
1069        include_port_forwards: bool,
1070    ) -> &'a mut process::Command {
1071        let args = if include_port_forwards {
1072            self.connection_options.additional_args()
1073        } else {
1074            self.connection_options.additional_args_for_scp()
1075        };
1076
1077        command
1078            .stdin(Stdio::piped())
1079            .stdout(Stdio::piped())
1080            .stderr(Stdio::piped())
1081            .args(args)
1082            .args(["-o", "ControlMaster=no", "-o"])
1083            .arg(format!("ControlPath={}", self.socket_path.display()))
1084    }
1085
1086    #[cfg(target_os = "windows")]
1087    fn ssh_options<'a>(
1088        &self,
1089        command: &'a mut process::Command,
1090        include_port_forwards: bool,
1091    ) -> &'a mut process::Command {
1092        let args = if include_port_forwards {
1093            self.connection_options.additional_args()
1094        } else {
1095            self.connection_options.additional_args_for_scp()
1096        };
1097
1098        command
1099            .stdin(Stdio::piped())
1100            .stdout(Stdio::piped())
1101            .stderr(Stdio::piped())
1102            .args(args)
1103            .envs(self.envs.clone())
1104    }
1105
1106    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1107    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1108    #[cfg(not(target_os = "windows"))]
1109    fn ssh_args(&self) -> Vec<String> {
1110        let mut arguments = self.connection_options.additional_args();
1111        arguments.extend(vec![
1112            "-o".to_string(),
1113            "ControlMaster=no".to_string(),
1114            "-o".to_string(),
1115            format!("ControlPath={}", self.socket_path.display()),
1116            self.connection_options.ssh_destination(),
1117        ]);
1118        arguments
1119    }
1120
1121    #[cfg(target_os = "windows")]
1122    fn ssh_args(&self) -> Vec<String> {
1123        let mut arguments = self.connection_options.additional_args();
1124        arguments.push(self.connection_options.ssh_destination());
1125        arguments
1126    }
1127
1128    async fn platform(&self, shell: ShellKind) -> Result<RemotePlatform> {
1129        let output = self.run_command(shell, "uname", &["-sm"], false).await?;
1130        parse_platform(&output)
1131    }
1132
1133    async fn shell(&self) -> String {
1134        const DEFAULT_SHELL: &str = "sh";
1135        match self
1136            .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false)
1137            .await
1138        {
1139            Ok(output) => parse_shell(&output, DEFAULT_SHELL),
1140            Err(e) => {
1141                log::error!("Failed to detect remote shell: {e}");
1142                DEFAULT_SHELL.to_owned()
1143            }
1144        }
1145    }
1146}
1147
1148fn parse_port_number(port_str: &str) -> Result<u16> {
1149    port_str
1150        .parse()
1151        .with_context(|| format!("parsing port number: {port_str}"))
1152}
1153
1154fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
1155    let parts: Vec<&str> = spec.split(':').collect();
1156
1157    match parts.len() {
1158        4 => {
1159            let local_port = parse_port_number(parts[1])?;
1160            let remote_port = parse_port_number(parts[3])?;
1161
1162            Ok(SshPortForwardOption {
1163                local_host: Some(parts[0].to_string()),
1164                local_port,
1165                remote_host: Some(parts[2].to_string()),
1166                remote_port,
1167            })
1168        }
1169        3 => {
1170            let local_port = parse_port_number(parts[0])?;
1171            let remote_port = parse_port_number(parts[2])?;
1172
1173            Ok(SshPortForwardOption {
1174                local_host: None,
1175                local_port,
1176                remote_host: Some(parts[1].to_string()),
1177                remote_port,
1178            })
1179        }
1180        _ => anyhow::bail!("Invalid port forward format"),
1181    }
1182}
1183
1184impl SshConnectionOptions {
1185    pub fn parse_command_line(input: &str) -> Result<Self> {
1186        let input = input.trim_start_matches("ssh ");
1187        let mut hostname: Option<String> = None;
1188        let mut username: Option<String> = None;
1189        let mut port: Option<u16> = None;
1190        let mut args = Vec::new();
1191        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
1192
1193        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
1194        const ALLOWED_OPTS: &[&str] = &[
1195            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
1196        ];
1197        const ALLOWED_ARGS: &[&str] = &[
1198            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
1199            "-w",
1200        ];
1201
1202        let mut tokens = ShellKind::Posix
1203            .split(input)
1204            .context("invalid input")?
1205            .into_iter();
1206
1207        'outer: while let Some(arg) = tokens.next() {
1208            if ALLOWED_OPTS.contains(&(&arg as &str)) {
1209                args.push(arg.to_string());
1210                continue;
1211            }
1212            if arg == "-p" {
1213                port = tokens.next().and_then(|arg| arg.parse().ok());
1214                continue;
1215            } else if let Some(p) = arg.strip_prefix("-p") {
1216                port = p.parse().ok();
1217                continue;
1218            }
1219            if arg == "-l" {
1220                username = tokens.next();
1221                continue;
1222            } else if let Some(l) = arg.strip_prefix("-l") {
1223                username = Some(l.to_string());
1224                continue;
1225            }
1226            if arg == "-L" || arg.starts_with("-L") {
1227                let forward_spec = if arg == "-L" {
1228                    tokens.next()
1229                } else {
1230                    Some(arg.strip_prefix("-L").unwrap().to_string())
1231                };
1232
1233                if let Some(spec) = forward_spec {
1234                    port_forwards.push(parse_port_forward_spec(&spec)?);
1235                } else {
1236                    anyhow::bail!("Missing port forward format");
1237                }
1238            }
1239
1240            for a in ALLOWED_ARGS {
1241                if arg == *a {
1242                    args.push(arg);
1243                    if let Some(next) = tokens.next() {
1244                        args.push(next);
1245                    }
1246                    continue 'outer;
1247                } else if arg.starts_with(a) {
1248                    args.push(arg);
1249                    continue 'outer;
1250                }
1251            }
1252            if arg.starts_with("-") || hostname.is_some() {
1253                anyhow::bail!("unsupported argument: {:?}", arg);
1254            }
1255            let mut input = &arg as &str;
1256            // Destination might be: username1@username2@ip2@ip1
1257            if let Some((u, rest)) = input.rsplit_once('@') {
1258                input = rest;
1259                username = Some(u.to_string());
1260            }
1261
1262            // Handle port parsing, accounting for IPv6 addresses
1263            // IPv6 addresses can be: 2001:db8::1 or [2001:db8::1]:22
1264            if input.starts_with('[') {
1265                if let Some((rest, p)) = input.rsplit_once("]:") {
1266                    input = rest.strip_prefix('[').unwrap_or(rest);
1267                    port = p.parse().ok();
1268                } else if input.ends_with(']') {
1269                    input = input.strip_prefix('[').unwrap_or(input);
1270                    input = input.strip_suffix(']').unwrap_or(input);
1271                }
1272            } else if let Some((rest, p)) = input.rsplit_once(':')
1273                && !rest.contains(":")
1274            {
1275                input = rest;
1276                port = p.parse().ok();
1277            }
1278
1279            hostname = Some(input.to_string())
1280        }
1281
1282        let Some(hostname) = hostname else {
1283            anyhow::bail!("missing hostname");
1284        };
1285
1286        let port_forwards = match port_forwards.len() {
1287            0 => None,
1288            _ => Some(port_forwards),
1289        };
1290
1291        Ok(Self {
1292            host: hostname.into(),
1293            username,
1294            port,
1295            port_forwards,
1296            args: Some(args),
1297            password: None,
1298            nickname: None,
1299            upload_binary_over_ssh: false,
1300            connection_timeout: None,
1301        })
1302    }
1303
1304    pub fn ssh_destination(&self) -> String {
1305        let mut result = String::default();
1306        if let Some(username) = &self.username {
1307            // Username might be: username1@username2@ip2
1308            let username = urlencoding::encode(username);
1309            result.push_str(&username);
1310            result.push('@');
1311        }
1312
1313        result.push_str(&self.host.to_string());
1314        result
1315    }
1316
1317    pub fn additional_args_for_scp(&self) -> Vec<String> {
1318        self.args.iter().flatten().cloned().collect::<Vec<String>>()
1319    }
1320
1321    pub fn additional_args(&self) -> Vec<String> {
1322        let mut args = self.additional_args_for_scp();
1323
1324        if let Some(timeout) = self.connection_timeout {
1325            args.extend(["-o".to_string(), format!("ConnectTimeout={}", timeout)]);
1326        }
1327
1328        if let Some(port) = self.port {
1329            args.push("-p".to_string());
1330            args.push(port.to_string());
1331        }
1332
1333        if let Some(forwards) = &self.port_forwards {
1334            args.extend(forwards.iter().map(|pf| {
1335                let local_host = match &pf.local_host {
1336                    Some(host) => host,
1337                    None => "localhost",
1338                };
1339                let remote_host = match &pf.remote_host {
1340                    Some(host) => host,
1341                    None => "localhost",
1342                };
1343
1344                format!(
1345                    "-L{}:{}:{}:{}",
1346                    local_host, pf.local_port, remote_host, pf.remote_port
1347                )
1348            }));
1349        }
1350
1351        args
1352    }
1353
1354    fn scp_destination(&self) -> String {
1355        if let Some(username) = &self.username {
1356            format!("{}@{}", username, self.host.to_bracketed_string())
1357        } else {
1358            self.host.to_string()
1359        }
1360    }
1361
1362    pub fn connection_string(&self) -> String {
1363        let host = if let Some(port) = &self.port {
1364            format!("{}:{}", self.host.to_bracketed_string(), port)
1365        } else {
1366            self.host.to_string()
1367        };
1368
1369        if let Some(username) = &self.username {
1370            format!("{}@{}", username, host)
1371        } else {
1372            host
1373        }
1374    }
1375}
1376
1377fn build_command(
1378    input_program: Option<String>,
1379    input_args: &[String],
1380    input_env: &HashMap<String, String>,
1381    working_dir: Option<String>,
1382    port_forward: Option<(u16, String, u16)>,
1383    ssh_env: HashMap<String, String>,
1384    ssh_path_style: PathStyle,
1385    ssh_shell: &str,
1386    ssh_shell_kind: ShellKind,
1387    ssh_args: Vec<String>,
1388) -> Result<CommandTemplate> {
1389    use std::fmt::Write as _;
1390
1391    let mut exec = String::new();
1392    if let Some(working_dir) = working_dir {
1393        let working_dir = RemotePathBuf::new(working_dir, ssh_path_style).to_string();
1394
1395        // shlex will wrap the command in single quotes (''), disabling ~ expansion,
1396        // replace with something that works
1397        const TILDE_PREFIX: &'static str = "~/";
1398        if working_dir.starts_with(TILDE_PREFIX) {
1399            let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
1400            write!(
1401                exec,
1402                "cd \"$HOME/{working_dir}\" {} ",
1403                ssh_shell_kind.sequential_and_commands_separator()
1404            )?;
1405        } else {
1406            write!(
1407                exec,
1408                "cd \"{working_dir}\" {} ",
1409                ssh_shell_kind.sequential_and_commands_separator()
1410            )?;
1411        }
1412    } else {
1413        write!(
1414            exec,
1415            "cd {} ",
1416            ssh_shell_kind.sequential_and_commands_separator()
1417        )?;
1418    };
1419    write!(exec, "exec env ")?;
1420
1421    for (k, v) in input_env.iter() {
1422        write!(
1423            exec,
1424            "{}={} ",
1425            k,
1426            ssh_shell_kind.try_quote(v).context("shell quoting")?
1427        )?;
1428    }
1429
1430    if let Some(input_program) = input_program {
1431        write!(
1432            exec,
1433            "{}",
1434            ssh_shell_kind
1435                .try_quote_prefix_aware(&input_program)
1436                .context("shell quoting")?
1437        )?;
1438        for arg in input_args {
1439            let arg = ssh_shell_kind.try_quote(&arg).context("shell quoting")?;
1440            write!(exec, " {}", &arg)?;
1441        }
1442    } else {
1443        write!(exec, "{ssh_shell} -l")?;
1444    };
1445    let (command, command_args) = ShellBuilder::new(&Shell::Program(ssh_shell.to_owned()), false)
1446        .build(Some(exec.clone()), &[]);
1447
1448    let mut args = Vec::new();
1449    args.extend(ssh_args);
1450
1451    if let Some((local_port, host, remote_port)) = port_forward {
1452        args.push("-L".into());
1453        args.push(format!("{local_port}:{host}:{remote_port}"));
1454    }
1455
1456    args.push("-t".into());
1457    args.push(command);
1458    args.extend(command_args);
1459
1460    Ok(CommandTemplate {
1461        program: "ssh".into(),
1462        args,
1463        env: ssh_env,
1464    })
1465}
1466
1467#[cfg(test)]
1468mod tests {
1469    use super::*;
1470
1471    #[test]
1472    fn test_build_command() -> Result<()> {
1473        let mut input_env = HashMap::default();
1474        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1475        let mut env = HashMap::default();
1476        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1477
1478        let command = build_command(
1479            Some("remote_program".to_string()),
1480            &["arg1".to_string(), "arg2".to_string()],
1481            &input_env,
1482            Some("~/work".to_string()),
1483            None,
1484            env.clone(),
1485            PathStyle::Posix,
1486            "/bin/fish",
1487            ShellKind::Fish,
1488            vec!["-p".to_string(), "2222".to_string()],
1489        )?;
1490
1491        assert_eq!(command.program, "ssh");
1492        assert_eq!(
1493            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1494            [
1495                "-p",
1496                "2222",
1497                "-t",
1498                "/bin/fish",
1499                "-i",
1500                "-c",
1501                "cd \"$HOME/work\" && exec env INPUT_VA=val remote_program arg1 arg2"
1502            ]
1503        );
1504        assert_eq!(command.env, env);
1505
1506        let mut input_env = HashMap::default();
1507        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1508        let mut env = HashMap::default();
1509        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1510
1511        let command = build_command(
1512            None,
1513            &["arg1".to_string(), "arg2".to_string()],
1514            &input_env,
1515            None,
1516            Some((1, "foo".to_owned(), 2)),
1517            env.clone(),
1518            PathStyle::Posix,
1519            "/bin/fish",
1520            ShellKind::Fish,
1521            vec!["-p".to_string(), "2222".to_string()],
1522        )?;
1523
1524        assert_eq!(command.program, "ssh");
1525        assert_eq!(
1526            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1527            [
1528                "-p",
1529                "2222",
1530                "-L",
1531                "1:foo:2",
1532                "-t",
1533                "/bin/fish",
1534                "-i",
1535                "-c",
1536                "cd && exec env INPUT_VA=val /bin/fish -l"
1537            ]
1538        );
1539        assert_eq!(command.env, env);
1540
1541        Ok(())
1542    }
1543
1544    #[test]
1545    fn scp_args_exclude_port_forward_flags() {
1546        let options = SshConnectionOptions {
1547            host: "example.com".into(),
1548            args: Some(vec![
1549                "-p".to_string(),
1550                "2222".to_string(),
1551                "-o".to_string(),
1552                "StrictHostKeyChecking=no".to_string(),
1553            ]),
1554            port_forwards: Some(vec![SshPortForwardOption {
1555                local_host: Some("127.0.0.1".to_string()),
1556                local_port: 8080,
1557                remote_host: Some("127.0.0.1".to_string()),
1558                remote_port: 80,
1559            }]),
1560            ..Default::default()
1561        };
1562
1563        let ssh_args = options.additional_args();
1564        assert!(
1565            ssh_args.iter().any(|arg| arg.starts_with("-L")),
1566            "expected ssh args to include port-forward: {ssh_args:?}"
1567        );
1568
1569        let scp_args = options.additional_args_for_scp();
1570        assert_eq!(
1571            scp_args,
1572            vec![
1573                "-p".to_string(),
1574                "2222".to_string(),
1575                "-o".to_string(),
1576                "StrictHostKeyChecking=no".to_string(),
1577            ]
1578        );
1579    }
1580
1581    #[test]
1582    fn test_host_parsing() -> Result<()> {
1583        let opts = SshConnectionOptions::parse_command_line("user@2001:db8::1")?;
1584        assert_eq!(opts.host, "2001:db8::1".into());
1585        assert_eq!(opts.username, Some("user".to_string()));
1586        assert_eq!(opts.port, None);
1587
1588        let opts = SshConnectionOptions::parse_command_line("user@[2001:db8::1]:2222")?;
1589        assert_eq!(opts.host, "2001:db8::1".into());
1590        assert_eq!(opts.username, Some("user".to_string()));
1591        assert_eq!(opts.port, Some(2222));
1592
1593        let opts = SshConnectionOptions::parse_command_line("user@[2001:db8::1]")?;
1594        assert_eq!(opts.host, "2001:db8::1".into());
1595        assert_eq!(opts.username, Some("user".to_string()));
1596        assert_eq!(opts.port, None);
1597
1598        let opts = SshConnectionOptions::parse_command_line("2001:db8::1")?;
1599        assert_eq!(opts.host, "2001:db8::1".into());
1600        assert_eq!(opts.username, None);
1601        assert_eq!(opts.port, None);
1602
1603        let opts = SshConnectionOptions::parse_command_line("[2001:db8::1]:2222")?;
1604        assert_eq!(opts.host, "2001:db8::1".into());
1605        assert_eq!(opts.username, None);
1606        assert_eq!(opts.port, Some(2222));
1607
1608        let opts = SshConnectionOptions::parse_command_line("user@example.com:2222")?;
1609        assert_eq!(opts.host, "example.com".into());
1610        assert_eq!(opts.username, Some("user".to_string()));
1611        assert_eq!(opts.port, Some(2222));
1612
1613        let opts = SshConnectionOptions::parse_command_line("user@192.168.1.1:2222")?;
1614        assert_eq!(opts.host, "192.168.1.1".into());
1615        assert_eq!(opts.username, Some("user".to_string()));
1616        assert_eq!(opts.port, Some(2222));
1617
1618        Ok(())
1619    }
1620}