ssh.rs

   1use crate::{
   2    RemoteClientDelegate, RemotePlatform,
   3    json_log::LogRecord,
   4    protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
   5    remote_client::{CommandTemplate, RemoteConnection},
   6};
   7use anyhow::{Context as _, Result, anyhow};
   8use async_trait::async_trait;
   9use collections::HashMap;
  10use futures::{
  11    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  12    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
  13    select_biased,
  14};
  15use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task};
  16use itertools::Itertools;
  17use parking_lot::Mutex;
  18use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  19use rpc::proto::Envelope;
  20use schemars::JsonSchema;
  21use serde::{Deserialize, Serialize};
  22use smol::{
  23    fs,
  24    process::{self, Child, Stdio},
  25};
  26use std::{
  27    iter,
  28    path::{Path, PathBuf},
  29    sync::Arc,
  30    time::Instant,
  31};
  32use tempfile::TempDir;
  33use util::{
  34    get_default_system_shell,
  35    paths::{PathStyle, RemotePathBuf},
  36};
  37
  38pub(crate) struct SshRemoteConnection {
  39    socket: SshSocket,
  40    master_process: Mutex<Option<Child>>,
  41    remote_binary_path: Option<RemotePathBuf>,
  42    ssh_platform: RemotePlatform,
  43    ssh_path_style: PathStyle,
  44    ssh_shell: String,
  45    _temp_dir: TempDir,
  46}
  47
  48#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  49pub struct SshConnectionOptions {
  50    pub host: String,
  51    pub username: Option<String>,
  52    pub port: Option<u16>,
  53    pub password: Option<String>,
  54    pub args: Option<Vec<String>>,
  55    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  56
  57    pub nickname: Option<String>,
  58    pub upload_binary_over_ssh: bool,
  59}
  60
  61#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  62pub struct SshPortForwardOption {
  63    #[serde(skip_serializing_if = "Option::is_none")]
  64    pub local_host: Option<String>,
  65    pub local_port: u16,
  66    #[serde(skip_serializing_if = "Option::is_none")]
  67    pub remote_host: Option<String>,
  68    pub remote_port: u16,
  69}
  70
  71#[derive(Clone)]
  72struct SshSocket {
  73    connection_options: SshConnectionOptions,
  74    #[cfg(not(target_os = "windows"))]
  75    socket_path: PathBuf,
  76    envs: HashMap<String, String>,
  77}
  78
  79macro_rules! shell_script {
  80    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
  81        format!(
  82            $fmt,
  83            $(
  84                $name = shlex::try_quote($arg).unwrap()
  85            ),+
  86        )
  87    }};
  88}
  89
  90#[async_trait(?Send)]
  91impl RemoteConnection for SshRemoteConnection {
  92    async fn kill(&self) -> Result<()> {
  93        let Some(mut process) = self.master_process.lock().take() else {
  94            return Ok(());
  95        };
  96        process.kill().ok();
  97        process.status().await?;
  98        Ok(())
  99    }
 100
 101    fn has_been_killed(&self) -> bool {
 102        self.master_process.lock().is_none()
 103    }
 104
 105    fn connection_options(&self) -> SshConnectionOptions {
 106        self.socket.connection_options.clone()
 107    }
 108
 109    fn shell(&self) -> String {
 110        self.ssh_shell.clone()
 111    }
 112
 113    fn build_command(
 114        &self,
 115        input_program: Option<String>,
 116        input_args: &[String],
 117        input_env: &HashMap<String, String>,
 118        working_dir: Option<String>,
 119        activation_script: Option<String>,
 120        port_forward: Option<(u16, String, u16)>,
 121    ) -> Result<CommandTemplate> {
 122        use std::fmt::Write as _;
 123
 124        let mut script = String::new();
 125        if let Some(working_dir) = working_dir {
 126            let working_dir =
 127                RemotePathBuf::new(working_dir.into(), self.ssh_path_style).to_string();
 128
 129            // shlex will wrap the command in single quotes (''), disabling ~ expansion,
 130            // replace ith with something that works
 131            const TILDE_PREFIX: &'static str = "~/";
 132            if working_dir.starts_with(TILDE_PREFIX) {
 133                let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
 134                write!(&mut script, "cd \"$HOME/{working_dir}\"; ").unwrap();
 135            } else {
 136                write!(&mut script, "cd \"{working_dir}\"; ").unwrap();
 137            }
 138        } else {
 139            write!(&mut script, "cd; ").unwrap();
 140        };
 141        if let Some(activation_script) = activation_script {
 142            write!(&mut script, " {activation_script};").unwrap();
 143        }
 144
 145        for (k, v) in input_env.iter() {
 146            if let Some((k, v)) = shlex::try_quote(k).ok().zip(shlex::try_quote(v).ok()) {
 147                write!(&mut script, "{}={} ", k, v).unwrap();
 148            }
 149        }
 150
 151        let shell = &self.ssh_shell;
 152
 153        if let Some(input_program) = input_program {
 154            let command = shlex::try_quote(&input_program)?;
 155            script.push_str(&command);
 156            for arg in input_args {
 157                let arg = shlex::try_quote(&arg)?;
 158                script.push_str(" ");
 159                script.push_str(&arg);
 160            }
 161        } else {
 162            write!(&mut script, "exec {shell} -l").unwrap();
 163        };
 164
 165        let sys_shell = get_default_system_shell();
 166        let shell_invocation = format!("{sys_shell} -c {}", shlex::try_quote(&script).unwrap());
 167
 168        let mut args = Vec::new();
 169        args.extend(self.socket.ssh_args());
 170
 171        if let Some((local_port, host, remote_port)) = port_forward {
 172            args.push("-L".into());
 173            args.push(format!("{local_port}:{host}:{remote_port}"));
 174        }
 175
 176        args.push("-t".into());
 177        args.push(shell_invocation);
 178        Ok(CommandTemplate {
 179            program: "ssh".into(),
 180            args,
 181            env: self.socket.envs.clone(),
 182        })
 183    }
 184
 185    fn upload_directory(
 186        &self,
 187        src_path: PathBuf,
 188        dest_path: RemotePathBuf,
 189        cx: &App,
 190    ) -> Task<Result<()>> {
 191        let mut command = util::command::new_smol_command("scp");
 192        let output = self
 193            .socket
 194            .ssh_options(&mut command)
 195            .args(
 196                self.socket
 197                    .connection_options
 198                    .port
 199                    .map(|port| vec!["-P".to_string(), port.to_string()])
 200                    .unwrap_or_default(),
 201            )
 202            .arg("-C")
 203            .arg("-r")
 204            .arg(&src_path)
 205            .arg(format!(
 206                "{}:{}",
 207                self.socket.connection_options.scp_url(),
 208                dest_path
 209            ))
 210            .output();
 211
 212        cx.background_spawn(async move {
 213            let output = output.await?;
 214
 215            anyhow::ensure!(
 216                output.status.success(),
 217                "failed to upload directory {} -> {}: {}",
 218                src_path.display(),
 219                dest_path.to_string(),
 220                String::from_utf8_lossy(&output.stderr)
 221            );
 222
 223            Ok(())
 224        })
 225    }
 226
 227    fn start_proxy(
 228        &self,
 229        unique_identifier: String,
 230        reconnect: bool,
 231        incoming_tx: UnboundedSender<Envelope>,
 232        outgoing_rx: UnboundedReceiver<Envelope>,
 233        connection_activity_tx: Sender<()>,
 234        delegate: Arc<dyn RemoteClientDelegate>,
 235        cx: &mut AsyncApp,
 236    ) -> Task<Result<i32>> {
 237        delegate.set_status(Some("Starting proxy"), cx);
 238
 239        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
 240            return Task::ready(Err(anyhow!("Remote binary path not set")));
 241        };
 242
 243        let mut start_proxy_command = shell_script!(
 244            "exec {binary_path} proxy --identifier {identifier}",
 245            binary_path = &remote_binary_path.to_string(),
 246            identifier = &unique_identifier,
 247        );
 248
 249        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
 250            if let Some(value) = std::env::var(env_var).ok() {
 251                start_proxy_command = format!(
 252                    "{}={} {} ",
 253                    env_var,
 254                    shlex::try_quote(&value).unwrap(),
 255                    start_proxy_command,
 256                );
 257            }
 258        }
 259
 260        if reconnect {
 261            start_proxy_command.push_str(" --reconnect");
 262        }
 263
 264        let ssh_proxy_process = match self
 265            .socket
 266            .ssh_command("sh", &["-lc", &start_proxy_command])
 267            // IMPORTANT: we kill this process when we drop the task that uses it.
 268            .kill_on_drop(true)
 269            .spawn()
 270        {
 271            Ok(process) => process,
 272            Err(error) => {
 273                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
 274            }
 275        };
 276
 277        Self::multiplex(
 278            ssh_proxy_process,
 279            incoming_tx,
 280            outgoing_rx,
 281            connection_activity_tx,
 282            cx,
 283        )
 284    }
 285
 286    fn path_style(&self) -> PathStyle {
 287        self.ssh_path_style
 288    }
 289}
 290
 291impl SshRemoteConnection {
 292    pub(crate) async fn new(
 293        connection_options: SshConnectionOptions,
 294        delegate: Arc<dyn RemoteClientDelegate>,
 295        cx: &mut AsyncApp,
 296    ) -> Result<Self> {
 297        use askpass::AskPassResult;
 298
 299        delegate.set_status(Some("Connecting"), cx);
 300
 301        let url = connection_options.ssh_url();
 302
 303        let temp_dir = tempfile::Builder::new()
 304            .prefix("zed-ssh-session")
 305            .tempdir()?;
 306        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
 307            let delegate = delegate.clone();
 308            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
 309        });
 310
 311        let mut askpass =
 312            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
 313
 314        // Start the master SSH process, which does not do anything except for establish
 315        // the connection and keep it open, allowing other ssh commands to reuse it
 316        // via a control socket.
 317        #[cfg(not(target_os = "windows"))]
 318        let socket_path = temp_dir.path().join("ssh.sock");
 319
 320        let mut master_process = {
 321            #[cfg(not(target_os = "windows"))]
 322            let args = [
 323                "-N",
 324                "-o",
 325                "ControlPersist=no",
 326                "-o",
 327                "ControlMaster=yes",
 328                "-o",
 329            ];
 330            // On Windows, `ControlMaster` and `ControlPath` are not supported:
 331            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
 332            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
 333            #[cfg(target_os = "windows")]
 334            let args = ["-N"];
 335            let mut master_process = util::command::new_smol_command("ssh");
 336            master_process
 337                .kill_on_drop(true)
 338                .stdin(Stdio::null())
 339                .stdout(Stdio::piped())
 340                .stderr(Stdio::piped())
 341                .env("SSH_ASKPASS_REQUIRE", "force")
 342                .env("SSH_ASKPASS", askpass.script_path())
 343                .args(connection_options.additional_args())
 344                .args(args);
 345            #[cfg(not(target_os = "windows"))]
 346            master_process.arg(format!("ControlPath={}", socket_path.display()));
 347            master_process.arg(&url).spawn()?
 348        };
 349        // Wait for this ssh process to close its stdout, indicating that authentication
 350        // has completed.
 351        let mut stdout = master_process.stdout.take().unwrap();
 352        let mut output = Vec::new();
 353
 354        let result = select_biased! {
 355            result = askpass.run().fuse() => {
 356                match result {
 357                    AskPassResult::CancelledByUser => {
 358                        master_process.kill().ok();
 359                        anyhow::bail!("SSH connection canceled")
 360                    }
 361                    AskPassResult::Timedout => {
 362                        anyhow::bail!("connecting to host timed out")
 363                    }
 364                }
 365            }
 366            _ = stdout.read_to_end(&mut output).fuse() => {
 367                anyhow::Ok(())
 368            }
 369        };
 370
 371        if let Err(e) = result {
 372            return Err(e.context("Failed to connect to host"));
 373        }
 374
 375        if master_process.try_status()?.is_some() {
 376            output.clear();
 377            let mut stderr = master_process.stderr.take().unwrap();
 378            stderr.read_to_end(&mut output).await?;
 379
 380            let error_message = format!(
 381                "failed to connect: {}",
 382                String::from_utf8_lossy(&output).trim()
 383            );
 384            anyhow::bail!(error_message);
 385        }
 386
 387        #[cfg(not(target_os = "windows"))]
 388        let socket = SshSocket::new(connection_options, socket_path)?;
 389        #[cfg(target_os = "windows")]
 390        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
 391        drop(askpass);
 392
 393        let ssh_platform = socket.platform().await?;
 394        let ssh_path_style = match ssh_platform.os {
 395            "windows" => PathStyle::Windows,
 396            _ => PathStyle::Posix,
 397        };
 398        let ssh_shell = socket.shell().await;
 399
 400        let mut this = Self {
 401            socket,
 402            master_process: Mutex::new(Some(master_process)),
 403            _temp_dir: temp_dir,
 404            remote_binary_path: None,
 405            ssh_path_style,
 406            ssh_platform,
 407            ssh_shell,
 408        };
 409
 410        let (release_channel, version, commit) = cx.update(|cx| {
 411            (
 412                ReleaseChannel::global(cx),
 413                AppVersion::global(cx),
 414                AppCommitSha::try_global(cx),
 415            )
 416        })?;
 417        this.remote_binary_path = Some(
 418            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
 419                .await?,
 420        );
 421
 422        Ok(this)
 423    }
 424
 425    fn multiplex(
 426        mut ssh_proxy_process: Child,
 427        incoming_tx: UnboundedSender<Envelope>,
 428        mut outgoing_rx: UnboundedReceiver<Envelope>,
 429        mut connection_activity_tx: Sender<()>,
 430        cx: &AsyncApp,
 431    ) -> Task<Result<i32>> {
 432        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 433        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 434        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 435
 436        let mut stdin_buffer = Vec::new();
 437        let mut stdout_buffer = Vec::new();
 438        let mut stderr_buffer = Vec::new();
 439        let mut stderr_offset = 0;
 440
 441        let stdin_task = cx.background_spawn(async move {
 442            while let Some(outgoing) = outgoing_rx.next().await {
 443                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 444            }
 445            anyhow::Ok(())
 446        });
 447
 448        let stdout_task = cx.background_spawn({
 449            let mut connection_activity_tx = connection_activity_tx.clone();
 450            async move {
 451                loop {
 452                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 453                    let len = child_stdout.read(&mut stdout_buffer).await?;
 454
 455                    if len == 0 {
 456                        return anyhow::Ok(());
 457                    }
 458
 459                    if len < MESSAGE_LEN_SIZE {
 460                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 461                    }
 462
 463                    let message_len = message_len_from_buffer(&stdout_buffer);
 464                    let envelope =
 465                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
 466                            .await?;
 467                    connection_activity_tx.try_send(()).ok();
 468                    incoming_tx.unbounded_send(envelope).ok();
 469                }
 470            }
 471        });
 472
 473        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
 474            loop {
 475                stderr_buffer.resize(stderr_offset + 1024, 0);
 476
 477                let len = child_stderr
 478                    .read(&mut stderr_buffer[stderr_offset..])
 479                    .await?;
 480                if len == 0 {
 481                    return anyhow::Ok(());
 482                }
 483
 484                stderr_offset += len;
 485                let mut start_ix = 0;
 486                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
 487                    .iter()
 488                    .position(|b| b == &b'\n')
 489                {
 490                    let line_ix = start_ix + ix;
 491                    let content = &stderr_buffer[start_ix..line_ix];
 492                    start_ix = line_ix + 1;
 493                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
 494                        record.log(log::logger())
 495                    } else {
 496                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 497                    }
 498                }
 499                stderr_buffer.drain(0..start_ix);
 500                stderr_offset -= start_ix;
 501
 502                connection_activity_tx.try_send(()).ok();
 503            }
 504        });
 505
 506        cx.background_spawn(async move {
 507            let result = futures::select! {
 508                result = stdin_task.fuse() => {
 509                    result.context("stdin")
 510                }
 511                result = stdout_task.fuse() => {
 512                    result.context("stdout")
 513                }
 514                result = stderr_task.fuse() => {
 515                    result.context("stderr")
 516                }
 517            };
 518
 519            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
 520            match result {
 521                Ok(_) => Ok(status),
 522                Err(error) => Err(error),
 523            }
 524        })
 525    }
 526
 527    #[allow(unused)]
 528    async fn ensure_server_binary(
 529        &self,
 530        delegate: &Arc<dyn RemoteClientDelegate>,
 531        release_channel: ReleaseChannel,
 532        version: SemanticVersion,
 533        commit: Option<AppCommitSha>,
 534        cx: &mut AsyncApp,
 535    ) -> Result<RemotePathBuf> {
 536        let version_str = match release_channel {
 537            ReleaseChannel::Nightly => {
 538                let commit = commit.map(|s| s.full()).unwrap_or_default();
 539                format!("{}-{}", version, commit)
 540            }
 541            ReleaseChannel::Dev => "build".to_string(),
 542            _ => version.to_string(),
 543        };
 544        let binary_name = format!(
 545            "zed-remote-server-{}-{}",
 546            release_channel.dev_name(),
 547            version_str
 548        );
 549        let dst_path = RemotePathBuf::new(
 550            paths::remote_server_dir_relative().join(binary_name),
 551            self.ssh_path_style,
 552        );
 553
 554        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
 555        #[cfg(debug_assertions)]
 556        if let Some(build_remote_server) = build_remote_server {
 557            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
 558            let tmp_path = RemotePathBuf::new(
 559                paths::remote_server_dir_relative().join(format!(
 560                    "download-{}-{}",
 561                    std::process::id(),
 562                    src_path.file_name().unwrap().to_string_lossy()
 563                )),
 564                self.ssh_path_style,
 565            );
 566            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
 567                .await?;
 568            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
 569                .await?;
 570            return Ok(dst_path);
 571        }
 572
 573        if self
 574            .socket
 575            .run_command(&dst_path.to_string(), &["version"])
 576            .await
 577            .is_ok()
 578        {
 579            return Ok(dst_path);
 580        }
 581
 582        let wanted_version = cx.update(|cx| match release_channel {
 583            ReleaseChannel::Nightly => Ok(None),
 584            ReleaseChannel::Dev => {
 585                anyhow::bail!(
 586                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
 587                    dst_path
 588                )
 589            }
 590            _ => Ok(Some(AppVersion::global(cx))),
 591        })??;
 592
 593        let tmp_path_gz = RemotePathBuf::new(
 594            PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
 595            self.ssh_path_style,
 596        );
 597        if !self.socket.connection_options.upload_binary_over_ssh
 598            && let Some((url, body)) = delegate
 599                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
 600                .await?
 601        {
 602            match self
 603                .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
 604                .await
 605            {
 606                Ok(_) => {
 607                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 608                        .await?;
 609                    return Ok(dst_path);
 610                }
 611                Err(e) => {
 612                    log::error!(
 613                        "Failed to download binary on server, attempting to upload server: {}",
 614                        e
 615                    )
 616                }
 617            }
 618        }
 619
 620        let src_path = delegate
 621            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
 622            .await?;
 623        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
 624            .await?;
 625        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
 626            .await?;
 627        Ok(dst_path)
 628    }
 629
 630    async fn download_binary_on_server(
 631        &self,
 632        url: &str,
 633        body: &str,
 634        tmp_path_gz: &RemotePathBuf,
 635        delegate: &Arc<dyn RemoteClientDelegate>,
 636        cx: &mut AsyncApp,
 637    ) -> Result<()> {
 638        if let Some(parent) = tmp_path_gz.parent() {
 639            self.socket
 640                .run_command(
 641                    "sh",
 642                    &[
 643                        "-lc",
 644                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
 645                    ],
 646                )
 647                .await?;
 648        }
 649
 650        delegate.set_status(Some("Downloading remote development server on host"), cx);
 651
 652        match self
 653            .socket
 654            .run_command(
 655                "curl",
 656                &[
 657                    "-f",
 658                    "-L",
 659                    "-X",
 660                    "GET",
 661                    "-H",
 662                    "Content-Type: application/json",
 663                    "-d",
 664                    body,
 665                    url,
 666                    "-o",
 667                    &tmp_path_gz.to_string(),
 668                ],
 669            )
 670            .await
 671        {
 672            Ok(_) => {}
 673            Err(e) => {
 674                if self.socket.run_command("which", &["curl"]).await.is_ok() {
 675                    return Err(e);
 676                }
 677
 678                match self
 679                    .socket
 680                    .run_command(
 681                        "wget",
 682                        &[
 683                            "--method=GET",
 684                            "--header=Content-Type: application/json",
 685                            "--body-data",
 686                            body,
 687                            url,
 688                            "-O",
 689                            &tmp_path_gz.to_string(),
 690                        ],
 691                    )
 692                    .await
 693                {
 694                    Ok(_) => {}
 695                    Err(e) => {
 696                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
 697                            return Err(e);
 698                        } else {
 699                            anyhow::bail!("Neither curl nor wget is available");
 700                        }
 701                    }
 702                }
 703            }
 704        }
 705
 706        Ok(())
 707    }
 708
 709    async fn upload_local_server_binary(
 710        &self,
 711        src_path: &Path,
 712        tmp_path_gz: &RemotePathBuf,
 713        delegate: &Arc<dyn RemoteClientDelegate>,
 714        cx: &mut AsyncApp,
 715    ) -> Result<()> {
 716        if let Some(parent) = tmp_path_gz.parent() {
 717            self.socket
 718                .run_command(
 719                    "sh",
 720                    &[
 721                        "-lc",
 722                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
 723                    ],
 724                )
 725                .await?;
 726        }
 727
 728        let src_stat = fs::metadata(&src_path).await?;
 729        let size = src_stat.len();
 730
 731        let t0 = Instant::now();
 732        delegate.set_status(Some("Uploading remote development server"), cx);
 733        log::info!(
 734            "uploading remote development server to {:?} ({}kb)",
 735            tmp_path_gz,
 736            size / 1024
 737        );
 738        self.upload_file(src_path, tmp_path_gz)
 739            .await
 740            .context("failed to upload server binary")?;
 741        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 742        Ok(())
 743    }
 744
 745    async fn extract_server_binary(
 746        &self,
 747        dst_path: &RemotePathBuf,
 748        tmp_path: &RemotePathBuf,
 749        delegate: &Arc<dyn RemoteClientDelegate>,
 750        cx: &mut AsyncApp,
 751    ) -> Result<()> {
 752        delegate.set_status(Some("Extracting remote development server"), cx);
 753        let server_mode = 0o755;
 754
 755        let orig_tmp_path = tmp_path.to_string();
 756        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
 757            shell_script!(
 758                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
 759                server_mode = &format!("{:o}", server_mode),
 760                dst_path = &dst_path.to_string(),
 761            )
 762        } else {
 763            shell_script!(
 764                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
 765                server_mode = &format!("{:o}", server_mode),
 766                dst_path = &dst_path.to_string()
 767            )
 768        };
 769        self.socket.run_command("sh", &["-lc", &script]).await?;
 770        Ok(())
 771    }
 772
 773    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
 774        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
 775        let mut command = util::command::new_smol_command("scp");
 776        let output = self
 777            .socket
 778            .ssh_options(&mut command)
 779            .args(
 780                self.socket
 781                    .connection_options
 782                    .port
 783                    .map(|port| vec!["-P".to_string(), port.to_string()])
 784                    .unwrap_or_default(),
 785            )
 786            .arg(src_path)
 787            .arg(format!(
 788                "{}:{}",
 789                self.socket.connection_options.scp_url(),
 790                dest_path
 791            ))
 792            .output()
 793            .await?;
 794
 795        anyhow::ensure!(
 796            output.status.success(),
 797            "failed to upload file {} -> {}: {}",
 798            src_path.display(),
 799            dest_path.to_string(),
 800            String::from_utf8_lossy(&output.stderr)
 801        );
 802        Ok(())
 803    }
 804
 805    #[cfg(debug_assertions)]
 806    async fn build_local(
 807        &self,
 808        build_remote_server: String,
 809        delegate: &Arc<dyn RemoteClientDelegate>,
 810        cx: &mut AsyncApp,
 811    ) -> Result<PathBuf> {
 812        use smol::process::{Command, Stdio};
 813        use std::env::VarError;
 814
 815        async fn run_cmd(command: &mut Command) -> Result<()> {
 816            let output = command
 817                .kill_on_drop(true)
 818                .stderr(Stdio::inherit())
 819                .output()
 820                .await?;
 821            anyhow::ensure!(
 822                output.status.success(),
 823                "Failed to run command: {command:?}"
 824            );
 825            Ok(())
 826        }
 827
 828        let use_musl = !build_remote_server.contains("nomusl");
 829        let triple = format!(
 830            "{}-{}",
 831            self.ssh_platform.arch,
 832            match self.ssh_platform.os {
 833                "linux" =>
 834                    if use_musl {
 835                        "unknown-linux-musl"
 836                    } else {
 837                        "unknown-linux-gnu"
 838                    },
 839                "macos" => "apple-darwin",
 840                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
 841            }
 842        );
 843        let mut rust_flags = match std::env::var("RUSTFLAGS") {
 844            Ok(val) => val,
 845            Err(VarError::NotPresent) => String::new(),
 846            Err(e) => {
 847                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
 848                String::new()
 849            }
 850        };
 851        if self.ssh_platform.os == "linux" && use_musl {
 852            rust_flags.push_str(" -C target-feature=+crt-static");
 853        }
 854        if build_remote_server.contains("mold") {
 855            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
 856        }
 857
 858        if self.ssh_platform.arch == std::env::consts::ARCH
 859            && self.ssh_platform.os == std::env::consts::OS
 860        {
 861            delegate.set_status(Some("Building remote server binary from source"), cx);
 862            log::info!("building remote server binary from source");
 863            run_cmd(
 864                Command::new("cargo")
 865                    .args([
 866                        "build",
 867                        "--package",
 868                        "remote_server",
 869                        "--features",
 870                        "debug-embed",
 871                        "--target-dir",
 872                        "target/remote_server",
 873                        "--target",
 874                        &triple,
 875                    ])
 876                    .env("RUSTFLAGS", &rust_flags),
 877            )
 878            .await?;
 879        } else if build_remote_server.contains("cross") {
 880            #[cfg(target_os = "windows")]
 881            use util::paths::SanitizedPath;
 882
 883            delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
 884            log::info!("installing cross");
 885            run_cmd(Command::new("cargo").args([
 886                "install",
 887                "cross",
 888                "--git",
 889                "https://github.com/cross-rs/cross",
 890            ]))
 891            .await?;
 892
 893            delegate.set_status(
 894                Some(&format!(
 895                    "Building remote server binary from source for {} with Docker",
 896                    &triple
 897                )),
 898                cx,
 899            );
 900            log::info!("building remote server binary from source for {}", &triple);
 901
 902            // On Windows, the binding needs to be set to the canonical path
 903            #[cfg(target_os = "windows")]
 904            let src =
 905                SanitizedPath::new(&smol::fs::canonicalize("./target").await?).to_glob_string();
 906            #[cfg(not(target_os = "windows"))]
 907            let src = "./target";
 908            run_cmd(
 909                Command::new("cross")
 910                    .args([
 911                        "build",
 912                        "--package",
 913                        "remote_server",
 914                        "--features",
 915                        "debug-embed",
 916                        "--target-dir",
 917                        "target/remote_server",
 918                        "--target",
 919                        &triple,
 920                    ])
 921                    .env(
 922                        "CROSS_CONTAINER_OPTS",
 923                        format!("--mount type=bind,src={src},dst=/app/target"),
 924                    )
 925                    .env("RUSTFLAGS", &rust_flags),
 926            )
 927            .await?;
 928        } else {
 929            let which = cx
 930                .background_spawn(async move { which::which("zig") })
 931                .await;
 932
 933            if which.is_err() {
 934                #[cfg(not(target_os = "windows"))]
 935                {
 936                    anyhow::bail!(
 937                        "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
 938                    )
 939                }
 940                #[cfg(target_os = "windows")]
 941                {
 942                    anyhow::bail!(
 943                        "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
 944                    )
 945                }
 946            }
 947
 948            delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
 949            log::info!("adding rustup target");
 950            run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
 951
 952            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
 953            log::info!("installing cargo-zigbuild");
 954            run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
 955
 956            delegate.set_status(
 957                Some(&format!(
 958                    "Building remote binary from source for {triple} with Zig"
 959                )),
 960                cx,
 961            );
 962            log::info!("building remote binary from source for {triple} with Zig");
 963            run_cmd(
 964                Command::new("cargo")
 965                    .args([
 966                        "zigbuild",
 967                        "--package",
 968                        "remote_server",
 969                        "--features",
 970                        "debug-embed",
 971                        "--target-dir",
 972                        "target/remote_server",
 973                        "--target",
 974                        &triple,
 975                    ])
 976                    .env("RUSTFLAGS", &rust_flags),
 977            )
 978            .await?;
 979        };
 980        let bin_path = Path::new("target")
 981            .join("remote_server")
 982            .join(&triple)
 983            .join("debug")
 984            .join("remote_server");
 985
 986        let path = if !build_remote_server.contains("nocompress") {
 987            delegate.set_status(Some("Compressing binary"), cx);
 988
 989            #[cfg(not(target_os = "windows"))]
 990            {
 991                run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
 992            }
 993            #[cfg(target_os = "windows")]
 994            {
 995                // On Windows, we use 7z to compress the binary
 996                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
 997                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
 998                if smol::fs::metadata(&gz_path).await.is_ok() {
 999                    smol::fs::remove_file(&gz_path).await?;
1000                }
1001                run_cmd(Command::new(seven_zip).args([
1002                    "a",
1003                    "-tgzip",
1004                    &gz_path,
1005                    &bin_path.to_string_lossy(),
1006                ]))
1007                .await?;
1008            }
1009
1010            let mut archive_path = bin_path;
1011            archive_path.set_extension("gz");
1012            std::env::current_dir()?.join(archive_path)
1013        } else {
1014            bin_path
1015        };
1016
1017        Ok(path)
1018    }
1019}
1020
1021impl SshSocket {
1022    #[cfg(not(target_os = "windows"))]
1023    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
1024        Ok(Self {
1025            connection_options: options,
1026            envs: HashMap::default(),
1027            socket_path,
1028        })
1029    }
1030
1031    #[cfg(target_os = "windows")]
1032    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
1033        let askpass_script = temp_dir.path().join("askpass.bat");
1034        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
1035        let mut envs = HashMap::default();
1036        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
1037        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
1038        envs.insert("ZED_SSH_ASKPASS".into(), secret);
1039        Ok(Self {
1040            connection_options: options,
1041            envs,
1042        })
1043    }
1044
1045    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
1046    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
1047    // and passes -l as an argument to sh, not to ls.
1048    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
1049    // into a machine. You must use `cd` to get back to $HOME.
1050    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
1051    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
1052        let mut command = util::command::new_smol_command("ssh");
1053        let to_run = iter::once(&program)
1054            .chain(args.iter())
1055            .map(|token| {
1056                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
1057                debug_assert!(
1058                    !token.contains('\n'),
1059                    "multiline arguments do not work in all shells"
1060                );
1061                shlex::try_quote(token).unwrap()
1062            })
1063            .join(" ");
1064        let to_run = format!("cd; {to_run}");
1065        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
1066        self.ssh_options(&mut command)
1067            .arg(self.connection_options.ssh_url())
1068            .arg(to_run);
1069        command
1070    }
1071
1072    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
1073        let output = self.ssh_command(program, args).output().await?;
1074        anyhow::ensure!(
1075            output.status.success(),
1076            "failed to run command: {}",
1077            String::from_utf8_lossy(&output.stderr)
1078        );
1079        Ok(String::from_utf8_lossy(&output.stdout).to_string())
1080    }
1081
1082    #[cfg(not(target_os = "windows"))]
1083    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
1084        command
1085            .stdin(Stdio::piped())
1086            .stdout(Stdio::piped())
1087            .stderr(Stdio::piped())
1088            .args(self.connection_options.additional_args())
1089            .args(["-o", "ControlMaster=no", "-o"])
1090            .arg(format!("ControlPath={}", self.socket_path.display()))
1091    }
1092
1093    #[cfg(target_os = "windows")]
1094    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
1095        command
1096            .stdin(Stdio::piped())
1097            .stdout(Stdio::piped())
1098            .stderr(Stdio::piped())
1099            .args(self.connection_options.additional_args())
1100            .envs(self.envs.clone())
1101    }
1102
1103    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1104    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1105    #[cfg(not(target_os = "windows"))]
1106    fn ssh_args(&self) -> Vec<String> {
1107        let mut arguments = self.connection_options.additional_args();
1108        arguments.extend(vec![
1109            "-o".to_string(),
1110            "ControlMaster=no".to_string(),
1111            "-o".to_string(),
1112            format!("ControlPath={}", self.socket_path.display()),
1113            self.connection_options.ssh_url(),
1114        ]);
1115        arguments
1116    }
1117
1118    #[cfg(target_os = "windows")]
1119    fn ssh_args(&self) -> Vec<String> {
1120        let mut arguments = self.connection_options.additional_args();
1121        arguments.push(self.connection_options.ssh_url());
1122        arguments
1123    }
1124
1125    async fn platform(&self) -> Result<RemotePlatform> {
1126        let uname = self.run_command("sh", &["-lc", "uname -sm"]).await?;
1127        let Some((os, arch)) = uname.split_once(" ") else {
1128            anyhow::bail!("unknown uname: {uname:?}")
1129        };
1130
1131        let os = match os.trim() {
1132            "Darwin" => "macos",
1133            "Linux" => "linux",
1134            _ => anyhow::bail!(
1135                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
1136            ),
1137        };
1138        // exclude armv5,6,7 as they are 32-bit.
1139        let arch = if arch.starts_with("armv8")
1140            || arch.starts_with("armv9")
1141            || arch.starts_with("arm64")
1142            || arch.starts_with("aarch64")
1143        {
1144            "aarch64"
1145        } else if arch.starts_with("x86") {
1146            "x86_64"
1147        } else {
1148            anyhow::bail!(
1149                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
1150            )
1151        };
1152
1153        Ok(RemotePlatform { os, arch })
1154    }
1155
1156    async fn shell(&self) -> String {
1157        match self.run_command("sh", &["-lc", "echo $SHELL"]).await {
1158            Ok(shell) => shell.trim().to_owned(),
1159            Err(e) => {
1160                log::error!("Failed to get shell: {e}");
1161                "sh".to_owned()
1162            }
1163        }
1164    }
1165}
1166
1167fn parse_port_number(port_str: &str) -> Result<u16> {
1168    port_str
1169        .parse()
1170        .with_context(|| format!("parsing port number: {port_str}"))
1171}
1172
1173fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
1174    let parts: Vec<&str> = spec.split(':').collect();
1175
1176    match parts.len() {
1177        4 => {
1178            let local_port = parse_port_number(parts[1])?;
1179            let remote_port = parse_port_number(parts[3])?;
1180
1181            Ok(SshPortForwardOption {
1182                local_host: Some(parts[0].to_string()),
1183                local_port,
1184                remote_host: Some(parts[2].to_string()),
1185                remote_port,
1186            })
1187        }
1188        3 => {
1189            let local_port = parse_port_number(parts[0])?;
1190            let remote_port = parse_port_number(parts[2])?;
1191
1192            Ok(SshPortForwardOption {
1193                local_host: None,
1194                local_port,
1195                remote_host: Some(parts[1].to_string()),
1196                remote_port,
1197            })
1198        }
1199        _ => anyhow::bail!("Invalid port forward format"),
1200    }
1201}
1202
1203impl SshConnectionOptions {
1204    pub fn parse_command_line(input: &str) -> Result<Self> {
1205        let input = input.trim_start_matches("ssh ");
1206        let mut hostname: Option<String> = None;
1207        let mut username: Option<String> = None;
1208        let mut port: Option<u16> = None;
1209        let mut args = Vec::new();
1210        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
1211
1212        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
1213        const ALLOWED_OPTS: &[&str] = &[
1214            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
1215        ];
1216        const ALLOWED_ARGS: &[&str] = &[
1217            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
1218            "-w",
1219        ];
1220
1221        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
1222
1223        'outer: while let Some(arg) = tokens.next() {
1224            if ALLOWED_OPTS.contains(&(&arg as &str)) {
1225                args.push(arg.to_string());
1226                continue;
1227            }
1228            if arg == "-p" {
1229                port = tokens.next().and_then(|arg| arg.parse().ok());
1230                continue;
1231            } else if let Some(p) = arg.strip_prefix("-p") {
1232                port = p.parse().ok();
1233                continue;
1234            }
1235            if arg == "-l" {
1236                username = tokens.next();
1237                continue;
1238            } else if let Some(l) = arg.strip_prefix("-l") {
1239                username = Some(l.to_string());
1240                continue;
1241            }
1242            if arg == "-L" || arg.starts_with("-L") {
1243                let forward_spec = if arg == "-L" {
1244                    tokens.next()
1245                } else {
1246                    Some(arg.strip_prefix("-L").unwrap().to_string())
1247                };
1248
1249                if let Some(spec) = forward_spec {
1250                    port_forwards.push(parse_port_forward_spec(&spec)?);
1251                } else {
1252                    anyhow::bail!("Missing port forward format");
1253                }
1254            }
1255
1256            for a in ALLOWED_ARGS {
1257                if arg == *a {
1258                    args.push(arg);
1259                    if let Some(next) = tokens.next() {
1260                        args.push(next);
1261                    }
1262                    continue 'outer;
1263                } else if arg.starts_with(a) {
1264                    args.push(arg);
1265                    continue 'outer;
1266                }
1267            }
1268            if arg.starts_with("-") || hostname.is_some() {
1269                anyhow::bail!("unsupported argument: {:?}", arg);
1270            }
1271            let mut input = &arg as &str;
1272            // Destination might be: username1@username2@ip2@ip1
1273            if let Some((u, rest)) = input.rsplit_once('@') {
1274                input = rest;
1275                username = Some(u.to_string());
1276            }
1277            if let Some((rest, p)) = input.split_once(':') {
1278                input = rest;
1279                port = p.parse().ok()
1280            }
1281            hostname = Some(input.to_string())
1282        }
1283
1284        let Some(hostname) = hostname else {
1285            anyhow::bail!("missing hostname");
1286        };
1287
1288        let port_forwards = match port_forwards.len() {
1289            0 => None,
1290            _ => Some(port_forwards),
1291        };
1292
1293        Ok(Self {
1294            host: hostname,
1295            username,
1296            port,
1297            port_forwards,
1298            args: Some(args),
1299            password: None,
1300            nickname: None,
1301            upload_binary_over_ssh: false,
1302        })
1303    }
1304
1305    pub fn ssh_url(&self) -> String {
1306        let mut result = String::from("ssh://");
1307        if let Some(username) = &self.username {
1308            // Username might be: username1@username2@ip2
1309            let username = urlencoding::encode(username);
1310            result.push_str(&username);
1311            result.push('@');
1312        }
1313        result.push_str(&self.host);
1314        if let Some(port) = self.port {
1315            result.push(':');
1316            result.push_str(&port.to_string());
1317        }
1318        result
1319    }
1320
1321    pub fn additional_args(&self) -> Vec<String> {
1322        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
1323
1324        if let Some(forwards) = &self.port_forwards {
1325            args.extend(forwards.iter().map(|pf| {
1326                let local_host = match &pf.local_host {
1327                    Some(host) => host,
1328                    None => "localhost",
1329                };
1330                let remote_host = match &pf.remote_host {
1331                    Some(host) => host,
1332                    None => "localhost",
1333                };
1334
1335                format!(
1336                    "-L{}:{}:{}:{}",
1337                    local_host, pf.local_port, remote_host, pf.remote_port
1338                )
1339            }));
1340        }
1341
1342        args
1343    }
1344
1345    fn scp_url(&self) -> String {
1346        if let Some(username) = &self.username {
1347            format!("{}@{}", username, self.host)
1348        } else {
1349            self.host.clone()
1350        }
1351    }
1352
1353    pub fn connection_string(&self) -> String {
1354        let host = if let Some(username) = &self.username {
1355            format!("{}@{}", username, self.host)
1356        } else {
1357            self.host.clone()
1358        };
1359        if let Some(port) = &self.port {
1360            format!("{}:{}", host, port)
1361        } else {
1362            host
1363        }
1364    }
1365}