ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  30    RpcError,
  31    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  32};
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize};
  35use smol::{
  36    fs,
  37    process::{self, Child, Stdio},
  38};
  39use std::{
  40    any::TypeId,
  41    collections::VecDeque,
  42    fmt, iter,
  43    ops::ControlFlow,
  44    path::{Path, PathBuf},
  45    sync::{
  46        Arc, Weak,
  47        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  48    },
  49    time::{Duration, Instant},
  50};
  51use tempfile::TempDir;
  52use util::{
  53    ResultExt,
  54    paths::{PathStyle, RemotePathBuf},
  55};
  56
  57#[derive(
  58    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  59)]
  60pub struct SshProjectId(pub u64);
  61
  62#[derive(Clone)]
  63pub struct SshSocket {
  64    connection_options: SshConnectionOptions,
  65    #[cfg(not(target_os = "windows"))]
  66    socket_path: PathBuf,
  67    #[cfg(target_os = "windows")]
  68    envs: HashMap<String, String>,
  69}
  70
  71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  72pub struct SshPortForwardOption {
  73    #[serde(skip_serializing_if = "Option::is_none")]
  74    pub local_host: Option<String>,
  75    pub local_port: u16,
  76    #[serde(skip_serializing_if = "Option::is_none")]
  77    pub remote_host: Option<String>,
  78    pub remote_port: u16,
  79}
  80
  81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  82pub struct SshConnectionOptions {
  83    pub host: String,
  84    pub username: Option<String>,
  85    pub port: Option<u16>,
  86    pub password: Option<String>,
  87    pub args: Option<Vec<String>>,
  88    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  89
  90    pub nickname: Option<String>,
  91    pub upload_binary_over_ssh: bool,
  92}
  93
  94pub struct SshArgs {
  95    pub arguments: Vec<String>,
  96    pub envs: Option<HashMap<String, String>>,
  97}
  98
  99#[macro_export]
 100macro_rules! shell_script {
 101    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 102        format!(
 103            $fmt,
 104            $(
 105                $name = shlex::try_quote($arg).unwrap()
 106            ),+
 107        )
 108    }};
 109}
 110
 111fn parse_port_number(port_str: &str) -> Result<u16> {
 112    port_str
 113        .parse()
 114        .with_context(|| format!("parsing port number: {port_str}"))
 115}
 116
 117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 118    let parts: Vec<&str> = spec.split(':').collect();
 119
 120    match parts.len() {
 121        4 => {
 122            let local_port = parse_port_number(parts[1])?;
 123            let remote_port = parse_port_number(parts[3])?;
 124
 125            Ok(SshPortForwardOption {
 126                local_host: Some(parts[0].to_string()),
 127                local_port,
 128                remote_host: Some(parts[2].to_string()),
 129                remote_port,
 130            })
 131        }
 132        3 => {
 133            let local_port = parse_port_number(parts[0])?;
 134            let remote_port = parse_port_number(parts[2])?;
 135
 136            Ok(SshPortForwardOption {
 137                local_host: None,
 138                local_port,
 139                remote_host: Some(parts[1].to_string()),
 140                remote_port,
 141            })
 142        }
 143        _ => anyhow::bail!("Invalid port forward format"),
 144    }
 145}
 146
 147impl SshConnectionOptions {
 148    pub fn parse_command_line(input: &str) -> Result<Self> {
 149        let input = input.trim_start_matches("ssh ");
 150        let mut hostname: Option<String> = None;
 151        let mut username: Option<String> = None;
 152        let mut port: Option<u16> = None;
 153        let mut args = Vec::new();
 154        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 155
 156        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 157        const ALLOWED_OPTS: &[&str] = &[
 158            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 159        ];
 160        const ALLOWED_ARGS: &[&str] = &[
 161            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 162            "-w",
 163        ];
 164
 165        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 166
 167        'outer: while let Some(arg) = tokens.next() {
 168            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 169                args.push(arg.to_string());
 170                continue;
 171            }
 172            if arg == "-p" {
 173                port = tokens.next().and_then(|arg| arg.parse().ok());
 174                continue;
 175            } else if let Some(p) = arg.strip_prefix("-p") {
 176                port = p.parse().ok();
 177                continue;
 178            }
 179            if arg == "-l" {
 180                username = tokens.next();
 181                continue;
 182            } else if let Some(l) = arg.strip_prefix("-l") {
 183                username = Some(l.to_string());
 184                continue;
 185            }
 186            if arg == "-L" || arg.starts_with("-L") {
 187                let forward_spec = if arg == "-L" {
 188                    tokens.next()
 189                } else {
 190                    Some(arg.strip_prefix("-L").unwrap().to_string())
 191                };
 192
 193                if let Some(spec) = forward_spec {
 194                    port_forwards.push(parse_port_forward_spec(&spec)?);
 195                } else {
 196                    anyhow::bail!("Missing port forward format");
 197                }
 198            }
 199
 200            for a in ALLOWED_ARGS {
 201                if arg == *a {
 202                    args.push(arg);
 203                    if let Some(next) = tokens.next() {
 204                        args.push(next);
 205                    }
 206                    continue 'outer;
 207                } else if arg.starts_with(a) {
 208                    args.push(arg);
 209                    continue 'outer;
 210                }
 211            }
 212            if arg.starts_with("-") || hostname.is_some() {
 213                anyhow::bail!("unsupported argument: {:?}", arg);
 214            }
 215            let mut input = &arg as &str;
 216            // Destination might be: username1@username2@ip2@ip1
 217            if let Some((u, rest)) = input.rsplit_once('@') {
 218                input = rest;
 219                username = Some(u.to_string());
 220            }
 221            if let Some((rest, p)) = input.split_once(':') {
 222                input = rest;
 223                port = p.parse().ok()
 224            }
 225            hostname = Some(input.to_string())
 226        }
 227
 228        let Some(hostname) = hostname else {
 229            anyhow::bail!("missing hostname");
 230        };
 231
 232        let port_forwards = match port_forwards.len() {
 233            0 => None,
 234            _ => Some(port_forwards),
 235        };
 236
 237        Ok(Self {
 238            host: hostname.to_string(),
 239            username: username.clone(),
 240            port,
 241            port_forwards,
 242            args: Some(args),
 243            password: None,
 244            nickname: None,
 245            upload_binary_over_ssh: false,
 246        })
 247    }
 248
 249    pub fn ssh_url(&self) -> String {
 250        let mut result = String::from("ssh://");
 251        if let Some(username) = &self.username {
 252            // Username might be: username1@username2@ip2
 253            let username = urlencoding::encode(username);
 254            result.push_str(&username);
 255            result.push('@');
 256        }
 257        result.push_str(&self.host);
 258        if let Some(port) = self.port {
 259            result.push(':');
 260            result.push_str(&port.to_string());
 261        }
 262        result
 263    }
 264
 265    pub fn additional_args(&self) -> Vec<String> {
 266        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 267
 268        if let Some(forwards) = &self.port_forwards {
 269            args.extend(forwards.iter().map(|pf| {
 270                let local_host = match &pf.local_host {
 271                    Some(host) => host,
 272                    None => "localhost",
 273                };
 274                let remote_host = match &pf.remote_host {
 275                    Some(host) => host,
 276                    None => "localhost",
 277                };
 278
 279                format!(
 280                    "-L{}:{}:{}:{}",
 281                    local_host, pf.local_port, remote_host, pf.remote_port
 282                )
 283            }));
 284        }
 285
 286        args
 287    }
 288
 289    fn scp_url(&self) -> String {
 290        if let Some(username) = &self.username {
 291            format!("{}@{}", username, self.host)
 292        } else {
 293            self.host.clone()
 294        }
 295    }
 296
 297    pub fn connection_string(&self) -> String {
 298        let host = if let Some(username) = &self.username {
 299            format!("{}@{}", username, self.host)
 300        } else {
 301            self.host.clone()
 302        };
 303        if let Some(port) = &self.port {
 304            format!("{}:{}", host, port)
 305        } else {
 306            host
 307        }
 308    }
 309}
 310
 311#[derive(Copy, Clone, Debug)]
 312pub struct SshPlatform {
 313    pub os: &'static str,
 314    pub arch: &'static str,
 315}
 316
 317pub trait SshClientDelegate: Send + Sync {
 318    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 319    fn get_download_params(
 320        &self,
 321        platform: SshPlatform,
 322        release_channel: ReleaseChannel,
 323        version: Option<SemanticVersion>,
 324        cx: &mut AsyncApp,
 325    ) -> Task<Result<Option<(String, String)>>>;
 326
 327    fn download_server_binary_locally(
 328        &self,
 329        platform: SshPlatform,
 330        release_channel: ReleaseChannel,
 331        version: Option<SemanticVersion>,
 332        cx: &mut AsyncApp,
 333    ) -> Task<Result<PathBuf>>;
 334    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 335}
 336
 337impl SshSocket {
 338    #[cfg(not(target_os = "windows"))]
 339    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 340        Ok(Self {
 341            connection_options: options,
 342            socket_path,
 343        })
 344    }
 345
 346    #[cfg(target_os = "windows")]
 347    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 348        let askpass_script = temp_dir.path().join("askpass.bat");
 349        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 350        let mut envs = HashMap::default();
 351        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 352        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 353        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 354        Ok(Self {
 355            connection_options: options,
 356            envs,
 357        })
 358    }
 359
 360    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 361    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 362    // and passes -l as an argument to sh, not to ls.
 363    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 364    // into a machine. You must use `cd` to get back to $HOME.
 365    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 366    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 367        let mut command = util::command::new_smol_command("ssh");
 368        let to_run = iter::once(&program)
 369            .chain(args.iter())
 370            .map(|token| {
 371                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 372                debug_assert!(
 373                    !token.contains('\n'),
 374                    "multiline arguments do not work in all shells"
 375                );
 376                shlex::try_quote(token).unwrap()
 377            })
 378            .join(" ");
 379        let to_run = format!("cd; {to_run}");
 380        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 381        self.ssh_options(&mut command)
 382            .arg(self.connection_options.ssh_url())
 383            .arg(to_run);
 384        command
 385    }
 386
 387    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 388        let output = self.ssh_command(program, args).output().await?;
 389        anyhow::ensure!(
 390            output.status.success(),
 391            "failed to run command: {}",
 392            String::from_utf8_lossy(&output.stderr)
 393        );
 394        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 395    }
 396
 397    #[cfg(not(target_os = "windows"))]
 398    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 399        command
 400            .stdin(Stdio::piped())
 401            .stdout(Stdio::piped())
 402            .stderr(Stdio::piped())
 403            .args(["-o", "ControlMaster=no", "-o"])
 404            .arg(format!("ControlPath={}", self.socket_path.display()))
 405    }
 406
 407    #[cfg(target_os = "windows")]
 408    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 409        command
 410            .stdin(Stdio::piped())
 411            .stdout(Stdio::piped())
 412            .stderr(Stdio::piped())
 413            .envs(self.envs.clone())
 414    }
 415
 416    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 417    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 418    #[cfg(not(target_os = "windows"))]
 419    fn ssh_args(&self) -> SshArgs {
 420        SshArgs {
 421            arguments: vec![
 422                "-o".to_string(),
 423                "ControlMaster=no".to_string(),
 424                "-o".to_string(),
 425                format!("ControlPath={}", self.socket_path.display()),
 426                self.connection_options.ssh_url(),
 427            ],
 428            envs: None,
 429        }
 430    }
 431
 432    #[cfg(target_os = "windows")]
 433    fn ssh_args(&self) -> SshArgs {
 434        SshArgs {
 435            arguments: vec![self.connection_options.ssh_url()],
 436            envs: Some(self.envs.clone()),
 437        }
 438    }
 439
 440    async fn platform(&self) -> Result<SshPlatform> {
 441        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 442        let Some((os, arch)) = uname.split_once(" ") else {
 443            anyhow::bail!("unknown uname: {uname:?}")
 444        };
 445
 446        let os = match os.trim() {
 447            "Darwin" => "macos",
 448            "Linux" => "linux",
 449            _ => anyhow::bail!(
 450                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 451            ),
 452        };
 453        // exclude armv5,6,7 as they are 32-bit.
 454        let arch = if arch.starts_with("armv8")
 455            || arch.starts_with("armv9")
 456            || arch.starts_with("arm64")
 457            || arch.starts_with("aarch64")
 458        {
 459            "aarch64"
 460        } else if arch.starts_with("x86") {
 461            "x86_64"
 462        } else {
 463            anyhow::bail!(
 464                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 465            )
 466        };
 467
 468        Ok(SshPlatform { os, arch })
 469    }
 470}
 471
 472const MAX_MISSED_HEARTBEATS: usize = 5;
 473const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 474const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 475
 476const MAX_RECONNECT_ATTEMPTS: usize = 3;
 477
 478enum State {
 479    Connecting,
 480    Connected {
 481        ssh_connection: Arc<dyn RemoteConnection>,
 482        delegate: Arc<dyn SshClientDelegate>,
 483
 484        multiplex_task: Task<Result<()>>,
 485        heartbeat_task: Task<Result<()>>,
 486    },
 487    HeartbeatMissed {
 488        missed_heartbeats: usize,
 489
 490        ssh_connection: Arc<dyn RemoteConnection>,
 491        delegate: Arc<dyn SshClientDelegate>,
 492
 493        multiplex_task: Task<Result<()>>,
 494        heartbeat_task: Task<Result<()>>,
 495    },
 496    Reconnecting,
 497    ReconnectFailed {
 498        ssh_connection: Arc<dyn RemoteConnection>,
 499        delegate: Arc<dyn SshClientDelegate>,
 500
 501        error: anyhow::Error,
 502        attempts: usize,
 503    },
 504    ReconnectExhausted,
 505    ServerNotRunning,
 506}
 507
 508impl fmt::Display for State {
 509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 510        match self {
 511            Self::Connecting => write!(f, "connecting"),
 512            Self::Connected { .. } => write!(f, "connected"),
 513            Self::Reconnecting => write!(f, "reconnecting"),
 514            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 515            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 516            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 517            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 518        }
 519    }
 520}
 521
 522impl State {
 523    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 524        match self {
 525            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 526            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 527            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 528            _ => None,
 529        }
 530    }
 531
 532    fn can_reconnect(&self) -> bool {
 533        match self {
 534            Self::Connected { .. }
 535            | Self::HeartbeatMissed { .. }
 536            | Self::ReconnectFailed { .. } => true,
 537            State::Connecting
 538            | State::Reconnecting
 539            | State::ReconnectExhausted
 540            | State::ServerNotRunning => false,
 541        }
 542    }
 543
 544    fn is_reconnect_failed(&self) -> bool {
 545        matches!(self, Self::ReconnectFailed { .. })
 546    }
 547
 548    fn is_reconnect_exhausted(&self) -> bool {
 549        matches!(self, Self::ReconnectExhausted { .. })
 550    }
 551
 552    fn is_server_not_running(&self) -> bool {
 553        matches!(self, Self::ServerNotRunning)
 554    }
 555
 556    fn is_reconnecting(&self) -> bool {
 557        matches!(self, Self::Reconnecting { .. })
 558    }
 559
 560    fn heartbeat_recovered(self) -> Self {
 561        match self {
 562            Self::HeartbeatMissed {
 563                ssh_connection,
 564                delegate,
 565                multiplex_task,
 566                heartbeat_task,
 567                ..
 568            } => Self::Connected {
 569                ssh_connection,
 570                delegate,
 571                multiplex_task,
 572                heartbeat_task,
 573            },
 574            _ => self,
 575        }
 576    }
 577
 578    fn heartbeat_missed(self) -> Self {
 579        match self {
 580            Self::Connected {
 581                ssh_connection,
 582                delegate,
 583                multiplex_task,
 584                heartbeat_task,
 585            } => Self::HeartbeatMissed {
 586                missed_heartbeats: 1,
 587                ssh_connection,
 588                delegate,
 589                multiplex_task,
 590                heartbeat_task,
 591            },
 592            Self::HeartbeatMissed {
 593                missed_heartbeats,
 594                ssh_connection,
 595                delegate,
 596                multiplex_task,
 597                heartbeat_task,
 598            } => Self::HeartbeatMissed {
 599                missed_heartbeats: missed_heartbeats + 1,
 600                ssh_connection,
 601                delegate,
 602                multiplex_task,
 603                heartbeat_task,
 604            },
 605            _ => self,
 606        }
 607    }
 608}
 609
 610/// The state of the ssh connection.
 611#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 612pub enum ConnectionState {
 613    Connecting,
 614    Connected,
 615    HeartbeatMissed,
 616    Reconnecting,
 617    Disconnected,
 618}
 619
 620impl From<&State> for ConnectionState {
 621    fn from(value: &State) -> Self {
 622        match value {
 623            State::Connecting => Self::Connecting,
 624            State::Connected { .. } => Self::Connected,
 625            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 626            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 627            State::ReconnectExhausted => Self::Disconnected,
 628            State::ServerNotRunning => Self::Disconnected,
 629        }
 630    }
 631}
 632
 633pub struct SshRemoteClient {
 634    client: Arc<ChannelClient>,
 635    unique_identifier: String,
 636    connection_options: SshConnectionOptions,
 637    path_style: PathStyle,
 638    state: Arc<Mutex<Option<State>>>,
 639}
 640
 641#[derive(Debug)]
 642pub enum SshRemoteEvent {
 643    Disconnected,
 644}
 645
 646impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 647
 648// Identifies the socket on the remote server so that reconnects
 649// can re-join the same project.
 650pub enum ConnectionIdentifier {
 651    Setup(u64),
 652    Workspace(i64),
 653}
 654
 655static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 656
 657impl ConnectionIdentifier {
 658    pub fn setup() -> Self {
 659        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 660    }
 661    // This string gets used in a socket name, and so must be relatively short.
 662    // The total length of:
 663    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 664    // Must be less than about 100 characters
 665    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 666    // So our strings should be at most 20 characters or so.
 667    fn to_string(&self, cx: &App) -> String {
 668        let identifier_prefix = match ReleaseChannel::global(cx) {
 669            ReleaseChannel::Stable => "".to_string(),
 670            release_channel => format!("{}-", release_channel.dev_name()),
 671        };
 672        match self {
 673            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 674            Self::Workspace(workspace_id) => {
 675                format!("{identifier_prefix}workspace-{workspace_id}",)
 676            }
 677        }
 678    }
 679}
 680
 681impl SshRemoteClient {
 682    pub fn new(
 683        unique_identifier: ConnectionIdentifier,
 684        connection_options: SshConnectionOptions,
 685        cancellation: oneshot::Receiver<()>,
 686        delegate: Arc<dyn SshClientDelegate>,
 687        cx: &mut App,
 688    ) -> Task<Result<Option<Entity<Self>>>> {
 689        let unique_identifier = unique_identifier.to_string(cx);
 690        cx.spawn(async move |cx| {
 691            let success = Box::pin(async move {
 692                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 693                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 694                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 695
 696                let client =
 697                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 698
 699                let ssh_connection = cx
 700                    .update(|cx| {
 701                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 702                            pool.connect(connection_options.clone(), &delegate, cx)
 703                        })
 704                    })?
 705                    .await
 706                    .map_err(|e| e.cloned())?;
 707
 708                let path_style = ssh_connection.path_style();
 709                let this = cx.new(|_| Self {
 710                    client: client.clone(),
 711                    unique_identifier: unique_identifier.clone(),
 712                    connection_options,
 713                    path_style,
 714                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 715                })?;
 716
 717                let io_task = ssh_connection.start_proxy(
 718                    unique_identifier,
 719                    false,
 720                    incoming_tx,
 721                    outgoing_rx,
 722                    connection_activity_tx,
 723                    delegate.clone(),
 724                    cx,
 725                );
 726
 727                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 728
 729                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 730                    log::error!("failed to establish connection: {}", error);
 731                    return Err(error);
 732                }
 733
 734                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 735
 736                this.update(cx, |this, _| {
 737                    *this.state.lock() = Some(State::Connected {
 738                        ssh_connection,
 739                        delegate,
 740                        multiplex_task,
 741                        heartbeat_task,
 742                    });
 743                })?;
 744
 745                Ok(Some(this))
 746            });
 747
 748            select! {
 749                _ = cancellation.fuse() => {
 750                    Ok(None)
 751                }
 752                result = success.fuse() =>  result
 753            }
 754        })
 755    }
 756
 757    pub fn shutdown_processes<T: RequestMessage>(
 758        &self,
 759        shutdown_request: Option<T>,
 760        executor: BackgroundExecutor,
 761    ) -> Option<impl Future<Output = ()> + use<T>> {
 762        let state = self.state.lock().take()?;
 763        log::info!("shutting down ssh processes");
 764
 765        let State::Connected {
 766            multiplex_task,
 767            heartbeat_task,
 768            ssh_connection,
 769            delegate,
 770        } = state
 771        else {
 772            return None;
 773        };
 774
 775        let client = self.client.clone();
 776
 777        Some(async move {
 778            if let Some(shutdown_request) = shutdown_request {
 779                client.send(shutdown_request).log_err();
 780                // We wait 50ms instead of waiting for a response, because
 781                // waiting for a response would require us to wait on the main thread
 782                // which we want to avoid in an `on_app_quit` callback.
 783                executor.timer(Duration::from_millis(50)).await;
 784            }
 785
 786            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 787            // child of master_process.
 788            drop(multiplex_task);
 789            // Now drop the rest of state, which kills master process.
 790            drop(heartbeat_task);
 791            drop(ssh_connection);
 792            drop(delegate);
 793        })
 794    }
 795
 796    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 797        let mut lock = self.state.lock();
 798
 799        let can_reconnect = lock
 800            .as_ref()
 801            .map(|state| state.can_reconnect())
 802            .unwrap_or(false);
 803        if !can_reconnect {
 804            log::info!("aborting reconnect, because not in state that allows reconnecting");
 805            let error = if let Some(state) = lock.as_ref() {
 806                format!("invalid state, cannot reconnect while in state {state}")
 807            } else {
 808                "no state set".to_string()
 809            };
 810            anyhow::bail!(error);
 811        }
 812
 813        let state = lock.take().unwrap();
 814        let (attempts, ssh_connection, delegate) = match state {
 815            State::Connected {
 816                ssh_connection,
 817                delegate,
 818                multiplex_task,
 819                heartbeat_task,
 820            }
 821            | State::HeartbeatMissed {
 822                ssh_connection,
 823                delegate,
 824                multiplex_task,
 825                heartbeat_task,
 826                ..
 827            } => {
 828                drop(multiplex_task);
 829                drop(heartbeat_task);
 830                (0, ssh_connection, delegate)
 831            }
 832            State::ReconnectFailed {
 833                attempts,
 834                ssh_connection,
 835                delegate,
 836                ..
 837            } => (attempts, ssh_connection, delegate),
 838            State::Connecting
 839            | State::Reconnecting
 840            | State::ReconnectExhausted
 841            | State::ServerNotRunning => unreachable!(),
 842        };
 843
 844        let attempts = attempts + 1;
 845        if attempts > MAX_RECONNECT_ATTEMPTS {
 846            log::error!(
 847                "Failed to reconnect to after {} attempts, giving up",
 848                MAX_RECONNECT_ATTEMPTS
 849            );
 850            drop(lock);
 851            self.set_state(State::ReconnectExhausted, cx);
 852            return Ok(());
 853        }
 854        drop(lock);
 855
 856        self.set_state(State::Reconnecting, cx);
 857
 858        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 859
 860        let unique_identifier = self.unique_identifier.clone();
 861        let client = self.client.clone();
 862        let reconnect_task = cx.spawn(async move |this, cx| {
 863            macro_rules! failed {
 864                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 865                    return State::ReconnectFailed {
 866                        error: anyhow!($error),
 867                        attempts: $attempts,
 868                        ssh_connection: $ssh_connection,
 869                        delegate: $delegate,
 870                    };
 871                };
 872            }
 873
 874            if let Err(error) = ssh_connection
 875                .kill()
 876                .await
 877                .context("Failed to kill ssh process")
 878            {
 879                failed!(error, attempts, ssh_connection, delegate);
 880            };
 881
 882            let connection_options = ssh_connection.connection_options();
 883
 884            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 885            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 886            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 887
 888            let (ssh_connection, io_task) = match async {
 889                let ssh_connection = cx
 890                    .update_global(|pool: &mut ConnectionPool, cx| {
 891                        pool.connect(connection_options, &delegate, cx)
 892                    })?
 893                    .await
 894                    .map_err(|error| error.cloned())?;
 895
 896                let io_task = ssh_connection.start_proxy(
 897                    unique_identifier,
 898                    true,
 899                    incoming_tx,
 900                    outgoing_rx,
 901                    connection_activity_tx,
 902                    delegate.clone(),
 903                    cx,
 904                );
 905                anyhow::Ok((ssh_connection, io_task))
 906            }
 907            .await
 908            {
 909                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 910                Err(error) => {
 911                    failed!(error, attempts, ssh_connection, delegate);
 912                }
 913            };
 914
 915            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 916            client.reconnect(incoming_rx, outgoing_tx, &cx);
 917
 918            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 919                failed!(error, attempts, ssh_connection, delegate);
 920            };
 921
 922            State::Connected {
 923                ssh_connection,
 924                delegate,
 925                multiplex_task,
 926                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 927            }
 928        });
 929
 930        cx.spawn(async move |this, cx| {
 931            let new_state = reconnect_task.await;
 932            this.update(cx, |this, cx| {
 933                this.try_set_state(cx, |old_state| {
 934                    if old_state.is_reconnecting() {
 935                        match &new_state {
 936                            State::Connecting
 937                            | State::Reconnecting { .. }
 938                            | State::HeartbeatMissed { .. }
 939                            | State::ServerNotRunning => {}
 940                            State::Connected { .. } => {
 941                                log::info!("Successfully reconnected");
 942                            }
 943                            State::ReconnectFailed {
 944                                error, attempts, ..
 945                            } => {
 946                                log::error!(
 947                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 948                                    attempts,
 949                                    error
 950                                );
 951                            }
 952                            State::ReconnectExhausted => {
 953                                log::error!("Reconnect attempt failed and all attempts exhausted");
 954                            }
 955                        }
 956                        Some(new_state)
 957                    } else {
 958                        None
 959                    }
 960                });
 961
 962                if this.state_is(State::is_reconnect_failed) {
 963                    this.reconnect(cx)
 964                } else if this.state_is(State::is_reconnect_exhausted) {
 965                    Ok(())
 966                } else {
 967                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 968                    Ok(())
 969                }
 970            })
 971        })
 972        .detach_and_log_err(cx);
 973
 974        Ok(())
 975    }
 976
 977    fn heartbeat(
 978        this: WeakEntity<Self>,
 979        mut connection_activity_rx: mpsc::Receiver<()>,
 980        cx: &mut AsyncApp,
 981    ) -> Task<Result<()>> {
 982        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 983            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 984        };
 985
 986        cx.spawn(async move |cx| {
 987                let mut missed_heartbeats = 0;
 988
 989                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 990                futures::pin_mut!(keepalive_timer);
 991
 992                loop {
 993                    select_biased! {
 994                        result = connection_activity_rx.next().fuse() => {
 995                            if result.is_none() {
 996                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 997                                return Ok(());
 998                            }
 999
1000                            if missed_heartbeats != 0 {
1001                                missed_heartbeats = 0;
1002                                let _ =this.update(cx, |this, mut cx| {
1003                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1004                                })?;
1005                            }
1006                        }
1007                        _ = keepalive_timer => {
1008                            log::debug!("Sending heartbeat to server...");
1009
1010                            let result = select_biased! {
1011                                _ = connection_activity_rx.next().fuse() => {
1012                                    Ok(())
1013                                }
1014                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1015                                    ping_result
1016                                }
1017                            };
1018
1019                            if result.is_err() {
1020                                missed_heartbeats += 1;
1021                                log::warn!(
1022                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1023                                    HEARTBEAT_TIMEOUT,
1024                                    missed_heartbeats,
1025                                    MAX_MISSED_HEARTBEATS
1026                                );
1027                            } else if missed_heartbeats != 0 {
1028                                missed_heartbeats = 0;
1029                            } else {
1030                                continue;
1031                            }
1032
1033                            let result = this.update(cx, |this, mut cx| {
1034                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1035                            })?;
1036                            if result.is_break() {
1037                                return Ok(());
1038                            }
1039                        }
1040                    }
1041
1042                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1043                }
1044
1045        })
1046    }
1047
1048    fn handle_heartbeat_result(
1049        &mut self,
1050        missed_heartbeats: usize,
1051        cx: &mut Context<Self>,
1052    ) -> ControlFlow<()> {
1053        let state = self.state.lock().take().unwrap();
1054        let next_state = if missed_heartbeats > 0 {
1055            state.heartbeat_missed()
1056        } else {
1057            state.heartbeat_recovered()
1058        };
1059
1060        self.set_state(next_state, cx);
1061
1062        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1063            log::error!(
1064                "Missed last {} heartbeats. Reconnecting...",
1065                missed_heartbeats
1066            );
1067
1068            self.reconnect(cx)
1069                .context("failed to start reconnect process after missing heartbeats")
1070                .log_err();
1071            ControlFlow::Break(())
1072        } else {
1073            ControlFlow::Continue(())
1074        }
1075    }
1076
1077    fn monitor(
1078        this: WeakEntity<Self>,
1079        io_task: Task<Result<i32>>,
1080        cx: &AsyncApp,
1081    ) -> Task<Result<()>> {
1082        cx.spawn(async move |cx| {
1083            let result = io_task.await;
1084
1085            match result {
1086                Ok(exit_code) => {
1087                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1088                        match error {
1089                            ProxyLaunchError::ServerNotRunning => {
1090                                log::error!("failed to reconnect because server is not running");
1091                                this.update(cx, |this, cx| {
1092                                    this.set_state(State::ServerNotRunning, cx);
1093                                })?;
1094                            }
1095                        }
1096                    } else if exit_code > 0 {
1097                        log::error!("proxy process terminated unexpectedly");
1098                        this.update(cx, |this, cx| {
1099                            this.reconnect(cx).ok();
1100                        })?;
1101                    }
1102                }
1103                Err(error) => {
1104                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1105                    this.update(cx, |this, cx| {
1106                        this.reconnect(cx).ok();
1107                    })?;
1108                }
1109            }
1110
1111            Ok(())
1112        })
1113    }
1114
1115    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1116        self.state.lock().as_ref().map_or(false, check)
1117    }
1118
1119    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1120        let mut lock = self.state.lock();
1121        let new_state = lock.as_ref().and_then(map);
1122
1123        if let Some(new_state) = new_state {
1124            lock.replace(new_state);
1125            cx.notify();
1126        }
1127    }
1128
1129    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1130        log::info!("setting state to '{}'", &state);
1131
1132        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1133        let is_server_not_running = state.is_server_not_running();
1134        self.state.lock().replace(state);
1135
1136        if is_reconnect_exhausted || is_server_not_running {
1137            cx.emit(SshRemoteEvent::Disconnected);
1138        }
1139        cx.notify();
1140    }
1141
1142    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1143        self.client.subscribe_to_entity(remote_id, entity);
1144    }
1145
1146    pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1147        self.state
1148            .lock()
1149            .as_ref()
1150            .and_then(|state| state.ssh_connection())
1151            .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1152    }
1153
1154    pub fn upload_directory(
1155        &self,
1156        src_path: PathBuf,
1157        dest_path: RemotePathBuf,
1158        cx: &App,
1159    ) -> Task<Result<()>> {
1160        let state = self.state.lock();
1161        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1162            return Task::ready(Err(anyhow!("no ssh connection")));
1163        };
1164        connection.upload_directory(src_path, dest_path, cx)
1165    }
1166
1167    pub fn proto_client(&self) -> AnyProtoClient {
1168        self.client.clone().into()
1169    }
1170
1171    pub fn connection_string(&self) -> String {
1172        self.connection_options.connection_string()
1173    }
1174
1175    pub fn connection_options(&self) -> SshConnectionOptions {
1176        self.connection_options.clone()
1177    }
1178
1179    pub fn connection_state(&self) -> ConnectionState {
1180        self.state
1181            .lock()
1182            .as_ref()
1183            .map(ConnectionState::from)
1184            .unwrap_or(ConnectionState::Disconnected)
1185    }
1186
1187    pub fn is_disconnected(&self) -> bool {
1188        self.connection_state() == ConnectionState::Disconnected
1189    }
1190
1191    pub fn path_style(&self) -> PathStyle {
1192        self.path_style
1193    }
1194
1195    #[cfg(any(test, feature = "test-support"))]
1196    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1197        let opts = self.connection_options();
1198        client_cx.spawn(async move |cx| {
1199            let connection = cx
1200                .update_global(|c: &mut ConnectionPool, _| {
1201                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1202                        c.clone()
1203                    } else {
1204                        panic!("missing test connection")
1205                    }
1206                })
1207                .unwrap()
1208                .await
1209                .unwrap();
1210
1211            connection.simulate_disconnect(&cx);
1212        })
1213    }
1214
1215    #[cfg(any(test, feature = "test-support"))]
1216    pub fn fake_server(
1217        client_cx: &mut gpui::TestAppContext,
1218        server_cx: &mut gpui::TestAppContext,
1219    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1220        let port = client_cx
1221            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1222        let opts = SshConnectionOptions {
1223            host: "<fake>".to_string(),
1224            port: Some(port),
1225            ..Default::default()
1226        };
1227        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1228        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1229        let server_client =
1230            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1231        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1232            connection_options: opts.clone(),
1233            server_cx: fake::SendableCx::new(server_cx),
1234            server_channel: server_client.clone(),
1235        });
1236
1237        client_cx.update(|cx| {
1238            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1239                c.connections.insert(
1240                    opts.clone(),
1241                    ConnectionPoolEntry::Connecting(
1242                        cx.background_spawn({
1243                            let connection = connection.clone();
1244                            async move { Ok(connection.clone()) }
1245                        })
1246                        .shared(),
1247                    ),
1248                );
1249            })
1250        });
1251
1252        (opts, server_client)
1253    }
1254
1255    #[cfg(any(test, feature = "test-support"))]
1256    pub async fn fake_client(
1257        opts: SshConnectionOptions,
1258        client_cx: &mut gpui::TestAppContext,
1259    ) -> Entity<Self> {
1260        let (_tx, rx) = oneshot::channel();
1261        client_cx
1262            .update(|cx| {
1263                Self::new(
1264                    ConnectionIdentifier::setup(),
1265                    opts,
1266                    rx,
1267                    Arc::new(fake::Delegate),
1268                    cx,
1269                )
1270            })
1271            .await
1272            .unwrap()
1273            .unwrap()
1274    }
1275}
1276
1277enum ConnectionPoolEntry {
1278    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1279    Connected(Weak<dyn RemoteConnection>),
1280}
1281
1282#[derive(Default)]
1283struct ConnectionPool {
1284    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1285}
1286
1287impl Global for ConnectionPool {}
1288
1289impl ConnectionPool {
1290    pub fn connect(
1291        &mut self,
1292        opts: SshConnectionOptions,
1293        delegate: &Arc<dyn SshClientDelegate>,
1294        cx: &mut App,
1295    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1296        let connection = self.connections.get(&opts);
1297        match connection {
1298            Some(ConnectionPoolEntry::Connecting(task)) => {
1299                let delegate = delegate.clone();
1300                cx.spawn(async move |cx| {
1301                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1302                })
1303                .detach();
1304                return task.clone();
1305            }
1306            Some(ConnectionPoolEntry::Connected(ssh)) => {
1307                if let Some(ssh) = ssh.upgrade() {
1308                    if !ssh.has_been_killed() {
1309                        return Task::ready(Ok(ssh)).shared();
1310                    }
1311                }
1312                self.connections.remove(&opts);
1313            }
1314            None => {}
1315        }
1316
1317        let task = cx
1318            .spawn({
1319                let opts = opts.clone();
1320                let delegate = delegate.clone();
1321                async move |cx| {
1322                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1323                        .await
1324                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1325
1326                    cx.update_global(|pool: &mut Self, _| {
1327                        debug_assert!(matches!(
1328                            pool.connections.get(&opts),
1329                            Some(ConnectionPoolEntry::Connecting(_))
1330                        ));
1331                        match connection {
1332                            Ok(connection) => {
1333                                pool.connections.insert(
1334                                    opts.clone(),
1335                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1336                                );
1337                                Ok(connection)
1338                            }
1339                            Err(error) => {
1340                                pool.connections.remove(&opts);
1341                                Err(Arc::new(error))
1342                            }
1343                        }
1344                    })?
1345                }
1346            })
1347            .shared();
1348
1349        self.connections
1350            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1351        task
1352    }
1353}
1354
1355impl From<SshRemoteClient> for AnyProtoClient {
1356    fn from(client: SshRemoteClient) -> Self {
1357        AnyProtoClient::new(client.client.clone())
1358    }
1359}
1360
1361#[async_trait(?Send)]
1362trait RemoteConnection: Send + Sync {
1363    fn start_proxy(
1364        &self,
1365        unique_identifier: String,
1366        reconnect: bool,
1367        incoming_tx: UnboundedSender<Envelope>,
1368        outgoing_rx: UnboundedReceiver<Envelope>,
1369        connection_activity_tx: Sender<()>,
1370        delegate: Arc<dyn SshClientDelegate>,
1371        cx: &mut AsyncApp,
1372    ) -> Task<Result<i32>>;
1373    fn upload_directory(
1374        &self,
1375        src_path: PathBuf,
1376        dest_path: RemotePathBuf,
1377        cx: &App,
1378    ) -> Task<Result<()>>;
1379    async fn kill(&self) -> Result<()>;
1380    fn has_been_killed(&self) -> bool;
1381    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1382    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1383    fn ssh_args(&self) -> SshArgs;
1384    fn connection_options(&self) -> SshConnectionOptions;
1385    fn path_style(&self) -> PathStyle;
1386
1387    #[cfg(any(test, feature = "test-support"))]
1388    fn simulate_disconnect(&self, _: &AsyncApp) {}
1389}
1390
1391struct SshRemoteConnection {
1392    socket: SshSocket,
1393    master_process: Mutex<Option<Child>>,
1394    remote_binary_path: Option<RemotePathBuf>,
1395    ssh_platform: SshPlatform,
1396    ssh_path_style: PathStyle,
1397    _temp_dir: TempDir,
1398}
1399
1400#[async_trait(?Send)]
1401impl RemoteConnection for SshRemoteConnection {
1402    async fn kill(&self) -> Result<()> {
1403        let Some(mut process) = self.master_process.lock().take() else {
1404            return Ok(());
1405        };
1406        process.kill().ok();
1407        process.status().await?;
1408        Ok(())
1409    }
1410
1411    fn has_been_killed(&self) -> bool {
1412        self.master_process.lock().is_none()
1413    }
1414
1415    fn ssh_args(&self) -> SshArgs {
1416        self.socket.ssh_args()
1417    }
1418
1419    fn connection_options(&self) -> SshConnectionOptions {
1420        self.socket.connection_options.clone()
1421    }
1422
1423    fn upload_directory(
1424        &self,
1425        src_path: PathBuf,
1426        dest_path: RemotePathBuf,
1427        cx: &App,
1428    ) -> Task<Result<()>> {
1429        let mut command = util::command::new_smol_command("scp");
1430        let output = self
1431            .socket
1432            .ssh_options(&mut command)
1433            .args(
1434                self.socket
1435                    .connection_options
1436                    .port
1437                    .map(|port| vec!["-P".to_string(), port.to_string()])
1438                    .unwrap_or_default(),
1439            )
1440            .arg("-C")
1441            .arg("-r")
1442            .arg(&src_path)
1443            .arg(format!(
1444                "{}:{}",
1445                self.socket.connection_options.scp_url(),
1446                dest_path.to_string()
1447            ))
1448            .output();
1449
1450        cx.background_spawn(async move {
1451            let output = output.await?;
1452
1453            anyhow::ensure!(
1454                output.status.success(),
1455                "failed to upload directory {} -> {}: {}",
1456                src_path.display(),
1457                dest_path.to_string(),
1458                String::from_utf8_lossy(&output.stderr)
1459            );
1460
1461            Ok(())
1462        })
1463    }
1464
1465    fn start_proxy(
1466        &self,
1467        unique_identifier: String,
1468        reconnect: bool,
1469        incoming_tx: UnboundedSender<Envelope>,
1470        outgoing_rx: UnboundedReceiver<Envelope>,
1471        connection_activity_tx: Sender<()>,
1472        delegate: Arc<dyn SshClientDelegate>,
1473        cx: &mut AsyncApp,
1474    ) -> Task<Result<i32>> {
1475        delegate.set_status(Some("Starting proxy"), cx);
1476
1477        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1478            return Task::ready(Err(anyhow!("Remote binary path not set")));
1479        };
1480
1481        let mut start_proxy_command = shell_script!(
1482            "exec {binary_path} proxy --identifier {identifier}",
1483            binary_path = &remote_binary_path.to_string(),
1484            identifier = &unique_identifier,
1485        );
1486
1487        if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1488            start_proxy_command = format!(
1489                "RUST_LOG={} {}",
1490                shlex::try_quote(&rust_log).unwrap(),
1491                start_proxy_command
1492            )
1493        }
1494        if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1495            start_proxy_command = format!(
1496                "RUST_BACKTRACE={} {}",
1497                shlex::try_quote(&rust_backtrace).unwrap(),
1498                start_proxy_command
1499            )
1500        }
1501        if reconnect {
1502            start_proxy_command.push_str(" --reconnect");
1503        }
1504
1505        let ssh_proxy_process = match self
1506            .socket
1507            .ssh_command("sh", &["-c", &start_proxy_command])
1508            // IMPORTANT: we kill this process when we drop the task that uses it.
1509            .kill_on_drop(true)
1510            .spawn()
1511        {
1512            Ok(process) => process,
1513            Err(error) => {
1514                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1515            }
1516        };
1517
1518        Self::multiplex(
1519            ssh_proxy_process,
1520            incoming_tx,
1521            outgoing_rx,
1522            connection_activity_tx,
1523            &cx,
1524        )
1525    }
1526
1527    fn path_style(&self) -> PathStyle {
1528        self.ssh_path_style
1529    }
1530}
1531
1532impl SshRemoteConnection {
1533    async fn new(
1534        connection_options: SshConnectionOptions,
1535        delegate: Arc<dyn SshClientDelegate>,
1536        cx: &mut AsyncApp,
1537    ) -> Result<Self> {
1538        use askpass::AskPassResult;
1539
1540        delegate.set_status(Some("Connecting"), cx);
1541
1542        let url = connection_options.ssh_url();
1543
1544        let temp_dir = tempfile::Builder::new()
1545            .prefix("zed-ssh-session")
1546            .tempdir()?;
1547        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1548            let delegate = delegate.clone();
1549            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1550        });
1551
1552        let mut askpass =
1553            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1554
1555        // Start the master SSH process, which does not do anything except for establish
1556        // the connection and keep it open, allowing other ssh commands to reuse it
1557        // via a control socket.
1558        #[cfg(not(target_os = "windows"))]
1559        let socket_path = temp_dir.path().join("ssh.sock");
1560
1561        let mut master_process = {
1562            #[cfg(not(target_os = "windows"))]
1563            let args = [
1564                "-N",
1565                "-o",
1566                "ControlPersist=no",
1567                "-o",
1568                "ControlMaster=yes",
1569                "-o",
1570            ];
1571            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1572            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1573            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1574            #[cfg(target_os = "windows")]
1575            let args = ["-N"];
1576            let mut master_process = util::command::new_smol_command("ssh");
1577            master_process
1578                .kill_on_drop(true)
1579                .stdin(Stdio::null())
1580                .stdout(Stdio::piped())
1581                .stderr(Stdio::piped())
1582                .env("SSH_ASKPASS_REQUIRE", "force")
1583                .env("SSH_ASKPASS", askpass.script_path())
1584                .args(connection_options.additional_args())
1585                .args(args);
1586            #[cfg(not(target_os = "windows"))]
1587            master_process.arg(format!("ControlPath={}", socket_path.display()));
1588            master_process.arg(&url).spawn()?
1589        };
1590        // Wait for this ssh process to close its stdout, indicating that authentication
1591        // has completed.
1592        let mut stdout = master_process.stdout.take().unwrap();
1593        let mut output = Vec::new();
1594
1595        let result = select_biased! {
1596            result = askpass.run().fuse() => {
1597                match result {
1598                    AskPassResult::CancelledByUser => {
1599                        master_process.kill().ok();
1600                        anyhow::bail!("SSH connection canceled")
1601                    }
1602                    AskPassResult::Timedout => {
1603                        anyhow::bail!("connecting to host timed out")
1604                    }
1605                }
1606            }
1607            _ = stdout.read_to_end(&mut output).fuse() => {
1608                anyhow::Ok(())
1609            }
1610        };
1611
1612        if let Err(e) = result {
1613            return Err(e.context("Failed to connect to host"));
1614        }
1615
1616        if master_process.try_status()?.is_some() {
1617            output.clear();
1618            let mut stderr = master_process.stderr.take().unwrap();
1619            stderr.read_to_end(&mut output).await?;
1620
1621            let error_message = format!(
1622                "failed to connect: {}",
1623                String::from_utf8_lossy(&output).trim()
1624            );
1625            anyhow::bail!(error_message);
1626        }
1627
1628        #[cfg(not(target_os = "windows"))]
1629        let socket = SshSocket::new(connection_options, socket_path)?;
1630        #[cfg(target_os = "windows")]
1631        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1632        drop(askpass);
1633
1634        let ssh_platform = socket.platform().await?;
1635        let ssh_path_style = match ssh_platform.os {
1636            "windows" => PathStyle::Windows,
1637            _ => PathStyle::Posix,
1638        };
1639
1640        let mut this = Self {
1641            socket,
1642            master_process: Mutex::new(Some(master_process)),
1643            _temp_dir: temp_dir,
1644            remote_binary_path: None,
1645            ssh_path_style,
1646            ssh_platform,
1647        };
1648
1649        let (release_channel, version, commit) = cx.update(|cx| {
1650            (
1651                ReleaseChannel::global(cx),
1652                AppVersion::global(cx),
1653                AppCommitSha::try_global(cx),
1654            )
1655        })?;
1656        this.remote_binary_path = Some(
1657            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1658                .await?,
1659        );
1660
1661        Ok(this)
1662    }
1663
1664    fn multiplex(
1665        mut ssh_proxy_process: Child,
1666        incoming_tx: UnboundedSender<Envelope>,
1667        mut outgoing_rx: UnboundedReceiver<Envelope>,
1668        mut connection_activity_tx: Sender<()>,
1669        cx: &AsyncApp,
1670    ) -> Task<Result<i32>> {
1671        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1672        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1673        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1674
1675        let mut stdin_buffer = Vec::new();
1676        let mut stdout_buffer = Vec::new();
1677        let mut stderr_buffer = Vec::new();
1678        let mut stderr_offset = 0;
1679
1680        let stdin_task = cx.background_spawn(async move {
1681            while let Some(outgoing) = outgoing_rx.next().await {
1682                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1683            }
1684            anyhow::Ok(())
1685        });
1686
1687        let stdout_task = cx.background_spawn({
1688            let mut connection_activity_tx = connection_activity_tx.clone();
1689            async move {
1690                loop {
1691                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1692                    let len = child_stdout.read(&mut stdout_buffer).await?;
1693
1694                    if len == 0 {
1695                        return anyhow::Ok(());
1696                    }
1697
1698                    if len < MESSAGE_LEN_SIZE {
1699                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1700                    }
1701
1702                    let message_len = message_len_from_buffer(&stdout_buffer);
1703                    let envelope =
1704                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1705                            .await?;
1706                    connection_activity_tx.try_send(()).ok();
1707                    incoming_tx.unbounded_send(envelope).ok();
1708                }
1709            }
1710        });
1711
1712        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1713            loop {
1714                stderr_buffer.resize(stderr_offset + 1024, 0);
1715
1716                let len = child_stderr
1717                    .read(&mut stderr_buffer[stderr_offset..])
1718                    .await?;
1719                if len == 0 {
1720                    return anyhow::Ok(());
1721                }
1722
1723                stderr_offset += len;
1724                let mut start_ix = 0;
1725                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1726                    .iter()
1727                    .position(|b| b == &b'\n')
1728                {
1729                    let line_ix = start_ix + ix;
1730                    let content = &stderr_buffer[start_ix..line_ix];
1731                    start_ix = line_ix + 1;
1732                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1733                        record.log(log::logger())
1734                    } else {
1735                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1736                    }
1737                }
1738                stderr_buffer.drain(0..start_ix);
1739                stderr_offset -= start_ix;
1740
1741                connection_activity_tx.try_send(()).ok();
1742            }
1743        });
1744
1745        cx.background_spawn(async move {
1746            let result = futures::select! {
1747                result = stdin_task.fuse() => {
1748                    result.context("stdin")
1749                }
1750                result = stdout_task.fuse() => {
1751                    result.context("stdout")
1752                }
1753                result = stderr_task.fuse() => {
1754                    result.context("stderr")
1755                }
1756            };
1757
1758            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1759            match result {
1760                Ok(_) => Ok(status),
1761                Err(error) => Err(error),
1762            }
1763        })
1764    }
1765
1766    #[allow(unused)]
1767    async fn ensure_server_binary(
1768        &self,
1769        delegate: &Arc<dyn SshClientDelegate>,
1770        release_channel: ReleaseChannel,
1771        version: SemanticVersion,
1772        commit: Option<AppCommitSha>,
1773        cx: &mut AsyncApp,
1774    ) -> Result<RemotePathBuf> {
1775        let version_str = match release_channel {
1776            ReleaseChannel::Nightly => {
1777                let commit = commit.map(|s| s.full()).unwrap_or_default();
1778                format!("{}-{}", version, commit)
1779            }
1780            ReleaseChannel::Dev => "build".to_string(),
1781            _ => version.to_string(),
1782        };
1783        let binary_name = format!(
1784            "zed-remote-server-{}-{}",
1785            release_channel.dev_name(),
1786            version_str
1787        );
1788        let dst_path = RemotePathBuf::new(
1789            paths::remote_server_dir_relative().join(binary_name),
1790            self.ssh_path_style,
1791        );
1792
1793        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1794        #[cfg(debug_assertions)]
1795        if let Some(build_remote_server) = build_remote_server {
1796            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1797            let tmp_path = RemotePathBuf::new(
1798                paths::remote_server_dir_relative().join(format!(
1799                    "download-{}-{}",
1800                    std::process::id(),
1801                    src_path.file_name().unwrap().to_string_lossy()
1802                )),
1803                self.ssh_path_style,
1804            );
1805            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1806                .await?;
1807            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1808                .await?;
1809            return Ok(dst_path);
1810        }
1811
1812        if self
1813            .socket
1814            .run_command(&dst_path.to_string(), &["version"])
1815            .await
1816            .is_ok()
1817        {
1818            return Ok(dst_path);
1819        }
1820
1821        let wanted_version = cx.update(|cx| match release_channel {
1822            ReleaseChannel::Nightly => Ok(None),
1823            ReleaseChannel::Dev => {
1824                anyhow::bail!(
1825                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1826                    dst_path
1827                )
1828            }
1829            _ => Ok(Some(AppVersion::global(cx))),
1830        })??;
1831
1832        let tmp_path_gz = RemotePathBuf::new(
1833            PathBuf::from(format!(
1834                "{}-download-{}.gz",
1835                dst_path.to_string(),
1836                std::process::id()
1837            )),
1838            self.ssh_path_style,
1839        );
1840        if !self.socket.connection_options.upload_binary_over_ssh {
1841            if let Some((url, body)) = delegate
1842                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1843                .await?
1844            {
1845                match self
1846                    .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1847                    .await
1848                {
1849                    Ok(_) => {
1850                        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1851                            .await?;
1852                        return Ok(dst_path);
1853                    }
1854                    Err(e) => {
1855                        log::error!(
1856                            "Failed to download binary on server, attempting to upload server: {}",
1857                            e
1858                        )
1859                    }
1860                }
1861            }
1862        }
1863
1864        let src_path = delegate
1865            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1866            .await?;
1867        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1868            .await?;
1869        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1870            .await?;
1871        return Ok(dst_path);
1872    }
1873
1874    async fn download_binary_on_server(
1875        &self,
1876        url: &str,
1877        body: &str,
1878        tmp_path_gz: &RemotePathBuf,
1879        delegate: &Arc<dyn SshClientDelegate>,
1880        cx: &mut AsyncApp,
1881    ) -> Result<()> {
1882        if let Some(parent) = tmp_path_gz.parent() {
1883            self.socket
1884                .run_command(
1885                    "sh",
1886                    &[
1887                        "-c",
1888                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1889                    ],
1890                )
1891                .await?;
1892        }
1893
1894        delegate.set_status(Some("Downloading remote development server on host"), cx);
1895
1896        match self
1897            .socket
1898            .run_command(
1899                "curl",
1900                &[
1901                    "-f",
1902                    "-L",
1903                    "-X",
1904                    "GET",
1905                    "-H",
1906                    "Content-Type: application/json",
1907                    "-d",
1908                    &body,
1909                    &url,
1910                    "-o",
1911                    &tmp_path_gz.to_string(),
1912                ],
1913            )
1914            .await
1915        {
1916            Ok(_) => {}
1917            Err(e) => {
1918                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1919                    return Err(e);
1920                }
1921
1922                match self
1923                    .socket
1924                    .run_command(
1925                        "wget",
1926                        &[
1927                            "--method=GET",
1928                            "--header=Content-Type: application/json",
1929                            "--body-data",
1930                            &body,
1931                            &url,
1932                            "-O",
1933                            &tmp_path_gz.to_string(),
1934                        ],
1935                    )
1936                    .await
1937                {
1938                    Ok(_) => {}
1939                    Err(e) => {
1940                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1941                            return Err(e);
1942                        } else {
1943                            anyhow::bail!("Neither curl nor wget is available");
1944                        }
1945                    }
1946                }
1947            }
1948        }
1949
1950        Ok(())
1951    }
1952
1953    async fn upload_local_server_binary(
1954        &self,
1955        src_path: &Path,
1956        tmp_path_gz: &RemotePathBuf,
1957        delegate: &Arc<dyn SshClientDelegate>,
1958        cx: &mut AsyncApp,
1959    ) -> Result<()> {
1960        if let Some(parent) = tmp_path_gz.parent() {
1961            self.socket
1962                .run_command(
1963                    "sh",
1964                    &[
1965                        "-c",
1966                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1967                    ],
1968                )
1969                .await?;
1970        }
1971
1972        let src_stat = fs::metadata(&src_path).await?;
1973        let size = src_stat.len();
1974
1975        let t0 = Instant::now();
1976        delegate.set_status(Some("Uploading remote development server"), cx);
1977        log::info!(
1978            "uploading remote development server to {:?} ({}kb)",
1979            tmp_path_gz,
1980            size / 1024
1981        );
1982        self.upload_file(&src_path, &tmp_path_gz)
1983            .await
1984            .context("failed to upload server binary")?;
1985        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1986        Ok(())
1987    }
1988
1989    async fn extract_server_binary(
1990        &self,
1991        dst_path: &RemotePathBuf,
1992        tmp_path: &RemotePathBuf,
1993        delegate: &Arc<dyn SshClientDelegate>,
1994        cx: &mut AsyncApp,
1995    ) -> Result<()> {
1996        delegate.set_status(Some("Extracting remote development server"), cx);
1997        let server_mode = 0o755;
1998
1999        let orig_tmp_path = tmp_path.to_string();
2000        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2001            shell_script!(
2002                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2003                server_mode = &format!("{:o}", server_mode),
2004                dst_path = &dst_path.to_string(),
2005            )
2006        } else {
2007            shell_script!(
2008                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2009                server_mode = &format!("{:o}", server_mode),
2010                dst_path = &dst_path.to_string()
2011            )
2012        };
2013        self.socket.run_command("sh", &["-c", &script]).await?;
2014        Ok(())
2015    }
2016
2017    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2018        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2019        let mut command = util::command::new_smol_command("scp");
2020        let output = self
2021            .socket
2022            .ssh_options(&mut command)
2023            .args(
2024                self.socket
2025                    .connection_options
2026                    .port
2027                    .map(|port| vec!["-P".to_string(), port.to_string()])
2028                    .unwrap_or_default(),
2029            )
2030            .arg(src_path)
2031            .arg(format!(
2032                "{}:{}",
2033                self.socket.connection_options.scp_url(),
2034                dest_path.to_string()
2035            ))
2036            .output()
2037            .await?;
2038
2039        anyhow::ensure!(
2040            output.status.success(),
2041            "failed to upload file {} -> {}: {}",
2042            src_path.display(),
2043            dest_path.to_string(),
2044            String::from_utf8_lossy(&output.stderr)
2045        );
2046        Ok(())
2047    }
2048
2049    #[cfg(debug_assertions)]
2050    async fn build_local(
2051        &self,
2052        build_remote_server: String,
2053        delegate: &Arc<dyn SshClientDelegate>,
2054        cx: &mut AsyncApp,
2055    ) -> Result<PathBuf> {
2056        use smol::process::{Command, Stdio};
2057        use std::env::VarError;
2058
2059        async fn run_cmd(command: &mut Command) -> Result<()> {
2060            let output = command
2061                .kill_on_drop(true)
2062                .stderr(Stdio::inherit())
2063                .output()
2064                .await?;
2065            anyhow::ensure!(
2066                output.status.success(),
2067                "Failed to run command: {command:?}"
2068            );
2069            Ok(())
2070        }
2071
2072        let use_musl = !build_remote_server.contains("nomusl");
2073        let triple = format!(
2074            "{}-{}",
2075            self.ssh_platform.arch,
2076            match self.ssh_platform.os {
2077                "linux" =>
2078                    if use_musl {
2079                        "unknown-linux-musl"
2080                    } else {
2081                        "unknown-linux-gnu"
2082                    },
2083                "macos" => "apple-darwin",
2084                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2085            }
2086        );
2087        let mut rust_flags = match std::env::var("RUSTFLAGS") {
2088            Ok(val) => val,
2089            Err(VarError::NotPresent) => String::new(),
2090            Err(e) => {
2091                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2092                String::new()
2093            }
2094        };
2095        if self.ssh_platform.os == "linux" && use_musl {
2096            rust_flags.push_str(" -C target-feature=+crt-static");
2097        }
2098        if build_remote_server.contains("mold") {
2099            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2100        }
2101
2102        if self.ssh_platform.arch == std::env::consts::ARCH
2103            && self.ssh_platform.os == std::env::consts::OS
2104        {
2105            delegate.set_status(Some("Building remote server binary from source"), cx);
2106            log::info!("building remote server binary from source");
2107            run_cmd(
2108                Command::new("cargo")
2109                    .args([
2110                        "build",
2111                        "--package",
2112                        "remote_server",
2113                        "--features",
2114                        "debug-embed",
2115                        "--target-dir",
2116                        "target/remote_server",
2117                        "--target",
2118                        &triple,
2119                    ])
2120                    .env("RUSTFLAGS", &rust_flags),
2121            )
2122            .await?;
2123        } else {
2124            if build_remote_server.contains("cross") {
2125                #[cfg(target_os = "windows")]
2126                use util::paths::SanitizedPath;
2127
2128                delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2129                log::info!("installing cross");
2130                run_cmd(Command::new("cargo").args([
2131                    "install",
2132                    "cross",
2133                    "--git",
2134                    "https://github.com/cross-rs/cross",
2135                ]))
2136                .await?;
2137
2138                delegate.set_status(
2139                    Some(&format!(
2140                        "Building remote server binary from source for {} with Docker",
2141                        &triple
2142                    )),
2143                    cx,
2144                );
2145                log::info!("building remote server binary from source for {}", &triple);
2146
2147                // On Windows, the binding needs to be set to the canonical path
2148                #[cfg(target_os = "windows")]
2149                let src =
2150                    SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2151                #[cfg(not(target_os = "windows"))]
2152                let src = "./target";
2153                run_cmd(
2154                    Command::new("cross")
2155                        .args([
2156                            "build",
2157                            "--package",
2158                            "remote_server",
2159                            "--features",
2160                            "debug-embed",
2161                            "--target-dir",
2162                            "target/remote_server",
2163                            "--target",
2164                            &triple,
2165                        ])
2166                        .env(
2167                            "CROSS_CONTAINER_OPTS",
2168                            format!("--mount type=bind,src={src},dst=/app/target"),
2169                        )
2170                        .env("RUSTFLAGS", &rust_flags),
2171                )
2172                .await?;
2173            } else {
2174                let which = cx
2175                    .background_spawn(async move { which::which("zig") })
2176                    .await;
2177
2178                if which.is_err() {
2179                    #[cfg(not(target_os = "windows"))]
2180                    {
2181                        anyhow::bail!(
2182                            "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2183                        )
2184                    }
2185                    #[cfg(target_os = "windows")]
2186                    {
2187                        anyhow::bail!(
2188                            "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2189                        )
2190                    }
2191                }
2192
2193                delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2194                log::info!("adding rustup target");
2195                run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2196
2197                delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2198                log::info!("installing cargo-zigbuild");
2199                run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"]))
2200                    .await?;
2201
2202                delegate.set_status(
2203                    Some(&format!(
2204                        "Building remote binary from source for {triple} with Zig"
2205                    )),
2206                    cx,
2207                );
2208                log::info!("building remote binary from source for {triple} with Zig");
2209                run_cmd(
2210                    Command::new("cargo")
2211                        .args([
2212                            "zigbuild",
2213                            "--package",
2214                            "remote_server",
2215                            "--features",
2216                            "debug-embed",
2217                            "--target-dir",
2218                            "target/remote_server",
2219                            "--target",
2220                            &triple,
2221                        ])
2222                        .env("RUSTFLAGS", &rust_flags),
2223                )
2224                .await?;
2225            }
2226        };
2227        let bin_path = Path::new("target")
2228            .join("remote_server")
2229            .join(&triple)
2230            .join("debug")
2231            .join("remote_server");
2232
2233        let path = if !build_remote_server.contains("nocompress") {
2234            delegate.set_status(Some("Compressing binary"), cx);
2235
2236            #[cfg(not(target_os = "windows"))]
2237            {
2238                run_cmd(Command::new("gzip").args(["-9", "-f", &bin_path.to_string_lossy()]))
2239                    .await?;
2240            }
2241            #[cfg(target_os = "windows")]
2242            {
2243                // On Windows, we use 7z to compress the binary
2244                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2245                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2246                if smol::fs::metadata(&gz_path).await.is_ok() {
2247                    smol::fs::remove_file(&gz_path).await?;
2248                }
2249                run_cmd(Command::new(seven_zip).args([
2250                    "a",
2251                    "-tgzip",
2252                    &gz_path,
2253                    &bin_path.to_string_lossy(),
2254                ]))
2255                .await?;
2256            }
2257
2258            let mut archive_path = bin_path;
2259            archive_path.set_extension("gz");
2260            std::env::current_dir()?.join(archive_path)
2261        } else {
2262            bin_path
2263        };
2264
2265        Ok(path)
2266    }
2267}
2268
2269type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2270
2271pub struct ChannelClient {
2272    next_message_id: AtomicU32,
2273    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2274    buffer: Mutex<VecDeque<Envelope>>,
2275    response_channels: ResponseChannels,
2276    message_handlers: Mutex<ProtoMessageHandlerSet>,
2277    max_received: AtomicU32,
2278    name: &'static str,
2279    task: Mutex<Task<Result<()>>>,
2280}
2281
2282impl ChannelClient {
2283    pub fn new(
2284        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2285        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2286        cx: &App,
2287        name: &'static str,
2288    ) -> Arc<Self> {
2289        Arc::new_cyclic(|this| Self {
2290            outgoing_tx: Mutex::new(outgoing_tx),
2291            next_message_id: AtomicU32::new(0),
2292            max_received: AtomicU32::new(0),
2293            response_channels: ResponseChannels::default(),
2294            message_handlers: Default::default(),
2295            buffer: Mutex::new(VecDeque::new()),
2296            name,
2297            task: Mutex::new(Self::start_handling_messages(
2298                this.clone(),
2299                incoming_rx,
2300                &cx.to_async(),
2301            )),
2302        })
2303    }
2304
2305    fn start_handling_messages(
2306        this: Weak<Self>,
2307        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2308        cx: &AsyncApp,
2309    ) -> Task<Result<()>> {
2310        cx.spawn(async move |cx| {
2311            let peer_id = PeerId { owner_id: 0, id: 0 };
2312            while let Some(incoming) = incoming_rx.next().await {
2313                let Some(this) = this.upgrade() else {
2314                    return anyhow::Ok(());
2315                };
2316                if let Some(ack_id) = incoming.ack_id {
2317                    let mut buffer = this.buffer.lock();
2318                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2319                        buffer.pop_front();
2320                    }
2321                }
2322                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2323                {
2324                    log::debug!(
2325                        "{}:ssh message received. name:FlushBufferedMessages",
2326                        this.name
2327                    );
2328                    {
2329                        let buffer = this.buffer.lock();
2330                        for envelope in buffer.iter() {
2331                            this.outgoing_tx
2332                                .lock()
2333                                .unbounded_send(envelope.clone())
2334                                .ok();
2335                        }
2336                    }
2337                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2338                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2339                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2340                    continue;
2341                }
2342
2343                this.max_received.store(incoming.id, SeqCst);
2344
2345                if let Some(request_id) = incoming.responding_to {
2346                    let request_id = MessageId(request_id);
2347                    let sender = this.response_channels.lock().remove(&request_id);
2348                    if let Some(sender) = sender {
2349                        let (tx, rx) = oneshot::channel();
2350                        if incoming.payload.is_some() {
2351                            sender.send((incoming, tx)).ok();
2352                        }
2353                        rx.await.ok();
2354                    }
2355                } else if let Some(envelope) =
2356                    build_typed_envelope(peer_id, Instant::now(), incoming)
2357                {
2358                    let type_name = envelope.payload_type_name();
2359                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2360                        &this.message_handlers,
2361                        envelope,
2362                        this.clone().into(),
2363                        cx.clone(),
2364                    ) {
2365                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2366                        cx.foreground_executor()
2367                            .spawn(async move {
2368                                match future.await {
2369                                    Ok(_) => {
2370                                        log::debug!(
2371                                            "{}:ssh message handled. name:{type_name}",
2372                                            this.name
2373                                        );
2374                                    }
2375                                    Err(error) => {
2376                                        log::error!(
2377                                            "{}:error handling message. type:{}, error:{}",
2378                                            this.name,
2379                                            type_name,
2380                                            format!("{error:#}").lines().fold(
2381                                                String::new(),
2382                                                |mut message, line| {
2383                                                    if !message.is_empty() {
2384                                                        message.push(' ');
2385                                                    }
2386                                                    message.push_str(line);
2387                                                    message
2388                                                }
2389                                            )
2390                                        );
2391                                    }
2392                                }
2393                            })
2394                            .detach()
2395                    } else {
2396                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2397                    }
2398                }
2399            }
2400            anyhow::Ok(())
2401        })
2402    }
2403
2404    pub fn reconnect(
2405        self: &Arc<Self>,
2406        incoming_rx: UnboundedReceiver<Envelope>,
2407        outgoing_tx: UnboundedSender<Envelope>,
2408        cx: &AsyncApp,
2409    ) {
2410        *self.outgoing_tx.lock() = outgoing_tx;
2411        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2412    }
2413
2414    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2415        let id = (TypeId::of::<E>(), remote_id);
2416
2417        let mut message_handlers = self.message_handlers.lock();
2418        if message_handlers
2419            .entities_by_type_and_remote_id
2420            .contains_key(&id)
2421        {
2422            panic!("already subscribed to entity");
2423        }
2424
2425        message_handlers.entities_by_type_and_remote_id.insert(
2426            id,
2427            EntityMessageSubscriber::Entity {
2428                handle: entity.downgrade().into(),
2429            },
2430        );
2431    }
2432
2433    pub fn request<T: RequestMessage>(
2434        &self,
2435        payload: T,
2436    ) -> impl 'static + Future<Output = Result<T::Response>> {
2437        self.request_internal(payload, true)
2438    }
2439
2440    fn request_internal<T: RequestMessage>(
2441        &self,
2442        payload: T,
2443        use_buffer: bool,
2444    ) -> impl 'static + Future<Output = Result<T::Response>> {
2445        log::debug!("ssh request start. name:{}", T::NAME);
2446        let response =
2447            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2448        async move {
2449            let response = response.await?;
2450            log::debug!("ssh request finish. name:{}", T::NAME);
2451            T::Response::from_envelope(response).context("received a response of the wrong type")
2452        }
2453    }
2454
2455    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2456        smol::future::or(
2457            async {
2458                self.request_internal(proto::FlushBufferedMessages {}, false)
2459                    .await?;
2460
2461                for envelope in self.buffer.lock().iter() {
2462                    self.outgoing_tx
2463                        .lock()
2464                        .unbounded_send(envelope.clone())
2465                        .ok();
2466                }
2467                Ok(())
2468            },
2469            async {
2470                smol::Timer::after(timeout).await;
2471                anyhow::bail!("Timeout detected")
2472            },
2473        )
2474        .await
2475    }
2476
2477    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2478        smol::future::or(
2479            async {
2480                self.request(proto::Ping {}).await?;
2481                Ok(())
2482            },
2483            async {
2484                smol::Timer::after(timeout).await;
2485                anyhow::bail!("Timeout detected")
2486            },
2487        )
2488        .await
2489    }
2490
2491    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2492        log::debug!("ssh send name:{}", T::NAME);
2493        self.send_dynamic(payload.into_envelope(0, None, None))
2494    }
2495
2496    fn request_dynamic(
2497        &self,
2498        mut envelope: proto::Envelope,
2499        type_name: &'static str,
2500        use_buffer: bool,
2501    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2502        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2503        let (tx, rx) = oneshot::channel();
2504        let mut response_channels_lock = self.response_channels.lock();
2505        response_channels_lock.insert(MessageId(envelope.id), tx);
2506        drop(response_channels_lock);
2507
2508        let result = if use_buffer {
2509            self.send_buffered(envelope)
2510        } else {
2511            self.send_unbuffered(envelope)
2512        };
2513        async move {
2514            if let Err(error) = &result {
2515                log::error!("failed to send message: {error}");
2516                anyhow::bail!("failed to send message: {error}");
2517            }
2518
2519            let response = rx.await.context("connection lost")?.0;
2520            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2521                return Err(RpcError::from_proto(error, type_name));
2522            }
2523            Ok(response)
2524        }
2525    }
2526
2527    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2528        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2529        self.send_buffered(envelope)
2530    }
2531
2532    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2533        envelope.ack_id = Some(self.max_received.load(SeqCst));
2534        self.buffer.lock().push_back(envelope.clone());
2535        // ignore errors on send (happen while we're reconnecting)
2536        // assume that the global "disconnected" overlay is sufficient.
2537        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2538        Ok(())
2539    }
2540
2541    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2542        envelope.ack_id = Some(self.max_received.load(SeqCst));
2543        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2544        Ok(())
2545    }
2546}
2547
2548impl ProtoClient for ChannelClient {
2549    fn request(
2550        &self,
2551        envelope: proto::Envelope,
2552        request_type: &'static str,
2553    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2554        self.request_dynamic(envelope, request_type, true).boxed()
2555    }
2556
2557    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2558        self.send_dynamic(envelope)
2559    }
2560
2561    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2562        self.send_dynamic(envelope)
2563    }
2564
2565    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2566        &self.message_handlers
2567    }
2568
2569    fn is_via_collab(&self) -> bool {
2570        false
2571    }
2572}
2573
2574#[cfg(any(test, feature = "test-support"))]
2575mod fake {
2576    use std::{path::PathBuf, sync::Arc};
2577
2578    use anyhow::Result;
2579    use async_trait::async_trait;
2580    use futures::{
2581        FutureExt, SinkExt, StreamExt,
2582        channel::{
2583            mpsc::{self, Sender},
2584            oneshot,
2585        },
2586        select_biased,
2587    };
2588    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2589    use release_channel::ReleaseChannel;
2590    use rpc::proto::Envelope;
2591    use util::paths::{PathStyle, RemotePathBuf};
2592
2593    use super::{
2594        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2595        SshPlatform,
2596    };
2597
2598    pub(super) struct FakeRemoteConnection {
2599        pub(super) connection_options: SshConnectionOptions,
2600        pub(super) server_channel: Arc<ChannelClient>,
2601        pub(super) server_cx: SendableCx,
2602    }
2603
2604    pub(super) struct SendableCx(AsyncApp);
2605    impl SendableCx {
2606        // SAFETY: When run in test mode, GPUI is always single threaded.
2607        pub(super) fn new(cx: &TestAppContext) -> Self {
2608            Self(cx.to_async())
2609        }
2610
2611        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2612        fn get(&self, _: &AsyncApp) -> AsyncApp {
2613            self.0.clone()
2614        }
2615    }
2616
2617    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2618    unsafe impl Send for SendableCx {}
2619    unsafe impl Sync for SendableCx {}
2620
2621    #[async_trait(?Send)]
2622    impl RemoteConnection for FakeRemoteConnection {
2623        async fn kill(&self) -> Result<()> {
2624            Ok(())
2625        }
2626
2627        fn has_been_killed(&self) -> bool {
2628            false
2629        }
2630
2631        fn ssh_args(&self) -> SshArgs {
2632            SshArgs {
2633                arguments: Vec::new(),
2634                envs: None,
2635            }
2636        }
2637
2638        fn upload_directory(
2639            &self,
2640            _src_path: PathBuf,
2641            _dest_path: RemotePathBuf,
2642            _cx: &App,
2643        ) -> Task<Result<()>> {
2644            unreachable!()
2645        }
2646
2647        fn connection_options(&self) -> SshConnectionOptions {
2648            self.connection_options.clone()
2649        }
2650
2651        fn simulate_disconnect(&self, cx: &AsyncApp) {
2652            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2653            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2654            self.server_channel
2655                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2656        }
2657
2658        fn start_proxy(
2659            &self,
2660            _unique_identifier: String,
2661            _reconnect: bool,
2662            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2663            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2664            mut connection_activity_tx: Sender<()>,
2665            _delegate: Arc<dyn SshClientDelegate>,
2666            cx: &mut AsyncApp,
2667        ) -> Task<Result<i32>> {
2668            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2669            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2670
2671            self.server_channel.reconnect(
2672                server_incoming_rx,
2673                server_outgoing_tx,
2674                &self.server_cx.get(cx),
2675            );
2676
2677            cx.background_spawn(async move {
2678                loop {
2679                    select_biased! {
2680                        server_to_client = server_outgoing_rx.next().fuse() => {
2681                            let Some(server_to_client) = server_to_client else {
2682                                return Ok(1)
2683                            };
2684                            connection_activity_tx.try_send(()).ok();
2685                            client_incoming_tx.send(server_to_client).await.ok();
2686                        }
2687                        client_to_server = client_outgoing_rx.next().fuse() => {
2688                            let Some(client_to_server) = client_to_server else {
2689                                return Ok(1)
2690                            };
2691                            server_incoming_tx.send(client_to_server).await.ok();
2692                        }
2693                    }
2694                }
2695            })
2696        }
2697
2698        fn path_style(&self) -> PathStyle {
2699            PathStyle::current()
2700        }
2701    }
2702
2703    pub(super) struct Delegate;
2704
2705    impl SshClientDelegate for Delegate {
2706        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2707            unreachable!()
2708        }
2709
2710        fn download_server_binary_locally(
2711            &self,
2712            _: SshPlatform,
2713            _: ReleaseChannel,
2714            _: Option<SemanticVersion>,
2715            _: &mut AsyncApp,
2716        ) -> Task<Result<PathBuf>> {
2717            unreachable!()
2718        }
2719
2720        fn get_download_params(
2721            &self,
2722            _platform: SshPlatform,
2723            _release_channel: ReleaseChannel,
2724            _version: Option<SemanticVersion>,
2725            _cx: &mut AsyncApp,
2726        ) -> Task<Result<Option<(String, String)>>> {
2727            unreachable!()
2728        }
2729
2730        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2731    }
2732}