ssh.rs

   1use crate::{
   2    RemoteClientDelegate, RemotePlatform,
   3    remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
   4};
   5use anyhow::{Context as _, Result, anyhow};
   6use async_trait::async_trait;
   7use collections::HashMap;
   8use futures::{
   9    AsyncReadExt as _, FutureExt as _,
  10    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
  11    select_biased,
  12};
  13use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task};
  14use itertools::Itertools;
  15use parking_lot::Mutex;
  16use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  17use rpc::proto::Envelope;
  18use schemars::JsonSchema;
  19use serde::{Deserialize, Serialize};
  20use smol::{
  21    fs,
  22    process::{self, Child, Stdio},
  23};
  24use std::{
  25    iter,
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28    time::Instant,
  29};
  30use tempfile::TempDir;
  31use util::paths::{PathStyle, RemotePathBuf};
  32
  33pub(crate) struct SshRemoteConnection {
  34    socket: SshSocket,
  35    master_process: Mutex<Option<Child>>,
  36    remote_binary_path: Option<RemotePathBuf>,
  37    ssh_platform: RemotePlatform,
  38    ssh_path_style: PathStyle,
  39    ssh_shell: String,
  40    ssh_default_system_shell: String,
  41    _temp_dir: TempDir,
  42}
  43
  44#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  45pub struct SshConnectionOptions {
  46    pub host: String,
  47    pub username: Option<String>,
  48    pub port: Option<u16>,
  49    pub password: Option<String>,
  50    pub args: Option<Vec<String>>,
  51    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  52
  53    pub nickname: Option<String>,
  54    pub upload_binary_over_ssh: bool,
  55}
  56
  57#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  58pub struct SshPortForwardOption {
  59    #[serde(skip_serializing_if = "Option::is_none")]
  60    pub local_host: Option<String>,
  61    pub local_port: u16,
  62    #[serde(skip_serializing_if = "Option::is_none")]
  63    pub remote_host: Option<String>,
  64    pub remote_port: u16,
  65}
  66
  67#[derive(Clone)]
  68struct SshSocket {
  69    connection_options: SshConnectionOptions,
  70    #[cfg(not(target_os = "windows"))]
  71    socket_path: PathBuf,
  72    envs: HashMap<String, String>,
  73}
  74
  75macro_rules! shell_script {
  76    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
  77        format!(
  78            $fmt,
  79            $(
  80                $name = shlex::try_quote($arg).unwrap()
  81            ),+
  82        )
  83    }};
  84}
  85
  86#[async_trait(?Send)]
  87impl RemoteConnection for SshRemoteConnection {
  88    async fn kill(&self) -> Result<()> {
  89        let Some(mut process) = self.master_process.lock().take() else {
  90            return Ok(());
  91        };
  92        process.kill().ok();
  93        process.status().await?;
  94        Ok(())
  95    }
  96
  97    fn has_been_killed(&self) -> bool {
  98        self.master_process.lock().is_none()
  99    }
 100
 101    fn connection_options(&self) -> RemoteConnectionOptions {
 102        RemoteConnectionOptions::Ssh(self.socket.connection_options.clone())
 103    }
 104
 105    fn shell(&self) -> String {
 106        self.ssh_shell.clone()
 107    }
 108
 109    fn default_system_shell(&self) -> String {
 110        self.ssh_default_system_shell.clone()
 111    }
 112
 113    fn build_command(
 114        &self,
 115        input_program: Option<String>,
 116        input_args: &[String],
 117        input_env: &HashMap<String, String>,
 118        working_dir: Option<String>,
 119        port_forward: Option<(u16, String, u16)>,
 120    ) -> Result<CommandTemplate> {
 121        let Self {
 122            ssh_path_style,
 123            socket,
 124            ssh_shell,
 125            ..
 126        } = self;
 127        let env = socket.envs.clone();
 128        build_command(
 129            input_program,
 130            input_args,
 131            input_env,
 132            working_dir,
 133            port_forward,
 134            env,
 135            *ssh_path_style,
 136            ssh_shell,
 137            socket.ssh_args(),
 138        )
 139    }
 140
 141    fn upload_directory(
 142        &self,
 143        src_path: PathBuf,
 144        dest_path: RemotePathBuf,
 145        cx: &App,
 146    ) -> Task<Result<()>> {
 147        let mut command = util::command::new_smol_command("scp");
 148        let output = self
 149            .socket
 150            .ssh_options(&mut command)
 151            .args(
 152                self.socket
 153                    .connection_options
 154                    .port
 155                    .map(|port| vec!["-P".to_string(), port.to_string()])
 156                    .unwrap_or_default(),
 157            )
 158            .arg("-C")
 159            .arg("-r")
 160            .arg(&src_path)
 161            .arg(format!(
 162                "{}:{}",
 163                self.socket.connection_options.scp_url(),
 164                dest_path
 165            ))
 166            .output();
 167
 168        cx.background_spawn(async move {
 169            let output = output.await?;
 170
 171            anyhow::ensure!(
 172                output.status.success(),
 173                "failed to upload directory {} -> {}: {}",
 174                src_path.display(),
 175                dest_path.to_string(),
 176                String::from_utf8_lossy(&output.stderr)
 177            );
 178
 179            Ok(())
 180        })
 181    }
 182
 183    fn start_proxy(
 184        &self,
 185        unique_identifier: String,
 186        reconnect: bool,
 187        incoming_tx: UnboundedSender<Envelope>,
 188        outgoing_rx: UnboundedReceiver<Envelope>,
 189        connection_activity_tx: Sender<()>,
 190        delegate: Arc<dyn RemoteClientDelegate>,
 191        cx: &mut AsyncApp,
 192    ) -> Task<Result<i32>> {
 193        delegate.set_status(Some("Starting proxy"), cx);
 194
 195        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
 196            return Task::ready(Err(anyhow!("Remote binary path not set")));
 197        };
 198
 199        let mut start_proxy_command = shell_script!(
 200            "exec {binary_path} proxy --identifier {identifier}",
 201            binary_path = &remote_binary_path.to_string(),
 202            identifier = &unique_identifier,
 203        );
 204
 205        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
 206            if let Some(value) = std::env::var(env_var).ok() {
 207                start_proxy_command = format!(
 208                    "{}={} {} ",
 209                    env_var,
 210                    shlex::try_quote(&value).unwrap(),
 211                    start_proxy_command,
 212                );
 213            }
 214        }
 215
 216        if reconnect {
 217            start_proxy_command.push_str(" --reconnect");
 218        }
 219
 220        let ssh_proxy_process = match self
 221            .socket
 222            .ssh_command("sh", &["-c", &start_proxy_command])
 223            // IMPORTANT: we kill this process when we drop the task that uses it.
 224            .kill_on_drop(true)
 225            .spawn()
 226        {
 227            Ok(process) => process,
 228            Err(error) => {
 229                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
 230            }
 231        };
 232
 233        super::handle_rpc_messages_over_child_process_stdio(
 234            ssh_proxy_process,
 235            incoming_tx,
 236            outgoing_rx,
 237            connection_activity_tx,
 238            cx,
 239        )
 240    }
 241
 242    fn path_style(&self) -> PathStyle {
 243        self.ssh_path_style
 244    }
 245}
 246
 247impl SshRemoteConnection {
 248    pub(crate) async fn new(
 249        connection_options: SshConnectionOptions,
 250        delegate: Arc<dyn RemoteClientDelegate>,
 251        cx: &mut AsyncApp,
 252    ) -> Result<Self> {
 253        use askpass::AskPassResult;
 254
 255        delegate.set_status(Some("Connecting"), cx);
 256
 257        let url = connection_options.ssh_url();
 258
 259        let temp_dir = tempfile::Builder::new()
 260            .prefix("zed-ssh-session")
 261            .tempdir()?;
 262        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
 263            let delegate = delegate.clone();
 264            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
 265        });
 266
 267        let mut askpass =
 268            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
 269
 270        // Start the master SSH process, which does not do anything except for establish
 271        // the connection and keep it open, allowing other ssh commands to reuse it
 272        // via a control socket.
 273        #[cfg(not(target_os = "windows"))]
 274        let socket_path = temp_dir.path().join("ssh.sock");
 275
 276        let mut master_process = {
 277            #[cfg(not(target_os = "windows"))]
 278            let args = [
 279                "-N",
 280                "-o",
 281                "ControlPersist=no",
 282                "-o",
 283                "ControlMaster=yes",
 284                "-o",
 285            ];
 286            // On Windows, `ControlMaster` and `ControlPath` are not supported:
 287            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
 288            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
 289            #[cfg(target_os = "windows")]
 290            let args = ["-N"];
 291            let mut master_process = util::command::new_smol_command("ssh");
 292            master_process
 293                .kill_on_drop(true)
 294                .stdin(Stdio::null())
 295                .stdout(Stdio::piped())
 296                .stderr(Stdio::piped())
 297                .env("SSH_ASKPASS_REQUIRE", "force")
 298                .env("SSH_ASKPASS", askpass.script_path())
 299                .args(connection_options.additional_args())
 300                .args(args);
 301            #[cfg(not(target_os = "windows"))]
 302            master_process.arg(format!("ControlPath={}", socket_path.display()));
 303            master_process.arg(&url).spawn()?
 304        };
 305        // Wait for this ssh process to close its stdout, indicating that authentication
 306        // has completed.
 307        let mut stdout = master_process.stdout.take().unwrap();
 308        let mut output = Vec::new();
 309
 310        let result = select_biased! {
 311            result = askpass.run().fuse() => {
 312                match result {
 313                    AskPassResult::CancelledByUser => {
 314                        master_process.kill().ok();
 315                        anyhow::bail!("SSH connection canceled")
 316                    }
 317                    AskPassResult::Timedout => {
 318                        anyhow::bail!("connecting to host timed out")
 319                    }
 320                }
 321            }
 322            _ = stdout.read_to_end(&mut output).fuse() => {
 323                anyhow::Ok(())
 324            }
 325        };
 326
 327        if let Err(e) = result {
 328            return Err(e.context("Failed to connect to host"));
 329        }
 330
 331        if master_process.try_status()?.is_some() {
 332            output.clear();
 333            let mut stderr = master_process.stderr.take().unwrap();
 334            stderr.read_to_end(&mut output).await?;
 335
 336            let error_message = format!(
 337                "failed to connect: {}",
 338                String::from_utf8_lossy(&output).trim()
 339            );
 340            anyhow::bail!(error_message);
 341        }
 342
 343        #[cfg(not(target_os = "windows"))]
 344        let socket = SshSocket::new(connection_options, socket_path)?;
 345        #[cfg(target_os = "windows")]
 346        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
 347        drop(askpass);
 348
 349        let ssh_platform = socket.platform().await?;
 350        let ssh_path_style = match ssh_platform.os {
 351            "windows" => PathStyle::Windows,
 352            _ => PathStyle::Posix,
 353        };
 354        let ssh_shell = socket.shell().await;
 355        let ssh_default_system_shell = String::from("/bin/sh");
 356
 357        let mut this = Self {
 358            socket,
 359            master_process: Mutex::new(Some(master_process)),
 360            _temp_dir: temp_dir,
 361            remote_binary_path: None,
 362            ssh_path_style,
 363            ssh_platform,
 364            ssh_shell,
 365            ssh_default_system_shell,
 366        };
 367
 368        let (release_channel, version, commit) = cx.update(|cx| {
 369            (
 370                ReleaseChannel::global(cx),
 371                AppVersion::global(cx),
 372                AppCommitSha::try_global(cx),
 373            )
 374        })?;
 375        this.remote_binary_path = Some(
 376            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
 377                .await?,
 378        );
 379
 380        Ok(this)
 381    }
 382
 383    async fn ensure_server_binary(
 384        &self,
 385        delegate: &Arc<dyn RemoteClientDelegate>,
 386        release_channel: ReleaseChannel,
 387        version: SemanticVersion,
 388        commit: Option<AppCommitSha>,
 389        cx: &mut AsyncApp,
 390    ) -> Result<RemotePathBuf> {
 391        let version_str = match release_channel {
 392            ReleaseChannel::Nightly => {
 393                let commit = commit.map(|s| s.full()).unwrap_or_default();
 394                format!("{}-{}", version, commit)
 395            }
 396            ReleaseChannel::Dev => "build".to_string(),
 397            _ => version.to_string(),
 398        };
 399        let binary_name = format!(
 400            "zed-remote-server-{}-{}",
 401            release_channel.dev_name(),
 402            version_str
 403        );
 404        let dst_path = RemotePathBuf::new(
 405            paths::remote_server_dir_relative().join(binary_name),
 406            self.ssh_path_style,
 407        );
 408
 409        #[cfg(debug_assertions)]
 410        if let Some(remote_server_path) =
 411            super::build_remote_server_from_source(&self.ssh_platform, delegate.as_ref(), cx)
 412                .await?
 413        {
 414            let tmp_path = RemotePathBuf::new(
 415                paths::remote_server_dir_relative().join(format!(
 416                    "download-{}-{}",
 417                    std::process::id(),
 418                    remote_server_path.file_name().unwrap().to_string_lossy()
 419                )),
 420                self.ssh_path_style,
 421            );
 422            self.upload_local_server_binary(&remote_server_path, &tmp_path, delegate, cx)
 423                .await?;
 424            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
 425                .await?;
 426            return Ok(dst_path);
 427        }
 428
 429        if self
 430            .socket
 431            .run_command(&dst_path.to_string(), &["version"])
 432            .await
 433            .is_ok()
 434        {
 435            return Ok(dst_path);
 436        }
 437
 438        let wanted_version = cx.update(|cx| match release_channel {
 439            ReleaseChannel::Nightly => Ok(None),
 440            ReleaseChannel::Dev => {
 441                anyhow::bail!(
 442                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
 443                    dst_path
 444                )
 445            }
 446            _ => Ok(Some(AppVersion::global(cx))),
 447        })??;
 448
 449        let tmp_path_gz = RemotePathBuf::new(
 450            PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
 451            self.ssh_path_style,
 452        );
 453        if !self.socket.connection_options.upload_binary_over_ssh
 454            && let Some((url, body)) = delegate
 455                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
 456                .await?
 457        {
 458            match self
 459                .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
 460                .await
 461            {
 462                Ok(_) => {
 463                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 464                        .await?;
 465                    return Ok(dst_path);
 466                }
 467                Err(e) => {
 468                    log::error!(
 469                        "Failed to download binary on server, attempting to upload server: {}",
 470                        e
 471                    )
 472                }
 473            }
 474        }
 475
 476        let src_path = delegate
 477            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
 478            .await?;
 479        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
 480            .await?;
 481        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 482            .await?;
 483        Ok(dst_path)
 484    }
 485
 486    async fn download_binary_on_server(
 487        &self,
 488        url: &str,
 489        body: &str,
 490        tmp_path_gz: &RemotePathBuf,
 491        delegate: &Arc<dyn RemoteClientDelegate>,
 492        cx: &mut AsyncApp,
 493    ) -> Result<()> {
 494        if let Some(parent) = tmp_path_gz.parent() {
 495            self.socket
 496                .run_command(
 497                    "sh",
 498                    &[
 499                        "-c",
 500                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
 501                    ],
 502                )
 503                .await?;
 504        }
 505
 506        delegate.set_status(Some("Downloading remote development server on host"), cx);
 507
 508        match self
 509            .socket
 510            .run_command(
 511                "curl",
 512                &[
 513                    "-f",
 514                    "-L",
 515                    "-X",
 516                    "GET",
 517                    "-H",
 518                    "Content-Type: application/json",
 519                    "-d",
 520                    body,
 521                    url,
 522                    "-o",
 523                    &tmp_path_gz.to_string(),
 524                ],
 525            )
 526            .await
 527        {
 528            Ok(_) => {}
 529            Err(e) => {
 530                if self.socket.run_command("which", &["curl"]).await.is_ok() {
 531                    return Err(e);
 532                }
 533
 534                match self
 535                    .socket
 536                    .run_command(
 537                        "wget",
 538                        &[
 539                            "--method=GET",
 540                            "--header=Content-Type: application/json",
 541                            "--body-data",
 542                            body,
 543                            url,
 544                            "-O",
 545                            &tmp_path_gz.to_string(),
 546                        ],
 547                    )
 548                    .await
 549                {
 550                    Ok(_) => {}
 551                    Err(e) => {
 552                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
 553                            return Err(e);
 554                        } else {
 555                            anyhow::bail!("Neither curl nor wget is available");
 556                        }
 557                    }
 558                }
 559            }
 560        }
 561
 562        Ok(())
 563    }
 564
 565    async fn upload_local_server_binary(
 566        &self,
 567        src_path: &Path,
 568        tmp_path_gz: &RemotePathBuf,
 569        delegate: &Arc<dyn RemoteClientDelegate>,
 570        cx: &mut AsyncApp,
 571    ) -> Result<()> {
 572        if let Some(parent) = tmp_path_gz.parent() {
 573            self.socket
 574                .run_command(
 575                    "sh",
 576                    &[
 577                        "-c",
 578                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
 579                    ],
 580                )
 581                .await?;
 582        }
 583
 584        let src_stat = fs::metadata(&src_path).await?;
 585        let size = src_stat.len();
 586
 587        let t0 = Instant::now();
 588        delegate.set_status(Some("Uploading remote development server"), cx);
 589        log::info!(
 590            "uploading remote development server to {:?} ({}kb)",
 591            tmp_path_gz,
 592            size / 1024
 593        );
 594        self.upload_file(src_path, tmp_path_gz)
 595            .await
 596            .context("failed to upload server binary")?;
 597        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 598        Ok(())
 599    }
 600
 601    async fn extract_server_binary(
 602        &self,
 603        dst_path: &RemotePathBuf,
 604        tmp_path: &RemotePathBuf,
 605        delegate: &Arc<dyn RemoteClientDelegate>,
 606        cx: &mut AsyncApp,
 607    ) -> Result<()> {
 608        delegate.set_status(Some("Extracting remote development server"), cx);
 609        let server_mode = 0o755;
 610
 611        let orig_tmp_path = tmp_path.to_string();
 612        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
 613            shell_script!(
 614                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
 615                server_mode = &format!("{:o}", server_mode),
 616                dst_path = &dst_path.to_string(),
 617            )
 618        } else {
 619            shell_script!(
 620                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
 621                server_mode = &format!("{:o}", server_mode),
 622                dst_path = &dst_path.to_string()
 623            )
 624        };
 625        self.socket.run_command("sh", &["-c", &script]).await?;
 626        Ok(())
 627    }
 628
 629    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
 630        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
 631        let mut command = util::command::new_smol_command("scp");
 632        let output = self
 633            .socket
 634            .ssh_options(&mut command)
 635            .args(
 636                self.socket
 637                    .connection_options
 638                    .port
 639                    .map(|port| vec!["-P".to_string(), port.to_string()])
 640                    .unwrap_or_default(),
 641            )
 642            .arg(src_path)
 643            .arg(format!(
 644                "{}:{}",
 645                self.socket.connection_options.scp_url(),
 646                dest_path
 647            ))
 648            .output()
 649            .await?;
 650
 651        anyhow::ensure!(
 652            output.status.success(),
 653            "failed to upload file {} -> {}: {}",
 654            src_path.display(),
 655            dest_path.to_string(),
 656            String::from_utf8_lossy(&output.stderr)
 657        );
 658        Ok(())
 659    }
 660}
 661
 662impl SshSocket {
 663    #[cfg(not(target_os = "windows"))]
 664    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 665        Ok(Self {
 666            connection_options: options,
 667            envs: HashMap::default(),
 668            socket_path,
 669        })
 670    }
 671
 672    #[cfg(target_os = "windows")]
 673    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 674        let askpass_script = temp_dir.path().join("askpass.bat");
 675        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 676        let mut envs = HashMap::default();
 677        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 678        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 679        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 680        Ok(Self {
 681            connection_options: options,
 682            envs,
 683        })
 684    }
 685
 686    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 687    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 688    // and passes -l as an argument to sh, not to ls.
 689    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 690    // into a machine. You must use `cd` to get back to $HOME.
 691    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 692    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 693        let mut command = util::command::new_smol_command("ssh");
 694        let to_run = iter::once(&program)
 695            .chain(args.iter())
 696            .map(|token| {
 697                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 698                debug_assert!(
 699                    !token.contains('\n'),
 700                    "multiline arguments do not work in all shells"
 701                );
 702                shlex::try_quote(token).unwrap()
 703            })
 704            .join(" ");
 705        let to_run = format!("cd; {to_run}");
 706        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 707        self.ssh_options(&mut command)
 708            .arg(self.connection_options.ssh_url())
 709            .arg(to_run);
 710        command
 711    }
 712
 713    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 714        let output = self.ssh_command(program, args).output().await?;
 715        anyhow::ensure!(
 716            output.status.success(),
 717            "failed to run command: {}",
 718            String::from_utf8_lossy(&output.stderr)
 719        );
 720        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 721    }
 722
 723    #[cfg(not(target_os = "windows"))]
 724    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 725        command
 726            .stdin(Stdio::piped())
 727            .stdout(Stdio::piped())
 728            .stderr(Stdio::piped())
 729            .args(self.connection_options.additional_args())
 730            .args(["-o", "ControlMaster=no", "-o"])
 731            .arg(format!("ControlPath={}", self.socket_path.display()))
 732    }
 733
 734    #[cfg(target_os = "windows")]
 735    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 736        command
 737            .stdin(Stdio::piped())
 738            .stdout(Stdio::piped())
 739            .stderr(Stdio::piped())
 740            .args(self.connection_options.additional_args())
 741            .envs(self.envs.clone())
 742    }
 743
 744    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 745    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 746    #[cfg(not(target_os = "windows"))]
 747    fn ssh_args(&self) -> Vec<String> {
 748        let mut arguments = self.connection_options.additional_args();
 749        arguments.extend(vec![
 750            "-o".to_string(),
 751            "ControlMaster=no".to_string(),
 752            "-o".to_string(),
 753            format!("ControlPath={}", self.socket_path.display()),
 754            self.connection_options.ssh_url(),
 755        ]);
 756        arguments
 757    }
 758
 759    #[cfg(target_os = "windows")]
 760    fn ssh_args(&self) -> Vec<String> {
 761        let mut arguments = self.connection_options.additional_args();
 762        arguments.push(self.connection_options.ssh_url());
 763        arguments
 764    }
 765
 766    async fn platform(&self) -> Result<RemotePlatform> {
 767        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 768        let Some((os, arch)) = uname.split_once(" ") else {
 769            anyhow::bail!("unknown uname: {uname:?}")
 770        };
 771
 772        let os = match os.trim() {
 773            "Darwin" => "macos",
 774            "Linux" => "linux",
 775            _ => anyhow::bail!(
 776                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 777            ),
 778        };
 779        // exclude armv5,6,7 as they are 32-bit.
 780        let arch = if arch.starts_with("armv8")
 781            || arch.starts_with("armv9")
 782            || arch.starts_with("arm64")
 783            || arch.starts_with("aarch64")
 784        {
 785            "aarch64"
 786        } else if arch.starts_with("x86") {
 787            "x86_64"
 788        } else {
 789            anyhow::bail!(
 790                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 791            )
 792        };
 793
 794        Ok(RemotePlatform { os, arch })
 795    }
 796
 797    async fn shell(&self) -> String {
 798        match self.run_command("sh", &["-c", "echo $SHELL"]).await {
 799            Ok(shell) => shell.trim().to_owned(),
 800            Err(e) => {
 801                log::error!("Failed to get shell: {e}");
 802                "sh".to_owned()
 803            }
 804        }
 805    }
 806}
 807
 808fn parse_port_number(port_str: &str) -> Result<u16> {
 809    port_str
 810        .parse()
 811        .with_context(|| format!("parsing port number: {port_str}"))
 812}
 813
 814fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 815    let parts: Vec<&str> = spec.split(':').collect();
 816
 817    match parts.len() {
 818        4 => {
 819            let local_port = parse_port_number(parts[1])?;
 820            let remote_port = parse_port_number(parts[3])?;
 821
 822            Ok(SshPortForwardOption {
 823                local_host: Some(parts[0].to_string()),
 824                local_port,
 825                remote_host: Some(parts[2].to_string()),
 826                remote_port,
 827            })
 828        }
 829        3 => {
 830            let local_port = parse_port_number(parts[0])?;
 831            let remote_port = parse_port_number(parts[2])?;
 832
 833            Ok(SshPortForwardOption {
 834                local_host: None,
 835                local_port,
 836                remote_host: Some(parts[1].to_string()),
 837                remote_port,
 838            })
 839        }
 840        _ => anyhow::bail!("Invalid port forward format"),
 841    }
 842}
 843
 844impl SshConnectionOptions {
 845    pub fn parse_command_line(input: &str) -> Result<Self> {
 846        let input = input.trim_start_matches("ssh ");
 847        let mut hostname: Option<String> = None;
 848        let mut username: Option<String> = None;
 849        let mut port: Option<u16> = None;
 850        let mut args = Vec::new();
 851        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 852
 853        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 854        const ALLOWED_OPTS: &[&str] = &[
 855            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 856        ];
 857        const ALLOWED_ARGS: &[&str] = &[
 858            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 859            "-w",
 860        ];
 861
 862        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 863
 864        'outer: while let Some(arg) = tokens.next() {
 865            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 866                args.push(arg.to_string());
 867                continue;
 868            }
 869            if arg == "-p" {
 870                port = tokens.next().and_then(|arg| arg.parse().ok());
 871                continue;
 872            } else if let Some(p) = arg.strip_prefix("-p") {
 873                port = p.parse().ok();
 874                continue;
 875            }
 876            if arg == "-l" {
 877                username = tokens.next();
 878                continue;
 879            } else if let Some(l) = arg.strip_prefix("-l") {
 880                username = Some(l.to_string());
 881                continue;
 882            }
 883            if arg == "-L" || arg.starts_with("-L") {
 884                let forward_spec = if arg == "-L" {
 885                    tokens.next()
 886                } else {
 887                    Some(arg.strip_prefix("-L").unwrap().to_string())
 888                };
 889
 890                if let Some(spec) = forward_spec {
 891                    port_forwards.push(parse_port_forward_spec(&spec)?);
 892                } else {
 893                    anyhow::bail!("Missing port forward format");
 894                }
 895            }
 896
 897            for a in ALLOWED_ARGS {
 898                if arg == *a {
 899                    args.push(arg);
 900                    if let Some(next) = tokens.next() {
 901                        args.push(next);
 902                    }
 903                    continue 'outer;
 904                } else if arg.starts_with(a) {
 905                    args.push(arg);
 906                    continue 'outer;
 907                }
 908            }
 909            if arg.starts_with("-") || hostname.is_some() {
 910                anyhow::bail!("unsupported argument: {:?}", arg);
 911            }
 912            let mut input = &arg as &str;
 913            // Destination might be: username1@username2@ip2@ip1
 914            if let Some((u, rest)) = input.rsplit_once('@') {
 915                input = rest;
 916                username = Some(u.to_string());
 917            }
 918            if let Some((rest, p)) = input.split_once(':') {
 919                input = rest;
 920                port = p.parse().ok()
 921            }
 922            hostname = Some(input.to_string())
 923        }
 924
 925        let Some(hostname) = hostname else {
 926            anyhow::bail!("missing hostname");
 927        };
 928
 929        let port_forwards = match port_forwards.len() {
 930            0 => None,
 931            _ => Some(port_forwards),
 932        };
 933
 934        Ok(Self {
 935            host: hostname,
 936            username,
 937            port,
 938            port_forwards,
 939            args: Some(args),
 940            password: None,
 941            nickname: None,
 942            upload_binary_over_ssh: false,
 943        })
 944    }
 945
 946    pub fn ssh_url(&self) -> String {
 947        let mut result = String::from("ssh://");
 948        if let Some(username) = &self.username {
 949            // Username might be: username1@username2@ip2
 950            let username = urlencoding::encode(username);
 951            result.push_str(&username);
 952            result.push('@');
 953        }
 954        result.push_str(&self.host);
 955        if let Some(port) = self.port {
 956            result.push(':');
 957            result.push_str(&port.to_string());
 958        }
 959        result
 960    }
 961
 962    pub fn additional_args(&self) -> Vec<String> {
 963        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 964
 965        if let Some(forwards) = &self.port_forwards {
 966            args.extend(forwards.iter().map(|pf| {
 967                let local_host = match &pf.local_host {
 968                    Some(host) => host,
 969                    None => "localhost",
 970                };
 971                let remote_host = match &pf.remote_host {
 972                    Some(host) => host,
 973                    None => "localhost",
 974                };
 975
 976                format!(
 977                    "-L{}:{}:{}:{}",
 978                    local_host, pf.local_port, remote_host, pf.remote_port
 979                )
 980            }));
 981        }
 982
 983        args
 984    }
 985
 986    fn scp_url(&self) -> String {
 987        if let Some(username) = &self.username {
 988            format!("{}@{}", username, self.host)
 989        } else {
 990            self.host.clone()
 991        }
 992    }
 993
 994    pub fn connection_string(&self) -> String {
 995        let host = if let Some(username) = &self.username {
 996            format!("{}@{}", username, self.host)
 997        } else {
 998            self.host.clone()
 999        };
1000        if let Some(port) = &self.port {
1001            format!("{}:{}", host, port)
1002        } else {
1003            host
1004        }
1005    }
1006}
1007
1008fn build_command(
1009    input_program: Option<String>,
1010    input_args: &[String],
1011    input_env: &HashMap<String, String>,
1012    working_dir: Option<String>,
1013    port_forward: Option<(u16, String, u16)>,
1014    ssh_env: HashMap<String, String>,
1015    ssh_path_style: PathStyle,
1016    ssh_shell: &str,
1017    ssh_args: Vec<String>,
1018) -> Result<CommandTemplate> {
1019    use std::fmt::Write as _;
1020
1021    let mut exec = String::from("exec env -C ");
1022    if let Some(working_dir) = working_dir {
1023        let working_dir = RemotePathBuf::new(working_dir.into(), ssh_path_style).to_string();
1024
1025        // shlex will wrap the command in single quotes (''), disabling ~ expansion,
1026        // replace with with something that works
1027        const TILDE_PREFIX: &'static str = "~/";
1028        if working_dir.starts_with(TILDE_PREFIX) {
1029            let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
1030            write!(exec, "\"$HOME/{working_dir}\" ",).unwrap();
1031        } else {
1032            write!(exec, "\"{working_dir}\" ",).unwrap();
1033        }
1034    } else {
1035        write!(exec, "\"$HOME\" ").unwrap();
1036    };
1037
1038    for (k, v) in input_env.iter() {
1039        if let Some((k, v)) = shlex::try_quote(k).ok().zip(shlex::try_quote(v).ok()) {
1040            write!(exec, "{}={} ", k, v).unwrap();
1041        }
1042    }
1043
1044    write!(exec, "{ssh_shell} ").unwrap();
1045    if let Some(input_program) = input_program {
1046        let mut script = shlex::try_quote(&input_program)?.into_owned();
1047        for arg in input_args {
1048            let arg = shlex::try_quote(&arg)?;
1049            script.push_str(" ");
1050            script.push_str(&arg);
1051        }
1052        write!(exec, "-c {}", shlex::try_quote(&script).unwrap()).unwrap();
1053    } else {
1054        write!(exec, "-l").unwrap();
1055    };
1056
1057    let mut args = Vec::new();
1058    args.extend(ssh_args);
1059
1060    if let Some((local_port, host, remote_port)) = port_forward {
1061        args.push("-L".into());
1062        args.push(format!("{local_port}:{host}:{remote_port}"));
1063    }
1064
1065    args.push("-t".into());
1066    args.push(exec);
1067    Ok(CommandTemplate {
1068        program: "ssh".into(),
1069        args,
1070        env: ssh_env,
1071    })
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077
1078    #[test]
1079    fn test_build_command() -> Result<()> {
1080        let mut input_env = HashMap::default();
1081        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1082        let mut env = HashMap::default();
1083        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1084
1085        let command = build_command(
1086            Some("remote_program".to_string()),
1087            &["arg1".to_string(), "arg2".to_string()],
1088            &input_env,
1089            Some("~/work".to_string()),
1090            None,
1091            env.clone(),
1092            PathStyle::Posix,
1093            "/bin/fish",
1094            vec!["-p".to_string(), "2222".to_string()],
1095        )?;
1096
1097        assert_eq!(command.program, "ssh");
1098        assert_eq!(
1099            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1100            [
1101                "-p",
1102                "2222",
1103                "-t",
1104                "exec env -C \"$HOME/work\" INPUT_VA=val /bin/fish -c 'remote_program arg1 arg2'"
1105            ]
1106        );
1107        assert_eq!(command.env, env);
1108
1109        let mut input_env = HashMap::default();
1110        input_env.insert("INPUT_VA".to_string(), "val".to_string());
1111        let mut env = HashMap::default();
1112        env.insert("SSH_VAR".to_string(), "ssh-val".to_string());
1113
1114        let command = build_command(
1115            None,
1116            &["arg1".to_string(), "arg2".to_string()],
1117            &input_env,
1118            None,
1119            Some((1, "foo".to_owned(), 2)),
1120            env.clone(),
1121            PathStyle::Posix,
1122            "/bin/fish",
1123            vec!["-p".to_string(), "2222".to_string()],
1124        )?;
1125
1126        assert_eq!(command.program, "ssh");
1127        assert_eq!(
1128            command.args.iter().map(String::as_str).collect::<Vec<_>>(),
1129            [
1130                "-p",
1131                "2222",
1132                "-L",
1133                "1:foo:2",
1134                "-t",
1135                "exec env -C \"$HOME\" INPUT_VA=val /bin/fish -l"
1136            ]
1137        );
1138        assert_eq!(command.env, env);
1139
1140        Ok(())
1141    }
1142}