ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  30    RpcError,
  31    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  32};
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize};
  35use smol::{
  36    fs,
  37    process::{self, Child, Stdio},
  38};
  39use std::{
  40    any::TypeId,
  41    collections::VecDeque,
  42    fmt, iter,
  43    ops::ControlFlow,
  44    path::{Path, PathBuf},
  45    sync::{
  46        Arc, Weak,
  47        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  48    },
  49    time::{Duration, Instant},
  50};
  51use tempfile::TempDir;
  52use util::{
  53    ResultExt,
  54    paths::{PathStyle, RemotePathBuf},
  55};
  56
  57#[derive(
  58    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  59)]
  60pub struct SshProjectId(pub u64);
  61
  62#[derive(Clone)]
  63pub struct SshSocket {
  64    connection_options: SshConnectionOptions,
  65    #[cfg(not(target_os = "windows"))]
  66    socket_path: PathBuf,
  67    #[cfg(target_os = "windows")]
  68    envs: HashMap<String, String>,
  69}
  70
  71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  72pub struct SshPortForwardOption {
  73    #[serde(skip_serializing_if = "Option::is_none")]
  74    pub local_host: Option<String>,
  75    pub local_port: u16,
  76    #[serde(skip_serializing_if = "Option::is_none")]
  77    pub remote_host: Option<String>,
  78    pub remote_port: u16,
  79}
  80
  81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  82pub struct SshConnectionOptions {
  83    pub host: String,
  84    pub username: Option<String>,
  85    pub port: Option<u16>,
  86    pub password: Option<String>,
  87    pub args: Option<Vec<String>>,
  88    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  89
  90    pub nickname: Option<String>,
  91    pub upload_binary_over_ssh: bool,
  92}
  93
  94pub struct SshArgs {
  95    pub arguments: Vec<String>,
  96    pub envs: Option<HashMap<String, String>>,
  97}
  98
  99#[macro_export]
 100macro_rules! shell_script {
 101    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 102        format!(
 103            $fmt,
 104            $(
 105                $name = shlex::try_quote($arg).unwrap()
 106            ),+
 107        )
 108    }};
 109}
 110
 111fn parse_port_number(port_str: &str) -> Result<u16> {
 112    port_str
 113        .parse()
 114        .with_context(|| format!("parsing port number: {port_str}"))
 115}
 116
 117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 118    let parts: Vec<&str> = spec.split(':').collect();
 119
 120    match parts.len() {
 121        4 => {
 122            let local_port = parse_port_number(parts[1])?;
 123            let remote_port = parse_port_number(parts[3])?;
 124
 125            Ok(SshPortForwardOption {
 126                local_host: Some(parts[0].to_string()),
 127                local_port,
 128                remote_host: Some(parts[2].to_string()),
 129                remote_port,
 130            })
 131        }
 132        3 => {
 133            let local_port = parse_port_number(parts[0])?;
 134            let remote_port = parse_port_number(parts[2])?;
 135
 136            Ok(SshPortForwardOption {
 137                local_host: None,
 138                local_port,
 139                remote_host: Some(parts[1].to_string()),
 140                remote_port,
 141            })
 142        }
 143        _ => anyhow::bail!("Invalid port forward format"),
 144    }
 145}
 146
 147impl SshConnectionOptions {
 148    pub fn parse_command_line(input: &str) -> Result<Self> {
 149        let input = input.trim_start_matches("ssh ");
 150        let mut hostname: Option<String> = None;
 151        let mut username: Option<String> = None;
 152        let mut port: Option<u16> = None;
 153        let mut args = Vec::new();
 154        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 155
 156        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 157        const ALLOWED_OPTS: &[&str] = &[
 158            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 159        ];
 160        const ALLOWED_ARGS: &[&str] = &[
 161            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 162            "-w",
 163        ];
 164
 165        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 166
 167        'outer: while let Some(arg) = tokens.next() {
 168            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 169                args.push(arg.to_string());
 170                continue;
 171            }
 172            if arg == "-p" {
 173                port = tokens.next().and_then(|arg| arg.parse().ok());
 174                continue;
 175            } else if let Some(p) = arg.strip_prefix("-p") {
 176                port = p.parse().ok();
 177                continue;
 178            }
 179            if arg == "-l" {
 180                username = tokens.next();
 181                continue;
 182            } else if let Some(l) = arg.strip_prefix("-l") {
 183                username = Some(l.to_string());
 184                continue;
 185            }
 186            if arg == "-L" || arg.starts_with("-L") {
 187                let forward_spec = if arg == "-L" {
 188                    tokens.next()
 189                } else {
 190                    Some(arg.strip_prefix("-L").unwrap().to_string())
 191                };
 192
 193                if let Some(spec) = forward_spec {
 194                    port_forwards.push(parse_port_forward_spec(&spec)?);
 195                } else {
 196                    anyhow::bail!("Missing port forward format");
 197                }
 198            }
 199
 200            for a in ALLOWED_ARGS {
 201                if arg == *a {
 202                    args.push(arg);
 203                    if let Some(next) = tokens.next() {
 204                        args.push(next);
 205                    }
 206                    continue 'outer;
 207                } else if arg.starts_with(a) {
 208                    args.push(arg);
 209                    continue 'outer;
 210                }
 211            }
 212            if arg.starts_with("-") || hostname.is_some() {
 213                anyhow::bail!("unsupported argument: {:?}", arg);
 214            }
 215            let mut input = &arg as &str;
 216            // Destination might be: username1@username2@ip2@ip1
 217            if let Some((u, rest)) = input.rsplit_once('@') {
 218                input = rest;
 219                username = Some(u.to_string());
 220            }
 221            if let Some((rest, p)) = input.split_once(':') {
 222                input = rest;
 223                port = p.parse().ok()
 224            }
 225            hostname = Some(input.to_string())
 226        }
 227
 228        let Some(hostname) = hostname else {
 229            anyhow::bail!("missing hostname");
 230        };
 231
 232        let port_forwards = match port_forwards.len() {
 233            0 => None,
 234            _ => Some(port_forwards),
 235        };
 236
 237        Ok(Self {
 238            host: hostname.to_string(),
 239            username: username.clone(),
 240            port,
 241            port_forwards,
 242            args: Some(args),
 243            password: None,
 244            nickname: None,
 245            upload_binary_over_ssh: false,
 246        })
 247    }
 248
 249    pub fn ssh_url(&self) -> String {
 250        let mut result = String::from("ssh://");
 251        if let Some(username) = &self.username {
 252            // Username might be: username1@username2@ip2
 253            let username = urlencoding::encode(username);
 254            result.push_str(&username);
 255            result.push('@');
 256        }
 257        result.push_str(&self.host);
 258        if let Some(port) = self.port {
 259            result.push(':');
 260            result.push_str(&port.to_string());
 261        }
 262        result
 263    }
 264
 265    pub fn additional_args(&self) -> Vec<String> {
 266        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 267
 268        if let Some(forwards) = &self.port_forwards {
 269            args.extend(forwards.iter().map(|pf| {
 270                let local_host = match &pf.local_host {
 271                    Some(host) => host,
 272                    None => "localhost",
 273                };
 274                let remote_host = match &pf.remote_host {
 275                    Some(host) => host,
 276                    None => "localhost",
 277                };
 278
 279                format!(
 280                    "-L{}:{}:{}:{}",
 281                    local_host, pf.local_port, remote_host, pf.remote_port
 282                )
 283            }));
 284        }
 285
 286        args
 287    }
 288
 289    fn scp_url(&self) -> String {
 290        if let Some(username) = &self.username {
 291            format!("{}@{}", username, self.host)
 292        } else {
 293            self.host.clone()
 294        }
 295    }
 296
 297    pub fn connection_string(&self) -> String {
 298        let host = if let Some(username) = &self.username {
 299            format!("{}@{}", username, self.host)
 300        } else {
 301            self.host.clone()
 302        };
 303        if let Some(port) = &self.port {
 304            format!("{}:{}", host, port)
 305        } else {
 306            host
 307        }
 308    }
 309}
 310
 311#[derive(Copy, Clone, Debug)]
 312pub struct SshPlatform {
 313    pub os: &'static str,
 314    pub arch: &'static str,
 315}
 316
 317pub trait SshClientDelegate: Send + Sync {
 318    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 319    fn get_download_params(
 320        &self,
 321        platform: SshPlatform,
 322        release_channel: ReleaseChannel,
 323        version: Option<SemanticVersion>,
 324        cx: &mut AsyncApp,
 325    ) -> Task<Result<Option<(String, String)>>>;
 326
 327    fn download_server_binary_locally(
 328        &self,
 329        platform: SshPlatform,
 330        release_channel: ReleaseChannel,
 331        version: Option<SemanticVersion>,
 332        cx: &mut AsyncApp,
 333    ) -> Task<Result<PathBuf>>;
 334    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 335}
 336
 337impl SshSocket {
 338    #[cfg(not(target_os = "windows"))]
 339    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 340        Ok(Self {
 341            connection_options: options,
 342            socket_path,
 343        })
 344    }
 345
 346    #[cfg(target_os = "windows")]
 347    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 348        let askpass_script = temp_dir.path().join("askpass.bat");
 349        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 350        let mut envs = HashMap::default();
 351        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 352        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 353        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 354        Ok(Self {
 355            connection_options: options,
 356            envs,
 357        })
 358    }
 359
 360    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 361    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 362    // and passes -l as an argument to sh, not to ls.
 363    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 364    // into a machine. You must use `cd` to get back to $HOME.
 365    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 366    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 367        let mut command = util::command::new_smol_command("ssh");
 368        let to_run = iter::once(&program)
 369            .chain(args.iter())
 370            .map(|token| {
 371                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 372                debug_assert!(
 373                    !token.contains('\n'),
 374                    "multiline arguments do not work in all shells"
 375                );
 376                shlex::try_quote(token).unwrap()
 377            })
 378            .join(" ");
 379        let to_run = format!("cd; {to_run}");
 380        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 381        self.ssh_options(&mut command)
 382            .arg(self.connection_options.ssh_url())
 383            .arg(to_run);
 384        command
 385    }
 386
 387    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 388        let output = self.ssh_command(program, args).output().await?;
 389        anyhow::ensure!(
 390            output.status.success(),
 391            "failed to run command: {}",
 392            String::from_utf8_lossy(&output.stderr)
 393        );
 394        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 395    }
 396
 397    #[cfg(not(target_os = "windows"))]
 398    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 399        command
 400            .stdin(Stdio::piped())
 401            .stdout(Stdio::piped())
 402            .stderr(Stdio::piped())
 403            .args(self.connection_options.additional_args())
 404            .args(["-o", "ControlMaster=no", "-o"])
 405            .arg(format!("ControlPath={}", self.socket_path.display()))
 406    }
 407
 408    #[cfg(target_os = "windows")]
 409    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 410        command
 411            .stdin(Stdio::piped())
 412            .stdout(Stdio::piped())
 413            .stderr(Stdio::piped())
 414            .args(self.connection_options.additional_args())
 415            .envs(self.envs.clone())
 416    }
 417
 418    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 419    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 420    #[cfg(not(target_os = "windows"))]
 421    fn ssh_args(&self) -> SshArgs {
 422        let mut arguments = self.connection_options.additional_args();
 423        arguments.extend(vec![
 424            "-o".to_string(),
 425            "ControlMaster=no".to_string(),
 426            "-o".to_string(),
 427            format!("ControlPath={}", self.socket_path.display()),
 428            self.connection_options.ssh_url(),
 429        ]);
 430        SshArgs {
 431            arguments,
 432            envs: None,
 433        }
 434    }
 435
 436    #[cfg(target_os = "windows")]
 437    fn ssh_args(&self) -> SshArgs {
 438        let mut arguments = self.connection_options.additional_args();
 439        arguments.push(self.connection_options.ssh_url());
 440        SshArgs {
 441            arguments,
 442            envs: Some(self.envs.clone()),
 443        }
 444    }
 445
 446    async fn platform(&self) -> Result<SshPlatform> {
 447        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 448        let Some((os, arch)) = uname.split_once(" ") else {
 449            anyhow::bail!("unknown uname: {uname:?}")
 450        };
 451
 452        let os = match os.trim() {
 453            "Darwin" => "macos",
 454            "Linux" => "linux",
 455            _ => anyhow::bail!(
 456                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 457            ),
 458        };
 459        // exclude armv5,6,7 as they are 32-bit.
 460        let arch = if arch.starts_with("armv8")
 461            || arch.starts_with("armv9")
 462            || arch.starts_with("arm64")
 463            || arch.starts_with("aarch64")
 464        {
 465            "aarch64"
 466        } else if arch.starts_with("x86") {
 467            "x86_64"
 468        } else {
 469            anyhow::bail!(
 470                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 471            )
 472        };
 473
 474        Ok(SshPlatform { os, arch })
 475    }
 476}
 477
 478const MAX_MISSED_HEARTBEATS: usize = 5;
 479const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 480const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 481
 482const MAX_RECONNECT_ATTEMPTS: usize = 3;
 483
 484enum State {
 485    Connecting,
 486    Connected {
 487        ssh_connection: Arc<dyn RemoteConnection>,
 488        delegate: Arc<dyn SshClientDelegate>,
 489
 490        multiplex_task: Task<Result<()>>,
 491        heartbeat_task: Task<Result<()>>,
 492    },
 493    HeartbeatMissed {
 494        missed_heartbeats: usize,
 495
 496        ssh_connection: Arc<dyn RemoteConnection>,
 497        delegate: Arc<dyn SshClientDelegate>,
 498
 499        multiplex_task: Task<Result<()>>,
 500        heartbeat_task: Task<Result<()>>,
 501    },
 502    Reconnecting,
 503    ReconnectFailed {
 504        ssh_connection: Arc<dyn RemoteConnection>,
 505        delegate: Arc<dyn SshClientDelegate>,
 506
 507        error: anyhow::Error,
 508        attempts: usize,
 509    },
 510    ReconnectExhausted,
 511    ServerNotRunning,
 512}
 513
 514impl fmt::Display for State {
 515    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 516        match self {
 517            Self::Connecting => write!(f, "connecting"),
 518            Self::Connected { .. } => write!(f, "connected"),
 519            Self::Reconnecting => write!(f, "reconnecting"),
 520            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 521            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 522            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 523            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 524        }
 525    }
 526}
 527
 528impl State {
 529    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 530        match self {
 531            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 532            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 533            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 534            _ => None,
 535        }
 536    }
 537
 538    fn can_reconnect(&self) -> bool {
 539        match self {
 540            Self::Connected { .. }
 541            | Self::HeartbeatMissed { .. }
 542            | Self::ReconnectFailed { .. } => true,
 543            State::Connecting
 544            | State::Reconnecting
 545            | State::ReconnectExhausted
 546            | State::ServerNotRunning => false,
 547        }
 548    }
 549
 550    fn is_reconnect_failed(&self) -> bool {
 551        matches!(self, Self::ReconnectFailed { .. })
 552    }
 553
 554    fn is_reconnect_exhausted(&self) -> bool {
 555        matches!(self, Self::ReconnectExhausted { .. })
 556    }
 557
 558    fn is_server_not_running(&self) -> bool {
 559        matches!(self, Self::ServerNotRunning)
 560    }
 561
 562    fn is_reconnecting(&self) -> bool {
 563        matches!(self, Self::Reconnecting { .. })
 564    }
 565
 566    fn heartbeat_recovered(self) -> Self {
 567        match self {
 568            Self::HeartbeatMissed {
 569                ssh_connection,
 570                delegate,
 571                multiplex_task,
 572                heartbeat_task,
 573                ..
 574            } => Self::Connected {
 575                ssh_connection,
 576                delegate,
 577                multiplex_task,
 578                heartbeat_task,
 579            },
 580            _ => self,
 581        }
 582    }
 583
 584    fn heartbeat_missed(self) -> Self {
 585        match self {
 586            Self::Connected {
 587                ssh_connection,
 588                delegate,
 589                multiplex_task,
 590                heartbeat_task,
 591            } => Self::HeartbeatMissed {
 592                missed_heartbeats: 1,
 593                ssh_connection,
 594                delegate,
 595                multiplex_task,
 596                heartbeat_task,
 597            },
 598            Self::HeartbeatMissed {
 599                missed_heartbeats,
 600                ssh_connection,
 601                delegate,
 602                multiplex_task,
 603                heartbeat_task,
 604            } => Self::HeartbeatMissed {
 605                missed_heartbeats: missed_heartbeats + 1,
 606                ssh_connection,
 607                delegate,
 608                multiplex_task,
 609                heartbeat_task,
 610            },
 611            _ => self,
 612        }
 613    }
 614}
 615
 616/// The state of the ssh connection.
 617#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 618pub enum ConnectionState {
 619    Connecting,
 620    Connected,
 621    HeartbeatMissed,
 622    Reconnecting,
 623    Disconnected,
 624}
 625
 626impl From<&State> for ConnectionState {
 627    fn from(value: &State) -> Self {
 628        match value {
 629            State::Connecting => Self::Connecting,
 630            State::Connected { .. } => Self::Connected,
 631            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 632            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 633            State::ReconnectExhausted => Self::Disconnected,
 634            State::ServerNotRunning => Self::Disconnected,
 635        }
 636    }
 637}
 638
 639pub struct SshRemoteClient {
 640    client: Arc<ChannelClient>,
 641    unique_identifier: String,
 642    connection_options: SshConnectionOptions,
 643    path_style: PathStyle,
 644    state: Arc<Mutex<Option<State>>>,
 645}
 646
 647#[derive(Debug)]
 648pub enum SshRemoteEvent {
 649    Disconnected,
 650}
 651
 652impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 653
 654// Identifies the socket on the remote server so that reconnects
 655// can re-join the same project.
 656pub enum ConnectionIdentifier {
 657    Setup(u64),
 658    Workspace(i64),
 659}
 660
 661static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 662
 663impl ConnectionIdentifier {
 664    pub fn setup() -> Self {
 665        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 666    }
 667    // This string gets used in a socket name, and so must be relatively short.
 668    // The total length of:
 669    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 670    // Must be less than about 100 characters
 671    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 672    // So our strings should be at most 20 characters or so.
 673    fn to_string(&self, cx: &App) -> String {
 674        let identifier_prefix = match ReleaseChannel::global(cx) {
 675            ReleaseChannel::Stable => "".to_string(),
 676            release_channel => format!("{}-", release_channel.dev_name()),
 677        };
 678        match self {
 679            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 680            Self::Workspace(workspace_id) => {
 681                format!("{identifier_prefix}workspace-{workspace_id}",)
 682            }
 683        }
 684    }
 685}
 686
 687impl SshRemoteClient {
 688    pub fn new(
 689        unique_identifier: ConnectionIdentifier,
 690        connection_options: SshConnectionOptions,
 691        cancellation: oneshot::Receiver<()>,
 692        delegate: Arc<dyn SshClientDelegate>,
 693        cx: &mut App,
 694    ) -> Task<Result<Option<Entity<Self>>>> {
 695        let unique_identifier = unique_identifier.to_string(cx);
 696        cx.spawn(async move |cx| {
 697            let success = Box::pin(async move {
 698                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 699                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 700                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 701
 702                let client =
 703                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 704
 705                let ssh_connection = cx
 706                    .update(|cx| {
 707                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 708                            pool.connect(connection_options.clone(), &delegate, cx)
 709                        })
 710                    })?
 711                    .await
 712                    .map_err(|e| e.cloned())?;
 713
 714                let path_style = ssh_connection.path_style();
 715                let this = cx.new(|_| Self {
 716                    client: client.clone(),
 717                    unique_identifier: unique_identifier.clone(),
 718                    connection_options,
 719                    path_style,
 720                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 721                })?;
 722
 723                let io_task = ssh_connection.start_proxy(
 724                    unique_identifier,
 725                    false,
 726                    incoming_tx,
 727                    outgoing_rx,
 728                    connection_activity_tx,
 729                    delegate.clone(),
 730                    cx,
 731                );
 732
 733                let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 734
 735                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 736                    log::error!("failed to establish connection: {}", error);
 737                    return Err(error);
 738                }
 739
 740                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 741
 742                this.update(cx, |this, _| {
 743                    *this.state.lock() = Some(State::Connected {
 744                        ssh_connection,
 745                        delegate,
 746                        multiplex_task,
 747                        heartbeat_task,
 748                    });
 749                })?;
 750
 751                Ok(Some(this))
 752            });
 753
 754            select! {
 755                _ = cancellation.fuse() => {
 756                    Ok(None)
 757                }
 758                result = success.fuse() =>  result
 759            }
 760        })
 761    }
 762
 763    pub fn shutdown_processes<T: RequestMessage>(
 764        &self,
 765        shutdown_request: Option<T>,
 766        executor: BackgroundExecutor,
 767    ) -> Option<impl Future<Output = ()> + use<T>> {
 768        let state = self.state.lock().take()?;
 769        log::info!("shutting down ssh processes");
 770
 771        let State::Connected {
 772            multiplex_task,
 773            heartbeat_task,
 774            ssh_connection,
 775            delegate,
 776        } = state
 777        else {
 778            return None;
 779        };
 780
 781        let client = self.client.clone();
 782
 783        Some(async move {
 784            if let Some(shutdown_request) = shutdown_request {
 785                client.send(shutdown_request).log_err();
 786                // We wait 50ms instead of waiting for a response, because
 787                // waiting for a response would require us to wait on the main thread
 788                // which we want to avoid in an `on_app_quit` callback.
 789                executor.timer(Duration::from_millis(50)).await;
 790            }
 791
 792            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 793            // child of master_process.
 794            drop(multiplex_task);
 795            // Now drop the rest of state, which kills master process.
 796            drop(heartbeat_task);
 797            drop(ssh_connection);
 798            drop(delegate);
 799        })
 800    }
 801
 802    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 803        let mut lock = self.state.lock();
 804
 805        let can_reconnect = lock
 806            .as_ref()
 807            .map(|state| state.can_reconnect())
 808            .unwrap_or(false);
 809        if !can_reconnect {
 810            log::info!("aborting reconnect, because not in state that allows reconnecting");
 811            let error = if let Some(state) = lock.as_ref() {
 812                format!("invalid state, cannot reconnect while in state {state}")
 813            } else {
 814                "no state set".to_string()
 815            };
 816            anyhow::bail!(error);
 817        }
 818
 819        let state = lock.take().unwrap();
 820        let (attempts, ssh_connection, delegate) = match state {
 821            State::Connected {
 822                ssh_connection,
 823                delegate,
 824                multiplex_task,
 825                heartbeat_task,
 826            }
 827            | State::HeartbeatMissed {
 828                ssh_connection,
 829                delegate,
 830                multiplex_task,
 831                heartbeat_task,
 832                ..
 833            } => {
 834                drop(multiplex_task);
 835                drop(heartbeat_task);
 836                (0, ssh_connection, delegate)
 837            }
 838            State::ReconnectFailed {
 839                attempts,
 840                ssh_connection,
 841                delegate,
 842                ..
 843            } => (attempts, ssh_connection, delegate),
 844            State::Connecting
 845            | State::Reconnecting
 846            | State::ReconnectExhausted
 847            | State::ServerNotRunning => unreachable!(),
 848        };
 849
 850        let attempts = attempts + 1;
 851        if attempts > MAX_RECONNECT_ATTEMPTS {
 852            log::error!(
 853                "Failed to reconnect to after {} attempts, giving up",
 854                MAX_RECONNECT_ATTEMPTS
 855            );
 856            drop(lock);
 857            self.set_state(State::ReconnectExhausted, cx);
 858            return Ok(());
 859        }
 860        drop(lock);
 861
 862        self.set_state(State::Reconnecting, cx);
 863
 864        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 865
 866        let unique_identifier = self.unique_identifier.clone();
 867        let client = self.client.clone();
 868        let reconnect_task = cx.spawn(async move |this, cx| {
 869            macro_rules! failed {
 870                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 871                    return State::ReconnectFailed {
 872                        error: anyhow!($error),
 873                        attempts: $attempts,
 874                        ssh_connection: $ssh_connection,
 875                        delegate: $delegate,
 876                    };
 877                };
 878            }
 879
 880            if let Err(error) = ssh_connection
 881                .kill()
 882                .await
 883                .context("Failed to kill ssh process")
 884            {
 885                failed!(error, attempts, ssh_connection, delegate);
 886            };
 887
 888            let connection_options = ssh_connection.connection_options();
 889
 890            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 891            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 892            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 893
 894            let (ssh_connection, io_task) = match async {
 895                let ssh_connection = cx
 896                    .update_global(|pool: &mut ConnectionPool, cx| {
 897                        pool.connect(connection_options, &delegate, cx)
 898                    })?
 899                    .await
 900                    .map_err(|error| error.cloned())?;
 901
 902                let io_task = ssh_connection.start_proxy(
 903                    unique_identifier,
 904                    true,
 905                    incoming_tx,
 906                    outgoing_rx,
 907                    connection_activity_tx,
 908                    delegate.clone(),
 909                    cx,
 910                );
 911                anyhow::Ok((ssh_connection, io_task))
 912            }
 913            .await
 914            {
 915                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 916                Err(error) => {
 917                    failed!(error, attempts, ssh_connection, delegate);
 918                }
 919            };
 920
 921            let multiplex_task = Self::monitor(this.clone(), io_task, cx);
 922            client.reconnect(incoming_rx, outgoing_tx, cx);
 923
 924            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 925                failed!(error, attempts, ssh_connection, delegate);
 926            };
 927
 928            State::Connected {
 929                ssh_connection,
 930                delegate,
 931                multiplex_task,
 932                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 933            }
 934        });
 935
 936        cx.spawn(async move |this, cx| {
 937            let new_state = reconnect_task.await;
 938            this.update(cx, |this, cx| {
 939                this.try_set_state(cx, |old_state| {
 940                    if old_state.is_reconnecting() {
 941                        match &new_state {
 942                            State::Connecting
 943                            | State::Reconnecting { .. }
 944                            | State::HeartbeatMissed { .. }
 945                            | State::ServerNotRunning => {}
 946                            State::Connected { .. } => {
 947                                log::info!("Successfully reconnected");
 948                            }
 949                            State::ReconnectFailed {
 950                                error, attempts, ..
 951                            } => {
 952                                log::error!(
 953                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 954                                    attempts,
 955                                    error
 956                                );
 957                            }
 958                            State::ReconnectExhausted => {
 959                                log::error!("Reconnect attempt failed and all attempts exhausted");
 960                            }
 961                        }
 962                        Some(new_state)
 963                    } else {
 964                        None
 965                    }
 966                });
 967
 968                if this.state_is(State::is_reconnect_failed) {
 969                    this.reconnect(cx)
 970                } else if this.state_is(State::is_reconnect_exhausted) {
 971                    Ok(())
 972                } else {
 973                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 974                    Ok(())
 975                }
 976            })
 977        })
 978        .detach_and_log_err(cx);
 979
 980        Ok(())
 981    }
 982
 983    fn heartbeat(
 984        this: WeakEntity<Self>,
 985        mut connection_activity_rx: mpsc::Receiver<()>,
 986        cx: &mut AsyncApp,
 987    ) -> Task<Result<()>> {
 988        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 989            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 990        };
 991
 992        cx.spawn(async move |cx| {
 993                let mut missed_heartbeats = 0;
 994
 995                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 996                futures::pin_mut!(keepalive_timer);
 997
 998                loop {
 999                    select_biased! {
1000                        result = connection_activity_rx.next().fuse() => {
1001                            if result.is_none() {
1002                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1003                                return Ok(());
1004                            }
1005
1006                            if missed_heartbeats != 0 {
1007                                missed_heartbeats = 0;
1008                                let _ =this.update(cx, |this, cx| {
1009                                    this.handle_heartbeat_result(missed_heartbeats, cx)
1010                                })?;
1011                            }
1012                        }
1013                        _ = keepalive_timer => {
1014                            log::debug!("Sending heartbeat to server...");
1015
1016                            let result = select_biased! {
1017                                _ = connection_activity_rx.next().fuse() => {
1018                                    Ok(())
1019                                }
1020                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1021                                    ping_result
1022                                }
1023                            };
1024
1025                            if result.is_err() {
1026                                missed_heartbeats += 1;
1027                                log::warn!(
1028                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1029                                    HEARTBEAT_TIMEOUT,
1030                                    missed_heartbeats,
1031                                    MAX_MISSED_HEARTBEATS
1032                                );
1033                            } else if missed_heartbeats != 0 {
1034                                missed_heartbeats = 0;
1035                            } else {
1036                                continue;
1037                            }
1038
1039                            let result = this.update(cx, |this, cx| {
1040                                this.handle_heartbeat_result(missed_heartbeats, cx)
1041                            })?;
1042                            if result.is_break() {
1043                                return Ok(());
1044                            }
1045                        }
1046                    }
1047
1048                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1049                }
1050
1051        })
1052    }
1053
1054    fn handle_heartbeat_result(
1055        &mut self,
1056        missed_heartbeats: usize,
1057        cx: &mut Context<Self>,
1058    ) -> ControlFlow<()> {
1059        let state = self.state.lock().take().unwrap();
1060        let next_state = if missed_heartbeats > 0 {
1061            state.heartbeat_missed()
1062        } else {
1063            state.heartbeat_recovered()
1064        };
1065
1066        self.set_state(next_state, cx);
1067
1068        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1069            log::error!(
1070                "Missed last {} heartbeats. Reconnecting...",
1071                missed_heartbeats
1072            );
1073
1074            self.reconnect(cx)
1075                .context("failed to start reconnect process after missing heartbeats")
1076                .log_err();
1077            ControlFlow::Break(())
1078        } else {
1079            ControlFlow::Continue(())
1080        }
1081    }
1082
1083    fn monitor(
1084        this: WeakEntity<Self>,
1085        io_task: Task<Result<i32>>,
1086        cx: &AsyncApp,
1087    ) -> Task<Result<()>> {
1088        cx.spawn(async move |cx| {
1089            let result = io_task.await;
1090
1091            match result {
1092                Ok(exit_code) => {
1093                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1094                        match error {
1095                            ProxyLaunchError::ServerNotRunning => {
1096                                log::error!("failed to reconnect because server is not running");
1097                                this.update(cx, |this, cx| {
1098                                    this.set_state(State::ServerNotRunning, cx);
1099                                })?;
1100                            }
1101                        }
1102                    } else if exit_code > 0 {
1103                        log::error!("proxy process terminated unexpectedly");
1104                        this.update(cx, |this, cx| {
1105                            this.reconnect(cx).ok();
1106                        })?;
1107                    }
1108                }
1109                Err(error) => {
1110                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1111                    this.update(cx, |this, cx| {
1112                        this.reconnect(cx).ok();
1113                    })?;
1114                }
1115            }
1116
1117            Ok(())
1118        })
1119    }
1120
1121    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1122        self.state.lock().as_ref().is_some_and(check)
1123    }
1124
1125    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1126        let mut lock = self.state.lock();
1127        let new_state = lock.as_ref().and_then(map);
1128
1129        if let Some(new_state) = new_state {
1130            lock.replace(new_state);
1131            cx.notify();
1132        }
1133    }
1134
1135    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1136        log::info!("setting state to '{}'", &state);
1137
1138        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1139        let is_server_not_running = state.is_server_not_running();
1140        self.state.lock().replace(state);
1141
1142        if is_reconnect_exhausted || is_server_not_running {
1143            cx.emit(SshRemoteEvent::Disconnected);
1144        }
1145        cx.notify();
1146    }
1147
1148    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1149        self.client.subscribe_to_entity(remote_id, entity);
1150    }
1151
1152    pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1153        self.state
1154            .lock()
1155            .as_ref()
1156            .and_then(|state| state.ssh_connection())
1157            .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1158    }
1159
1160    pub fn upload_directory(
1161        &self,
1162        src_path: PathBuf,
1163        dest_path: RemotePathBuf,
1164        cx: &App,
1165    ) -> Task<Result<()>> {
1166        let state = self.state.lock();
1167        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1168            return Task::ready(Err(anyhow!("no ssh connection")));
1169        };
1170        connection.upload_directory(src_path, dest_path, cx)
1171    }
1172
1173    pub fn proto_client(&self) -> AnyProtoClient {
1174        self.client.clone().into()
1175    }
1176
1177    pub fn connection_string(&self) -> String {
1178        self.connection_options.connection_string()
1179    }
1180
1181    pub fn connection_options(&self) -> SshConnectionOptions {
1182        self.connection_options.clone()
1183    }
1184
1185    pub fn connection_state(&self) -> ConnectionState {
1186        self.state
1187            .lock()
1188            .as_ref()
1189            .map(ConnectionState::from)
1190            .unwrap_or(ConnectionState::Disconnected)
1191    }
1192
1193    pub fn is_disconnected(&self) -> bool {
1194        self.connection_state() == ConnectionState::Disconnected
1195    }
1196
1197    pub fn path_style(&self) -> PathStyle {
1198        self.path_style
1199    }
1200
1201    #[cfg(any(test, feature = "test-support"))]
1202    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1203        let opts = self.connection_options();
1204        client_cx.spawn(async move |cx| {
1205            let connection = cx
1206                .update_global(|c: &mut ConnectionPool, _| {
1207                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1208                        c.clone()
1209                    } else {
1210                        panic!("missing test connection")
1211                    }
1212                })
1213                .unwrap()
1214                .await
1215                .unwrap();
1216
1217            connection.simulate_disconnect(cx);
1218        })
1219    }
1220
1221    #[cfg(any(test, feature = "test-support"))]
1222    pub fn fake_server(
1223        client_cx: &mut gpui::TestAppContext,
1224        server_cx: &mut gpui::TestAppContext,
1225    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1226        let port = client_cx
1227            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1228        let opts = SshConnectionOptions {
1229            host: "<fake>".to_string(),
1230            port: Some(port),
1231            ..Default::default()
1232        };
1233        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1234        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1235        let server_client =
1236            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1237        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1238            connection_options: opts.clone(),
1239            server_cx: fake::SendableCx::new(server_cx),
1240            server_channel: server_client.clone(),
1241        });
1242
1243        client_cx.update(|cx| {
1244            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1245                c.connections.insert(
1246                    opts.clone(),
1247                    ConnectionPoolEntry::Connecting(
1248                        cx.background_spawn({
1249                            let connection = connection.clone();
1250                            async move { Ok(connection.clone()) }
1251                        })
1252                        .shared(),
1253                    ),
1254                );
1255            })
1256        });
1257
1258        (opts, server_client)
1259    }
1260
1261    #[cfg(any(test, feature = "test-support"))]
1262    pub async fn fake_client(
1263        opts: SshConnectionOptions,
1264        client_cx: &mut gpui::TestAppContext,
1265    ) -> Entity<Self> {
1266        let (_tx, rx) = oneshot::channel();
1267        client_cx
1268            .update(|cx| {
1269                Self::new(
1270                    ConnectionIdentifier::setup(),
1271                    opts,
1272                    rx,
1273                    Arc::new(fake::Delegate),
1274                    cx,
1275                )
1276            })
1277            .await
1278            .unwrap()
1279            .unwrap()
1280    }
1281}
1282
1283enum ConnectionPoolEntry {
1284    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1285    Connected(Weak<dyn RemoteConnection>),
1286}
1287
1288#[derive(Default)]
1289struct ConnectionPool {
1290    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1291}
1292
1293impl Global for ConnectionPool {}
1294
1295impl ConnectionPool {
1296    pub fn connect(
1297        &mut self,
1298        opts: SshConnectionOptions,
1299        delegate: &Arc<dyn SshClientDelegate>,
1300        cx: &mut App,
1301    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1302        let connection = self.connections.get(&opts);
1303        match connection {
1304            Some(ConnectionPoolEntry::Connecting(task)) => {
1305                let delegate = delegate.clone();
1306                cx.spawn(async move |cx| {
1307                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1308                })
1309                .detach();
1310                return task.clone();
1311            }
1312            Some(ConnectionPoolEntry::Connected(ssh)) => {
1313                if let Some(ssh) = ssh.upgrade()
1314                    && !ssh.has_been_killed()
1315                {
1316                    return Task::ready(Ok(ssh)).shared();
1317                }
1318                self.connections.remove(&opts);
1319            }
1320            None => {}
1321        }
1322
1323        let task = cx
1324            .spawn({
1325                let opts = opts.clone();
1326                let delegate = delegate.clone();
1327                async move |cx| {
1328                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1329                        .await
1330                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1331
1332                    cx.update_global(|pool: &mut Self, _| {
1333                        debug_assert!(matches!(
1334                            pool.connections.get(&opts),
1335                            Some(ConnectionPoolEntry::Connecting(_))
1336                        ));
1337                        match connection {
1338                            Ok(connection) => {
1339                                pool.connections.insert(
1340                                    opts.clone(),
1341                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1342                                );
1343                                Ok(connection)
1344                            }
1345                            Err(error) => {
1346                                pool.connections.remove(&opts);
1347                                Err(Arc::new(error))
1348                            }
1349                        }
1350                    })?
1351                }
1352            })
1353            .shared();
1354
1355        self.connections
1356            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1357        task
1358    }
1359}
1360
1361impl From<SshRemoteClient> for AnyProtoClient {
1362    fn from(client: SshRemoteClient) -> Self {
1363        AnyProtoClient::new(client.client.clone())
1364    }
1365}
1366
1367#[async_trait(?Send)]
1368trait RemoteConnection: Send + Sync {
1369    fn start_proxy(
1370        &self,
1371        unique_identifier: String,
1372        reconnect: bool,
1373        incoming_tx: UnboundedSender<Envelope>,
1374        outgoing_rx: UnboundedReceiver<Envelope>,
1375        connection_activity_tx: Sender<()>,
1376        delegate: Arc<dyn SshClientDelegate>,
1377        cx: &mut AsyncApp,
1378    ) -> Task<Result<i32>>;
1379    fn upload_directory(
1380        &self,
1381        src_path: PathBuf,
1382        dest_path: RemotePathBuf,
1383        cx: &App,
1384    ) -> Task<Result<()>>;
1385    async fn kill(&self) -> Result<()>;
1386    fn has_been_killed(&self) -> bool;
1387    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1388    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1389    fn ssh_args(&self) -> SshArgs;
1390    fn connection_options(&self) -> SshConnectionOptions;
1391    fn path_style(&self) -> PathStyle;
1392
1393    #[cfg(any(test, feature = "test-support"))]
1394    fn simulate_disconnect(&self, _: &AsyncApp) {}
1395}
1396
1397struct SshRemoteConnection {
1398    socket: SshSocket,
1399    master_process: Mutex<Option<Child>>,
1400    remote_binary_path: Option<RemotePathBuf>,
1401    ssh_platform: SshPlatform,
1402    ssh_path_style: PathStyle,
1403    _temp_dir: TempDir,
1404}
1405
1406#[async_trait(?Send)]
1407impl RemoteConnection for SshRemoteConnection {
1408    async fn kill(&self) -> Result<()> {
1409        let Some(mut process) = self.master_process.lock().take() else {
1410            return Ok(());
1411        };
1412        process.kill().ok();
1413        process.status().await?;
1414        Ok(())
1415    }
1416
1417    fn has_been_killed(&self) -> bool {
1418        self.master_process.lock().is_none()
1419    }
1420
1421    fn ssh_args(&self) -> SshArgs {
1422        self.socket.ssh_args()
1423    }
1424
1425    fn connection_options(&self) -> SshConnectionOptions {
1426        self.socket.connection_options.clone()
1427    }
1428
1429    fn upload_directory(
1430        &self,
1431        src_path: PathBuf,
1432        dest_path: RemotePathBuf,
1433        cx: &App,
1434    ) -> Task<Result<()>> {
1435        let mut command = util::command::new_smol_command("scp");
1436        let output = self
1437            .socket
1438            .ssh_options(&mut command)
1439            .args(
1440                self.socket
1441                    .connection_options
1442                    .port
1443                    .map(|port| vec!["-P".to_string(), port.to_string()])
1444                    .unwrap_or_default(),
1445            )
1446            .arg("-C")
1447            .arg("-r")
1448            .arg(&src_path)
1449            .arg(format!(
1450                "{}:{}",
1451                self.socket.connection_options.scp_url(),
1452                dest_path.to_string()
1453            ))
1454            .output();
1455
1456        cx.background_spawn(async move {
1457            let output = output.await?;
1458
1459            anyhow::ensure!(
1460                output.status.success(),
1461                "failed to upload directory {} -> {}: {}",
1462                src_path.display(),
1463                dest_path.to_string(),
1464                String::from_utf8_lossy(&output.stderr)
1465            );
1466
1467            Ok(())
1468        })
1469    }
1470
1471    fn start_proxy(
1472        &self,
1473        unique_identifier: String,
1474        reconnect: bool,
1475        incoming_tx: UnboundedSender<Envelope>,
1476        outgoing_rx: UnboundedReceiver<Envelope>,
1477        connection_activity_tx: Sender<()>,
1478        delegate: Arc<dyn SshClientDelegate>,
1479        cx: &mut AsyncApp,
1480    ) -> Task<Result<i32>> {
1481        delegate.set_status(Some("Starting proxy"), cx);
1482
1483        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1484            return Task::ready(Err(anyhow!("Remote binary path not set")));
1485        };
1486
1487        let mut start_proxy_command = shell_script!(
1488            "exec {binary_path} proxy --identifier {identifier}",
1489            binary_path = &remote_binary_path.to_string(),
1490            identifier = &unique_identifier,
1491        );
1492
1493        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
1494            if let Some(value) = std::env::var(env_var).ok() {
1495                start_proxy_command = format!(
1496                    "{}={} {} ",
1497                    env_var,
1498                    shlex::try_quote(&value).unwrap(),
1499                    start_proxy_command,
1500                );
1501            }
1502        }
1503
1504        if reconnect {
1505            start_proxy_command.push_str(" --reconnect");
1506        }
1507
1508        let ssh_proxy_process = match self
1509            .socket
1510            .ssh_command("sh", &["-c", &start_proxy_command])
1511            // IMPORTANT: we kill this process when we drop the task that uses it.
1512            .kill_on_drop(true)
1513            .spawn()
1514        {
1515            Ok(process) => process,
1516            Err(error) => {
1517                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1518            }
1519        };
1520
1521        Self::multiplex(
1522            ssh_proxy_process,
1523            incoming_tx,
1524            outgoing_rx,
1525            connection_activity_tx,
1526            cx,
1527        )
1528    }
1529
1530    fn path_style(&self) -> PathStyle {
1531        self.ssh_path_style
1532    }
1533}
1534
1535impl SshRemoteConnection {
1536    async fn new(
1537        connection_options: SshConnectionOptions,
1538        delegate: Arc<dyn SshClientDelegate>,
1539        cx: &mut AsyncApp,
1540    ) -> Result<Self> {
1541        use askpass::AskPassResult;
1542
1543        delegate.set_status(Some("Connecting"), cx);
1544
1545        let url = connection_options.ssh_url();
1546
1547        let temp_dir = tempfile::Builder::new()
1548            .prefix("zed-ssh-session")
1549            .tempdir()?;
1550        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1551            let delegate = delegate.clone();
1552            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1553        });
1554
1555        let mut askpass =
1556            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1557
1558        // Start the master SSH process, which does not do anything except for establish
1559        // the connection and keep it open, allowing other ssh commands to reuse it
1560        // via a control socket.
1561        #[cfg(not(target_os = "windows"))]
1562        let socket_path = temp_dir.path().join("ssh.sock");
1563
1564        let mut master_process = {
1565            #[cfg(not(target_os = "windows"))]
1566            let args = [
1567                "-N",
1568                "-o",
1569                "ControlPersist=no",
1570                "-o",
1571                "ControlMaster=yes",
1572                "-o",
1573            ];
1574            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1575            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1576            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1577            #[cfg(target_os = "windows")]
1578            let args = ["-N"];
1579            let mut master_process = util::command::new_smol_command("ssh");
1580            master_process
1581                .kill_on_drop(true)
1582                .stdin(Stdio::null())
1583                .stdout(Stdio::piped())
1584                .stderr(Stdio::piped())
1585                .env("SSH_ASKPASS_REQUIRE", "force")
1586                .env("SSH_ASKPASS", askpass.script_path())
1587                .args(connection_options.additional_args())
1588                .args(args);
1589            #[cfg(not(target_os = "windows"))]
1590            master_process.arg(format!("ControlPath={}", socket_path.display()));
1591            master_process.arg(&url).spawn()?
1592        };
1593        // Wait for this ssh process to close its stdout, indicating that authentication
1594        // has completed.
1595        let mut stdout = master_process.stdout.take().unwrap();
1596        let mut output = Vec::new();
1597
1598        let result = select_biased! {
1599            result = askpass.run().fuse() => {
1600                match result {
1601                    AskPassResult::CancelledByUser => {
1602                        master_process.kill().ok();
1603                        anyhow::bail!("SSH connection canceled")
1604                    }
1605                    AskPassResult::Timedout => {
1606                        anyhow::bail!("connecting to host timed out")
1607                    }
1608                }
1609            }
1610            _ = stdout.read_to_end(&mut output).fuse() => {
1611                anyhow::Ok(())
1612            }
1613        };
1614
1615        if let Err(e) = result {
1616            return Err(e.context("Failed to connect to host"));
1617        }
1618
1619        if master_process.try_status()?.is_some() {
1620            output.clear();
1621            let mut stderr = master_process.stderr.take().unwrap();
1622            stderr.read_to_end(&mut output).await?;
1623
1624            let error_message = format!(
1625                "failed to connect: {}",
1626                String::from_utf8_lossy(&output).trim()
1627            );
1628            anyhow::bail!(error_message);
1629        }
1630
1631        #[cfg(not(target_os = "windows"))]
1632        let socket = SshSocket::new(connection_options, socket_path)?;
1633        #[cfg(target_os = "windows")]
1634        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1635        drop(askpass);
1636
1637        let ssh_platform = socket.platform().await?;
1638        let ssh_path_style = match ssh_platform.os {
1639            "windows" => PathStyle::Windows,
1640            _ => PathStyle::Posix,
1641        };
1642
1643        let mut this = Self {
1644            socket,
1645            master_process: Mutex::new(Some(master_process)),
1646            _temp_dir: temp_dir,
1647            remote_binary_path: None,
1648            ssh_path_style,
1649            ssh_platform,
1650        };
1651
1652        let (release_channel, version, commit) = cx.update(|cx| {
1653            (
1654                ReleaseChannel::global(cx),
1655                AppVersion::global(cx),
1656                AppCommitSha::try_global(cx),
1657            )
1658        })?;
1659        this.remote_binary_path = Some(
1660            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1661                .await?,
1662        );
1663
1664        Ok(this)
1665    }
1666
1667    fn multiplex(
1668        mut ssh_proxy_process: Child,
1669        incoming_tx: UnboundedSender<Envelope>,
1670        mut outgoing_rx: UnboundedReceiver<Envelope>,
1671        mut connection_activity_tx: Sender<()>,
1672        cx: &AsyncApp,
1673    ) -> Task<Result<i32>> {
1674        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1675        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1676        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1677
1678        let mut stdin_buffer = Vec::new();
1679        let mut stdout_buffer = Vec::new();
1680        let mut stderr_buffer = Vec::new();
1681        let mut stderr_offset = 0;
1682
1683        let stdin_task = cx.background_spawn(async move {
1684            while let Some(outgoing) = outgoing_rx.next().await {
1685                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1686            }
1687            anyhow::Ok(())
1688        });
1689
1690        let stdout_task = cx.background_spawn({
1691            let mut connection_activity_tx = connection_activity_tx.clone();
1692            async move {
1693                loop {
1694                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1695                    let len = child_stdout.read(&mut stdout_buffer).await?;
1696
1697                    if len == 0 {
1698                        return anyhow::Ok(());
1699                    }
1700
1701                    if len < MESSAGE_LEN_SIZE {
1702                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1703                    }
1704
1705                    let message_len = message_len_from_buffer(&stdout_buffer);
1706                    let envelope =
1707                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1708                            .await?;
1709                    connection_activity_tx.try_send(()).ok();
1710                    incoming_tx.unbounded_send(envelope).ok();
1711                }
1712            }
1713        });
1714
1715        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1716            loop {
1717                stderr_buffer.resize(stderr_offset + 1024, 0);
1718
1719                let len = child_stderr
1720                    .read(&mut stderr_buffer[stderr_offset..])
1721                    .await?;
1722                if len == 0 {
1723                    return anyhow::Ok(());
1724                }
1725
1726                stderr_offset += len;
1727                let mut start_ix = 0;
1728                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1729                    .iter()
1730                    .position(|b| b == &b'\n')
1731                {
1732                    let line_ix = start_ix + ix;
1733                    let content = &stderr_buffer[start_ix..line_ix];
1734                    start_ix = line_ix + 1;
1735                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1736                        record.log(log::logger())
1737                    } else {
1738                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1739                    }
1740                }
1741                stderr_buffer.drain(0..start_ix);
1742                stderr_offset -= start_ix;
1743
1744                connection_activity_tx.try_send(()).ok();
1745            }
1746        });
1747
1748        cx.background_spawn(async move {
1749            let result = futures::select! {
1750                result = stdin_task.fuse() => {
1751                    result.context("stdin")
1752                }
1753                result = stdout_task.fuse() => {
1754                    result.context("stdout")
1755                }
1756                result = stderr_task.fuse() => {
1757                    result.context("stderr")
1758                }
1759            };
1760
1761            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1762            match result {
1763                Ok(_) => Ok(status),
1764                Err(error) => Err(error),
1765            }
1766        })
1767    }
1768
1769    #[allow(unused)]
1770    async fn ensure_server_binary(
1771        &self,
1772        delegate: &Arc<dyn SshClientDelegate>,
1773        release_channel: ReleaseChannel,
1774        version: SemanticVersion,
1775        commit: Option<AppCommitSha>,
1776        cx: &mut AsyncApp,
1777    ) -> Result<RemotePathBuf> {
1778        let version_str = match release_channel {
1779            ReleaseChannel::Nightly => {
1780                let commit = commit.map(|s| s.full()).unwrap_or_default();
1781                format!("{}-{}", version, commit)
1782            }
1783            ReleaseChannel::Dev => "build".to_string(),
1784            _ => version.to_string(),
1785        };
1786        let binary_name = format!(
1787            "zed-remote-server-{}-{}",
1788            release_channel.dev_name(),
1789            version_str
1790        );
1791        let dst_path = RemotePathBuf::new(
1792            paths::remote_server_dir_relative().join(binary_name),
1793            self.ssh_path_style,
1794        );
1795
1796        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1797        #[cfg(debug_assertions)]
1798        if let Some(build_remote_server) = build_remote_server {
1799            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1800            let tmp_path = RemotePathBuf::new(
1801                paths::remote_server_dir_relative().join(format!(
1802                    "download-{}-{}",
1803                    std::process::id(),
1804                    src_path.file_name().unwrap().to_string_lossy()
1805                )),
1806                self.ssh_path_style,
1807            );
1808            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1809                .await?;
1810            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1811                .await?;
1812            return Ok(dst_path);
1813        }
1814
1815        if self
1816            .socket
1817            .run_command(&dst_path.to_string(), &["version"])
1818            .await
1819            .is_ok()
1820        {
1821            return Ok(dst_path);
1822        }
1823
1824        let wanted_version = cx.update(|cx| match release_channel {
1825            ReleaseChannel::Nightly => Ok(None),
1826            ReleaseChannel::Dev => {
1827                anyhow::bail!(
1828                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1829                    dst_path
1830                )
1831            }
1832            _ => Ok(Some(AppVersion::global(cx))),
1833        })??;
1834
1835        let tmp_path_gz = RemotePathBuf::new(
1836            PathBuf::from(format!(
1837                "{}-download-{}.gz",
1838                dst_path.to_string(),
1839                std::process::id()
1840            )),
1841            self.ssh_path_style,
1842        );
1843        if !self.socket.connection_options.upload_binary_over_ssh
1844            && let Some((url, body)) = delegate
1845                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1846                .await?
1847        {
1848            match self
1849                .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1850                .await
1851            {
1852                Ok(_) => {
1853                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1854                        .await?;
1855                    return Ok(dst_path);
1856                }
1857                Err(e) => {
1858                    log::error!(
1859                        "Failed to download binary on server, attempting to upload server: {}",
1860                        e
1861                    )
1862                }
1863            }
1864        }
1865
1866        let src_path = delegate
1867            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1868            .await?;
1869        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1870            .await?;
1871        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1872            .await?;
1873        Ok(dst_path)
1874    }
1875
1876    async fn download_binary_on_server(
1877        &self,
1878        url: &str,
1879        body: &str,
1880        tmp_path_gz: &RemotePathBuf,
1881        delegate: &Arc<dyn SshClientDelegate>,
1882        cx: &mut AsyncApp,
1883    ) -> Result<()> {
1884        if let Some(parent) = tmp_path_gz.parent() {
1885            self.socket
1886                .run_command(
1887                    "sh",
1888                    &[
1889                        "-c",
1890                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1891                    ],
1892                )
1893                .await?;
1894        }
1895
1896        delegate.set_status(Some("Downloading remote development server on host"), cx);
1897
1898        match self
1899            .socket
1900            .run_command(
1901                "curl",
1902                &[
1903                    "-f",
1904                    "-L",
1905                    "-X",
1906                    "GET",
1907                    "-H",
1908                    "Content-Type: application/json",
1909                    "-d",
1910                    body,
1911                    url,
1912                    "-o",
1913                    &tmp_path_gz.to_string(),
1914                ],
1915            )
1916            .await
1917        {
1918            Ok(_) => {}
1919            Err(e) => {
1920                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1921                    return Err(e);
1922                }
1923
1924                match self
1925                    .socket
1926                    .run_command(
1927                        "wget",
1928                        &[
1929                            "--method=GET",
1930                            "--header=Content-Type: application/json",
1931                            "--body-data",
1932                            body,
1933                            url,
1934                            "-O",
1935                            &tmp_path_gz.to_string(),
1936                        ],
1937                    )
1938                    .await
1939                {
1940                    Ok(_) => {}
1941                    Err(e) => {
1942                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1943                            return Err(e);
1944                        } else {
1945                            anyhow::bail!("Neither curl nor wget is available");
1946                        }
1947                    }
1948                }
1949            }
1950        }
1951
1952        Ok(())
1953    }
1954
1955    async fn upload_local_server_binary(
1956        &self,
1957        src_path: &Path,
1958        tmp_path_gz: &RemotePathBuf,
1959        delegate: &Arc<dyn SshClientDelegate>,
1960        cx: &mut AsyncApp,
1961    ) -> Result<()> {
1962        if let Some(parent) = tmp_path_gz.parent() {
1963            self.socket
1964                .run_command(
1965                    "sh",
1966                    &[
1967                        "-c",
1968                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1969                    ],
1970                )
1971                .await?;
1972        }
1973
1974        let src_stat = fs::metadata(&src_path).await?;
1975        let size = src_stat.len();
1976
1977        let t0 = Instant::now();
1978        delegate.set_status(Some("Uploading remote development server"), cx);
1979        log::info!(
1980            "uploading remote development server to {:?} ({}kb)",
1981            tmp_path_gz,
1982            size / 1024
1983        );
1984        self.upload_file(src_path, tmp_path_gz)
1985            .await
1986            .context("failed to upload server binary")?;
1987        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1988        Ok(())
1989    }
1990
1991    async fn extract_server_binary(
1992        &self,
1993        dst_path: &RemotePathBuf,
1994        tmp_path: &RemotePathBuf,
1995        delegate: &Arc<dyn SshClientDelegate>,
1996        cx: &mut AsyncApp,
1997    ) -> Result<()> {
1998        delegate.set_status(Some("Extracting remote development server"), cx);
1999        let server_mode = 0o755;
2000
2001        let orig_tmp_path = tmp_path.to_string();
2002        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2003            shell_script!(
2004                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2005                server_mode = &format!("{:o}", server_mode),
2006                dst_path = &dst_path.to_string(),
2007            )
2008        } else {
2009            shell_script!(
2010                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2011                server_mode = &format!("{:o}", server_mode),
2012                dst_path = &dst_path.to_string()
2013            )
2014        };
2015        self.socket.run_command("sh", &["-c", &script]).await?;
2016        Ok(())
2017    }
2018
2019    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2020        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2021        let mut command = util::command::new_smol_command("scp");
2022        let output = self
2023            .socket
2024            .ssh_options(&mut command)
2025            .args(
2026                self.socket
2027                    .connection_options
2028                    .port
2029                    .map(|port| vec!["-P".to_string(), port.to_string()])
2030                    .unwrap_or_default(),
2031            )
2032            .arg(src_path)
2033            .arg(format!(
2034                "{}:{}",
2035                self.socket.connection_options.scp_url(),
2036                dest_path.to_string()
2037            ))
2038            .output()
2039            .await?;
2040
2041        anyhow::ensure!(
2042            output.status.success(),
2043            "failed to upload file {} -> {}: {}",
2044            src_path.display(),
2045            dest_path.to_string(),
2046            String::from_utf8_lossy(&output.stderr)
2047        );
2048        Ok(())
2049    }
2050
2051    #[cfg(debug_assertions)]
2052    async fn build_local(
2053        &self,
2054        build_remote_server: String,
2055        delegate: &Arc<dyn SshClientDelegate>,
2056        cx: &mut AsyncApp,
2057    ) -> Result<PathBuf> {
2058        use smol::process::{Command, Stdio};
2059        use std::env::VarError;
2060
2061        async fn run_cmd(command: &mut Command) -> Result<()> {
2062            let output = command
2063                .kill_on_drop(true)
2064                .stderr(Stdio::inherit())
2065                .output()
2066                .await?;
2067            anyhow::ensure!(
2068                output.status.success(),
2069                "Failed to run command: {command:?}"
2070            );
2071            Ok(())
2072        }
2073
2074        let use_musl = !build_remote_server.contains("nomusl");
2075        let triple = format!(
2076            "{}-{}",
2077            self.ssh_platform.arch,
2078            match self.ssh_platform.os {
2079                "linux" =>
2080                    if use_musl {
2081                        "unknown-linux-musl"
2082                    } else {
2083                        "unknown-linux-gnu"
2084                    },
2085                "macos" => "apple-darwin",
2086                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2087            }
2088        );
2089        let mut rust_flags = match std::env::var("RUSTFLAGS") {
2090            Ok(val) => val,
2091            Err(VarError::NotPresent) => String::new(),
2092            Err(e) => {
2093                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2094                String::new()
2095            }
2096        };
2097        if self.ssh_platform.os == "linux" && use_musl {
2098            rust_flags.push_str(" -C target-feature=+crt-static");
2099        }
2100        if build_remote_server.contains("mold") {
2101            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2102        }
2103
2104        if self.ssh_platform.arch == std::env::consts::ARCH
2105            && self.ssh_platform.os == std::env::consts::OS
2106        {
2107            delegate.set_status(Some("Building remote server binary from source"), cx);
2108            log::info!("building remote server binary from source");
2109            run_cmd(
2110                Command::new("cargo")
2111                    .args([
2112                        "build",
2113                        "--package",
2114                        "remote_server",
2115                        "--features",
2116                        "debug-embed",
2117                        "--target-dir",
2118                        "target/remote_server",
2119                        "--target",
2120                        &triple,
2121                    ])
2122                    .env("RUSTFLAGS", &rust_flags),
2123            )
2124            .await?;
2125        } else {
2126            if build_remote_server.contains("cross") {
2127                #[cfg(target_os = "windows")]
2128                use util::paths::SanitizedPath;
2129
2130                delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2131                log::info!("installing cross");
2132                run_cmd(Command::new("cargo").args([
2133                    "install",
2134                    "cross",
2135                    "--git",
2136                    "https://github.com/cross-rs/cross",
2137                ]))
2138                .await?;
2139
2140                delegate.set_status(
2141                    Some(&format!(
2142                        "Building remote server binary from source for {} with Docker",
2143                        &triple
2144                    )),
2145                    cx,
2146                );
2147                log::info!("building remote server binary from source for {}", &triple);
2148
2149                // On Windows, the binding needs to be set to the canonical path
2150                #[cfg(target_os = "windows")]
2151                let src =
2152                    SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2153                #[cfg(not(target_os = "windows"))]
2154                let src = "./target";
2155                run_cmd(
2156                    Command::new("cross")
2157                        .args([
2158                            "build",
2159                            "--package",
2160                            "remote_server",
2161                            "--features",
2162                            "debug-embed",
2163                            "--target-dir",
2164                            "target/remote_server",
2165                            "--target",
2166                            &triple,
2167                        ])
2168                        .env(
2169                            "CROSS_CONTAINER_OPTS",
2170                            format!("--mount type=bind,src={src},dst=/app/target"),
2171                        )
2172                        .env("RUSTFLAGS", &rust_flags),
2173                )
2174                .await?;
2175            } else {
2176                let which = cx
2177                    .background_spawn(async move { which::which("zig") })
2178                    .await;
2179
2180                if which.is_err() {
2181                    #[cfg(not(target_os = "windows"))]
2182                    {
2183                        anyhow::bail!(
2184                            "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2185                        )
2186                    }
2187                    #[cfg(target_os = "windows")]
2188                    {
2189                        anyhow::bail!(
2190                            "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2191                        )
2192                    }
2193                }
2194
2195                delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2196                log::info!("adding rustup target");
2197                run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2198
2199                delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2200                log::info!("installing cargo-zigbuild");
2201                run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"]))
2202                    .await?;
2203
2204                delegate.set_status(
2205                    Some(&format!(
2206                        "Building remote binary from source for {triple} with Zig"
2207                    )),
2208                    cx,
2209                );
2210                log::info!("building remote binary from source for {triple} with Zig");
2211                run_cmd(
2212                    Command::new("cargo")
2213                        .args([
2214                            "zigbuild",
2215                            "--package",
2216                            "remote_server",
2217                            "--features",
2218                            "debug-embed",
2219                            "--target-dir",
2220                            "target/remote_server",
2221                            "--target",
2222                            &triple,
2223                        ])
2224                        .env("RUSTFLAGS", &rust_flags),
2225                )
2226                .await?;
2227            }
2228        };
2229        let bin_path = Path::new("target")
2230            .join("remote_server")
2231            .join(&triple)
2232            .join("debug")
2233            .join("remote_server");
2234
2235        let path = if !build_remote_server.contains("nocompress") {
2236            delegate.set_status(Some("Compressing binary"), cx);
2237
2238            #[cfg(not(target_os = "windows"))]
2239            {
2240                run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
2241            }
2242            #[cfg(target_os = "windows")]
2243            {
2244                // On Windows, we use 7z to compress the binary
2245                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2246                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2247                if smol::fs::metadata(&gz_path).await.is_ok() {
2248                    smol::fs::remove_file(&gz_path).await?;
2249                }
2250                run_cmd(Command::new(seven_zip).args([
2251                    "a",
2252                    "-tgzip",
2253                    &gz_path,
2254                    &bin_path.to_string_lossy(),
2255                ]))
2256                .await?;
2257            }
2258
2259            let mut archive_path = bin_path;
2260            archive_path.set_extension("gz");
2261            std::env::current_dir()?.join(archive_path)
2262        } else {
2263            bin_path
2264        };
2265
2266        Ok(path)
2267    }
2268}
2269
2270type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2271
2272pub struct ChannelClient {
2273    next_message_id: AtomicU32,
2274    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2275    buffer: Mutex<VecDeque<Envelope>>,
2276    response_channels: ResponseChannels,
2277    message_handlers: Mutex<ProtoMessageHandlerSet>,
2278    max_received: AtomicU32,
2279    name: &'static str,
2280    task: Mutex<Task<Result<()>>>,
2281}
2282
2283impl ChannelClient {
2284    pub fn new(
2285        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2286        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2287        cx: &App,
2288        name: &'static str,
2289    ) -> Arc<Self> {
2290        Arc::new_cyclic(|this| Self {
2291            outgoing_tx: Mutex::new(outgoing_tx),
2292            next_message_id: AtomicU32::new(0),
2293            max_received: AtomicU32::new(0),
2294            response_channels: ResponseChannels::default(),
2295            message_handlers: Default::default(),
2296            buffer: Mutex::new(VecDeque::new()),
2297            name,
2298            task: Mutex::new(Self::start_handling_messages(
2299                this.clone(),
2300                incoming_rx,
2301                &cx.to_async(),
2302            )),
2303        })
2304    }
2305
2306    fn start_handling_messages(
2307        this: Weak<Self>,
2308        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2309        cx: &AsyncApp,
2310    ) -> Task<Result<()>> {
2311        cx.spawn(async move |cx| {
2312            let peer_id = PeerId { owner_id: 0, id: 0 };
2313            while let Some(incoming) = incoming_rx.next().await {
2314                let Some(this) = this.upgrade() else {
2315                    return anyhow::Ok(());
2316                };
2317                if let Some(ack_id) = incoming.ack_id {
2318                    let mut buffer = this.buffer.lock();
2319                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2320                        buffer.pop_front();
2321                    }
2322                }
2323                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2324                {
2325                    log::debug!(
2326                        "{}:ssh message received. name:FlushBufferedMessages",
2327                        this.name
2328                    );
2329                    {
2330                        let buffer = this.buffer.lock();
2331                        for envelope in buffer.iter() {
2332                            this.outgoing_tx
2333                                .lock()
2334                                .unbounded_send(envelope.clone())
2335                                .ok();
2336                        }
2337                    }
2338                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2339                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2340                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2341                    continue;
2342                }
2343
2344                this.max_received.store(incoming.id, SeqCst);
2345
2346                if let Some(request_id) = incoming.responding_to {
2347                    let request_id = MessageId(request_id);
2348                    let sender = this.response_channels.lock().remove(&request_id);
2349                    if let Some(sender) = sender {
2350                        let (tx, rx) = oneshot::channel();
2351                        if incoming.payload.is_some() {
2352                            sender.send((incoming, tx)).ok();
2353                        }
2354                        rx.await.ok();
2355                    }
2356                } else if let Some(envelope) =
2357                    build_typed_envelope(peer_id, Instant::now(), incoming)
2358                {
2359                    let type_name = envelope.payload_type_name();
2360                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2361                        &this.message_handlers,
2362                        envelope,
2363                        this.clone().into(),
2364                        cx.clone(),
2365                    ) {
2366                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2367                        cx.foreground_executor()
2368                            .spawn(async move {
2369                                match future.await {
2370                                    Ok(_) => {
2371                                        log::debug!(
2372                                            "{}:ssh message handled. name:{type_name}",
2373                                            this.name
2374                                        );
2375                                    }
2376                                    Err(error) => {
2377                                        log::error!(
2378                                            "{}:error handling message. type:{}, error:{}",
2379                                            this.name,
2380                                            type_name,
2381                                            format!("{error:#}").lines().fold(
2382                                                String::new(),
2383                                                |mut message, line| {
2384                                                    if !message.is_empty() {
2385                                                        message.push(' ');
2386                                                    }
2387                                                    message.push_str(line);
2388                                                    message
2389                                                }
2390                                            )
2391                                        );
2392                                    }
2393                                }
2394                            })
2395                            .detach()
2396                    } else {
2397                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2398                    }
2399                }
2400            }
2401            anyhow::Ok(())
2402        })
2403    }
2404
2405    pub fn reconnect(
2406        self: &Arc<Self>,
2407        incoming_rx: UnboundedReceiver<Envelope>,
2408        outgoing_tx: UnboundedSender<Envelope>,
2409        cx: &AsyncApp,
2410    ) {
2411        *self.outgoing_tx.lock() = outgoing_tx;
2412        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2413    }
2414
2415    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2416        let id = (TypeId::of::<E>(), remote_id);
2417
2418        let mut message_handlers = self.message_handlers.lock();
2419        if message_handlers
2420            .entities_by_type_and_remote_id
2421            .contains_key(&id)
2422        {
2423            panic!("already subscribed to entity");
2424        }
2425
2426        message_handlers.entities_by_type_and_remote_id.insert(
2427            id,
2428            EntityMessageSubscriber::Entity {
2429                handle: entity.downgrade().into(),
2430            },
2431        );
2432    }
2433
2434    pub fn request<T: RequestMessage>(
2435        &self,
2436        payload: T,
2437    ) -> impl 'static + Future<Output = Result<T::Response>> {
2438        self.request_internal(payload, true)
2439    }
2440
2441    fn request_internal<T: RequestMessage>(
2442        &self,
2443        payload: T,
2444        use_buffer: bool,
2445    ) -> impl 'static + Future<Output = Result<T::Response>> {
2446        log::debug!("ssh request start. name:{}", T::NAME);
2447        let response =
2448            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2449        async move {
2450            let response = response.await?;
2451            log::debug!("ssh request finish. name:{}", T::NAME);
2452            T::Response::from_envelope(response).context("received a response of the wrong type")
2453        }
2454    }
2455
2456    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2457        smol::future::or(
2458            async {
2459                self.request_internal(proto::FlushBufferedMessages {}, false)
2460                    .await?;
2461
2462                for envelope in self.buffer.lock().iter() {
2463                    self.outgoing_tx
2464                        .lock()
2465                        .unbounded_send(envelope.clone())
2466                        .ok();
2467                }
2468                Ok(())
2469            },
2470            async {
2471                smol::Timer::after(timeout).await;
2472                anyhow::bail!("Timed out resyncing remote client")
2473            },
2474        )
2475        .await
2476    }
2477
2478    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2479        smol::future::or(
2480            async {
2481                self.request(proto::Ping {}).await?;
2482                Ok(())
2483            },
2484            async {
2485                smol::Timer::after(timeout).await;
2486                anyhow::bail!("Timed out pinging remote client")
2487            },
2488        )
2489        .await
2490    }
2491
2492    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2493        log::debug!("ssh send name:{}", T::NAME);
2494        self.send_dynamic(payload.into_envelope(0, None, None))
2495    }
2496
2497    fn request_dynamic(
2498        &self,
2499        mut envelope: proto::Envelope,
2500        type_name: &'static str,
2501        use_buffer: bool,
2502    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2503        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2504        let (tx, rx) = oneshot::channel();
2505        let mut response_channels_lock = self.response_channels.lock();
2506        response_channels_lock.insert(MessageId(envelope.id), tx);
2507        drop(response_channels_lock);
2508
2509        let result = if use_buffer {
2510            self.send_buffered(envelope)
2511        } else {
2512            self.send_unbuffered(envelope)
2513        };
2514        async move {
2515            if let Err(error) = &result {
2516                log::error!("failed to send message: {error}");
2517                anyhow::bail!("failed to send message: {error}");
2518            }
2519
2520            let response = rx.await.context("connection lost")?.0;
2521            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2522                return Err(RpcError::from_proto(error, type_name));
2523            }
2524            Ok(response)
2525        }
2526    }
2527
2528    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2529        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2530        self.send_buffered(envelope)
2531    }
2532
2533    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2534        envelope.ack_id = Some(self.max_received.load(SeqCst));
2535        self.buffer.lock().push_back(envelope.clone());
2536        // ignore errors on send (happen while we're reconnecting)
2537        // assume that the global "disconnected" overlay is sufficient.
2538        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2539        Ok(())
2540    }
2541
2542    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2543        envelope.ack_id = Some(self.max_received.load(SeqCst));
2544        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2545        Ok(())
2546    }
2547}
2548
2549impl ProtoClient for ChannelClient {
2550    fn request(
2551        &self,
2552        envelope: proto::Envelope,
2553        request_type: &'static str,
2554    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2555        self.request_dynamic(envelope, request_type, true).boxed()
2556    }
2557
2558    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2559        self.send_dynamic(envelope)
2560    }
2561
2562    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2563        self.send_dynamic(envelope)
2564    }
2565
2566    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2567        &self.message_handlers
2568    }
2569
2570    fn is_via_collab(&self) -> bool {
2571        false
2572    }
2573}
2574
2575#[cfg(any(test, feature = "test-support"))]
2576mod fake {
2577    use std::{path::PathBuf, sync::Arc};
2578
2579    use anyhow::Result;
2580    use async_trait::async_trait;
2581    use futures::{
2582        FutureExt, SinkExt, StreamExt,
2583        channel::{
2584            mpsc::{self, Sender},
2585            oneshot,
2586        },
2587        select_biased,
2588    };
2589    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2590    use release_channel::ReleaseChannel;
2591    use rpc::proto::Envelope;
2592    use util::paths::{PathStyle, RemotePathBuf};
2593
2594    use super::{
2595        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2596        SshPlatform,
2597    };
2598
2599    pub(super) struct FakeRemoteConnection {
2600        pub(super) connection_options: SshConnectionOptions,
2601        pub(super) server_channel: Arc<ChannelClient>,
2602        pub(super) server_cx: SendableCx,
2603    }
2604
2605    pub(super) struct SendableCx(AsyncApp);
2606    impl SendableCx {
2607        // SAFETY: When run in test mode, GPUI is always single threaded.
2608        pub(super) fn new(cx: &TestAppContext) -> Self {
2609            Self(cx.to_async())
2610        }
2611
2612        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2613        fn get(&self, _: &AsyncApp) -> AsyncApp {
2614            self.0.clone()
2615        }
2616    }
2617
2618    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2619    unsafe impl Send for SendableCx {}
2620    unsafe impl Sync for SendableCx {}
2621
2622    #[async_trait(?Send)]
2623    impl RemoteConnection for FakeRemoteConnection {
2624        async fn kill(&self) -> Result<()> {
2625            Ok(())
2626        }
2627
2628        fn has_been_killed(&self) -> bool {
2629            false
2630        }
2631
2632        fn ssh_args(&self) -> SshArgs {
2633            SshArgs {
2634                arguments: Vec::new(),
2635                envs: None,
2636            }
2637        }
2638
2639        fn upload_directory(
2640            &self,
2641            _src_path: PathBuf,
2642            _dest_path: RemotePathBuf,
2643            _cx: &App,
2644        ) -> Task<Result<()>> {
2645            unreachable!()
2646        }
2647
2648        fn connection_options(&self) -> SshConnectionOptions {
2649            self.connection_options.clone()
2650        }
2651
2652        fn simulate_disconnect(&self, cx: &AsyncApp) {
2653            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2654            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2655            self.server_channel
2656                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
2657        }
2658
2659        fn start_proxy(
2660            &self,
2661            _unique_identifier: String,
2662            _reconnect: bool,
2663            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2664            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2665            mut connection_activity_tx: Sender<()>,
2666            _delegate: Arc<dyn SshClientDelegate>,
2667            cx: &mut AsyncApp,
2668        ) -> Task<Result<i32>> {
2669            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2670            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2671
2672            self.server_channel.reconnect(
2673                server_incoming_rx,
2674                server_outgoing_tx,
2675                &self.server_cx.get(cx),
2676            );
2677
2678            cx.background_spawn(async move {
2679                loop {
2680                    select_biased! {
2681                        server_to_client = server_outgoing_rx.next().fuse() => {
2682                            let Some(server_to_client) = server_to_client else {
2683                                return Ok(1)
2684                            };
2685                            connection_activity_tx.try_send(()).ok();
2686                            client_incoming_tx.send(server_to_client).await.ok();
2687                        }
2688                        client_to_server = client_outgoing_rx.next().fuse() => {
2689                            let Some(client_to_server) = client_to_server else {
2690                                return Ok(1)
2691                            };
2692                            server_incoming_tx.send(client_to_server).await.ok();
2693                        }
2694                    }
2695                }
2696            })
2697        }
2698
2699        fn path_style(&self) -> PathStyle {
2700            PathStyle::current()
2701        }
2702    }
2703
2704    pub(super) struct Delegate;
2705
2706    impl SshClientDelegate for Delegate {
2707        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2708            unreachable!()
2709        }
2710
2711        fn download_server_binary_locally(
2712            &self,
2713            _: SshPlatform,
2714            _: ReleaseChannel,
2715            _: Option<SemanticVersion>,
2716            _: &mut AsyncApp,
2717        ) -> Task<Result<PathBuf>> {
2718            unreachable!()
2719        }
2720
2721        fn get_download_params(
2722            &self,
2723            _platform: SshPlatform,
2724            _release_channel: ReleaseChannel,
2725            _version: Option<SemanticVersion>,
2726            _cx: &mut AsyncApp,
2727        ) -> Task<Result<Option<(String, String)>>> {
2728            unreachable!()
2729        }
2730
2731        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2732    }
2733}