ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  30    RpcError,
  31    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  32};
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize};
  35use smol::{
  36    fs,
  37    process::{self, Child, Stdio},
  38};
  39use std::{
  40    any::TypeId,
  41    collections::VecDeque,
  42    fmt, iter,
  43    ops::ControlFlow,
  44    path::{Path, PathBuf},
  45    sync::{
  46        Arc, Weak,
  47        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  48    },
  49    time::{Duration, Instant},
  50};
  51use tempfile::TempDir;
  52use util::{
  53    ResultExt,
  54    paths::{PathStyle, RemotePathBuf},
  55};
  56
  57#[derive(
  58    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  59)]
  60pub struct SshProjectId(pub u64);
  61
  62#[derive(Clone)]
  63pub struct SshSocket {
  64    connection_options: SshConnectionOptions,
  65    #[cfg(not(target_os = "windows"))]
  66    socket_path: PathBuf,
  67    #[cfg(target_os = "windows")]
  68    envs: HashMap<String, String>,
  69}
  70
  71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  72pub struct SshPortForwardOption {
  73    #[serde(skip_serializing_if = "Option::is_none")]
  74    pub local_host: Option<String>,
  75    pub local_port: u16,
  76    #[serde(skip_serializing_if = "Option::is_none")]
  77    pub remote_host: Option<String>,
  78    pub remote_port: u16,
  79}
  80
  81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  82pub struct SshConnectionOptions {
  83    pub host: String,
  84    pub username: Option<String>,
  85    pub port: Option<u16>,
  86    pub password: Option<String>,
  87    pub args: Option<Vec<String>>,
  88    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  89
  90    pub nickname: Option<String>,
  91    pub upload_binary_over_ssh: bool,
  92}
  93
  94pub struct SshArgs {
  95    pub arguments: Vec<String>,
  96    pub envs: Option<HashMap<String, String>>,
  97}
  98
  99#[macro_export]
 100macro_rules! shell_script {
 101    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 102        format!(
 103            $fmt,
 104            $(
 105                $name = shlex::try_quote($arg).unwrap()
 106            ),+
 107        )
 108    }};
 109}
 110
 111fn parse_port_number(port_str: &str) -> Result<u16> {
 112    port_str
 113        .parse()
 114        .with_context(|| format!("parsing port number: {port_str}"))
 115}
 116
 117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 118    let parts: Vec<&str> = spec.split(':').collect();
 119
 120    match parts.len() {
 121        4 => {
 122            let local_port = parse_port_number(parts[1])?;
 123            let remote_port = parse_port_number(parts[3])?;
 124
 125            Ok(SshPortForwardOption {
 126                local_host: Some(parts[0].to_string()),
 127                local_port,
 128                remote_host: Some(parts[2].to_string()),
 129                remote_port,
 130            })
 131        }
 132        3 => {
 133            let local_port = parse_port_number(parts[0])?;
 134            let remote_port = parse_port_number(parts[2])?;
 135
 136            Ok(SshPortForwardOption {
 137                local_host: None,
 138                local_port,
 139                remote_host: Some(parts[1].to_string()),
 140                remote_port,
 141            })
 142        }
 143        _ => anyhow::bail!("Invalid port forward format"),
 144    }
 145}
 146
 147impl SshConnectionOptions {
 148    pub fn parse_command_line(input: &str) -> Result<Self> {
 149        let input = input.trim_start_matches("ssh ");
 150        let mut hostname: Option<String> = None;
 151        let mut username: Option<String> = None;
 152        let mut port: Option<u16> = None;
 153        let mut args = Vec::new();
 154        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 155
 156        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 157        const ALLOWED_OPTS: &[&str] = &[
 158            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 159        ];
 160        const ALLOWED_ARGS: &[&str] = &[
 161            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 162            "-w",
 163        ];
 164
 165        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 166
 167        'outer: while let Some(arg) = tokens.next() {
 168            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 169                args.push(arg.to_string());
 170                continue;
 171            }
 172            if arg == "-p" {
 173                port = tokens.next().and_then(|arg| arg.parse().ok());
 174                continue;
 175            } else if let Some(p) = arg.strip_prefix("-p") {
 176                port = p.parse().ok();
 177                continue;
 178            }
 179            if arg == "-l" {
 180                username = tokens.next();
 181                continue;
 182            } else if let Some(l) = arg.strip_prefix("-l") {
 183                username = Some(l.to_string());
 184                continue;
 185            }
 186            if arg == "-L" || arg.starts_with("-L") {
 187                let forward_spec = if arg == "-L" {
 188                    tokens.next()
 189                } else {
 190                    Some(arg.strip_prefix("-L").unwrap().to_string())
 191                };
 192
 193                if let Some(spec) = forward_spec {
 194                    port_forwards.push(parse_port_forward_spec(&spec)?);
 195                } else {
 196                    anyhow::bail!("Missing port forward format");
 197                }
 198            }
 199
 200            for a in ALLOWED_ARGS {
 201                if arg == *a {
 202                    args.push(arg);
 203                    if let Some(next) = tokens.next() {
 204                        args.push(next);
 205                    }
 206                    continue 'outer;
 207                } else if arg.starts_with(a) {
 208                    args.push(arg);
 209                    continue 'outer;
 210                }
 211            }
 212            if arg.starts_with("-") || hostname.is_some() {
 213                anyhow::bail!("unsupported argument: {:?}", arg);
 214            }
 215            let mut input = &arg as &str;
 216            // Destination might be: username1@username2@ip2@ip1
 217            if let Some((u, rest)) = input.rsplit_once('@') {
 218                input = rest;
 219                username = Some(u.to_string());
 220            }
 221            if let Some((rest, p)) = input.split_once(':') {
 222                input = rest;
 223                port = p.parse().ok()
 224            }
 225            hostname = Some(input.to_string())
 226        }
 227
 228        let Some(hostname) = hostname else {
 229            anyhow::bail!("missing hostname");
 230        };
 231
 232        let port_forwards = match port_forwards.len() {
 233            0 => None,
 234            _ => Some(port_forwards),
 235        };
 236
 237        Ok(Self {
 238            host: hostname.to_string(),
 239            username: username.clone(),
 240            port,
 241            port_forwards,
 242            args: Some(args),
 243            password: None,
 244            nickname: None,
 245            upload_binary_over_ssh: false,
 246        })
 247    }
 248
 249    pub fn ssh_url(&self) -> String {
 250        let mut result = String::from("ssh://");
 251        if let Some(username) = &self.username {
 252            // Username might be: username1@username2@ip2
 253            let username = urlencoding::encode(username);
 254            result.push_str(&username);
 255            result.push('@');
 256        }
 257        result.push_str(&self.host);
 258        if let Some(port) = self.port {
 259            result.push(':');
 260            result.push_str(&port.to_string());
 261        }
 262        result
 263    }
 264
 265    pub fn additional_args(&self) -> Vec<String> {
 266        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 267
 268        if let Some(forwards) = &self.port_forwards {
 269            args.extend(forwards.iter().map(|pf| {
 270                let local_host = match &pf.local_host {
 271                    Some(host) => host,
 272                    None => "localhost",
 273                };
 274                let remote_host = match &pf.remote_host {
 275                    Some(host) => host,
 276                    None => "localhost",
 277                };
 278
 279                format!(
 280                    "-L{}:{}:{}:{}",
 281                    local_host, pf.local_port, remote_host, pf.remote_port
 282                )
 283            }));
 284        }
 285
 286        args
 287    }
 288
 289    fn scp_url(&self) -> String {
 290        if let Some(username) = &self.username {
 291            format!("{}@{}", username, self.host)
 292        } else {
 293            self.host.clone()
 294        }
 295    }
 296
 297    pub fn connection_string(&self) -> String {
 298        let host = if let Some(username) = &self.username {
 299            format!("{}@{}", username, self.host)
 300        } else {
 301            self.host.clone()
 302        };
 303        if let Some(port) = &self.port {
 304            format!("{}:{}", host, port)
 305        } else {
 306            host
 307        }
 308    }
 309}
 310
 311#[derive(Copy, Clone, Debug)]
 312pub struct SshPlatform {
 313    pub os: &'static str,
 314    pub arch: &'static str,
 315}
 316
 317impl SshPlatform {
 318    pub fn triple(&self) -> Option<String> {
 319        Some(format!(
 320            "{}-{}",
 321            self.arch,
 322            match self.os {
 323                "linux" => "unknown-linux-gnu",
 324                "macos" => "apple-darwin",
 325                _ => return None,
 326            }
 327        ))
 328    }
 329}
 330
 331pub trait SshClientDelegate: Send + Sync {
 332    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 333    fn get_download_params(
 334        &self,
 335        platform: SshPlatform,
 336        release_channel: ReleaseChannel,
 337        version: Option<SemanticVersion>,
 338        cx: &mut AsyncApp,
 339    ) -> Task<Result<Option<(String, String)>>>;
 340
 341    fn download_server_binary_locally(
 342        &self,
 343        platform: SshPlatform,
 344        release_channel: ReleaseChannel,
 345        version: Option<SemanticVersion>,
 346        cx: &mut AsyncApp,
 347    ) -> Task<Result<PathBuf>>;
 348    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 349}
 350
 351impl SshSocket {
 352    #[cfg(not(target_os = "windows"))]
 353    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 354        Ok(Self {
 355            connection_options: options,
 356            socket_path,
 357        })
 358    }
 359
 360    #[cfg(target_os = "windows")]
 361    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 362        let askpass_script = temp_dir.path().join("askpass.bat");
 363        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 364        let mut envs = HashMap::default();
 365        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 366        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 367        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 368        Ok(Self {
 369            connection_options: options,
 370            envs,
 371        })
 372    }
 373
 374    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 375    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 376    // and passes -l as an argument to sh, not to ls.
 377    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 378    // into a machine. You must use `cd` to get back to $HOME.
 379    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 380    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 381        let mut command = util::command::new_smol_command("ssh");
 382        let to_run = iter::once(&program)
 383            .chain(args.iter())
 384            .map(|token| {
 385                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 386                debug_assert!(
 387                    !token.contains('\n'),
 388                    "multiline arguments do not work in all shells"
 389                );
 390                shlex::try_quote(token).unwrap()
 391            })
 392            .join(" ");
 393        let to_run = format!("cd; {to_run}");
 394        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 395        self.ssh_options(&mut command)
 396            .arg(self.connection_options.ssh_url())
 397            .arg(to_run);
 398        command
 399    }
 400
 401    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 402        let output = self.ssh_command(program, args).output().await?;
 403        anyhow::ensure!(
 404            output.status.success(),
 405            "failed to run command: {}",
 406            String::from_utf8_lossy(&output.stderr)
 407        );
 408        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 409    }
 410
 411    #[cfg(not(target_os = "windows"))]
 412    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 413        command
 414            .stdin(Stdio::piped())
 415            .stdout(Stdio::piped())
 416            .stderr(Stdio::piped())
 417            .args(["-o", "ControlMaster=no", "-o"])
 418            .arg(format!("ControlPath={}", self.socket_path.display()))
 419    }
 420
 421    #[cfg(target_os = "windows")]
 422    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 423        command
 424            .stdin(Stdio::piped())
 425            .stdout(Stdio::piped())
 426            .stderr(Stdio::piped())
 427            .envs(self.envs.clone())
 428    }
 429
 430    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 431    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 432    #[cfg(not(target_os = "windows"))]
 433    fn ssh_args(&self) -> SshArgs {
 434        SshArgs {
 435            arguments: vec![
 436                "-o".to_string(),
 437                "ControlMaster=no".to_string(),
 438                "-o".to_string(),
 439                format!("ControlPath={}", self.socket_path.display()),
 440                self.connection_options.ssh_url(),
 441            ],
 442            envs: None,
 443        }
 444    }
 445
 446    #[cfg(target_os = "windows")]
 447    fn ssh_args(&self) -> SshArgs {
 448        SshArgs {
 449            arguments: vec![self.connection_options.ssh_url()],
 450            envs: Some(self.envs.clone()),
 451        }
 452    }
 453
 454    async fn platform(&self) -> Result<SshPlatform> {
 455        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 456        let Some((os, arch)) = uname.split_once(" ") else {
 457            anyhow::bail!("unknown uname: {uname:?}")
 458        };
 459
 460        let os = match os.trim() {
 461            "Darwin" => "macos",
 462            "Linux" => "linux",
 463            _ => anyhow::bail!(
 464                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 465            ),
 466        };
 467        // exclude armv5,6,7 as they are 32-bit.
 468        let arch = if arch.starts_with("armv8")
 469            || arch.starts_with("armv9")
 470            || arch.starts_with("arm64")
 471            || arch.starts_with("aarch64")
 472        {
 473            "aarch64"
 474        } else if arch.starts_with("x86") {
 475            "x86_64"
 476        } else {
 477            anyhow::bail!(
 478                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 479            )
 480        };
 481
 482        Ok(SshPlatform { os, arch })
 483    }
 484}
 485
 486const MAX_MISSED_HEARTBEATS: usize = 5;
 487const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 488const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 489
 490const MAX_RECONNECT_ATTEMPTS: usize = 3;
 491
 492enum State {
 493    Connecting,
 494    Connected {
 495        ssh_connection: Arc<dyn RemoteConnection>,
 496        delegate: Arc<dyn SshClientDelegate>,
 497
 498        multiplex_task: Task<Result<()>>,
 499        heartbeat_task: Task<Result<()>>,
 500    },
 501    HeartbeatMissed {
 502        missed_heartbeats: usize,
 503
 504        ssh_connection: Arc<dyn RemoteConnection>,
 505        delegate: Arc<dyn SshClientDelegate>,
 506
 507        multiplex_task: Task<Result<()>>,
 508        heartbeat_task: Task<Result<()>>,
 509    },
 510    Reconnecting,
 511    ReconnectFailed {
 512        ssh_connection: Arc<dyn RemoteConnection>,
 513        delegate: Arc<dyn SshClientDelegate>,
 514
 515        error: anyhow::Error,
 516        attempts: usize,
 517    },
 518    ReconnectExhausted,
 519    ServerNotRunning,
 520}
 521
 522impl fmt::Display for State {
 523    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 524        match self {
 525            Self::Connecting => write!(f, "connecting"),
 526            Self::Connected { .. } => write!(f, "connected"),
 527            Self::Reconnecting => write!(f, "reconnecting"),
 528            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 529            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 530            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 531            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 532        }
 533    }
 534}
 535
 536impl State {
 537    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 538        match self {
 539            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 540            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 541            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 542            _ => None,
 543        }
 544    }
 545
 546    fn can_reconnect(&self) -> bool {
 547        match self {
 548            Self::Connected { .. }
 549            | Self::HeartbeatMissed { .. }
 550            | Self::ReconnectFailed { .. } => true,
 551            State::Connecting
 552            | State::Reconnecting
 553            | State::ReconnectExhausted
 554            | State::ServerNotRunning => false,
 555        }
 556    }
 557
 558    fn is_reconnect_failed(&self) -> bool {
 559        matches!(self, Self::ReconnectFailed { .. })
 560    }
 561
 562    fn is_reconnect_exhausted(&self) -> bool {
 563        matches!(self, Self::ReconnectExhausted { .. })
 564    }
 565
 566    fn is_server_not_running(&self) -> bool {
 567        matches!(self, Self::ServerNotRunning)
 568    }
 569
 570    fn is_reconnecting(&self) -> bool {
 571        matches!(self, Self::Reconnecting { .. })
 572    }
 573
 574    fn heartbeat_recovered(self) -> Self {
 575        match self {
 576            Self::HeartbeatMissed {
 577                ssh_connection,
 578                delegate,
 579                multiplex_task,
 580                heartbeat_task,
 581                ..
 582            } => Self::Connected {
 583                ssh_connection,
 584                delegate,
 585                multiplex_task,
 586                heartbeat_task,
 587            },
 588            _ => self,
 589        }
 590    }
 591
 592    fn heartbeat_missed(self) -> Self {
 593        match self {
 594            Self::Connected {
 595                ssh_connection,
 596                delegate,
 597                multiplex_task,
 598                heartbeat_task,
 599            } => Self::HeartbeatMissed {
 600                missed_heartbeats: 1,
 601                ssh_connection,
 602                delegate,
 603                multiplex_task,
 604                heartbeat_task,
 605            },
 606            Self::HeartbeatMissed {
 607                missed_heartbeats,
 608                ssh_connection,
 609                delegate,
 610                multiplex_task,
 611                heartbeat_task,
 612            } => Self::HeartbeatMissed {
 613                missed_heartbeats: missed_heartbeats + 1,
 614                ssh_connection,
 615                delegate,
 616                multiplex_task,
 617                heartbeat_task,
 618            },
 619            _ => self,
 620        }
 621    }
 622}
 623
 624/// The state of the ssh connection.
 625#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 626pub enum ConnectionState {
 627    Connecting,
 628    Connected,
 629    HeartbeatMissed,
 630    Reconnecting,
 631    Disconnected,
 632}
 633
 634impl From<&State> for ConnectionState {
 635    fn from(value: &State) -> Self {
 636        match value {
 637            State::Connecting => Self::Connecting,
 638            State::Connected { .. } => Self::Connected,
 639            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 640            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 641            State::ReconnectExhausted => Self::Disconnected,
 642            State::ServerNotRunning => Self::Disconnected,
 643        }
 644    }
 645}
 646
 647pub struct SshRemoteClient {
 648    client: Arc<ChannelClient>,
 649    unique_identifier: String,
 650    connection_options: SshConnectionOptions,
 651    path_style: PathStyle,
 652    state: Arc<Mutex<Option<State>>>,
 653}
 654
 655#[derive(Debug)]
 656pub enum SshRemoteEvent {
 657    Disconnected,
 658}
 659
 660impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 661
 662// Identifies the socket on the remote server so that reconnects
 663// can re-join the same project.
 664pub enum ConnectionIdentifier {
 665    Setup(u64),
 666    Workspace(i64),
 667}
 668
 669static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 670
 671impl ConnectionIdentifier {
 672    pub fn setup() -> Self {
 673        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 674    }
 675    // This string gets used in a socket name, and so must be relatively short.
 676    // The total length of:
 677    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 678    // Must be less than about 100 characters
 679    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 680    // So our strings should be at most 20 characters or so.
 681    fn to_string(&self, cx: &App) -> String {
 682        let identifier_prefix = match ReleaseChannel::global(cx) {
 683            ReleaseChannel::Stable => "".to_string(),
 684            release_channel => format!("{}-", release_channel.dev_name()),
 685        };
 686        match self {
 687            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 688            Self::Workspace(workspace_id) => {
 689                format!("{identifier_prefix}workspace-{workspace_id}",)
 690            }
 691        }
 692    }
 693}
 694
 695impl SshRemoteClient {
 696    pub fn new(
 697        unique_identifier: ConnectionIdentifier,
 698        connection_options: SshConnectionOptions,
 699        cancellation: oneshot::Receiver<()>,
 700        delegate: Arc<dyn SshClientDelegate>,
 701        cx: &mut App,
 702    ) -> Task<Result<Option<Entity<Self>>>> {
 703        let unique_identifier = unique_identifier.to_string(cx);
 704        cx.spawn(async move |cx| {
 705            let success = Box::pin(async move {
 706                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 707                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 708                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 709
 710                let client =
 711                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 712
 713                let ssh_connection = cx
 714                    .update(|cx| {
 715                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 716                            pool.connect(connection_options.clone(), &delegate, cx)
 717                        })
 718                    })?
 719                    .await
 720                    .map_err(|e| e.cloned())?;
 721
 722                let path_style = ssh_connection.path_style();
 723                let this = cx.new(|_| Self {
 724                    client: client.clone(),
 725                    unique_identifier: unique_identifier.clone(),
 726                    connection_options,
 727                    path_style,
 728                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 729                })?;
 730
 731                let io_task = ssh_connection.start_proxy(
 732                    unique_identifier,
 733                    false,
 734                    incoming_tx,
 735                    outgoing_rx,
 736                    connection_activity_tx,
 737                    delegate.clone(),
 738                    cx,
 739                );
 740
 741                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 742
 743                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 744                    log::error!("failed to establish connection: {}", error);
 745                    return Err(error);
 746                }
 747
 748                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 749
 750                this.update(cx, |this, _| {
 751                    *this.state.lock() = Some(State::Connected {
 752                        ssh_connection,
 753                        delegate,
 754                        multiplex_task,
 755                        heartbeat_task,
 756                    });
 757                })?;
 758
 759                Ok(Some(this))
 760            });
 761
 762            select! {
 763                _ = cancellation.fuse() => {
 764                    Ok(None)
 765                }
 766                result = success.fuse() =>  result
 767            }
 768        })
 769    }
 770
 771    pub fn shutdown_processes<T: RequestMessage>(
 772        &self,
 773        shutdown_request: Option<T>,
 774        executor: BackgroundExecutor,
 775    ) -> Option<impl Future<Output = ()> + use<T>> {
 776        let state = self.state.lock().take()?;
 777        log::info!("shutting down ssh processes");
 778
 779        let State::Connected {
 780            multiplex_task,
 781            heartbeat_task,
 782            ssh_connection,
 783            delegate,
 784        } = state
 785        else {
 786            return None;
 787        };
 788
 789        let client = self.client.clone();
 790
 791        Some(async move {
 792            if let Some(shutdown_request) = shutdown_request {
 793                client.send(shutdown_request).log_err();
 794                // We wait 50ms instead of waiting for a response, because
 795                // waiting for a response would require us to wait on the main thread
 796                // which we want to avoid in an `on_app_quit` callback.
 797                executor.timer(Duration::from_millis(50)).await;
 798            }
 799
 800            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 801            // child of master_process.
 802            drop(multiplex_task);
 803            // Now drop the rest of state, which kills master process.
 804            drop(heartbeat_task);
 805            drop(ssh_connection);
 806            drop(delegate);
 807        })
 808    }
 809
 810    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 811        let mut lock = self.state.lock();
 812
 813        let can_reconnect = lock
 814            .as_ref()
 815            .map(|state| state.can_reconnect())
 816            .unwrap_or(false);
 817        if !can_reconnect {
 818            log::info!("aborting reconnect, because not in state that allows reconnecting");
 819            let error = if let Some(state) = lock.as_ref() {
 820                format!("invalid state, cannot reconnect while in state {state}")
 821            } else {
 822                "no state set".to_string()
 823            };
 824            anyhow::bail!(error);
 825        }
 826
 827        let state = lock.take().unwrap();
 828        let (attempts, ssh_connection, delegate) = match state {
 829            State::Connected {
 830                ssh_connection,
 831                delegate,
 832                multiplex_task,
 833                heartbeat_task,
 834            }
 835            | State::HeartbeatMissed {
 836                ssh_connection,
 837                delegate,
 838                multiplex_task,
 839                heartbeat_task,
 840                ..
 841            } => {
 842                drop(multiplex_task);
 843                drop(heartbeat_task);
 844                (0, ssh_connection, delegate)
 845            }
 846            State::ReconnectFailed {
 847                attempts,
 848                ssh_connection,
 849                delegate,
 850                ..
 851            } => (attempts, ssh_connection, delegate),
 852            State::Connecting
 853            | State::Reconnecting
 854            | State::ReconnectExhausted
 855            | State::ServerNotRunning => unreachable!(),
 856        };
 857
 858        let attempts = attempts + 1;
 859        if attempts > MAX_RECONNECT_ATTEMPTS {
 860            log::error!(
 861                "Failed to reconnect to after {} attempts, giving up",
 862                MAX_RECONNECT_ATTEMPTS
 863            );
 864            drop(lock);
 865            self.set_state(State::ReconnectExhausted, cx);
 866            return Ok(());
 867        }
 868        drop(lock);
 869
 870        self.set_state(State::Reconnecting, cx);
 871
 872        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 873
 874        let unique_identifier = self.unique_identifier.clone();
 875        let client = self.client.clone();
 876        let reconnect_task = cx.spawn(async move |this, cx| {
 877            macro_rules! failed {
 878                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 879                    return State::ReconnectFailed {
 880                        error: anyhow!($error),
 881                        attempts: $attempts,
 882                        ssh_connection: $ssh_connection,
 883                        delegate: $delegate,
 884                    };
 885                };
 886            }
 887
 888            if let Err(error) = ssh_connection
 889                .kill()
 890                .await
 891                .context("Failed to kill ssh process")
 892            {
 893                failed!(error, attempts, ssh_connection, delegate);
 894            };
 895
 896            let connection_options = ssh_connection.connection_options();
 897
 898            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 899            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 900            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 901
 902            let (ssh_connection, io_task) = match async {
 903                let ssh_connection = cx
 904                    .update_global(|pool: &mut ConnectionPool, cx| {
 905                        pool.connect(connection_options, &delegate, cx)
 906                    })?
 907                    .await
 908                    .map_err(|error| error.cloned())?;
 909
 910                let io_task = ssh_connection.start_proxy(
 911                    unique_identifier,
 912                    true,
 913                    incoming_tx,
 914                    outgoing_rx,
 915                    connection_activity_tx,
 916                    delegate.clone(),
 917                    cx,
 918                );
 919                anyhow::Ok((ssh_connection, io_task))
 920            }
 921            .await
 922            {
 923                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 924                Err(error) => {
 925                    failed!(error, attempts, ssh_connection, delegate);
 926                }
 927            };
 928
 929            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 930            client.reconnect(incoming_rx, outgoing_tx, &cx);
 931
 932            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 933                failed!(error, attempts, ssh_connection, delegate);
 934            };
 935
 936            State::Connected {
 937                ssh_connection,
 938                delegate,
 939                multiplex_task,
 940                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 941            }
 942        });
 943
 944        cx.spawn(async move |this, cx| {
 945            let new_state = reconnect_task.await;
 946            this.update(cx, |this, cx| {
 947                this.try_set_state(cx, |old_state| {
 948                    if old_state.is_reconnecting() {
 949                        match &new_state {
 950                            State::Connecting
 951                            | State::Reconnecting { .. }
 952                            | State::HeartbeatMissed { .. }
 953                            | State::ServerNotRunning => {}
 954                            State::Connected { .. } => {
 955                                log::info!("Successfully reconnected");
 956                            }
 957                            State::ReconnectFailed {
 958                                error, attempts, ..
 959                            } => {
 960                                log::error!(
 961                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 962                                    attempts,
 963                                    error
 964                                );
 965                            }
 966                            State::ReconnectExhausted => {
 967                                log::error!("Reconnect attempt failed and all attempts exhausted");
 968                            }
 969                        }
 970                        Some(new_state)
 971                    } else {
 972                        None
 973                    }
 974                });
 975
 976                if this.state_is(State::is_reconnect_failed) {
 977                    this.reconnect(cx)
 978                } else if this.state_is(State::is_reconnect_exhausted) {
 979                    Ok(())
 980                } else {
 981                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 982                    Ok(())
 983                }
 984            })
 985        })
 986        .detach_and_log_err(cx);
 987
 988        Ok(())
 989    }
 990
 991    fn heartbeat(
 992        this: WeakEntity<Self>,
 993        mut connection_activity_rx: mpsc::Receiver<()>,
 994        cx: &mut AsyncApp,
 995    ) -> Task<Result<()>> {
 996        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 997            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 998        };
 999
1000        cx.spawn(async move |cx| {
1001                let mut missed_heartbeats = 0;
1002
1003                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
1004                futures::pin_mut!(keepalive_timer);
1005
1006                loop {
1007                    select_biased! {
1008                        result = connection_activity_rx.next().fuse() => {
1009                            if result.is_none() {
1010                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1011                                return Ok(());
1012                            }
1013
1014                            if missed_heartbeats != 0 {
1015                                missed_heartbeats = 0;
1016                                let _ =this.update(cx, |this, mut cx| {
1017                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1018                                })?;
1019                            }
1020                        }
1021                        _ = keepalive_timer => {
1022                            log::debug!("Sending heartbeat to server...");
1023
1024                            let result = select_biased! {
1025                                _ = connection_activity_rx.next().fuse() => {
1026                                    Ok(())
1027                                }
1028                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1029                                    ping_result
1030                                }
1031                            };
1032
1033                            if result.is_err() {
1034                                missed_heartbeats += 1;
1035                                log::warn!(
1036                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1037                                    HEARTBEAT_TIMEOUT,
1038                                    missed_heartbeats,
1039                                    MAX_MISSED_HEARTBEATS
1040                                );
1041                            } else if missed_heartbeats != 0 {
1042                                missed_heartbeats = 0;
1043                            } else {
1044                                continue;
1045                            }
1046
1047                            let result = this.update(cx, |this, mut cx| {
1048                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1049                            })?;
1050                            if result.is_break() {
1051                                return Ok(());
1052                            }
1053                        }
1054                    }
1055
1056                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1057                }
1058
1059        })
1060    }
1061
1062    fn handle_heartbeat_result(
1063        &mut self,
1064        missed_heartbeats: usize,
1065        cx: &mut Context<Self>,
1066    ) -> ControlFlow<()> {
1067        let state = self.state.lock().take().unwrap();
1068        let next_state = if missed_heartbeats > 0 {
1069            state.heartbeat_missed()
1070        } else {
1071            state.heartbeat_recovered()
1072        };
1073
1074        self.set_state(next_state, cx);
1075
1076        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1077            log::error!(
1078                "Missed last {} heartbeats. Reconnecting...",
1079                missed_heartbeats
1080            );
1081
1082            self.reconnect(cx)
1083                .context("failed to start reconnect process after missing heartbeats")
1084                .log_err();
1085            ControlFlow::Break(())
1086        } else {
1087            ControlFlow::Continue(())
1088        }
1089    }
1090
1091    fn monitor(
1092        this: WeakEntity<Self>,
1093        io_task: Task<Result<i32>>,
1094        cx: &AsyncApp,
1095    ) -> Task<Result<()>> {
1096        cx.spawn(async move |cx| {
1097            let result = io_task.await;
1098
1099            match result {
1100                Ok(exit_code) => {
1101                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1102                        match error {
1103                            ProxyLaunchError::ServerNotRunning => {
1104                                log::error!("failed to reconnect because server is not running");
1105                                this.update(cx, |this, cx| {
1106                                    this.set_state(State::ServerNotRunning, cx);
1107                                })?;
1108                            }
1109                        }
1110                    } else if exit_code > 0 {
1111                        log::error!("proxy process terminated unexpectedly");
1112                        this.update(cx, |this, cx| {
1113                            this.reconnect(cx).ok();
1114                        })?;
1115                    }
1116                }
1117                Err(error) => {
1118                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1119                    this.update(cx, |this, cx| {
1120                        this.reconnect(cx).ok();
1121                    })?;
1122                }
1123            }
1124
1125            Ok(())
1126        })
1127    }
1128
1129    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1130        self.state.lock().as_ref().map_or(false, check)
1131    }
1132
1133    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1134        let mut lock = self.state.lock();
1135        let new_state = lock.as_ref().and_then(map);
1136
1137        if let Some(new_state) = new_state {
1138            lock.replace(new_state);
1139            cx.notify();
1140        }
1141    }
1142
1143    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1144        log::info!("setting state to '{}'", &state);
1145
1146        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1147        let is_server_not_running = state.is_server_not_running();
1148        self.state.lock().replace(state);
1149
1150        if is_reconnect_exhausted || is_server_not_running {
1151            cx.emit(SshRemoteEvent::Disconnected);
1152        }
1153        cx.notify();
1154    }
1155
1156    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1157        self.client.subscribe_to_entity(remote_id, entity);
1158    }
1159
1160    pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1161        self.state
1162            .lock()
1163            .as_ref()
1164            .and_then(|state| state.ssh_connection())
1165            .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1166    }
1167
1168    pub fn upload_directory(
1169        &self,
1170        src_path: PathBuf,
1171        dest_path: RemotePathBuf,
1172        cx: &App,
1173    ) -> Task<Result<()>> {
1174        let state = self.state.lock();
1175        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1176            return Task::ready(Err(anyhow!("no ssh connection")));
1177        };
1178        connection.upload_directory(src_path, dest_path, cx)
1179    }
1180
1181    pub fn proto_client(&self) -> AnyProtoClient {
1182        self.client.clone().into()
1183    }
1184
1185    pub fn connection_string(&self) -> String {
1186        self.connection_options.connection_string()
1187    }
1188
1189    pub fn connection_options(&self) -> SshConnectionOptions {
1190        self.connection_options.clone()
1191    }
1192
1193    pub fn connection_state(&self) -> ConnectionState {
1194        self.state
1195            .lock()
1196            .as_ref()
1197            .map(ConnectionState::from)
1198            .unwrap_or(ConnectionState::Disconnected)
1199    }
1200
1201    pub fn is_disconnected(&self) -> bool {
1202        self.connection_state() == ConnectionState::Disconnected
1203    }
1204
1205    pub fn path_style(&self) -> PathStyle {
1206        self.path_style
1207    }
1208
1209    #[cfg(any(test, feature = "test-support"))]
1210    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1211        let opts = self.connection_options();
1212        client_cx.spawn(async move |cx| {
1213            let connection = cx
1214                .update_global(|c: &mut ConnectionPool, _| {
1215                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1216                        c.clone()
1217                    } else {
1218                        panic!("missing test connection")
1219                    }
1220                })
1221                .unwrap()
1222                .await
1223                .unwrap();
1224
1225            connection.simulate_disconnect(&cx);
1226        })
1227    }
1228
1229    #[cfg(any(test, feature = "test-support"))]
1230    pub fn fake_server(
1231        client_cx: &mut gpui::TestAppContext,
1232        server_cx: &mut gpui::TestAppContext,
1233    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1234        let port = client_cx
1235            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1236        let opts = SshConnectionOptions {
1237            host: "<fake>".to_string(),
1238            port: Some(port),
1239            ..Default::default()
1240        };
1241        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1242        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1243        let server_client =
1244            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1245        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1246            connection_options: opts.clone(),
1247            server_cx: fake::SendableCx::new(server_cx),
1248            server_channel: server_client.clone(),
1249        });
1250
1251        client_cx.update(|cx| {
1252            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1253                c.connections.insert(
1254                    opts.clone(),
1255                    ConnectionPoolEntry::Connecting(
1256                        cx.background_spawn({
1257                            let connection = connection.clone();
1258                            async move { Ok(connection.clone()) }
1259                        })
1260                        .shared(),
1261                    ),
1262                );
1263            })
1264        });
1265
1266        (opts, server_client)
1267    }
1268
1269    #[cfg(any(test, feature = "test-support"))]
1270    pub async fn fake_client(
1271        opts: SshConnectionOptions,
1272        client_cx: &mut gpui::TestAppContext,
1273    ) -> Entity<Self> {
1274        let (_tx, rx) = oneshot::channel();
1275        client_cx
1276            .update(|cx| {
1277                Self::new(
1278                    ConnectionIdentifier::setup(),
1279                    opts,
1280                    rx,
1281                    Arc::new(fake::Delegate),
1282                    cx,
1283                )
1284            })
1285            .await
1286            .unwrap()
1287            .unwrap()
1288    }
1289}
1290
1291enum ConnectionPoolEntry {
1292    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1293    Connected(Weak<dyn RemoteConnection>),
1294}
1295
1296#[derive(Default)]
1297struct ConnectionPool {
1298    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1299}
1300
1301impl Global for ConnectionPool {}
1302
1303impl ConnectionPool {
1304    pub fn connect(
1305        &mut self,
1306        opts: SshConnectionOptions,
1307        delegate: &Arc<dyn SshClientDelegate>,
1308        cx: &mut App,
1309    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1310        let connection = self.connections.get(&opts);
1311        match connection {
1312            Some(ConnectionPoolEntry::Connecting(task)) => {
1313                let delegate = delegate.clone();
1314                cx.spawn(async move |cx| {
1315                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1316                })
1317                .detach();
1318                return task.clone();
1319            }
1320            Some(ConnectionPoolEntry::Connected(ssh)) => {
1321                if let Some(ssh) = ssh.upgrade() {
1322                    if !ssh.has_been_killed() {
1323                        return Task::ready(Ok(ssh)).shared();
1324                    }
1325                }
1326                self.connections.remove(&opts);
1327            }
1328            None => {}
1329        }
1330
1331        let task = cx
1332            .spawn({
1333                let opts = opts.clone();
1334                let delegate = delegate.clone();
1335                async move |cx| {
1336                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1337                        .await
1338                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1339
1340                    cx.update_global(|pool: &mut Self, _| {
1341                        debug_assert!(matches!(
1342                            pool.connections.get(&opts),
1343                            Some(ConnectionPoolEntry::Connecting(_))
1344                        ));
1345                        match connection {
1346                            Ok(connection) => {
1347                                pool.connections.insert(
1348                                    opts.clone(),
1349                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1350                                );
1351                                Ok(connection)
1352                            }
1353                            Err(error) => {
1354                                pool.connections.remove(&opts);
1355                                Err(Arc::new(error))
1356                            }
1357                        }
1358                    })?
1359                }
1360            })
1361            .shared();
1362
1363        self.connections
1364            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1365        task
1366    }
1367}
1368
1369impl From<SshRemoteClient> for AnyProtoClient {
1370    fn from(client: SshRemoteClient) -> Self {
1371        AnyProtoClient::new(client.client.clone())
1372    }
1373}
1374
1375#[async_trait(?Send)]
1376trait RemoteConnection: Send + Sync {
1377    fn start_proxy(
1378        &self,
1379        unique_identifier: String,
1380        reconnect: bool,
1381        incoming_tx: UnboundedSender<Envelope>,
1382        outgoing_rx: UnboundedReceiver<Envelope>,
1383        connection_activity_tx: Sender<()>,
1384        delegate: Arc<dyn SshClientDelegate>,
1385        cx: &mut AsyncApp,
1386    ) -> Task<Result<i32>>;
1387    fn upload_directory(
1388        &self,
1389        src_path: PathBuf,
1390        dest_path: RemotePathBuf,
1391        cx: &App,
1392    ) -> Task<Result<()>>;
1393    async fn kill(&self) -> Result<()>;
1394    fn has_been_killed(&self) -> bool;
1395    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1396    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1397    fn ssh_args(&self) -> SshArgs;
1398    fn connection_options(&self) -> SshConnectionOptions;
1399    fn path_style(&self) -> PathStyle;
1400
1401    #[cfg(any(test, feature = "test-support"))]
1402    fn simulate_disconnect(&self, _: &AsyncApp) {}
1403}
1404
1405struct SshRemoteConnection {
1406    socket: SshSocket,
1407    master_process: Mutex<Option<Child>>,
1408    remote_binary_path: Option<RemotePathBuf>,
1409    ssh_platform: SshPlatform,
1410    ssh_path_style: PathStyle,
1411    _temp_dir: TempDir,
1412}
1413
1414#[async_trait(?Send)]
1415impl RemoteConnection for SshRemoteConnection {
1416    async fn kill(&self) -> Result<()> {
1417        let Some(mut process) = self.master_process.lock().take() else {
1418            return Ok(());
1419        };
1420        process.kill().ok();
1421        process.status().await?;
1422        Ok(())
1423    }
1424
1425    fn has_been_killed(&self) -> bool {
1426        self.master_process.lock().is_none()
1427    }
1428
1429    fn ssh_args(&self) -> SshArgs {
1430        self.socket.ssh_args()
1431    }
1432
1433    fn connection_options(&self) -> SshConnectionOptions {
1434        self.socket.connection_options.clone()
1435    }
1436
1437    fn upload_directory(
1438        &self,
1439        src_path: PathBuf,
1440        dest_path: RemotePathBuf,
1441        cx: &App,
1442    ) -> Task<Result<()>> {
1443        let mut command = util::command::new_smol_command("scp");
1444        let output = self
1445            .socket
1446            .ssh_options(&mut command)
1447            .args(
1448                self.socket
1449                    .connection_options
1450                    .port
1451                    .map(|port| vec!["-P".to_string(), port.to_string()])
1452                    .unwrap_or_default(),
1453            )
1454            .arg("-C")
1455            .arg("-r")
1456            .arg(&src_path)
1457            .arg(format!(
1458                "{}:{}",
1459                self.socket.connection_options.scp_url(),
1460                dest_path.to_string()
1461            ))
1462            .output();
1463
1464        cx.background_spawn(async move {
1465            let output = output.await?;
1466
1467            anyhow::ensure!(
1468                output.status.success(),
1469                "failed to upload directory {} -> {}: {}",
1470                src_path.display(),
1471                dest_path.to_string(),
1472                String::from_utf8_lossy(&output.stderr)
1473            );
1474
1475            Ok(())
1476        })
1477    }
1478
1479    fn start_proxy(
1480        &self,
1481        unique_identifier: String,
1482        reconnect: bool,
1483        incoming_tx: UnboundedSender<Envelope>,
1484        outgoing_rx: UnboundedReceiver<Envelope>,
1485        connection_activity_tx: Sender<()>,
1486        delegate: Arc<dyn SshClientDelegate>,
1487        cx: &mut AsyncApp,
1488    ) -> Task<Result<i32>> {
1489        delegate.set_status(Some("Starting proxy"), cx);
1490
1491        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1492            return Task::ready(Err(anyhow!("Remote binary path not set")));
1493        };
1494
1495        let mut start_proxy_command = shell_script!(
1496            "exec {binary_path} proxy --identifier {identifier}",
1497            binary_path = &remote_binary_path.to_string(),
1498            identifier = &unique_identifier,
1499        );
1500
1501        if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1502            start_proxy_command = format!(
1503                "RUST_LOG={} {}",
1504                shlex::try_quote(&rust_log).unwrap(),
1505                start_proxy_command
1506            )
1507        }
1508        if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1509            start_proxy_command = format!(
1510                "RUST_BACKTRACE={} {}",
1511                shlex::try_quote(&rust_backtrace).unwrap(),
1512                start_proxy_command
1513            )
1514        }
1515        if reconnect {
1516            start_proxy_command.push_str(" --reconnect");
1517        }
1518
1519        let ssh_proxy_process = match self
1520            .socket
1521            .ssh_command("sh", &["-c", &start_proxy_command])
1522            // IMPORTANT: we kill this process when we drop the task that uses it.
1523            .kill_on_drop(true)
1524            .spawn()
1525        {
1526            Ok(process) => process,
1527            Err(error) => {
1528                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1529            }
1530        };
1531
1532        Self::multiplex(
1533            ssh_proxy_process,
1534            incoming_tx,
1535            outgoing_rx,
1536            connection_activity_tx,
1537            &cx,
1538        )
1539    }
1540
1541    fn path_style(&self) -> PathStyle {
1542        self.ssh_path_style
1543    }
1544}
1545
1546impl SshRemoteConnection {
1547    async fn new(
1548        connection_options: SshConnectionOptions,
1549        delegate: Arc<dyn SshClientDelegate>,
1550        cx: &mut AsyncApp,
1551    ) -> Result<Self> {
1552        use askpass::AskPassResult;
1553
1554        delegate.set_status(Some("Connecting"), cx);
1555
1556        let url = connection_options.ssh_url();
1557
1558        let temp_dir = tempfile::Builder::new()
1559            .prefix("zed-ssh-session")
1560            .tempdir()?;
1561        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1562            let delegate = delegate.clone();
1563            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1564        });
1565
1566        let mut askpass =
1567            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1568
1569        // Start the master SSH process, which does not do anything except for establish
1570        // the connection and keep it open, allowing other ssh commands to reuse it
1571        // via a control socket.
1572        #[cfg(not(target_os = "windows"))]
1573        let socket_path = temp_dir.path().join("ssh.sock");
1574
1575        let mut master_process = {
1576            #[cfg(not(target_os = "windows"))]
1577            let args = [
1578                "-N",
1579                "-o",
1580                "ControlPersist=no",
1581                "-o",
1582                "ControlMaster=yes",
1583                "-o",
1584            ];
1585            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1586            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1587            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1588            #[cfg(target_os = "windows")]
1589            let args = ["-N"];
1590            let mut master_process = process::Command::new("ssh");
1591            master_process
1592                .kill_on_drop(true)
1593                .stdin(Stdio::null())
1594                .stdout(Stdio::piped())
1595                .stderr(Stdio::piped())
1596                .env("SSH_ASKPASS_REQUIRE", "force")
1597                .env("SSH_ASKPASS", askpass.script_path())
1598                .args(connection_options.additional_args())
1599                .args(args);
1600            #[cfg(not(target_os = "windows"))]
1601            master_process.arg(format!("ControlPath={}", socket_path.display()));
1602            master_process.arg(&url).spawn()?
1603        };
1604        // Wait for this ssh process to close its stdout, indicating that authentication
1605        // has completed.
1606        let mut stdout = master_process.stdout.take().unwrap();
1607        let mut output = Vec::new();
1608
1609        let result = select_biased! {
1610            result = askpass.run().fuse() => {
1611                match result {
1612                    AskPassResult::CancelledByUser => {
1613                        master_process.kill().ok();
1614                        anyhow::bail!("SSH connection canceled")
1615                    }
1616                    AskPassResult::Timedout => {
1617                        anyhow::bail!("connecting to host timed out")
1618                    }
1619                }
1620            }
1621            _ = stdout.read_to_end(&mut output).fuse() => {
1622                anyhow::Ok(())
1623            }
1624        };
1625
1626        if let Err(e) = result {
1627            return Err(e.context("Failed to connect to host"));
1628        }
1629
1630        if master_process.try_status()?.is_some() {
1631            output.clear();
1632            let mut stderr = master_process.stderr.take().unwrap();
1633            stderr.read_to_end(&mut output).await?;
1634
1635            let error_message = format!(
1636                "failed to connect: {}",
1637                String::from_utf8_lossy(&output).trim()
1638            );
1639            anyhow::bail!(error_message);
1640        }
1641
1642        #[cfg(not(target_os = "windows"))]
1643        let socket = SshSocket::new(connection_options, socket_path)?;
1644        #[cfg(target_os = "windows")]
1645        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1646        drop(askpass);
1647
1648        let ssh_platform = socket.platform().await?;
1649        let ssh_path_style = match ssh_platform.os {
1650            "windows" => PathStyle::Windows,
1651            _ => PathStyle::Posix,
1652        };
1653
1654        let mut this = Self {
1655            socket,
1656            master_process: Mutex::new(Some(master_process)),
1657            _temp_dir: temp_dir,
1658            remote_binary_path: None,
1659            ssh_path_style,
1660            ssh_platform,
1661        };
1662
1663        let (release_channel, version, commit) = cx.update(|cx| {
1664            (
1665                ReleaseChannel::global(cx),
1666                AppVersion::global(cx),
1667                AppCommitSha::try_global(cx),
1668            )
1669        })?;
1670        this.remote_binary_path = Some(
1671            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1672                .await?,
1673        );
1674
1675        Ok(this)
1676    }
1677
1678    fn multiplex(
1679        mut ssh_proxy_process: Child,
1680        incoming_tx: UnboundedSender<Envelope>,
1681        mut outgoing_rx: UnboundedReceiver<Envelope>,
1682        mut connection_activity_tx: Sender<()>,
1683        cx: &AsyncApp,
1684    ) -> Task<Result<i32>> {
1685        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1686        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1687        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1688
1689        let mut stdin_buffer = Vec::new();
1690        let mut stdout_buffer = Vec::new();
1691        let mut stderr_buffer = Vec::new();
1692        let mut stderr_offset = 0;
1693
1694        let stdin_task = cx.background_spawn(async move {
1695            while let Some(outgoing) = outgoing_rx.next().await {
1696                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1697            }
1698            anyhow::Ok(())
1699        });
1700
1701        let stdout_task = cx.background_spawn({
1702            let mut connection_activity_tx = connection_activity_tx.clone();
1703            async move {
1704                loop {
1705                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1706                    let len = child_stdout.read(&mut stdout_buffer).await?;
1707
1708                    if len == 0 {
1709                        return anyhow::Ok(());
1710                    }
1711
1712                    if len < MESSAGE_LEN_SIZE {
1713                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1714                    }
1715
1716                    let message_len = message_len_from_buffer(&stdout_buffer);
1717                    let envelope =
1718                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1719                            .await?;
1720                    connection_activity_tx.try_send(()).ok();
1721                    incoming_tx.unbounded_send(envelope).ok();
1722                }
1723            }
1724        });
1725
1726        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1727            loop {
1728                stderr_buffer.resize(stderr_offset + 1024, 0);
1729
1730                let len = child_stderr
1731                    .read(&mut stderr_buffer[stderr_offset..])
1732                    .await?;
1733                if len == 0 {
1734                    return anyhow::Ok(());
1735                }
1736
1737                stderr_offset += len;
1738                let mut start_ix = 0;
1739                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1740                    .iter()
1741                    .position(|b| b == &b'\n')
1742                {
1743                    let line_ix = start_ix + ix;
1744                    let content = &stderr_buffer[start_ix..line_ix];
1745                    start_ix = line_ix + 1;
1746                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1747                        record.log(log::logger())
1748                    } else {
1749                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1750                    }
1751                }
1752                stderr_buffer.drain(0..start_ix);
1753                stderr_offset -= start_ix;
1754
1755                connection_activity_tx.try_send(()).ok();
1756            }
1757        });
1758
1759        cx.spawn(async move |_| {
1760            let result = futures::select! {
1761                result = stdin_task.fuse() => {
1762                    result.context("stdin")
1763                }
1764                result = stdout_task.fuse() => {
1765                    result.context("stdout")
1766                }
1767                result = stderr_task.fuse() => {
1768                    result.context("stderr")
1769                }
1770            };
1771
1772            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1773            match result {
1774                Ok(_) => Ok(status),
1775                Err(error) => Err(error),
1776            }
1777        })
1778    }
1779
1780    #[allow(unused)]
1781    async fn ensure_server_binary(
1782        &self,
1783        delegate: &Arc<dyn SshClientDelegate>,
1784        release_channel: ReleaseChannel,
1785        version: SemanticVersion,
1786        commit: Option<AppCommitSha>,
1787        cx: &mut AsyncApp,
1788    ) -> Result<RemotePathBuf> {
1789        let version_str = match release_channel {
1790            ReleaseChannel::Nightly => {
1791                let commit = commit.map(|s| s.full()).unwrap_or_default();
1792                format!("{}-{}", version, commit)
1793            }
1794            ReleaseChannel::Dev => "build".to_string(),
1795            _ => version.to_string(),
1796        };
1797        let binary_name = format!(
1798            "zed-remote-server-{}-{}",
1799            release_channel.dev_name(),
1800            version_str
1801        );
1802        let dst_path = RemotePathBuf::new(
1803            paths::remote_server_dir_relative().join(binary_name),
1804            self.ssh_path_style,
1805        );
1806
1807        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1808        #[cfg(debug_assertions)]
1809        if let Some(build_remote_server) = build_remote_server {
1810            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1811            let tmp_path = RemotePathBuf::new(
1812                paths::remote_server_dir_relative().join(format!(
1813                    "download-{}-{}",
1814                    std::process::id(),
1815                    src_path.file_name().unwrap().to_string_lossy()
1816                )),
1817                self.ssh_path_style,
1818            );
1819            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1820                .await?;
1821            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1822                .await?;
1823            return Ok(dst_path);
1824        }
1825
1826        if self
1827            .socket
1828            .run_command(&dst_path.to_string(), &["version"])
1829            .await
1830            .is_ok()
1831        {
1832            return Ok(dst_path);
1833        }
1834
1835        let wanted_version = cx.update(|cx| match release_channel {
1836            ReleaseChannel::Nightly => Ok(None),
1837            ReleaseChannel::Dev => {
1838                anyhow::bail!(
1839                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1840                    dst_path
1841                )
1842            }
1843            _ => Ok(Some(AppVersion::global(cx))),
1844        })??;
1845
1846        let tmp_path_gz = RemotePathBuf::new(
1847            PathBuf::from(format!(
1848                "{}-download-{}.gz",
1849                dst_path.to_string(),
1850                std::process::id()
1851            )),
1852            self.ssh_path_style,
1853        );
1854        if !self.socket.connection_options.upload_binary_over_ssh {
1855            if let Some((url, body)) = delegate
1856                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1857                .await?
1858            {
1859                match self
1860                    .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1861                    .await
1862                {
1863                    Ok(_) => {
1864                        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1865                            .await?;
1866                        return Ok(dst_path);
1867                    }
1868                    Err(e) => {
1869                        log::error!(
1870                            "Failed to download binary on server, attempting to upload server: {}",
1871                            e
1872                        )
1873                    }
1874                }
1875            }
1876        }
1877
1878        let src_path = delegate
1879            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1880            .await?;
1881        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1882            .await?;
1883        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1884            .await?;
1885        return Ok(dst_path);
1886    }
1887
1888    async fn download_binary_on_server(
1889        &self,
1890        url: &str,
1891        body: &str,
1892        tmp_path_gz: &RemotePathBuf,
1893        delegate: &Arc<dyn SshClientDelegate>,
1894        cx: &mut AsyncApp,
1895    ) -> Result<()> {
1896        if let Some(parent) = tmp_path_gz.parent() {
1897            self.socket
1898                .run_command(
1899                    "sh",
1900                    &[
1901                        "-c",
1902                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1903                    ],
1904                )
1905                .await?;
1906        }
1907
1908        delegate.set_status(Some("Downloading remote development server on host"), cx);
1909
1910        match self
1911            .socket
1912            .run_command(
1913                "curl",
1914                &[
1915                    "-f",
1916                    "-L",
1917                    "-X",
1918                    "GET",
1919                    "-H",
1920                    "Content-Type: application/json",
1921                    "-d",
1922                    &body,
1923                    &url,
1924                    "-o",
1925                    &tmp_path_gz.to_string(),
1926                ],
1927            )
1928            .await
1929        {
1930            Ok(_) => {}
1931            Err(e) => {
1932                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1933                    return Err(e);
1934                }
1935
1936                match self
1937                    .socket
1938                    .run_command(
1939                        "wget",
1940                        &[
1941                            "--method=GET",
1942                            "--header=Content-Type: application/json",
1943                            "--body-data",
1944                            &body,
1945                            &url,
1946                            "-O",
1947                            &tmp_path_gz.to_string(),
1948                        ],
1949                    )
1950                    .await
1951                {
1952                    Ok(_) => {}
1953                    Err(e) => {
1954                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1955                            return Err(e);
1956                        } else {
1957                            anyhow::bail!("Neither curl nor wget is available");
1958                        }
1959                    }
1960                }
1961            }
1962        }
1963
1964        Ok(())
1965    }
1966
1967    async fn upload_local_server_binary(
1968        &self,
1969        src_path: &Path,
1970        tmp_path_gz: &RemotePathBuf,
1971        delegate: &Arc<dyn SshClientDelegate>,
1972        cx: &mut AsyncApp,
1973    ) -> Result<()> {
1974        if let Some(parent) = tmp_path_gz.parent() {
1975            self.socket
1976                .run_command(
1977                    "sh",
1978                    &[
1979                        "-c",
1980                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1981                    ],
1982                )
1983                .await?;
1984        }
1985
1986        let src_stat = fs::metadata(&src_path).await?;
1987        let size = src_stat.len();
1988
1989        let t0 = Instant::now();
1990        delegate.set_status(Some("Uploading remote development server"), cx);
1991        log::info!(
1992            "uploading remote development server to {:?} ({}kb)",
1993            tmp_path_gz,
1994            size / 1024
1995        );
1996        self.upload_file(&src_path, &tmp_path_gz)
1997            .await
1998            .context("failed to upload server binary")?;
1999        log::info!("uploaded remote development server in {:?}", t0.elapsed());
2000        Ok(())
2001    }
2002
2003    async fn extract_server_binary(
2004        &self,
2005        dst_path: &RemotePathBuf,
2006        tmp_path: &RemotePathBuf,
2007        delegate: &Arc<dyn SshClientDelegate>,
2008        cx: &mut AsyncApp,
2009    ) -> Result<()> {
2010        delegate.set_status(Some("Extracting remote development server"), cx);
2011        let server_mode = 0o755;
2012
2013        let orig_tmp_path = tmp_path.to_string();
2014        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2015            shell_script!(
2016                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2017                server_mode = &format!("{:o}", server_mode),
2018                dst_path = &dst_path.to_string(),
2019            )
2020        } else {
2021            shell_script!(
2022                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2023                server_mode = &format!("{:o}", server_mode),
2024                dst_path = &dst_path.to_string()
2025            )
2026        };
2027        self.socket.run_command("sh", &["-c", &script]).await?;
2028        Ok(())
2029    }
2030
2031    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2032        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2033        let mut command = util::command::new_smol_command("scp");
2034        let output = self
2035            .socket
2036            .ssh_options(&mut command)
2037            .args(
2038                self.socket
2039                    .connection_options
2040                    .port
2041                    .map(|port| vec!["-P".to_string(), port.to_string()])
2042                    .unwrap_or_default(),
2043            )
2044            .arg(src_path)
2045            .arg(format!(
2046                "{}:{}",
2047                self.socket.connection_options.scp_url(),
2048                dest_path.to_string()
2049            ))
2050            .output()
2051            .await?;
2052
2053        anyhow::ensure!(
2054            output.status.success(),
2055            "failed to upload file {} -> {}: {}",
2056            src_path.display(),
2057            dest_path.to_string(),
2058            String::from_utf8_lossy(&output.stderr)
2059        );
2060        Ok(())
2061    }
2062
2063    #[cfg(debug_assertions)]
2064    async fn build_local(
2065        &self,
2066        build_remote_server: String,
2067        delegate: &Arc<dyn SshClientDelegate>,
2068        cx: &mut AsyncApp,
2069    ) -> Result<PathBuf> {
2070        use smol::process::{Command, Stdio};
2071
2072        async fn run_cmd(command: &mut Command) -> Result<()> {
2073            let output = command
2074                .kill_on_drop(true)
2075                .stderr(Stdio::inherit())
2076                .output()
2077                .await?;
2078            anyhow::ensure!(
2079                output.status.success(),
2080                "Failed to run command: {command:?}"
2081            );
2082            Ok(())
2083        }
2084
2085        if self.ssh_platform.arch == std::env::consts::ARCH
2086            && self.ssh_platform.os == std::env::consts::OS
2087        {
2088            delegate.set_status(Some("Building remote server binary from source"), cx);
2089            log::info!("building remote server binary from source");
2090            run_cmd(Command::new("cargo").args([
2091                "build",
2092                "--package",
2093                "remote_server",
2094                "--features",
2095                "debug-embed",
2096                "--target-dir",
2097                "target/remote_server",
2098            ]))
2099            .await?;
2100
2101            delegate.set_status(Some("Compressing binary"), cx);
2102
2103            run_cmd(Command::new("gzip").args([
2104                "-9",
2105                "-f",
2106                "target/remote_server/debug/remote_server",
2107            ]))
2108            .await?;
2109
2110            let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
2111            return Ok(path);
2112        }
2113        let Some(triple) = self.ssh_platform.triple() else {
2114            anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform);
2115        };
2116        smol::fs::create_dir_all("target/remote_server").await?;
2117
2118        if build_remote_server.contains("cross") {
2119            #[cfg(target_os = "windows")]
2120            use util::paths::SanitizedPath;
2121
2122            delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2123            log::info!("installing cross");
2124            run_cmd(Command::new("cargo").args([
2125                "install",
2126                "cross",
2127                "--git",
2128                "https://github.com/cross-rs/cross",
2129            ]))
2130            .await?;
2131
2132            delegate.set_status(
2133                Some(&format!(
2134                    "Building remote server binary from source for {} with Docker",
2135                    &triple
2136                )),
2137                cx,
2138            );
2139            log::info!("building remote server binary from source for {}", &triple);
2140
2141            // On Windows, the binding needs to be set to the canonical path
2142            #[cfg(target_os = "windows")]
2143            let src =
2144                SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2145            #[cfg(not(target_os = "windows"))]
2146            let src = "./target";
2147            run_cmd(
2148                Command::new("cross")
2149                    .args([
2150                        "build",
2151                        "--package",
2152                        "remote_server",
2153                        "--features",
2154                        "debug-embed",
2155                        "--target-dir",
2156                        "target/remote_server",
2157                        "--target",
2158                        &triple,
2159                    ])
2160                    .env(
2161                        "CROSS_CONTAINER_OPTS",
2162                        format!("--mount type=bind,src={src},dst=/app/target"),
2163                    ),
2164            )
2165            .await?;
2166        } else {
2167            let which = cx
2168                .background_spawn(async move { which::which("zig") })
2169                .await;
2170
2171            if which.is_err() {
2172                #[cfg(not(target_os = "windows"))]
2173                {
2174                    anyhow::bail!(
2175                        "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2176                    )
2177                }
2178                #[cfg(target_os = "windows")]
2179                {
2180                    anyhow::bail!(
2181                        "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2182                    )
2183                }
2184            }
2185
2186            delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2187            log::info!("adding rustup target");
2188            run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2189
2190            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2191            log::info!("installing cargo-zigbuild");
2192            run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
2193
2194            delegate.set_status(
2195                Some(&format!(
2196                    "Building remote binary from source for {triple} with Zig"
2197                )),
2198                cx,
2199            );
2200            log::info!("building remote binary from source for {triple} with Zig");
2201            run_cmd(Command::new("cargo").args([
2202                "zigbuild",
2203                "--package",
2204                "remote_server",
2205                "--features",
2206                "debug-embed",
2207                "--target-dir",
2208                "target/remote_server",
2209                "--target",
2210                &triple,
2211            ]))
2212            .await?;
2213        };
2214
2215        let mut path = format!("target/remote_server/{triple}/debug/remote_server").into();
2216        if !build_remote_server.contains("nocompress") {
2217            delegate.set_status(Some("Compressing binary"), cx);
2218
2219            #[cfg(not(target_os = "windows"))]
2220            {
2221                run_cmd(Command::new("gzip").args([
2222                    "-9",
2223                    "-f",
2224                    &format!("target/remote_server/{}/debug/remote_server", triple),
2225                ]))
2226                .await?;
2227            }
2228            #[cfg(target_os = "windows")]
2229            {
2230                // On Windows, we use 7z to compress the binary
2231                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2232                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2233                if smol::fs::metadata(&gz_path).await.is_ok() {
2234                    smol::fs::remove_file(&gz_path).await?;
2235                }
2236                run_cmd(Command::new(seven_zip).args([
2237                    "a",
2238                    "-tgzip",
2239                    &gz_path,
2240                    &format!("target/remote_server/{}/debug/remote_server", triple),
2241                ]))
2242                .await?;
2243            }
2244
2245            path = std::env::current_dir()?.join(format!(
2246                "target/remote_server/{triple}/debug/remote_server.gz"
2247            ));
2248        }
2249
2250        return Ok(path);
2251    }
2252}
2253
2254type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2255
2256pub struct ChannelClient {
2257    next_message_id: AtomicU32,
2258    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2259    buffer: Mutex<VecDeque<Envelope>>,
2260    response_channels: ResponseChannels,
2261    message_handlers: Mutex<ProtoMessageHandlerSet>,
2262    max_received: AtomicU32,
2263    name: &'static str,
2264    task: Mutex<Task<Result<()>>>,
2265}
2266
2267impl ChannelClient {
2268    pub fn new(
2269        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2270        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2271        cx: &App,
2272        name: &'static str,
2273    ) -> Arc<Self> {
2274        Arc::new_cyclic(|this| Self {
2275            outgoing_tx: Mutex::new(outgoing_tx),
2276            next_message_id: AtomicU32::new(0),
2277            max_received: AtomicU32::new(0),
2278            response_channels: ResponseChannels::default(),
2279            message_handlers: Default::default(),
2280            buffer: Mutex::new(VecDeque::new()),
2281            name,
2282            task: Mutex::new(Self::start_handling_messages(
2283                this.clone(),
2284                incoming_rx,
2285                &cx.to_async(),
2286            )),
2287        })
2288    }
2289
2290    fn start_handling_messages(
2291        this: Weak<Self>,
2292        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2293        cx: &AsyncApp,
2294    ) -> Task<Result<()>> {
2295        cx.spawn(async move |cx| {
2296            let peer_id = PeerId { owner_id: 0, id: 0 };
2297            while let Some(incoming) = incoming_rx.next().await {
2298                let Some(this) = this.upgrade() else {
2299                    return anyhow::Ok(());
2300                };
2301                if let Some(ack_id) = incoming.ack_id {
2302                    let mut buffer = this.buffer.lock();
2303                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2304                        buffer.pop_front();
2305                    }
2306                }
2307                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2308                {
2309                    log::debug!(
2310                        "{}:ssh message received. name:FlushBufferedMessages",
2311                        this.name
2312                    );
2313                    {
2314                        let buffer = this.buffer.lock();
2315                        for envelope in buffer.iter() {
2316                            this.outgoing_tx
2317                                .lock()
2318                                .unbounded_send(envelope.clone())
2319                                .ok();
2320                        }
2321                    }
2322                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2323                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2324                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2325                    continue;
2326                }
2327
2328                this.max_received.store(incoming.id, SeqCst);
2329
2330                if let Some(request_id) = incoming.responding_to {
2331                    let request_id = MessageId(request_id);
2332                    let sender = this.response_channels.lock().remove(&request_id);
2333                    if let Some(sender) = sender {
2334                        let (tx, rx) = oneshot::channel();
2335                        if incoming.payload.is_some() {
2336                            sender.send((incoming, tx)).ok();
2337                        }
2338                        rx.await.ok();
2339                    }
2340                } else if let Some(envelope) =
2341                    build_typed_envelope(peer_id, Instant::now(), incoming)
2342                {
2343                    let type_name = envelope.payload_type_name();
2344                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2345                        &this.message_handlers,
2346                        envelope,
2347                        this.clone().into(),
2348                        cx.clone(),
2349                    ) {
2350                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2351                        cx.foreground_executor()
2352                            .spawn(async move {
2353                                match future.await {
2354                                    Ok(_) => {
2355                                        log::debug!(
2356                                            "{}:ssh message handled. name:{type_name}",
2357                                            this.name
2358                                        );
2359                                    }
2360                                    Err(error) => {
2361                                        log::error!(
2362                                            "{}:error handling message. type:{}, error:{}",
2363                                            this.name,
2364                                            type_name,
2365                                            format!("{error:#}").lines().fold(
2366                                                String::new(),
2367                                                |mut message, line| {
2368                                                    if !message.is_empty() {
2369                                                        message.push(' ');
2370                                                    }
2371                                                    message.push_str(line);
2372                                                    message
2373                                                }
2374                                            )
2375                                        );
2376                                    }
2377                                }
2378                            })
2379                            .detach()
2380                    } else {
2381                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2382                    }
2383                }
2384            }
2385            anyhow::Ok(())
2386        })
2387    }
2388
2389    pub fn reconnect(
2390        self: &Arc<Self>,
2391        incoming_rx: UnboundedReceiver<Envelope>,
2392        outgoing_tx: UnboundedSender<Envelope>,
2393        cx: &AsyncApp,
2394    ) {
2395        *self.outgoing_tx.lock() = outgoing_tx;
2396        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2397    }
2398
2399    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2400        let id = (TypeId::of::<E>(), remote_id);
2401
2402        let mut message_handlers = self.message_handlers.lock();
2403        if message_handlers
2404            .entities_by_type_and_remote_id
2405            .contains_key(&id)
2406        {
2407            panic!("already subscribed to entity");
2408        }
2409
2410        message_handlers.entities_by_type_and_remote_id.insert(
2411            id,
2412            EntityMessageSubscriber::Entity {
2413                handle: entity.downgrade().into(),
2414            },
2415        );
2416    }
2417
2418    pub fn request<T: RequestMessage>(
2419        &self,
2420        payload: T,
2421    ) -> impl 'static + Future<Output = Result<T::Response>> {
2422        self.request_internal(payload, true)
2423    }
2424
2425    fn request_internal<T: RequestMessage>(
2426        &self,
2427        payload: T,
2428        use_buffer: bool,
2429    ) -> impl 'static + Future<Output = Result<T::Response>> {
2430        log::debug!("ssh request start. name:{}", T::NAME);
2431        let response =
2432            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2433        async move {
2434            let response = response.await?;
2435            log::debug!("ssh request finish. name:{}", T::NAME);
2436            T::Response::from_envelope(response).context("received a response of the wrong type")
2437        }
2438    }
2439
2440    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2441        smol::future::or(
2442            async {
2443                self.request_internal(proto::FlushBufferedMessages {}, false)
2444                    .await?;
2445
2446                for envelope in self.buffer.lock().iter() {
2447                    self.outgoing_tx
2448                        .lock()
2449                        .unbounded_send(envelope.clone())
2450                        .ok();
2451                }
2452                Ok(())
2453            },
2454            async {
2455                smol::Timer::after(timeout).await;
2456                anyhow::bail!("Timeout detected")
2457            },
2458        )
2459        .await
2460    }
2461
2462    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2463        smol::future::or(
2464            async {
2465                self.request(proto::Ping {}).await?;
2466                Ok(())
2467            },
2468            async {
2469                smol::Timer::after(timeout).await;
2470                anyhow::bail!("Timeout detected")
2471            },
2472        )
2473        .await
2474    }
2475
2476    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2477        log::debug!("ssh send name:{}", T::NAME);
2478        self.send_dynamic(payload.into_envelope(0, None, None))
2479    }
2480
2481    fn request_dynamic(
2482        &self,
2483        mut envelope: proto::Envelope,
2484        type_name: &'static str,
2485        use_buffer: bool,
2486    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2487        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2488        let (tx, rx) = oneshot::channel();
2489        let mut response_channels_lock = self.response_channels.lock();
2490        response_channels_lock.insert(MessageId(envelope.id), tx);
2491        drop(response_channels_lock);
2492
2493        let result = if use_buffer {
2494            self.send_buffered(envelope)
2495        } else {
2496            self.send_unbuffered(envelope)
2497        };
2498        async move {
2499            if let Err(error) = &result {
2500                log::error!("failed to send message: {error}");
2501                anyhow::bail!("failed to send message: {error}");
2502            }
2503
2504            let response = rx.await.context("connection lost")?.0;
2505            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2506                return Err(RpcError::from_proto(error, type_name));
2507            }
2508            Ok(response)
2509        }
2510    }
2511
2512    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2513        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2514        self.send_buffered(envelope)
2515    }
2516
2517    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2518        envelope.ack_id = Some(self.max_received.load(SeqCst));
2519        self.buffer.lock().push_back(envelope.clone());
2520        // ignore errors on send (happen while we're reconnecting)
2521        // assume that the global "disconnected" overlay is sufficient.
2522        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2523        Ok(())
2524    }
2525
2526    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2527        envelope.ack_id = Some(self.max_received.load(SeqCst));
2528        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2529        Ok(())
2530    }
2531}
2532
2533impl ProtoClient for ChannelClient {
2534    fn request(
2535        &self,
2536        envelope: proto::Envelope,
2537        request_type: &'static str,
2538    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2539        self.request_dynamic(envelope, request_type, true).boxed()
2540    }
2541
2542    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2543        self.send_dynamic(envelope)
2544    }
2545
2546    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2547        self.send_dynamic(envelope)
2548    }
2549
2550    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2551        &self.message_handlers
2552    }
2553
2554    fn is_via_collab(&self) -> bool {
2555        false
2556    }
2557}
2558
2559#[cfg(any(test, feature = "test-support"))]
2560mod fake {
2561    use std::{path::PathBuf, sync::Arc};
2562
2563    use anyhow::Result;
2564    use async_trait::async_trait;
2565    use futures::{
2566        FutureExt, SinkExt, StreamExt,
2567        channel::{
2568            mpsc::{self, Sender},
2569            oneshot,
2570        },
2571        select_biased,
2572    };
2573    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2574    use release_channel::ReleaseChannel;
2575    use rpc::proto::Envelope;
2576    use util::paths::{PathStyle, RemotePathBuf};
2577
2578    use super::{
2579        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2580        SshPlatform,
2581    };
2582
2583    pub(super) struct FakeRemoteConnection {
2584        pub(super) connection_options: SshConnectionOptions,
2585        pub(super) server_channel: Arc<ChannelClient>,
2586        pub(super) server_cx: SendableCx,
2587    }
2588
2589    pub(super) struct SendableCx(AsyncApp);
2590    impl SendableCx {
2591        // SAFETY: When run in test mode, GPUI is always single threaded.
2592        pub(super) fn new(cx: &TestAppContext) -> Self {
2593            Self(cx.to_async())
2594        }
2595
2596        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2597        fn get(&self, _: &AsyncApp) -> AsyncApp {
2598            self.0.clone()
2599        }
2600    }
2601
2602    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2603    unsafe impl Send for SendableCx {}
2604    unsafe impl Sync for SendableCx {}
2605
2606    #[async_trait(?Send)]
2607    impl RemoteConnection for FakeRemoteConnection {
2608        async fn kill(&self) -> Result<()> {
2609            Ok(())
2610        }
2611
2612        fn has_been_killed(&self) -> bool {
2613            false
2614        }
2615
2616        fn ssh_args(&self) -> SshArgs {
2617            SshArgs {
2618                arguments: Vec::new(),
2619                envs: None,
2620            }
2621        }
2622
2623        fn upload_directory(
2624            &self,
2625            _src_path: PathBuf,
2626            _dest_path: RemotePathBuf,
2627            _cx: &App,
2628        ) -> Task<Result<()>> {
2629            unreachable!()
2630        }
2631
2632        fn connection_options(&self) -> SshConnectionOptions {
2633            self.connection_options.clone()
2634        }
2635
2636        fn simulate_disconnect(&self, cx: &AsyncApp) {
2637            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2638            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2639            self.server_channel
2640                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2641        }
2642
2643        fn start_proxy(
2644            &self,
2645            _unique_identifier: String,
2646            _reconnect: bool,
2647            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2648            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2649            mut connection_activity_tx: Sender<()>,
2650            _delegate: Arc<dyn SshClientDelegate>,
2651            cx: &mut AsyncApp,
2652        ) -> Task<Result<i32>> {
2653            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2654            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2655
2656            self.server_channel.reconnect(
2657                server_incoming_rx,
2658                server_outgoing_tx,
2659                &self.server_cx.get(cx),
2660            );
2661
2662            cx.background_spawn(async move {
2663                loop {
2664                    select_biased! {
2665                        server_to_client = server_outgoing_rx.next().fuse() => {
2666                            let Some(server_to_client) = server_to_client else {
2667                                return Ok(1)
2668                            };
2669                            connection_activity_tx.try_send(()).ok();
2670                            client_incoming_tx.send(server_to_client).await.ok();
2671                        }
2672                        client_to_server = client_outgoing_rx.next().fuse() => {
2673                            let Some(client_to_server) = client_to_server else {
2674                                return Ok(1)
2675                            };
2676                            server_incoming_tx.send(client_to_server).await.ok();
2677                        }
2678                    }
2679                }
2680            })
2681        }
2682
2683        fn path_style(&self) -> PathStyle {
2684            PathStyle::current()
2685        }
2686    }
2687
2688    pub(super) struct Delegate;
2689
2690    impl SshClientDelegate for Delegate {
2691        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2692            unreachable!()
2693        }
2694
2695        fn download_server_binary_locally(
2696            &self,
2697            _: SshPlatform,
2698            _: ReleaseChannel,
2699            _: Option<SemanticVersion>,
2700            _: &mut AsyncApp,
2701        ) -> Task<Result<PathBuf>> {
2702            unreachable!()
2703        }
2704
2705        fn get_download_params(
2706            &self,
2707            _platform: SshPlatform,
2708            _release_channel: ReleaseChannel,
2709            _version: Option<SemanticVersion>,
2710            _cx: &mut AsyncApp,
2711        ) -> Task<Result<Option<(String, String)>>> {
2712            unreachable!()
2713        }
2714
2715        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2716    }
2717}