ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  30    RpcError,
  31    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  32};
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize};
  35use smol::{
  36    fs,
  37    process::{self, Child, Stdio},
  38};
  39use std::{
  40    any::TypeId,
  41    collections::VecDeque,
  42    fmt, iter,
  43    ops::ControlFlow,
  44    path::{Path, PathBuf},
  45    sync::{
  46        Arc, Weak,
  47        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  48    },
  49    time::{Duration, Instant},
  50};
  51use tempfile::TempDir;
  52use util::{
  53    ResultExt,
  54    paths::{PathStyle, RemotePathBuf},
  55};
  56
  57#[derive(
  58    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  59)]
  60pub struct SshProjectId(pub u64);
  61
  62#[derive(Clone)]
  63pub struct SshSocket {
  64    connection_options: SshConnectionOptions,
  65    #[cfg(not(target_os = "windows"))]
  66    socket_path: PathBuf,
  67    #[cfg(target_os = "windows")]
  68    envs: HashMap<String, String>,
  69}
  70
  71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  72pub struct SshPortForwardOption {
  73    #[serde(skip_serializing_if = "Option::is_none")]
  74    pub local_host: Option<String>,
  75    pub local_port: u16,
  76    #[serde(skip_serializing_if = "Option::is_none")]
  77    pub remote_host: Option<String>,
  78    pub remote_port: u16,
  79}
  80
  81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  82pub struct SshConnectionOptions {
  83    pub host: String,
  84    pub username: Option<String>,
  85    pub port: Option<u16>,
  86    pub password: Option<String>,
  87    pub args: Option<Vec<String>>,
  88    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  89
  90    pub nickname: Option<String>,
  91    pub upload_binary_over_ssh: bool,
  92}
  93
  94pub struct SshArgs {
  95    pub arguments: Vec<String>,
  96    pub envs: Option<HashMap<String, String>>,
  97}
  98
  99#[macro_export]
 100macro_rules! shell_script {
 101    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 102        format!(
 103            $fmt,
 104            $(
 105                $name = shlex::try_quote($arg).unwrap()
 106            ),+
 107        )
 108    }};
 109}
 110
 111fn parse_port_number(port_str: &str) -> Result<u16> {
 112    port_str
 113        .parse()
 114        .with_context(|| format!("parsing port number: {port_str}"))
 115}
 116
 117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 118    let parts: Vec<&str> = spec.split(':').collect();
 119
 120    match parts.len() {
 121        4 => {
 122            let local_port = parse_port_number(parts[1])?;
 123            let remote_port = parse_port_number(parts[3])?;
 124
 125            Ok(SshPortForwardOption {
 126                local_host: Some(parts[0].to_string()),
 127                local_port,
 128                remote_host: Some(parts[2].to_string()),
 129                remote_port,
 130            })
 131        }
 132        3 => {
 133            let local_port = parse_port_number(parts[0])?;
 134            let remote_port = parse_port_number(parts[2])?;
 135
 136            Ok(SshPortForwardOption {
 137                local_host: None,
 138                local_port,
 139                remote_host: Some(parts[1].to_string()),
 140                remote_port,
 141            })
 142        }
 143        _ => anyhow::bail!("Invalid port forward format"),
 144    }
 145}
 146
 147impl SshConnectionOptions {
 148    pub fn parse_command_line(input: &str) -> Result<Self> {
 149        let input = input.trim_start_matches("ssh ");
 150        let mut hostname: Option<String> = None;
 151        let mut username: Option<String> = None;
 152        let mut port: Option<u16> = None;
 153        let mut args = Vec::new();
 154        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 155
 156        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 157        const ALLOWED_OPTS: &[&str] = &[
 158            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 159        ];
 160        const ALLOWED_ARGS: &[&str] = &[
 161            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 162            "-w",
 163        ];
 164
 165        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 166
 167        'outer: while let Some(arg) = tokens.next() {
 168            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 169                args.push(arg.to_string());
 170                continue;
 171            }
 172            if arg == "-p" {
 173                port = tokens.next().and_then(|arg| arg.parse().ok());
 174                continue;
 175            } else if let Some(p) = arg.strip_prefix("-p") {
 176                port = p.parse().ok();
 177                continue;
 178            }
 179            if arg == "-l" {
 180                username = tokens.next();
 181                continue;
 182            } else if let Some(l) = arg.strip_prefix("-l") {
 183                username = Some(l.to_string());
 184                continue;
 185            }
 186            if arg == "-L" || arg.starts_with("-L") {
 187                let forward_spec = if arg == "-L" {
 188                    tokens.next()
 189                } else {
 190                    Some(arg.strip_prefix("-L").unwrap().to_string())
 191                };
 192
 193                if let Some(spec) = forward_spec {
 194                    port_forwards.push(parse_port_forward_spec(&spec)?);
 195                } else {
 196                    anyhow::bail!("Missing port forward format");
 197                }
 198            }
 199
 200            for a in ALLOWED_ARGS {
 201                if arg == *a {
 202                    args.push(arg);
 203                    if let Some(next) = tokens.next() {
 204                        args.push(next);
 205                    }
 206                    continue 'outer;
 207                } else if arg.starts_with(a) {
 208                    args.push(arg);
 209                    continue 'outer;
 210                }
 211            }
 212            if arg.starts_with("-") || hostname.is_some() {
 213                anyhow::bail!("unsupported argument: {:?}", arg);
 214            }
 215            let mut input = &arg as &str;
 216            // Destination might be: username1@username2@ip2@ip1
 217            if let Some((u, rest)) = input.rsplit_once('@') {
 218                input = rest;
 219                username = Some(u.to_string());
 220            }
 221            if let Some((rest, p)) = input.split_once(':') {
 222                input = rest;
 223                port = p.parse().ok()
 224            }
 225            hostname = Some(input.to_string())
 226        }
 227
 228        let Some(hostname) = hostname else {
 229            anyhow::bail!("missing hostname");
 230        };
 231
 232        let port_forwards = match port_forwards.len() {
 233            0 => None,
 234            _ => Some(port_forwards),
 235        };
 236
 237        Ok(Self {
 238            host: hostname.to_string(),
 239            username: username.clone(),
 240            port,
 241            port_forwards,
 242            args: Some(args),
 243            password: None,
 244            nickname: None,
 245            upload_binary_over_ssh: false,
 246        })
 247    }
 248
 249    pub fn ssh_url(&self) -> String {
 250        let mut result = String::from("ssh://");
 251        if let Some(username) = &self.username {
 252            // Username might be: username1@username2@ip2
 253            let username = urlencoding::encode(username);
 254            result.push_str(&username);
 255            result.push('@');
 256        }
 257        result.push_str(&self.host);
 258        if let Some(port) = self.port {
 259            result.push(':');
 260            result.push_str(&port.to_string());
 261        }
 262        result
 263    }
 264
 265    pub fn additional_args(&self) -> Vec<String> {
 266        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 267
 268        if let Some(forwards) = &self.port_forwards {
 269            args.extend(forwards.iter().map(|pf| {
 270                let local_host = match &pf.local_host {
 271                    Some(host) => host,
 272                    None => "localhost",
 273                };
 274                let remote_host = match &pf.remote_host {
 275                    Some(host) => host,
 276                    None => "localhost",
 277                };
 278
 279                format!(
 280                    "-L{}:{}:{}:{}",
 281                    local_host, pf.local_port, remote_host, pf.remote_port
 282                )
 283            }));
 284        }
 285
 286        args
 287    }
 288
 289    fn scp_url(&self) -> String {
 290        if let Some(username) = &self.username {
 291            format!("{}@{}", username, self.host)
 292        } else {
 293            self.host.clone()
 294        }
 295    }
 296
 297    pub fn connection_string(&self) -> String {
 298        let host = if let Some(username) = &self.username {
 299            format!("{}@{}", username, self.host)
 300        } else {
 301            self.host.clone()
 302        };
 303        if let Some(port) = &self.port {
 304            format!("{}:{}", host, port)
 305        } else {
 306            host
 307        }
 308    }
 309}
 310
 311#[derive(Copy, Clone, Debug)]
 312pub struct SshPlatform {
 313    pub os: &'static str,
 314    pub arch: &'static str,
 315}
 316
 317pub trait SshClientDelegate: Send + Sync {
 318    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 319    fn get_download_params(
 320        &self,
 321        platform: SshPlatform,
 322        release_channel: ReleaseChannel,
 323        version: Option<SemanticVersion>,
 324        cx: &mut AsyncApp,
 325    ) -> Task<Result<Option<(String, String)>>>;
 326
 327    fn download_server_binary_locally(
 328        &self,
 329        platform: SshPlatform,
 330        release_channel: ReleaseChannel,
 331        version: Option<SemanticVersion>,
 332        cx: &mut AsyncApp,
 333    ) -> Task<Result<PathBuf>>;
 334    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 335}
 336
 337impl SshSocket {
 338    #[cfg(not(target_os = "windows"))]
 339    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 340        Ok(Self {
 341            connection_options: options,
 342            socket_path,
 343        })
 344    }
 345
 346    #[cfg(target_os = "windows")]
 347    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 348        let askpass_script = temp_dir.path().join("askpass.bat");
 349        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 350        let mut envs = HashMap::default();
 351        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 352        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 353        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 354        Ok(Self {
 355            connection_options: options,
 356            envs,
 357        })
 358    }
 359
 360    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 361    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 362    // and passes -l as an argument to sh, not to ls.
 363    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 364    // into a machine. You must use `cd` to get back to $HOME.
 365    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 366    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 367        let mut command = util::command::new_smol_command("ssh");
 368        let to_run = iter::once(&program)
 369            .chain(args.iter())
 370            .map(|token| {
 371                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 372                debug_assert!(
 373                    !token.contains('\n'),
 374                    "multiline arguments do not work in all shells"
 375                );
 376                shlex::try_quote(token).unwrap()
 377            })
 378            .join(" ");
 379        let to_run = format!("cd; {to_run}");
 380        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 381        self.ssh_options(&mut command)
 382            .arg(self.connection_options.ssh_url())
 383            .arg(to_run);
 384        command
 385    }
 386
 387    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 388        let output = self.ssh_command(program, args).output().await?;
 389        anyhow::ensure!(
 390            output.status.success(),
 391            "failed to run command: {}",
 392            String::from_utf8_lossy(&output.stderr)
 393        );
 394        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 395    }
 396
 397    #[cfg(not(target_os = "windows"))]
 398    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 399        command
 400            .stdin(Stdio::piped())
 401            .stdout(Stdio::piped())
 402            .stderr(Stdio::piped())
 403            .args(["-o", "ControlMaster=no", "-o"])
 404            .arg(format!("ControlPath={}", self.socket_path.display()))
 405    }
 406
 407    #[cfg(target_os = "windows")]
 408    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 409        command
 410            .stdin(Stdio::piped())
 411            .stdout(Stdio::piped())
 412            .stderr(Stdio::piped())
 413            .envs(self.envs.clone())
 414    }
 415
 416    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 417    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 418    #[cfg(not(target_os = "windows"))]
 419    fn ssh_args(&self) -> SshArgs {
 420        SshArgs {
 421            arguments: vec![
 422                "-o".to_string(),
 423                "ControlMaster=no".to_string(),
 424                "-o".to_string(),
 425                format!("ControlPath={}", self.socket_path.display()),
 426                self.connection_options.ssh_url(),
 427            ],
 428            envs: None,
 429        }
 430    }
 431
 432    #[cfg(target_os = "windows")]
 433    fn ssh_args(&self) -> SshArgs {
 434        SshArgs {
 435            arguments: vec![self.connection_options.ssh_url()],
 436            envs: Some(self.envs.clone()),
 437        }
 438    }
 439
 440    async fn platform(&self) -> Result<SshPlatform> {
 441        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 442        let Some((os, arch)) = uname.split_once(" ") else {
 443            anyhow::bail!("unknown uname: {uname:?}")
 444        };
 445
 446        let os = match os.trim() {
 447            "Darwin" => "macos",
 448            "Linux" => "linux",
 449            _ => anyhow::bail!(
 450                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 451            ),
 452        };
 453        // exclude armv5,6,7 as they are 32-bit.
 454        let arch = if arch.starts_with("armv8")
 455            || arch.starts_with("armv9")
 456            || arch.starts_with("arm64")
 457            || arch.starts_with("aarch64")
 458        {
 459            "aarch64"
 460        } else if arch.starts_with("x86") {
 461            "x86_64"
 462        } else {
 463            anyhow::bail!(
 464                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 465            )
 466        };
 467
 468        Ok(SshPlatform { os, arch })
 469    }
 470}
 471
 472const MAX_MISSED_HEARTBEATS: usize = 5;
 473const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 474const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 475
 476const MAX_RECONNECT_ATTEMPTS: usize = 3;
 477
 478enum State {
 479    Connecting,
 480    Connected {
 481        ssh_connection: Arc<dyn RemoteConnection>,
 482        delegate: Arc<dyn SshClientDelegate>,
 483
 484        multiplex_task: Task<Result<()>>,
 485        heartbeat_task: Task<Result<()>>,
 486    },
 487    HeartbeatMissed {
 488        missed_heartbeats: usize,
 489
 490        ssh_connection: Arc<dyn RemoteConnection>,
 491        delegate: Arc<dyn SshClientDelegate>,
 492
 493        multiplex_task: Task<Result<()>>,
 494        heartbeat_task: Task<Result<()>>,
 495    },
 496    Reconnecting,
 497    ReconnectFailed {
 498        ssh_connection: Arc<dyn RemoteConnection>,
 499        delegate: Arc<dyn SshClientDelegate>,
 500
 501        error: anyhow::Error,
 502        attempts: usize,
 503    },
 504    ReconnectExhausted,
 505    ServerNotRunning,
 506}
 507
 508impl fmt::Display for State {
 509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 510        match self {
 511            Self::Connecting => write!(f, "connecting"),
 512            Self::Connected { .. } => write!(f, "connected"),
 513            Self::Reconnecting => write!(f, "reconnecting"),
 514            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 515            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 516            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 517            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 518        }
 519    }
 520}
 521
 522impl State {
 523    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 524        match self {
 525            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 526            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 527            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 528            _ => None,
 529        }
 530    }
 531
 532    fn can_reconnect(&self) -> bool {
 533        match self {
 534            Self::Connected { .. }
 535            | Self::HeartbeatMissed { .. }
 536            | Self::ReconnectFailed { .. } => true,
 537            State::Connecting
 538            | State::Reconnecting
 539            | State::ReconnectExhausted
 540            | State::ServerNotRunning => false,
 541        }
 542    }
 543
 544    fn is_reconnect_failed(&self) -> bool {
 545        matches!(self, Self::ReconnectFailed { .. })
 546    }
 547
 548    fn is_reconnect_exhausted(&self) -> bool {
 549        matches!(self, Self::ReconnectExhausted { .. })
 550    }
 551
 552    fn is_server_not_running(&self) -> bool {
 553        matches!(self, Self::ServerNotRunning)
 554    }
 555
 556    fn is_reconnecting(&self) -> bool {
 557        matches!(self, Self::Reconnecting { .. })
 558    }
 559
 560    fn heartbeat_recovered(self) -> Self {
 561        match self {
 562            Self::HeartbeatMissed {
 563                ssh_connection,
 564                delegate,
 565                multiplex_task,
 566                heartbeat_task,
 567                ..
 568            } => Self::Connected {
 569                ssh_connection,
 570                delegate,
 571                multiplex_task,
 572                heartbeat_task,
 573            },
 574            _ => self,
 575        }
 576    }
 577
 578    fn heartbeat_missed(self) -> Self {
 579        match self {
 580            Self::Connected {
 581                ssh_connection,
 582                delegate,
 583                multiplex_task,
 584                heartbeat_task,
 585            } => Self::HeartbeatMissed {
 586                missed_heartbeats: 1,
 587                ssh_connection,
 588                delegate,
 589                multiplex_task,
 590                heartbeat_task,
 591            },
 592            Self::HeartbeatMissed {
 593                missed_heartbeats,
 594                ssh_connection,
 595                delegate,
 596                multiplex_task,
 597                heartbeat_task,
 598            } => Self::HeartbeatMissed {
 599                missed_heartbeats: missed_heartbeats + 1,
 600                ssh_connection,
 601                delegate,
 602                multiplex_task,
 603                heartbeat_task,
 604            },
 605            _ => self,
 606        }
 607    }
 608}
 609
 610/// The state of the ssh connection.
 611#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 612pub enum ConnectionState {
 613    Connecting,
 614    Connected,
 615    HeartbeatMissed,
 616    Reconnecting,
 617    Disconnected,
 618}
 619
 620impl From<&State> for ConnectionState {
 621    fn from(value: &State) -> Self {
 622        match value {
 623            State::Connecting => Self::Connecting,
 624            State::Connected { .. } => Self::Connected,
 625            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 626            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 627            State::ReconnectExhausted => Self::Disconnected,
 628            State::ServerNotRunning => Self::Disconnected,
 629        }
 630    }
 631}
 632
 633pub struct SshRemoteClient {
 634    client: Arc<ChannelClient>,
 635    unique_identifier: String,
 636    connection_options: SshConnectionOptions,
 637    path_style: PathStyle,
 638    state: Arc<Mutex<Option<State>>>,
 639}
 640
 641#[derive(Debug)]
 642pub enum SshRemoteEvent {
 643    Disconnected,
 644}
 645
 646impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 647
 648// Identifies the socket on the remote server so that reconnects
 649// can re-join the same project.
 650pub enum ConnectionIdentifier {
 651    Setup(u64),
 652    Workspace(i64),
 653}
 654
 655static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 656
 657impl ConnectionIdentifier {
 658    pub fn setup() -> Self {
 659        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 660    }
 661    // This string gets used in a socket name, and so must be relatively short.
 662    // The total length of:
 663    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 664    // Must be less than about 100 characters
 665    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 666    // So our strings should be at most 20 characters or so.
 667    fn to_string(&self, cx: &App) -> String {
 668        let identifier_prefix = match ReleaseChannel::global(cx) {
 669            ReleaseChannel::Stable => "".to_string(),
 670            release_channel => format!("{}-", release_channel.dev_name()),
 671        };
 672        match self {
 673            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 674            Self::Workspace(workspace_id) => {
 675                format!("{identifier_prefix}workspace-{workspace_id}",)
 676            }
 677        }
 678    }
 679}
 680
 681impl SshRemoteClient {
 682    pub fn new(
 683        unique_identifier: ConnectionIdentifier,
 684        connection_options: SshConnectionOptions,
 685        cancellation: oneshot::Receiver<()>,
 686        delegate: Arc<dyn SshClientDelegate>,
 687        cx: &mut App,
 688    ) -> Task<Result<Option<Entity<Self>>>> {
 689        let unique_identifier = unique_identifier.to_string(cx);
 690        cx.spawn(async move |cx| {
 691            let success = Box::pin(async move {
 692                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 693                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 694                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 695
 696                let client =
 697                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 698
 699                let ssh_connection = cx
 700                    .update(|cx| {
 701                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 702                            pool.connect(connection_options.clone(), &delegate, cx)
 703                        })
 704                    })?
 705                    .await
 706                    .map_err(|e| e.cloned())?;
 707
 708                let path_style = ssh_connection.path_style();
 709                let this = cx.new(|_| Self {
 710                    client: client.clone(),
 711                    unique_identifier: unique_identifier.clone(),
 712                    connection_options,
 713                    path_style,
 714                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 715                })?;
 716
 717                let io_task = ssh_connection.start_proxy(
 718                    unique_identifier,
 719                    false,
 720                    incoming_tx,
 721                    outgoing_rx,
 722                    connection_activity_tx,
 723                    delegate.clone(),
 724                    cx,
 725                );
 726
 727                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 728
 729                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 730                    log::error!("failed to establish connection: {}", error);
 731                    return Err(error);
 732                }
 733
 734                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 735
 736                this.update(cx, |this, _| {
 737                    *this.state.lock() = Some(State::Connected {
 738                        ssh_connection,
 739                        delegate,
 740                        multiplex_task,
 741                        heartbeat_task,
 742                    });
 743                })?;
 744
 745                Ok(Some(this))
 746            });
 747
 748            select! {
 749                _ = cancellation.fuse() => {
 750                    Ok(None)
 751                }
 752                result = success.fuse() =>  result
 753            }
 754        })
 755    }
 756
 757    pub fn shutdown_processes<T: RequestMessage>(
 758        &self,
 759        shutdown_request: Option<T>,
 760        executor: BackgroundExecutor,
 761    ) -> Option<impl Future<Output = ()> + use<T>> {
 762        let state = self.state.lock().take()?;
 763        log::info!("shutting down ssh processes");
 764
 765        let State::Connected {
 766            multiplex_task,
 767            heartbeat_task,
 768            ssh_connection,
 769            delegate,
 770        } = state
 771        else {
 772            return None;
 773        };
 774
 775        let client = self.client.clone();
 776
 777        Some(async move {
 778            if let Some(shutdown_request) = shutdown_request {
 779                client.send(shutdown_request).log_err();
 780                // We wait 50ms instead of waiting for a response, because
 781                // waiting for a response would require us to wait on the main thread
 782                // which we want to avoid in an `on_app_quit` callback.
 783                executor.timer(Duration::from_millis(50)).await;
 784            }
 785
 786            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 787            // child of master_process.
 788            drop(multiplex_task);
 789            // Now drop the rest of state, which kills master process.
 790            drop(heartbeat_task);
 791            drop(ssh_connection);
 792            drop(delegate);
 793        })
 794    }
 795
 796    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 797        let mut lock = self.state.lock();
 798
 799        let can_reconnect = lock
 800            .as_ref()
 801            .map(|state| state.can_reconnect())
 802            .unwrap_or(false);
 803        if !can_reconnect {
 804            log::info!("aborting reconnect, because not in state that allows reconnecting");
 805            let error = if let Some(state) = lock.as_ref() {
 806                format!("invalid state, cannot reconnect while in state {state}")
 807            } else {
 808                "no state set".to_string()
 809            };
 810            anyhow::bail!(error);
 811        }
 812
 813        let state = lock.take().unwrap();
 814        let (attempts, ssh_connection, delegate) = match state {
 815            State::Connected {
 816                ssh_connection,
 817                delegate,
 818                multiplex_task,
 819                heartbeat_task,
 820            }
 821            | State::HeartbeatMissed {
 822                ssh_connection,
 823                delegate,
 824                multiplex_task,
 825                heartbeat_task,
 826                ..
 827            } => {
 828                drop(multiplex_task);
 829                drop(heartbeat_task);
 830                (0, ssh_connection, delegate)
 831            }
 832            State::ReconnectFailed {
 833                attempts,
 834                ssh_connection,
 835                delegate,
 836                ..
 837            } => (attempts, ssh_connection, delegate),
 838            State::Connecting
 839            | State::Reconnecting
 840            | State::ReconnectExhausted
 841            | State::ServerNotRunning => unreachable!(),
 842        };
 843
 844        let attempts = attempts + 1;
 845        if attempts > MAX_RECONNECT_ATTEMPTS {
 846            log::error!(
 847                "Failed to reconnect to after {} attempts, giving up",
 848                MAX_RECONNECT_ATTEMPTS
 849            );
 850            drop(lock);
 851            self.set_state(State::ReconnectExhausted, cx);
 852            return Ok(());
 853        }
 854        drop(lock);
 855
 856        self.set_state(State::Reconnecting, cx);
 857
 858        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 859
 860        let unique_identifier = self.unique_identifier.clone();
 861        let client = self.client.clone();
 862        let reconnect_task = cx.spawn(async move |this, cx| {
 863            macro_rules! failed {
 864                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 865                    return State::ReconnectFailed {
 866                        error: anyhow!($error),
 867                        attempts: $attempts,
 868                        ssh_connection: $ssh_connection,
 869                        delegate: $delegate,
 870                    };
 871                };
 872            }
 873
 874            if let Err(error) = ssh_connection
 875                .kill()
 876                .await
 877                .context("Failed to kill ssh process")
 878            {
 879                failed!(error, attempts, ssh_connection, delegate);
 880            };
 881
 882            let connection_options = ssh_connection.connection_options();
 883
 884            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 885            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 886            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 887
 888            let (ssh_connection, io_task) = match async {
 889                let ssh_connection = cx
 890                    .update_global(|pool: &mut ConnectionPool, cx| {
 891                        pool.connect(connection_options, &delegate, cx)
 892                    })?
 893                    .await
 894                    .map_err(|error| error.cloned())?;
 895
 896                let io_task = ssh_connection.start_proxy(
 897                    unique_identifier,
 898                    true,
 899                    incoming_tx,
 900                    outgoing_rx,
 901                    connection_activity_tx,
 902                    delegate.clone(),
 903                    cx,
 904                );
 905                anyhow::Ok((ssh_connection, io_task))
 906            }
 907            .await
 908            {
 909                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 910                Err(error) => {
 911                    failed!(error, attempts, ssh_connection, delegate);
 912                }
 913            };
 914
 915            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 916            client.reconnect(incoming_rx, outgoing_tx, &cx);
 917
 918            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 919                failed!(error, attempts, ssh_connection, delegate);
 920            };
 921
 922            State::Connected {
 923                ssh_connection,
 924                delegate,
 925                multiplex_task,
 926                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 927            }
 928        });
 929
 930        cx.spawn(async move |this, cx| {
 931            let new_state = reconnect_task.await;
 932            this.update(cx, |this, cx| {
 933                this.try_set_state(cx, |old_state| {
 934                    if old_state.is_reconnecting() {
 935                        match &new_state {
 936                            State::Connecting
 937                            | State::Reconnecting { .. }
 938                            | State::HeartbeatMissed { .. }
 939                            | State::ServerNotRunning => {}
 940                            State::Connected { .. } => {
 941                                log::info!("Successfully reconnected");
 942                            }
 943                            State::ReconnectFailed {
 944                                error, attempts, ..
 945                            } => {
 946                                log::error!(
 947                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 948                                    attempts,
 949                                    error
 950                                );
 951                            }
 952                            State::ReconnectExhausted => {
 953                                log::error!("Reconnect attempt failed and all attempts exhausted");
 954                            }
 955                        }
 956                        Some(new_state)
 957                    } else {
 958                        None
 959                    }
 960                });
 961
 962                if this.state_is(State::is_reconnect_failed) {
 963                    this.reconnect(cx)
 964                } else if this.state_is(State::is_reconnect_exhausted) {
 965                    Ok(())
 966                } else {
 967                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 968                    Ok(())
 969                }
 970            })
 971        })
 972        .detach_and_log_err(cx);
 973
 974        Ok(())
 975    }
 976
 977    fn heartbeat(
 978        this: WeakEntity<Self>,
 979        mut connection_activity_rx: mpsc::Receiver<()>,
 980        cx: &mut AsyncApp,
 981    ) -> Task<Result<()>> {
 982        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 983            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 984        };
 985
 986        cx.spawn(async move |cx| {
 987                let mut missed_heartbeats = 0;
 988
 989                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 990                futures::pin_mut!(keepalive_timer);
 991
 992                loop {
 993                    select_biased! {
 994                        result = connection_activity_rx.next().fuse() => {
 995                            if result.is_none() {
 996                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 997                                return Ok(());
 998                            }
 999
1000                            if missed_heartbeats != 0 {
1001                                missed_heartbeats = 0;
1002                                let _ =this.update(cx, |this, mut cx| {
1003                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1004                                })?;
1005                            }
1006                        }
1007                        _ = keepalive_timer => {
1008                            log::debug!("Sending heartbeat to server...");
1009
1010                            let result = select_biased! {
1011                                _ = connection_activity_rx.next().fuse() => {
1012                                    Ok(())
1013                                }
1014                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1015                                    ping_result
1016                                }
1017                            };
1018
1019                            if result.is_err() {
1020                                missed_heartbeats += 1;
1021                                log::warn!(
1022                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1023                                    HEARTBEAT_TIMEOUT,
1024                                    missed_heartbeats,
1025                                    MAX_MISSED_HEARTBEATS
1026                                );
1027                            } else if missed_heartbeats != 0 {
1028                                missed_heartbeats = 0;
1029                            } else {
1030                                continue;
1031                            }
1032
1033                            let result = this.update(cx, |this, mut cx| {
1034                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1035                            })?;
1036                            if result.is_break() {
1037                                return Ok(());
1038                            }
1039                        }
1040                    }
1041
1042                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1043                }
1044
1045        })
1046    }
1047
1048    fn handle_heartbeat_result(
1049        &mut self,
1050        missed_heartbeats: usize,
1051        cx: &mut Context<Self>,
1052    ) -> ControlFlow<()> {
1053        let state = self.state.lock().take().unwrap();
1054        let next_state = if missed_heartbeats > 0 {
1055            state.heartbeat_missed()
1056        } else {
1057            state.heartbeat_recovered()
1058        };
1059
1060        self.set_state(next_state, cx);
1061
1062        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1063            log::error!(
1064                "Missed last {} heartbeats. Reconnecting...",
1065                missed_heartbeats
1066            );
1067
1068            self.reconnect(cx)
1069                .context("failed to start reconnect process after missing heartbeats")
1070                .log_err();
1071            ControlFlow::Break(())
1072        } else {
1073            ControlFlow::Continue(())
1074        }
1075    }
1076
1077    fn monitor(
1078        this: WeakEntity<Self>,
1079        io_task: Task<Result<i32>>,
1080        cx: &AsyncApp,
1081    ) -> Task<Result<()>> {
1082        cx.spawn(async move |cx| {
1083            let result = io_task.await;
1084
1085            match result {
1086                Ok(exit_code) => {
1087                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1088                        match error {
1089                            ProxyLaunchError::ServerNotRunning => {
1090                                log::error!("failed to reconnect because server is not running");
1091                                this.update(cx, |this, cx| {
1092                                    this.set_state(State::ServerNotRunning, cx);
1093                                })?;
1094                            }
1095                        }
1096                    } else if exit_code > 0 {
1097                        log::error!("proxy process terminated unexpectedly");
1098                        this.update(cx, |this, cx| {
1099                            this.reconnect(cx).ok();
1100                        })?;
1101                    }
1102                }
1103                Err(error) => {
1104                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1105                    this.update(cx, |this, cx| {
1106                        this.reconnect(cx).ok();
1107                    })?;
1108                }
1109            }
1110
1111            Ok(())
1112        })
1113    }
1114
1115    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1116        self.state.lock().as_ref().map_or(false, check)
1117    }
1118
1119    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1120        let mut lock = self.state.lock();
1121        let new_state = lock.as_ref().and_then(map);
1122
1123        if let Some(new_state) = new_state {
1124            lock.replace(new_state);
1125            cx.notify();
1126        }
1127    }
1128
1129    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1130        log::info!("setting state to '{}'", &state);
1131
1132        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1133        let is_server_not_running = state.is_server_not_running();
1134        self.state.lock().replace(state);
1135
1136        if is_reconnect_exhausted || is_server_not_running {
1137            cx.emit(SshRemoteEvent::Disconnected);
1138        }
1139        cx.notify();
1140    }
1141
1142    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1143        self.client.subscribe_to_entity(remote_id, entity);
1144    }
1145
1146    pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1147        self.state
1148            .lock()
1149            .as_ref()
1150            .and_then(|state| state.ssh_connection())
1151            .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1152    }
1153
1154    pub fn upload_directory(
1155        &self,
1156        src_path: PathBuf,
1157        dest_path: RemotePathBuf,
1158        cx: &App,
1159    ) -> Task<Result<()>> {
1160        let state = self.state.lock();
1161        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1162            return Task::ready(Err(anyhow!("no ssh connection")));
1163        };
1164        connection.upload_directory(src_path, dest_path, cx)
1165    }
1166
1167    pub fn proto_client(&self) -> AnyProtoClient {
1168        self.client.clone().into()
1169    }
1170
1171    pub fn connection_string(&self) -> String {
1172        self.connection_options.connection_string()
1173    }
1174
1175    pub fn connection_options(&self) -> SshConnectionOptions {
1176        self.connection_options.clone()
1177    }
1178
1179    pub fn connection_state(&self) -> ConnectionState {
1180        self.state
1181            .lock()
1182            .as_ref()
1183            .map(ConnectionState::from)
1184            .unwrap_or(ConnectionState::Disconnected)
1185    }
1186
1187    pub fn is_disconnected(&self) -> bool {
1188        self.connection_state() == ConnectionState::Disconnected
1189    }
1190
1191    pub fn path_style(&self) -> PathStyle {
1192        self.path_style
1193    }
1194
1195    #[cfg(any(test, feature = "test-support"))]
1196    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1197        let opts = self.connection_options();
1198        client_cx.spawn(async move |cx| {
1199            let connection = cx
1200                .update_global(|c: &mut ConnectionPool, _| {
1201                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1202                        c.clone()
1203                    } else {
1204                        panic!("missing test connection")
1205                    }
1206                })
1207                .unwrap()
1208                .await
1209                .unwrap();
1210
1211            connection.simulate_disconnect(&cx);
1212        })
1213    }
1214
1215    #[cfg(any(test, feature = "test-support"))]
1216    pub fn fake_server(
1217        client_cx: &mut gpui::TestAppContext,
1218        server_cx: &mut gpui::TestAppContext,
1219    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1220        let port = client_cx
1221            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1222        let opts = SshConnectionOptions {
1223            host: "<fake>".to_string(),
1224            port: Some(port),
1225            ..Default::default()
1226        };
1227        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1228        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1229        let server_client =
1230            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1231        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1232            connection_options: opts.clone(),
1233            server_cx: fake::SendableCx::new(server_cx),
1234            server_channel: server_client.clone(),
1235        });
1236
1237        client_cx.update(|cx| {
1238            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1239                c.connections.insert(
1240                    opts.clone(),
1241                    ConnectionPoolEntry::Connecting(
1242                        cx.background_spawn({
1243                            let connection = connection.clone();
1244                            async move { Ok(connection.clone()) }
1245                        })
1246                        .shared(),
1247                    ),
1248                );
1249            })
1250        });
1251
1252        (opts, server_client)
1253    }
1254
1255    #[cfg(any(test, feature = "test-support"))]
1256    pub async fn fake_client(
1257        opts: SshConnectionOptions,
1258        client_cx: &mut gpui::TestAppContext,
1259    ) -> Entity<Self> {
1260        let (_tx, rx) = oneshot::channel();
1261        client_cx
1262            .update(|cx| {
1263                Self::new(
1264                    ConnectionIdentifier::setup(),
1265                    opts,
1266                    rx,
1267                    Arc::new(fake::Delegate),
1268                    cx,
1269                )
1270            })
1271            .await
1272            .unwrap()
1273            .unwrap()
1274    }
1275}
1276
1277enum ConnectionPoolEntry {
1278    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1279    Connected(Weak<dyn RemoteConnection>),
1280}
1281
1282#[derive(Default)]
1283struct ConnectionPool {
1284    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1285}
1286
1287impl Global for ConnectionPool {}
1288
1289impl ConnectionPool {
1290    pub fn connect(
1291        &mut self,
1292        opts: SshConnectionOptions,
1293        delegate: &Arc<dyn SshClientDelegate>,
1294        cx: &mut App,
1295    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1296        let connection = self.connections.get(&opts);
1297        match connection {
1298            Some(ConnectionPoolEntry::Connecting(task)) => {
1299                let delegate = delegate.clone();
1300                cx.spawn(async move |cx| {
1301                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1302                })
1303                .detach();
1304                return task.clone();
1305            }
1306            Some(ConnectionPoolEntry::Connected(ssh)) => {
1307                if let Some(ssh) = ssh.upgrade() {
1308                    if !ssh.has_been_killed() {
1309                        return Task::ready(Ok(ssh)).shared();
1310                    }
1311                }
1312                self.connections.remove(&opts);
1313            }
1314            None => {}
1315        }
1316
1317        let task = cx
1318            .spawn({
1319                let opts = opts.clone();
1320                let delegate = delegate.clone();
1321                async move |cx| {
1322                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1323                        .await
1324                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1325
1326                    cx.update_global(|pool: &mut Self, _| {
1327                        debug_assert!(matches!(
1328                            pool.connections.get(&opts),
1329                            Some(ConnectionPoolEntry::Connecting(_))
1330                        ));
1331                        match connection {
1332                            Ok(connection) => {
1333                                pool.connections.insert(
1334                                    opts.clone(),
1335                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1336                                );
1337                                Ok(connection)
1338                            }
1339                            Err(error) => {
1340                                pool.connections.remove(&opts);
1341                                Err(Arc::new(error))
1342                            }
1343                        }
1344                    })?
1345                }
1346            })
1347            .shared();
1348
1349        self.connections
1350            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1351        task
1352    }
1353}
1354
1355impl From<SshRemoteClient> for AnyProtoClient {
1356    fn from(client: SshRemoteClient) -> Self {
1357        AnyProtoClient::new(client.client.clone())
1358    }
1359}
1360
1361#[async_trait(?Send)]
1362trait RemoteConnection: Send + Sync {
1363    fn start_proxy(
1364        &self,
1365        unique_identifier: String,
1366        reconnect: bool,
1367        incoming_tx: UnboundedSender<Envelope>,
1368        outgoing_rx: UnboundedReceiver<Envelope>,
1369        connection_activity_tx: Sender<()>,
1370        delegate: Arc<dyn SshClientDelegate>,
1371        cx: &mut AsyncApp,
1372    ) -> Task<Result<i32>>;
1373    fn upload_directory(
1374        &self,
1375        src_path: PathBuf,
1376        dest_path: RemotePathBuf,
1377        cx: &App,
1378    ) -> Task<Result<()>>;
1379    async fn kill(&self) -> Result<()>;
1380    fn has_been_killed(&self) -> bool;
1381    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1382    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1383    fn ssh_args(&self) -> SshArgs;
1384    fn connection_options(&self) -> SshConnectionOptions;
1385    fn path_style(&self) -> PathStyle;
1386
1387    #[cfg(any(test, feature = "test-support"))]
1388    fn simulate_disconnect(&self, _: &AsyncApp) {}
1389}
1390
1391struct SshRemoteConnection {
1392    socket: SshSocket,
1393    master_process: Mutex<Option<Child>>,
1394    remote_binary_path: Option<RemotePathBuf>,
1395    ssh_platform: SshPlatform,
1396    ssh_path_style: PathStyle,
1397    _temp_dir: TempDir,
1398}
1399
1400#[async_trait(?Send)]
1401impl RemoteConnection for SshRemoteConnection {
1402    async fn kill(&self) -> Result<()> {
1403        let Some(mut process) = self.master_process.lock().take() else {
1404            return Ok(());
1405        };
1406        process.kill().ok();
1407        process.status().await?;
1408        Ok(())
1409    }
1410
1411    fn has_been_killed(&self) -> bool {
1412        self.master_process.lock().is_none()
1413    }
1414
1415    fn ssh_args(&self) -> SshArgs {
1416        self.socket.ssh_args()
1417    }
1418
1419    fn connection_options(&self) -> SshConnectionOptions {
1420        self.socket.connection_options.clone()
1421    }
1422
1423    fn upload_directory(
1424        &self,
1425        src_path: PathBuf,
1426        dest_path: RemotePathBuf,
1427        cx: &App,
1428    ) -> Task<Result<()>> {
1429        let mut command = util::command::new_smol_command("scp");
1430        let output = self
1431            .socket
1432            .ssh_options(&mut command)
1433            .args(
1434                self.socket
1435                    .connection_options
1436                    .port
1437                    .map(|port| vec!["-P".to_string(), port.to_string()])
1438                    .unwrap_or_default(),
1439            )
1440            .arg("-C")
1441            .arg("-r")
1442            .arg(&src_path)
1443            .arg(format!(
1444                "{}:{}",
1445                self.socket.connection_options.scp_url(),
1446                dest_path.to_string()
1447            ))
1448            .output();
1449
1450        cx.background_spawn(async move {
1451            let output = output.await?;
1452
1453            anyhow::ensure!(
1454                output.status.success(),
1455                "failed to upload directory {} -> {}: {}",
1456                src_path.display(),
1457                dest_path.to_string(),
1458                String::from_utf8_lossy(&output.stderr)
1459            );
1460
1461            Ok(())
1462        })
1463    }
1464
1465    fn start_proxy(
1466        &self,
1467        unique_identifier: String,
1468        reconnect: bool,
1469        incoming_tx: UnboundedSender<Envelope>,
1470        outgoing_rx: UnboundedReceiver<Envelope>,
1471        connection_activity_tx: Sender<()>,
1472        delegate: Arc<dyn SshClientDelegate>,
1473        cx: &mut AsyncApp,
1474    ) -> Task<Result<i32>> {
1475        delegate.set_status(Some("Starting proxy"), cx);
1476
1477        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1478            return Task::ready(Err(anyhow!("Remote binary path not set")));
1479        };
1480
1481        let mut start_proxy_command = shell_script!(
1482            "exec {binary_path} proxy --identifier {identifier}",
1483            binary_path = &remote_binary_path.to_string(),
1484            identifier = &unique_identifier,
1485        );
1486
1487        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
1488            if let Some(value) = std::env::var(env_var).ok() {
1489                start_proxy_command = format!(
1490                    "{}={} {} ",
1491                    env_var,
1492                    shlex::try_quote(&value).unwrap(),
1493                    start_proxy_command,
1494                );
1495            }
1496        }
1497
1498        if reconnect {
1499            start_proxy_command.push_str(" --reconnect");
1500        }
1501
1502        let ssh_proxy_process = match self
1503            .socket
1504            .ssh_command("sh", &["-c", &start_proxy_command])
1505            // IMPORTANT: we kill this process when we drop the task that uses it.
1506            .kill_on_drop(true)
1507            .spawn()
1508        {
1509            Ok(process) => process,
1510            Err(error) => {
1511                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1512            }
1513        };
1514
1515        Self::multiplex(
1516            ssh_proxy_process,
1517            incoming_tx,
1518            outgoing_rx,
1519            connection_activity_tx,
1520            &cx,
1521        )
1522    }
1523
1524    fn path_style(&self) -> PathStyle {
1525        self.ssh_path_style
1526    }
1527}
1528
1529impl SshRemoteConnection {
1530    async fn new(
1531        connection_options: SshConnectionOptions,
1532        delegate: Arc<dyn SshClientDelegate>,
1533        cx: &mut AsyncApp,
1534    ) -> Result<Self> {
1535        use askpass::AskPassResult;
1536
1537        delegate.set_status(Some("Connecting"), cx);
1538
1539        let url = connection_options.ssh_url();
1540
1541        let temp_dir = tempfile::Builder::new()
1542            .prefix("zed-ssh-session")
1543            .tempdir()?;
1544        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1545            let delegate = delegate.clone();
1546            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1547        });
1548
1549        let mut askpass =
1550            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1551
1552        // Start the master SSH process, which does not do anything except for establish
1553        // the connection and keep it open, allowing other ssh commands to reuse it
1554        // via a control socket.
1555        #[cfg(not(target_os = "windows"))]
1556        let socket_path = temp_dir.path().join("ssh.sock");
1557
1558        let mut master_process = {
1559            #[cfg(not(target_os = "windows"))]
1560            let args = [
1561                "-N",
1562                "-o",
1563                "ControlPersist=no",
1564                "-o",
1565                "ControlMaster=yes",
1566                "-o",
1567            ];
1568            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1569            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1570            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1571            #[cfg(target_os = "windows")]
1572            let args = ["-N"];
1573            let mut master_process = util::command::new_smol_command("ssh");
1574            master_process
1575                .kill_on_drop(true)
1576                .stdin(Stdio::null())
1577                .stdout(Stdio::piped())
1578                .stderr(Stdio::piped())
1579                .env("SSH_ASKPASS_REQUIRE", "force")
1580                .env("SSH_ASKPASS", askpass.script_path())
1581                .args(connection_options.additional_args())
1582                .args(args);
1583            #[cfg(not(target_os = "windows"))]
1584            master_process.arg(format!("ControlPath={}", socket_path.display()));
1585            master_process.arg(&url).spawn()?
1586        };
1587        // Wait for this ssh process to close its stdout, indicating that authentication
1588        // has completed.
1589        let mut stdout = master_process.stdout.take().unwrap();
1590        let mut output = Vec::new();
1591
1592        let result = select_biased! {
1593            result = askpass.run().fuse() => {
1594                match result {
1595                    AskPassResult::CancelledByUser => {
1596                        master_process.kill().ok();
1597                        anyhow::bail!("SSH connection canceled")
1598                    }
1599                    AskPassResult::Timedout => {
1600                        anyhow::bail!("connecting to host timed out")
1601                    }
1602                }
1603            }
1604            _ = stdout.read_to_end(&mut output).fuse() => {
1605                anyhow::Ok(())
1606            }
1607        };
1608
1609        if let Err(e) = result {
1610            return Err(e.context("Failed to connect to host"));
1611        }
1612
1613        if master_process.try_status()?.is_some() {
1614            output.clear();
1615            let mut stderr = master_process.stderr.take().unwrap();
1616            stderr.read_to_end(&mut output).await?;
1617
1618            let error_message = format!(
1619                "failed to connect: {}",
1620                String::from_utf8_lossy(&output).trim()
1621            );
1622            anyhow::bail!(error_message);
1623        }
1624
1625        #[cfg(not(target_os = "windows"))]
1626        let socket = SshSocket::new(connection_options, socket_path)?;
1627        #[cfg(target_os = "windows")]
1628        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1629        drop(askpass);
1630
1631        let ssh_platform = socket.platform().await?;
1632        let ssh_path_style = match ssh_platform.os {
1633            "windows" => PathStyle::Windows,
1634            _ => PathStyle::Posix,
1635        };
1636
1637        let mut this = Self {
1638            socket,
1639            master_process: Mutex::new(Some(master_process)),
1640            _temp_dir: temp_dir,
1641            remote_binary_path: None,
1642            ssh_path_style,
1643            ssh_platform,
1644        };
1645
1646        let (release_channel, version, commit) = cx.update(|cx| {
1647            (
1648                ReleaseChannel::global(cx),
1649                AppVersion::global(cx),
1650                AppCommitSha::try_global(cx),
1651            )
1652        })?;
1653        this.remote_binary_path = Some(
1654            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1655                .await?,
1656        );
1657
1658        Ok(this)
1659    }
1660
1661    fn multiplex(
1662        mut ssh_proxy_process: Child,
1663        incoming_tx: UnboundedSender<Envelope>,
1664        mut outgoing_rx: UnboundedReceiver<Envelope>,
1665        mut connection_activity_tx: Sender<()>,
1666        cx: &AsyncApp,
1667    ) -> Task<Result<i32>> {
1668        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1669        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1670        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1671
1672        let mut stdin_buffer = Vec::new();
1673        let mut stdout_buffer = Vec::new();
1674        let mut stderr_buffer = Vec::new();
1675        let mut stderr_offset = 0;
1676
1677        let stdin_task = cx.background_spawn(async move {
1678            while let Some(outgoing) = outgoing_rx.next().await {
1679                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1680            }
1681            anyhow::Ok(())
1682        });
1683
1684        let stdout_task = cx.background_spawn({
1685            let mut connection_activity_tx = connection_activity_tx.clone();
1686            async move {
1687                loop {
1688                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1689                    let len = child_stdout.read(&mut stdout_buffer).await?;
1690
1691                    if len == 0 {
1692                        return anyhow::Ok(());
1693                    }
1694
1695                    if len < MESSAGE_LEN_SIZE {
1696                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1697                    }
1698
1699                    let message_len = message_len_from_buffer(&stdout_buffer);
1700                    let envelope =
1701                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1702                            .await?;
1703                    connection_activity_tx.try_send(()).ok();
1704                    incoming_tx.unbounded_send(envelope).ok();
1705                }
1706            }
1707        });
1708
1709        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1710            loop {
1711                stderr_buffer.resize(stderr_offset + 1024, 0);
1712
1713                let len = child_stderr
1714                    .read(&mut stderr_buffer[stderr_offset..])
1715                    .await?;
1716                if len == 0 {
1717                    return anyhow::Ok(());
1718                }
1719
1720                stderr_offset += len;
1721                let mut start_ix = 0;
1722                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1723                    .iter()
1724                    .position(|b| b == &b'\n')
1725                {
1726                    let line_ix = start_ix + ix;
1727                    let content = &stderr_buffer[start_ix..line_ix];
1728                    start_ix = line_ix + 1;
1729                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1730                        record.log(log::logger())
1731                    } else {
1732                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1733                    }
1734                }
1735                stderr_buffer.drain(0..start_ix);
1736                stderr_offset -= start_ix;
1737
1738                connection_activity_tx.try_send(()).ok();
1739            }
1740        });
1741
1742        cx.background_spawn(async move {
1743            let result = futures::select! {
1744                result = stdin_task.fuse() => {
1745                    result.context("stdin")
1746                }
1747                result = stdout_task.fuse() => {
1748                    result.context("stdout")
1749                }
1750                result = stderr_task.fuse() => {
1751                    result.context("stderr")
1752                }
1753            };
1754
1755            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1756            match result {
1757                Ok(_) => Ok(status),
1758                Err(error) => Err(error),
1759            }
1760        })
1761    }
1762
1763    #[allow(unused)]
1764    async fn ensure_server_binary(
1765        &self,
1766        delegate: &Arc<dyn SshClientDelegate>,
1767        release_channel: ReleaseChannel,
1768        version: SemanticVersion,
1769        commit: Option<AppCommitSha>,
1770        cx: &mut AsyncApp,
1771    ) -> Result<RemotePathBuf> {
1772        let version_str = match release_channel {
1773            ReleaseChannel::Nightly => {
1774                let commit = commit.map(|s| s.full()).unwrap_or_default();
1775                format!("{}-{}", version, commit)
1776            }
1777            ReleaseChannel::Dev => "build".to_string(),
1778            _ => version.to_string(),
1779        };
1780        let binary_name = format!(
1781            "zed-remote-server-{}-{}",
1782            release_channel.dev_name(),
1783            version_str
1784        );
1785        let dst_path = RemotePathBuf::new(
1786            paths::remote_server_dir_relative().join(binary_name),
1787            self.ssh_path_style,
1788        );
1789
1790        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1791        #[cfg(debug_assertions)]
1792        if let Some(build_remote_server) = build_remote_server {
1793            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1794            let tmp_path = RemotePathBuf::new(
1795                paths::remote_server_dir_relative().join(format!(
1796                    "download-{}-{}",
1797                    std::process::id(),
1798                    src_path.file_name().unwrap().to_string_lossy()
1799                )),
1800                self.ssh_path_style,
1801            );
1802            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1803                .await?;
1804            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1805                .await?;
1806            return Ok(dst_path);
1807        }
1808
1809        if self
1810            .socket
1811            .run_command(&dst_path.to_string(), &["version"])
1812            .await
1813            .is_ok()
1814        {
1815            return Ok(dst_path);
1816        }
1817
1818        let wanted_version = cx.update(|cx| match release_channel {
1819            ReleaseChannel::Nightly => Ok(None),
1820            ReleaseChannel::Dev => {
1821                anyhow::bail!(
1822                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1823                    dst_path
1824                )
1825            }
1826            _ => Ok(Some(AppVersion::global(cx))),
1827        })??;
1828
1829        let tmp_path_gz = RemotePathBuf::new(
1830            PathBuf::from(format!(
1831                "{}-download-{}.gz",
1832                dst_path.to_string(),
1833                std::process::id()
1834            )),
1835            self.ssh_path_style,
1836        );
1837        if !self.socket.connection_options.upload_binary_over_ssh {
1838            if let Some((url, body)) = delegate
1839                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1840                .await?
1841            {
1842                match self
1843                    .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1844                    .await
1845                {
1846                    Ok(_) => {
1847                        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1848                            .await?;
1849                        return Ok(dst_path);
1850                    }
1851                    Err(e) => {
1852                        log::error!(
1853                            "Failed to download binary on server, attempting to upload server: {}",
1854                            e
1855                        )
1856                    }
1857                }
1858            }
1859        }
1860
1861        let src_path = delegate
1862            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1863            .await?;
1864        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1865            .await?;
1866        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1867            .await?;
1868        return Ok(dst_path);
1869    }
1870
1871    async fn download_binary_on_server(
1872        &self,
1873        url: &str,
1874        body: &str,
1875        tmp_path_gz: &RemotePathBuf,
1876        delegate: &Arc<dyn SshClientDelegate>,
1877        cx: &mut AsyncApp,
1878    ) -> Result<()> {
1879        if let Some(parent) = tmp_path_gz.parent() {
1880            self.socket
1881                .run_command(
1882                    "sh",
1883                    &[
1884                        "-c",
1885                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1886                    ],
1887                )
1888                .await?;
1889        }
1890
1891        delegate.set_status(Some("Downloading remote development server on host"), cx);
1892
1893        match self
1894            .socket
1895            .run_command(
1896                "curl",
1897                &[
1898                    "-f",
1899                    "-L",
1900                    "-X",
1901                    "GET",
1902                    "-H",
1903                    "Content-Type: application/json",
1904                    "-d",
1905                    &body,
1906                    &url,
1907                    "-o",
1908                    &tmp_path_gz.to_string(),
1909                ],
1910            )
1911            .await
1912        {
1913            Ok(_) => {}
1914            Err(e) => {
1915                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1916                    return Err(e);
1917                }
1918
1919                match self
1920                    .socket
1921                    .run_command(
1922                        "wget",
1923                        &[
1924                            "--method=GET",
1925                            "--header=Content-Type: application/json",
1926                            "--body-data",
1927                            &body,
1928                            &url,
1929                            "-O",
1930                            &tmp_path_gz.to_string(),
1931                        ],
1932                    )
1933                    .await
1934                {
1935                    Ok(_) => {}
1936                    Err(e) => {
1937                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1938                            return Err(e);
1939                        } else {
1940                            anyhow::bail!("Neither curl nor wget is available");
1941                        }
1942                    }
1943                }
1944            }
1945        }
1946
1947        Ok(())
1948    }
1949
1950    async fn upload_local_server_binary(
1951        &self,
1952        src_path: &Path,
1953        tmp_path_gz: &RemotePathBuf,
1954        delegate: &Arc<dyn SshClientDelegate>,
1955        cx: &mut AsyncApp,
1956    ) -> Result<()> {
1957        if let Some(parent) = tmp_path_gz.parent() {
1958            self.socket
1959                .run_command(
1960                    "sh",
1961                    &[
1962                        "-c",
1963                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1964                    ],
1965                )
1966                .await?;
1967        }
1968
1969        let src_stat = fs::metadata(&src_path).await?;
1970        let size = src_stat.len();
1971
1972        let t0 = Instant::now();
1973        delegate.set_status(Some("Uploading remote development server"), cx);
1974        log::info!(
1975            "uploading remote development server to {:?} ({}kb)",
1976            tmp_path_gz,
1977            size / 1024
1978        );
1979        self.upload_file(&src_path, &tmp_path_gz)
1980            .await
1981            .context("failed to upload server binary")?;
1982        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1983        Ok(())
1984    }
1985
1986    async fn extract_server_binary(
1987        &self,
1988        dst_path: &RemotePathBuf,
1989        tmp_path: &RemotePathBuf,
1990        delegate: &Arc<dyn SshClientDelegate>,
1991        cx: &mut AsyncApp,
1992    ) -> Result<()> {
1993        delegate.set_status(Some("Extracting remote development server"), cx);
1994        let server_mode = 0o755;
1995
1996        let orig_tmp_path = tmp_path.to_string();
1997        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
1998            shell_script!(
1999                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2000                server_mode = &format!("{:o}", server_mode),
2001                dst_path = &dst_path.to_string(),
2002            )
2003        } else {
2004            shell_script!(
2005                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2006                server_mode = &format!("{:o}", server_mode),
2007                dst_path = &dst_path.to_string()
2008            )
2009        };
2010        self.socket.run_command("sh", &["-c", &script]).await?;
2011        Ok(())
2012    }
2013
2014    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2015        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2016        let mut command = util::command::new_smol_command("scp");
2017        let output = self
2018            .socket
2019            .ssh_options(&mut command)
2020            .args(
2021                self.socket
2022                    .connection_options
2023                    .port
2024                    .map(|port| vec!["-P".to_string(), port.to_string()])
2025                    .unwrap_or_default(),
2026            )
2027            .arg(src_path)
2028            .arg(format!(
2029                "{}:{}",
2030                self.socket.connection_options.scp_url(),
2031                dest_path.to_string()
2032            ))
2033            .output()
2034            .await?;
2035
2036        anyhow::ensure!(
2037            output.status.success(),
2038            "failed to upload file {} -> {}: {}",
2039            src_path.display(),
2040            dest_path.to_string(),
2041            String::from_utf8_lossy(&output.stderr)
2042        );
2043        Ok(())
2044    }
2045
2046    #[cfg(debug_assertions)]
2047    async fn build_local(
2048        &self,
2049        build_remote_server: String,
2050        delegate: &Arc<dyn SshClientDelegate>,
2051        cx: &mut AsyncApp,
2052    ) -> Result<PathBuf> {
2053        use smol::process::{Command, Stdio};
2054        use std::env::VarError;
2055
2056        async fn run_cmd(command: &mut Command) -> Result<()> {
2057            let output = command
2058                .kill_on_drop(true)
2059                .stderr(Stdio::inherit())
2060                .output()
2061                .await?;
2062            anyhow::ensure!(
2063                output.status.success(),
2064                "Failed to run command: {command:?}"
2065            );
2066            Ok(())
2067        }
2068
2069        let triple = format!(
2070            "{}-{}",
2071            self.ssh_platform.arch,
2072            match self.ssh_platform.os {
2073                "linux" => "unknown-linux-musl",
2074                "macos" => "apple-darwin",
2075                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2076            }
2077        );
2078        let mut rust_flags = match std::env::var("RUSTFLAGS") {
2079            Ok(val) => val,
2080            Err(VarError::NotPresent) => String::new(),
2081            Err(e) => {
2082                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2083                String::new()
2084            }
2085        };
2086        if self.ssh_platform.os == "linux" {
2087            rust_flags.push_str(" -C target-feature=+crt-static");
2088        }
2089        if build_remote_server.contains("mold") {
2090            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2091        }
2092
2093        if self.ssh_platform.arch == std::env::consts::ARCH
2094            && self.ssh_platform.os == std::env::consts::OS
2095        {
2096            delegate.set_status(Some("Building remote server binary from source"), cx);
2097            log::info!("building remote server binary from source");
2098            run_cmd(
2099                Command::new("cargo")
2100                    .args([
2101                        "build",
2102                        "--package",
2103                        "remote_server",
2104                        "--features",
2105                        "debug-embed",
2106                        "--target-dir",
2107                        "target/remote_server",
2108                        "--target",
2109                        &triple,
2110                    ])
2111                    .env("RUSTFLAGS", &rust_flags),
2112            )
2113            .await?;
2114        } else {
2115            if build_remote_server.contains("cross") {
2116                #[cfg(target_os = "windows")]
2117                use util::paths::SanitizedPath;
2118
2119                delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2120                log::info!("installing cross");
2121                run_cmd(Command::new("cargo").args([
2122                    "install",
2123                    "cross",
2124                    "--git",
2125                    "https://github.com/cross-rs/cross",
2126                ]))
2127                .await?;
2128
2129                delegate.set_status(
2130                    Some(&format!(
2131                        "Building remote server binary from source for {} with Docker",
2132                        &triple
2133                    )),
2134                    cx,
2135                );
2136                log::info!("building remote server binary from source for {}", &triple);
2137
2138                // On Windows, the binding needs to be set to the canonical path
2139                #[cfg(target_os = "windows")]
2140                let src =
2141                    SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2142                #[cfg(not(target_os = "windows"))]
2143                let src = "./target";
2144                run_cmd(
2145                    Command::new("cross")
2146                        .args([
2147                            "build",
2148                            "--package",
2149                            "remote_server",
2150                            "--features",
2151                            "debug-embed",
2152                            "--target-dir",
2153                            "target/remote_server",
2154                            "--target",
2155                            &triple,
2156                        ])
2157                        .env(
2158                            "CROSS_CONTAINER_OPTS",
2159                            format!("--mount type=bind,src={src},dst=/app/target"),
2160                        )
2161                        .env("RUSTFLAGS", &rust_flags),
2162                )
2163                .await?;
2164            } else {
2165                let which = cx
2166                    .background_spawn(async move { which::which("zig") })
2167                    .await;
2168
2169                if which.is_err() {
2170                    #[cfg(not(target_os = "windows"))]
2171                    {
2172                        anyhow::bail!(
2173                            "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2174                        )
2175                    }
2176                    #[cfg(target_os = "windows")]
2177                    {
2178                        anyhow::bail!(
2179                            "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2180                        )
2181                    }
2182                }
2183
2184                delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2185                log::info!("adding rustup target");
2186                run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2187
2188                delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2189                log::info!("installing cargo-zigbuild");
2190                run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"]))
2191                    .await?;
2192
2193                delegate.set_status(
2194                    Some(&format!(
2195                        "Building remote binary from source for {triple} with Zig"
2196                    )),
2197                    cx,
2198                );
2199                log::info!("building remote binary from source for {triple} with Zig");
2200                run_cmd(
2201                    Command::new("cargo")
2202                        .args([
2203                            "zigbuild",
2204                            "--package",
2205                            "remote_server",
2206                            "--features",
2207                            "debug-embed",
2208                            "--target-dir",
2209                            "target/remote_server",
2210                            "--target",
2211                            &triple,
2212                        ])
2213                        .env("RUSTFLAGS", &rust_flags),
2214                )
2215                .await?;
2216            }
2217        };
2218        let bin_path = Path::new("target")
2219            .join("remote_server")
2220            .join(&triple)
2221            .join("debug")
2222            .join("remote_server");
2223
2224        let path = if !build_remote_server.contains("nocompress") {
2225            delegate.set_status(Some("Compressing binary"), cx);
2226
2227            #[cfg(not(target_os = "windows"))]
2228            {
2229                run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
2230            }
2231            #[cfg(target_os = "windows")]
2232            {
2233                // On Windows, we use 7z to compress the binary
2234                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2235                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2236                if smol::fs::metadata(&gz_path).await.is_ok() {
2237                    smol::fs::remove_file(&gz_path).await?;
2238                }
2239                run_cmd(Command::new(seven_zip).args([
2240                    "a",
2241                    "-tgzip",
2242                    &gz_path,
2243                    &bin_path.to_string_lossy(),
2244                ]))
2245                .await?;
2246            }
2247
2248            let mut archive_path = bin_path;
2249            archive_path.set_extension("gz");
2250            std::env::current_dir()?.join(archive_path)
2251        } else {
2252            bin_path
2253        };
2254
2255        Ok(path)
2256    }
2257}
2258
2259type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2260
2261pub struct ChannelClient {
2262    next_message_id: AtomicU32,
2263    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2264    buffer: Mutex<VecDeque<Envelope>>,
2265    response_channels: ResponseChannels,
2266    message_handlers: Mutex<ProtoMessageHandlerSet>,
2267    max_received: AtomicU32,
2268    name: &'static str,
2269    task: Mutex<Task<Result<()>>>,
2270}
2271
2272impl ChannelClient {
2273    pub fn new(
2274        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2275        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2276        cx: &App,
2277        name: &'static str,
2278    ) -> Arc<Self> {
2279        Arc::new_cyclic(|this| Self {
2280            outgoing_tx: Mutex::new(outgoing_tx),
2281            next_message_id: AtomicU32::new(0),
2282            max_received: AtomicU32::new(0),
2283            response_channels: ResponseChannels::default(),
2284            message_handlers: Default::default(),
2285            buffer: Mutex::new(VecDeque::new()),
2286            name,
2287            task: Mutex::new(Self::start_handling_messages(
2288                this.clone(),
2289                incoming_rx,
2290                &cx.to_async(),
2291            )),
2292        })
2293    }
2294
2295    fn start_handling_messages(
2296        this: Weak<Self>,
2297        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2298        cx: &AsyncApp,
2299    ) -> Task<Result<()>> {
2300        cx.spawn(async move |cx| {
2301            let peer_id = PeerId { owner_id: 0, id: 0 };
2302            while let Some(incoming) = incoming_rx.next().await {
2303                let Some(this) = this.upgrade() else {
2304                    return anyhow::Ok(());
2305                };
2306                if let Some(ack_id) = incoming.ack_id {
2307                    let mut buffer = this.buffer.lock();
2308                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2309                        buffer.pop_front();
2310                    }
2311                }
2312                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2313                {
2314                    log::debug!(
2315                        "{}:ssh message received. name:FlushBufferedMessages",
2316                        this.name
2317                    );
2318                    {
2319                        let buffer = this.buffer.lock();
2320                        for envelope in buffer.iter() {
2321                            this.outgoing_tx
2322                                .lock()
2323                                .unbounded_send(envelope.clone())
2324                                .ok();
2325                        }
2326                    }
2327                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2328                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2329                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2330                    continue;
2331                }
2332
2333                this.max_received.store(incoming.id, SeqCst);
2334
2335                if let Some(request_id) = incoming.responding_to {
2336                    let request_id = MessageId(request_id);
2337                    let sender = this.response_channels.lock().remove(&request_id);
2338                    if let Some(sender) = sender {
2339                        let (tx, rx) = oneshot::channel();
2340                        if incoming.payload.is_some() {
2341                            sender.send((incoming, tx)).ok();
2342                        }
2343                        rx.await.ok();
2344                    }
2345                } else if let Some(envelope) =
2346                    build_typed_envelope(peer_id, Instant::now(), incoming)
2347                {
2348                    let type_name = envelope.payload_type_name();
2349                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2350                        &this.message_handlers,
2351                        envelope,
2352                        this.clone().into(),
2353                        cx.clone(),
2354                    ) {
2355                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2356                        cx.foreground_executor()
2357                            .spawn(async move {
2358                                match future.await {
2359                                    Ok(_) => {
2360                                        log::debug!(
2361                                            "{}:ssh message handled. name:{type_name}",
2362                                            this.name
2363                                        );
2364                                    }
2365                                    Err(error) => {
2366                                        log::error!(
2367                                            "{}:error handling message. type:{}, error:{}",
2368                                            this.name,
2369                                            type_name,
2370                                            format!("{error:#}").lines().fold(
2371                                                String::new(),
2372                                                |mut message, line| {
2373                                                    if !message.is_empty() {
2374                                                        message.push(' ');
2375                                                    }
2376                                                    message.push_str(line);
2377                                                    message
2378                                                }
2379                                            )
2380                                        );
2381                                    }
2382                                }
2383                            })
2384                            .detach()
2385                    } else {
2386                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2387                    }
2388                }
2389            }
2390            anyhow::Ok(())
2391        })
2392    }
2393
2394    pub fn reconnect(
2395        self: &Arc<Self>,
2396        incoming_rx: UnboundedReceiver<Envelope>,
2397        outgoing_tx: UnboundedSender<Envelope>,
2398        cx: &AsyncApp,
2399    ) {
2400        *self.outgoing_tx.lock() = outgoing_tx;
2401        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2402    }
2403
2404    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2405        let id = (TypeId::of::<E>(), remote_id);
2406
2407        let mut message_handlers = self.message_handlers.lock();
2408        if message_handlers
2409            .entities_by_type_and_remote_id
2410            .contains_key(&id)
2411        {
2412            panic!("already subscribed to entity");
2413        }
2414
2415        message_handlers.entities_by_type_and_remote_id.insert(
2416            id,
2417            EntityMessageSubscriber::Entity {
2418                handle: entity.downgrade().into(),
2419            },
2420        );
2421    }
2422
2423    pub fn request<T: RequestMessage>(
2424        &self,
2425        payload: T,
2426    ) -> impl 'static + Future<Output = Result<T::Response>> {
2427        self.request_internal(payload, true)
2428    }
2429
2430    fn request_internal<T: RequestMessage>(
2431        &self,
2432        payload: T,
2433        use_buffer: bool,
2434    ) -> impl 'static + Future<Output = Result<T::Response>> {
2435        log::debug!("ssh request start. name:{}", T::NAME);
2436        let response =
2437            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2438        async move {
2439            let response = response.await?;
2440            log::debug!("ssh request finish. name:{}", T::NAME);
2441            T::Response::from_envelope(response).context("received a response of the wrong type")
2442        }
2443    }
2444
2445    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2446        smol::future::or(
2447            async {
2448                self.request_internal(proto::FlushBufferedMessages {}, false)
2449                    .await?;
2450
2451                for envelope in self.buffer.lock().iter() {
2452                    self.outgoing_tx
2453                        .lock()
2454                        .unbounded_send(envelope.clone())
2455                        .ok();
2456                }
2457                Ok(())
2458            },
2459            async {
2460                smol::Timer::after(timeout).await;
2461                anyhow::bail!("Timed out resyncing remote client")
2462            },
2463        )
2464        .await
2465    }
2466
2467    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2468        smol::future::or(
2469            async {
2470                self.request(proto::Ping {}).await?;
2471                Ok(())
2472            },
2473            async {
2474                smol::Timer::after(timeout).await;
2475                anyhow::bail!("Timed out pinging remote client")
2476            },
2477        )
2478        .await
2479    }
2480
2481    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2482        log::debug!("ssh send name:{}", T::NAME);
2483        self.send_dynamic(payload.into_envelope(0, None, None))
2484    }
2485
2486    fn request_dynamic(
2487        &self,
2488        mut envelope: proto::Envelope,
2489        type_name: &'static str,
2490        use_buffer: bool,
2491    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2492        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2493        let (tx, rx) = oneshot::channel();
2494        let mut response_channels_lock = self.response_channels.lock();
2495        response_channels_lock.insert(MessageId(envelope.id), tx);
2496        drop(response_channels_lock);
2497
2498        let result = if use_buffer {
2499            self.send_buffered(envelope)
2500        } else {
2501            self.send_unbuffered(envelope)
2502        };
2503        async move {
2504            if let Err(error) = &result {
2505                log::error!("failed to send message: {error}");
2506                anyhow::bail!("failed to send message: {error}");
2507            }
2508
2509            let response = rx.await.context("connection lost")?.0;
2510            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2511                return Err(RpcError::from_proto(error, type_name));
2512            }
2513            Ok(response)
2514        }
2515    }
2516
2517    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2518        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2519        self.send_buffered(envelope)
2520    }
2521
2522    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2523        envelope.ack_id = Some(self.max_received.load(SeqCst));
2524        self.buffer.lock().push_back(envelope.clone());
2525        // ignore errors on send (happen while we're reconnecting)
2526        // assume that the global "disconnected" overlay is sufficient.
2527        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2528        Ok(())
2529    }
2530
2531    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2532        envelope.ack_id = Some(self.max_received.load(SeqCst));
2533        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2534        Ok(())
2535    }
2536}
2537
2538impl ProtoClient for ChannelClient {
2539    fn request(
2540        &self,
2541        envelope: proto::Envelope,
2542        request_type: &'static str,
2543    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2544        self.request_dynamic(envelope, request_type, true).boxed()
2545    }
2546
2547    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2548        self.send_dynamic(envelope)
2549    }
2550
2551    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2552        self.send_dynamic(envelope)
2553    }
2554
2555    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2556        &self.message_handlers
2557    }
2558
2559    fn is_via_collab(&self) -> bool {
2560        false
2561    }
2562}
2563
2564#[cfg(any(test, feature = "test-support"))]
2565mod fake {
2566    use std::{path::PathBuf, sync::Arc};
2567
2568    use anyhow::Result;
2569    use async_trait::async_trait;
2570    use futures::{
2571        FutureExt, SinkExt, StreamExt,
2572        channel::{
2573            mpsc::{self, Sender},
2574            oneshot,
2575        },
2576        select_biased,
2577    };
2578    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2579    use release_channel::ReleaseChannel;
2580    use rpc::proto::Envelope;
2581    use util::paths::{PathStyle, RemotePathBuf};
2582
2583    use super::{
2584        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2585        SshPlatform,
2586    };
2587
2588    pub(super) struct FakeRemoteConnection {
2589        pub(super) connection_options: SshConnectionOptions,
2590        pub(super) server_channel: Arc<ChannelClient>,
2591        pub(super) server_cx: SendableCx,
2592    }
2593
2594    pub(super) struct SendableCx(AsyncApp);
2595    impl SendableCx {
2596        // SAFETY: When run in test mode, GPUI is always single threaded.
2597        pub(super) fn new(cx: &TestAppContext) -> Self {
2598            Self(cx.to_async())
2599        }
2600
2601        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2602        fn get(&self, _: &AsyncApp) -> AsyncApp {
2603            self.0.clone()
2604        }
2605    }
2606
2607    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2608    unsafe impl Send for SendableCx {}
2609    unsafe impl Sync for SendableCx {}
2610
2611    #[async_trait(?Send)]
2612    impl RemoteConnection for FakeRemoteConnection {
2613        async fn kill(&self) -> Result<()> {
2614            Ok(())
2615        }
2616
2617        fn has_been_killed(&self) -> bool {
2618            false
2619        }
2620
2621        fn ssh_args(&self) -> SshArgs {
2622            SshArgs {
2623                arguments: Vec::new(),
2624                envs: None,
2625            }
2626        }
2627
2628        fn upload_directory(
2629            &self,
2630            _src_path: PathBuf,
2631            _dest_path: RemotePathBuf,
2632            _cx: &App,
2633        ) -> Task<Result<()>> {
2634            unreachable!()
2635        }
2636
2637        fn connection_options(&self) -> SshConnectionOptions {
2638            self.connection_options.clone()
2639        }
2640
2641        fn simulate_disconnect(&self, cx: &AsyncApp) {
2642            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2643            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2644            self.server_channel
2645                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2646        }
2647
2648        fn start_proxy(
2649            &self,
2650            _unique_identifier: String,
2651            _reconnect: bool,
2652            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2653            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2654            mut connection_activity_tx: Sender<()>,
2655            _delegate: Arc<dyn SshClientDelegate>,
2656            cx: &mut AsyncApp,
2657        ) -> Task<Result<i32>> {
2658            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2659            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2660
2661            self.server_channel.reconnect(
2662                server_incoming_rx,
2663                server_outgoing_tx,
2664                &self.server_cx.get(cx),
2665            );
2666
2667            cx.background_spawn(async move {
2668                loop {
2669                    select_biased! {
2670                        server_to_client = server_outgoing_rx.next().fuse() => {
2671                            let Some(server_to_client) = server_to_client else {
2672                                return Ok(1)
2673                            };
2674                            connection_activity_tx.try_send(()).ok();
2675                            client_incoming_tx.send(server_to_client).await.ok();
2676                        }
2677                        client_to_server = client_outgoing_rx.next().fuse() => {
2678                            let Some(client_to_server) = client_to_server else {
2679                                return Ok(1)
2680                            };
2681                            server_incoming_tx.send(client_to_server).await.ok();
2682                        }
2683                    }
2684                }
2685            })
2686        }
2687
2688        fn path_style(&self) -> PathStyle {
2689            PathStyle::current()
2690        }
2691    }
2692
2693    pub(super) struct Delegate;
2694
2695    impl SshClientDelegate for Delegate {
2696        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2697            unreachable!()
2698        }
2699
2700        fn download_server_binary_locally(
2701            &self,
2702            _: SshPlatform,
2703            _: ReleaseChannel,
2704            _: Option<SemanticVersion>,
2705            _: &mut AsyncApp,
2706        ) -> Task<Result<PathBuf>> {
2707            unreachable!()
2708        }
2709
2710        fn get_download_params(
2711            &self,
2712            _platform: SshPlatform,
2713            _release_channel: ReleaseChannel,
2714            _version: Option<SemanticVersion>,
2715            _cx: &mut AsyncApp,
2716        ) -> Task<Result<Option<(String, String)>>> {
2717            unreachable!()
2718        }
2719
2720        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2721    }
2722}