ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
  30    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  31};
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use smol::{
  35    fs,
  36    process::{self, Child, Stdio},
  37};
  38use std::{
  39    collections::VecDeque,
  40    fmt, iter,
  41    ops::ControlFlow,
  42    path::{Path, PathBuf},
  43    sync::{
  44        Arc, Weak,
  45        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  46    },
  47    time::{Duration, Instant},
  48};
  49use tempfile::TempDir;
  50use util::{
  51    ResultExt,
  52    paths::{PathStyle, RemotePathBuf},
  53};
  54
  55#[derive(
  56    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  57)]
  58pub struct SshProjectId(pub u64);
  59
  60#[derive(Clone)]
  61pub struct SshSocket {
  62    connection_options: SshConnectionOptions,
  63    #[cfg(not(target_os = "windows"))]
  64    socket_path: PathBuf,
  65    #[cfg(target_os = "windows")]
  66    envs: HashMap<String, String>,
  67}
  68
  69#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  70pub struct SshPortForwardOption {
  71    #[serde(skip_serializing_if = "Option::is_none")]
  72    pub local_host: Option<String>,
  73    pub local_port: u16,
  74    #[serde(skip_serializing_if = "Option::is_none")]
  75    pub remote_host: Option<String>,
  76    pub remote_port: u16,
  77}
  78
  79#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  80pub struct SshConnectionOptions {
  81    pub host: String,
  82    pub username: Option<String>,
  83    pub port: Option<u16>,
  84    pub password: Option<String>,
  85    pub args: Option<Vec<String>>,
  86    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  87
  88    pub nickname: Option<String>,
  89    pub upload_binary_over_ssh: bool,
  90}
  91
  92pub struct SshArgs {
  93    pub arguments: Vec<String>,
  94    pub envs: Option<HashMap<String, String>>,
  95}
  96
  97#[macro_export]
  98macro_rules! shell_script {
  99    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 100        format!(
 101            $fmt,
 102            $(
 103                $name = shlex::try_quote($arg).unwrap()
 104            ),+
 105        )
 106    }};
 107}
 108
 109fn parse_port_number(port_str: &str) -> Result<u16> {
 110    port_str
 111        .parse()
 112        .with_context(|| format!("parsing port number: {port_str}"))
 113}
 114
 115fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 116    let parts: Vec<&str> = spec.split(':').collect();
 117
 118    match parts.len() {
 119        4 => {
 120            let local_port = parse_port_number(parts[1])?;
 121            let remote_port = parse_port_number(parts[3])?;
 122
 123            Ok(SshPortForwardOption {
 124                local_host: Some(parts[0].to_string()),
 125                local_port,
 126                remote_host: Some(parts[2].to_string()),
 127                remote_port,
 128            })
 129        }
 130        3 => {
 131            let local_port = parse_port_number(parts[0])?;
 132            let remote_port = parse_port_number(parts[2])?;
 133
 134            Ok(SshPortForwardOption {
 135                local_host: None,
 136                local_port,
 137                remote_host: Some(parts[1].to_string()),
 138                remote_port,
 139            })
 140        }
 141        _ => anyhow::bail!("Invalid port forward format"),
 142    }
 143}
 144
 145impl SshConnectionOptions {
 146    pub fn parse_command_line(input: &str) -> Result<Self> {
 147        let input = input.trim_start_matches("ssh ");
 148        let mut hostname: Option<String> = None;
 149        let mut username: Option<String> = None;
 150        let mut port: Option<u16> = None;
 151        let mut args = Vec::new();
 152        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 153
 154        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 155        const ALLOWED_OPTS: &[&str] = &[
 156            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 157        ];
 158        const ALLOWED_ARGS: &[&str] = &[
 159            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 160            "-w",
 161        ];
 162
 163        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 164
 165        'outer: while let Some(arg) = tokens.next() {
 166            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 167                args.push(arg.to_string());
 168                continue;
 169            }
 170            if arg == "-p" {
 171                port = tokens.next().and_then(|arg| arg.parse().ok());
 172                continue;
 173            } else if let Some(p) = arg.strip_prefix("-p") {
 174                port = p.parse().ok();
 175                continue;
 176            }
 177            if arg == "-l" {
 178                username = tokens.next();
 179                continue;
 180            } else if let Some(l) = arg.strip_prefix("-l") {
 181                username = Some(l.to_string());
 182                continue;
 183            }
 184            if arg == "-L" || arg.starts_with("-L") {
 185                let forward_spec = if arg == "-L" {
 186                    tokens.next()
 187                } else {
 188                    Some(arg.strip_prefix("-L").unwrap().to_string())
 189                };
 190
 191                if let Some(spec) = forward_spec {
 192                    port_forwards.push(parse_port_forward_spec(&spec)?);
 193                } else {
 194                    anyhow::bail!("Missing port forward format");
 195                }
 196            }
 197
 198            for a in ALLOWED_ARGS {
 199                if arg == *a {
 200                    args.push(arg);
 201                    if let Some(next) = tokens.next() {
 202                        args.push(next);
 203                    }
 204                    continue 'outer;
 205                } else if arg.starts_with(a) {
 206                    args.push(arg);
 207                    continue 'outer;
 208                }
 209            }
 210            if arg.starts_with("-") || hostname.is_some() {
 211                anyhow::bail!("unsupported argument: {:?}", arg);
 212            }
 213            let mut input = &arg as &str;
 214            // Destination might be: username1@username2@ip2@ip1
 215            if let Some((u, rest)) = input.rsplit_once('@') {
 216                input = rest;
 217                username = Some(u.to_string());
 218            }
 219            if let Some((rest, p)) = input.split_once(':') {
 220                input = rest;
 221                port = p.parse().ok()
 222            }
 223            hostname = Some(input.to_string())
 224        }
 225
 226        let Some(hostname) = hostname else {
 227            anyhow::bail!("missing hostname");
 228        };
 229
 230        let port_forwards = match port_forwards.len() {
 231            0 => None,
 232            _ => Some(port_forwards),
 233        };
 234
 235        Ok(Self {
 236            host: hostname.to_string(),
 237            username: username.clone(),
 238            port,
 239            port_forwards,
 240            args: Some(args),
 241            password: None,
 242            nickname: None,
 243            upload_binary_over_ssh: false,
 244        })
 245    }
 246
 247    pub fn ssh_url(&self) -> String {
 248        let mut result = String::from("ssh://");
 249        if let Some(username) = &self.username {
 250            // Username might be: username1@username2@ip2
 251            let username = urlencoding::encode(username);
 252            result.push_str(&username);
 253            result.push('@');
 254        }
 255        result.push_str(&self.host);
 256        if let Some(port) = self.port {
 257            result.push(':');
 258            result.push_str(&port.to_string());
 259        }
 260        result
 261    }
 262
 263    pub fn additional_args(&self) -> Vec<String> {
 264        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 265
 266        if let Some(forwards) = &self.port_forwards {
 267            args.extend(forwards.iter().map(|pf| {
 268                let local_host = match &pf.local_host {
 269                    Some(host) => host,
 270                    None => "localhost",
 271                };
 272                let remote_host = match &pf.remote_host {
 273                    Some(host) => host,
 274                    None => "localhost",
 275                };
 276
 277                format!(
 278                    "-L{}:{}:{}:{}",
 279                    local_host, pf.local_port, remote_host, pf.remote_port
 280                )
 281            }));
 282        }
 283
 284        args
 285    }
 286
 287    fn scp_url(&self) -> String {
 288        if let Some(username) = &self.username {
 289            format!("{}@{}", username, self.host)
 290        } else {
 291            self.host.clone()
 292        }
 293    }
 294
 295    pub fn connection_string(&self) -> String {
 296        let host = if let Some(username) = &self.username {
 297            format!("{}@{}", username, self.host)
 298        } else {
 299            self.host.clone()
 300        };
 301        if let Some(port) = &self.port {
 302            format!("{}:{}", host, port)
 303        } else {
 304            host
 305        }
 306    }
 307}
 308
 309#[derive(Copy, Clone, Debug)]
 310pub struct SshPlatform {
 311    pub os: &'static str,
 312    pub arch: &'static str,
 313}
 314
 315pub trait SshClientDelegate: Send + Sync {
 316    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 317    fn get_download_params(
 318        &self,
 319        platform: SshPlatform,
 320        release_channel: ReleaseChannel,
 321        version: Option<SemanticVersion>,
 322        cx: &mut AsyncApp,
 323    ) -> Task<Result<Option<(String, String)>>>;
 324
 325    fn download_server_binary_locally(
 326        &self,
 327        platform: SshPlatform,
 328        release_channel: ReleaseChannel,
 329        version: Option<SemanticVersion>,
 330        cx: &mut AsyncApp,
 331    ) -> Task<Result<PathBuf>>;
 332    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 333}
 334
 335impl SshSocket {
 336    #[cfg(not(target_os = "windows"))]
 337    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 338        Ok(Self {
 339            connection_options: options,
 340            socket_path,
 341        })
 342    }
 343
 344    #[cfg(target_os = "windows")]
 345    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 346        let askpass_script = temp_dir.path().join("askpass.bat");
 347        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 348        let mut envs = HashMap::default();
 349        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 350        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 351        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 352        Ok(Self {
 353            connection_options: options,
 354            envs,
 355        })
 356    }
 357
 358    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 359    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 360    // and passes -l as an argument to sh, not to ls.
 361    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 362    // into a machine. You must use `cd` to get back to $HOME.
 363    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 364    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 365        let mut command = util::command::new_smol_command("ssh");
 366        let to_run = iter::once(&program)
 367            .chain(args.iter())
 368            .map(|token| {
 369                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 370                debug_assert!(
 371                    !token.contains('\n'),
 372                    "multiline arguments do not work in all shells"
 373                );
 374                shlex::try_quote(token).unwrap()
 375            })
 376            .join(" ");
 377        let to_run = format!("cd; {to_run}");
 378        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 379        self.ssh_options(&mut command)
 380            .arg(self.connection_options.ssh_url())
 381            .arg(to_run);
 382        command
 383    }
 384
 385    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 386        let output = self.ssh_command(program, args).output().await?;
 387        anyhow::ensure!(
 388            output.status.success(),
 389            "failed to run command: {}",
 390            String::from_utf8_lossy(&output.stderr)
 391        );
 392        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 393    }
 394
 395    #[cfg(not(target_os = "windows"))]
 396    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 397        command
 398            .stdin(Stdio::piped())
 399            .stdout(Stdio::piped())
 400            .stderr(Stdio::piped())
 401            .args(self.connection_options.additional_args())
 402            .args(["-o", "ControlMaster=no", "-o"])
 403            .arg(format!("ControlPath={}", self.socket_path.display()))
 404    }
 405
 406    #[cfg(target_os = "windows")]
 407    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 408        command
 409            .stdin(Stdio::piped())
 410            .stdout(Stdio::piped())
 411            .stderr(Stdio::piped())
 412            .args(self.connection_options.additional_args())
 413            .envs(self.envs.clone())
 414    }
 415
 416    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 417    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 418    #[cfg(not(target_os = "windows"))]
 419    fn ssh_args(&self) -> SshArgs {
 420        let mut arguments = self.connection_options.additional_args();
 421        arguments.extend(vec![
 422            "-o".to_string(),
 423            "ControlMaster=no".to_string(),
 424            "-o".to_string(),
 425            format!("ControlPath={}", self.socket_path.display()),
 426            self.connection_options.ssh_url(),
 427        ]);
 428        SshArgs {
 429            arguments,
 430            envs: None,
 431        }
 432    }
 433
 434    #[cfg(target_os = "windows")]
 435    fn ssh_args(&self) -> SshArgs {
 436        let mut arguments = self.connection_options.additional_args();
 437        arguments.push(self.connection_options.ssh_url());
 438        SshArgs {
 439            arguments,
 440            envs: Some(self.envs.clone()),
 441        }
 442    }
 443
 444    async fn platform(&self) -> Result<SshPlatform> {
 445        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 446        let Some((os, arch)) = uname.split_once(" ") else {
 447            anyhow::bail!("unknown uname: {uname:?}")
 448        };
 449
 450        let os = match os.trim() {
 451            "Darwin" => "macos",
 452            "Linux" => "linux",
 453            _ => anyhow::bail!(
 454                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 455            ),
 456        };
 457        // exclude armv5,6,7 as they are 32-bit.
 458        let arch = if arch.starts_with("armv8")
 459            || arch.starts_with("armv9")
 460            || arch.starts_with("arm64")
 461            || arch.starts_with("aarch64")
 462        {
 463            "aarch64"
 464        } else if arch.starts_with("x86") {
 465            "x86_64"
 466        } else {
 467            anyhow::bail!(
 468                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 469            )
 470        };
 471
 472        Ok(SshPlatform { os, arch })
 473    }
 474}
 475
 476const MAX_MISSED_HEARTBEATS: usize = 5;
 477const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 478const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 479
 480const MAX_RECONNECT_ATTEMPTS: usize = 3;
 481
 482enum State {
 483    Connecting,
 484    Connected {
 485        ssh_connection: Arc<dyn RemoteConnection>,
 486        delegate: Arc<dyn SshClientDelegate>,
 487
 488        multiplex_task: Task<Result<()>>,
 489        heartbeat_task: Task<Result<()>>,
 490    },
 491    HeartbeatMissed {
 492        missed_heartbeats: usize,
 493
 494        ssh_connection: Arc<dyn RemoteConnection>,
 495        delegate: Arc<dyn SshClientDelegate>,
 496
 497        multiplex_task: Task<Result<()>>,
 498        heartbeat_task: Task<Result<()>>,
 499    },
 500    Reconnecting,
 501    ReconnectFailed {
 502        ssh_connection: Arc<dyn RemoteConnection>,
 503        delegate: Arc<dyn SshClientDelegate>,
 504
 505        error: anyhow::Error,
 506        attempts: usize,
 507    },
 508    ReconnectExhausted,
 509    ServerNotRunning,
 510}
 511
 512impl fmt::Display for State {
 513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 514        match self {
 515            Self::Connecting => write!(f, "connecting"),
 516            Self::Connected { .. } => write!(f, "connected"),
 517            Self::Reconnecting => write!(f, "reconnecting"),
 518            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 519            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 520            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 521            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 522        }
 523    }
 524}
 525
 526impl State {
 527    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 528        match self {
 529            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 530            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 531            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 532            _ => None,
 533        }
 534    }
 535
 536    fn can_reconnect(&self) -> bool {
 537        match self {
 538            Self::Connected { .. }
 539            | Self::HeartbeatMissed { .. }
 540            | Self::ReconnectFailed { .. } => true,
 541            State::Connecting
 542            | State::Reconnecting
 543            | State::ReconnectExhausted
 544            | State::ServerNotRunning => false,
 545        }
 546    }
 547
 548    fn is_reconnect_failed(&self) -> bool {
 549        matches!(self, Self::ReconnectFailed { .. })
 550    }
 551
 552    fn is_reconnect_exhausted(&self) -> bool {
 553        matches!(self, Self::ReconnectExhausted { .. })
 554    }
 555
 556    fn is_server_not_running(&self) -> bool {
 557        matches!(self, Self::ServerNotRunning)
 558    }
 559
 560    fn is_reconnecting(&self) -> bool {
 561        matches!(self, Self::Reconnecting { .. })
 562    }
 563
 564    fn heartbeat_recovered(self) -> Self {
 565        match self {
 566            Self::HeartbeatMissed {
 567                ssh_connection,
 568                delegate,
 569                multiplex_task,
 570                heartbeat_task,
 571                ..
 572            } => Self::Connected {
 573                ssh_connection,
 574                delegate,
 575                multiplex_task,
 576                heartbeat_task,
 577            },
 578            _ => self,
 579        }
 580    }
 581
 582    fn heartbeat_missed(self) -> Self {
 583        match self {
 584            Self::Connected {
 585                ssh_connection,
 586                delegate,
 587                multiplex_task,
 588                heartbeat_task,
 589            } => Self::HeartbeatMissed {
 590                missed_heartbeats: 1,
 591                ssh_connection,
 592                delegate,
 593                multiplex_task,
 594                heartbeat_task,
 595            },
 596            Self::HeartbeatMissed {
 597                missed_heartbeats,
 598                ssh_connection,
 599                delegate,
 600                multiplex_task,
 601                heartbeat_task,
 602            } => Self::HeartbeatMissed {
 603                missed_heartbeats: missed_heartbeats + 1,
 604                ssh_connection,
 605                delegate,
 606                multiplex_task,
 607                heartbeat_task,
 608            },
 609            _ => self,
 610        }
 611    }
 612}
 613
 614/// The state of the ssh connection.
 615#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 616pub enum ConnectionState {
 617    Connecting,
 618    Connected,
 619    HeartbeatMissed,
 620    Reconnecting,
 621    Disconnected,
 622}
 623
 624impl From<&State> for ConnectionState {
 625    fn from(value: &State) -> Self {
 626        match value {
 627            State::Connecting => Self::Connecting,
 628            State::Connected { .. } => Self::Connected,
 629            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 630            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 631            State::ReconnectExhausted => Self::Disconnected,
 632            State::ServerNotRunning => Self::Disconnected,
 633        }
 634    }
 635}
 636
 637pub struct SshRemoteClient {
 638    client: Arc<ChannelClient>,
 639    unique_identifier: String,
 640    connection_options: SshConnectionOptions,
 641    path_style: PathStyle,
 642    state: Arc<Mutex<Option<State>>>,
 643}
 644
 645#[derive(Debug)]
 646pub enum SshRemoteEvent {
 647    Disconnected,
 648}
 649
 650impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 651
 652// Identifies the socket on the remote server so that reconnects
 653// can re-join the same project.
 654pub enum ConnectionIdentifier {
 655    Setup(u64),
 656    Workspace(i64),
 657}
 658
 659static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 660
 661impl ConnectionIdentifier {
 662    pub fn setup() -> Self {
 663        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 664    }
 665
 666    // This string gets used in a socket name, and so must be relatively short.
 667    // The total length of:
 668    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 669    // Must be less than about 100 characters
 670    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 671    // So our strings should be at most 20 characters or so.
 672    fn to_string(&self, cx: &App) -> String {
 673        let identifier_prefix = match ReleaseChannel::global(cx) {
 674            ReleaseChannel::Stable => "".to_string(),
 675            release_channel => format!("{}-", release_channel.dev_name()),
 676        };
 677        match self {
 678            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 679            Self::Workspace(workspace_id) => {
 680                format!("{identifier_prefix}workspace-{workspace_id}",)
 681            }
 682        }
 683    }
 684}
 685
 686impl SshRemoteClient {
 687    pub fn new(
 688        unique_identifier: ConnectionIdentifier,
 689        connection_options: SshConnectionOptions,
 690        cancellation: oneshot::Receiver<()>,
 691        delegate: Arc<dyn SshClientDelegate>,
 692        cx: &mut App,
 693    ) -> Task<Result<Option<Entity<Self>>>> {
 694        let unique_identifier = unique_identifier.to_string(cx);
 695        cx.spawn(async move |cx| {
 696            let success = Box::pin(async move {
 697                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 698                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 699                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 700
 701                let client =
 702                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 703
 704                let ssh_connection = cx
 705                    .update(|cx| {
 706                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 707                            pool.connect(connection_options.clone(), &delegate, cx)
 708                        })
 709                    })?
 710                    .await
 711                    .map_err(|e| e.cloned())?;
 712
 713                let path_style = ssh_connection.path_style();
 714                let this = cx.new(|_| Self {
 715                    client: client.clone(),
 716                    unique_identifier: unique_identifier.clone(),
 717                    connection_options,
 718                    path_style,
 719                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 720                })?;
 721
 722                let io_task = ssh_connection.start_proxy(
 723                    unique_identifier,
 724                    false,
 725                    incoming_tx,
 726                    outgoing_rx,
 727                    connection_activity_tx,
 728                    delegate.clone(),
 729                    cx,
 730                );
 731
 732                let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 733
 734                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 735                    log::error!("failed to establish connection: {}", error);
 736                    return Err(error);
 737                }
 738
 739                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 740
 741                this.update(cx, |this, _| {
 742                    *this.state.lock() = Some(State::Connected {
 743                        ssh_connection,
 744                        delegate,
 745                        multiplex_task,
 746                        heartbeat_task,
 747                    });
 748                })?;
 749
 750                Ok(Some(this))
 751            });
 752
 753            select! {
 754                _ = cancellation.fuse() => {
 755                    Ok(None)
 756                }
 757                result = success.fuse() =>  result
 758            }
 759        })
 760    }
 761
 762    pub fn proto_client_from_channels(
 763        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 764        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 765        cx: &App,
 766        name: &'static str,
 767    ) -> AnyProtoClient {
 768        ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
 769    }
 770
 771    pub fn shutdown_processes<T: RequestMessage>(
 772        &self,
 773        shutdown_request: Option<T>,
 774        executor: BackgroundExecutor,
 775    ) -> Option<impl Future<Output = ()> + use<T>> {
 776        let state = self.state.lock().take()?;
 777        log::info!("shutting down ssh processes");
 778
 779        let State::Connected {
 780            multiplex_task,
 781            heartbeat_task,
 782            ssh_connection,
 783            delegate,
 784        } = state
 785        else {
 786            return None;
 787        };
 788
 789        let client = self.client.clone();
 790
 791        Some(async move {
 792            if let Some(shutdown_request) = shutdown_request {
 793                client.send(shutdown_request).log_err();
 794                // We wait 50ms instead of waiting for a response, because
 795                // waiting for a response would require us to wait on the main thread
 796                // which we want to avoid in an `on_app_quit` callback.
 797                executor.timer(Duration::from_millis(50)).await;
 798            }
 799
 800            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 801            // child of master_process.
 802            drop(multiplex_task);
 803            // Now drop the rest of state, which kills master process.
 804            drop(heartbeat_task);
 805            drop(ssh_connection);
 806            drop(delegate);
 807        })
 808    }
 809
 810    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 811        let mut lock = self.state.lock();
 812
 813        let can_reconnect = lock
 814            .as_ref()
 815            .map(|state| state.can_reconnect())
 816            .unwrap_or(false);
 817        if !can_reconnect {
 818            log::info!("aborting reconnect, because not in state that allows reconnecting");
 819            let error = if let Some(state) = lock.as_ref() {
 820                format!("invalid state, cannot reconnect while in state {state}")
 821            } else {
 822                "no state set".to_string()
 823            };
 824            anyhow::bail!(error);
 825        }
 826
 827        let state = lock.take().unwrap();
 828        let (attempts, ssh_connection, delegate) = match state {
 829            State::Connected {
 830                ssh_connection,
 831                delegate,
 832                multiplex_task,
 833                heartbeat_task,
 834            }
 835            | State::HeartbeatMissed {
 836                ssh_connection,
 837                delegate,
 838                multiplex_task,
 839                heartbeat_task,
 840                ..
 841            } => {
 842                drop(multiplex_task);
 843                drop(heartbeat_task);
 844                (0, ssh_connection, delegate)
 845            }
 846            State::ReconnectFailed {
 847                attempts,
 848                ssh_connection,
 849                delegate,
 850                ..
 851            } => (attempts, ssh_connection, delegate),
 852            State::Connecting
 853            | State::Reconnecting
 854            | State::ReconnectExhausted
 855            | State::ServerNotRunning => unreachable!(),
 856        };
 857
 858        let attempts = attempts + 1;
 859        if attempts > MAX_RECONNECT_ATTEMPTS {
 860            log::error!(
 861                "Failed to reconnect to after {} attempts, giving up",
 862                MAX_RECONNECT_ATTEMPTS
 863            );
 864            drop(lock);
 865            self.set_state(State::ReconnectExhausted, cx);
 866            return Ok(());
 867        }
 868        drop(lock);
 869
 870        self.set_state(State::Reconnecting, cx);
 871
 872        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 873
 874        let unique_identifier = self.unique_identifier.clone();
 875        let client = self.client.clone();
 876        let reconnect_task = cx.spawn(async move |this, cx| {
 877            macro_rules! failed {
 878                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 879                    return State::ReconnectFailed {
 880                        error: anyhow!($error),
 881                        attempts: $attempts,
 882                        ssh_connection: $ssh_connection,
 883                        delegate: $delegate,
 884                    };
 885                };
 886            }
 887
 888            if let Err(error) = ssh_connection
 889                .kill()
 890                .await
 891                .context("Failed to kill ssh process")
 892            {
 893                failed!(error, attempts, ssh_connection, delegate);
 894            };
 895
 896            let connection_options = ssh_connection.connection_options();
 897
 898            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 899            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 900            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 901
 902            let (ssh_connection, io_task) = match async {
 903                let ssh_connection = cx
 904                    .update_global(|pool: &mut ConnectionPool, cx| {
 905                        pool.connect(connection_options, &delegate, cx)
 906                    })?
 907                    .await
 908                    .map_err(|error| error.cloned())?;
 909
 910                let io_task = ssh_connection.start_proxy(
 911                    unique_identifier,
 912                    true,
 913                    incoming_tx,
 914                    outgoing_rx,
 915                    connection_activity_tx,
 916                    delegate.clone(),
 917                    cx,
 918                );
 919                anyhow::Ok((ssh_connection, io_task))
 920            }
 921            .await
 922            {
 923                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 924                Err(error) => {
 925                    failed!(error, attempts, ssh_connection, delegate);
 926                }
 927            };
 928
 929            let multiplex_task = Self::monitor(this.clone(), io_task, cx);
 930            client.reconnect(incoming_rx, outgoing_tx, cx);
 931
 932            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 933                failed!(error, attempts, ssh_connection, delegate);
 934            };
 935
 936            State::Connected {
 937                ssh_connection,
 938                delegate,
 939                multiplex_task,
 940                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 941            }
 942        });
 943
 944        cx.spawn(async move |this, cx| {
 945            let new_state = reconnect_task.await;
 946            this.update(cx, |this, cx| {
 947                this.try_set_state(cx, |old_state| {
 948                    if old_state.is_reconnecting() {
 949                        match &new_state {
 950                            State::Connecting
 951                            | State::Reconnecting { .. }
 952                            | State::HeartbeatMissed { .. }
 953                            | State::ServerNotRunning => {}
 954                            State::Connected { .. } => {
 955                                log::info!("Successfully reconnected");
 956                            }
 957                            State::ReconnectFailed {
 958                                error, attempts, ..
 959                            } => {
 960                                log::error!(
 961                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 962                                    attempts,
 963                                    error
 964                                );
 965                            }
 966                            State::ReconnectExhausted => {
 967                                log::error!("Reconnect attempt failed and all attempts exhausted");
 968                            }
 969                        }
 970                        Some(new_state)
 971                    } else {
 972                        None
 973                    }
 974                });
 975
 976                if this.state_is(State::is_reconnect_failed) {
 977                    this.reconnect(cx)
 978                } else if this.state_is(State::is_reconnect_exhausted) {
 979                    Ok(())
 980                } else {
 981                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 982                    Ok(())
 983                }
 984            })
 985        })
 986        .detach_and_log_err(cx);
 987
 988        Ok(())
 989    }
 990
 991    fn heartbeat(
 992        this: WeakEntity<Self>,
 993        mut connection_activity_rx: mpsc::Receiver<()>,
 994        cx: &mut AsyncApp,
 995    ) -> Task<Result<()>> {
 996        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 997            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 998        };
 999
1000        cx.spawn(async move |cx| {
1001            let mut missed_heartbeats = 0;
1002
1003            let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
1004            futures::pin_mut!(keepalive_timer);
1005
1006            loop {
1007                select_biased! {
1008                    result = connection_activity_rx.next().fuse() => {
1009                        if result.is_none() {
1010                            log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1011                            return Ok(());
1012                        }
1013
1014                        if missed_heartbeats != 0 {
1015                            missed_heartbeats = 0;
1016                            let _ =this.update(cx, |this, cx| {
1017                                this.handle_heartbeat_result(missed_heartbeats, cx)
1018                            })?;
1019                        }
1020                    }
1021                    _ = keepalive_timer => {
1022                        log::debug!("Sending heartbeat to server...");
1023
1024                        let result = select_biased! {
1025                            _ = connection_activity_rx.next().fuse() => {
1026                                Ok(())
1027                            }
1028                            ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1029                                ping_result
1030                            }
1031                        };
1032
1033                        if result.is_err() {
1034                            missed_heartbeats += 1;
1035                            log::warn!(
1036                                "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1037                                HEARTBEAT_TIMEOUT,
1038                                missed_heartbeats,
1039                                MAX_MISSED_HEARTBEATS
1040                            );
1041                        } else if missed_heartbeats != 0 {
1042                            missed_heartbeats = 0;
1043                        } else {
1044                            continue;
1045                        }
1046
1047                        let result = this.update(cx, |this, cx| {
1048                            this.handle_heartbeat_result(missed_heartbeats, cx)
1049                        })?;
1050                        if result.is_break() {
1051                            return Ok(());
1052                        }
1053                    }
1054                }
1055
1056                keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1057            }
1058        })
1059    }
1060
1061    fn handle_heartbeat_result(
1062        &mut self,
1063        missed_heartbeats: usize,
1064        cx: &mut Context<Self>,
1065    ) -> ControlFlow<()> {
1066        let state = self.state.lock().take().unwrap();
1067        let next_state = if missed_heartbeats > 0 {
1068            state.heartbeat_missed()
1069        } else {
1070            state.heartbeat_recovered()
1071        };
1072
1073        self.set_state(next_state, cx);
1074
1075        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1076            log::error!(
1077                "Missed last {} heartbeats. Reconnecting...",
1078                missed_heartbeats
1079            );
1080
1081            self.reconnect(cx)
1082                .context("failed to start reconnect process after missing heartbeats")
1083                .log_err();
1084            ControlFlow::Break(())
1085        } else {
1086            ControlFlow::Continue(())
1087        }
1088    }
1089
1090    fn monitor(
1091        this: WeakEntity<Self>,
1092        io_task: Task<Result<i32>>,
1093        cx: &AsyncApp,
1094    ) -> Task<Result<()>> {
1095        cx.spawn(async move |cx| {
1096            let result = io_task.await;
1097
1098            match result {
1099                Ok(exit_code) => {
1100                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1101                        match error {
1102                            ProxyLaunchError::ServerNotRunning => {
1103                                log::error!("failed to reconnect because server is not running");
1104                                this.update(cx, |this, cx| {
1105                                    this.set_state(State::ServerNotRunning, cx);
1106                                })?;
1107                            }
1108                        }
1109                    } else if exit_code > 0 {
1110                        log::error!("proxy process terminated unexpectedly");
1111                        this.update(cx, |this, cx| {
1112                            this.reconnect(cx).ok();
1113                        })?;
1114                    }
1115                }
1116                Err(error) => {
1117                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1118                    this.update(cx, |this, cx| {
1119                        this.reconnect(cx).ok();
1120                    })?;
1121                }
1122            }
1123
1124            Ok(())
1125        })
1126    }
1127
1128    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1129        self.state.lock().as_ref().is_some_and(check)
1130    }
1131
1132    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1133        let mut lock = self.state.lock();
1134        let new_state = lock.as_ref().and_then(map);
1135
1136        if let Some(new_state) = new_state {
1137            lock.replace(new_state);
1138            cx.notify();
1139        }
1140    }
1141
1142    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1143        log::info!("setting state to '{}'", &state);
1144
1145        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1146        let is_server_not_running = state.is_server_not_running();
1147        self.state.lock().replace(state);
1148
1149        if is_reconnect_exhausted || is_server_not_running {
1150            cx.emit(SshRemoteEvent::Disconnected);
1151        }
1152        cx.notify();
1153    }
1154
1155    pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1156        self.state
1157            .lock()
1158            .as_ref()
1159            .and_then(|state| state.ssh_connection())
1160            .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1161    }
1162
1163    pub fn upload_directory(
1164        &self,
1165        src_path: PathBuf,
1166        dest_path: RemotePathBuf,
1167        cx: &App,
1168    ) -> Task<Result<()>> {
1169        let state = self.state.lock();
1170        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1171            return Task::ready(Err(anyhow!("no ssh connection")));
1172        };
1173        connection.upload_directory(src_path, dest_path, cx)
1174    }
1175
1176    pub fn proto_client(&self) -> AnyProtoClient {
1177        self.client.clone().into()
1178    }
1179
1180    pub fn connection_string(&self) -> String {
1181        self.connection_options.connection_string()
1182    }
1183
1184    pub fn connection_options(&self) -> SshConnectionOptions {
1185        self.connection_options.clone()
1186    }
1187
1188    pub fn connection_state(&self) -> ConnectionState {
1189        self.state
1190            .lock()
1191            .as_ref()
1192            .map(ConnectionState::from)
1193            .unwrap_or(ConnectionState::Disconnected)
1194    }
1195
1196    pub fn is_disconnected(&self) -> bool {
1197        self.connection_state() == ConnectionState::Disconnected
1198    }
1199
1200    pub fn path_style(&self) -> PathStyle {
1201        self.path_style
1202    }
1203
1204    #[cfg(any(test, feature = "test-support"))]
1205    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1206        let opts = self.connection_options();
1207        client_cx.spawn(async move |cx| {
1208            let connection = cx
1209                .update_global(|c: &mut ConnectionPool, _| {
1210                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1211                        c.clone()
1212                    } else {
1213                        panic!("missing test connection")
1214                    }
1215                })
1216                .unwrap()
1217                .await
1218                .unwrap();
1219
1220            connection.simulate_disconnect(cx);
1221        })
1222    }
1223
1224    #[cfg(any(test, feature = "test-support"))]
1225    pub fn fake_server(
1226        client_cx: &mut gpui::TestAppContext,
1227        server_cx: &mut gpui::TestAppContext,
1228    ) -> (SshConnectionOptions, AnyProtoClient) {
1229        let port = client_cx
1230            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1231        let opts = SshConnectionOptions {
1232            host: "<fake>".to_string(),
1233            port: Some(port),
1234            ..Default::default()
1235        };
1236        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1237        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1238        let server_client =
1239            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1240        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1241            connection_options: opts.clone(),
1242            server_cx: fake::SendableCx::new(server_cx),
1243            server_channel: server_client.clone(),
1244        });
1245
1246        client_cx.update(|cx| {
1247            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1248                c.connections.insert(
1249                    opts.clone(),
1250                    ConnectionPoolEntry::Connecting(
1251                        cx.background_spawn({
1252                            let connection = connection.clone();
1253                            async move { Ok(connection.clone()) }
1254                        })
1255                        .shared(),
1256                    ),
1257                );
1258            })
1259        });
1260
1261        (opts, server_client.into())
1262    }
1263
1264    #[cfg(any(test, feature = "test-support"))]
1265    pub async fn fake_client(
1266        opts: SshConnectionOptions,
1267        client_cx: &mut gpui::TestAppContext,
1268    ) -> Entity<Self> {
1269        let (_tx, rx) = oneshot::channel();
1270        client_cx
1271            .update(|cx| {
1272                Self::new(
1273                    ConnectionIdentifier::setup(),
1274                    opts,
1275                    rx,
1276                    Arc::new(fake::Delegate),
1277                    cx,
1278                )
1279            })
1280            .await
1281            .unwrap()
1282            .unwrap()
1283    }
1284}
1285
1286enum ConnectionPoolEntry {
1287    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1288    Connected(Weak<dyn RemoteConnection>),
1289}
1290
1291#[derive(Default)]
1292struct ConnectionPool {
1293    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1294}
1295
1296impl Global for ConnectionPool {}
1297
1298impl ConnectionPool {
1299    pub fn connect(
1300        &mut self,
1301        opts: SshConnectionOptions,
1302        delegate: &Arc<dyn SshClientDelegate>,
1303        cx: &mut App,
1304    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1305        let connection = self.connections.get(&opts);
1306        match connection {
1307            Some(ConnectionPoolEntry::Connecting(task)) => {
1308                let delegate = delegate.clone();
1309                cx.spawn(async move |cx| {
1310                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1311                })
1312                .detach();
1313                return task.clone();
1314            }
1315            Some(ConnectionPoolEntry::Connected(ssh)) => {
1316                if let Some(ssh) = ssh.upgrade()
1317                    && !ssh.has_been_killed()
1318                {
1319                    return Task::ready(Ok(ssh)).shared();
1320                }
1321                self.connections.remove(&opts);
1322            }
1323            None => {}
1324        }
1325
1326        let task = cx
1327            .spawn({
1328                let opts = opts.clone();
1329                let delegate = delegate.clone();
1330                async move |cx| {
1331                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1332                        .await
1333                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1334
1335                    cx.update_global(|pool: &mut Self, _| {
1336                        debug_assert!(matches!(
1337                            pool.connections.get(&opts),
1338                            Some(ConnectionPoolEntry::Connecting(_))
1339                        ));
1340                        match connection {
1341                            Ok(connection) => {
1342                                pool.connections.insert(
1343                                    opts.clone(),
1344                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1345                                );
1346                                Ok(connection)
1347                            }
1348                            Err(error) => {
1349                                pool.connections.remove(&opts);
1350                                Err(Arc::new(error))
1351                            }
1352                        }
1353                    })?
1354                }
1355            })
1356            .shared();
1357
1358        self.connections
1359            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1360        task
1361    }
1362}
1363
1364impl From<SshRemoteClient> for AnyProtoClient {
1365    fn from(client: SshRemoteClient) -> Self {
1366        AnyProtoClient::new(client.client.clone())
1367    }
1368}
1369
1370#[async_trait(?Send)]
1371trait RemoteConnection: Send + Sync {
1372    fn start_proxy(
1373        &self,
1374        unique_identifier: String,
1375        reconnect: bool,
1376        incoming_tx: UnboundedSender<Envelope>,
1377        outgoing_rx: UnboundedReceiver<Envelope>,
1378        connection_activity_tx: Sender<()>,
1379        delegate: Arc<dyn SshClientDelegate>,
1380        cx: &mut AsyncApp,
1381    ) -> Task<Result<i32>>;
1382    fn upload_directory(
1383        &self,
1384        src_path: PathBuf,
1385        dest_path: RemotePathBuf,
1386        cx: &App,
1387    ) -> Task<Result<()>>;
1388    async fn kill(&self) -> Result<()>;
1389    fn has_been_killed(&self) -> bool;
1390    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1391    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1392    fn ssh_args(&self) -> SshArgs;
1393    fn connection_options(&self) -> SshConnectionOptions;
1394    fn path_style(&self) -> PathStyle;
1395
1396    #[cfg(any(test, feature = "test-support"))]
1397    fn simulate_disconnect(&self, _: &AsyncApp) {}
1398}
1399
1400struct SshRemoteConnection {
1401    socket: SshSocket,
1402    master_process: Mutex<Option<Child>>,
1403    remote_binary_path: Option<RemotePathBuf>,
1404    ssh_platform: SshPlatform,
1405    ssh_path_style: PathStyle,
1406    _temp_dir: TempDir,
1407}
1408
1409#[async_trait(?Send)]
1410impl RemoteConnection for SshRemoteConnection {
1411    async fn kill(&self) -> Result<()> {
1412        let Some(mut process) = self.master_process.lock().take() else {
1413            return Ok(());
1414        };
1415        process.kill().ok();
1416        process.status().await?;
1417        Ok(())
1418    }
1419
1420    fn has_been_killed(&self) -> bool {
1421        self.master_process.lock().is_none()
1422    }
1423
1424    fn ssh_args(&self) -> SshArgs {
1425        self.socket.ssh_args()
1426    }
1427
1428    fn connection_options(&self) -> SshConnectionOptions {
1429        self.socket.connection_options.clone()
1430    }
1431
1432    fn upload_directory(
1433        &self,
1434        src_path: PathBuf,
1435        dest_path: RemotePathBuf,
1436        cx: &App,
1437    ) -> Task<Result<()>> {
1438        let mut command = util::command::new_smol_command("scp");
1439        let output = self
1440            .socket
1441            .ssh_options(&mut command)
1442            .args(
1443                self.socket
1444                    .connection_options
1445                    .port
1446                    .map(|port| vec!["-P".to_string(), port.to_string()])
1447                    .unwrap_or_default(),
1448            )
1449            .arg("-C")
1450            .arg("-r")
1451            .arg(&src_path)
1452            .arg(format!(
1453                "{}:{}",
1454                self.socket.connection_options.scp_url(),
1455                dest_path
1456            ))
1457            .output();
1458
1459        cx.background_spawn(async move {
1460            let output = output.await?;
1461
1462            anyhow::ensure!(
1463                output.status.success(),
1464                "failed to upload directory {} -> {}: {}",
1465                src_path.display(),
1466                dest_path.to_string(),
1467                String::from_utf8_lossy(&output.stderr)
1468            );
1469
1470            Ok(())
1471        })
1472    }
1473
1474    fn start_proxy(
1475        &self,
1476        unique_identifier: String,
1477        reconnect: bool,
1478        incoming_tx: UnboundedSender<Envelope>,
1479        outgoing_rx: UnboundedReceiver<Envelope>,
1480        connection_activity_tx: Sender<()>,
1481        delegate: Arc<dyn SshClientDelegate>,
1482        cx: &mut AsyncApp,
1483    ) -> Task<Result<i32>> {
1484        delegate.set_status(Some("Starting proxy"), cx);
1485
1486        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1487            return Task::ready(Err(anyhow!("Remote binary path not set")));
1488        };
1489
1490        let mut start_proxy_command = shell_script!(
1491            "exec {binary_path} proxy --identifier {identifier}",
1492            binary_path = &remote_binary_path.to_string(),
1493            identifier = &unique_identifier,
1494        );
1495
1496        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
1497            if let Some(value) = std::env::var(env_var).ok() {
1498                start_proxy_command = format!(
1499                    "{}={} {} ",
1500                    env_var,
1501                    shlex::try_quote(&value).unwrap(),
1502                    start_proxy_command,
1503                );
1504            }
1505        }
1506
1507        if reconnect {
1508            start_proxy_command.push_str(" --reconnect");
1509        }
1510
1511        let ssh_proxy_process = match self
1512            .socket
1513            .ssh_command("sh", &["-c", &start_proxy_command])
1514            // IMPORTANT: we kill this process when we drop the task that uses it.
1515            .kill_on_drop(true)
1516            .spawn()
1517        {
1518            Ok(process) => process,
1519            Err(error) => {
1520                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1521            }
1522        };
1523
1524        Self::multiplex(
1525            ssh_proxy_process,
1526            incoming_tx,
1527            outgoing_rx,
1528            connection_activity_tx,
1529            cx,
1530        )
1531    }
1532
1533    fn path_style(&self) -> PathStyle {
1534        self.ssh_path_style
1535    }
1536}
1537
1538impl SshRemoteConnection {
1539    async fn new(
1540        connection_options: SshConnectionOptions,
1541        delegate: Arc<dyn SshClientDelegate>,
1542        cx: &mut AsyncApp,
1543    ) -> Result<Self> {
1544        use askpass::AskPassResult;
1545
1546        delegate.set_status(Some("Connecting"), cx);
1547
1548        let url = connection_options.ssh_url();
1549
1550        let temp_dir = tempfile::Builder::new()
1551            .prefix("zed-ssh-session")
1552            .tempdir()?;
1553        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1554            let delegate = delegate.clone();
1555            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1556        });
1557
1558        let mut askpass =
1559            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1560
1561        // Start the master SSH process, which does not do anything except for establish
1562        // the connection and keep it open, allowing other ssh commands to reuse it
1563        // via a control socket.
1564        #[cfg(not(target_os = "windows"))]
1565        let socket_path = temp_dir.path().join("ssh.sock");
1566
1567        let mut master_process = {
1568            #[cfg(not(target_os = "windows"))]
1569            let args = [
1570                "-N",
1571                "-o",
1572                "ControlPersist=no",
1573                "-o",
1574                "ControlMaster=yes",
1575                "-o",
1576            ];
1577            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1578            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1579            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1580            #[cfg(target_os = "windows")]
1581            let args = ["-N"];
1582            let mut master_process = util::command::new_smol_command("ssh");
1583            master_process
1584                .kill_on_drop(true)
1585                .stdin(Stdio::null())
1586                .stdout(Stdio::piped())
1587                .stderr(Stdio::piped())
1588                .env("SSH_ASKPASS_REQUIRE", "force")
1589                .env("SSH_ASKPASS", askpass.script_path())
1590                .args(connection_options.additional_args())
1591                .args(args);
1592            #[cfg(not(target_os = "windows"))]
1593            master_process.arg(format!("ControlPath={}", socket_path.display()));
1594            master_process.arg(&url).spawn()?
1595        };
1596        // Wait for this ssh process to close its stdout, indicating that authentication
1597        // has completed.
1598        let mut stdout = master_process.stdout.take().unwrap();
1599        let mut output = Vec::new();
1600
1601        let result = select_biased! {
1602            result = askpass.run().fuse() => {
1603                match result {
1604                    AskPassResult::CancelledByUser => {
1605                        master_process.kill().ok();
1606                        anyhow::bail!("SSH connection canceled")
1607                    }
1608                    AskPassResult::Timedout => {
1609                        anyhow::bail!("connecting to host timed out")
1610                    }
1611                }
1612            }
1613            _ = stdout.read_to_end(&mut output).fuse() => {
1614                anyhow::Ok(())
1615            }
1616        };
1617
1618        if let Err(e) = result {
1619            return Err(e.context("Failed to connect to host"));
1620        }
1621
1622        if master_process.try_status()?.is_some() {
1623            output.clear();
1624            let mut stderr = master_process.stderr.take().unwrap();
1625            stderr.read_to_end(&mut output).await?;
1626
1627            let error_message = format!(
1628                "failed to connect: {}",
1629                String::from_utf8_lossy(&output).trim()
1630            );
1631            anyhow::bail!(error_message);
1632        }
1633
1634        #[cfg(not(target_os = "windows"))]
1635        let socket = SshSocket::new(connection_options, socket_path)?;
1636        #[cfg(target_os = "windows")]
1637        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1638        drop(askpass);
1639
1640        let ssh_platform = socket.platform().await?;
1641        let ssh_path_style = match ssh_platform.os {
1642            "windows" => PathStyle::Windows,
1643            _ => PathStyle::Posix,
1644        };
1645
1646        let mut this = Self {
1647            socket,
1648            master_process: Mutex::new(Some(master_process)),
1649            _temp_dir: temp_dir,
1650            remote_binary_path: None,
1651            ssh_path_style,
1652            ssh_platform,
1653        };
1654
1655        let (release_channel, version, commit) = cx.update(|cx| {
1656            (
1657                ReleaseChannel::global(cx),
1658                AppVersion::global(cx),
1659                AppCommitSha::try_global(cx),
1660            )
1661        })?;
1662        this.remote_binary_path = Some(
1663            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1664                .await?,
1665        );
1666
1667        Ok(this)
1668    }
1669
1670    fn multiplex(
1671        mut ssh_proxy_process: Child,
1672        incoming_tx: UnboundedSender<Envelope>,
1673        mut outgoing_rx: UnboundedReceiver<Envelope>,
1674        mut connection_activity_tx: Sender<()>,
1675        cx: &AsyncApp,
1676    ) -> Task<Result<i32>> {
1677        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1678        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1679        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1680
1681        let mut stdin_buffer = Vec::new();
1682        let mut stdout_buffer = Vec::new();
1683        let mut stderr_buffer = Vec::new();
1684        let mut stderr_offset = 0;
1685
1686        let stdin_task = cx.background_spawn(async move {
1687            while let Some(outgoing) = outgoing_rx.next().await {
1688                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1689            }
1690            anyhow::Ok(())
1691        });
1692
1693        let stdout_task = cx.background_spawn({
1694            let mut connection_activity_tx = connection_activity_tx.clone();
1695            async move {
1696                loop {
1697                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1698                    let len = child_stdout.read(&mut stdout_buffer).await?;
1699
1700                    if len == 0 {
1701                        return anyhow::Ok(());
1702                    }
1703
1704                    if len < MESSAGE_LEN_SIZE {
1705                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1706                    }
1707
1708                    let message_len = message_len_from_buffer(&stdout_buffer);
1709                    let envelope =
1710                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1711                            .await?;
1712                    connection_activity_tx.try_send(()).ok();
1713                    incoming_tx.unbounded_send(envelope).ok();
1714                }
1715            }
1716        });
1717
1718        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1719            loop {
1720                stderr_buffer.resize(stderr_offset + 1024, 0);
1721
1722                let len = child_stderr
1723                    .read(&mut stderr_buffer[stderr_offset..])
1724                    .await?;
1725                if len == 0 {
1726                    return anyhow::Ok(());
1727                }
1728
1729                stderr_offset += len;
1730                let mut start_ix = 0;
1731                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1732                    .iter()
1733                    .position(|b| b == &b'\n')
1734                {
1735                    let line_ix = start_ix + ix;
1736                    let content = &stderr_buffer[start_ix..line_ix];
1737                    start_ix = line_ix + 1;
1738                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1739                        record.log(log::logger())
1740                    } else {
1741                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1742                    }
1743                }
1744                stderr_buffer.drain(0..start_ix);
1745                stderr_offset -= start_ix;
1746
1747                connection_activity_tx.try_send(()).ok();
1748            }
1749        });
1750
1751        cx.background_spawn(async move {
1752            let result = futures::select! {
1753                result = stdin_task.fuse() => {
1754                    result.context("stdin")
1755                }
1756                result = stdout_task.fuse() => {
1757                    result.context("stdout")
1758                }
1759                result = stderr_task.fuse() => {
1760                    result.context("stderr")
1761                }
1762            };
1763
1764            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1765            match result {
1766                Ok(_) => Ok(status),
1767                Err(error) => Err(error),
1768            }
1769        })
1770    }
1771
1772    #[allow(unused)]
1773    async fn ensure_server_binary(
1774        &self,
1775        delegate: &Arc<dyn SshClientDelegate>,
1776        release_channel: ReleaseChannel,
1777        version: SemanticVersion,
1778        commit: Option<AppCommitSha>,
1779        cx: &mut AsyncApp,
1780    ) -> Result<RemotePathBuf> {
1781        let version_str = match release_channel {
1782            ReleaseChannel::Nightly => {
1783                let commit = commit.map(|s| s.full()).unwrap_or_default();
1784                format!("{}-{}", version, commit)
1785            }
1786            ReleaseChannel::Dev => "build".to_string(),
1787            _ => version.to_string(),
1788        };
1789        let binary_name = format!(
1790            "zed-remote-server-{}-{}",
1791            release_channel.dev_name(),
1792            version_str
1793        );
1794        let dst_path = RemotePathBuf::new(
1795            paths::remote_server_dir_relative().join(binary_name),
1796            self.ssh_path_style,
1797        );
1798
1799        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1800        #[cfg(debug_assertions)]
1801        if let Some(build_remote_server) = build_remote_server {
1802            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1803            let tmp_path = RemotePathBuf::new(
1804                paths::remote_server_dir_relative().join(format!(
1805                    "download-{}-{}",
1806                    std::process::id(),
1807                    src_path.file_name().unwrap().to_string_lossy()
1808                )),
1809                self.ssh_path_style,
1810            );
1811            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1812                .await?;
1813            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1814                .await?;
1815            return Ok(dst_path);
1816        }
1817
1818        if self
1819            .socket
1820            .run_command(&dst_path.to_string(), &["version"])
1821            .await
1822            .is_ok()
1823        {
1824            return Ok(dst_path);
1825        }
1826
1827        let wanted_version = cx.update(|cx| match release_channel {
1828            ReleaseChannel::Nightly => Ok(None),
1829            ReleaseChannel::Dev => {
1830                anyhow::bail!(
1831                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1832                    dst_path
1833                )
1834            }
1835            _ => Ok(Some(AppVersion::global(cx))),
1836        })??;
1837
1838        let tmp_path_gz = RemotePathBuf::new(
1839            PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
1840            self.ssh_path_style,
1841        );
1842        if !self.socket.connection_options.upload_binary_over_ssh
1843            && let Some((url, body)) = delegate
1844                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1845                .await?
1846        {
1847            match self
1848                .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1849                .await
1850            {
1851                Ok(_) => {
1852                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1853                        .await?;
1854                    return Ok(dst_path);
1855                }
1856                Err(e) => {
1857                    log::error!(
1858                        "Failed to download binary on server, attempting to upload server: {}",
1859                        e
1860                    )
1861                }
1862            }
1863        }
1864
1865        let src_path = delegate
1866            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1867            .await?;
1868        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1869            .await?;
1870        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1871            .await?;
1872        Ok(dst_path)
1873    }
1874
1875    async fn download_binary_on_server(
1876        &self,
1877        url: &str,
1878        body: &str,
1879        tmp_path_gz: &RemotePathBuf,
1880        delegate: &Arc<dyn SshClientDelegate>,
1881        cx: &mut AsyncApp,
1882    ) -> Result<()> {
1883        if let Some(parent) = tmp_path_gz.parent() {
1884            self.socket
1885                .run_command(
1886                    "sh",
1887                    &[
1888                        "-c",
1889                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1890                    ],
1891                )
1892                .await?;
1893        }
1894
1895        delegate.set_status(Some("Downloading remote development server on host"), cx);
1896
1897        match self
1898            .socket
1899            .run_command(
1900                "curl",
1901                &[
1902                    "-f",
1903                    "-L",
1904                    "-X",
1905                    "GET",
1906                    "-H",
1907                    "Content-Type: application/json",
1908                    "-d",
1909                    body,
1910                    url,
1911                    "-o",
1912                    &tmp_path_gz.to_string(),
1913                ],
1914            )
1915            .await
1916        {
1917            Ok(_) => {}
1918            Err(e) => {
1919                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1920                    return Err(e);
1921                }
1922
1923                match self
1924                    .socket
1925                    .run_command(
1926                        "wget",
1927                        &[
1928                            "--method=GET",
1929                            "--header=Content-Type: application/json",
1930                            "--body-data",
1931                            body,
1932                            url,
1933                            "-O",
1934                            &tmp_path_gz.to_string(),
1935                        ],
1936                    )
1937                    .await
1938                {
1939                    Ok(_) => {}
1940                    Err(e) => {
1941                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1942                            return Err(e);
1943                        } else {
1944                            anyhow::bail!("Neither curl nor wget is available");
1945                        }
1946                    }
1947                }
1948            }
1949        }
1950
1951        Ok(())
1952    }
1953
1954    async fn upload_local_server_binary(
1955        &self,
1956        src_path: &Path,
1957        tmp_path_gz: &RemotePathBuf,
1958        delegate: &Arc<dyn SshClientDelegate>,
1959        cx: &mut AsyncApp,
1960    ) -> Result<()> {
1961        if let Some(parent) = tmp_path_gz.parent() {
1962            self.socket
1963                .run_command(
1964                    "sh",
1965                    &[
1966                        "-c",
1967                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1968                    ],
1969                )
1970                .await?;
1971        }
1972
1973        let src_stat = fs::metadata(&src_path).await?;
1974        let size = src_stat.len();
1975
1976        let t0 = Instant::now();
1977        delegate.set_status(Some("Uploading remote development server"), cx);
1978        log::info!(
1979            "uploading remote development server to {:?} ({}kb)",
1980            tmp_path_gz,
1981            size / 1024
1982        );
1983        self.upload_file(src_path, tmp_path_gz)
1984            .await
1985            .context("failed to upload server binary")?;
1986        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1987        Ok(())
1988    }
1989
1990    async fn extract_server_binary(
1991        &self,
1992        dst_path: &RemotePathBuf,
1993        tmp_path: &RemotePathBuf,
1994        delegate: &Arc<dyn SshClientDelegate>,
1995        cx: &mut AsyncApp,
1996    ) -> Result<()> {
1997        delegate.set_status(Some("Extracting remote development server"), cx);
1998        let server_mode = 0o755;
1999
2000        let orig_tmp_path = tmp_path.to_string();
2001        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2002            shell_script!(
2003                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2004                server_mode = &format!("{:o}", server_mode),
2005                dst_path = &dst_path.to_string(),
2006            )
2007        } else {
2008            shell_script!(
2009                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2010                server_mode = &format!("{:o}", server_mode),
2011                dst_path = &dst_path.to_string()
2012            )
2013        };
2014        self.socket.run_command("sh", &["-c", &script]).await?;
2015        Ok(())
2016    }
2017
2018    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2019        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2020        let mut command = util::command::new_smol_command("scp");
2021        let output = self
2022            .socket
2023            .ssh_options(&mut command)
2024            .args(
2025                self.socket
2026                    .connection_options
2027                    .port
2028                    .map(|port| vec!["-P".to_string(), port.to_string()])
2029                    .unwrap_or_default(),
2030            )
2031            .arg(src_path)
2032            .arg(format!(
2033                "{}:{}",
2034                self.socket.connection_options.scp_url(),
2035                dest_path
2036            ))
2037            .output()
2038            .await?;
2039
2040        anyhow::ensure!(
2041            output.status.success(),
2042            "failed to upload file {} -> {}: {}",
2043            src_path.display(),
2044            dest_path.to_string(),
2045            String::from_utf8_lossy(&output.stderr)
2046        );
2047        Ok(())
2048    }
2049
2050    #[cfg(debug_assertions)]
2051    async fn build_local(
2052        &self,
2053        build_remote_server: String,
2054        delegate: &Arc<dyn SshClientDelegate>,
2055        cx: &mut AsyncApp,
2056    ) -> Result<PathBuf> {
2057        use smol::process::{Command, Stdio};
2058        use std::env::VarError;
2059
2060        async fn run_cmd(command: &mut Command) -> Result<()> {
2061            let output = command
2062                .kill_on_drop(true)
2063                .stderr(Stdio::inherit())
2064                .output()
2065                .await?;
2066            anyhow::ensure!(
2067                output.status.success(),
2068                "Failed to run command: {command:?}"
2069            );
2070            Ok(())
2071        }
2072
2073        let use_musl = !build_remote_server.contains("nomusl");
2074        let triple = format!(
2075            "{}-{}",
2076            self.ssh_platform.arch,
2077            match self.ssh_platform.os {
2078                "linux" =>
2079                    if use_musl {
2080                        "unknown-linux-musl"
2081                    } else {
2082                        "unknown-linux-gnu"
2083                    },
2084                "macos" => "apple-darwin",
2085                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2086            }
2087        );
2088        let mut rust_flags = match std::env::var("RUSTFLAGS") {
2089            Ok(val) => val,
2090            Err(VarError::NotPresent) => String::new(),
2091            Err(e) => {
2092                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2093                String::new()
2094            }
2095        };
2096        if self.ssh_platform.os == "linux" && use_musl {
2097            rust_flags.push_str(" -C target-feature=+crt-static");
2098        }
2099        if build_remote_server.contains("mold") {
2100            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2101        }
2102
2103        if self.ssh_platform.arch == std::env::consts::ARCH
2104            && self.ssh_platform.os == std::env::consts::OS
2105        {
2106            delegate.set_status(Some("Building remote server binary from source"), cx);
2107            log::info!("building remote server binary from source");
2108            run_cmd(
2109                Command::new("cargo")
2110                    .args([
2111                        "build",
2112                        "--package",
2113                        "remote_server",
2114                        "--features",
2115                        "debug-embed",
2116                        "--target-dir",
2117                        "target/remote_server",
2118                        "--target",
2119                        &triple,
2120                    ])
2121                    .env("RUSTFLAGS", &rust_flags),
2122            )
2123            .await?;
2124        } else if build_remote_server.contains("cross") {
2125            #[cfg(target_os = "windows")]
2126            use util::paths::SanitizedPath;
2127
2128            delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2129            log::info!("installing cross");
2130            run_cmd(Command::new("cargo").args([
2131                "install",
2132                "cross",
2133                "--git",
2134                "https://github.com/cross-rs/cross",
2135            ]))
2136            .await?;
2137
2138            delegate.set_status(
2139                Some(&format!(
2140                    "Building remote server binary from source for {} with Docker",
2141                    &triple
2142                )),
2143                cx,
2144            );
2145            log::info!("building remote server binary from source for {}", &triple);
2146
2147            // On Windows, the binding needs to be set to the canonical path
2148            #[cfg(target_os = "windows")]
2149            let src =
2150                SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2151            #[cfg(not(target_os = "windows"))]
2152            let src = "./target";
2153            run_cmd(
2154                Command::new("cross")
2155                    .args([
2156                        "build",
2157                        "--package",
2158                        "remote_server",
2159                        "--features",
2160                        "debug-embed",
2161                        "--target-dir",
2162                        "target/remote_server",
2163                        "--target",
2164                        &triple,
2165                    ])
2166                    .env(
2167                        "CROSS_CONTAINER_OPTS",
2168                        format!("--mount type=bind,src={src},dst=/app/target"),
2169                    )
2170                    .env("RUSTFLAGS", &rust_flags),
2171            )
2172            .await?;
2173        } else {
2174            let which = cx
2175                .background_spawn(async move { which::which("zig") })
2176                .await;
2177
2178            if which.is_err() {
2179                #[cfg(not(target_os = "windows"))]
2180                {
2181                    anyhow::bail!(
2182                        "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2183                    )
2184                }
2185                #[cfg(target_os = "windows")]
2186                {
2187                    anyhow::bail!(
2188                        "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2189                    )
2190                }
2191            }
2192
2193            delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2194            log::info!("adding rustup target");
2195            run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2196
2197            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2198            log::info!("installing cargo-zigbuild");
2199            run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
2200
2201            delegate.set_status(
2202                Some(&format!(
2203                    "Building remote binary from source for {triple} with Zig"
2204                )),
2205                cx,
2206            );
2207            log::info!("building remote binary from source for {triple} with Zig");
2208            run_cmd(
2209                Command::new("cargo")
2210                    .args([
2211                        "zigbuild",
2212                        "--package",
2213                        "remote_server",
2214                        "--features",
2215                        "debug-embed",
2216                        "--target-dir",
2217                        "target/remote_server",
2218                        "--target",
2219                        &triple,
2220                    ])
2221                    .env("RUSTFLAGS", &rust_flags),
2222            )
2223            .await?;
2224        };
2225        let bin_path = Path::new("target")
2226            .join("remote_server")
2227            .join(&triple)
2228            .join("debug")
2229            .join("remote_server");
2230
2231        let path = if !build_remote_server.contains("nocompress") {
2232            delegate.set_status(Some("Compressing binary"), cx);
2233
2234            #[cfg(not(target_os = "windows"))]
2235            {
2236                run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
2237            }
2238            #[cfg(target_os = "windows")]
2239            {
2240                // On Windows, we use 7z to compress the binary
2241                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2242                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2243                if smol::fs::metadata(&gz_path).await.is_ok() {
2244                    smol::fs::remove_file(&gz_path).await?;
2245                }
2246                run_cmd(Command::new(seven_zip).args([
2247                    "a",
2248                    "-tgzip",
2249                    &gz_path,
2250                    &bin_path.to_string_lossy(),
2251                ]))
2252                .await?;
2253            }
2254
2255            let mut archive_path = bin_path;
2256            archive_path.set_extension("gz");
2257            std::env::current_dir()?.join(archive_path)
2258        } else {
2259            bin_path
2260        };
2261
2262        Ok(path)
2263    }
2264}
2265
2266type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2267
2268struct ChannelClient {
2269    next_message_id: AtomicU32,
2270    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2271    buffer: Mutex<VecDeque<Envelope>>,
2272    response_channels: ResponseChannels,
2273    message_handlers: Mutex<ProtoMessageHandlerSet>,
2274    max_received: AtomicU32,
2275    name: &'static str,
2276    task: Mutex<Task<Result<()>>>,
2277}
2278
2279impl ChannelClient {
2280    fn new(
2281        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2282        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2283        cx: &App,
2284        name: &'static str,
2285    ) -> Arc<Self> {
2286        Arc::new_cyclic(|this| Self {
2287            outgoing_tx: Mutex::new(outgoing_tx),
2288            next_message_id: AtomicU32::new(0),
2289            max_received: AtomicU32::new(0),
2290            response_channels: ResponseChannels::default(),
2291            message_handlers: Default::default(),
2292            buffer: Mutex::new(VecDeque::new()),
2293            name,
2294            task: Mutex::new(Self::start_handling_messages(
2295                this.clone(),
2296                incoming_rx,
2297                &cx.to_async(),
2298            )),
2299        })
2300    }
2301
2302    fn start_handling_messages(
2303        this: Weak<Self>,
2304        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2305        cx: &AsyncApp,
2306    ) -> Task<Result<()>> {
2307        cx.spawn(async move |cx| {
2308            let peer_id = PeerId { owner_id: 0, id: 0 };
2309            while let Some(incoming) = incoming_rx.next().await {
2310                let Some(this) = this.upgrade() else {
2311                    return anyhow::Ok(());
2312                };
2313                if let Some(ack_id) = incoming.ack_id {
2314                    let mut buffer = this.buffer.lock();
2315                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2316                        buffer.pop_front();
2317                    }
2318                }
2319                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2320                {
2321                    log::debug!(
2322                        "{}:ssh message received. name:FlushBufferedMessages",
2323                        this.name
2324                    );
2325                    {
2326                        let buffer = this.buffer.lock();
2327                        for envelope in buffer.iter() {
2328                            this.outgoing_tx
2329                                .lock()
2330                                .unbounded_send(envelope.clone())
2331                                .ok();
2332                        }
2333                    }
2334                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2335                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2336                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2337                    continue;
2338                }
2339
2340                this.max_received.store(incoming.id, SeqCst);
2341
2342                if let Some(request_id) = incoming.responding_to {
2343                    let request_id = MessageId(request_id);
2344                    let sender = this.response_channels.lock().remove(&request_id);
2345                    if let Some(sender) = sender {
2346                        let (tx, rx) = oneshot::channel();
2347                        if incoming.payload.is_some() {
2348                            sender.send((incoming, tx)).ok();
2349                        }
2350                        rx.await.ok();
2351                    }
2352                } else if let Some(envelope) =
2353                    build_typed_envelope(peer_id, Instant::now(), incoming)
2354                {
2355                    let type_name = envelope.payload_type_name();
2356                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2357                        &this.message_handlers,
2358                        envelope,
2359                        this.clone().into(),
2360                        cx.clone(),
2361                    ) {
2362                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2363                        cx.foreground_executor()
2364                            .spawn(async move {
2365                                match future.await {
2366                                    Ok(_) => {
2367                                        log::debug!(
2368                                            "{}:ssh message handled. name:{type_name}",
2369                                            this.name
2370                                        );
2371                                    }
2372                                    Err(error) => {
2373                                        log::error!(
2374                                            "{}:error handling message. type:{}, error:{}",
2375                                            this.name,
2376                                            type_name,
2377                                            format!("{error:#}").lines().fold(
2378                                                String::new(),
2379                                                |mut message, line| {
2380                                                    if !message.is_empty() {
2381                                                        message.push(' ');
2382                                                    }
2383                                                    message.push_str(line);
2384                                                    message
2385                                                }
2386                                            )
2387                                        );
2388                                    }
2389                                }
2390                            })
2391                            .detach()
2392                    } else {
2393                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2394                    }
2395                }
2396            }
2397            anyhow::Ok(())
2398        })
2399    }
2400
2401    fn reconnect(
2402        self: &Arc<Self>,
2403        incoming_rx: UnboundedReceiver<Envelope>,
2404        outgoing_tx: UnboundedSender<Envelope>,
2405        cx: &AsyncApp,
2406    ) {
2407        *self.outgoing_tx.lock() = outgoing_tx;
2408        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2409    }
2410
2411    fn request<T: RequestMessage>(
2412        &self,
2413        payload: T,
2414    ) -> impl 'static + Future<Output = Result<T::Response>> {
2415        self.request_internal(payload, true)
2416    }
2417
2418    fn request_internal<T: RequestMessage>(
2419        &self,
2420        payload: T,
2421        use_buffer: bool,
2422    ) -> impl 'static + Future<Output = Result<T::Response>> {
2423        log::debug!("ssh request start. name:{}", T::NAME);
2424        let response =
2425            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2426        async move {
2427            let response = response.await?;
2428            log::debug!("ssh request finish. name:{}", T::NAME);
2429            T::Response::from_envelope(response).context("received a response of the wrong type")
2430        }
2431    }
2432
2433    async fn resync(&self, timeout: Duration) -> Result<()> {
2434        smol::future::or(
2435            async {
2436                self.request_internal(proto::FlushBufferedMessages {}, false)
2437                    .await?;
2438
2439                for envelope in self.buffer.lock().iter() {
2440                    self.outgoing_tx
2441                        .lock()
2442                        .unbounded_send(envelope.clone())
2443                        .ok();
2444                }
2445                Ok(())
2446            },
2447            async {
2448                smol::Timer::after(timeout).await;
2449                anyhow::bail!("Timed out resyncing remote client")
2450            },
2451        )
2452        .await
2453    }
2454
2455    async fn ping(&self, timeout: Duration) -> Result<()> {
2456        smol::future::or(
2457            async {
2458                self.request(proto::Ping {}).await?;
2459                Ok(())
2460            },
2461            async {
2462                smol::Timer::after(timeout).await;
2463                anyhow::bail!("Timed out pinging remote client")
2464            },
2465        )
2466        .await
2467    }
2468
2469    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2470        log::debug!("ssh send name:{}", T::NAME);
2471        self.send_dynamic(payload.into_envelope(0, None, None))
2472    }
2473
2474    fn request_dynamic(
2475        &self,
2476        mut envelope: proto::Envelope,
2477        type_name: &'static str,
2478        use_buffer: bool,
2479    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2480        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2481        let (tx, rx) = oneshot::channel();
2482        let mut response_channels_lock = self.response_channels.lock();
2483        response_channels_lock.insert(MessageId(envelope.id), tx);
2484        drop(response_channels_lock);
2485
2486        let result = if use_buffer {
2487            self.send_buffered(envelope)
2488        } else {
2489            self.send_unbuffered(envelope)
2490        };
2491        async move {
2492            if let Err(error) = &result {
2493                log::error!("failed to send message: {error}");
2494                anyhow::bail!("failed to send message: {error}");
2495            }
2496
2497            let response = rx.await.context("connection lost")?.0;
2498            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2499                return Err(RpcError::from_proto(error, type_name));
2500            }
2501            Ok(response)
2502        }
2503    }
2504
2505    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2506        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2507        self.send_buffered(envelope)
2508    }
2509
2510    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2511        envelope.ack_id = Some(self.max_received.load(SeqCst));
2512        self.buffer.lock().push_back(envelope.clone());
2513        // ignore errors on send (happen while we're reconnecting)
2514        // assume that the global "disconnected" overlay is sufficient.
2515        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2516        Ok(())
2517    }
2518
2519    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2520        envelope.ack_id = Some(self.max_received.load(SeqCst));
2521        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2522        Ok(())
2523    }
2524}
2525
2526impl ProtoClient for ChannelClient {
2527    fn request(
2528        &self,
2529        envelope: proto::Envelope,
2530        request_type: &'static str,
2531    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2532        self.request_dynamic(envelope, request_type, true).boxed()
2533    }
2534
2535    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2536        self.send_dynamic(envelope)
2537    }
2538
2539    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2540        self.send_dynamic(envelope)
2541    }
2542
2543    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2544        &self.message_handlers
2545    }
2546
2547    fn is_via_collab(&self) -> bool {
2548        false
2549    }
2550}
2551
2552#[cfg(any(test, feature = "test-support"))]
2553mod fake {
2554    use std::{path::PathBuf, sync::Arc};
2555
2556    use anyhow::Result;
2557    use async_trait::async_trait;
2558    use futures::{
2559        FutureExt, SinkExt, StreamExt,
2560        channel::{
2561            mpsc::{self, Sender},
2562            oneshot,
2563        },
2564        select_biased,
2565    };
2566    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2567    use release_channel::ReleaseChannel;
2568    use rpc::proto::Envelope;
2569    use util::paths::{PathStyle, RemotePathBuf};
2570
2571    use super::{
2572        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2573        SshPlatform,
2574    };
2575
2576    pub(super) struct FakeRemoteConnection {
2577        pub(super) connection_options: SshConnectionOptions,
2578        pub(super) server_channel: Arc<ChannelClient>,
2579        pub(super) server_cx: SendableCx,
2580    }
2581
2582    pub(super) struct SendableCx(AsyncApp);
2583    impl SendableCx {
2584        // SAFETY: When run in test mode, GPUI is always single threaded.
2585        pub(super) fn new(cx: &TestAppContext) -> Self {
2586            Self(cx.to_async())
2587        }
2588
2589        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2590        fn get(&self, _: &AsyncApp) -> AsyncApp {
2591            self.0.clone()
2592        }
2593    }
2594
2595    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2596    unsafe impl Send for SendableCx {}
2597    unsafe impl Sync for SendableCx {}
2598
2599    #[async_trait(?Send)]
2600    impl RemoteConnection for FakeRemoteConnection {
2601        async fn kill(&self) -> Result<()> {
2602            Ok(())
2603        }
2604
2605        fn has_been_killed(&self) -> bool {
2606            false
2607        }
2608
2609        fn ssh_args(&self) -> SshArgs {
2610            SshArgs {
2611                arguments: Vec::new(),
2612                envs: None,
2613            }
2614        }
2615
2616        fn upload_directory(
2617            &self,
2618            _src_path: PathBuf,
2619            _dest_path: RemotePathBuf,
2620            _cx: &App,
2621        ) -> Task<Result<()>> {
2622            unreachable!()
2623        }
2624
2625        fn connection_options(&self) -> SshConnectionOptions {
2626            self.connection_options.clone()
2627        }
2628
2629        fn simulate_disconnect(&self, cx: &AsyncApp) {
2630            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2631            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2632            self.server_channel
2633                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
2634        }
2635
2636        fn start_proxy(
2637            &self,
2638            _unique_identifier: String,
2639            _reconnect: bool,
2640            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2641            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2642            mut connection_activity_tx: Sender<()>,
2643            _delegate: Arc<dyn SshClientDelegate>,
2644            cx: &mut AsyncApp,
2645        ) -> Task<Result<i32>> {
2646            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2647            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2648
2649            self.server_channel.reconnect(
2650                server_incoming_rx,
2651                server_outgoing_tx,
2652                &self.server_cx.get(cx),
2653            );
2654
2655            cx.background_spawn(async move {
2656                loop {
2657                    select_biased! {
2658                        server_to_client = server_outgoing_rx.next().fuse() => {
2659                            let Some(server_to_client) = server_to_client else {
2660                                return Ok(1)
2661                            };
2662                            connection_activity_tx.try_send(()).ok();
2663                            client_incoming_tx.send(server_to_client).await.ok();
2664                        }
2665                        client_to_server = client_outgoing_rx.next().fuse() => {
2666                            let Some(client_to_server) = client_to_server else {
2667                                return Ok(1)
2668                            };
2669                            server_incoming_tx.send(client_to_server).await.ok();
2670                        }
2671                    }
2672                }
2673            })
2674        }
2675
2676        fn path_style(&self) -> PathStyle {
2677            PathStyle::current()
2678        }
2679    }
2680
2681    pub(super) struct Delegate;
2682
2683    impl SshClientDelegate for Delegate {
2684        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2685            unreachable!()
2686        }
2687
2688        fn download_server_binary_locally(
2689            &self,
2690            _: SshPlatform,
2691            _: ReleaseChannel,
2692            _: Option<SemanticVersion>,
2693            _: &mut AsyncApp,
2694        ) -> Task<Result<PathBuf>> {
2695            unreachable!()
2696        }
2697
2698        fn get_download_params(
2699            &self,
2700            _platform: SshPlatform,
2701            _release_channel: ReleaseChannel,
2702            _version: Option<SemanticVersion>,
2703            _cx: &mut AsyncApp,
2704        ) -> Task<Result<Option<(String, String)>>> {
2705            unreachable!()
2706        }
2707
2708        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2709    }
2710}