ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
  30    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  31};
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use smol::{
  35    fs,
  36    process::{self, Child, Stdio},
  37};
  38use std::{
  39    collections::VecDeque,
  40    fmt, iter,
  41    ops::ControlFlow,
  42    path::{Path, PathBuf},
  43    sync::{
  44        Arc, Weak,
  45        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  46    },
  47    time::{Duration, Instant},
  48};
  49use tempfile::TempDir;
  50use util::{
  51    ResultExt,
  52    paths::{PathStyle, RemotePathBuf},
  53};
  54
  55#[derive(
  56    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  57)]
  58pub struct SshProjectId(pub u64);
  59
  60#[derive(Clone)]
  61pub struct SshSocket {
  62    connection_options: SshConnectionOptions,
  63    #[cfg(not(target_os = "windows"))]
  64    socket_path: PathBuf,
  65    #[cfg(target_os = "windows")]
  66    envs: HashMap<String, String>,
  67}
  68
  69#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  70pub struct SshPortForwardOption {
  71    #[serde(skip_serializing_if = "Option::is_none")]
  72    pub local_host: Option<String>,
  73    pub local_port: u16,
  74    #[serde(skip_serializing_if = "Option::is_none")]
  75    pub remote_host: Option<String>,
  76    pub remote_port: u16,
  77}
  78
  79#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  80pub struct SshConnectionOptions {
  81    pub host: String,
  82    pub username: Option<String>,
  83    pub port: Option<u16>,
  84    pub password: Option<String>,
  85    pub args: Option<Vec<String>>,
  86    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  87
  88    pub nickname: Option<String>,
  89    pub upload_binary_over_ssh: bool,
  90}
  91
  92#[derive(Debug, Clone, PartialEq, Eq)]
  93pub struct SshArgs {
  94    pub arguments: Vec<String>,
  95    pub envs: Option<HashMap<String, String>>,
  96}
  97
  98#[derive(Debug, Clone, PartialEq, Eq)]
  99pub struct SshInfo {
 100    pub args: SshArgs,
 101    pub path_style: PathStyle,
 102    pub shell: String,
 103}
 104
 105#[macro_export]
 106macro_rules! shell_script {
 107    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 108        format!(
 109            $fmt,
 110            $(
 111                $name = shlex::try_quote($arg).unwrap()
 112            ),+
 113        )
 114    }};
 115}
 116
 117fn parse_port_number(port_str: &str) -> Result<u16> {
 118    port_str
 119        .parse()
 120        .with_context(|| format!("parsing port number: {port_str}"))
 121}
 122
 123fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 124    let parts: Vec<&str> = spec.split(':').collect();
 125
 126    match parts.len() {
 127        4 => {
 128            let local_port = parse_port_number(parts[1])?;
 129            let remote_port = parse_port_number(parts[3])?;
 130
 131            Ok(SshPortForwardOption {
 132                local_host: Some(parts[0].to_string()),
 133                local_port,
 134                remote_host: Some(parts[2].to_string()),
 135                remote_port,
 136            })
 137        }
 138        3 => {
 139            let local_port = parse_port_number(parts[0])?;
 140            let remote_port = parse_port_number(parts[2])?;
 141
 142            Ok(SshPortForwardOption {
 143                local_host: None,
 144                local_port,
 145                remote_host: Some(parts[1].to_string()),
 146                remote_port,
 147            })
 148        }
 149        _ => anyhow::bail!("Invalid port forward format"),
 150    }
 151}
 152
 153impl SshConnectionOptions {
 154    pub fn parse_command_line(input: &str) -> Result<Self> {
 155        let input = input.trim_start_matches("ssh ");
 156        let mut hostname: Option<String> = None;
 157        let mut username: Option<String> = None;
 158        let mut port: Option<u16> = None;
 159        let mut args = Vec::new();
 160        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 161
 162        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 163        const ALLOWED_OPTS: &[&str] = &[
 164            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 165        ];
 166        const ALLOWED_ARGS: &[&str] = &[
 167            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 168            "-w",
 169        ];
 170
 171        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 172
 173        'outer: while let Some(arg) = tokens.next() {
 174            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 175                args.push(arg.to_string());
 176                continue;
 177            }
 178            if arg == "-p" {
 179                port = tokens.next().and_then(|arg| arg.parse().ok());
 180                continue;
 181            } else if let Some(p) = arg.strip_prefix("-p") {
 182                port = p.parse().ok();
 183                continue;
 184            }
 185            if arg == "-l" {
 186                username = tokens.next();
 187                continue;
 188            } else if let Some(l) = arg.strip_prefix("-l") {
 189                username = Some(l.to_string());
 190                continue;
 191            }
 192            if arg == "-L" || arg.starts_with("-L") {
 193                let forward_spec = if arg == "-L" {
 194                    tokens.next()
 195                } else {
 196                    Some(arg.strip_prefix("-L").unwrap().to_string())
 197                };
 198
 199                if let Some(spec) = forward_spec {
 200                    port_forwards.push(parse_port_forward_spec(&spec)?);
 201                } else {
 202                    anyhow::bail!("Missing port forward format");
 203                }
 204            }
 205
 206            for a in ALLOWED_ARGS {
 207                if arg == *a {
 208                    args.push(arg);
 209                    if let Some(next) = tokens.next() {
 210                        args.push(next);
 211                    }
 212                    continue 'outer;
 213                } else if arg.starts_with(a) {
 214                    args.push(arg);
 215                    continue 'outer;
 216                }
 217            }
 218            if arg.starts_with("-") || hostname.is_some() {
 219                anyhow::bail!("unsupported argument: {:?}", arg);
 220            }
 221            let mut input = &arg as &str;
 222            // Destination might be: username1@username2@ip2@ip1
 223            if let Some((u, rest)) = input.rsplit_once('@') {
 224                input = rest;
 225                username = Some(u.to_string());
 226            }
 227            if let Some((rest, p)) = input.split_once(':') {
 228                input = rest;
 229                port = p.parse().ok()
 230            }
 231            hostname = Some(input.to_string())
 232        }
 233
 234        let Some(hostname) = hostname else {
 235            anyhow::bail!("missing hostname");
 236        };
 237
 238        let port_forwards = match port_forwards.len() {
 239            0 => None,
 240            _ => Some(port_forwards),
 241        };
 242
 243        Ok(Self {
 244            host: hostname,
 245            username,
 246            port,
 247            port_forwards,
 248            args: Some(args),
 249            password: None,
 250            nickname: None,
 251            upload_binary_over_ssh: false,
 252        })
 253    }
 254
 255    pub fn ssh_url(&self) -> String {
 256        let mut result = String::from("ssh://");
 257        if let Some(username) = &self.username {
 258            // Username might be: username1@username2@ip2
 259            let username = urlencoding::encode(username);
 260            result.push_str(&username);
 261            result.push('@');
 262        }
 263        result.push_str(&self.host);
 264        if let Some(port) = self.port {
 265            result.push(':');
 266            result.push_str(&port.to_string());
 267        }
 268        result
 269    }
 270
 271    pub fn additional_args(&self) -> Vec<String> {
 272        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 273
 274        if let Some(forwards) = &self.port_forwards {
 275            args.extend(forwards.iter().map(|pf| {
 276                let local_host = match &pf.local_host {
 277                    Some(host) => host,
 278                    None => "localhost",
 279                };
 280                let remote_host = match &pf.remote_host {
 281                    Some(host) => host,
 282                    None => "localhost",
 283                };
 284
 285                format!(
 286                    "-L{}:{}:{}:{}",
 287                    local_host, pf.local_port, remote_host, pf.remote_port
 288                )
 289            }));
 290        }
 291
 292        args
 293    }
 294
 295    fn scp_url(&self) -> String {
 296        if let Some(username) = &self.username {
 297            format!("{}@{}", username, self.host)
 298        } else {
 299            self.host.clone()
 300        }
 301    }
 302
 303    pub fn connection_string(&self) -> String {
 304        let host = if let Some(username) = &self.username {
 305            format!("{}@{}", username, self.host)
 306        } else {
 307            self.host.clone()
 308        };
 309        if let Some(port) = &self.port {
 310            format!("{}:{}", host, port)
 311        } else {
 312            host
 313        }
 314    }
 315}
 316
 317#[derive(Copy, Clone, Debug)]
 318pub struct SshPlatform {
 319    pub os: &'static str,
 320    pub arch: &'static str,
 321}
 322
 323pub trait SshClientDelegate: Send + Sync {
 324    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 325    fn get_download_params(
 326        &self,
 327        platform: SshPlatform,
 328        release_channel: ReleaseChannel,
 329        version: Option<SemanticVersion>,
 330        cx: &mut AsyncApp,
 331    ) -> Task<Result<Option<(String, String)>>>;
 332
 333    fn download_server_binary_locally(
 334        &self,
 335        platform: SshPlatform,
 336        release_channel: ReleaseChannel,
 337        version: Option<SemanticVersion>,
 338        cx: &mut AsyncApp,
 339    ) -> Task<Result<PathBuf>>;
 340    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 341}
 342
 343impl SshSocket {
 344    #[cfg(not(target_os = "windows"))]
 345    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 346        Ok(Self {
 347            connection_options: options,
 348            socket_path,
 349        })
 350    }
 351
 352    #[cfg(target_os = "windows")]
 353    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 354        let askpass_script = temp_dir.path().join("askpass.bat");
 355        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 356        let mut envs = HashMap::default();
 357        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 358        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 359        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 360        Ok(Self {
 361            connection_options: options,
 362            envs,
 363        })
 364    }
 365
 366    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 367    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 368    // and passes -l as an argument to sh, not to ls.
 369    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 370    // into a machine. You must use `cd` to get back to $HOME.
 371    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 372    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 373        let mut command = util::command::new_smol_command("ssh");
 374        let to_run = iter::once(&program)
 375            .chain(args.iter())
 376            .map(|token| {
 377                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 378                debug_assert!(
 379                    !token.contains('\n'),
 380                    "multiline arguments do not work in all shells"
 381                );
 382                shlex::try_quote(token).unwrap()
 383            })
 384            .join(" ");
 385        let to_run = format!("cd; {to_run}");
 386        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 387        self.ssh_options(&mut command)
 388            .arg(self.connection_options.ssh_url())
 389            .arg(to_run);
 390        command
 391    }
 392
 393    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 394        let output = self.ssh_command(program, args).output().await?;
 395        anyhow::ensure!(
 396            output.status.success(),
 397            "failed to run command: {}",
 398            String::from_utf8_lossy(&output.stderr)
 399        );
 400        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 401    }
 402
 403    #[cfg(not(target_os = "windows"))]
 404    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 405        command
 406            .stdin(Stdio::piped())
 407            .stdout(Stdio::piped())
 408            .stderr(Stdio::piped())
 409            .args(self.connection_options.additional_args())
 410            .args(["-o", "ControlMaster=no", "-o"])
 411            .arg(format!("ControlPath={}", self.socket_path.display()))
 412    }
 413
 414    #[cfg(target_os = "windows")]
 415    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 416        command
 417            .stdin(Stdio::piped())
 418            .stdout(Stdio::piped())
 419            .stderr(Stdio::piped())
 420            .args(self.connection_options.additional_args())
 421            .envs(self.envs.clone())
 422    }
 423
 424    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 425    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 426    #[cfg(not(target_os = "windows"))]
 427    fn ssh_args(&self) -> SshArgs {
 428        let mut arguments = self.connection_options.additional_args();
 429        arguments.extend(vec![
 430            "-o".to_string(),
 431            "ControlMaster=no".to_string(),
 432            "-o".to_string(),
 433            format!("ControlPath={}", self.socket_path.display()),
 434            self.connection_options.ssh_url(),
 435        ]);
 436        SshArgs {
 437            arguments,
 438            envs: None,
 439        }
 440    }
 441
 442    #[cfg(target_os = "windows")]
 443    fn ssh_args(&self) -> SshArgs {
 444        let mut arguments = self.connection_options.additional_args();
 445        arguments.push(self.connection_options.ssh_url());
 446        SshArgs {
 447            arguments,
 448            envs: Some(self.envs.clone()),
 449        }
 450    }
 451
 452    async fn platform(&self) -> Result<SshPlatform> {
 453        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 454        let Some((os, arch)) = uname.split_once(" ") else {
 455            anyhow::bail!("unknown uname: {uname:?}")
 456        };
 457
 458        let os = match os.trim() {
 459            "Darwin" => "macos",
 460            "Linux" => "linux",
 461            _ => anyhow::bail!(
 462                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 463            ),
 464        };
 465        // exclude armv5,6,7 as they are 32-bit.
 466        let arch = if arch.starts_with("armv8")
 467            || arch.starts_with("armv9")
 468            || arch.starts_with("arm64")
 469            || arch.starts_with("aarch64")
 470        {
 471            "aarch64"
 472        } else if arch.starts_with("x86") {
 473            "x86_64"
 474        } else {
 475            anyhow::bail!(
 476                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 477            )
 478        };
 479
 480        Ok(SshPlatform { os, arch })
 481    }
 482
 483    async fn shell(&self) -> String {
 484        match self.run_command("sh", &["-c", "echo $SHELL"]).await {
 485            Ok(shell) => shell.trim().to_owned(),
 486            Err(e) => {
 487                log::error!("Failed to get shell: {e}");
 488                "sh".to_owned()
 489            }
 490        }
 491    }
 492}
 493
 494const MAX_MISSED_HEARTBEATS: usize = 5;
 495const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 496const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 497
 498const MAX_RECONNECT_ATTEMPTS: usize = 3;
 499
 500enum State {
 501    Connecting,
 502    Connected {
 503        ssh_connection: Arc<dyn RemoteConnection>,
 504        delegate: Arc<dyn SshClientDelegate>,
 505
 506        multiplex_task: Task<Result<()>>,
 507        heartbeat_task: Task<Result<()>>,
 508    },
 509    HeartbeatMissed {
 510        missed_heartbeats: usize,
 511
 512        ssh_connection: Arc<dyn RemoteConnection>,
 513        delegate: Arc<dyn SshClientDelegate>,
 514
 515        multiplex_task: Task<Result<()>>,
 516        heartbeat_task: Task<Result<()>>,
 517    },
 518    Reconnecting,
 519    ReconnectFailed {
 520        ssh_connection: Arc<dyn RemoteConnection>,
 521        delegate: Arc<dyn SshClientDelegate>,
 522
 523        error: anyhow::Error,
 524        attempts: usize,
 525    },
 526    ReconnectExhausted,
 527    ServerNotRunning,
 528}
 529
 530impl fmt::Display for State {
 531    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 532        match self {
 533            Self::Connecting => write!(f, "connecting"),
 534            Self::Connected { .. } => write!(f, "connected"),
 535            Self::Reconnecting => write!(f, "reconnecting"),
 536            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 537            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 538            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 539            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 540        }
 541    }
 542}
 543
 544impl State {
 545    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 546        match self {
 547            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 548            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 549            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 550            _ => None,
 551        }
 552    }
 553
 554    fn can_reconnect(&self) -> bool {
 555        match self {
 556            Self::Connected { .. }
 557            | Self::HeartbeatMissed { .. }
 558            | Self::ReconnectFailed { .. } => true,
 559            State::Connecting
 560            | State::Reconnecting
 561            | State::ReconnectExhausted
 562            | State::ServerNotRunning => false,
 563        }
 564    }
 565
 566    fn is_reconnect_failed(&self) -> bool {
 567        matches!(self, Self::ReconnectFailed { .. })
 568    }
 569
 570    fn is_reconnect_exhausted(&self) -> bool {
 571        matches!(self, Self::ReconnectExhausted { .. })
 572    }
 573
 574    fn is_server_not_running(&self) -> bool {
 575        matches!(self, Self::ServerNotRunning)
 576    }
 577
 578    fn is_reconnecting(&self) -> bool {
 579        matches!(self, Self::Reconnecting { .. })
 580    }
 581
 582    fn heartbeat_recovered(self) -> Self {
 583        match self {
 584            Self::HeartbeatMissed {
 585                ssh_connection,
 586                delegate,
 587                multiplex_task,
 588                heartbeat_task,
 589                ..
 590            } => Self::Connected {
 591                ssh_connection,
 592                delegate,
 593                multiplex_task,
 594                heartbeat_task,
 595            },
 596            _ => self,
 597        }
 598    }
 599
 600    fn heartbeat_missed(self) -> Self {
 601        match self {
 602            Self::Connected {
 603                ssh_connection,
 604                delegate,
 605                multiplex_task,
 606                heartbeat_task,
 607            } => Self::HeartbeatMissed {
 608                missed_heartbeats: 1,
 609                ssh_connection,
 610                delegate,
 611                multiplex_task,
 612                heartbeat_task,
 613            },
 614            Self::HeartbeatMissed {
 615                missed_heartbeats,
 616                ssh_connection,
 617                delegate,
 618                multiplex_task,
 619                heartbeat_task,
 620            } => Self::HeartbeatMissed {
 621                missed_heartbeats: missed_heartbeats + 1,
 622                ssh_connection,
 623                delegate,
 624                multiplex_task,
 625                heartbeat_task,
 626            },
 627            _ => self,
 628        }
 629    }
 630}
 631
 632/// The state of the ssh connection.
 633#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 634pub enum ConnectionState {
 635    Connecting,
 636    Connected,
 637    HeartbeatMissed,
 638    Reconnecting,
 639    Disconnected,
 640}
 641
 642impl From<&State> for ConnectionState {
 643    fn from(value: &State) -> Self {
 644        match value {
 645            State::Connecting => Self::Connecting,
 646            State::Connected { .. } => Self::Connected,
 647            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 648            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 649            State::ReconnectExhausted => Self::Disconnected,
 650            State::ServerNotRunning => Self::Disconnected,
 651        }
 652    }
 653}
 654
 655pub struct SshRemoteClient {
 656    client: Arc<ChannelClient>,
 657    unique_identifier: String,
 658    connection_options: SshConnectionOptions,
 659    path_style: PathStyle,
 660    state: Arc<Mutex<Option<State>>>,
 661}
 662
 663#[derive(Debug)]
 664pub enum SshRemoteEvent {
 665    Disconnected,
 666}
 667
 668impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 669
 670// Identifies the socket on the remote server so that reconnects
 671// can re-join the same project.
 672pub enum ConnectionIdentifier {
 673    Setup(u64),
 674    Workspace(i64),
 675}
 676
 677static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 678
 679impl ConnectionIdentifier {
 680    pub fn setup() -> Self {
 681        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 682    }
 683
 684    // This string gets used in a socket name, and so must be relatively short.
 685    // The total length of:
 686    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 687    // Must be less than about 100 characters
 688    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 689    // So our strings should be at most 20 characters or so.
 690    fn to_string(&self, cx: &App) -> String {
 691        let identifier_prefix = match ReleaseChannel::global(cx) {
 692            ReleaseChannel::Stable => "".to_string(),
 693            release_channel => format!("{}-", release_channel.dev_name()),
 694        };
 695        match self {
 696            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 697            Self::Workspace(workspace_id) => {
 698                format!("{identifier_prefix}workspace-{workspace_id}",)
 699            }
 700        }
 701    }
 702}
 703
 704impl SshRemoteClient {
 705    pub fn new(
 706        unique_identifier: ConnectionIdentifier,
 707        connection_options: SshConnectionOptions,
 708        cancellation: oneshot::Receiver<()>,
 709        delegate: Arc<dyn SshClientDelegate>,
 710        cx: &mut App,
 711    ) -> Task<Result<Option<Entity<Self>>>> {
 712        let unique_identifier = unique_identifier.to_string(cx);
 713        cx.spawn(async move |cx| {
 714            let success = Box::pin(async move {
 715                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 716                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 717                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 718
 719                let client =
 720                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 721
 722                let ssh_connection = cx
 723                    .update(|cx| {
 724                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 725                            pool.connect(connection_options.clone(), &delegate, cx)
 726                        })
 727                    })?
 728                    .await
 729                    .map_err(|e| e.cloned())?;
 730
 731                let path_style = ssh_connection.path_style();
 732                let this = cx.new(|_| Self {
 733                    client: client.clone(),
 734                    unique_identifier: unique_identifier.clone(),
 735                    connection_options,
 736                    path_style,
 737                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 738                })?;
 739
 740                let io_task = ssh_connection.start_proxy(
 741                    unique_identifier,
 742                    false,
 743                    incoming_tx,
 744                    outgoing_rx,
 745                    connection_activity_tx,
 746                    delegate.clone(),
 747                    cx,
 748                );
 749
 750                let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 751
 752                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 753                    log::error!("failed to establish connection: {}", error);
 754                    return Err(error);
 755                }
 756
 757                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 758
 759                this.update(cx, |this, _| {
 760                    *this.state.lock() = Some(State::Connected {
 761                        ssh_connection,
 762                        delegate,
 763                        multiplex_task,
 764                        heartbeat_task,
 765                    });
 766                })?;
 767
 768                Ok(Some(this))
 769            });
 770
 771            select! {
 772                _ = cancellation.fuse() => {
 773                    Ok(None)
 774                }
 775                result = success.fuse() =>  result
 776            }
 777        })
 778    }
 779
 780    pub fn proto_client_from_channels(
 781        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 782        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 783        cx: &App,
 784        name: &'static str,
 785    ) -> AnyProtoClient {
 786        ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
 787    }
 788
 789    pub fn shutdown_processes<T: RequestMessage>(
 790        &self,
 791        shutdown_request: Option<T>,
 792        executor: BackgroundExecutor,
 793    ) -> Option<impl Future<Output = ()> + use<T>> {
 794        let state = self.state.lock().take()?;
 795        log::info!("shutting down ssh processes");
 796
 797        let State::Connected {
 798            multiplex_task,
 799            heartbeat_task,
 800            ssh_connection,
 801            delegate,
 802        } = state
 803        else {
 804            return None;
 805        };
 806
 807        let client = self.client.clone();
 808
 809        Some(async move {
 810            if let Some(shutdown_request) = shutdown_request {
 811                client.send(shutdown_request).log_err();
 812                // We wait 50ms instead of waiting for a response, because
 813                // waiting for a response would require us to wait on the main thread
 814                // which we want to avoid in an `on_app_quit` callback.
 815                executor.timer(Duration::from_millis(50)).await;
 816            }
 817
 818            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 819            // child of master_process.
 820            drop(multiplex_task);
 821            // Now drop the rest of state, which kills master process.
 822            drop(heartbeat_task);
 823            drop(ssh_connection);
 824            drop(delegate);
 825        })
 826    }
 827
 828    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 829        let mut lock = self.state.lock();
 830
 831        let can_reconnect = lock
 832            .as_ref()
 833            .map(|state| state.can_reconnect())
 834            .unwrap_or(false);
 835        if !can_reconnect {
 836            log::info!("aborting reconnect, because not in state that allows reconnecting");
 837            let error = if let Some(state) = lock.as_ref() {
 838                format!("invalid state, cannot reconnect while in state {state}")
 839            } else {
 840                "no state set".to_string()
 841            };
 842            anyhow::bail!(error);
 843        }
 844
 845        let state = lock.take().unwrap();
 846        let (attempts, ssh_connection, delegate) = match state {
 847            State::Connected {
 848                ssh_connection,
 849                delegate,
 850                multiplex_task,
 851                heartbeat_task,
 852            }
 853            | State::HeartbeatMissed {
 854                ssh_connection,
 855                delegate,
 856                multiplex_task,
 857                heartbeat_task,
 858                ..
 859            } => {
 860                drop(multiplex_task);
 861                drop(heartbeat_task);
 862                (0, ssh_connection, delegate)
 863            }
 864            State::ReconnectFailed {
 865                attempts,
 866                ssh_connection,
 867                delegate,
 868                ..
 869            } => (attempts, ssh_connection, delegate),
 870            State::Connecting
 871            | State::Reconnecting
 872            | State::ReconnectExhausted
 873            | State::ServerNotRunning => unreachable!(),
 874        };
 875
 876        let attempts = attempts + 1;
 877        if attempts > MAX_RECONNECT_ATTEMPTS {
 878            log::error!(
 879                "Failed to reconnect to after {} attempts, giving up",
 880                MAX_RECONNECT_ATTEMPTS
 881            );
 882            drop(lock);
 883            self.set_state(State::ReconnectExhausted, cx);
 884            return Ok(());
 885        }
 886        drop(lock);
 887
 888        self.set_state(State::Reconnecting, cx);
 889
 890        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 891
 892        let unique_identifier = self.unique_identifier.clone();
 893        let client = self.client.clone();
 894        let reconnect_task = cx.spawn(async move |this, cx| {
 895            macro_rules! failed {
 896                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 897                    return State::ReconnectFailed {
 898                        error: anyhow!($error),
 899                        attempts: $attempts,
 900                        ssh_connection: $ssh_connection,
 901                        delegate: $delegate,
 902                    };
 903                };
 904            }
 905
 906            if let Err(error) = ssh_connection
 907                .kill()
 908                .await
 909                .context("Failed to kill ssh process")
 910            {
 911                failed!(error, attempts, ssh_connection, delegate);
 912            };
 913
 914            let connection_options = ssh_connection.connection_options();
 915
 916            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 917            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 918            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 919
 920            let (ssh_connection, io_task) = match async {
 921                let ssh_connection = cx
 922                    .update_global(|pool: &mut ConnectionPool, cx| {
 923                        pool.connect(connection_options, &delegate, cx)
 924                    })?
 925                    .await
 926                    .map_err(|error| error.cloned())?;
 927
 928                let io_task = ssh_connection.start_proxy(
 929                    unique_identifier,
 930                    true,
 931                    incoming_tx,
 932                    outgoing_rx,
 933                    connection_activity_tx,
 934                    delegate.clone(),
 935                    cx,
 936                );
 937                anyhow::Ok((ssh_connection, io_task))
 938            }
 939            .await
 940            {
 941                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 942                Err(error) => {
 943                    failed!(error, attempts, ssh_connection, delegate);
 944                }
 945            };
 946
 947            let multiplex_task = Self::monitor(this.clone(), io_task, cx);
 948            client.reconnect(incoming_rx, outgoing_tx, cx);
 949
 950            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 951                failed!(error, attempts, ssh_connection, delegate);
 952            };
 953
 954            State::Connected {
 955                ssh_connection,
 956                delegate,
 957                multiplex_task,
 958                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 959            }
 960        });
 961
 962        cx.spawn(async move |this, cx| {
 963            let new_state = reconnect_task.await;
 964            this.update(cx, |this, cx| {
 965                this.try_set_state(cx, |old_state| {
 966                    if old_state.is_reconnecting() {
 967                        match &new_state {
 968                            State::Connecting
 969                            | State::Reconnecting
 970                            | State::HeartbeatMissed { .. }
 971                            | State::ServerNotRunning => {}
 972                            State::Connected { .. } => {
 973                                log::info!("Successfully reconnected");
 974                            }
 975                            State::ReconnectFailed {
 976                                error, attempts, ..
 977                            } => {
 978                                log::error!(
 979                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 980                                    attempts,
 981                                    error
 982                                );
 983                            }
 984                            State::ReconnectExhausted => {
 985                                log::error!("Reconnect attempt failed and all attempts exhausted");
 986                            }
 987                        }
 988                        Some(new_state)
 989                    } else {
 990                        None
 991                    }
 992                });
 993
 994                if this.state_is(State::is_reconnect_failed) {
 995                    this.reconnect(cx)
 996                } else if this.state_is(State::is_reconnect_exhausted) {
 997                    Ok(())
 998                } else {
 999                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
1000                    Ok(())
1001                }
1002            })
1003        })
1004        .detach_and_log_err(cx);
1005
1006        Ok(())
1007    }
1008
1009    fn heartbeat(
1010        this: WeakEntity<Self>,
1011        mut connection_activity_rx: mpsc::Receiver<()>,
1012        cx: &mut AsyncApp,
1013    ) -> Task<Result<()>> {
1014        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
1015            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
1016        };
1017
1018        cx.spawn(async move |cx| {
1019            let mut missed_heartbeats = 0;
1020
1021            let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
1022            futures::pin_mut!(keepalive_timer);
1023
1024            loop {
1025                select_biased! {
1026                    result = connection_activity_rx.next().fuse() => {
1027                        if result.is_none() {
1028                            log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1029                            return Ok(());
1030                        }
1031
1032                        if missed_heartbeats != 0 {
1033                            missed_heartbeats = 0;
1034                            let _ =this.update(cx, |this, cx| {
1035                                this.handle_heartbeat_result(missed_heartbeats, cx)
1036                            })?;
1037                        }
1038                    }
1039                    _ = keepalive_timer => {
1040                        log::debug!("Sending heartbeat to server...");
1041
1042                        let result = select_biased! {
1043                            _ = connection_activity_rx.next().fuse() => {
1044                                Ok(())
1045                            }
1046                            ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1047                                ping_result
1048                            }
1049                        };
1050
1051                        if result.is_err() {
1052                            missed_heartbeats += 1;
1053                            log::warn!(
1054                                "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1055                                HEARTBEAT_TIMEOUT,
1056                                missed_heartbeats,
1057                                MAX_MISSED_HEARTBEATS
1058                            );
1059                        } else if missed_heartbeats != 0 {
1060                            missed_heartbeats = 0;
1061                        } else {
1062                            continue;
1063                        }
1064
1065                        let result = this.update(cx, |this, cx| {
1066                            this.handle_heartbeat_result(missed_heartbeats, cx)
1067                        })?;
1068                        if result.is_break() {
1069                            return Ok(());
1070                        }
1071                    }
1072                }
1073
1074                keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1075            }
1076        })
1077    }
1078
1079    fn handle_heartbeat_result(
1080        &mut self,
1081        missed_heartbeats: usize,
1082        cx: &mut Context<Self>,
1083    ) -> ControlFlow<()> {
1084        let state = self.state.lock().take().unwrap();
1085        let next_state = if missed_heartbeats > 0 {
1086            state.heartbeat_missed()
1087        } else {
1088            state.heartbeat_recovered()
1089        };
1090
1091        self.set_state(next_state, cx);
1092
1093        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1094            log::error!(
1095                "Missed last {} heartbeats. Reconnecting...",
1096                missed_heartbeats
1097            );
1098
1099            self.reconnect(cx)
1100                .context("failed to start reconnect process after missing heartbeats")
1101                .log_err();
1102            ControlFlow::Break(())
1103        } else {
1104            ControlFlow::Continue(())
1105        }
1106    }
1107
1108    fn monitor(
1109        this: WeakEntity<Self>,
1110        io_task: Task<Result<i32>>,
1111        cx: &AsyncApp,
1112    ) -> Task<Result<()>> {
1113        cx.spawn(async move |cx| {
1114            let result = io_task.await;
1115
1116            match result {
1117                Ok(exit_code) => {
1118                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1119                        match error {
1120                            ProxyLaunchError::ServerNotRunning => {
1121                                log::error!("failed to reconnect because server is not running");
1122                                this.update(cx, |this, cx| {
1123                                    this.set_state(State::ServerNotRunning, cx);
1124                                })?;
1125                            }
1126                        }
1127                    } else if exit_code > 0 {
1128                        log::error!("proxy process terminated unexpectedly");
1129                        this.update(cx, |this, cx| {
1130                            this.reconnect(cx).ok();
1131                        })?;
1132                    }
1133                }
1134                Err(error) => {
1135                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1136                    this.update(cx, |this, cx| {
1137                        this.reconnect(cx).ok();
1138                    })?;
1139                }
1140            }
1141
1142            Ok(())
1143        })
1144    }
1145
1146    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1147        self.state.lock().as_ref().is_some_and(check)
1148    }
1149
1150    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1151        let mut lock = self.state.lock();
1152        let new_state = lock.as_ref().and_then(map);
1153
1154        if let Some(new_state) = new_state {
1155            lock.replace(new_state);
1156            cx.notify();
1157        }
1158    }
1159
1160    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1161        log::info!("setting state to '{}'", &state);
1162
1163        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1164        let is_server_not_running = state.is_server_not_running();
1165        self.state.lock().replace(state);
1166
1167        if is_reconnect_exhausted || is_server_not_running {
1168            cx.emit(SshRemoteEvent::Disconnected);
1169        }
1170        cx.notify();
1171    }
1172
1173    pub fn ssh_info(&self) -> Option<SshInfo> {
1174        self.state
1175            .lock()
1176            .as_ref()
1177            .and_then(|state| state.ssh_connection())
1178            .map(|ssh_connection| SshInfo {
1179                args: ssh_connection.ssh_args(),
1180                path_style: ssh_connection.path_style(),
1181                shell: ssh_connection.shell(),
1182            })
1183    }
1184
1185    pub fn upload_directory(
1186        &self,
1187        src_path: PathBuf,
1188        dest_path: RemotePathBuf,
1189        cx: &App,
1190    ) -> Task<Result<()>> {
1191        let state = self.state.lock();
1192        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1193            return Task::ready(Err(anyhow!("no ssh connection")));
1194        };
1195        connection.upload_directory(src_path, dest_path, cx)
1196    }
1197
1198    pub fn proto_client(&self) -> AnyProtoClient {
1199        self.client.clone().into()
1200    }
1201
1202    pub fn connection_string(&self) -> String {
1203        self.connection_options.connection_string()
1204    }
1205
1206    pub fn connection_options(&self) -> SshConnectionOptions {
1207        self.connection_options.clone()
1208    }
1209
1210    pub fn connection_state(&self) -> ConnectionState {
1211        self.state
1212            .lock()
1213            .as_ref()
1214            .map(ConnectionState::from)
1215            .unwrap_or(ConnectionState::Disconnected)
1216    }
1217
1218    pub fn is_disconnected(&self) -> bool {
1219        self.connection_state() == ConnectionState::Disconnected
1220    }
1221
1222    pub fn path_style(&self) -> PathStyle {
1223        self.path_style
1224    }
1225
1226    #[cfg(any(test, feature = "test-support"))]
1227    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1228        let opts = self.connection_options();
1229        client_cx.spawn(async move |cx| {
1230            let connection = cx
1231                .update_global(|c: &mut ConnectionPool, _| {
1232                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1233                        c.clone()
1234                    } else {
1235                        panic!("missing test connection")
1236                    }
1237                })
1238                .unwrap()
1239                .await
1240                .unwrap();
1241
1242            connection.simulate_disconnect(cx);
1243        })
1244    }
1245
1246    #[cfg(any(test, feature = "test-support"))]
1247    pub fn fake_server(
1248        client_cx: &mut gpui::TestAppContext,
1249        server_cx: &mut gpui::TestAppContext,
1250    ) -> (SshConnectionOptions, AnyProtoClient) {
1251        let port = client_cx
1252            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1253        let opts = SshConnectionOptions {
1254            host: "<fake>".to_string(),
1255            port: Some(port),
1256            ..Default::default()
1257        };
1258        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1259        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1260        let server_client =
1261            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1262        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1263            connection_options: opts.clone(),
1264            server_cx: fake::SendableCx::new(server_cx),
1265            server_channel: server_client.clone(),
1266        });
1267
1268        client_cx.update(|cx| {
1269            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1270                c.connections.insert(
1271                    opts.clone(),
1272                    ConnectionPoolEntry::Connecting(
1273                        cx.background_spawn({
1274                            let connection = connection.clone();
1275                            async move { Ok(connection.clone()) }
1276                        })
1277                        .shared(),
1278                    ),
1279                );
1280            })
1281        });
1282
1283        (opts, server_client.into())
1284    }
1285
1286    #[cfg(any(test, feature = "test-support"))]
1287    pub async fn fake_client(
1288        opts: SshConnectionOptions,
1289        client_cx: &mut gpui::TestAppContext,
1290    ) -> Entity<Self> {
1291        let (_tx, rx) = oneshot::channel();
1292        client_cx
1293            .update(|cx| {
1294                Self::new(
1295                    ConnectionIdentifier::setup(),
1296                    opts,
1297                    rx,
1298                    Arc::new(fake::Delegate),
1299                    cx,
1300                )
1301            })
1302            .await
1303            .unwrap()
1304            .unwrap()
1305    }
1306}
1307
1308enum ConnectionPoolEntry {
1309    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1310    Connected(Weak<dyn RemoteConnection>),
1311}
1312
1313#[derive(Default)]
1314struct ConnectionPool {
1315    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1316}
1317
1318impl Global for ConnectionPool {}
1319
1320impl ConnectionPool {
1321    pub fn connect(
1322        &mut self,
1323        opts: SshConnectionOptions,
1324        delegate: &Arc<dyn SshClientDelegate>,
1325        cx: &mut App,
1326    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1327        let connection = self.connections.get(&opts);
1328        match connection {
1329            Some(ConnectionPoolEntry::Connecting(task)) => {
1330                let delegate = delegate.clone();
1331                cx.spawn(async move |cx| {
1332                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1333                })
1334                .detach();
1335                return task.clone();
1336            }
1337            Some(ConnectionPoolEntry::Connected(ssh)) => {
1338                if let Some(ssh) = ssh.upgrade()
1339                    && !ssh.has_been_killed()
1340                {
1341                    return Task::ready(Ok(ssh)).shared();
1342                }
1343                self.connections.remove(&opts);
1344            }
1345            None => {}
1346        }
1347
1348        let task = cx
1349            .spawn({
1350                let opts = opts.clone();
1351                let delegate = delegate.clone();
1352                async move |cx| {
1353                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1354                        .await
1355                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1356
1357                    cx.update_global(|pool: &mut Self, _| {
1358                        debug_assert!(matches!(
1359                            pool.connections.get(&opts),
1360                            Some(ConnectionPoolEntry::Connecting(_))
1361                        ));
1362                        match connection {
1363                            Ok(connection) => {
1364                                pool.connections.insert(
1365                                    opts.clone(),
1366                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1367                                );
1368                                Ok(connection)
1369                            }
1370                            Err(error) => {
1371                                pool.connections.remove(&opts);
1372                                Err(Arc::new(error))
1373                            }
1374                        }
1375                    })?
1376                }
1377            })
1378            .shared();
1379
1380        self.connections
1381            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1382        task
1383    }
1384}
1385
1386impl From<SshRemoteClient> for AnyProtoClient {
1387    fn from(client: SshRemoteClient) -> Self {
1388        AnyProtoClient::new(client.client)
1389    }
1390}
1391
1392#[async_trait(?Send)]
1393trait RemoteConnection: Send + Sync {
1394    fn start_proxy(
1395        &self,
1396        unique_identifier: String,
1397        reconnect: bool,
1398        incoming_tx: UnboundedSender<Envelope>,
1399        outgoing_rx: UnboundedReceiver<Envelope>,
1400        connection_activity_tx: Sender<()>,
1401        delegate: Arc<dyn SshClientDelegate>,
1402        cx: &mut AsyncApp,
1403    ) -> Task<Result<i32>>;
1404    fn upload_directory(
1405        &self,
1406        src_path: PathBuf,
1407        dest_path: RemotePathBuf,
1408        cx: &App,
1409    ) -> Task<Result<()>>;
1410    async fn kill(&self) -> Result<()>;
1411    fn has_been_killed(&self) -> bool;
1412    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1413    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1414    fn ssh_args(&self) -> SshArgs;
1415    fn connection_options(&self) -> SshConnectionOptions;
1416    fn path_style(&self) -> PathStyle;
1417    fn shell(&self) -> String;
1418
1419    #[cfg(any(test, feature = "test-support"))]
1420    fn simulate_disconnect(&self, _: &AsyncApp) {}
1421}
1422
1423struct SshRemoteConnection {
1424    socket: SshSocket,
1425    master_process: Mutex<Option<Child>>,
1426    remote_binary_path: Option<RemotePathBuf>,
1427    ssh_platform: SshPlatform,
1428    ssh_path_style: PathStyle,
1429    ssh_shell: String,
1430    _temp_dir: TempDir,
1431}
1432
1433#[async_trait(?Send)]
1434impl RemoteConnection for SshRemoteConnection {
1435    async fn kill(&self) -> Result<()> {
1436        let Some(mut process) = self.master_process.lock().take() else {
1437            return Ok(());
1438        };
1439        process.kill().ok();
1440        process.status().await?;
1441        Ok(())
1442    }
1443
1444    fn has_been_killed(&self) -> bool {
1445        self.master_process.lock().is_none()
1446    }
1447
1448    fn ssh_args(&self) -> SshArgs {
1449        self.socket.ssh_args()
1450    }
1451
1452    fn connection_options(&self) -> SshConnectionOptions {
1453        self.socket.connection_options.clone()
1454    }
1455
1456    fn shell(&self) -> String {
1457        self.ssh_shell.clone()
1458    }
1459
1460    fn upload_directory(
1461        &self,
1462        src_path: PathBuf,
1463        dest_path: RemotePathBuf,
1464        cx: &App,
1465    ) -> Task<Result<()>> {
1466        let mut command = util::command::new_smol_command("scp");
1467        let output = self
1468            .socket
1469            .ssh_options(&mut command)
1470            .args(
1471                self.socket
1472                    .connection_options
1473                    .port
1474                    .map(|port| vec!["-P".to_string(), port.to_string()])
1475                    .unwrap_or_default(),
1476            )
1477            .arg("-C")
1478            .arg("-r")
1479            .arg(&src_path)
1480            .arg(format!(
1481                "{}:{}",
1482                self.socket.connection_options.scp_url(),
1483                dest_path
1484            ))
1485            .output();
1486
1487        cx.background_spawn(async move {
1488            let output = output.await?;
1489
1490            anyhow::ensure!(
1491                output.status.success(),
1492                "failed to upload directory {} -> {}: {}",
1493                src_path.display(),
1494                dest_path.to_string(),
1495                String::from_utf8_lossy(&output.stderr)
1496            );
1497
1498            Ok(())
1499        })
1500    }
1501
1502    fn start_proxy(
1503        &self,
1504        unique_identifier: String,
1505        reconnect: bool,
1506        incoming_tx: UnboundedSender<Envelope>,
1507        outgoing_rx: UnboundedReceiver<Envelope>,
1508        connection_activity_tx: Sender<()>,
1509        delegate: Arc<dyn SshClientDelegate>,
1510        cx: &mut AsyncApp,
1511    ) -> Task<Result<i32>> {
1512        delegate.set_status(Some("Starting proxy"), cx);
1513
1514        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1515            return Task::ready(Err(anyhow!("Remote binary path not set")));
1516        };
1517
1518        let mut start_proxy_command = shell_script!(
1519            "exec {binary_path} proxy --identifier {identifier}",
1520            binary_path = &remote_binary_path.to_string(),
1521            identifier = &unique_identifier,
1522        );
1523
1524        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
1525            if let Some(value) = std::env::var(env_var).ok() {
1526                start_proxy_command = format!(
1527                    "{}={} {} ",
1528                    env_var,
1529                    shlex::try_quote(&value).unwrap(),
1530                    start_proxy_command,
1531                );
1532            }
1533        }
1534
1535        if reconnect {
1536            start_proxy_command.push_str(" --reconnect");
1537        }
1538
1539        let ssh_proxy_process = match self
1540            .socket
1541            .ssh_command("sh", &["-c", &start_proxy_command])
1542            // IMPORTANT: we kill this process when we drop the task that uses it.
1543            .kill_on_drop(true)
1544            .spawn()
1545        {
1546            Ok(process) => process,
1547            Err(error) => {
1548                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1549            }
1550        };
1551
1552        Self::multiplex(
1553            ssh_proxy_process,
1554            incoming_tx,
1555            outgoing_rx,
1556            connection_activity_tx,
1557            cx,
1558        )
1559    }
1560
1561    fn path_style(&self) -> PathStyle {
1562        self.ssh_path_style
1563    }
1564}
1565
1566impl SshRemoteConnection {
1567    async fn new(
1568        connection_options: SshConnectionOptions,
1569        delegate: Arc<dyn SshClientDelegate>,
1570        cx: &mut AsyncApp,
1571    ) -> Result<Self> {
1572        use askpass::AskPassResult;
1573
1574        delegate.set_status(Some("Connecting"), cx);
1575
1576        let url = connection_options.ssh_url();
1577
1578        let temp_dir = tempfile::Builder::new()
1579            .prefix("zed-ssh-session")
1580            .tempdir()?;
1581        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1582            let delegate = delegate.clone();
1583            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1584        });
1585
1586        let mut askpass =
1587            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1588
1589        // Start the master SSH process, which does not do anything except for establish
1590        // the connection and keep it open, allowing other ssh commands to reuse it
1591        // via a control socket.
1592        #[cfg(not(target_os = "windows"))]
1593        let socket_path = temp_dir.path().join("ssh.sock");
1594
1595        let mut master_process = {
1596            #[cfg(not(target_os = "windows"))]
1597            let args = [
1598                "-N",
1599                "-o",
1600                "ControlPersist=no",
1601                "-o",
1602                "ControlMaster=yes",
1603                "-o",
1604            ];
1605            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1606            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1607            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1608            #[cfg(target_os = "windows")]
1609            let args = ["-N"];
1610            let mut master_process = util::command::new_smol_command("ssh");
1611            master_process
1612                .kill_on_drop(true)
1613                .stdin(Stdio::null())
1614                .stdout(Stdio::piped())
1615                .stderr(Stdio::piped())
1616                .env("SSH_ASKPASS_REQUIRE", "force")
1617                .env("SSH_ASKPASS", askpass.script_path())
1618                .args(connection_options.additional_args())
1619                .args(args);
1620            #[cfg(not(target_os = "windows"))]
1621            master_process.arg(format!("ControlPath={}", socket_path.display()));
1622            master_process.arg(&url).spawn()?
1623        };
1624        // Wait for this ssh process to close its stdout, indicating that authentication
1625        // has completed.
1626        let mut stdout = master_process.stdout.take().unwrap();
1627        let mut output = Vec::new();
1628
1629        let result = select_biased! {
1630            result = askpass.run().fuse() => {
1631                match result {
1632                    AskPassResult::CancelledByUser => {
1633                        master_process.kill().ok();
1634                        anyhow::bail!("SSH connection canceled")
1635                    }
1636                    AskPassResult::Timedout => {
1637                        anyhow::bail!("connecting to host timed out")
1638                    }
1639                }
1640            }
1641            _ = stdout.read_to_end(&mut output).fuse() => {
1642                anyhow::Ok(())
1643            }
1644        };
1645
1646        if let Err(e) = result {
1647            return Err(e.context("Failed to connect to host"));
1648        }
1649
1650        if master_process.try_status()?.is_some() {
1651            output.clear();
1652            let mut stderr = master_process.stderr.take().unwrap();
1653            stderr.read_to_end(&mut output).await?;
1654
1655            let error_message = format!(
1656                "failed to connect: {}",
1657                String::from_utf8_lossy(&output).trim()
1658            );
1659            anyhow::bail!(error_message);
1660        }
1661
1662        #[cfg(not(target_os = "windows"))]
1663        let socket = SshSocket::new(connection_options, socket_path)?;
1664        #[cfg(target_os = "windows")]
1665        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1666        drop(askpass);
1667
1668        let ssh_platform = socket.platform().await?;
1669        let ssh_path_style = match ssh_platform.os {
1670            "windows" => PathStyle::Windows,
1671            _ => PathStyle::Posix,
1672        };
1673        let ssh_shell = socket.shell().await;
1674
1675        let mut this = Self {
1676            socket,
1677            master_process: Mutex::new(Some(master_process)),
1678            _temp_dir: temp_dir,
1679            remote_binary_path: None,
1680            ssh_path_style,
1681            ssh_platform,
1682            ssh_shell,
1683        };
1684
1685        let (release_channel, version, commit) = cx.update(|cx| {
1686            (
1687                ReleaseChannel::global(cx),
1688                AppVersion::global(cx),
1689                AppCommitSha::try_global(cx),
1690            )
1691        })?;
1692        this.remote_binary_path = Some(
1693            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1694                .await?,
1695        );
1696
1697        Ok(this)
1698    }
1699
1700    fn multiplex(
1701        mut ssh_proxy_process: Child,
1702        incoming_tx: UnboundedSender<Envelope>,
1703        mut outgoing_rx: UnboundedReceiver<Envelope>,
1704        mut connection_activity_tx: Sender<()>,
1705        cx: &AsyncApp,
1706    ) -> Task<Result<i32>> {
1707        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1708        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1709        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1710
1711        let mut stdin_buffer = Vec::new();
1712        let mut stdout_buffer = Vec::new();
1713        let mut stderr_buffer = Vec::new();
1714        let mut stderr_offset = 0;
1715
1716        let stdin_task = cx.background_spawn(async move {
1717            while let Some(outgoing) = outgoing_rx.next().await {
1718                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1719            }
1720            anyhow::Ok(())
1721        });
1722
1723        let stdout_task = cx.background_spawn({
1724            let mut connection_activity_tx = connection_activity_tx.clone();
1725            async move {
1726                loop {
1727                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1728                    let len = child_stdout.read(&mut stdout_buffer).await?;
1729
1730                    if len == 0 {
1731                        return anyhow::Ok(());
1732                    }
1733
1734                    if len < MESSAGE_LEN_SIZE {
1735                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1736                    }
1737
1738                    let message_len = message_len_from_buffer(&stdout_buffer);
1739                    let envelope =
1740                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1741                            .await?;
1742                    connection_activity_tx.try_send(()).ok();
1743                    incoming_tx.unbounded_send(envelope).ok();
1744                }
1745            }
1746        });
1747
1748        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1749            loop {
1750                stderr_buffer.resize(stderr_offset + 1024, 0);
1751
1752                let len = child_stderr
1753                    .read(&mut stderr_buffer[stderr_offset..])
1754                    .await?;
1755                if len == 0 {
1756                    return anyhow::Ok(());
1757                }
1758
1759                stderr_offset += len;
1760                let mut start_ix = 0;
1761                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1762                    .iter()
1763                    .position(|b| b == &b'\n')
1764                {
1765                    let line_ix = start_ix + ix;
1766                    let content = &stderr_buffer[start_ix..line_ix];
1767                    start_ix = line_ix + 1;
1768                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1769                        record.log(log::logger())
1770                    } else {
1771                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1772                    }
1773                }
1774                stderr_buffer.drain(0..start_ix);
1775                stderr_offset -= start_ix;
1776
1777                connection_activity_tx.try_send(()).ok();
1778            }
1779        });
1780
1781        cx.background_spawn(async move {
1782            let result = futures::select! {
1783                result = stdin_task.fuse() => {
1784                    result.context("stdin")
1785                }
1786                result = stdout_task.fuse() => {
1787                    result.context("stdout")
1788                }
1789                result = stderr_task.fuse() => {
1790                    result.context("stderr")
1791                }
1792            };
1793
1794            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1795            match result {
1796                Ok(_) => Ok(status),
1797                Err(error) => Err(error),
1798            }
1799        })
1800    }
1801
1802    #[allow(unused)]
1803    async fn ensure_server_binary(
1804        &self,
1805        delegate: &Arc<dyn SshClientDelegate>,
1806        release_channel: ReleaseChannel,
1807        version: SemanticVersion,
1808        commit: Option<AppCommitSha>,
1809        cx: &mut AsyncApp,
1810    ) -> Result<RemotePathBuf> {
1811        let version_str = match release_channel {
1812            ReleaseChannel::Nightly => {
1813                let commit = commit.map(|s| s.full()).unwrap_or_default();
1814                format!("{}-{}", version, commit)
1815            }
1816            ReleaseChannel::Dev => "build".to_string(),
1817            _ => version.to_string(),
1818        };
1819        let binary_name = format!(
1820            "zed-remote-server-{}-{}",
1821            release_channel.dev_name(),
1822            version_str
1823        );
1824        let dst_path = RemotePathBuf::new(
1825            paths::remote_server_dir_relative().join(binary_name),
1826            self.ssh_path_style,
1827        );
1828
1829        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1830        #[cfg(debug_assertions)]
1831        if let Some(build_remote_server) = build_remote_server {
1832            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1833            let tmp_path = RemotePathBuf::new(
1834                paths::remote_server_dir_relative().join(format!(
1835                    "download-{}-{}",
1836                    std::process::id(),
1837                    src_path.file_name().unwrap().to_string_lossy()
1838                )),
1839                self.ssh_path_style,
1840            );
1841            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1842                .await?;
1843            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1844                .await?;
1845            return Ok(dst_path);
1846        }
1847
1848        if self
1849            .socket
1850            .run_command(&dst_path.to_string(), &["version"])
1851            .await
1852            .is_ok()
1853        {
1854            return Ok(dst_path);
1855        }
1856
1857        let wanted_version = cx.update(|cx| match release_channel {
1858            ReleaseChannel::Nightly => Ok(None),
1859            ReleaseChannel::Dev => {
1860                anyhow::bail!(
1861                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1862                    dst_path
1863                )
1864            }
1865            _ => Ok(Some(AppVersion::global(cx))),
1866        })??;
1867
1868        let tmp_path_gz = RemotePathBuf::new(
1869            PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
1870            self.ssh_path_style,
1871        );
1872        if !self.socket.connection_options.upload_binary_over_ssh
1873            && let Some((url, body)) = delegate
1874                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1875                .await?
1876        {
1877            match self
1878                .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1879                .await
1880            {
1881                Ok(_) => {
1882                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1883                        .await?;
1884                    return Ok(dst_path);
1885                }
1886                Err(e) => {
1887                    log::error!(
1888                        "Failed to download binary on server, attempting to upload server: {}",
1889                        e
1890                    )
1891                }
1892            }
1893        }
1894
1895        let src_path = delegate
1896            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1897            .await?;
1898        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1899            .await?;
1900        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1901            .await?;
1902        Ok(dst_path)
1903    }
1904
1905    async fn download_binary_on_server(
1906        &self,
1907        url: &str,
1908        body: &str,
1909        tmp_path_gz: &RemotePathBuf,
1910        delegate: &Arc<dyn SshClientDelegate>,
1911        cx: &mut AsyncApp,
1912    ) -> Result<()> {
1913        if let Some(parent) = tmp_path_gz.parent() {
1914            self.socket
1915                .run_command(
1916                    "sh",
1917                    &[
1918                        "-c",
1919                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1920                    ],
1921                )
1922                .await?;
1923        }
1924
1925        delegate.set_status(Some("Downloading remote development server on host"), cx);
1926
1927        match self
1928            .socket
1929            .run_command(
1930                "curl",
1931                &[
1932                    "-f",
1933                    "-L",
1934                    "-X",
1935                    "GET",
1936                    "-H",
1937                    "Content-Type: application/json",
1938                    "-d",
1939                    body,
1940                    url,
1941                    "-o",
1942                    &tmp_path_gz.to_string(),
1943                ],
1944            )
1945            .await
1946        {
1947            Ok(_) => {}
1948            Err(e) => {
1949                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1950                    return Err(e);
1951                }
1952
1953                match self
1954                    .socket
1955                    .run_command(
1956                        "wget",
1957                        &[
1958                            "--method=GET",
1959                            "--header=Content-Type: application/json",
1960                            "--body-data",
1961                            body,
1962                            url,
1963                            "-O",
1964                            &tmp_path_gz.to_string(),
1965                        ],
1966                    )
1967                    .await
1968                {
1969                    Ok(_) => {}
1970                    Err(e) => {
1971                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1972                            return Err(e);
1973                        } else {
1974                            anyhow::bail!("Neither curl nor wget is available");
1975                        }
1976                    }
1977                }
1978            }
1979        }
1980
1981        Ok(())
1982    }
1983
1984    async fn upload_local_server_binary(
1985        &self,
1986        src_path: &Path,
1987        tmp_path_gz: &RemotePathBuf,
1988        delegate: &Arc<dyn SshClientDelegate>,
1989        cx: &mut AsyncApp,
1990    ) -> Result<()> {
1991        if let Some(parent) = tmp_path_gz.parent() {
1992            self.socket
1993                .run_command(
1994                    "sh",
1995                    &[
1996                        "-c",
1997                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1998                    ],
1999                )
2000                .await?;
2001        }
2002
2003        let src_stat = fs::metadata(&src_path).await?;
2004        let size = src_stat.len();
2005
2006        let t0 = Instant::now();
2007        delegate.set_status(Some("Uploading remote development server"), cx);
2008        log::info!(
2009            "uploading remote development server to {:?} ({}kb)",
2010            tmp_path_gz,
2011            size / 1024
2012        );
2013        self.upload_file(src_path, tmp_path_gz)
2014            .await
2015            .context("failed to upload server binary")?;
2016        log::info!("uploaded remote development server in {:?}", t0.elapsed());
2017        Ok(())
2018    }
2019
2020    async fn extract_server_binary(
2021        &self,
2022        dst_path: &RemotePathBuf,
2023        tmp_path: &RemotePathBuf,
2024        delegate: &Arc<dyn SshClientDelegate>,
2025        cx: &mut AsyncApp,
2026    ) -> Result<()> {
2027        delegate.set_status(Some("Extracting remote development server"), cx);
2028        let server_mode = 0o755;
2029
2030        let orig_tmp_path = tmp_path.to_string();
2031        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2032            shell_script!(
2033                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2034                server_mode = &format!("{:o}", server_mode),
2035                dst_path = &dst_path.to_string(),
2036            )
2037        } else {
2038            shell_script!(
2039                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2040                server_mode = &format!("{:o}", server_mode),
2041                dst_path = &dst_path.to_string()
2042            )
2043        };
2044        self.socket.run_command("sh", &["-c", &script]).await?;
2045        Ok(())
2046    }
2047
2048    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2049        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2050        let mut command = util::command::new_smol_command("scp");
2051        let output = self
2052            .socket
2053            .ssh_options(&mut command)
2054            .args(
2055                self.socket
2056                    .connection_options
2057                    .port
2058                    .map(|port| vec!["-P".to_string(), port.to_string()])
2059                    .unwrap_or_default(),
2060            )
2061            .arg(src_path)
2062            .arg(format!(
2063                "{}:{}",
2064                self.socket.connection_options.scp_url(),
2065                dest_path
2066            ))
2067            .output()
2068            .await?;
2069
2070        anyhow::ensure!(
2071            output.status.success(),
2072            "failed to upload file {} -> {}: {}",
2073            src_path.display(),
2074            dest_path.to_string(),
2075            String::from_utf8_lossy(&output.stderr)
2076        );
2077        Ok(())
2078    }
2079
2080    #[cfg(debug_assertions)]
2081    async fn build_local(
2082        &self,
2083        build_remote_server: String,
2084        delegate: &Arc<dyn SshClientDelegate>,
2085        cx: &mut AsyncApp,
2086    ) -> Result<PathBuf> {
2087        use smol::process::{Command, Stdio};
2088        use std::env::VarError;
2089
2090        async fn run_cmd(command: &mut Command) -> Result<()> {
2091            let output = command
2092                .kill_on_drop(true)
2093                .stderr(Stdio::inherit())
2094                .output()
2095                .await?;
2096            anyhow::ensure!(
2097                output.status.success(),
2098                "Failed to run command: {command:?}"
2099            );
2100            Ok(())
2101        }
2102
2103        let use_musl = !build_remote_server.contains("nomusl");
2104        let triple = format!(
2105            "{}-{}",
2106            self.ssh_platform.arch,
2107            match self.ssh_platform.os {
2108                "linux" =>
2109                    if use_musl {
2110                        "unknown-linux-musl"
2111                    } else {
2112                        "unknown-linux-gnu"
2113                    },
2114                "macos" => "apple-darwin",
2115                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2116            }
2117        );
2118        let mut rust_flags = match std::env::var("RUSTFLAGS") {
2119            Ok(val) => val,
2120            Err(VarError::NotPresent) => String::new(),
2121            Err(e) => {
2122                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2123                String::new()
2124            }
2125        };
2126        if self.ssh_platform.os == "linux" && use_musl {
2127            rust_flags.push_str(" -C target-feature=+crt-static");
2128        }
2129        if build_remote_server.contains("mold") {
2130            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2131        }
2132
2133        if self.ssh_platform.arch == std::env::consts::ARCH
2134            && self.ssh_platform.os == std::env::consts::OS
2135        {
2136            delegate.set_status(Some("Building remote server binary from source"), cx);
2137            log::info!("building remote server binary from source");
2138            run_cmd(
2139                Command::new("cargo")
2140                    .args([
2141                        "build",
2142                        "--package",
2143                        "remote_server",
2144                        "--features",
2145                        "debug-embed",
2146                        "--target-dir",
2147                        "target/remote_server",
2148                        "--target",
2149                        &triple,
2150                    ])
2151                    .env("RUSTFLAGS", &rust_flags),
2152            )
2153            .await?;
2154        } else if build_remote_server.contains("cross") {
2155            #[cfg(target_os = "windows")]
2156            use util::paths::SanitizedPath;
2157
2158            delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2159            log::info!("installing cross");
2160            run_cmd(Command::new("cargo").args([
2161                "install",
2162                "cross",
2163                "--git",
2164                "https://github.com/cross-rs/cross",
2165            ]))
2166            .await?;
2167
2168            delegate.set_status(
2169                Some(&format!(
2170                    "Building remote server binary from source for {} with Docker",
2171                    &triple
2172                )),
2173                cx,
2174            );
2175            log::info!("building remote server binary from source for {}", &triple);
2176
2177            // On Windows, the binding needs to be set to the canonical path
2178            #[cfg(target_os = "windows")]
2179            let src =
2180                SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2181            #[cfg(not(target_os = "windows"))]
2182            let src = "./target";
2183            run_cmd(
2184                Command::new("cross")
2185                    .args([
2186                        "build",
2187                        "--package",
2188                        "remote_server",
2189                        "--features",
2190                        "debug-embed",
2191                        "--target-dir",
2192                        "target/remote_server",
2193                        "--target",
2194                        &triple,
2195                    ])
2196                    .env(
2197                        "CROSS_CONTAINER_OPTS",
2198                        format!("--mount type=bind,src={src},dst=/app/target"),
2199                    )
2200                    .env("RUSTFLAGS", &rust_flags),
2201            )
2202            .await?;
2203        } else {
2204            let which = cx
2205                .background_spawn(async move { which::which("zig") })
2206                .await;
2207
2208            if which.is_err() {
2209                #[cfg(not(target_os = "windows"))]
2210                {
2211                    anyhow::bail!(
2212                        "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2213                    )
2214                }
2215                #[cfg(target_os = "windows")]
2216                {
2217                    anyhow::bail!(
2218                        "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2219                    )
2220                }
2221            }
2222
2223            delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2224            log::info!("adding rustup target");
2225            run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2226
2227            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2228            log::info!("installing cargo-zigbuild");
2229            run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
2230
2231            delegate.set_status(
2232                Some(&format!(
2233                    "Building remote binary from source for {triple} with Zig"
2234                )),
2235                cx,
2236            );
2237            log::info!("building remote binary from source for {triple} with Zig");
2238            run_cmd(
2239                Command::new("cargo")
2240                    .args([
2241                        "zigbuild",
2242                        "--package",
2243                        "remote_server",
2244                        "--features",
2245                        "debug-embed",
2246                        "--target-dir",
2247                        "target/remote_server",
2248                        "--target",
2249                        &triple,
2250                    ])
2251                    .env("RUSTFLAGS", &rust_flags),
2252            )
2253            .await?;
2254        };
2255        let bin_path = Path::new("target")
2256            .join("remote_server")
2257            .join(&triple)
2258            .join("debug")
2259            .join("remote_server");
2260
2261        let path = if !build_remote_server.contains("nocompress") {
2262            delegate.set_status(Some("Compressing binary"), cx);
2263
2264            #[cfg(not(target_os = "windows"))]
2265            {
2266                run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
2267            }
2268            #[cfg(target_os = "windows")]
2269            {
2270                // On Windows, we use 7z to compress the binary
2271                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2272                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2273                if smol::fs::metadata(&gz_path).await.is_ok() {
2274                    smol::fs::remove_file(&gz_path).await?;
2275                }
2276                run_cmd(Command::new(seven_zip).args([
2277                    "a",
2278                    "-tgzip",
2279                    &gz_path,
2280                    &bin_path.to_string_lossy(),
2281                ]))
2282                .await?;
2283            }
2284
2285            let mut archive_path = bin_path;
2286            archive_path.set_extension("gz");
2287            std::env::current_dir()?.join(archive_path)
2288        } else {
2289            bin_path
2290        };
2291
2292        Ok(path)
2293    }
2294}
2295
2296type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2297
2298struct ChannelClient {
2299    next_message_id: AtomicU32,
2300    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2301    buffer: Mutex<VecDeque<Envelope>>,
2302    response_channels: ResponseChannels,
2303    message_handlers: Mutex<ProtoMessageHandlerSet>,
2304    max_received: AtomicU32,
2305    name: &'static str,
2306    task: Mutex<Task<Result<()>>>,
2307}
2308
2309impl ChannelClient {
2310    fn new(
2311        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2312        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2313        cx: &App,
2314        name: &'static str,
2315    ) -> Arc<Self> {
2316        Arc::new_cyclic(|this| Self {
2317            outgoing_tx: Mutex::new(outgoing_tx),
2318            next_message_id: AtomicU32::new(0),
2319            max_received: AtomicU32::new(0),
2320            response_channels: ResponseChannels::default(),
2321            message_handlers: Default::default(),
2322            buffer: Mutex::new(VecDeque::new()),
2323            name,
2324            task: Mutex::new(Self::start_handling_messages(
2325                this.clone(),
2326                incoming_rx,
2327                &cx.to_async(),
2328            )),
2329        })
2330    }
2331
2332    fn start_handling_messages(
2333        this: Weak<Self>,
2334        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2335        cx: &AsyncApp,
2336    ) -> Task<Result<()>> {
2337        cx.spawn(async move |cx| {
2338            let peer_id = PeerId { owner_id: 0, id: 0 };
2339            while let Some(incoming) = incoming_rx.next().await {
2340                let Some(this) = this.upgrade() else {
2341                    return anyhow::Ok(());
2342                };
2343                if let Some(ack_id) = incoming.ack_id {
2344                    let mut buffer = this.buffer.lock();
2345                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2346                        buffer.pop_front();
2347                    }
2348                }
2349                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2350                {
2351                    log::debug!(
2352                        "{}:ssh message received. name:FlushBufferedMessages",
2353                        this.name
2354                    );
2355                    {
2356                        let buffer = this.buffer.lock();
2357                        for envelope in buffer.iter() {
2358                            this.outgoing_tx
2359                                .lock()
2360                                .unbounded_send(envelope.clone())
2361                                .ok();
2362                        }
2363                    }
2364                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2365                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2366                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2367                    continue;
2368                }
2369
2370                this.max_received.store(incoming.id, SeqCst);
2371
2372                if let Some(request_id) = incoming.responding_to {
2373                    let request_id = MessageId(request_id);
2374                    let sender = this.response_channels.lock().remove(&request_id);
2375                    if let Some(sender) = sender {
2376                        let (tx, rx) = oneshot::channel();
2377                        if incoming.payload.is_some() {
2378                            sender.send((incoming, tx)).ok();
2379                        }
2380                        rx.await.ok();
2381                    }
2382                } else if let Some(envelope) =
2383                    build_typed_envelope(peer_id, Instant::now(), incoming)
2384                {
2385                    let type_name = envelope.payload_type_name();
2386                    let message_id = envelope.message_id();
2387                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2388                        &this.message_handlers,
2389                        envelope,
2390                        this.clone().into(),
2391                        cx.clone(),
2392                    ) {
2393                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2394                        cx.foreground_executor()
2395                            .spawn(async move {
2396                                match future.await {
2397                                    Ok(_) => {
2398                                        log::debug!(
2399                                            "{}:ssh message handled. name:{type_name}",
2400                                            this.name
2401                                        );
2402                                    }
2403                                    Err(error) => {
2404                                        log::error!(
2405                                            "{}:error handling message. type:{}, error:{}",
2406                                            this.name,
2407                                            type_name,
2408                                            format!("{error:#}").lines().fold(
2409                                                String::new(),
2410                                                |mut message, line| {
2411                                                    if !message.is_empty() {
2412                                                        message.push(' ');
2413                                                    }
2414                                                    message.push_str(line);
2415                                                    message
2416                                                }
2417                                            )
2418                                        );
2419                                    }
2420                                }
2421                            })
2422                            .detach()
2423                    } else {
2424                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2425                        if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
2426                            message_id,
2427                            anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
2428                        ) {
2429                            log::error!(
2430                                "{}:error sending error response for {type_name}:{e:#}",
2431                                this.name
2432                            );
2433                        }
2434                    }
2435                }
2436            }
2437            anyhow::Ok(())
2438        })
2439    }
2440
2441    fn reconnect(
2442        self: &Arc<Self>,
2443        incoming_rx: UnboundedReceiver<Envelope>,
2444        outgoing_tx: UnboundedSender<Envelope>,
2445        cx: &AsyncApp,
2446    ) {
2447        *self.outgoing_tx.lock() = outgoing_tx;
2448        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2449    }
2450
2451    fn request<T: RequestMessage>(
2452        &self,
2453        payload: T,
2454    ) -> impl 'static + Future<Output = Result<T::Response>> {
2455        self.request_internal(payload, true)
2456    }
2457
2458    fn request_internal<T: RequestMessage>(
2459        &self,
2460        payload: T,
2461        use_buffer: bool,
2462    ) -> impl 'static + Future<Output = Result<T::Response>> {
2463        log::debug!("ssh request start. name:{}", T::NAME);
2464        let response =
2465            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2466        async move {
2467            let response = response.await?;
2468            log::debug!("ssh request finish. name:{}", T::NAME);
2469            T::Response::from_envelope(response).context("received a response of the wrong type")
2470        }
2471    }
2472
2473    async fn resync(&self, timeout: Duration) -> Result<()> {
2474        smol::future::or(
2475            async {
2476                self.request_internal(proto::FlushBufferedMessages {}, false)
2477                    .await?;
2478
2479                for envelope in self.buffer.lock().iter() {
2480                    self.outgoing_tx
2481                        .lock()
2482                        .unbounded_send(envelope.clone())
2483                        .ok();
2484                }
2485                Ok(())
2486            },
2487            async {
2488                smol::Timer::after(timeout).await;
2489                anyhow::bail!("Timed out resyncing remote client")
2490            },
2491        )
2492        .await
2493    }
2494
2495    async fn ping(&self, timeout: Duration) -> Result<()> {
2496        smol::future::or(
2497            async {
2498                self.request(proto::Ping {}).await?;
2499                Ok(())
2500            },
2501            async {
2502                smol::Timer::after(timeout).await;
2503                anyhow::bail!("Timed out pinging remote client")
2504            },
2505        )
2506        .await
2507    }
2508
2509    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2510        log::debug!("ssh send name:{}", T::NAME);
2511        self.send_dynamic(payload.into_envelope(0, None, None))
2512    }
2513
2514    fn request_dynamic(
2515        &self,
2516        mut envelope: proto::Envelope,
2517        type_name: &'static str,
2518        use_buffer: bool,
2519    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2520        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2521        let (tx, rx) = oneshot::channel();
2522        let mut response_channels_lock = self.response_channels.lock();
2523        response_channels_lock.insert(MessageId(envelope.id), tx);
2524        drop(response_channels_lock);
2525
2526        let result = if use_buffer {
2527            self.send_buffered(envelope)
2528        } else {
2529            self.send_unbuffered(envelope)
2530        };
2531        async move {
2532            if let Err(error) = &result {
2533                log::error!("failed to send message: {error}");
2534                anyhow::bail!("failed to send message: {error}");
2535            }
2536
2537            let response = rx.await.context("connection lost")?.0;
2538            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2539                return Err(RpcError::from_proto(error, type_name));
2540            }
2541            Ok(response)
2542        }
2543    }
2544
2545    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2546        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2547        self.send_buffered(envelope)
2548    }
2549
2550    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2551        envelope.ack_id = Some(self.max_received.load(SeqCst));
2552        self.buffer.lock().push_back(envelope.clone());
2553        // ignore errors on send (happen while we're reconnecting)
2554        // assume that the global "disconnected" overlay is sufficient.
2555        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2556        Ok(())
2557    }
2558
2559    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2560        envelope.ack_id = Some(self.max_received.load(SeqCst));
2561        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2562        Ok(())
2563    }
2564}
2565
2566impl ProtoClient for ChannelClient {
2567    fn request(
2568        &self,
2569        envelope: proto::Envelope,
2570        request_type: &'static str,
2571    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2572        self.request_dynamic(envelope, request_type, true).boxed()
2573    }
2574
2575    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2576        self.send_dynamic(envelope)
2577    }
2578
2579    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2580        self.send_dynamic(envelope)
2581    }
2582
2583    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2584        &self.message_handlers
2585    }
2586
2587    fn is_via_collab(&self) -> bool {
2588        false
2589    }
2590}
2591
2592#[cfg(any(test, feature = "test-support"))]
2593mod fake {
2594    use std::{path::PathBuf, sync::Arc};
2595
2596    use anyhow::Result;
2597    use async_trait::async_trait;
2598    use futures::{
2599        FutureExt, SinkExt, StreamExt,
2600        channel::{
2601            mpsc::{self, Sender},
2602            oneshot,
2603        },
2604        select_biased,
2605    };
2606    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2607    use release_channel::ReleaseChannel;
2608    use rpc::proto::Envelope;
2609    use util::paths::{PathStyle, RemotePathBuf};
2610
2611    use super::{
2612        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2613        SshPlatform,
2614    };
2615
2616    pub(super) struct FakeRemoteConnection {
2617        pub(super) connection_options: SshConnectionOptions,
2618        pub(super) server_channel: Arc<ChannelClient>,
2619        pub(super) server_cx: SendableCx,
2620    }
2621
2622    pub(super) struct SendableCx(AsyncApp);
2623    impl SendableCx {
2624        // SAFETY: When run in test mode, GPUI is always single threaded.
2625        pub(super) fn new(cx: &TestAppContext) -> Self {
2626            Self(cx.to_async())
2627        }
2628
2629        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2630        fn get(&self, _: &AsyncApp) -> AsyncApp {
2631            self.0.clone()
2632        }
2633    }
2634
2635    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2636    unsafe impl Send for SendableCx {}
2637    unsafe impl Sync for SendableCx {}
2638
2639    #[async_trait(?Send)]
2640    impl RemoteConnection for FakeRemoteConnection {
2641        async fn kill(&self) -> Result<()> {
2642            Ok(())
2643        }
2644
2645        fn has_been_killed(&self) -> bool {
2646            false
2647        }
2648
2649        fn ssh_args(&self) -> SshArgs {
2650            SshArgs {
2651                arguments: Vec::new(),
2652                envs: None,
2653            }
2654        }
2655
2656        fn upload_directory(
2657            &self,
2658            _src_path: PathBuf,
2659            _dest_path: RemotePathBuf,
2660            _cx: &App,
2661        ) -> Task<Result<()>> {
2662            unreachable!()
2663        }
2664
2665        fn connection_options(&self) -> SshConnectionOptions {
2666            self.connection_options.clone()
2667        }
2668
2669        fn simulate_disconnect(&self, cx: &AsyncApp) {
2670            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2671            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2672            self.server_channel
2673                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
2674        }
2675
2676        fn start_proxy(
2677            &self,
2678            _unique_identifier: String,
2679            _reconnect: bool,
2680            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2681            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2682            mut connection_activity_tx: Sender<()>,
2683            _delegate: Arc<dyn SshClientDelegate>,
2684            cx: &mut AsyncApp,
2685        ) -> Task<Result<i32>> {
2686            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2687            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2688
2689            self.server_channel.reconnect(
2690                server_incoming_rx,
2691                server_outgoing_tx,
2692                &self.server_cx.get(cx),
2693            );
2694
2695            cx.background_spawn(async move {
2696                loop {
2697                    select_biased! {
2698                        server_to_client = server_outgoing_rx.next().fuse() => {
2699                            let Some(server_to_client) = server_to_client else {
2700                                return Ok(1)
2701                            };
2702                            connection_activity_tx.try_send(()).ok();
2703                            client_incoming_tx.send(server_to_client).await.ok();
2704                        }
2705                        client_to_server = client_outgoing_rx.next().fuse() => {
2706                            let Some(client_to_server) = client_to_server else {
2707                                return Ok(1)
2708                            };
2709                            server_incoming_tx.send(client_to_server).await.ok();
2710                        }
2711                    }
2712                }
2713            })
2714        }
2715
2716        fn path_style(&self) -> PathStyle {
2717            PathStyle::current()
2718        }
2719
2720        fn shell(&self) -> String {
2721            "sh".to_owned()
2722        }
2723    }
2724
2725    pub(super) struct Delegate;
2726
2727    impl SshClientDelegate for Delegate {
2728        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2729            unreachable!()
2730        }
2731
2732        fn download_server_binary_locally(
2733            &self,
2734            _: SshPlatform,
2735            _: ReleaseChannel,
2736            _: Option<SemanticVersion>,
2737            _: &mut AsyncApp,
2738        ) -> Task<Result<PathBuf>> {
2739            unreachable!()
2740        }
2741
2742        fn get_download_params(
2743            &self,
2744            _platform: SshPlatform,
2745            _release_channel: ReleaseChannel,
2746            _version: Option<SemanticVersion>,
2747            _cx: &mut AsyncApp,
2748        ) -> Task<Result<Option<(String, String)>>> {
2749            unreachable!()
2750        }
2751
2752        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2753    }
2754}