ssh.rs

   1use crate::{
   2    RemoteArch, RemoteClientDelegate, RemoteOs, RemotePlatform,
   3    remote_client::{CommandTemplate, Interactive, RemoteConnection, RemoteConnectionOptions},
   4    transport::{parse_platform, parse_shell},
   5};
   6use anyhow::{Context as _, Result, anyhow};
   7use async_trait::async_trait;
   8use collections::HashMap;
   9use futures::{
  10    AsyncReadExt as _, FutureExt as _,
  11    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
  12    select_biased,
  13};
  14use gpui::{App, AppContext as _, AsyncApp, Task};
  15use parking_lot::Mutex;
  16use paths::remote_server_dir_relative;
  17use release_channel::{AppVersion, ReleaseChannel};
  18use rpc::proto::Envelope;
  19use semver::Version;
  20pub use settings::SshPortForwardOption;
  21use smol::fs;
  22use std::{
  23    net::IpAddr,
  24    path::{Path, PathBuf},
  25    sync::Arc,
  26    time::Instant,
  27};
  28use tempfile::TempDir;
  29use util::command::{Child, Stdio};
  30use util::{
  31    paths::{PathStyle, RemotePathBuf},
  32    rel_path::RelPath,
  33    shell::ShellKind,
  34};
  35
  36pub(crate) struct SshRemoteConnection {
  37    socket: SshSocket,
  38    master_process: Mutex<Option<MasterProcess>>,
  39    remote_binary_path: Option<Arc<RelPath>>,
  40    ssh_platform: RemotePlatform,
  41    ssh_path_style: PathStyle,
  42    ssh_shell: String,
  43    ssh_shell_kind: ShellKind,
  44    ssh_default_system_shell: String,
  45    _temp_dir: TempDir,
  46}
  47
  48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  49pub enum SshConnectionHost {
  50    IpAddr(IpAddr),
  51    Hostname(String),
  52}
  53
  54impl SshConnectionHost {
  55    pub fn to_bracketed_string(&self) -> String {
  56        match self {
  57            Self::IpAddr(IpAddr::V4(ip)) => ip.to_string(),
  58            Self::IpAddr(IpAddr::V6(ip)) => format!("[{}]", ip),
  59            Self::Hostname(hostname) => hostname.clone(),
  60        }
  61    }
  62
  63    pub fn to_string(&self) -> String {
  64        match self {
  65            Self::IpAddr(ip) => ip.to_string(),
  66            Self::Hostname(hostname) => hostname.clone(),
  67        }
  68    }
  69}
  70
  71impl From<&str> for SshConnectionHost {
  72    fn from(value: &str) -> Self {
  73        if let Ok(address) = value.parse() {
  74            Self::IpAddr(address)
  75        } else {
  76            Self::Hostname(value.to_string())
  77        }
  78    }
  79}
  80
  81impl From<String> for SshConnectionHost {
  82    fn from(value: String) -> Self {
  83        if let Ok(address) = value.parse() {
  84            Self::IpAddr(address)
  85        } else {
  86            Self::Hostname(value)
  87        }
  88    }
  89}
  90
  91impl Default for SshConnectionHost {
  92    fn default() -> Self {
  93        Self::Hostname(Default::default())
  94    }
  95}
  96
  97fn bracket_ipv6(host: &str) -> String {
  98    if host.contains(':') && !host.starts_with('[') {
  99        format!("[{}]", host)
 100    } else {
 101        host.to_string()
 102    }
 103}
 104
 105#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
 106pub struct SshConnectionOptions {
 107    pub host: SshConnectionHost,
 108    pub username: Option<String>,
 109    pub port: Option<u16>,
 110    pub password: Option<String>,
 111    pub args: Option<Vec<String>>,
 112    pub port_forwards: Option<Vec<SshPortForwardOption>>,
 113    pub connection_timeout: Option<u16>,
 114
 115    pub nickname: Option<String>,
 116    pub upload_binary_over_ssh: bool,
 117}
 118
 119impl From<settings::SshConnection> for SshConnectionOptions {
 120    fn from(val: settings::SshConnection) -> Self {
 121        SshConnectionOptions {
 122            host: val.host.to_string().into(),
 123            username: val.username,
 124            port: val.port,
 125            password: None,
 126            args: Some(val.args),
 127            nickname: val.nickname,
 128            upload_binary_over_ssh: val.upload_binary_over_ssh.unwrap_or_default(),
 129            port_forwards: val.port_forwards,
 130            connection_timeout: val.connection_timeout,
 131        }
 132    }
 133}
 134
 135struct SshSocket {
 136    connection_options: SshConnectionOptions,
 137    #[cfg(not(windows))]
 138    socket_path: std::path::PathBuf,
 139    /// Extra environment variables needed for the ssh process
 140    envs: HashMap<String, String>,
 141    #[cfg(windows)]
 142    _proxy: askpass::PasswordProxy,
 143}
 144
 145struct MasterProcess {
 146    process: Child,
 147}
 148
 149#[cfg(not(windows))]
 150impl MasterProcess {
 151    pub fn new(
 152        askpass_script_path: &std::ffi::OsStr,
 153        additional_args: Vec<String>,
 154        socket_path: &std::path::Path,
 155        destination: &str,
 156    ) -> Result<Self> {
 157        let args = [
 158            "-N",
 159            "-o",
 160            "ControlPersist=no",
 161            "-o",
 162            "ControlMaster=yes",
 163            "-o",
 164        ];
 165
 166        let mut master_process = util::command::new_command("ssh");
 167        master_process
 168            .kill_on_drop(true)
 169            .stdin(Stdio::null())
 170            .stdout(Stdio::piped())
 171            .stderr(Stdio::piped())
 172            .env("SSH_ASKPASS_REQUIRE", "force")
 173            .env("SSH_ASKPASS", askpass_script_path)
 174            .args(additional_args)
 175            .args(args);
 176
 177        master_process.arg(format!("ControlPath={}", socket_path.display()));
 178
 179        let process = master_process.arg(&destination).spawn()?;
 180
 181        Ok(MasterProcess { process })
 182    }
 183
 184    pub async fn wait_connected(&mut self) -> Result<()> {
 185        let Some(mut stdout) = self.process.stdout.take() else {
 186            anyhow::bail!("ssh process stdout capture failed");
 187        };
 188
 189        let mut output = Vec::new();
 190        stdout.read_to_end(&mut output).await?;
 191        Ok(())
 192    }
 193}
 194
 195#[cfg(windows)]
 196impl MasterProcess {
 197    const CONNECTION_ESTABLISHED_MAGIC: &str = "ZED_SSH_CONNECTION_ESTABLISHED";
 198
 199    pub fn new(
 200        askpass_script_path: &std::ffi::OsStr,
 201        additional_args: Vec<String>,
 202        destination: &str,
 203    ) -> Result<Self> {
 204        // On Windows, `ControlMaster` and `ControlPath` are not supported:
 205        // https://github.com/PowerShell/Win32-OpenSSH/issues/405
 206        // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
 207        //
 208        // Using an ugly workaround to detect connection establishment
 209        // -N doesn't work with JumpHosts as windows openssh never closes stdin in that case
 210        let args = [
 211            "-t",
 212            &format!("echo '{}'; exec $0", Self::CONNECTION_ESTABLISHED_MAGIC),
 213        ];
 214
 215        let mut master_process = util::command::new_command("ssh");
 216        master_process
 217            .kill_on_drop(true)
 218            .stdin(Stdio::null())
 219            .stdout(Stdio::piped())
 220            .stderr(Stdio::piped())
 221            .env("SSH_ASKPASS_REQUIRE", "force")
 222            .env("SSH_ASKPASS", askpass_script_path)
 223            .args(additional_args)
 224            .arg(destination)
 225            .args(args);
 226
 227        let process = master_process.spawn()?;
 228
 229        Ok(MasterProcess { process })
 230    }
 231
 232    pub async fn wait_connected(&mut self) -> Result<()> {
 233        use smol::io::AsyncBufReadExt;
 234
 235        let Some(stdout) = self.process.stdout.take() else {
 236            anyhow::bail!("ssh process stdout capture failed");
 237        };
 238
 239        let mut reader = smol::io::BufReader::new(stdout);
 240
 241        let mut line = String::new();
 242
 243        loop {
 244            let n = reader.read_line(&mut line).await?;
 245            if n == 0 {
 246                anyhow::bail!("ssh process exited before connection established");
 247            }
 248
 249            if line.contains(Self::CONNECTION_ESTABLISHED_MAGIC) {
 250                return Ok(());
 251            }
 252        }
 253    }
 254}
 255
 256impl AsRef<Child> for MasterProcess {
 257    fn as_ref(&self) -> &Child {
 258        &self.process
 259    }
 260}
 261
 262impl AsMut<Child> for MasterProcess {
 263    fn as_mut(&mut self) -> &mut Child {
 264        &mut self.process
 265    }
 266}
 267
 268#[async_trait(?Send)]
 269impl RemoteConnection for SshRemoteConnection {
 270    async fn kill(&self) -> Result<()> {
 271        let Some(mut process) = self.master_process.lock().take() else {
 272            return Ok(());
 273        };
 274        process.as_mut().kill().ok();
 275        process.as_mut().status().await?;
 276        Ok(())
 277    }
 278
 279    fn has_been_killed(&self) -> bool {
 280        self.master_process.lock().is_none()
 281    }
 282
 283    fn connection_options(&self) -> RemoteConnectionOptions {
 284        RemoteConnectionOptions::Ssh(self.socket.connection_options.clone())
 285    }
 286
 287    fn shell(&self) -> String {
 288        self.ssh_shell.clone()
 289    }
 290
 291    fn default_system_shell(&self) -> String {
 292        self.ssh_default_system_shell.clone()
 293    }
 294
 295    fn build_command(
 296        &self,
 297        input_program: Option<String>,
 298        input_args: &[String],
 299        input_env: &HashMap<String, String>,
 300        working_dir: Option<String>,
 301        port_forward: Option<(u16, String, u16)>,
 302        interactive: Interactive,
 303    ) -> Result<CommandTemplate> {
 304        let Self {
 305            ssh_path_style,
 306            socket,
 307            ssh_shell_kind,
 308            ssh_shell,
 309            ..
 310        } = self;
 311        let env = socket.envs.clone();
 312
 313        if self.ssh_platform.os.is_windows() {
 314            build_command_windows(
 315                input_program,
 316                input_args,
 317                input_env,
 318                working_dir,
 319                port_forward,
 320                env,
 321                *ssh_path_style,
 322                ssh_shell,
 323                *ssh_shell_kind,
 324                socket.ssh_command_options(),
 325                &socket.connection_options.ssh_destination(),
 326                interactive,
 327            )
 328        } else {
 329            build_command_posix(
 330                input_program,
 331                input_args,
 332                input_env,
 333                working_dir,
 334                port_forward,
 335                env,
 336                *ssh_path_style,
 337                ssh_shell,
 338                *ssh_shell_kind,
 339                socket.ssh_command_options(),
 340                &socket.connection_options.ssh_destination(),
 341                interactive,
 342            )
 343        }
 344    }
 345
 346    fn build_forward_ports_command(
 347        &self,
 348        forwards: Vec<(u16, String, u16)>,
 349    ) -> Result<CommandTemplate> {
 350        let Self { socket, .. } = self;
 351        let mut args = socket.ssh_command_options();
 352        args.push("-N".into());
 353        for (local_port, host, remote_port) in forwards {
 354            args.push("-L".into());
 355            args.push(format!(
 356                "{}:{}:{}",
 357                local_port,
 358                bracket_ipv6(&host),
 359                remote_port
 360            ));
 361        }
 362        args.push(socket.connection_options.ssh_destination());
 363        Ok(CommandTemplate {
 364            program: "ssh".into(),
 365            args,
 366            env: Default::default(),
 367        })
 368    }
 369
 370    fn upload_directory(
 371        &self,
 372        src_path: PathBuf,
 373        dest_path: RemotePathBuf,
 374        cx: &App,
 375    ) -> Task<Result<()>> {
 376        let dest_path_str = dest_path.to_string();
 377        let src_path_display = src_path.display().to_string();
 378
 379        let mut sftp_command = self.build_sftp_command();
 380        let mut scp_command =
 381            self.build_scp_command(&src_path, &dest_path_str, Some(&["-C", "-r"]));
 382
 383        cx.background_spawn(async move {
 384            // We will try SFTP first, and if that fails, we will fall back to SCP.
 385            // If SCP fails also, we give up and return an error.
 386            // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
 387            // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
 388            // This is for example the case on Windows as evidenced by this implementation snippet:
 389            // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
 390            if Self::is_sftp_available().await {
 391                log::debug!("using SFTP for directory upload");
 392                let mut child = sftp_command.spawn()?;
 393                if let Some(mut stdin) = child.stdin.take() {
 394                    use futures::AsyncWriteExt;
 395                    let sftp_batch = format!("put -r \"{src_path_display}\" \"{dest_path_str}\"\n");
 396                    stdin.write_all(sftp_batch.as_bytes()).await?;
 397                    stdin.flush().await?;
 398                }
 399
 400                let output = child.output().await?;
 401                if output.status.success() {
 402                    return Ok(());
 403                }
 404
 405                let stderr = String::from_utf8_lossy(&output.stderr);
 406                log::debug!("failed to upload directory via SFTP {src_path_display} -> {dest_path_str}: {stderr}");
 407            }
 408
 409            log::debug!("using SCP for directory upload");
 410            let output = scp_command.output().await?;
 411
 412            if output.status.success() {
 413                return Ok(());
 414            }
 415
 416            let stderr = String::from_utf8_lossy(&output.stderr);
 417            log::debug!("failed to upload directory via SCP {src_path_display} -> {dest_path_str}: {stderr}");
 418
 419            anyhow::bail!(
 420                "failed to upload directory via SFTP/SCP {} -> {}: {}",
 421                src_path_display,
 422                dest_path_str,
 423                stderr,
 424            );
 425        })
 426    }
 427
 428    fn start_proxy(
 429        &self,
 430        unique_identifier: String,
 431        reconnect: bool,
 432        incoming_tx: UnboundedSender<Envelope>,
 433        outgoing_rx: UnboundedReceiver<Envelope>,
 434        connection_activity_tx: Sender<()>,
 435        delegate: Arc<dyn RemoteClientDelegate>,
 436        cx: &mut AsyncApp,
 437    ) -> Task<Result<i32>> {
 438        const VARS: [&str; 3] = ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"];
 439        delegate.set_status(Some("Starting proxy"), cx);
 440
 441        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
 442            return Task::ready(Err(anyhow!("Remote binary path not set")));
 443        };
 444
 445        let mut ssh_command = if self.ssh_platform.os.is_windows() {
 446            // TODO: Set the `VARS` environment variables, we do not have `env` on windows
 447            // so this needs a different approach
 448            let mut proxy_args = vec![];
 449            proxy_args.push("proxy".to_owned());
 450            proxy_args.push("--identifier".to_owned());
 451            proxy_args.push(unique_identifier);
 452
 453            if reconnect {
 454                proxy_args.push("--reconnect".to_owned());
 455            }
 456            self.socket.ssh_command(
 457                self.ssh_shell_kind,
 458                &remote_binary_path.display(self.path_style()),
 459                &proxy_args,
 460                false,
 461            )
 462        } else {
 463            let mut proxy_args = vec![];
 464            for env_var in VARS {
 465                if let Some(value) = std::env::var(env_var).ok() {
 466                    proxy_args.push(format!("{}='{}'", env_var, value));
 467                }
 468            }
 469            proxy_args.push(remote_binary_path.display(self.path_style()).into_owned());
 470            proxy_args.push("proxy".to_owned());
 471            proxy_args.push("--identifier".to_owned());
 472            proxy_args.push(unique_identifier);
 473
 474            if reconnect {
 475                proxy_args.push("--reconnect".to_owned());
 476            }
 477            self.socket
 478                .ssh_command(self.ssh_shell_kind, "env", &proxy_args, false)
 479        };
 480
 481        let ssh_proxy_process = match ssh_command
 482            // IMPORTANT: we kill this process when we drop the task that uses it.
 483            .kill_on_drop(true)
 484            .spawn()
 485        {
 486            Ok(process) => process,
 487            Err(error) => {
 488                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
 489            }
 490        };
 491
 492        super::handle_rpc_messages_over_child_process_stdio(
 493            ssh_proxy_process,
 494            incoming_tx,
 495            outgoing_rx,
 496            connection_activity_tx,
 497            cx,
 498        )
 499    }
 500
 501    fn path_style(&self) -> PathStyle {
 502        self.ssh_path_style
 503    }
 504
 505    fn has_wsl_interop(&self) -> bool {
 506        false
 507    }
 508}
 509
 510impl SshRemoteConnection {
 511    pub(crate) async fn new(
 512        connection_options: SshConnectionOptions,
 513        delegate: Arc<dyn RemoteClientDelegate>,
 514        cx: &mut AsyncApp,
 515    ) -> Result<Self> {
 516        use askpass::AskPassResult;
 517
 518        let destination = connection_options.ssh_destination();
 519
 520        let temp_dir = tempfile::Builder::new()
 521            .prefix("zed-ssh-session")
 522            .tempdir()?;
 523        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
 524            let delegate = delegate.clone();
 525            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
 526        });
 527
 528        let mut askpass =
 529            askpass::AskPassSession::new(cx.background_executor().clone(), askpass_delegate)
 530                .await?;
 531
 532        delegate.set_status(Some("Connecting"), cx);
 533
 534        // Start the master SSH process, which does not do anything except for establish
 535        // the connection and keep it open, allowing other ssh commands to reuse it
 536        // via a control socket.
 537        #[cfg(not(windows))]
 538        let socket_path = temp_dir.path().join("ssh.sock");
 539
 540        #[cfg(windows)]
 541        let mut master_process = MasterProcess::new(
 542            askpass.script_path().as_ref(),
 543            connection_options.additional_args(),
 544            &destination,
 545        )?;
 546        #[cfg(not(windows))]
 547        let mut master_process = MasterProcess::new(
 548            askpass.script_path().as_ref(),
 549            connection_options.additional_args(),
 550            &socket_path,
 551            &destination,
 552        )?;
 553
 554        let result = select_biased! {
 555            result = askpass.run().fuse() => {
 556                match result {
 557                    AskPassResult::CancelledByUser => {
 558                        master_process.as_mut().kill().ok();
 559                        anyhow::bail!("SSH connection canceled")
 560                    }
 561                    AskPassResult::Timedout => {
 562                        anyhow::bail!("connecting to host timed out")
 563                    }
 564                }
 565            }
 566            _ = master_process.wait_connected().fuse() => {
 567                anyhow::Ok(())
 568            }
 569        };
 570
 571        if let Err(e) = result {
 572            return Err(e.context("Failed to connect to host"));
 573        }
 574
 575        if master_process.as_mut().try_status()?.is_some() {
 576            let mut output = Vec::new();
 577            output.clear();
 578            let mut stderr = master_process.as_mut().stderr.take().unwrap();
 579            stderr.read_to_end(&mut output).await?;
 580
 581            let error_message = format!(
 582                "failed to connect: {}",
 583                String::from_utf8_lossy(&output).trim()
 584            );
 585            anyhow::bail!(error_message);
 586        }
 587
 588        #[cfg(not(windows))]
 589        let socket = SshSocket::new(connection_options, socket_path).await?;
 590        #[cfg(windows)]
 591        let socket = SshSocket::new(
 592            connection_options,
 593            askpass
 594                .get_password()
 595                .or_else(|| askpass::EncryptedPassword::try_from("").ok())
 596                .context("Failed to fetch askpass password")?,
 597            cx.background_executor().clone(),
 598        )
 599        .await?;
 600        drop(askpass);
 601
 602        let is_windows = socket.probe_is_windows().await;
 603        log::info!("Remote is windows: {}", is_windows);
 604
 605        let ssh_shell = socket.shell(is_windows).await;
 606        log::info!("Remote shell discovered: {}", ssh_shell);
 607
 608        let ssh_shell_kind = ShellKind::new(&ssh_shell, is_windows);
 609        let ssh_platform = socket.platform(ssh_shell_kind, is_windows).await?;
 610        log::info!("Remote platform discovered: {:?}", ssh_platform);
 611
 612        let (ssh_path_style, ssh_default_system_shell) = match ssh_platform.os {
 613            RemoteOs::Windows => (PathStyle::Windows, ssh_shell.clone()),
 614            _ => (PathStyle::Posix, String::from("/bin/sh")),
 615        };
 616
 617        let mut this = Self {
 618            socket,
 619            master_process: Mutex::new(Some(master_process)),
 620            _temp_dir: temp_dir,
 621            remote_binary_path: None,
 622            ssh_path_style,
 623            ssh_platform,
 624            ssh_shell,
 625            ssh_shell_kind,
 626            ssh_default_system_shell,
 627        };
 628
 629        let (release_channel, version) =
 630            cx.update(|cx| (ReleaseChannel::global(cx), AppVersion::global(cx)));
 631        this.remote_binary_path = Some(
 632            this.ensure_server_binary(&delegate, release_channel, version, cx)
 633                .await?,
 634        );
 635
 636        Ok(this)
 637    }
 638
 639    async fn ensure_server_binary(
 640        &self,
 641        delegate: &Arc<dyn RemoteClientDelegate>,
 642        release_channel: ReleaseChannel,
 643        version: Version,
 644        cx: &mut AsyncApp,
 645    ) -> Result<Arc<RelPath>> {
 646        let version_str = match release_channel {
 647            ReleaseChannel::Dev => "build".to_string(),
 648            _ => version.to_string(),
 649        };
 650        let binary_name = format!(
 651            "zed-remote-server-{}-{}{}",
 652            release_channel.dev_name(),
 653            version_str,
 654            if self.ssh_platform.os.is_windows() {
 655                ".exe"
 656            } else {
 657                ""
 658            }
 659        );
 660        let dst_path =
 661            paths::remote_server_dir_relative().join(RelPath::unix(&binary_name).unwrap());
 662
 663        let binary_exists_on_server = self
 664            .socket
 665            .run_command(
 666                self.ssh_shell_kind,
 667                &dst_path.display(self.path_style()),
 668                &["version"],
 669                true,
 670            )
 671            .await
 672            .is_ok();
 673
 674        #[cfg(any(debug_assertions, feature = "build-remote-server-binary"))]
 675        if let Some(remote_server_path) = super::build_remote_server_from_source(
 676            &self.ssh_platform,
 677            delegate.as_ref(),
 678            binary_exists_on_server,
 679            cx,
 680        )
 681        .await?
 682        {
 683            let tmp_path = paths::remote_server_dir_relative().join(
 684                RelPath::unix(&format!(
 685                    "download-{}-{}",
 686                    std::process::id(),
 687                    remote_server_path.file_name().unwrap().to_string_lossy()
 688                ))
 689                .unwrap(),
 690            );
 691            self.upload_local_server_binary(&remote_server_path, &tmp_path, delegate, cx)
 692                .await?;
 693            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
 694                .await?;
 695            return Ok(dst_path);
 696        }
 697
 698        if binary_exists_on_server {
 699            return Ok(dst_path);
 700        }
 701
 702        let wanted_version = cx.update(|cx| match release_channel {
 703            ReleaseChannel::Nightly => Ok(None),
 704            ReleaseChannel::Dev => {
 705                anyhow::bail!(
 706                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
 707                    dst_path
 708                )
 709            }
 710            _ => Ok(Some(AppVersion::global(cx))),
 711        })?;
 712
 713        let tmp_path_compressed = remote_server_dir_relative().join(
 714            RelPath::unix(&format!(
 715                "{}-download-{}.{}",
 716                binary_name,
 717                std::process::id(),
 718                if self.ssh_platform.os.is_windows() {
 719                    "zip"
 720                } else {
 721                    "gz"
 722                }
 723            ))
 724            .unwrap(),
 725        );
 726        if !self.socket.connection_options.upload_binary_over_ssh
 727            && let Some(url) = delegate
 728                .get_download_url(
 729                    self.ssh_platform,
 730                    release_channel,
 731                    wanted_version.clone(),
 732                    cx,
 733                )
 734                .await?
 735        {
 736            match self
 737                .download_binary_on_server(&url, &tmp_path_compressed, delegate, cx)
 738                .await
 739            {
 740                Ok(_) => {
 741                    self.extract_server_binary(&dst_path, &tmp_path_compressed, delegate, cx)
 742                        .await
 743                        .context("extracting server binary")?;
 744                    return Ok(dst_path);
 745                }
 746                Err(e) => {
 747                    log::error!(
 748                        "Failed to download binary on server, attempting to download locally and then upload it the server: {e:#}",
 749                    )
 750                }
 751            }
 752        }
 753
 754        let src_path = delegate
 755            .download_server_binary_locally(
 756                self.ssh_platform,
 757                release_channel,
 758                wanted_version.clone(),
 759                cx,
 760            )
 761            .await
 762            .context("downloading server binary locally")?;
 763        self.upload_local_server_binary(&src_path, &tmp_path_compressed, delegate, cx)
 764            .await
 765            .context("uploading server binary")?;
 766        self.extract_server_binary(&dst_path, &tmp_path_compressed, delegate, cx)
 767            .await
 768            .context("extracting server binary")?;
 769        Ok(dst_path)
 770    }
 771
 772    async fn download_binary_on_server(
 773        &self,
 774        url: &str,
 775        tmp_path: &RelPath,
 776        delegate: &Arc<dyn RemoteClientDelegate>,
 777        cx: &mut AsyncApp,
 778    ) -> Result<()> {
 779        if let Some(parent) = tmp_path.parent() {
 780            let res = self
 781                .socket
 782                .run_command(
 783                    self.ssh_shell_kind,
 784                    "mkdir",
 785                    &["-p", parent.display(self.path_style()).as_ref()],
 786                    true,
 787                )
 788                .await;
 789            if !self.ssh_platform.os.is_windows() {
 790                // mkdir fails on windows if the path already exists ...
 791                res?;
 792            }
 793        }
 794
 795        delegate.set_status(Some("Downloading remote development server on host"), cx);
 796
 797        let connection_timeout = self
 798            .socket
 799            .connection_options
 800            .connection_timeout
 801            .unwrap_or(10)
 802            .to_string();
 803
 804        match self
 805            .socket
 806            .run_command(
 807                self.ssh_shell_kind,
 808                "curl",
 809                &[
 810                    "-f",
 811                    "-L",
 812                    "--connect-timeout",
 813                    &connection_timeout,
 814                    url,
 815                    "-o",
 816                    &tmp_path.display(self.path_style()),
 817                ],
 818                true,
 819            )
 820            .await
 821        {
 822            Ok(_) => {}
 823            Err(e) => {
 824                if self
 825                    .socket
 826                    .run_command(self.ssh_shell_kind, "which", &["curl"], true)
 827                    .await
 828                    .is_ok()
 829                {
 830                    return Err(e);
 831                }
 832
 833                log::info!("curl is not available, trying wget");
 834                match self
 835                    .socket
 836                    .run_command(
 837                        self.ssh_shell_kind,
 838                        "wget",
 839                        &[
 840                            "--connect-timeout",
 841                            &connection_timeout,
 842                            "--tries",
 843                            "1",
 844                            url,
 845                            "-O",
 846                            &tmp_path.display(self.path_style()),
 847                        ],
 848                        true,
 849                    )
 850                    .await
 851                {
 852                    Ok(_) => {}
 853                    Err(e) => {
 854                        if self
 855                            .socket
 856                            .run_command(self.ssh_shell_kind, "which", &["wget"], true)
 857                            .await
 858                            .is_ok()
 859                        {
 860                            return Err(e);
 861                        } else {
 862                            anyhow::bail!("Neither curl nor wget is available");
 863                        }
 864                    }
 865                }
 866            }
 867        }
 868
 869        Ok(())
 870    }
 871
 872    async fn upload_local_server_binary(
 873        &self,
 874        src_path: &Path,
 875        tmp_path: &RelPath,
 876        delegate: &Arc<dyn RemoteClientDelegate>,
 877        cx: &mut AsyncApp,
 878    ) -> Result<()> {
 879        if let Some(parent) = tmp_path.parent() {
 880            let res = self
 881                .socket
 882                .run_command(
 883                    self.ssh_shell_kind,
 884                    "mkdir",
 885                    &["-p", parent.display(self.path_style()).as_ref()],
 886                    true,
 887                )
 888                .await;
 889            if !self.ssh_platform.os.is_windows() {
 890                // mkdir fails on windows if the path already exists ...
 891                res?;
 892            }
 893        }
 894
 895        let src_stat = fs::metadata(&src_path)
 896            .await
 897            .with_context(|| format!("failed to get metadata for {:?}", src_path))?;
 898        let size = src_stat.len();
 899
 900        let t0 = Instant::now();
 901        delegate.set_status(Some("Uploading remote development server"), cx);
 902        log::info!(
 903            "uploading remote development server to {:?} ({}kb)",
 904            tmp_path,
 905            size / 1024
 906        );
 907        self.upload_file(src_path, tmp_path)
 908            .await
 909            .context("failed to upload server binary")?;
 910        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 911        Ok(())
 912    }
 913
 914    async fn extract_server_binary(
 915        &self,
 916        dst_path: &RelPath,
 917        tmp_path: &RelPath,
 918        delegate: &Arc<dyn RemoteClientDelegate>,
 919        cx: &mut AsyncApp,
 920    ) -> Result<()> {
 921        delegate.set_status(Some("Extracting remote development server"), cx);
 922
 923        if self.ssh_platform.os.is_windows() {
 924            self.extract_server_binary_windows(dst_path, tmp_path).await
 925        } else {
 926            self.extract_server_binary_posix(dst_path, tmp_path).await
 927        }
 928    }
 929
 930    async fn extract_server_binary_posix(
 931        &self,
 932        dst_path: &RelPath,
 933        tmp_path: &RelPath,
 934    ) -> Result<()> {
 935        let shell_kind = ShellKind::Posix;
 936        let server_mode = 0o755;
 937        let orig_tmp_path = tmp_path.display(self.path_style());
 938        let server_mode = format!("{:o}", server_mode);
 939        let server_mode = shell_kind
 940            .try_quote(&server_mode)
 941            .context("shell quoting")?;
 942        let dst_path = dst_path.display(self.path_style());
 943        let dst_path = shell_kind.try_quote(&dst_path).context("shell quoting")?;
 944        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
 945            let orig_tmp_path = shell_kind
 946                .try_quote(&orig_tmp_path)
 947                .context("shell quoting")?;
 948            let tmp_path = shell_kind.try_quote(&tmp_path).context("shell quoting")?;
 949            format!(
 950                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
 951            )
 952        } else {
 953            let orig_tmp_path = shell_kind
 954                .try_quote(&orig_tmp_path)
 955                .context("shell quoting")?;
 956            format!("chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",)
 957        };
 958        let args = shell_kind.args_for_shell(false, script.to_string());
 959        self.socket
 960            .run_command(self.ssh_shell_kind, "sh", &args, true)
 961            .await?;
 962        Ok(())
 963    }
 964
 965    async fn extract_server_binary_windows(
 966        &self,
 967        dst_path: &RelPath,
 968        tmp_path: &RelPath,
 969    ) -> Result<()> {
 970        let shell_kind = ShellKind::Pwsh;
 971        let orig_tmp_path = tmp_path.display(self.path_style());
 972        let dst_path = dst_path.display(self.path_style());
 973        let dst_path = shell_kind.try_quote(&dst_path).context("shell quoting")?;
 974
 975        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".zip") {
 976            let orig_tmp_path = shell_kind
 977                .try_quote(&orig_tmp_path)
 978                .context("shell quoting")?;
 979            let tmp_path = shell_kind.try_quote(tmp_path).context("shell quoting")?;
 980            let tmp_exe_path = format!("{tmp_path}\\remote_server.exe");
 981            let tmp_exe_path = shell_kind
 982                .try_quote(&tmp_exe_path)
 983                .context("shell quoting")?;
 984            format!(
 985                "Expand-Archive -Force -Path {orig_tmp_path} -DestinationPath {tmp_path} -ErrorAction Stop; Move-Item -Force {tmp_exe_path} {dst_path}; Remove-Item -Force {tmp_path} -Recurse; Remove-Item -Force {orig_tmp_path}",
 986            )
 987        } else {
 988            let orig_tmp_path = shell_kind
 989                .try_quote(&orig_tmp_path)
 990                .context("shell quoting")?;
 991            format!("Move-Item -Force {orig_tmp_path} {dst_path}")
 992        };
 993
 994        let args = shell_kind.args_for_shell(false, script);
 995        self.socket
 996            .run_command(self.ssh_shell_kind, "powershell", &args, true)
 997            .await?;
 998        Ok(())
 999    }
1000
1001    fn build_scp_command(
1002        &self,
1003        src_path: &Path,
1004        dest_path_str: &str,
1005        args: Option<&[&str]>,
1006    ) -> util::command::Command {
1007        let mut command = util::command::new_command("scp");
1008        self.socket.ssh_options(&mut command, false).args(
1009            self.socket
1010                .connection_options
1011                .port
1012                .map(|port| vec!["-P".to_string(), port.to_string()])
1013                .unwrap_or_default(),
1014        );
1015        if let Some(args) = args {
1016            command.args(args);
1017        }
1018        command.arg(src_path).arg(format!(
1019            "{}:{}",
1020            self.socket.connection_options.scp_destination(),
1021            dest_path_str
1022        ));
1023        command
1024    }
1025
1026    fn build_sftp_command(&self) -> util::command::Command {
1027        let mut command = util::command::new_command("sftp");
1028        self.socket.ssh_options(&mut command, false).args(
1029            self.socket
1030                .connection_options
1031                .port
1032                .map(|port| vec!["-P".to_string(), port.to_string()])
1033                .unwrap_or_default(),
1034        );
1035        command.arg("-b").arg("-");
1036        command.arg(self.socket.connection_options.scp_destination());
1037        command.stdin(Stdio::piped());
1038        command
1039    }
1040
1041    async fn upload_file(&self, src_path: &Path, dest_path: &RelPath) -> Result<()> {
1042        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
1043
1044        let src_path_display = src_path.display().to_string();
1045        let dest_path_str = dest_path.display(self.path_style());
1046
1047        // We will try SFTP first, and if that fails, we will fall back to SCP.
1048        // If SCP fails also, we give up and return an error.
1049        // The reason we allow a fallback from SFTP to SCP is that if the user has to specify a password,
1050        // depending on the implementation of SSH stack, SFTP may disable interactive password prompts in batch mode.
1051        // This is for example the case on Windows as evidenced by this implementation snippet:
1052        // https://github.com/PowerShell/openssh-portable/blob/b8c08ef9da9450a94a9c5ef717d96a7bd83f3332/sshconnect2.c#L417
1053        if Self::is_sftp_available().await {
1054            log::debug!("using SFTP for file upload");
1055            let mut command = self.build_sftp_command();
1056            let sftp_batch = format!("put {src_path_display} {dest_path_str}\n");
1057
1058            let mut child = command.spawn()?;
1059            if let Some(mut stdin) = child.stdin.take() {
1060                use futures::AsyncWriteExt;
1061                stdin.write_all(sftp_batch.as_bytes()).await?;
1062                stdin.flush().await?;
1063            }
1064
1065            let output = child.output().await?;
1066            if output.status.success() {
1067                return Ok(());
1068            }
1069
1070            let stderr = String::from_utf8_lossy(&output.stderr);
1071            log::debug!(
1072                "failed to upload file via SFTP {src_path_display} -> {dest_path_str}: {stderr}"
1073            );
1074        }
1075
1076        log::debug!("using SCP for file upload");
1077        let mut command = self.build_scp_command(src_path, &dest_path_str, None);
1078        let output = command.output().await?;
1079
1080        if output.status.success() {
1081            return Ok(());
1082        }
1083
1084        let stderr = String::from_utf8_lossy(&output.stderr);
1085        log::debug!(
1086            "failed to upload file via SCP {src_path_display} -> {dest_path_str}: {stderr}",
1087        );
1088        anyhow::bail!(
1089            "failed to upload file via STFP/SCP {} -> {}: {}",
1090            src_path_display,
1091            dest_path_str,
1092            stderr,
1093        );
1094    }
1095
1096    async fn is_sftp_available() -> bool {
1097        which::which("sftp").is_ok()
1098    }
1099}
1100
1101impl SshSocket {
1102    #[cfg(not(windows))]
1103    async fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
1104        Ok(Self {
1105            connection_options: options,
1106            envs: HashMap::default(),
1107            socket_path,
1108        })
1109    }
1110
1111    #[cfg(windows)]
1112    async fn new(
1113        options: SshConnectionOptions,
1114        password: askpass::EncryptedPassword,
1115        executor: gpui::BackgroundExecutor,
1116    ) -> Result<Self> {
1117        let mut envs = HashMap::default();
1118        let get_password =
1119            move |_| Task::ready(std::ops::ControlFlow::Continue(Ok(password.clone())));
1120
1121        let _proxy = askpass::PasswordProxy::new(Box::new(get_password), executor).await?;
1122        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
1123        envs.insert(
1124            "SSH_ASKPASS".into(),
1125            _proxy.script_path().as_ref().display().to_string(),
1126        );
1127
1128        Ok(Self {
1129            connection_options: options,
1130            envs,
1131            _proxy,
1132        })
1133    }
1134
1135    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
1136    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
1137    // and passes -l as an argument to sh, not to ls.
1138    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
1139    // into a machine. You must use `cd` to get back to $HOME.
1140    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
1141    fn ssh_command(
1142        &self,
1143        shell_kind: ShellKind,
1144        program: &str,
1145        args: &[impl AsRef<str>],
1146        allow_pseudo_tty: bool,
1147    ) -> util::command::Command {
1148        let mut command = util::command::new_command("ssh");
1149        let program = shell_kind.prepend_command_prefix(program);
1150        let mut to_run = shell_kind
1151            .try_quote_prefix_aware(&program)
1152            .expect("shell quoting")
1153            .into_owned();
1154        for arg in args {
1155            // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
1156            debug_assert!(
1157                !arg.as_ref().contains('\n'),
1158                "multiline arguments do not work in all shells"
1159            );
1160            to_run.push(' ');
1161            to_run.push_str(&shell_kind.try_quote(arg.as_ref()).expect("shell quoting"));
1162        }
1163        let to_run = if shell_kind == ShellKind::Cmd {
1164            to_run // 'cd' prints the current directory in CMD
1165        } else {
1166            let separator = shell_kind.sequential_commands_separator();
1167            format!("cd{separator} {to_run}")
1168        };
1169        self.ssh_options(&mut command, true)
1170            .arg(self.connection_options.ssh_destination());
1171        if !allow_pseudo_tty {
1172            command.arg("-T");
1173        }
1174        command.arg(to_run);
1175        log::debug!("ssh {:?}", command);
1176        command
1177    }
1178
1179    async fn run_command(
1180        &self,
1181        shell_kind: ShellKind,
1182        program: &str,
1183        args: &[impl AsRef<str>],
1184        allow_pseudo_tty: bool,
1185    ) -> Result<String> {
1186        let mut command = self.ssh_command(shell_kind, program, args, allow_pseudo_tty);
1187        let output = command.output().await?;
1188        log::debug!("{:?}: {:?}", command, output);
1189        anyhow::ensure!(
1190            output.status.success(),
1191            "failed to run command {command:?}: {}",
1192            String::from_utf8_lossy(&output.stderr)
1193        );
1194        Ok(String::from_utf8_lossy(&output.stdout).to_string())
1195    }
1196
1197    fn ssh_options<'a>(
1198        &self,
1199        command: &'a mut util::command::Command,
1200        include_port_forwards: bool,
1201    ) -> &'a mut util::command::Command {
1202        let args = if include_port_forwards {
1203            self.connection_options.additional_args()
1204        } else {
1205            self.connection_options.additional_args_for_scp()
1206        };
1207
1208        let cmd = command
1209            .stdin(Stdio::piped())
1210            .stdout(Stdio::piped())
1211            .stderr(Stdio::piped())
1212            .args(args);
1213
1214        if cfg!(windows) {
1215            cmd.envs(self.envs.clone());
1216        }
1217        #[cfg(not(windows))]
1218        {
1219            cmd.args(["-o", "ControlMaster=no", "-o"])
1220                .arg(format!("ControlPath={}", self.socket_path.display()));
1221        }
1222        cmd
1223    }
1224
1225    // Returns the SSH command-line options (without the destination) for building commands.
1226    // On Linux, this includes the ControlPath option to reuse the existing connection.
1227    // Note: The destination must be added separately after all options to ensure proper
1228    // SSH command structure: ssh [options] destination [command]
1229    fn ssh_command_options(&self) -> Vec<String> {
1230        let arguments = self.connection_options.additional_args();
1231        #[cfg(not(windows))]
1232        let arguments = {
1233            let mut args = arguments;
1234            args.extend(vec![
1235                "-o".to_string(),
1236                "ControlMaster=no".to_string(),
1237                "-o".to_string(),
1238                format!("ControlPath={}", self.socket_path.display()),
1239            ]);
1240            args
1241        };
1242        arguments
1243    }
1244
1245    async fn platform(&self, shell: ShellKind, is_windows: bool) -> Result<RemotePlatform> {
1246        if is_windows {
1247            self.platform_windows(shell).await
1248        } else {
1249            self.platform_posix(shell).await
1250        }
1251    }
1252
1253    async fn platform_posix(&self, shell: ShellKind) -> Result<RemotePlatform> {
1254        let output = self
1255            .run_command(shell, "uname", &["-sm"], false)
1256            .await
1257            .context("Failed to run 'uname -sm' to determine platform")?;
1258        parse_platform(&output)
1259    }
1260
1261    async fn platform_windows(&self, shell: ShellKind) -> Result<RemotePlatform> {
1262        let output = self
1263            .run_command(
1264                shell,
1265                "cmd.exe",
1266                &["/c", "echo", "%PROCESSOR_ARCHITECTURE%"],
1267                false,
1268            )
1269            .await
1270            .context(
1271                "Failed to run 'echo %PROCESSOR_ARCHITECTURE%' to determine Windows architecture",
1272            )?;
1273
1274        Ok(RemotePlatform {
1275            os: RemoteOs::Windows,
1276            arch: match output.trim() {
1277                "AMD64" => RemoteArch::X86_64,
1278                "ARM64" => RemoteArch::Aarch64,
1279                arch => anyhow::bail!(
1280                    "Prebuilt remote servers are not yet available for windows-{arch}. See https://zed.dev/docs/remote-development"
1281                ),
1282            },
1283        })
1284    }
1285
1286    /// Probes whether the remote host is running Windows.
1287    ///
1288    /// This is done by attempting to run a simple Windows-specific command.
1289    /// If it succeeds and returns Windows-like output, we assume it's Windows.
1290    async fn probe_is_windows(&self) -> bool {
1291        match self
1292            .run_command(ShellKind::Cmd, "cmd.exe", &["/c", "ver"], false)
1293            .await
1294        {
1295            // Windows 'ver' command outputs something like "Microsoft Windows [Version 10.0.19045.5011]"
1296            Ok(output) => output.trim().contains("indows"),
1297            Err(_) => false,
1298        }
1299    }
1300
1301    async fn shell(&self, is_windows: bool) -> String {
1302        if is_windows {
1303            self.shell_windows().await
1304        } else {
1305            self.shell_posix().await
1306        }
1307    }
1308
1309    async fn shell_posix(&self) -> String {
1310        const DEFAULT_SHELL: &str = "sh";
1311        match self
1312            .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false)
1313            .await
1314        {
1315            Ok(output) => parse_shell(&output, DEFAULT_SHELL),
1316            Err(e) => {
1317                log::error!("Failed to detect remote shell: {e}");
1318                DEFAULT_SHELL.to_owned()
1319            }
1320        }
1321    }
1322
1323    async fn shell_windows(&self) -> String {
1324        const DEFAULT_SHELL: &str = "cmd.exe";
1325
1326        // We detect the shell used by the SSH session by running the following command in PowerShell:
1327        // (Get-CimInstance Win32_Process -Filter "ProcessId = $((Get-CimInstance Win32_Process -Filter ProcessId=$PID).ParentProcessId)").Name
1328        // This prints the name of PowerShell's parent process (which will be the shell that SSH launched).
1329        // We pass it as a Base64 encoded string since we don't yet know how to correctly quote that command.
1330        // (We'd need to know what the shell is to do that...)
1331        match self
1332            .run_command(
1333                ShellKind::Cmd,
1334                "powershell",
1335                &[
1336                    "-E",
1337                    "KABHAGUAdAAtAEMAaQBtAEkAbgBzAHQAYQBuAGMAZQAgAFcAaQBuADMAMgBfAFAAcgBvAGMAZQBzAHMAIAAtAEYAaQBsAHQAZQByACAAIgBQAHIAbwBjAGUAcwBzAEkAZAAgAD0AIAAkACgAKABHAGUAdAAtAEMAaQBtAEkAbgBzAHQAYQBuAGMAZQAgAFcAaQBuADMAMgBfAFAAcgBvAGMAZQBzAHMAIAAtAEYAaQBsAHQAZQByACAAUAByAG8AYwBlAHMAcwBJAGQAPQAkAFAASQBEACkALgBQAGEAcgBlAG4AdABQAHIAbwBjAGUAcwBzAEkAZAApACIAKQAuAE4AYQBtAGUA",
1338                ],
1339                false,
1340            )
1341            .await
1342        {
1343            Ok(output) => parse_shell(&output, DEFAULT_SHELL),
1344            Err(e) => {
1345                log::error!("Failed to detect remote shell: {e}");
1346                DEFAULT_SHELL.to_owned()
1347            }
1348        }
1349    }
1350}
1351
1352fn parse_port_number(port_str: &str) -> Result<u16> {
1353    port_str
1354        .parse()
1355        .with_context(|| format!("parsing port number: {port_str}"))
1356}
1357
1358fn split_port_forward_tokens(spec: &str) -> Result<Vec<String>> {
1359    let mut tokens = Vec::new();
1360    let mut chars = spec.chars().peekable();
1361
1362    while chars.peek().is_some() {
1363        if chars.peek() == Some(&'[') {
1364            chars.next();
1365            let mut bracket_content = String::new();
1366            loop {
1367                match chars.next() {
1368                    Some(']') => break,
1369                    Some(ch) => bracket_content.push(ch),
1370                    None => anyhow::bail!("Unmatched '[' in port forward spec: {spec}"),
1371                }
1372            }
1373            tokens.push(bracket_content);
1374            if chars.peek() == Some(&':') {
1375                chars.next();
1376            }
1377        } else {
1378            let mut token = String::new();
1379            for ch in chars.by_ref() {
1380                if ch == ':' {
1381                    break;
1382                }
1383                token.push(ch);
1384            }
1385            tokens.push(token);
1386        }
1387    }
1388
1389    Ok(tokens)
1390}
1391
1392fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
1393    let tokens = if spec.contains('[') {
1394        split_port_forward_tokens(spec)?
1395    } else {
1396        spec.split(':').map(String::from).collect()
1397    };
1398
1399    match tokens.len() {
1400        4 => {
1401            let local_port = parse_port_number(&tokens[1])?;
1402            let remote_port = parse_port_number(&tokens[3])?;
1403
1404            Ok(SshPortForwardOption {
1405                local_host: Some(tokens[0].clone()),
1406                local_port,
1407                remote_host: Some(tokens[2].clone()),
1408                remote_port,
1409            })
1410        }
1411        3 => {
1412            let local_port = parse_port_number(&tokens[0])?;
1413            let remote_port = parse_port_number(&tokens[2])?;
1414
1415            Ok(SshPortForwardOption {
1416                local_host: None,
1417                local_port,
1418                remote_host: Some(tokens[1].clone()),
1419                remote_port,
1420            })
1421        }
1422        _ => anyhow::bail!("Invalid port forward format: {spec}"),
1423    }
1424}
1425
1426impl SshConnectionOptions {
1427    pub fn parse_command_line(input: &str) -> Result<Self> {
1428        let input = input.trim_start_matches("ssh ");
1429        let mut hostname: Option<String> = None;
1430        let mut username: Option<String> = None;
1431        let mut port: Option<u16> = None;
1432        let mut args = Vec::new();
1433        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
1434
1435        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
1436        const ALLOWED_OPTS: &[&str] = &[
1437            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
1438        ];
1439        const ALLOWED_ARGS: &[&str] = &[
1440            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
1441            "-w",
1442        ];
1443
1444        let mut tokens = ShellKind::Posix
1445            .split(input)
1446            .context("invalid input")?
1447            .into_iter();
1448
1449        'outer: while let Some(arg) = tokens.next() {
1450            if ALLOWED_OPTS.contains(&(&arg as &str)) {
1451                args.push(arg.to_string());
1452                continue;
1453            }
1454            if arg == "-p" {
1455                port = tokens.next().and_then(|arg| arg.parse().ok());
1456                continue;
1457            } else if let Some(p) = arg.strip_prefix("-p") {
1458                port = p.parse().ok();
1459                continue;
1460            }
1461            if arg == "-l" {
1462                username = tokens.next();
1463                continue;
1464            } else if let Some(l) = arg.strip_prefix("-l") {
1465                username = Some(l.to_string());
1466                continue;
1467            }
1468            if arg == "-L" || arg.starts_with("-L") {
1469                let forward_spec = if arg == "-L" {
1470                    tokens.next()
1471                } else {
1472                    Some(arg.strip_prefix("-L").unwrap().to_string())
1473                };
1474
1475                if let Some(spec) = forward_spec {
1476                    port_forwards.push(parse_port_forward_spec(&spec)?);
1477                } else {
1478                    anyhow::bail!("Missing port forward format");
1479                }
1480            }
1481
1482            for a in ALLOWED_ARGS {
1483                if arg == *a {
1484                    args.push(arg);
1485                    if let Some(next) = tokens.next() {
1486                        args.push(next);
1487                    }
1488                    continue 'outer;
1489                } else if arg.starts_with(a) {
1490                    args.push(arg);
1491                    continue 'outer;
1492                }
1493            }
1494            if arg.starts_with("-") || hostname.is_some() {
1495                anyhow::bail!("unsupported argument: {:?}", arg);
1496            }
1497            let mut input = &arg as &str;
1498            // Destination might be: username1@username2@ip2@ip1
1499            if let Some((u, rest)) = input.rsplit_once('@') {
1500                input = rest;
1501                username = Some(u.to_string());
1502            }
1503
1504            // Handle port parsing, accounting for IPv6 addresses
1505            // IPv6 addresses can be: 2001:db8::1 or [2001:db8::1]:22
1506            if input.starts_with('[') {
1507                if let Some((rest, p)) = input.rsplit_once("]:") {
1508                    input = rest.strip_prefix('[').unwrap_or(rest);
1509                    port = p.parse().ok();
1510                } else if input.ends_with(']') {
1511                    input = input.strip_prefix('[').unwrap_or(input);
1512                    input = input.strip_suffix(']').unwrap_or(input);
1513                }
1514            } else if let Some((rest, p)) = input.rsplit_once(':')
1515                && !rest.contains(":")
1516            {
1517                input = rest;
1518                port = p.parse().ok();
1519            }
1520
1521            hostname = Some(input.to_string())
1522        }
1523
1524        let Some(hostname) = hostname else {
1525            anyhow::bail!("missing hostname");
1526        };
1527
1528        let port_forwards = match port_forwards.len() {
1529            0 => None,
1530            _ => Some(port_forwards),
1531        };
1532
1533        Ok(Self {
1534            host: hostname.into(),
1535            username,
1536            port,
1537            port_forwards,
1538            args: Some(args),
1539            password: None,
1540            nickname: None,
1541            upload_binary_over_ssh: false,
1542            connection_timeout: None,
1543        })
1544    }
1545
1546    pub fn ssh_destination(&self) -> String {
1547        let mut result = String::default();
1548        if let Some(username) = &self.username {
1549            // Username might be: username1@username2@ip2
1550            let username = urlencoding::encode(username);
1551            result.push_str(&username);
1552            result.push('@');
1553        }
1554
1555        result.push_str(&self.host.to_string());
1556        result
1557    }
1558
1559    pub fn additional_args_for_scp(&self) -> Vec<String> {
1560        self.args.iter().flatten().cloned().collect::<Vec<String>>()
1561    }
1562
1563    pub fn additional_args(&self) -> Vec<String> {
1564        let mut args = self.additional_args_for_scp();
1565
1566        if let Some(timeout) = self.connection_timeout {
1567            args.extend(["-o".to_string(), format!("ConnectTimeout={}", timeout)]);
1568        }
1569
1570        if let Some(port) = self.port {
1571            args.push("-p".to_string());
1572            args.push(port.to_string());
1573        }
1574
1575        if let Some(forwards) = &self.port_forwards {
1576            args.extend(forwards.iter().map(|pf| {
1577                let local_host = match &pf.local_host {
1578                    Some(host) => host,
1579                    None => "localhost",
1580                };
1581                let remote_host = match &pf.remote_host {
1582                    Some(host) => host,
1583                    None => "localhost",
1584                };
1585
1586                format!(
1587                    "-L{}:{}:{}:{}",
1588                    bracket_ipv6(local_host),
1589                    pf.local_port,
1590                    bracket_ipv6(remote_host),
1591                    pf.remote_port
1592                )
1593            }));
1594        }
1595
1596        args
1597    }
1598
1599    fn scp_destination(&self) -> String {
1600        if let Some(username) = &self.username {
1601            format!("{}@{}", username, self.host.to_bracketed_string())
1602        } else {
1603            self.host.to_string()
1604        }
1605    }
1606
1607    pub fn connection_string(&self) -> String {
1608        let host = if let Some(port) = &self.port {
1609            format!("{}:{}", self.host.to_bracketed_string(), port)
1610        } else {
1611            self.host.to_string()
1612        };
1613
1614        if let Some(username) = &self.username {
1615            format!("{}@{}", username, host)
1616        } else {
1617            host
1618        }
1619    }
1620}
1621
1622fn build_command_posix(
1623    input_program: Option<String>,
1624    input_args: &[String],
1625    input_env: &HashMap<String, String>,
1626    working_dir: Option<String>,
1627    port_forward: Option<(u16, String, u16)>,
1628    ssh_env: HashMap<String, String>,
1629    ssh_path_style: PathStyle,
1630    ssh_shell: &str,
1631    ssh_shell_kind: ShellKind,
1632    ssh_options: Vec<String>,
1633    ssh_destination: &str,
1634    interactive: Interactive,
1635) -> Result<CommandTemplate> {
1636    use std::fmt::Write as _;
1637
1638    let mut exec = String::new();
1639    if let Some(working_dir) = working_dir {
1640        let working_dir = RemotePathBuf::new(working_dir, ssh_path_style).to_string();
1641
1642        // shlex will wrap the command in single quotes (''), disabling ~ expansion,
1643        // replace with something that works
1644        const TILDE_PREFIX: &'static str = "~/";
1645        if working_dir.starts_with(TILDE_PREFIX) {
1646            let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
1647            write!(
1648                exec,
1649                "cd \"$HOME/{working_dir}\" {} ",
1650                ssh_shell_kind.sequential_and_commands_separator()
1651            )?;
1652        } else {
1653            write!(
1654                exec,
1655                "cd \"{working_dir}\" {} ",
1656                ssh_shell_kind.sequential_and_commands_separator()
1657            )?;
1658        }
1659    } else {
1660        write!(
1661            exec,
1662            "cd {} ",
1663            ssh_shell_kind.sequential_and_commands_separator()
1664        )?;
1665    };
1666    write!(exec, "exec env ")?;
1667
1668    for (k, v) in input_env.iter() {
1669        write!(
1670            exec,
1671            "{}={} ",
1672            k,
1673            ssh_shell_kind.try_quote(v).context("shell quoting")?
1674        )?;
1675    }
1676
1677    if let Some(input_program) = input_program {
1678        write!(
1679            exec,
1680            "{}",
1681            ssh_shell_kind
1682                .try_quote_prefix_aware(&input_program)
1683                .context("shell quoting")?
1684        )?;
1685        for arg in input_args {
1686            let arg = ssh_shell_kind.try_quote(&arg).context("shell quoting")?;
1687            write!(exec, " {}", &arg)?;
1688        }
1689    } else {
1690        write!(exec, "{ssh_shell} -l")?;
1691    };
1692
1693    let mut args = Vec::new();
1694    args.extend(ssh_options);
1695
1696    if let Some((local_port, host, remote_port)) = port_forward {
1697        args.push("-L".into());
1698        args.push(format!(
1699            "{}:{}:{}",
1700            local_port,
1701            bracket_ipv6(&host),
1702            remote_port
1703        ));
1704    }
1705
1706    // -q suppresses the "Connection to ... closed." message that SSH prints when
1707    // the connection terminates with -t (pseudo-terminal allocation)
1708    args.push("-q".into());
1709    match interactive {
1710        // -t forces pseudo-TTY allocation (for interactive use)
1711        Interactive::Yes => args.push("-t".into()),
1712        // -T disables pseudo-TTY allocation (for non-interactive piped stdio)
1713        Interactive::No => args.push("-T".into()),
1714    }
1715    // The destination must come after all options but before the command
1716    args.push(ssh_destination.into());
1717    args.push(exec);
1718
1719    Ok(CommandTemplate {
1720        program: "ssh".into(),
1721        args,
1722        env: ssh_env,
1723    })
1724}
1725
1726fn build_command_windows(
1727    input_program: Option<String>,
1728    input_args: &[String],
1729    _input_env: &HashMap<String, String>,
1730    working_dir: Option<String>,
1731    port_forward: Option<(u16, String, u16)>,
1732    ssh_env: HashMap<String, String>,
1733    ssh_path_style: PathStyle,
1734    ssh_shell: &str,
1735    _ssh_shell_kind: ShellKind,
1736    ssh_options: Vec<String>,
1737    ssh_destination: &str,
1738    interactive: Interactive,
1739) -> Result<CommandTemplate> {
1740    use base64::Engine as _;
1741    use std::fmt::Write as _;
1742
1743    let mut exec = String::new();
1744    let shell_kind = ShellKind::PowerShell;
1745
1746    if let Some(working_dir) = working_dir {
1747        let working_dir = RemotePathBuf::new(working_dir, ssh_path_style).to_string();
1748
1749        write!(
1750            exec,
1751            "Set-Location -Path {} {} ",
1752            shell_kind
1753                .try_quote(&working_dir)
1754                .context("shell quoting")?,
1755            shell_kind.sequential_and_commands_separator()
1756        )?;
1757    }
1758
1759    // Windows OpenSSH has an 8K character limit for command lines. Sending a lot of environment variables easily puts us over the limit.
1760    // Until we have a better solution for this, we just won't set environment variables for now.
1761    // for (k, v) in input_env.iter() {
1762    //     write!(
1763    //         exec,
1764    //         "$env:{}={} {} ",
1765    //         k,
1766    //         shell_kind.try_quote(v).context("shell quoting")?,
1767    //         shell_kind.sequential_and_commands_separator()
1768    //     )?;
1769    // }
1770
1771    if let Some(input_program) = input_program {
1772        write!(
1773            exec,
1774            "{}",
1775            shell_kind
1776                .try_quote_prefix_aware(&shell_kind.prepend_command_prefix(&input_program))
1777                .context("shell quoting")?
1778        )?;
1779        for arg in input_args {
1780            let arg = shell_kind.try_quote(arg).context("shell quoting")?;
1781            write!(exec, " {}", &arg)?;
1782        }
1783    } else {
1784        // Launch an interactive shell session
1785        write!(exec, "{ssh_shell}")?;
1786    };
1787
1788    let mut args = Vec::new();
1789    args.extend(ssh_options);
1790
1791    if let Some((local_port, host, remote_port)) = port_forward {
1792        args.push("-L".into());
1793        args.push(format!(
1794            "{}:{}:{}",
1795            local_port,
1796            bracket_ipv6(&host),
1797            remote_port
1798        ));
1799    }
1800
1801    // -q suppresses the "Connection to ... closed." message that SSH prints when
1802    // the connection terminates with -t (pseudo-terminal allocation)
1803    args.push("-q".into());
1804    match interactive {
1805        // -t forces pseudo-TTY allocation (for interactive use)
1806        Interactive::Yes => args.push("-t".into()),
1807        // -T disables pseudo-TTY allocation (for non-interactive piped stdio)
1808        Interactive::No => args.push("-T".into()),
1809    }
1810
1811    // The destination must come after all options but before the command
1812    args.push(ssh_destination.into());
1813
1814    // Windows OpenSSH server incorrectly escapes the command string when the PTY is used.
1815    // The simplest way to work around this is to use a base64 encoded command, which doesn't require escaping.
1816    let utf16_bytes: Vec<u16> = exec.encode_utf16().collect();
1817    let byte_slice: Vec<u8> = utf16_bytes.iter().flat_map(|&u| u.to_le_bytes()).collect();
1818    let base64_encoded = base64::engine::general_purpose::STANDARD.encode(&byte_slice);
1819
1820    args.push(format!("powershell.exe -E {}", base64_encoded));
1821
1822    Ok(CommandTemplate {
1823        program: "ssh".into(),
1824        args,
1825        env: ssh_env,
1826    })
1827}
1828
1829#[cfg(test)]
1830mod tests {
1831    use super::*;
1832
1833    #[test]
1834    fn test_build_command() -> Result<()> {
1835        let mut input_env = HashMap::default();
1836        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1837        let mut env = HashMap::default();
1838        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1839
1840        // Test non-interactive command (interactive=false should use -T)
1841        let command = build_command_posix(
1842            Some("remote_program".to_string()),
1843            &["arg1".to_string(), "arg2".to_string()],
1844            &input_env,
1845            Some("~/work".to_string()),
1846            None,
1847            env.clone(),
1848            PathStyle::Posix,
1849            "/bin/bash",
1850            ShellKind::Posix,
1851            vec!["-o".to_string(), "ControlMaster=auto".to_string()],
1852            "user@host",
1853            Interactive::No,
1854        )?;
1855        assert_eq!(command.program, "ssh");
1856        // Should contain -T for non-interactive
1857        assert!(command.args.iter().any(|arg| arg == "-T"));
1858        assert!(!command.args.iter().any(|arg| arg == "-t"));
1859
1860        // Test interactive command (interactive=true should use -t)
1861        let command = build_command_posix(
1862            Some("remote_program".to_string()),
1863            &["arg1".to_string(), "arg2".to_string()],
1864            &input_env,
1865            Some("~/work".to_string()),
1866            None,
1867            env.clone(),
1868            PathStyle::Posix,
1869            "/bin/fish",
1870            ShellKind::Fish,
1871            vec!["-p".to_string(), "2222".to_string()],
1872            "user@host",
1873            Interactive::Yes,
1874        )?;
1875
1876        assert_eq!(command.program, "ssh");
1877        assert_eq!(
1878            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1879            [
1880                "-p",
1881                "2222",
1882                "-q",
1883                "-t",
1884                "user@host",
1885                "cd \"$HOME/work\" && exec env INPUT_VA=val remote_program arg1 arg2"
1886            ]
1887        );
1888        assert_eq!(command.env, env);
1889
1890        let mut input_env = HashMap::default();
1891        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1892        let mut env = HashMap::default();
1893        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1894
1895        let command = build_command_posix(
1896            None,
1897            &[],
1898            &input_env,
1899            None,
1900            Some((1, "foo".to_owned(), 2)),
1901            env.clone(),
1902            PathStyle::Posix,
1903            "/bin/fish",
1904            ShellKind::Fish,
1905            vec!["-p".to_string(), "2222".to_string()],
1906            "user@host",
1907            Interactive::Yes,
1908        )?;
1909
1910        assert_eq!(command.program, "ssh");
1911        assert_eq!(
1912            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1913            [
1914                "-p",
1915                "2222",
1916                "-L",
1917                "1:foo:2",
1918                "-q",
1919                "-t",
1920                "user@host",
1921                "cd && exec env INPUT_VA=val /bin/fish -l"
1922            ]
1923        );
1924        assert_eq!(command.env, env);
1925
1926        Ok(())
1927    }
1928
1929    #[test]
1930    fn scp_args_exclude_port_forward_flags() {
1931        let options = SshConnectionOptions {
1932            host: "example.com".into(),
1933            args: Some(vec![
1934                "-p".to_string(),
1935                "2222".to_string(),
1936                "-o".to_string(),
1937                "StrictHostKeyChecking=no".to_string(),
1938            ]),
1939            port_forwards: Some(vec![SshPortForwardOption {
1940                local_host: Some("127.0.0.1".to_string()),
1941                local_port: 8080,
1942                remote_host: Some("127.0.0.1".to_string()),
1943                remote_port: 80,
1944            }]),
1945            ..Default::default()
1946        };
1947
1948        let ssh_args = options.additional_args();
1949        assert!(
1950            ssh_args.iter().any(|arg| arg.starts_with("-L")),
1951            "expected ssh args to include port-forward: {ssh_args:?}"
1952        );
1953
1954        let scp_args = options.additional_args_for_scp();
1955        assert_eq!(
1956            scp_args,
1957            vec![
1958                "-p".to_string(),
1959                "2222".to_string(),
1960                "-o".to_string(),
1961                "StrictHostKeyChecking=no".to_string(),
1962            ]
1963        );
1964    }
1965
1966    #[test]
1967    fn test_host_parsing() -> Result<()> {
1968        let opts = SshConnectionOptions::parse_command_line("user@2001:db8::1")?;
1969        assert_eq!(opts.host, "2001:db8::1".into());
1970        assert_eq!(opts.username, Some("user".to_string()));
1971        assert_eq!(opts.port, None);
1972
1973        let opts = SshConnectionOptions::parse_command_line("user@[2001:db8::1]:2222")?;
1974        assert_eq!(opts.host, "2001:db8::1".into());
1975        assert_eq!(opts.username, Some("user".to_string()));
1976        assert_eq!(opts.port, Some(2222));
1977
1978        let opts = SshConnectionOptions::parse_command_line("user@[2001:db8::1]")?;
1979        assert_eq!(opts.host, "2001:db8::1".into());
1980        assert_eq!(opts.username, Some("user".to_string()));
1981        assert_eq!(opts.port, None);
1982
1983        let opts = SshConnectionOptions::parse_command_line("2001:db8::1")?;
1984        assert_eq!(opts.host, "2001:db8::1".into());
1985        assert_eq!(opts.username, None);
1986        assert_eq!(opts.port, None);
1987
1988        let opts = SshConnectionOptions::parse_command_line("[2001:db8::1]:2222")?;
1989        assert_eq!(opts.host, "2001:db8::1".into());
1990        assert_eq!(opts.username, None);
1991        assert_eq!(opts.port, Some(2222));
1992
1993        let opts = SshConnectionOptions::parse_command_line("user@example.com:2222")?;
1994        assert_eq!(opts.host, "example.com".into());
1995        assert_eq!(opts.username, Some("user".to_string()));
1996        assert_eq!(opts.port, Some(2222));
1997
1998        let opts = SshConnectionOptions::parse_command_line("user@192.168.1.1:2222")?;
1999        assert_eq!(opts.host, "192.168.1.1".into());
2000        assert_eq!(opts.username, Some("user".to_string()));
2001        assert_eq!(opts.port, Some(2222));
2002
2003        Ok(())
2004    }
2005
2006    #[test]
2007    fn test_parse_port_forward_spec_ipv6() -> Result<()> {
2008        let pf = parse_port_forward_spec("[::1]:8080:[::1]:80")?;
2009        assert_eq!(pf.local_host, Some("::1".to_string()));
2010        assert_eq!(pf.local_port, 8080);
2011        assert_eq!(pf.remote_host, Some("::1".to_string()));
2012        assert_eq!(pf.remote_port, 80);
2013
2014        let pf = parse_port_forward_spec("8080:[::1]:80")?;
2015        assert_eq!(pf.local_host, None);
2016        assert_eq!(pf.local_port, 8080);
2017        assert_eq!(pf.remote_host, Some("::1".to_string()));
2018        assert_eq!(pf.remote_port, 80);
2019
2020        let pf = parse_port_forward_spec("[2001:db8::1]:3000:[fe80::1]:4000")?;
2021        assert_eq!(pf.local_host, Some("2001:db8::1".to_string()));
2022        assert_eq!(pf.local_port, 3000);
2023        assert_eq!(pf.remote_host, Some("fe80::1".to_string()));
2024        assert_eq!(pf.remote_port, 4000);
2025
2026        let pf = parse_port_forward_spec("127.0.0.1:8080:localhost:80")?;
2027        assert_eq!(pf.local_host, Some("127.0.0.1".to_string()));
2028        assert_eq!(pf.local_port, 8080);
2029        assert_eq!(pf.remote_host, Some("localhost".to_string()));
2030        assert_eq!(pf.remote_port, 80);
2031
2032        Ok(())
2033    }
2034
2035    #[test]
2036    fn test_port_forward_ipv6_formatting() {
2037        let options = SshConnectionOptions {
2038            host: "example.com".into(),
2039            port_forwards: Some(vec![SshPortForwardOption {
2040                local_host: Some("::1".to_string()),
2041                local_port: 8080,
2042                remote_host: Some("::1".to_string()),
2043                remote_port: 80,
2044            }]),
2045            ..Default::default()
2046        };
2047
2048        let args = options.additional_args();
2049        assert!(
2050            args.iter().any(|arg| arg == "-L[::1]:8080:[::1]:80"),
2051            "expected bracketed IPv6 in -L flag: {args:?}"
2052        );
2053    }
2054
2055    #[test]
2056    fn test_build_command_with_ipv6_port_forward() -> Result<()> {
2057        let command = build_command_posix(
2058            None,
2059            &[],
2060            &HashMap::default(),
2061            None,
2062            Some((8080, "::1".to_owned(), 80)),
2063            HashMap::default(),
2064            PathStyle::Posix,
2065            "/bin/bash",
2066            ShellKind::Posix,
2067            vec![],
2068            "user@host",
2069            Interactive::No,
2070        )?;
2071
2072        assert!(
2073            command.args.iter().any(|arg| arg == "8080:[::1]:80"),
2074            "expected bracketed IPv6 in port forward arg: {:?}",
2075            command.args
2076        );
2077
2078        Ok(())
2079    }
2080}