ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  30    RpcError,
  31    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  32};
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize};
  35use smol::{
  36    fs,
  37    process::{self, Child, Stdio},
  38};
  39use std::{
  40    any::TypeId,
  41    collections::VecDeque,
  42    fmt, iter,
  43    ops::ControlFlow,
  44    path::{Path, PathBuf},
  45    sync::{
  46        Arc, Weak,
  47        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  48    },
  49    time::{Duration, Instant},
  50};
  51use tempfile::TempDir;
  52use util::{
  53    ResultExt,
  54    paths::{PathStyle, RemotePathBuf},
  55};
  56
  57#[derive(
  58    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  59)]
  60pub struct SshProjectId(pub u64);
  61
  62#[derive(Clone)]
  63pub struct SshSocket {
  64    connection_options: SshConnectionOptions,
  65    #[cfg(not(target_os = "windows"))]
  66    socket_path: PathBuf,
  67    #[cfg(target_os = "windows")]
  68    envs: HashMap<String, String>,
  69}
  70
  71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  72pub struct SshPortForwardOption {
  73    #[serde(skip_serializing_if = "Option::is_none")]
  74    pub local_host: Option<String>,
  75    pub local_port: u16,
  76    #[serde(skip_serializing_if = "Option::is_none")]
  77    pub remote_host: Option<String>,
  78    pub remote_port: u16,
  79}
  80
  81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  82pub struct SshConnectionOptions {
  83    pub host: String,
  84    pub username: Option<String>,
  85    pub port: Option<u16>,
  86    pub password: Option<String>,
  87    pub args: Option<Vec<String>>,
  88    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  89
  90    pub nickname: Option<String>,
  91    pub upload_binary_over_ssh: bool,
  92}
  93
  94pub struct SshArgs {
  95    pub arguments: Vec<String>,
  96    pub envs: Option<HashMap<String, String>>,
  97}
  98
  99#[macro_export]
 100macro_rules! shell_script {
 101    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 102        format!(
 103            $fmt,
 104            $(
 105                $name = shlex::try_quote($arg).unwrap()
 106            ),+
 107        )
 108    }};
 109}
 110
 111fn parse_port_number(port_str: &str) -> Result<u16> {
 112    port_str
 113        .parse()
 114        .with_context(|| format!("parsing port number: {port_str}"))
 115}
 116
 117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 118    let parts: Vec<&str> = spec.split(':').collect();
 119
 120    match parts.len() {
 121        4 => {
 122            let local_port = parse_port_number(parts[1])?;
 123            let remote_port = parse_port_number(parts[3])?;
 124
 125            Ok(SshPortForwardOption {
 126                local_host: Some(parts[0].to_string()),
 127                local_port,
 128                remote_host: Some(parts[2].to_string()),
 129                remote_port,
 130            })
 131        }
 132        3 => {
 133            let local_port = parse_port_number(parts[0])?;
 134            let remote_port = parse_port_number(parts[2])?;
 135
 136            Ok(SshPortForwardOption {
 137                local_host: None,
 138                local_port,
 139                remote_host: Some(parts[1].to_string()),
 140                remote_port,
 141            })
 142        }
 143        _ => anyhow::bail!("Invalid port forward format"),
 144    }
 145}
 146
 147impl SshConnectionOptions {
 148    pub fn parse_command_line(input: &str) -> Result<Self> {
 149        let input = input.trim_start_matches("ssh ");
 150        let mut hostname: Option<String> = None;
 151        let mut username: Option<String> = None;
 152        let mut port: Option<u16> = None;
 153        let mut args = Vec::new();
 154        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 155
 156        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 157        const ALLOWED_OPTS: &[&str] = &[
 158            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 159        ];
 160        const ALLOWED_ARGS: &[&str] = &[
 161            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 162            "-w",
 163        ];
 164
 165        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 166
 167        'outer: while let Some(arg) = tokens.next() {
 168            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 169                args.push(arg.to_string());
 170                continue;
 171            }
 172            if arg == "-p" {
 173                port = tokens.next().and_then(|arg| arg.parse().ok());
 174                continue;
 175            } else if let Some(p) = arg.strip_prefix("-p") {
 176                port = p.parse().ok();
 177                continue;
 178            }
 179            if arg == "-l" {
 180                username = tokens.next();
 181                continue;
 182            } else if let Some(l) = arg.strip_prefix("-l") {
 183                username = Some(l.to_string());
 184                continue;
 185            }
 186            if arg == "-L" || arg.starts_with("-L") {
 187                let forward_spec = if arg == "-L" {
 188                    tokens.next()
 189                } else {
 190                    Some(arg.strip_prefix("-L").unwrap().to_string())
 191                };
 192
 193                if let Some(spec) = forward_spec {
 194                    port_forwards.push(parse_port_forward_spec(&spec)?);
 195                } else {
 196                    anyhow::bail!("Missing port forward format");
 197                }
 198            }
 199
 200            for a in ALLOWED_ARGS {
 201                if arg == *a {
 202                    args.push(arg);
 203                    if let Some(next) = tokens.next() {
 204                        args.push(next);
 205                    }
 206                    continue 'outer;
 207                } else if arg.starts_with(a) {
 208                    args.push(arg);
 209                    continue 'outer;
 210                }
 211            }
 212            if arg.starts_with("-") || hostname.is_some() {
 213                anyhow::bail!("unsupported argument: {:?}", arg);
 214            }
 215            let mut input = &arg as &str;
 216            // Destination might be: username1@username2@ip2@ip1
 217            if let Some((u, rest)) = input.rsplit_once('@') {
 218                input = rest;
 219                username = Some(u.to_string());
 220            }
 221            if let Some((rest, p)) = input.split_once(':') {
 222                input = rest;
 223                port = p.parse().ok()
 224            }
 225            hostname = Some(input.to_string())
 226        }
 227
 228        let Some(hostname) = hostname else {
 229            anyhow::bail!("missing hostname");
 230        };
 231
 232        let port_forwards = match port_forwards.len() {
 233            0 => None,
 234            _ => Some(port_forwards),
 235        };
 236
 237        Ok(Self {
 238            host: hostname.to_string(),
 239            username: username.clone(),
 240            port,
 241            port_forwards,
 242            args: Some(args),
 243            password: None,
 244            nickname: None,
 245            upload_binary_over_ssh: false,
 246        })
 247    }
 248
 249    pub fn ssh_url(&self) -> String {
 250        let mut result = String::from("ssh://");
 251        if let Some(username) = &self.username {
 252            // Username might be: username1@username2@ip2
 253            let username = urlencoding::encode(username);
 254            result.push_str(&username);
 255            result.push('@');
 256        }
 257        result.push_str(&self.host);
 258        if let Some(port) = self.port {
 259            result.push(':');
 260            result.push_str(&port.to_string());
 261        }
 262        result
 263    }
 264
 265    pub fn additional_args(&self) -> Vec<String> {
 266        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 267
 268        if let Some(forwards) = &self.port_forwards {
 269            args.extend(forwards.iter().map(|pf| {
 270                let local_host = match &pf.local_host {
 271                    Some(host) => host,
 272                    None => "localhost",
 273                };
 274                let remote_host = match &pf.remote_host {
 275                    Some(host) => host,
 276                    None => "localhost",
 277                };
 278
 279                format!(
 280                    "-L{}:{}:{}:{}",
 281                    local_host, pf.local_port, remote_host, pf.remote_port
 282                )
 283            }));
 284        }
 285
 286        args
 287    }
 288
 289    fn scp_url(&self) -> String {
 290        if let Some(username) = &self.username {
 291            format!("{}@{}", username, self.host)
 292        } else {
 293            self.host.clone()
 294        }
 295    }
 296
 297    pub fn connection_string(&self) -> String {
 298        let host = if let Some(username) = &self.username {
 299            format!("{}@{}", username, self.host)
 300        } else {
 301            self.host.clone()
 302        };
 303        if let Some(port) = &self.port {
 304            format!("{}:{}", host, port)
 305        } else {
 306            host
 307        }
 308    }
 309}
 310
 311#[derive(Copy, Clone, Debug)]
 312pub struct SshPlatform {
 313    pub os: &'static str,
 314    pub arch: &'static str,
 315}
 316
 317pub trait SshClientDelegate: Send + Sync {
 318    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 319    fn get_download_params(
 320        &self,
 321        platform: SshPlatform,
 322        release_channel: ReleaseChannel,
 323        version: Option<SemanticVersion>,
 324        cx: &mut AsyncApp,
 325    ) -> Task<Result<Option<(String, String)>>>;
 326
 327    fn download_server_binary_locally(
 328        &self,
 329        platform: SshPlatform,
 330        release_channel: ReleaseChannel,
 331        version: Option<SemanticVersion>,
 332        cx: &mut AsyncApp,
 333    ) -> Task<Result<PathBuf>>;
 334    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 335}
 336
 337impl SshSocket {
 338    #[cfg(not(target_os = "windows"))]
 339    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 340        Ok(Self {
 341            connection_options: options,
 342            socket_path,
 343        })
 344    }
 345
 346    #[cfg(target_os = "windows")]
 347    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 348        let askpass_script = temp_dir.path().join("askpass.bat");
 349        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 350        let mut envs = HashMap::default();
 351        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 352        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 353        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 354        Ok(Self {
 355            connection_options: options,
 356            envs,
 357        })
 358    }
 359
 360    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 361    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 362    // and passes -l as an argument to sh, not to ls.
 363    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 364    // into a machine. You must use `cd` to get back to $HOME.
 365    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 366    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 367        let mut command = util::command::new_smol_command("ssh");
 368        let to_run = iter::once(&program)
 369            .chain(args.iter())
 370            .map(|token| {
 371                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 372                debug_assert!(
 373                    !token.contains('\n'),
 374                    "multiline arguments do not work in all shells"
 375                );
 376                shlex::try_quote(token).unwrap()
 377            })
 378            .join(" ");
 379        let to_run = format!("cd; {to_run}");
 380        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 381        self.ssh_options(&mut command)
 382            .arg(self.connection_options.ssh_url())
 383            .arg(to_run);
 384        command
 385    }
 386
 387    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 388        let output = self.ssh_command(program, args).output().await?;
 389        anyhow::ensure!(
 390            output.status.success(),
 391            "failed to run command: {}",
 392            String::from_utf8_lossy(&output.stderr)
 393        );
 394        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 395    }
 396
 397    #[cfg(not(target_os = "windows"))]
 398    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 399        command
 400            .stdin(Stdio::piped())
 401            .stdout(Stdio::piped())
 402            .stderr(Stdio::piped())
 403            .args(self.connection_options.additional_args())
 404            .args(["-o", "ControlMaster=no", "-o"])
 405            .arg(format!("ControlPath={}", self.socket_path.display()))
 406    }
 407
 408    #[cfg(target_os = "windows")]
 409    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 410        command
 411            .stdin(Stdio::piped())
 412            .stdout(Stdio::piped())
 413            .stderr(Stdio::piped())
 414            .args(self.connection_options.additional_args())
 415            .envs(self.envs.clone())
 416    }
 417
 418    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 419    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 420    #[cfg(not(target_os = "windows"))]
 421    fn ssh_args(&self) -> SshArgs {
 422        let mut arguments = self.connection_options.additional_args();
 423        arguments.extend(vec![
 424            "-o".to_string(),
 425            "ControlMaster=no".to_string(),
 426            "-o".to_string(),
 427            format!("ControlPath={}", self.socket_path.display()),
 428            self.connection_options.ssh_url(),
 429        ]);
 430        SshArgs {
 431            arguments,
 432            envs: None,
 433        }
 434    }
 435
 436    #[cfg(target_os = "windows")]
 437    fn ssh_args(&self) -> SshArgs {
 438        let mut arguments = self.connection_options.additional_args();
 439        arguments.push(self.connection_options.ssh_url());
 440        SshArgs {
 441            arguments,
 442            envs: Some(self.envs.clone()),
 443        }
 444    }
 445
 446    async fn platform(&self) -> Result<SshPlatform> {
 447        let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
 448        let Some((os, arch)) = uname.split_once(" ") else {
 449            anyhow::bail!("unknown uname: {uname:?}")
 450        };
 451
 452        let os = match os.trim() {
 453            "Darwin" => "macos",
 454            "Linux" => "linux",
 455            _ => anyhow::bail!(
 456                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 457            ),
 458        };
 459        // exclude armv5,6,7 as they are 32-bit.
 460        let arch = if arch.starts_with("armv8")
 461            || arch.starts_with("armv9")
 462            || arch.starts_with("arm64")
 463            || arch.starts_with("aarch64")
 464        {
 465            "aarch64"
 466        } else if arch.starts_with("x86") {
 467            "x86_64"
 468        } else {
 469            anyhow::bail!(
 470                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 471            )
 472        };
 473
 474        Ok(SshPlatform { os, arch })
 475    }
 476}
 477
 478const MAX_MISSED_HEARTBEATS: usize = 5;
 479const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 480const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 481
 482const MAX_RECONNECT_ATTEMPTS: usize = 3;
 483
 484enum State {
 485    Connecting,
 486    Connected {
 487        ssh_connection: Arc<dyn RemoteConnection>,
 488        delegate: Arc<dyn SshClientDelegate>,
 489
 490        multiplex_task: Task<Result<()>>,
 491        heartbeat_task: Task<Result<()>>,
 492    },
 493    HeartbeatMissed {
 494        missed_heartbeats: usize,
 495
 496        ssh_connection: Arc<dyn RemoteConnection>,
 497        delegate: Arc<dyn SshClientDelegate>,
 498
 499        multiplex_task: Task<Result<()>>,
 500        heartbeat_task: Task<Result<()>>,
 501    },
 502    Reconnecting,
 503    ReconnectFailed {
 504        ssh_connection: Arc<dyn RemoteConnection>,
 505        delegate: Arc<dyn SshClientDelegate>,
 506
 507        error: anyhow::Error,
 508        attempts: usize,
 509    },
 510    ReconnectExhausted,
 511    ServerNotRunning,
 512}
 513
 514impl fmt::Display for State {
 515    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 516        match self {
 517            Self::Connecting => write!(f, "connecting"),
 518            Self::Connected { .. } => write!(f, "connected"),
 519            Self::Reconnecting => write!(f, "reconnecting"),
 520            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 521            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 522            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 523            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 524        }
 525    }
 526}
 527
 528impl State {
 529    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 530        match self {
 531            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 532            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 533            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 534            _ => None,
 535        }
 536    }
 537
 538    fn can_reconnect(&self) -> bool {
 539        match self {
 540            Self::Connected { .. }
 541            | Self::HeartbeatMissed { .. }
 542            | Self::ReconnectFailed { .. } => true,
 543            State::Connecting
 544            | State::Reconnecting
 545            | State::ReconnectExhausted
 546            | State::ServerNotRunning => false,
 547        }
 548    }
 549
 550    fn is_reconnect_failed(&self) -> bool {
 551        matches!(self, Self::ReconnectFailed { .. })
 552    }
 553
 554    fn is_reconnect_exhausted(&self) -> bool {
 555        matches!(self, Self::ReconnectExhausted { .. })
 556    }
 557
 558    fn is_server_not_running(&self) -> bool {
 559        matches!(self, Self::ServerNotRunning)
 560    }
 561
 562    fn is_reconnecting(&self) -> bool {
 563        matches!(self, Self::Reconnecting { .. })
 564    }
 565
 566    fn heartbeat_recovered(self) -> Self {
 567        match self {
 568            Self::HeartbeatMissed {
 569                ssh_connection,
 570                delegate,
 571                multiplex_task,
 572                heartbeat_task,
 573                ..
 574            } => Self::Connected {
 575                ssh_connection,
 576                delegate,
 577                multiplex_task,
 578                heartbeat_task,
 579            },
 580            _ => self,
 581        }
 582    }
 583
 584    fn heartbeat_missed(self) -> Self {
 585        match self {
 586            Self::Connected {
 587                ssh_connection,
 588                delegate,
 589                multiplex_task,
 590                heartbeat_task,
 591            } => Self::HeartbeatMissed {
 592                missed_heartbeats: 1,
 593                ssh_connection,
 594                delegate,
 595                multiplex_task,
 596                heartbeat_task,
 597            },
 598            Self::HeartbeatMissed {
 599                missed_heartbeats,
 600                ssh_connection,
 601                delegate,
 602                multiplex_task,
 603                heartbeat_task,
 604            } => Self::HeartbeatMissed {
 605                missed_heartbeats: missed_heartbeats + 1,
 606                ssh_connection,
 607                delegate,
 608                multiplex_task,
 609                heartbeat_task,
 610            },
 611            _ => self,
 612        }
 613    }
 614}
 615
 616/// The state of the ssh connection.
 617#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 618pub enum ConnectionState {
 619    Connecting,
 620    Connected,
 621    HeartbeatMissed,
 622    Reconnecting,
 623    Disconnected,
 624}
 625
 626impl From<&State> for ConnectionState {
 627    fn from(value: &State) -> Self {
 628        match value {
 629            State::Connecting => Self::Connecting,
 630            State::Connected { .. } => Self::Connected,
 631            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 632            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 633            State::ReconnectExhausted => Self::Disconnected,
 634            State::ServerNotRunning => Self::Disconnected,
 635        }
 636    }
 637}
 638
 639pub struct SshRemoteClient {
 640    client: Arc<ChannelClient>,
 641    unique_identifier: String,
 642    connection_options: SshConnectionOptions,
 643    path_style: PathStyle,
 644    state: Arc<Mutex<Option<State>>>,
 645}
 646
 647#[derive(Debug)]
 648pub enum SshRemoteEvent {
 649    Disconnected,
 650}
 651
 652impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 653
 654// Identifies the socket on the remote server so that reconnects
 655// can re-join the same project.
 656pub enum ConnectionIdentifier {
 657    Setup(u64),
 658    Workspace(i64),
 659}
 660
 661static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 662
 663impl ConnectionIdentifier {
 664    pub fn setup() -> Self {
 665        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 666    }
 667    // This string gets used in a socket name, and so must be relatively short.
 668    // The total length of:
 669    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 670    // Must be less than about 100 characters
 671    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 672    // So our strings should be at most 20 characters or so.
 673    fn to_string(&self, cx: &App) -> String {
 674        let identifier_prefix = match ReleaseChannel::global(cx) {
 675            ReleaseChannel::Stable => "".to_string(),
 676            release_channel => format!("{}-", release_channel.dev_name()),
 677        };
 678        match self {
 679            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 680            Self::Workspace(workspace_id) => {
 681                format!("{identifier_prefix}workspace-{workspace_id}",)
 682            }
 683        }
 684    }
 685}
 686
 687impl SshRemoteClient {
 688    pub fn new(
 689        unique_identifier: ConnectionIdentifier,
 690        connection_options: SshConnectionOptions,
 691        cancellation: oneshot::Receiver<()>,
 692        delegate: Arc<dyn SshClientDelegate>,
 693        cx: &mut App,
 694    ) -> Task<Result<Option<Entity<Self>>>> {
 695        let unique_identifier = unique_identifier.to_string(cx);
 696        cx.spawn(async move |cx| {
 697            let success = Box::pin(async move {
 698                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 699                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 700                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 701
 702                let client =
 703                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 704
 705                let ssh_connection = cx
 706                    .update(|cx| {
 707                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 708                            pool.connect(connection_options.clone(), &delegate, cx)
 709                        })
 710                    })?
 711                    .await
 712                    .map_err(|e| e.cloned())?;
 713
 714                let path_style = ssh_connection.path_style();
 715                let this = cx.new(|_| Self {
 716                    client: client.clone(),
 717                    unique_identifier: unique_identifier.clone(),
 718                    connection_options,
 719                    path_style,
 720                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 721                })?;
 722
 723                let io_task = ssh_connection.start_proxy(
 724                    unique_identifier,
 725                    false,
 726                    incoming_tx,
 727                    outgoing_rx,
 728                    connection_activity_tx,
 729                    delegate.clone(),
 730                    cx,
 731                );
 732
 733                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 734
 735                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 736                    log::error!("failed to establish connection: {}", error);
 737                    return Err(error);
 738                }
 739
 740                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 741
 742                this.update(cx, |this, _| {
 743                    *this.state.lock() = Some(State::Connected {
 744                        ssh_connection,
 745                        delegate,
 746                        multiplex_task,
 747                        heartbeat_task,
 748                    });
 749                })?;
 750
 751                Ok(Some(this))
 752            });
 753
 754            select! {
 755                _ = cancellation.fuse() => {
 756                    Ok(None)
 757                }
 758                result = success.fuse() =>  result
 759            }
 760        })
 761    }
 762
 763    pub fn shutdown_processes<T: RequestMessage>(
 764        &self,
 765        shutdown_request: Option<T>,
 766        executor: BackgroundExecutor,
 767    ) -> Option<impl Future<Output = ()> + use<T>> {
 768        let state = self.state.lock().take()?;
 769        log::info!("shutting down ssh processes");
 770
 771        let State::Connected {
 772            multiplex_task,
 773            heartbeat_task,
 774            ssh_connection,
 775            delegate,
 776        } = state
 777        else {
 778            return None;
 779        };
 780
 781        let client = self.client.clone();
 782
 783        Some(async move {
 784            if let Some(shutdown_request) = shutdown_request {
 785                client.send(shutdown_request).log_err();
 786                // We wait 50ms instead of waiting for a response, because
 787                // waiting for a response would require us to wait on the main thread
 788                // which we want to avoid in an `on_app_quit` callback.
 789                executor.timer(Duration::from_millis(50)).await;
 790            }
 791
 792            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 793            // child of master_process.
 794            drop(multiplex_task);
 795            // Now drop the rest of state, which kills master process.
 796            drop(heartbeat_task);
 797            drop(ssh_connection);
 798            drop(delegate);
 799        })
 800    }
 801
 802    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 803        let mut lock = self.state.lock();
 804
 805        let can_reconnect = lock
 806            .as_ref()
 807            .map(|state| state.can_reconnect())
 808            .unwrap_or(false);
 809        if !can_reconnect {
 810            log::info!("aborting reconnect, because not in state that allows reconnecting");
 811            let error = if let Some(state) = lock.as_ref() {
 812                format!("invalid state, cannot reconnect while in state {state}")
 813            } else {
 814                "no state set".to_string()
 815            };
 816            anyhow::bail!(error);
 817        }
 818
 819        let state = lock.take().unwrap();
 820        let (attempts, ssh_connection, delegate) = match state {
 821            State::Connected {
 822                ssh_connection,
 823                delegate,
 824                multiplex_task,
 825                heartbeat_task,
 826            }
 827            | State::HeartbeatMissed {
 828                ssh_connection,
 829                delegate,
 830                multiplex_task,
 831                heartbeat_task,
 832                ..
 833            } => {
 834                drop(multiplex_task);
 835                drop(heartbeat_task);
 836                (0, ssh_connection, delegate)
 837            }
 838            State::ReconnectFailed {
 839                attempts,
 840                ssh_connection,
 841                delegate,
 842                ..
 843            } => (attempts, ssh_connection, delegate),
 844            State::Connecting
 845            | State::Reconnecting
 846            | State::ReconnectExhausted
 847            | State::ServerNotRunning => unreachable!(),
 848        };
 849
 850        let attempts = attempts + 1;
 851        if attempts > MAX_RECONNECT_ATTEMPTS {
 852            log::error!(
 853                "Failed to reconnect to after {} attempts, giving up",
 854                MAX_RECONNECT_ATTEMPTS
 855            );
 856            drop(lock);
 857            self.set_state(State::ReconnectExhausted, cx);
 858            return Ok(());
 859        }
 860        drop(lock);
 861
 862        self.set_state(State::Reconnecting, cx);
 863
 864        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 865
 866        let unique_identifier = self.unique_identifier.clone();
 867        let client = self.client.clone();
 868        let reconnect_task = cx.spawn(async move |this, cx| {
 869            macro_rules! failed {
 870                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 871                    return State::ReconnectFailed {
 872                        error: anyhow!($error),
 873                        attempts: $attempts,
 874                        ssh_connection: $ssh_connection,
 875                        delegate: $delegate,
 876                    };
 877                };
 878            }
 879
 880            if let Err(error) = ssh_connection
 881                .kill()
 882                .await
 883                .context("Failed to kill ssh process")
 884            {
 885                failed!(error, attempts, ssh_connection, delegate);
 886            };
 887
 888            let connection_options = ssh_connection.connection_options();
 889
 890            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 891            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 892            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 893
 894            let (ssh_connection, io_task) = match async {
 895                let ssh_connection = cx
 896                    .update_global(|pool: &mut ConnectionPool, cx| {
 897                        pool.connect(connection_options, &delegate, cx)
 898                    })?
 899                    .await
 900                    .map_err(|error| error.cloned())?;
 901
 902                let io_task = ssh_connection.start_proxy(
 903                    unique_identifier,
 904                    true,
 905                    incoming_tx,
 906                    outgoing_rx,
 907                    connection_activity_tx,
 908                    delegate.clone(),
 909                    cx,
 910                );
 911                anyhow::Ok((ssh_connection, io_task))
 912            }
 913            .await
 914            {
 915                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 916                Err(error) => {
 917                    failed!(error, attempts, ssh_connection, delegate);
 918                }
 919            };
 920
 921            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 922            client.reconnect(incoming_rx, outgoing_tx, &cx);
 923
 924            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 925                failed!(error, attempts, ssh_connection, delegate);
 926            };
 927
 928            State::Connected {
 929                ssh_connection,
 930                delegate,
 931                multiplex_task,
 932                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 933            }
 934        });
 935
 936        cx.spawn(async move |this, cx| {
 937            let new_state = reconnect_task.await;
 938            this.update(cx, |this, cx| {
 939                this.try_set_state(cx, |old_state| {
 940                    if old_state.is_reconnecting() {
 941                        match &new_state {
 942                            State::Connecting
 943                            | State::Reconnecting { .. }
 944                            | State::HeartbeatMissed { .. }
 945                            | State::ServerNotRunning => {}
 946                            State::Connected { .. } => {
 947                                log::info!("Successfully reconnected");
 948                            }
 949                            State::ReconnectFailed {
 950                                error, attempts, ..
 951                            } => {
 952                                log::error!(
 953                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 954                                    attempts,
 955                                    error
 956                                );
 957                            }
 958                            State::ReconnectExhausted => {
 959                                log::error!("Reconnect attempt failed and all attempts exhausted");
 960                            }
 961                        }
 962                        Some(new_state)
 963                    } else {
 964                        None
 965                    }
 966                });
 967
 968                if this.state_is(State::is_reconnect_failed) {
 969                    this.reconnect(cx)
 970                } else if this.state_is(State::is_reconnect_exhausted) {
 971                    Ok(())
 972                } else {
 973                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 974                    Ok(())
 975                }
 976            })
 977        })
 978        .detach_and_log_err(cx);
 979
 980        Ok(())
 981    }
 982
 983    fn heartbeat(
 984        this: WeakEntity<Self>,
 985        mut connection_activity_rx: mpsc::Receiver<()>,
 986        cx: &mut AsyncApp,
 987    ) -> Task<Result<()>> {
 988        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 989            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 990        };
 991
 992        cx.spawn(async move |cx| {
 993                let mut missed_heartbeats = 0;
 994
 995                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 996                futures::pin_mut!(keepalive_timer);
 997
 998                loop {
 999                    select_biased! {
1000                        result = connection_activity_rx.next().fuse() => {
1001                            if result.is_none() {
1002                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1003                                return Ok(());
1004                            }
1005
1006                            if missed_heartbeats != 0 {
1007                                missed_heartbeats = 0;
1008                                let _ =this.update(cx, |this, mut cx| {
1009                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1010                                })?;
1011                            }
1012                        }
1013                        _ = keepalive_timer => {
1014                            log::debug!("Sending heartbeat to server...");
1015
1016                            let result = select_biased! {
1017                                _ = connection_activity_rx.next().fuse() => {
1018                                    Ok(())
1019                                }
1020                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1021                                    ping_result
1022                                }
1023                            };
1024
1025                            if result.is_err() {
1026                                missed_heartbeats += 1;
1027                                log::warn!(
1028                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1029                                    HEARTBEAT_TIMEOUT,
1030                                    missed_heartbeats,
1031                                    MAX_MISSED_HEARTBEATS
1032                                );
1033                            } else if missed_heartbeats != 0 {
1034                                missed_heartbeats = 0;
1035                            } else {
1036                                continue;
1037                            }
1038
1039                            let result = this.update(cx, |this, mut cx| {
1040                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1041                            })?;
1042                            if result.is_break() {
1043                                return Ok(());
1044                            }
1045                        }
1046                    }
1047
1048                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1049                }
1050
1051        })
1052    }
1053
1054    fn handle_heartbeat_result(
1055        &mut self,
1056        missed_heartbeats: usize,
1057        cx: &mut Context<Self>,
1058    ) -> ControlFlow<()> {
1059        let state = self.state.lock().take().unwrap();
1060        let next_state = if missed_heartbeats > 0 {
1061            state.heartbeat_missed()
1062        } else {
1063            state.heartbeat_recovered()
1064        };
1065
1066        self.set_state(next_state, cx);
1067
1068        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1069            log::error!(
1070                "Missed last {} heartbeats. Reconnecting...",
1071                missed_heartbeats
1072            );
1073
1074            self.reconnect(cx)
1075                .context("failed to start reconnect process after missing heartbeats")
1076                .log_err();
1077            ControlFlow::Break(())
1078        } else {
1079            ControlFlow::Continue(())
1080        }
1081    }
1082
1083    fn monitor(
1084        this: WeakEntity<Self>,
1085        io_task: Task<Result<i32>>,
1086        cx: &AsyncApp,
1087    ) -> Task<Result<()>> {
1088        cx.spawn(async move |cx| {
1089            let result = io_task.await;
1090
1091            match result {
1092                Ok(exit_code) => {
1093                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1094                        match error {
1095                            ProxyLaunchError::ServerNotRunning => {
1096                                log::error!("failed to reconnect because server is not running");
1097                                this.update(cx, |this, cx| {
1098                                    this.set_state(State::ServerNotRunning, cx);
1099                                })?;
1100                            }
1101                        }
1102                    } else if exit_code > 0 {
1103                        log::error!("proxy process terminated unexpectedly");
1104                        this.update(cx, |this, cx| {
1105                            this.reconnect(cx).ok();
1106                        })?;
1107                    }
1108                }
1109                Err(error) => {
1110                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1111                    this.update(cx, |this, cx| {
1112                        this.reconnect(cx).ok();
1113                    })?;
1114                }
1115            }
1116
1117            Ok(())
1118        })
1119    }
1120
1121    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1122        self.state.lock().as_ref().map_or(false, check)
1123    }
1124
1125    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1126        let mut lock = self.state.lock();
1127        let new_state = lock.as_ref().and_then(map);
1128
1129        if let Some(new_state) = new_state {
1130            lock.replace(new_state);
1131            cx.notify();
1132        }
1133    }
1134
1135    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1136        log::info!("setting state to '{}'", &state);
1137
1138        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1139        let is_server_not_running = state.is_server_not_running();
1140        self.state.lock().replace(state);
1141
1142        if is_reconnect_exhausted || is_server_not_running {
1143            cx.emit(SshRemoteEvent::Disconnected);
1144        }
1145        cx.notify();
1146    }
1147
1148    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1149        self.client.subscribe_to_entity(remote_id, entity);
1150    }
1151
1152    pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1153        self.state
1154            .lock()
1155            .as_ref()
1156            .and_then(|state| state.ssh_connection())
1157            .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1158    }
1159
1160    pub fn upload_directory(
1161        &self,
1162        src_path: PathBuf,
1163        dest_path: RemotePathBuf,
1164        cx: &App,
1165    ) -> Task<Result<()>> {
1166        let state = self.state.lock();
1167        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1168            return Task::ready(Err(anyhow!("no ssh connection")));
1169        };
1170        connection.upload_directory(src_path, dest_path, cx)
1171    }
1172
1173    pub fn proto_client(&self) -> AnyProtoClient {
1174        self.client.clone().into()
1175    }
1176
1177    pub fn connection_string(&self) -> String {
1178        self.connection_options.connection_string()
1179    }
1180
1181    pub fn connection_options(&self) -> SshConnectionOptions {
1182        self.connection_options.clone()
1183    }
1184
1185    pub fn connection_state(&self) -> ConnectionState {
1186        self.state
1187            .lock()
1188            .as_ref()
1189            .map(ConnectionState::from)
1190            .unwrap_or(ConnectionState::Disconnected)
1191    }
1192
1193    pub fn is_disconnected(&self) -> bool {
1194        self.connection_state() == ConnectionState::Disconnected
1195    }
1196
1197    pub fn path_style(&self) -> PathStyle {
1198        self.path_style
1199    }
1200
1201    #[cfg(any(test, feature = "test-support"))]
1202    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1203        let opts = self.connection_options();
1204        client_cx.spawn(async move |cx| {
1205            let connection = cx
1206                .update_global(|c: &mut ConnectionPool, _| {
1207                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1208                        c.clone()
1209                    } else {
1210                        panic!("missing test connection")
1211                    }
1212                })
1213                .unwrap()
1214                .await
1215                .unwrap();
1216
1217            connection.simulate_disconnect(&cx);
1218        })
1219    }
1220
1221    #[cfg(any(test, feature = "test-support"))]
1222    pub fn fake_server(
1223        client_cx: &mut gpui::TestAppContext,
1224        server_cx: &mut gpui::TestAppContext,
1225    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1226        let port = client_cx
1227            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1228        let opts = SshConnectionOptions {
1229            host: "<fake>".to_string(),
1230            port: Some(port),
1231            ..Default::default()
1232        };
1233        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1234        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1235        let server_client =
1236            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1237        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1238            connection_options: opts.clone(),
1239            server_cx: fake::SendableCx::new(server_cx),
1240            server_channel: server_client.clone(),
1241        });
1242
1243        client_cx.update(|cx| {
1244            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1245                c.connections.insert(
1246                    opts.clone(),
1247                    ConnectionPoolEntry::Connecting(
1248                        cx.background_spawn({
1249                            let connection = connection.clone();
1250                            async move { Ok(connection.clone()) }
1251                        })
1252                        .shared(),
1253                    ),
1254                );
1255            })
1256        });
1257
1258        (opts, server_client)
1259    }
1260
1261    #[cfg(any(test, feature = "test-support"))]
1262    pub async fn fake_client(
1263        opts: SshConnectionOptions,
1264        client_cx: &mut gpui::TestAppContext,
1265    ) -> Entity<Self> {
1266        let (_tx, rx) = oneshot::channel();
1267        client_cx
1268            .update(|cx| {
1269                Self::new(
1270                    ConnectionIdentifier::setup(),
1271                    opts,
1272                    rx,
1273                    Arc::new(fake::Delegate),
1274                    cx,
1275                )
1276            })
1277            .await
1278            .unwrap()
1279            .unwrap()
1280    }
1281}
1282
1283enum ConnectionPoolEntry {
1284    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1285    Connected(Weak<dyn RemoteConnection>),
1286}
1287
1288#[derive(Default)]
1289struct ConnectionPool {
1290    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1291}
1292
1293impl Global for ConnectionPool {}
1294
1295impl ConnectionPool {
1296    pub fn connect(
1297        &mut self,
1298        opts: SshConnectionOptions,
1299        delegate: &Arc<dyn SshClientDelegate>,
1300        cx: &mut App,
1301    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1302        let connection = self.connections.get(&opts);
1303        match connection {
1304            Some(ConnectionPoolEntry::Connecting(task)) => {
1305                let delegate = delegate.clone();
1306                cx.spawn(async move |cx| {
1307                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1308                })
1309                .detach();
1310                return task.clone();
1311            }
1312            Some(ConnectionPoolEntry::Connected(ssh)) => {
1313                if let Some(ssh) = ssh.upgrade() {
1314                    if !ssh.has_been_killed() {
1315                        return Task::ready(Ok(ssh)).shared();
1316                    }
1317                }
1318                self.connections.remove(&opts);
1319            }
1320            None => {}
1321        }
1322
1323        let task = cx
1324            .spawn({
1325                let opts = opts.clone();
1326                let delegate = delegate.clone();
1327                async move |cx| {
1328                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1329                        .await
1330                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1331
1332                    cx.update_global(|pool: &mut Self, _| {
1333                        debug_assert!(matches!(
1334                            pool.connections.get(&opts),
1335                            Some(ConnectionPoolEntry::Connecting(_))
1336                        ));
1337                        match connection {
1338                            Ok(connection) => {
1339                                pool.connections.insert(
1340                                    opts.clone(),
1341                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1342                                );
1343                                Ok(connection)
1344                            }
1345                            Err(error) => {
1346                                pool.connections.remove(&opts);
1347                                Err(Arc::new(error))
1348                            }
1349                        }
1350                    })?
1351                }
1352            })
1353            .shared();
1354
1355        self.connections
1356            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1357        task
1358    }
1359}
1360
1361impl From<SshRemoteClient> for AnyProtoClient {
1362    fn from(client: SshRemoteClient) -> Self {
1363        AnyProtoClient::new(client.client.clone())
1364    }
1365}
1366
1367#[async_trait(?Send)]
1368trait RemoteConnection: Send + Sync {
1369    fn start_proxy(
1370        &self,
1371        unique_identifier: String,
1372        reconnect: bool,
1373        incoming_tx: UnboundedSender<Envelope>,
1374        outgoing_rx: UnboundedReceiver<Envelope>,
1375        connection_activity_tx: Sender<()>,
1376        delegate: Arc<dyn SshClientDelegate>,
1377        cx: &mut AsyncApp,
1378    ) -> Task<Result<i32>>;
1379    fn upload_directory(
1380        &self,
1381        src_path: PathBuf,
1382        dest_path: RemotePathBuf,
1383        cx: &App,
1384    ) -> Task<Result<()>>;
1385    async fn kill(&self) -> Result<()>;
1386    fn has_been_killed(&self) -> bool;
1387    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1388    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1389    fn ssh_args(&self) -> SshArgs;
1390    fn connection_options(&self) -> SshConnectionOptions;
1391    fn path_style(&self) -> PathStyle;
1392
1393    #[cfg(any(test, feature = "test-support"))]
1394    fn simulate_disconnect(&self, _: &AsyncApp) {}
1395}
1396
1397struct SshRemoteConnection {
1398    socket: SshSocket,
1399    master_process: Mutex<Option<Child>>,
1400    remote_binary_path: Option<RemotePathBuf>,
1401    ssh_platform: SshPlatform,
1402    ssh_path_style: PathStyle,
1403    _temp_dir: TempDir,
1404}
1405
1406#[async_trait(?Send)]
1407impl RemoteConnection for SshRemoteConnection {
1408    async fn kill(&self) -> Result<()> {
1409        let Some(mut process) = self.master_process.lock().take() else {
1410            return Ok(());
1411        };
1412        process.kill().ok();
1413        process.status().await?;
1414        Ok(())
1415    }
1416
1417    fn has_been_killed(&self) -> bool {
1418        self.master_process.lock().is_none()
1419    }
1420
1421    fn ssh_args(&self) -> SshArgs {
1422        self.socket.ssh_args()
1423    }
1424
1425    fn connection_options(&self) -> SshConnectionOptions {
1426        self.socket.connection_options.clone()
1427    }
1428
1429    fn upload_directory(
1430        &self,
1431        src_path: PathBuf,
1432        dest_path: RemotePathBuf,
1433        cx: &App,
1434    ) -> Task<Result<()>> {
1435        let mut command = util::command::new_smol_command("scp");
1436        let output = self
1437            .socket
1438            .ssh_options(&mut command)
1439            .args(
1440                self.socket
1441                    .connection_options
1442                    .port
1443                    .map(|port| vec!["-P".to_string(), port.to_string()])
1444                    .unwrap_or_default(),
1445            )
1446            .arg("-C")
1447            .arg("-r")
1448            .arg(&src_path)
1449            .arg(format!(
1450                "{}:{}",
1451                self.socket.connection_options.scp_url(),
1452                dest_path.to_string()
1453            ))
1454            .output();
1455
1456        cx.background_spawn(async move {
1457            let output = output.await?;
1458
1459            anyhow::ensure!(
1460                output.status.success(),
1461                "failed to upload directory {} -> {}: {}",
1462                src_path.display(),
1463                dest_path.to_string(),
1464                String::from_utf8_lossy(&output.stderr)
1465            );
1466
1467            Ok(())
1468        })
1469    }
1470
1471    fn start_proxy(
1472        &self,
1473        unique_identifier: String,
1474        reconnect: bool,
1475        incoming_tx: UnboundedSender<Envelope>,
1476        outgoing_rx: UnboundedReceiver<Envelope>,
1477        connection_activity_tx: Sender<()>,
1478        delegate: Arc<dyn SshClientDelegate>,
1479        cx: &mut AsyncApp,
1480    ) -> Task<Result<i32>> {
1481        delegate.set_status(Some("Starting proxy"), cx);
1482
1483        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1484            return Task::ready(Err(anyhow!("Remote binary path not set")));
1485        };
1486
1487        let mut start_proxy_command = shell_script!(
1488            "exec {binary_path} proxy --identifier {identifier}",
1489            binary_path = &remote_binary_path.to_string(),
1490            identifier = &unique_identifier,
1491        );
1492
1493        if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1494            start_proxy_command = format!(
1495                "RUST_LOG={} {}",
1496                shlex::try_quote(&rust_log).unwrap(),
1497                start_proxy_command
1498            )
1499        }
1500        if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1501            start_proxy_command = format!(
1502                "RUST_BACKTRACE={} {}",
1503                shlex::try_quote(&rust_backtrace).unwrap(),
1504                start_proxy_command
1505            )
1506        }
1507        if reconnect {
1508            start_proxy_command.push_str(" --reconnect");
1509        }
1510
1511        let ssh_proxy_process = match self
1512            .socket
1513            .ssh_command("sh", &["-c", &start_proxy_command])
1514            // IMPORTANT: we kill this process when we drop the task that uses it.
1515            .kill_on_drop(true)
1516            .spawn()
1517        {
1518            Ok(process) => process,
1519            Err(error) => {
1520                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1521            }
1522        };
1523
1524        Self::multiplex(
1525            ssh_proxy_process,
1526            incoming_tx,
1527            outgoing_rx,
1528            connection_activity_tx,
1529            &cx,
1530        )
1531    }
1532
1533    fn path_style(&self) -> PathStyle {
1534        self.ssh_path_style
1535    }
1536}
1537
1538impl SshRemoteConnection {
1539    async fn new(
1540        connection_options: SshConnectionOptions,
1541        delegate: Arc<dyn SshClientDelegate>,
1542        cx: &mut AsyncApp,
1543    ) -> Result<Self> {
1544        use askpass::AskPassResult;
1545
1546        delegate.set_status(Some("Connecting"), cx);
1547
1548        let url = connection_options.ssh_url();
1549
1550        let temp_dir = tempfile::Builder::new()
1551            .prefix("zed-ssh-session")
1552            .tempdir()?;
1553        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1554            let delegate = delegate.clone();
1555            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1556        });
1557
1558        let mut askpass =
1559            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1560
1561        // Start the master SSH process, which does not do anything except for establish
1562        // the connection and keep it open, allowing other ssh commands to reuse it
1563        // via a control socket.
1564        #[cfg(not(target_os = "windows"))]
1565        let socket_path = temp_dir.path().join("ssh.sock");
1566
1567        let mut master_process = {
1568            #[cfg(not(target_os = "windows"))]
1569            let args = [
1570                "-N",
1571                "-o",
1572                "ControlPersist=no",
1573                "-o",
1574                "ControlMaster=yes",
1575                "-o",
1576            ];
1577            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1578            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1579            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1580            #[cfg(target_os = "windows")]
1581            let args = ["-N"];
1582            let mut master_process = util::command::new_smol_command("ssh");
1583            master_process
1584                .kill_on_drop(true)
1585                .stdin(Stdio::null())
1586                .stdout(Stdio::piped())
1587                .stderr(Stdio::piped())
1588                .env("SSH_ASKPASS_REQUIRE", "force")
1589                .env("SSH_ASKPASS", askpass.script_path())
1590                .args(connection_options.additional_args())
1591                .args(args);
1592            #[cfg(not(target_os = "windows"))]
1593            master_process.arg(format!("ControlPath={}", socket_path.display()));
1594            master_process.arg(&url).spawn()?
1595        };
1596        // Wait for this ssh process to close its stdout, indicating that authentication
1597        // has completed.
1598        let mut stdout = master_process.stdout.take().unwrap();
1599        let mut output = Vec::new();
1600
1601        let result = select_biased! {
1602            result = askpass.run().fuse() => {
1603                match result {
1604                    AskPassResult::CancelledByUser => {
1605                        master_process.kill().ok();
1606                        anyhow::bail!("SSH connection canceled")
1607                    }
1608                    AskPassResult::Timedout => {
1609                        anyhow::bail!("connecting to host timed out")
1610                    }
1611                }
1612            }
1613            _ = stdout.read_to_end(&mut output).fuse() => {
1614                anyhow::Ok(())
1615            }
1616        };
1617
1618        if let Err(e) = result {
1619            return Err(e.context("Failed to connect to host"));
1620        }
1621
1622        if master_process.try_status()?.is_some() {
1623            output.clear();
1624            let mut stderr = master_process.stderr.take().unwrap();
1625            stderr.read_to_end(&mut output).await?;
1626
1627            let error_message = format!(
1628                "failed to connect: {}",
1629                String::from_utf8_lossy(&output).trim()
1630            );
1631            anyhow::bail!(error_message);
1632        }
1633
1634        #[cfg(not(target_os = "windows"))]
1635        let socket = SshSocket::new(connection_options, socket_path)?;
1636        #[cfg(target_os = "windows")]
1637        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1638        drop(askpass);
1639
1640        let ssh_platform = socket.platform().await?;
1641        let ssh_path_style = match ssh_platform.os {
1642            "windows" => PathStyle::Windows,
1643            _ => PathStyle::Posix,
1644        };
1645
1646        let mut this = Self {
1647            socket,
1648            master_process: Mutex::new(Some(master_process)),
1649            _temp_dir: temp_dir,
1650            remote_binary_path: None,
1651            ssh_path_style,
1652            ssh_platform,
1653        };
1654
1655        let (release_channel, version, commit) = cx.update(|cx| {
1656            (
1657                ReleaseChannel::global(cx),
1658                AppVersion::global(cx),
1659                AppCommitSha::try_global(cx),
1660            )
1661        })?;
1662        this.remote_binary_path = Some(
1663            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1664                .await?,
1665        );
1666
1667        Ok(this)
1668    }
1669
1670    fn multiplex(
1671        mut ssh_proxy_process: Child,
1672        incoming_tx: UnboundedSender<Envelope>,
1673        mut outgoing_rx: UnboundedReceiver<Envelope>,
1674        mut connection_activity_tx: Sender<()>,
1675        cx: &AsyncApp,
1676    ) -> Task<Result<i32>> {
1677        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1678        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1679        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1680
1681        let mut stdin_buffer = Vec::new();
1682        let mut stdout_buffer = Vec::new();
1683        let mut stderr_buffer = Vec::new();
1684        let mut stderr_offset = 0;
1685
1686        let stdin_task = cx.background_spawn(async move {
1687            while let Some(outgoing) = outgoing_rx.next().await {
1688                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1689            }
1690            anyhow::Ok(())
1691        });
1692
1693        let stdout_task = cx.background_spawn({
1694            let mut connection_activity_tx = connection_activity_tx.clone();
1695            async move {
1696                loop {
1697                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1698                    let len = child_stdout.read(&mut stdout_buffer).await?;
1699
1700                    if len == 0 {
1701                        return anyhow::Ok(());
1702                    }
1703
1704                    if len < MESSAGE_LEN_SIZE {
1705                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1706                    }
1707
1708                    let message_len = message_len_from_buffer(&stdout_buffer);
1709                    let envelope =
1710                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1711                            .await?;
1712                    connection_activity_tx.try_send(()).ok();
1713                    incoming_tx.unbounded_send(envelope).ok();
1714                }
1715            }
1716        });
1717
1718        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1719            loop {
1720                stderr_buffer.resize(stderr_offset + 1024, 0);
1721
1722                let len = child_stderr
1723                    .read(&mut stderr_buffer[stderr_offset..])
1724                    .await?;
1725                if len == 0 {
1726                    return anyhow::Ok(());
1727                }
1728
1729                stderr_offset += len;
1730                let mut start_ix = 0;
1731                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1732                    .iter()
1733                    .position(|b| b == &b'\n')
1734                {
1735                    let line_ix = start_ix + ix;
1736                    let content = &stderr_buffer[start_ix..line_ix];
1737                    start_ix = line_ix + 1;
1738                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1739                        record.log(log::logger())
1740                    } else {
1741                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1742                    }
1743                }
1744                stderr_buffer.drain(0..start_ix);
1745                stderr_offset -= start_ix;
1746
1747                connection_activity_tx.try_send(()).ok();
1748            }
1749        });
1750
1751        cx.background_spawn(async move {
1752            let result = futures::select! {
1753                result = stdin_task.fuse() => {
1754                    result.context("stdin")
1755                }
1756                result = stdout_task.fuse() => {
1757                    result.context("stdout")
1758                }
1759                result = stderr_task.fuse() => {
1760                    result.context("stderr")
1761                }
1762            };
1763
1764            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1765            match result {
1766                Ok(_) => Ok(status),
1767                Err(error) => Err(error),
1768            }
1769        })
1770    }
1771
1772    #[allow(unused)]
1773    async fn ensure_server_binary(
1774        &self,
1775        delegate: &Arc<dyn SshClientDelegate>,
1776        release_channel: ReleaseChannel,
1777        version: SemanticVersion,
1778        commit: Option<AppCommitSha>,
1779        cx: &mut AsyncApp,
1780    ) -> Result<RemotePathBuf> {
1781        let version_str = match release_channel {
1782            ReleaseChannel::Nightly => {
1783                let commit = commit.map(|s| s.full()).unwrap_or_default();
1784                format!("{}-{}", version, commit)
1785            }
1786            ReleaseChannel::Dev => "build".to_string(),
1787            _ => version.to_string(),
1788        };
1789        let binary_name = format!(
1790            "zed-remote-server-{}-{}",
1791            release_channel.dev_name(),
1792            version_str
1793        );
1794        let dst_path = RemotePathBuf::new(
1795            paths::remote_server_dir_relative().join(binary_name),
1796            self.ssh_path_style,
1797        );
1798
1799        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1800        #[cfg(debug_assertions)]
1801        if let Some(build_remote_server) = build_remote_server {
1802            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1803            let tmp_path = RemotePathBuf::new(
1804                paths::remote_server_dir_relative().join(format!(
1805                    "download-{}-{}",
1806                    std::process::id(),
1807                    src_path.file_name().unwrap().to_string_lossy()
1808                )),
1809                self.ssh_path_style,
1810            );
1811            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1812                .await?;
1813            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1814                .await?;
1815            return Ok(dst_path);
1816        }
1817
1818        if self
1819            .socket
1820            .run_command(&dst_path.to_string(), &["version"])
1821            .await
1822            .is_ok()
1823        {
1824            return Ok(dst_path);
1825        }
1826
1827        let wanted_version = cx.update(|cx| match release_channel {
1828            ReleaseChannel::Nightly => Ok(None),
1829            ReleaseChannel::Dev => {
1830                anyhow::bail!(
1831                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1832                    dst_path
1833                )
1834            }
1835            _ => Ok(Some(AppVersion::global(cx))),
1836        })??;
1837
1838        let tmp_path_gz = RemotePathBuf::new(
1839            PathBuf::from(format!(
1840                "{}-download-{}.gz",
1841                dst_path.to_string(),
1842                std::process::id()
1843            )),
1844            self.ssh_path_style,
1845        );
1846        if !self.socket.connection_options.upload_binary_over_ssh {
1847            if let Some((url, body)) = delegate
1848                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1849                .await?
1850            {
1851                match self
1852                    .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1853                    .await
1854                {
1855                    Ok(_) => {
1856                        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1857                            .await?;
1858                        return Ok(dst_path);
1859                    }
1860                    Err(e) => {
1861                        log::error!(
1862                            "Failed to download binary on server, attempting to upload server: {}",
1863                            e
1864                        )
1865                    }
1866                }
1867            }
1868        }
1869
1870        let src_path = delegate
1871            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1872            .await?;
1873        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1874            .await?;
1875        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1876            .await?;
1877        return Ok(dst_path);
1878    }
1879
1880    async fn download_binary_on_server(
1881        &self,
1882        url: &str,
1883        body: &str,
1884        tmp_path_gz: &RemotePathBuf,
1885        delegate: &Arc<dyn SshClientDelegate>,
1886        cx: &mut AsyncApp,
1887    ) -> Result<()> {
1888        if let Some(parent) = tmp_path_gz.parent() {
1889            self.socket
1890                .run_command(
1891                    "sh",
1892                    &[
1893                        "-c",
1894                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1895                    ],
1896                )
1897                .await?;
1898        }
1899
1900        delegate.set_status(Some("Downloading remote development server on host"), cx);
1901
1902        match self
1903            .socket
1904            .run_command(
1905                "curl",
1906                &[
1907                    "-f",
1908                    "-L",
1909                    "-X",
1910                    "GET",
1911                    "-H",
1912                    "Content-Type: application/json",
1913                    "-d",
1914                    &body,
1915                    &url,
1916                    "-o",
1917                    &tmp_path_gz.to_string(),
1918                ],
1919            )
1920            .await
1921        {
1922            Ok(_) => {}
1923            Err(e) => {
1924                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1925                    return Err(e);
1926                }
1927
1928                match self
1929                    .socket
1930                    .run_command(
1931                        "wget",
1932                        &[
1933                            "--method=GET",
1934                            "--header=Content-Type: application/json",
1935                            "--body-data",
1936                            &body,
1937                            &url,
1938                            "-O",
1939                            &tmp_path_gz.to_string(),
1940                        ],
1941                    )
1942                    .await
1943                {
1944                    Ok(_) => {}
1945                    Err(e) => {
1946                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1947                            return Err(e);
1948                        } else {
1949                            anyhow::bail!("Neither curl nor wget is available");
1950                        }
1951                    }
1952                }
1953            }
1954        }
1955
1956        Ok(())
1957    }
1958
1959    async fn upload_local_server_binary(
1960        &self,
1961        src_path: &Path,
1962        tmp_path_gz: &RemotePathBuf,
1963        delegate: &Arc<dyn SshClientDelegate>,
1964        cx: &mut AsyncApp,
1965    ) -> Result<()> {
1966        if let Some(parent) = tmp_path_gz.parent() {
1967            self.socket
1968                .run_command(
1969                    "sh",
1970                    &[
1971                        "-c",
1972                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1973                    ],
1974                )
1975                .await?;
1976        }
1977
1978        let src_stat = fs::metadata(&src_path).await?;
1979        let size = src_stat.len();
1980
1981        let t0 = Instant::now();
1982        delegate.set_status(Some("Uploading remote development server"), cx);
1983        log::info!(
1984            "uploading remote development server to {:?} ({}kb)",
1985            tmp_path_gz,
1986            size / 1024
1987        );
1988        self.upload_file(&src_path, &tmp_path_gz)
1989            .await
1990            .context("failed to upload server binary")?;
1991        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1992        Ok(())
1993    }
1994
1995    async fn extract_server_binary(
1996        &self,
1997        dst_path: &RemotePathBuf,
1998        tmp_path: &RemotePathBuf,
1999        delegate: &Arc<dyn SshClientDelegate>,
2000        cx: &mut AsyncApp,
2001    ) -> Result<()> {
2002        delegate.set_status(Some("Extracting remote development server"), cx);
2003        let server_mode = 0o755;
2004
2005        let orig_tmp_path = tmp_path.to_string();
2006        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2007            shell_script!(
2008                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2009                server_mode = &format!("{:o}", server_mode),
2010                dst_path = &dst_path.to_string(),
2011            )
2012        } else {
2013            shell_script!(
2014                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2015                server_mode = &format!("{:o}", server_mode),
2016                dst_path = &dst_path.to_string()
2017            )
2018        };
2019        self.socket.run_command("sh", &["-c", &script]).await?;
2020        Ok(())
2021    }
2022
2023    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2024        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2025        let mut command = util::command::new_smol_command("scp");
2026        let output = self
2027            .socket
2028            .ssh_options(&mut command)
2029            .args(
2030                self.socket
2031                    .connection_options
2032                    .port
2033                    .map(|port| vec!["-P".to_string(), port.to_string()])
2034                    .unwrap_or_default(),
2035            )
2036            .arg(src_path)
2037            .arg(format!(
2038                "{}:{}",
2039                self.socket.connection_options.scp_url(),
2040                dest_path.to_string()
2041            ))
2042            .output()
2043            .await?;
2044
2045        anyhow::ensure!(
2046            output.status.success(),
2047            "failed to upload file {} -> {}: {}",
2048            src_path.display(),
2049            dest_path.to_string(),
2050            String::from_utf8_lossy(&output.stderr)
2051        );
2052        Ok(())
2053    }
2054
2055    #[cfg(debug_assertions)]
2056    async fn build_local(
2057        &self,
2058        build_remote_server: String,
2059        delegate: &Arc<dyn SshClientDelegate>,
2060        cx: &mut AsyncApp,
2061    ) -> Result<PathBuf> {
2062        use smol::process::{Command, Stdio};
2063        use std::env::VarError;
2064
2065        async fn run_cmd(command: &mut Command) -> Result<()> {
2066            let output = command
2067                .kill_on_drop(true)
2068                .stderr(Stdio::inherit())
2069                .output()
2070                .await?;
2071            anyhow::ensure!(
2072                output.status.success(),
2073                "Failed to run command: {command:?}"
2074            );
2075            Ok(())
2076        }
2077
2078        let use_musl = !build_remote_server.contains("nomusl");
2079        let triple = format!(
2080            "{}-{}",
2081            self.ssh_platform.arch,
2082            match self.ssh_platform.os {
2083                "linux" =>
2084                    if use_musl {
2085                        "unknown-linux-musl"
2086                    } else {
2087                        "unknown-linux-gnu"
2088                    },
2089                "macos" => "apple-darwin",
2090                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2091            }
2092        );
2093        let mut rust_flags = match std::env::var("RUSTFLAGS") {
2094            Ok(val) => val,
2095            Err(VarError::NotPresent) => String::new(),
2096            Err(e) => {
2097                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2098                String::new()
2099            }
2100        };
2101        if self.ssh_platform.os == "linux" && use_musl {
2102            rust_flags.push_str(" -C target-feature=+crt-static");
2103        }
2104        if build_remote_server.contains("mold") {
2105            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2106        }
2107
2108        if self.ssh_platform.arch == std::env::consts::ARCH
2109            && self.ssh_platform.os == std::env::consts::OS
2110        {
2111            delegate.set_status(Some("Building remote server binary from source"), cx);
2112            log::info!("building remote server binary from source");
2113            run_cmd(
2114                Command::new("cargo")
2115                    .args([
2116                        "build",
2117                        "--package",
2118                        "remote_server",
2119                        "--features",
2120                        "debug-embed",
2121                        "--target-dir",
2122                        "target/remote_server",
2123                        "--target",
2124                        &triple,
2125                    ])
2126                    .env("RUSTFLAGS", &rust_flags),
2127            )
2128            .await?;
2129        } else {
2130            if build_remote_server.contains("cross") {
2131                #[cfg(target_os = "windows")]
2132                use util::paths::SanitizedPath;
2133
2134                delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2135                log::info!("installing cross");
2136                run_cmd(Command::new("cargo").args([
2137                    "install",
2138                    "cross",
2139                    "--git",
2140                    "https://github.com/cross-rs/cross",
2141                ]))
2142                .await?;
2143
2144                delegate.set_status(
2145                    Some(&format!(
2146                        "Building remote server binary from source for {} with Docker",
2147                        &triple
2148                    )),
2149                    cx,
2150                );
2151                log::info!("building remote server binary from source for {}", &triple);
2152
2153                // On Windows, the binding needs to be set to the canonical path
2154                #[cfg(target_os = "windows")]
2155                let src =
2156                    SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2157                #[cfg(not(target_os = "windows"))]
2158                let src = "./target";
2159                run_cmd(
2160                    Command::new("cross")
2161                        .args([
2162                            "build",
2163                            "--package",
2164                            "remote_server",
2165                            "--features",
2166                            "debug-embed",
2167                            "--target-dir",
2168                            "target/remote_server",
2169                            "--target",
2170                            &triple,
2171                        ])
2172                        .env(
2173                            "CROSS_CONTAINER_OPTS",
2174                            format!("--mount type=bind,src={src},dst=/app/target"),
2175                        )
2176                        .env("RUSTFLAGS", &rust_flags),
2177                )
2178                .await?;
2179            } else {
2180                let which = cx
2181                    .background_spawn(async move { which::which("zig") })
2182                    .await;
2183
2184                if which.is_err() {
2185                    #[cfg(not(target_os = "windows"))]
2186                    {
2187                        anyhow::bail!(
2188                            "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2189                        )
2190                    }
2191                    #[cfg(target_os = "windows")]
2192                    {
2193                        anyhow::bail!(
2194                            "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2195                        )
2196                    }
2197                }
2198
2199                delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2200                log::info!("adding rustup target");
2201                run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2202
2203                delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2204                log::info!("installing cargo-zigbuild");
2205                run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"]))
2206                    .await?;
2207
2208                delegate.set_status(
2209                    Some(&format!(
2210                        "Building remote binary from source for {triple} with Zig"
2211                    )),
2212                    cx,
2213                );
2214                log::info!("building remote binary from source for {triple} with Zig");
2215                run_cmd(
2216                    Command::new("cargo")
2217                        .args([
2218                            "zigbuild",
2219                            "--package",
2220                            "remote_server",
2221                            "--features",
2222                            "debug-embed",
2223                            "--target-dir",
2224                            "target/remote_server",
2225                            "--target",
2226                            &triple,
2227                        ])
2228                        .env("RUSTFLAGS", &rust_flags),
2229                )
2230                .await?;
2231            }
2232        };
2233        let bin_path = Path::new("target")
2234            .join("remote_server")
2235            .join(&triple)
2236            .join("debug")
2237            .join("remote_server");
2238
2239        let path = if !build_remote_server.contains("nocompress") {
2240            delegate.set_status(Some("Compressing binary"), cx);
2241
2242            #[cfg(not(target_os = "windows"))]
2243            {
2244                run_cmd(Command::new("gzip").args(["-9", "-f", &bin_path.to_string_lossy()]))
2245                    .await?;
2246            }
2247            #[cfg(target_os = "windows")]
2248            {
2249                // On Windows, we use 7z to compress the binary
2250                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2251                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2252                if smol::fs::metadata(&gz_path).await.is_ok() {
2253                    smol::fs::remove_file(&gz_path).await?;
2254                }
2255                run_cmd(Command::new(seven_zip).args([
2256                    "a",
2257                    "-tgzip",
2258                    &gz_path,
2259                    &bin_path.to_string_lossy(),
2260                ]))
2261                .await?;
2262            }
2263
2264            let mut archive_path = bin_path;
2265            archive_path.set_extension("gz");
2266            std::env::current_dir()?.join(archive_path)
2267        } else {
2268            bin_path
2269        };
2270
2271        Ok(path)
2272    }
2273}
2274
2275type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2276
2277pub struct ChannelClient {
2278    next_message_id: AtomicU32,
2279    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2280    buffer: Mutex<VecDeque<Envelope>>,
2281    response_channels: ResponseChannels,
2282    message_handlers: Mutex<ProtoMessageHandlerSet>,
2283    max_received: AtomicU32,
2284    name: &'static str,
2285    task: Mutex<Task<Result<()>>>,
2286}
2287
2288impl ChannelClient {
2289    pub fn new(
2290        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2291        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2292        cx: &App,
2293        name: &'static str,
2294    ) -> Arc<Self> {
2295        Arc::new_cyclic(|this| Self {
2296            outgoing_tx: Mutex::new(outgoing_tx),
2297            next_message_id: AtomicU32::new(0),
2298            max_received: AtomicU32::new(0),
2299            response_channels: ResponseChannels::default(),
2300            message_handlers: Default::default(),
2301            buffer: Mutex::new(VecDeque::new()),
2302            name,
2303            task: Mutex::new(Self::start_handling_messages(
2304                this.clone(),
2305                incoming_rx,
2306                &cx.to_async(),
2307            )),
2308        })
2309    }
2310
2311    fn start_handling_messages(
2312        this: Weak<Self>,
2313        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2314        cx: &AsyncApp,
2315    ) -> Task<Result<()>> {
2316        cx.spawn(async move |cx| {
2317            let peer_id = PeerId { owner_id: 0, id: 0 };
2318            while let Some(incoming) = incoming_rx.next().await {
2319                let Some(this) = this.upgrade() else {
2320                    return anyhow::Ok(());
2321                };
2322                if let Some(ack_id) = incoming.ack_id {
2323                    let mut buffer = this.buffer.lock();
2324                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2325                        buffer.pop_front();
2326                    }
2327                }
2328                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2329                {
2330                    log::debug!(
2331                        "{}:ssh message received. name:FlushBufferedMessages",
2332                        this.name
2333                    );
2334                    {
2335                        let buffer = this.buffer.lock();
2336                        for envelope in buffer.iter() {
2337                            this.outgoing_tx
2338                                .lock()
2339                                .unbounded_send(envelope.clone())
2340                                .ok();
2341                        }
2342                    }
2343                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2344                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2345                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2346                    continue;
2347                }
2348
2349                this.max_received.store(incoming.id, SeqCst);
2350
2351                if let Some(request_id) = incoming.responding_to {
2352                    let request_id = MessageId(request_id);
2353                    let sender = this.response_channels.lock().remove(&request_id);
2354                    if let Some(sender) = sender {
2355                        let (tx, rx) = oneshot::channel();
2356                        if incoming.payload.is_some() {
2357                            sender.send((incoming, tx)).ok();
2358                        }
2359                        rx.await.ok();
2360                    }
2361                } else if let Some(envelope) =
2362                    build_typed_envelope(peer_id, Instant::now(), incoming)
2363                {
2364                    let type_name = envelope.payload_type_name();
2365                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2366                        &this.message_handlers,
2367                        envelope,
2368                        this.clone().into(),
2369                        cx.clone(),
2370                    ) {
2371                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2372                        cx.foreground_executor()
2373                            .spawn(async move {
2374                                match future.await {
2375                                    Ok(_) => {
2376                                        log::debug!(
2377                                            "{}:ssh message handled. name:{type_name}",
2378                                            this.name
2379                                        );
2380                                    }
2381                                    Err(error) => {
2382                                        log::error!(
2383                                            "{}:error handling message. type:{}, error:{}",
2384                                            this.name,
2385                                            type_name,
2386                                            format!("{error:#}").lines().fold(
2387                                                String::new(),
2388                                                |mut message, line| {
2389                                                    if !message.is_empty() {
2390                                                        message.push(' ');
2391                                                    }
2392                                                    message.push_str(line);
2393                                                    message
2394                                                }
2395                                            )
2396                                        );
2397                                    }
2398                                }
2399                            })
2400                            .detach()
2401                    } else {
2402                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2403                    }
2404                }
2405            }
2406            anyhow::Ok(())
2407        })
2408    }
2409
2410    pub fn reconnect(
2411        self: &Arc<Self>,
2412        incoming_rx: UnboundedReceiver<Envelope>,
2413        outgoing_tx: UnboundedSender<Envelope>,
2414        cx: &AsyncApp,
2415    ) {
2416        *self.outgoing_tx.lock() = outgoing_tx;
2417        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2418    }
2419
2420    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2421        let id = (TypeId::of::<E>(), remote_id);
2422
2423        let mut message_handlers = self.message_handlers.lock();
2424        if message_handlers
2425            .entities_by_type_and_remote_id
2426            .contains_key(&id)
2427        {
2428            panic!("already subscribed to entity");
2429        }
2430
2431        message_handlers.entities_by_type_and_remote_id.insert(
2432            id,
2433            EntityMessageSubscriber::Entity {
2434                handle: entity.downgrade().into(),
2435            },
2436        );
2437    }
2438
2439    pub fn request<T: RequestMessage>(
2440        &self,
2441        payload: T,
2442    ) -> impl 'static + Future<Output = Result<T::Response>> {
2443        self.request_internal(payload, true)
2444    }
2445
2446    fn request_internal<T: RequestMessage>(
2447        &self,
2448        payload: T,
2449        use_buffer: bool,
2450    ) -> impl 'static + Future<Output = Result<T::Response>> {
2451        log::debug!("ssh request start. name:{}", T::NAME);
2452        let response =
2453            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2454        async move {
2455            let response = response.await?;
2456            log::debug!("ssh request finish. name:{}", T::NAME);
2457            T::Response::from_envelope(response).context("received a response of the wrong type")
2458        }
2459    }
2460
2461    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2462        smol::future::or(
2463            async {
2464                self.request_internal(proto::FlushBufferedMessages {}, false)
2465                    .await?;
2466
2467                for envelope in self.buffer.lock().iter() {
2468                    self.outgoing_tx
2469                        .lock()
2470                        .unbounded_send(envelope.clone())
2471                        .ok();
2472                }
2473                Ok(())
2474            },
2475            async {
2476                smol::Timer::after(timeout).await;
2477                anyhow::bail!("Timeout detected")
2478            },
2479        )
2480        .await
2481    }
2482
2483    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2484        smol::future::or(
2485            async {
2486                self.request(proto::Ping {}).await?;
2487                Ok(())
2488            },
2489            async {
2490                smol::Timer::after(timeout).await;
2491                anyhow::bail!("Timeout detected")
2492            },
2493        )
2494        .await
2495    }
2496
2497    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2498        log::debug!("ssh send name:{}", T::NAME);
2499        self.send_dynamic(payload.into_envelope(0, None, None))
2500    }
2501
2502    fn request_dynamic(
2503        &self,
2504        mut envelope: proto::Envelope,
2505        type_name: &'static str,
2506        use_buffer: bool,
2507    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2508        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2509        let (tx, rx) = oneshot::channel();
2510        let mut response_channels_lock = self.response_channels.lock();
2511        response_channels_lock.insert(MessageId(envelope.id), tx);
2512        drop(response_channels_lock);
2513
2514        let result = if use_buffer {
2515            self.send_buffered(envelope)
2516        } else {
2517            self.send_unbuffered(envelope)
2518        };
2519        async move {
2520            if let Err(error) = &result {
2521                log::error!("failed to send message: {error}");
2522                anyhow::bail!("failed to send message: {error}");
2523            }
2524
2525            let response = rx.await.context("connection lost")?.0;
2526            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2527                return Err(RpcError::from_proto(error, type_name));
2528            }
2529            Ok(response)
2530        }
2531    }
2532
2533    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2534        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2535        self.send_buffered(envelope)
2536    }
2537
2538    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2539        envelope.ack_id = Some(self.max_received.load(SeqCst));
2540        self.buffer.lock().push_back(envelope.clone());
2541        // ignore errors on send (happen while we're reconnecting)
2542        // assume that the global "disconnected" overlay is sufficient.
2543        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2544        Ok(())
2545    }
2546
2547    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2548        envelope.ack_id = Some(self.max_received.load(SeqCst));
2549        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2550        Ok(())
2551    }
2552}
2553
2554impl ProtoClient for ChannelClient {
2555    fn request(
2556        &self,
2557        envelope: proto::Envelope,
2558        request_type: &'static str,
2559    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2560        self.request_dynamic(envelope, request_type, true).boxed()
2561    }
2562
2563    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2564        self.send_dynamic(envelope)
2565    }
2566
2567    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2568        self.send_dynamic(envelope)
2569    }
2570
2571    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2572        &self.message_handlers
2573    }
2574
2575    fn is_via_collab(&self) -> bool {
2576        false
2577    }
2578}
2579
2580#[cfg(any(test, feature = "test-support"))]
2581mod fake {
2582    use std::{path::PathBuf, sync::Arc};
2583
2584    use anyhow::Result;
2585    use async_trait::async_trait;
2586    use futures::{
2587        FutureExt, SinkExt, StreamExt,
2588        channel::{
2589            mpsc::{self, Sender},
2590            oneshot,
2591        },
2592        select_biased,
2593    };
2594    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2595    use release_channel::ReleaseChannel;
2596    use rpc::proto::Envelope;
2597    use util::paths::{PathStyle, RemotePathBuf};
2598
2599    use super::{
2600        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2601        SshPlatform,
2602    };
2603
2604    pub(super) struct FakeRemoteConnection {
2605        pub(super) connection_options: SshConnectionOptions,
2606        pub(super) server_channel: Arc<ChannelClient>,
2607        pub(super) server_cx: SendableCx,
2608    }
2609
2610    pub(super) struct SendableCx(AsyncApp);
2611    impl SendableCx {
2612        // SAFETY: When run in test mode, GPUI is always single threaded.
2613        pub(super) fn new(cx: &TestAppContext) -> Self {
2614            Self(cx.to_async())
2615        }
2616
2617        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2618        fn get(&self, _: &AsyncApp) -> AsyncApp {
2619            self.0.clone()
2620        }
2621    }
2622
2623    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2624    unsafe impl Send for SendableCx {}
2625    unsafe impl Sync for SendableCx {}
2626
2627    #[async_trait(?Send)]
2628    impl RemoteConnection for FakeRemoteConnection {
2629        async fn kill(&self) -> Result<()> {
2630            Ok(())
2631        }
2632
2633        fn has_been_killed(&self) -> bool {
2634            false
2635        }
2636
2637        fn ssh_args(&self) -> SshArgs {
2638            SshArgs {
2639                arguments: Vec::new(),
2640                envs: None,
2641            }
2642        }
2643
2644        fn upload_directory(
2645            &self,
2646            _src_path: PathBuf,
2647            _dest_path: RemotePathBuf,
2648            _cx: &App,
2649        ) -> Task<Result<()>> {
2650            unreachable!()
2651        }
2652
2653        fn connection_options(&self) -> SshConnectionOptions {
2654            self.connection_options.clone()
2655        }
2656
2657        fn simulate_disconnect(&self, cx: &AsyncApp) {
2658            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2659            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2660            self.server_channel
2661                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2662        }
2663
2664        fn start_proxy(
2665            &self,
2666            _unique_identifier: String,
2667            _reconnect: bool,
2668            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2669            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2670            mut connection_activity_tx: Sender<()>,
2671            _delegate: Arc<dyn SshClientDelegate>,
2672            cx: &mut AsyncApp,
2673        ) -> Task<Result<i32>> {
2674            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2675            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2676
2677            self.server_channel.reconnect(
2678                server_incoming_rx,
2679                server_outgoing_tx,
2680                &self.server_cx.get(cx),
2681            );
2682
2683            cx.background_spawn(async move {
2684                loop {
2685                    select_biased! {
2686                        server_to_client = server_outgoing_rx.next().fuse() => {
2687                            let Some(server_to_client) = server_to_client else {
2688                                return Ok(1)
2689                            };
2690                            connection_activity_tx.try_send(()).ok();
2691                            client_incoming_tx.send(server_to_client).await.ok();
2692                        }
2693                        client_to_server = client_outgoing_rx.next().fuse() => {
2694                            let Some(client_to_server) = client_to_server else {
2695                                return Ok(1)
2696                            };
2697                            server_incoming_tx.send(client_to_server).await.ok();
2698                        }
2699                    }
2700                }
2701            })
2702        }
2703
2704        fn path_style(&self) -> PathStyle {
2705            PathStyle::current()
2706        }
2707    }
2708
2709    pub(super) struct Delegate;
2710
2711    impl SshClientDelegate for Delegate {
2712        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2713            unreachable!()
2714        }
2715
2716        fn download_server_binary_locally(
2717            &self,
2718            _: SshPlatform,
2719            _: ReleaseChannel,
2720            _: Option<SemanticVersion>,
2721            _: &mut AsyncApp,
2722        ) -> Task<Result<PathBuf>> {
2723            unreachable!()
2724        }
2725
2726        fn get_download_params(
2727            &self,
2728            _platform: SshPlatform,
2729            _release_channel: ReleaseChannel,
2730            _version: Option<SemanticVersion>,
2731            _cx: &mut AsyncApp,
2732        ) -> Task<Result<Option<(String, String)>>> {
2733            unreachable!()
2734        }
2735
2736        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2737    }
2738}