ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  13    channel::{
  14        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  15        oneshot,
  16    },
  17    future::{BoxFuture, Shared},
  18    select, select_biased,
  19};
  20use gpui::{
  21    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  22    EventEmitter, Global, SemanticVersion, Task, WeakEntity,
  23};
  24use itertools::Itertools;
  25use parking_lot::Mutex;
  26
  27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  28use rpc::{
  29    AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
  30    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  31};
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use smol::{
  35    fs,
  36    process::{self, Child, Stdio},
  37};
  38use std::{
  39    collections::VecDeque,
  40    fmt, iter,
  41    ops::ControlFlow,
  42    path::{Path, PathBuf},
  43    sync::{
  44        Arc, Weak,
  45        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  46    },
  47    time::{Duration, Instant},
  48};
  49use tempfile::TempDir;
  50use util::{
  51    ResultExt,
  52    paths::{PathStyle, RemotePathBuf},
  53};
  54
  55#[derive(Clone)]
  56pub struct SshSocket {
  57    connection_options: SshConnectionOptions,
  58    #[cfg(not(target_os = "windows"))]
  59    socket_path: PathBuf,
  60    #[cfg(target_os = "windows")]
  61    envs: HashMap<String, String>,
  62}
  63
  64#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
  65pub struct SshPortForwardOption {
  66    #[serde(skip_serializing_if = "Option::is_none")]
  67    pub local_host: Option<String>,
  68    pub local_port: u16,
  69    #[serde(skip_serializing_if = "Option::is_none")]
  70    pub remote_host: Option<String>,
  71    pub remote_port: u16,
  72}
  73
  74#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  75pub struct SshConnectionOptions {
  76    pub host: String,
  77    pub username: Option<String>,
  78    pub port: Option<u16>,
  79    pub password: Option<String>,
  80    pub args: Option<Vec<String>>,
  81    pub port_forwards: Option<Vec<SshPortForwardOption>>,
  82
  83    pub nickname: Option<String>,
  84    pub upload_binary_over_ssh: bool,
  85}
  86
  87#[derive(Debug, Clone, PartialEq, Eq)]
  88pub struct SshArgs {
  89    pub arguments: Vec<String>,
  90    pub envs: Option<HashMap<String, String>>,
  91}
  92
  93#[derive(Debug, Clone, PartialEq, Eq)]
  94pub struct SshInfo {
  95    pub args: SshArgs,
  96    pub path_style: PathStyle,
  97    pub shell: String,
  98}
  99
 100#[macro_export]
 101macro_rules! shell_script {
 102    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
 103        format!(
 104            $fmt,
 105            $(
 106                $name = shlex::try_quote($arg).unwrap()
 107            ),+
 108        )
 109    }};
 110}
 111
 112fn parse_port_number(port_str: &str) -> Result<u16> {
 113    port_str
 114        .parse()
 115        .with_context(|| format!("parsing port number: {port_str}"))
 116}
 117
 118fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
 119    let parts: Vec<&str> = spec.split(':').collect();
 120
 121    match parts.len() {
 122        4 => {
 123            let local_port = parse_port_number(parts[1])?;
 124            let remote_port = parse_port_number(parts[3])?;
 125
 126            Ok(SshPortForwardOption {
 127                local_host: Some(parts[0].to_string()),
 128                local_port,
 129                remote_host: Some(parts[2].to_string()),
 130                remote_port,
 131            })
 132        }
 133        3 => {
 134            let local_port = parse_port_number(parts[0])?;
 135            let remote_port = parse_port_number(parts[2])?;
 136
 137            Ok(SshPortForwardOption {
 138                local_host: None,
 139                local_port,
 140                remote_host: Some(parts[1].to_string()),
 141                remote_port,
 142            })
 143        }
 144        _ => anyhow::bail!("Invalid port forward format"),
 145    }
 146}
 147
 148impl SshConnectionOptions {
 149    pub fn parse_command_line(input: &str) -> Result<Self> {
 150        let input = input.trim_start_matches("ssh ");
 151        let mut hostname: Option<String> = None;
 152        let mut username: Option<String> = None;
 153        let mut port: Option<u16> = None;
 154        let mut args = Vec::new();
 155        let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
 156
 157        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
 158        const ALLOWED_OPTS: &[&str] = &[
 159            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
 160        ];
 161        const ALLOWED_ARGS: &[&str] = &[
 162            "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
 163            "-w",
 164        ];
 165
 166        let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
 167
 168        'outer: while let Some(arg) = tokens.next() {
 169            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 170                args.push(arg.to_string());
 171                continue;
 172            }
 173            if arg == "-p" {
 174                port = tokens.next().and_then(|arg| arg.parse().ok());
 175                continue;
 176            } else if let Some(p) = arg.strip_prefix("-p") {
 177                port = p.parse().ok();
 178                continue;
 179            }
 180            if arg == "-l" {
 181                username = tokens.next();
 182                continue;
 183            } else if let Some(l) = arg.strip_prefix("-l") {
 184                username = Some(l.to_string());
 185                continue;
 186            }
 187            if arg == "-L" || arg.starts_with("-L") {
 188                let forward_spec = if arg == "-L" {
 189                    tokens.next()
 190                } else {
 191                    Some(arg.strip_prefix("-L").unwrap().to_string())
 192                };
 193
 194                if let Some(spec) = forward_spec {
 195                    port_forwards.push(parse_port_forward_spec(&spec)?);
 196                } else {
 197                    anyhow::bail!("Missing port forward format");
 198                }
 199            }
 200
 201            for a in ALLOWED_ARGS {
 202                if arg == *a {
 203                    args.push(arg);
 204                    if let Some(next) = tokens.next() {
 205                        args.push(next);
 206                    }
 207                    continue 'outer;
 208                } else if arg.starts_with(a) {
 209                    args.push(arg);
 210                    continue 'outer;
 211                }
 212            }
 213            if arg.starts_with("-") || hostname.is_some() {
 214                anyhow::bail!("unsupported argument: {:?}", arg);
 215            }
 216            let mut input = &arg as &str;
 217            // Destination might be: username1@username2@ip2@ip1
 218            if let Some((u, rest)) = input.rsplit_once('@') {
 219                input = rest;
 220                username = Some(u.to_string());
 221            }
 222            if let Some((rest, p)) = input.split_once(':') {
 223                input = rest;
 224                port = p.parse().ok()
 225            }
 226            hostname = Some(input.to_string())
 227        }
 228
 229        let Some(hostname) = hostname else {
 230            anyhow::bail!("missing hostname");
 231        };
 232
 233        let port_forwards = match port_forwards.len() {
 234            0 => None,
 235            _ => Some(port_forwards),
 236        };
 237
 238        Ok(Self {
 239            host: hostname,
 240            username,
 241            port,
 242            port_forwards,
 243            args: Some(args),
 244            password: None,
 245            nickname: None,
 246            upload_binary_over_ssh: false,
 247        })
 248    }
 249
 250    pub fn ssh_url(&self) -> String {
 251        let mut result = String::from("ssh://");
 252        if let Some(username) = &self.username {
 253            // Username might be: username1@username2@ip2
 254            let username = urlencoding::encode(username);
 255            result.push_str(&username);
 256            result.push('@');
 257        }
 258        result.push_str(&self.host);
 259        if let Some(port) = self.port {
 260            result.push(':');
 261            result.push_str(&port.to_string());
 262        }
 263        result
 264    }
 265
 266    pub fn additional_args(&self) -> Vec<String> {
 267        let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
 268
 269        if let Some(forwards) = &self.port_forwards {
 270            args.extend(forwards.iter().map(|pf| {
 271                let local_host = match &pf.local_host {
 272                    Some(host) => host,
 273                    None => "localhost",
 274                };
 275                let remote_host = match &pf.remote_host {
 276                    Some(host) => host,
 277                    None => "localhost",
 278                };
 279
 280                format!(
 281                    "-L{}:{}:{}:{}",
 282                    local_host, pf.local_port, remote_host, pf.remote_port
 283                )
 284            }));
 285        }
 286
 287        args
 288    }
 289
 290    fn scp_url(&self) -> String {
 291        if let Some(username) = &self.username {
 292            format!("{}@{}", username, self.host)
 293        } else {
 294            self.host.clone()
 295        }
 296    }
 297
 298    pub fn connection_string(&self) -> String {
 299        let host = if let Some(username) = &self.username {
 300            format!("{}@{}", username, self.host)
 301        } else {
 302            self.host.clone()
 303        };
 304        if let Some(port) = &self.port {
 305            format!("{}:{}", host, port)
 306        } else {
 307            host
 308        }
 309    }
 310}
 311
 312#[derive(Copy, Clone, Debug)]
 313pub struct SshPlatform {
 314    pub os: &'static str,
 315    pub arch: &'static str,
 316}
 317
 318pub trait SshClientDelegate: Send + Sync {
 319    fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
 320    fn get_download_params(
 321        &self,
 322        platform: SshPlatform,
 323        release_channel: ReleaseChannel,
 324        version: Option<SemanticVersion>,
 325        cx: &mut AsyncApp,
 326    ) -> Task<Result<Option<(String, String)>>>;
 327
 328    fn download_server_binary_locally(
 329        &self,
 330        platform: SshPlatform,
 331        release_channel: ReleaseChannel,
 332        version: Option<SemanticVersion>,
 333        cx: &mut AsyncApp,
 334    ) -> Task<Result<PathBuf>>;
 335    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 336}
 337
 338impl SshSocket {
 339    #[cfg(not(target_os = "windows"))]
 340    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
 341        Ok(Self {
 342            connection_options: options,
 343            socket_path,
 344        })
 345    }
 346
 347    #[cfg(target_os = "windows")]
 348    fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
 349        let askpass_script = temp_dir.path().join("askpass.bat");
 350        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
 351        let mut envs = HashMap::default();
 352        envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
 353        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
 354        envs.insert("ZED_SSH_ASKPASS".into(), secret);
 355        Ok(Self {
 356            connection_options: options,
 357            envs,
 358        })
 359    }
 360
 361    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 362    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 363    // and passes -l as an argument to sh, not to ls.
 364    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 365    // into a machine. You must use `cd` to get back to $HOME.
 366    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 367    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 368        let mut command = util::command::new_smol_command("ssh");
 369        let to_run = iter::once(&program)
 370            .chain(args.iter())
 371            .map(|token| {
 372                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 373                debug_assert!(
 374                    !token.contains('\n'),
 375                    "multiline arguments do not work in all shells"
 376                );
 377                shlex::try_quote(token).unwrap()
 378            })
 379            .join(" ");
 380        let to_run = format!("cd; {to_run}");
 381        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 382        self.ssh_options(&mut command)
 383            .arg(self.connection_options.ssh_url())
 384            .arg(to_run);
 385        command
 386    }
 387
 388    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 389        let output = self.ssh_command(program, args).output().await?;
 390        anyhow::ensure!(
 391            output.status.success(),
 392            "failed to run command: {}",
 393            String::from_utf8_lossy(&output.stderr)
 394        );
 395        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 396    }
 397
 398    #[cfg(not(target_os = "windows"))]
 399    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 400        command
 401            .stdin(Stdio::piped())
 402            .stdout(Stdio::piped())
 403            .stderr(Stdio::piped())
 404            .args(self.connection_options.additional_args())
 405            .args(["-o", "ControlMaster=no", "-o"])
 406            .arg(format!("ControlPath={}", self.socket_path.display()))
 407    }
 408
 409    #[cfg(target_os = "windows")]
 410    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 411        command
 412            .stdin(Stdio::piped())
 413            .stdout(Stdio::piped())
 414            .stderr(Stdio::piped())
 415            .args(self.connection_options.additional_args())
 416            .envs(self.envs.clone())
 417    }
 418
 419    // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
 420    // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
 421    #[cfg(not(target_os = "windows"))]
 422    fn ssh_args(&self) -> SshArgs {
 423        let mut arguments = self.connection_options.additional_args();
 424        arguments.extend(vec![
 425            "-o".to_string(),
 426            "ControlMaster=no".to_string(),
 427            "-o".to_string(),
 428            format!("ControlPath={}", self.socket_path.display()),
 429            self.connection_options.ssh_url(),
 430        ]);
 431        SshArgs {
 432            arguments,
 433            envs: None,
 434        }
 435    }
 436
 437    #[cfg(target_os = "windows")]
 438    fn ssh_args(&self) -> SshArgs {
 439        let mut arguments = self.connection_options.additional_args();
 440        arguments.push(self.connection_options.ssh_url());
 441        SshArgs {
 442            arguments,
 443            envs: Some(self.envs.clone()),
 444        }
 445    }
 446
 447    async fn platform(&self) -> Result<SshPlatform> {
 448        let uname = self.run_command("sh", &["-lc", "uname -sm"]).await?;
 449        let Some((os, arch)) = uname.split_once(" ") else {
 450            anyhow::bail!("unknown uname: {uname:?}")
 451        };
 452
 453        let os = match os.trim() {
 454            "Darwin" => "macos",
 455            "Linux" => "linux",
 456            _ => anyhow::bail!(
 457                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 458            ),
 459        };
 460        // exclude armv5,6,7 as they are 32-bit.
 461        let arch = if arch.starts_with("armv8")
 462            || arch.starts_with("armv9")
 463            || arch.starts_with("arm64")
 464            || arch.starts_with("aarch64")
 465        {
 466            "aarch64"
 467        } else if arch.starts_with("x86") {
 468            "x86_64"
 469        } else {
 470            anyhow::bail!(
 471                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 472            )
 473        };
 474
 475        Ok(SshPlatform { os, arch })
 476    }
 477
 478    async fn shell(&self) -> String {
 479        match self.run_command("sh", &["-lc", "echo $SHELL"]).await {
 480            Ok(shell) => shell.trim().to_owned(),
 481            Err(e) => {
 482                log::error!("Failed to get shell: {e}");
 483                "sh".to_owned()
 484            }
 485        }
 486    }
 487}
 488
 489const MAX_MISSED_HEARTBEATS: usize = 5;
 490const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 491const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 492
 493const MAX_RECONNECT_ATTEMPTS: usize = 3;
 494
 495enum State {
 496    Connecting,
 497    Connected {
 498        ssh_connection: Arc<dyn RemoteConnection>,
 499        delegate: Arc<dyn SshClientDelegate>,
 500
 501        multiplex_task: Task<Result<()>>,
 502        heartbeat_task: Task<Result<()>>,
 503    },
 504    HeartbeatMissed {
 505        missed_heartbeats: usize,
 506
 507        ssh_connection: Arc<dyn RemoteConnection>,
 508        delegate: Arc<dyn SshClientDelegate>,
 509
 510        multiplex_task: Task<Result<()>>,
 511        heartbeat_task: Task<Result<()>>,
 512    },
 513    Reconnecting,
 514    ReconnectFailed {
 515        ssh_connection: Arc<dyn RemoteConnection>,
 516        delegate: Arc<dyn SshClientDelegate>,
 517
 518        error: anyhow::Error,
 519        attempts: usize,
 520    },
 521    ReconnectExhausted,
 522    ServerNotRunning,
 523}
 524
 525impl fmt::Display for State {
 526    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 527        match self {
 528            Self::Connecting => write!(f, "connecting"),
 529            Self::Connected { .. } => write!(f, "connected"),
 530            Self::Reconnecting => write!(f, "reconnecting"),
 531            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 532            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 533            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 534            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 535        }
 536    }
 537}
 538
 539impl State {
 540    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 541        match self {
 542            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 543            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 544            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 545            _ => None,
 546        }
 547    }
 548
 549    fn can_reconnect(&self) -> bool {
 550        match self {
 551            Self::Connected { .. }
 552            | Self::HeartbeatMissed { .. }
 553            | Self::ReconnectFailed { .. } => true,
 554            State::Connecting
 555            | State::Reconnecting
 556            | State::ReconnectExhausted
 557            | State::ServerNotRunning => false,
 558        }
 559    }
 560
 561    fn is_reconnect_failed(&self) -> bool {
 562        matches!(self, Self::ReconnectFailed { .. })
 563    }
 564
 565    fn is_reconnect_exhausted(&self) -> bool {
 566        matches!(self, Self::ReconnectExhausted { .. })
 567    }
 568
 569    fn is_server_not_running(&self) -> bool {
 570        matches!(self, Self::ServerNotRunning)
 571    }
 572
 573    fn is_reconnecting(&self) -> bool {
 574        matches!(self, Self::Reconnecting { .. })
 575    }
 576
 577    fn heartbeat_recovered(self) -> Self {
 578        match self {
 579            Self::HeartbeatMissed {
 580                ssh_connection,
 581                delegate,
 582                multiplex_task,
 583                heartbeat_task,
 584                ..
 585            } => Self::Connected {
 586                ssh_connection,
 587                delegate,
 588                multiplex_task,
 589                heartbeat_task,
 590            },
 591            _ => self,
 592        }
 593    }
 594
 595    fn heartbeat_missed(self) -> Self {
 596        match self {
 597            Self::Connected {
 598                ssh_connection,
 599                delegate,
 600                multiplex_task,
 601                heartbeat_task,
 602            } => Self::HeartbeatMissed {
 603                missed_heartbeats: 1,
 604                ssh_connection,
 605                delegate,
 606                multiplex_task,
 607                heartbeat_task,
 608            },
 609            Self::HeartbeatMissed {
 610                missed_heartbeats,
 611                ssh_connection,
 612                delegate,
 613                multiplex_task,
 614                heartbeat_task,
 615            } => Self::HeartbeatMissed {
 616                missed_heartbeats: missed_heartbeats + 1,
 617                ssh_connection,
 618                delegate,
 619                multiplex_task,
 620                heartbeat_task,
 621            },
 622            _ => self,
 623        }
 624    }
 625}
 626
 627/// The state of the ssh connection.
 628#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 629pub enum ConnectionState {
 630    Connecting,
 631    Connected,
 632    HeartbeatMissed,
 633    Reconnecting,
 634    Disconnected,
 635}
 636
 637impl From<&State> for ConnectionState {
 638    fn from(value: &State) -> Self {
 639        match value {
 640            State::Connecting => Self::Connecting,
 641            State::Connected { .. } => Self::Connected,
 642            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 643            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 644            State::ReconnectExhausted => Self::Disconnected,
 645            State::ServerNotRunning => Self::Disconnected,
 646        }
 647    }
 648}
 649
 650pub struct SshRemoteClient {
 651    client: Arc<ChannelClient>,
 652    unique_identifier: String,
 653    connection_options: SshConnectionOptions,
 654    path_style: PathStyle,
 655    state: Arc<Mutex<Option<State>>>,
 656}
 657
 658#[derive(Debug)]
 659pub enum SshRemoteEvent {
 660    Disconnected,
 661}
 662
 663impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 664
 665// Identifies the socket on the remote server so that reconnects
 666// can re-join the same project.
 667pub enum ConnectionIdentifier {
 668    Setup(u64),
 669    Workspace(i64),
 670}
 671
 672static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 673
 674impl ConnectionIdentifier {
 675    pub fn setup() -> Self {
 676        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 677    }
 678
 679    // This string gets used in a socket name, and so must be relatively short.
 680    // The total length of:
 681    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 682    // Must be less than about 100 characters
 683    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 684    // So our strings should be at most 20 characters or so.
 685    fn to_string(&self, cx: &App) -> String {
 686        let identifier_prefix = match ReleaseChannel::global(cx) {
 687            ReleaseChannel::Stable => "".to_string(),
 688            release_channel => format!("{}-", release_channel.dev_name()),
 689        };
 690        match self {
 691            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 692            Self::Workspace(workspace_id) => {
 693                format!("{identifier_prefix}workspace-{workspace_id}",)
 694            }
 695        }
 696    }
 697}
 698
 699impl SshRemoteClient {
 700    pub fn new(
 701        unique_identifier: ConnectionIdentifier,
 702        connection_options: SshConnectionOptions,
 703        cancellation: oneshot::Receiver<()>,
 704        delegate: Arc<dyn SshClientDelegate>,
 705        cx: &mut App,
 706    ) -> Task<Result<Option<Entity<Self>>>> {
 707        let unique_identifier = unique_identifier.to_string(cx);
 708        cx.spawn(async move |cx| {
 709            let success = Box::pin(async move {
 710                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 711                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 712                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 713
 714                let client =
 715                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 716
 717                let ssh_connection = cx
 718                    .update(|cx| {
 719                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 720                            pool.connect(connection_options.clone(), &delegate, cx)
 721                        })
 722                    })?
 723                    .await
 724                    .map_err(|e| e.cloned())?;
 725
 726                let path_style = ssh_connection.path_style();
 727                let this = cx.new(|_| Self {
 728                    client: client.clone(),
 729                    unique_identifier: unique_identifier.clone(),
 730                    connection_options,
 731                    path_style,
 732                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 733                })?;
 734
 735                let io_task = ssh_connection.start_proxy(
 736                    unique_identifier,
 737                    false,
 738                    incoming_tx,
 739                    outgoing_rx,
 740                    connection_activity_tx,
 741                    delegate.clone(),
 742                    cx,
 743                );
 744
 745                let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 746
 747                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 748                    log::error!("failed to establish connection: {}", error);
 749                    return Err(error);
 750                }
 751
 752                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 753
 754                this.update(cx, |this, _| {
 755                    *this.state.lock() = Some(State::Connected {
 756                        ssh_connection,
 757                        delegate,
 758                        multiplex_task,
 759                        heartbeat_task,
 760                    });
 761                })?;
 762
 763                Ok(Some(this))
 764            });
 765
 766            select! {
 767                _ = cancellation.fuse() => {
 768                    Ok(None)
 769                }
 770                result = success.fuse() =>  result
 771            }
 772        })
 773    }
 774
 775    pub fn proto_client_from_channels(
 776        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 777        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 778        cx: &App,
 779        name: &'static str,
 780    ) -> AnyProtoClient {
 781        ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
 782    }
 783
 784    pub fn shutdown_processes<T: RequestMessage>(
 785        &self,
 786        shutdown_request: Option<T>,
 787        executor: BackgroundExecutor,
 788    ) -> Option<impl Future<Output = ()> + use<T>> {
 789        let state = self.state.lock().take()?;
 790        log::info!("shutting down ssh processes");
 791
 792        let State::Connected {
 793            multiplex_task,
 794            heartbeat_task,
 795            ssh_connection,
 796            delegate,
 797        } = state
 798        else {
 799            return None;
 800        };
 801
 802        let client = self.client.clone();
 803
 804        Some(async move {
 805            if let Some(shutdown_request) = shutdown_request {
 806                client.send(shutdown_request).log_err();
 807                // We wait 50ms instead of waiting for a response, because
 808                // waiting for a response would require us to wait on the main thread
 809                // which we want to avoid in an `on_app_quit` callback.
 810                executor.timer(Duration::from_millis(50)).await;
 811            }
 812
 813            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 814            // child of master_process.
 815            drop(multiplex_task);
 816            // Now drop the rest of state, which kills master process.
 817            drop(heartbeat_task);
 818            drop(ssh_connection);
 819            drop(delegate);
 820        })
 821    }
 822
 823    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 824        let mut lock = self.state.lock();
 825
 826        let can_reconnect = lock
 827            .as_ref()
 828            .map(|state| state.can_reconnect())
 829            .unwrap_or(false);
 830        if !can_reconnect {
 831            log::info!("aborting reconnect, because not in state that allows reconnecting");
 832            let error = if let Some(state) = lock.as_ref() {
 833                format!("invalid state, cannot reconnect while in state {state}")
 834            } else {
 835                "no state set".to_string()
 836            };
 837            anyhow::bail!(error);
 838        }
 839
 840        let state = lock.take().unwrap();
 841        let (attempts, ssh_connection, delegate) = match state {
 842            State::Connected {
 843                ssh_connection,
 844                delegate,
 845                multiplex_task,
 846                heartbeat_task,
 847            }
 848            | State::HeartbeatMissed {
 849                ssh_connection,
 850                delegate,
 851                multiplex_task,
 852                heartbeat_task,
 853                ..
 854            } => {
 855                drop(multiplex_task);
 856                drop(heartbeat_task);
 857                (0, ssh_connection, delegate)
 858            }
 859            State::ReconnectFailed {
 860                attempts,
 861                ssh_connection,
 862                delegate,
 863                ..
 864            } => (attempts, ssh_connection, delegate),
 865            State::Connecting
 866            | State::Reconnecting
 867            | State::ReconnectExhausted
 868            | State::ServerNotRunning => unreachable!(),
 869        };
 870
 871        let attempts = attempts + 1;
 872        if attempts > MAX_RECONNECT_ATTEMPTS {
 873            log::error!(
 874                "Failed to reconnect to after {} attempts, giving up",
 875                MAX_RECONNECT_ATTEMPTS
 876            );
 877            drop(lock);
 878            self.set_state(State::ReconnectExhausted, cx);
 879            return Ok(());
 880        }
 881        drop(lock);
 882
 883        self.set_state(State::Reconnecting, cx);
 884
 885        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 886
 887        let unique_identifier = self.unique_identifier.clone();
 888        let client = self.client.clone();
 889        let reconnect_task = cx.spawn(async move |this, cx| {
 890            macro_rules! failed {
 891                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 892                    return State::ReconnectFailed {
 893                        error: anyhow!($error),
 894                        attempts: $attempts,
 895                        ssh_connection: $ssh_connection,
 896                        delegate: $delegate,
 897                    };
 898                };
 899            }
 900
 901            if let Err(error) = ssh_connection
 902                .kill()
 903                .await
 904                .context("Failed to kill ssh process")
 905            {
 906                failed!(error, attempts, ssh_connection, delegate);
 907            };
 908
 909            let connection_options = ssh_connection.connection_options();
 910
 911            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 912            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 913            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 914
 915            let (ssh_connection, io_task) = match async {
 916                let ssh_connection = cx
 917                    .update_global(|pool: &mut ConnectionPool, cx| {
 918                        pool.connect(connection_options, &delegate, cx)
 919                    })?
 920                    .await
 921                    .map_err(|error| error.cloned())?;
 922
 923                let io_task = ssh_connection.start_proxy(
 924                    unique_identifier,
 925                    true,
 926                    incoming_tx,
 927                    outgoing_rx,
 928                    connection_activity_tx,
 929                    delegate.clone(),
 930                    cx,
 931                );
 932                anyhow::Ok((ssh_connection, io_task))
 933            }
 934            .await
 935            {
 936                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 937                Err(error) => {
 938                    failed!(error, attempts, ssh_connection, delegate);
 939                }
 940            };
 941
 942            let multiplex_task = Self::monitor(this.clone(), io_task, cx);
 943            client.reconnect(incoming_rx, outgoing_tx, cx);
 944
 945            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 946                failed!(error, attempts, ssh_connection, delegate);
 947            };
 948
 949            State::Connected {
 950                ssh_connection,
 951                delegate,
 952                multiplex_task,
 953                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 954            }
 955        });
 956
 957        cx.spawn(async move |this, cx| {
 958            let new_state = reconnect_task.await;
 959            this.update(cx, |this, cx| {
 960                this.try_set_state(cx, |old_state| {
 961                    if old_state.is_reconnecting() {
 962                        match &new_state {
 963                            State::Connecting
 964                            | State::Reconnecting
 965                            | State::HeartbeatMissed { .. }
 966                            | State::ServerNotRunning => {}
 967                            State::Connected { .. } => {
 968                                log::info!("Successfully reconnected");
 969                            }
 970                            State::ReconnectFailed {
 971                                error, attempts, ..
 972                            } => {
 973                                log::error!(
 974                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 975                                    attempts,
 976                                    error
 977                                );
 978                            }
 979                            State::ReconnectExhausted => {
 980                                log::error!("Reconnect attempt failed and all attempts exhausted");
 981                            }
 982                        }
 983                        Some(new_state)
 984                    } else {
 985                        None
 986                    }
 987                });
 988
 989                if this.state_is(State::is_reconnect_failed) {
 990                    this.reconnect(cx)
 991                } else if this.state_is(State::is_reconnect_exhausted) {
 992                    Ok(())
 993                } else {
 994                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 995                    Ok(())
 996                }
 997            })
 998        })
 999        .detach_and_log_err(cx);
1000
1001        Ok(())
1002    }
1003
1004    fn heartbeat(
1005        this: WeakEntity<Self>,
1006        mut connection_activity_rx: mpsc::Receiver<()>,
1007        cx: &mut AsyncApp,
1008    ) -> Task<Result<()>> {
1009        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
1010            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
1011        };
1012
1013        cx.spawn(async move |cx| {
1014            let mut missed_heartbeats = 0;
1015
1016            let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
1017            futures::pin_mut!(keepalive_timer);
1018
1019            loop {
1020                select_biased! {
1021                    result = connection_activity_rx.next().fuse() => {
1022                        if result.is_none() {
1023                            log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1024                            return Ok(());
1025                        }
1026
1027                        if missed_heartbeats != 0 {
1028                            missed_heartbeats = 0;
1029                            let _ =this.update(cx, |this, cx| {
1030                                this.handle_heartbeat_result(missed_heartbeats, cx)
1031                            })?;
1032                        }
1033                    }
1034                    _ = keepalive_timer => {
1035                        log::debug!("Sending heartbeat to server...");
1036
1037                        let result = select_biased! {
1038                            _ = connection_activity_rx.next().fuse() => {
1039                                Ok(())
1040                            }
1041                            ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1042                                ping_result
1043                            }
1044                        };
1045
1046                        if result.is_err() {
1047                            missed_heartbeats += 1;
1048                            log::warn!(
1049                                "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1050                                HEARTBEAT_TIMEOUT,
1051                                missed_heartbeats,
1052                                MAX_MISSED_HEARTBEATS
1053                            );
1054                        } else if missed_heartbeats != 0 {
1055                            missed_heartbeats = 0;
1056                        } else {
1057                            continue;
1058                        }
1059
1060                        let result = this.update(cx, |this, cx| {
1061                            this.handle_heartbeat_result(missed_heartbeats, cx)
1062                        })?;
1063                        if result.is_break() {
1064                            return Ok(());
1065                        }
1066                    }
1067                }
1068
1069                keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1070            }
1071        })
1072    }
1073
1074    fn handle_heartbeat_result(
1075        &mut self,
1076        missed_heartbeats: usize,
1077        cx: &mut Context<Self>,
1078    ) -> ControlFlow<()> {
1079        let state = self.state.lock().take().unwrap();
1080        let next_state = if missed_heartbeats > 0 {
1081            state.heartbeat_missed()
1082        } else {
1083            state.heartbeat_recovered()
1084        };
1085
1086        self.set_state(next_state, cx);
1087
1088        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1089            log::error!(
1090                "Missed last {} heartbeats. Reconnecting...",
1091                missed_heartbeats
1092            );
1093
1094            self.reconnect(cx)
1095                .context("failed to start reconnect process after missing heartbeats")
1096                .log_err();
1097            ControlFlow::Break(())
1098        } else {
1099            ControlFlow::Continue(())
1100        }
1101    }
1102
1103    fn monitor(
1104        this: WeakEntity<Self>,
1105        io_task: Task<Result<i32>>,
1106        cx: &AsyncApp,
1107    ) -> Task<Result<()>> {
1108        cx.spawn(async move |cx| {
1109            let result = io_task.await;
1110
1111            match result {
1112                Ok(exit_code) => {
1113                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1114                        match error {
1115                            ProxyLaunchError::ServerNotRunning => {
1116                                log::error!("failed to reconnect because server is not running");
1117                                this.update(cx, |this, cx| {
1118                                    this.set_state(State::ServerNotRunning, cx);
1119                                })?;
1120                            }
1121                        }
1122                    } else if exit_code > 0 {
1123                        log::error!("proxy process terminated unexpectedly");
1124                        this.update(cx, |this, cx| {
1125                            this.reconnect(cx).ok();
1126                        })?;
1127                    }
1128                }
1129                Err(error) => {
1130                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1131                    this.update(cx, |this, cx| {
1132                        this.reconnect(cx).ok();
1133                    })?;
1134                }
1135            }
1136
1137            Ok(())
1138        })
1139    }
1140
1141    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1142        self.state.lock().as_ref().is_some_and(check)
1143    }
1144
1145    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1146        let mut lock = self.state.lock();
1147        let new_state = lock.as_ref().and_then(map);
1148
1149        if let Some(new_state) = new_state {
1150            lock.replace(new_state);
1151            cx.notify();
1152        }
1153    }
1154
1155    fn set_state(&self, state: State, cx: &mut Context<Self>) {
1156        log::info!("setting state to '{}'", &state);
1157
1158        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1159        let is_server_not_running = state.is_server_not_running();
1160        self.state.lock().replace(state);
1161
1162        if is_reconnect_exhausted || is_server_not_running {
1163            cx.emit(SshRemoteEvent::Disconnected);
1164        }
1165        cx.notify();
1166    }
1167
1168    pub fn ssh_info(&self) -> Option<SshInfo> {
1169        self.state
1170            .lock()
1171            .as_ref()
1172            .and_then(|state| state.ssh_connection())
1173            .map(|ssh_connection| SshInfo {
1174                args: ssh_connection.ssh_args(),
1175                path_style: ssh_connection.path_style(),
1176                shell: ssh_connection.shell(),
1177            })
1178    }
1179
1180    pub fn upload_directory(
1181        &self,
1182        src_path: PathBuf,
1183        dest_path: RemotePathBuf,
1184        cx: &App,
1185    ) -> Task<Result<()>> {
1186        let state = self.state.lock();
1187        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1188            return Task::ready(Err(anyhow!("no ssh connection")));
1189        };
1190        connection.upload_directory(src_path, dest_path, cx)
1191    }
1192
1193    pub fn proto_client(&self) -> AnyProtoClient {
1194        self.client.clone().into()
1195    }
1196
1197    pub fn connection_string(&self) -> String {
1198        self.connection_options.connection_string()
1199    }
1200
1201    pub fn connection_options(&self) -> SshConnectionOptions {
1202        self.connection_options.clone()
1203    }
1204
1205    pub fn connection_state(&self) -> ConnectionState {
1206        self.state
1207            .lock()
1208            .as_ref()
1209            .map(ConnectionState::from)
1210            .unwrap_or(ConnectionState::Disconnected)
1211    }
1212
1213    pub fn is_disconnected(&self) -> bool {
1214        self.connection_state() == ConnectionState::Disconnected
1215    }
1216
1217    pub fn path_style(&self) -> PathStyle {
1218        self.path_style
1219    }
1220
1221    #[cfg(any(test, feature = "test-support"))]
1222    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1223        let opts = self.connection_options();
1224        client_cx.spawn(async move |cx| {
1225            let connection = cx
1226                .update_global(|c: &mut ConnectionPool, _| {
1227                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1228                        c.clone()
1229                    } else {
1230                        panic!("missing test connection")
1231                    }
1232                })
1233                .unwrap()
1234                .await
1235                .unwrap();
1236
1237            connection.simulate_disconnect(cx);
1238        })
1239    }
1240
1241    #[cfg(any(test, feature = "test-support"))]
1242    pub fn fake_server(
1243        client_cx: &mut gpui::TestAppContext,
1244        server_cx: &mut gpui::TestAppContext,
1245    ) -> (SshConnectionOptions, AnyProtoClient) {
1246        let port = client_cx
1247            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1248        let opts = SshConnectionOptions {
1249            host: "<fake>".to_string(),
1250            port: Some(port),
1251            ..Default::default()
1252        };
1253        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1254        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1255        let server_client =
1256            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1257        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1258            connection_options: opts.clone(),
1259            server_cx: fake::SendableCx::new(server_cx),
1260            server_channel: server_client.clone(),
1261        });
1262
1263        client_cx.update(|cx| {
1264            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1265                c.connections.insert(
1266                    opts.clone(),
1267                    ConnectionPoolEntry::Connecting(
1268                        cx.background_spawn({
1269                            let connection = connection.clone();
1270                            async move { Ok(connection.clone()) }
1271                        })
1272                        .shared(),
1273                    ),
1274                );
1275            })
1276        });
1277
1278        (opts, server_client.into())
1279    }
1280
1281    #[cfg(any(test, feature = "test-support"))]
1282    pub async fn fake_client(
1283        opts: SshConnectionOptions,
1284        client_cx: &mut gpui::TestAppContext,
1285    ) -> Entity<Self> {
1286        let (_tx, rx) = oneshot::channel();
1287        client_cx
1288            .update(|cx| {
1289                Self::new(
1290                    ConnectionIdentifier::setup(),
1291                    opts,
1292                    rx,
1293                    Arc::new(fake::Delegate),
1294                    cx,
1295                )
1296            })
1297            .await
1298            .unwrap()
1299            .unwrap()
1300    }
1301}
1302
1303enum ConnectionPoolEntry {
1304    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1305    Connected(Weak<dyn RemoteConnection>),
1306}
1307
1308#[derive(Default)]
1309struct ConnectionPool {
1310    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1311}
1312
1313impl Global for ConnectionPool {}
1314
1315impl ConnectionPool {
1316    pub fn connect(
1317        &mut self,
1318        opts: SshConnectionOptions,
1319        delegate: &Arc<dyn SshClientDelegate>,
1320        cx: &mut App,
1321    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1322        let connection = self.connections.get(&opts);
1323        match connection {
1324            Some(ConnectionPoolEntry::Connecting(task)) => {
1325                let delegate = delegate.clone();
1326                cx.spawn(async move |cx| {
1327                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1328                })
1329                .detach();
1330                return task.clone();
1331            }
1332            Some(ConnectionPoolEntry::Connected(ssh)) => {
1333                if let Some(ssh) = ssh.upgrade()
1334                    && !ssh.has_been_killed()
1335                {
1336                    return Task::ready(Ok(ssh)).shared();
1337                }
1338                self.connections.remove(&opts);
1339            }
1340            None => {}
1341        }
1342
1343        let task = cx
1344            .spawn({
1345                let opts = opts.clone();
1346                let delegate = delegate.clone();
1347                async move |cx| {
1348                    let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1349                        .await
1350                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1351
1352                    cx.update_global(|pool: &mut Self, _| {
1353                        debug_assert!(matches!(
1354                            pool.connections.get(&opts),
1355                            Some(ConnectionPoolEntry::Connecting(_))
1356                        ));
1357                        match connection {
1358                            Ok(connection) => {
1359                                pool.connections.insert(
1360                                    opts.clone(),
1361                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1362                                );
1363                                Ok(connection)
1364                            }
1365                            Err(error) => {
1366                                pool.connections.remove(&opts);
1367                                Err(Arc::new(error))
1368                            }
1369                        }
1370                    })?
1371                }
1372            })
1373            .shared();
1374
1375        self.connections
1376            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1377        task
1378    }
1379}
1380
1381impl From<SshRemoteClient> for AnyProtoClient {
1382    fn from(client: SshRemoteClient) -> Self {
1383        AnyProtoClient::new(client.client)
1384    }
1385}
1386
1387#[async_trait(?Send)]
1388trait RemoteConnection: Send + Sync {
1389    fn start_proxy(
1390        &self,
1391        unique_identifier: String,
1392        reconnect: bool,
1393        incoming_tx: UnboundedSender<Envelope>,
1394        outgoing_rx: UnboundedReceiver<Envelope>,
1395        connection_activity_tx: Sender<()>,
1396        delegate: Arc<dyn SshClientDelegate>,
1397        cx: &mut AsyncApp,
1398    ) -> Task<Result<i32>>;
1399    fn upload_directory(
1400        &self,
1401        src_path: PathBuf,
1402        dest_path: RemotePathBuf,
1403        cx: &App,
1404    ) -> Task<Result<()>>;
1405    async fn kill(&self) -> Result<()>;
1406    fn has_been_killed(&self) -> bool;
1407    /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1408    /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1409    fn ssh_args(&self) -> SshArgs;
1410    fn connection_options(&self) -> SshConnectionOptions;
1411    fn path_style(&self) -> PathStyle;
1412    fn shell(&self) -> String;
1413
1414    #[cfg(any(test, feature = "test-support"))]
1415    fn simulate_disconnect(&self, _: &AsyncApp) {}
1416}
1417
1418struct SshRemoteConnection {
1419    socket: SshSocket,
1420    master_process: Mutex<Option<Child>>,
1421    remote_binary_path: Option<RemotePathBuf>,
1422    ssh_platform: SshPlatform,
1423    ssh_path_style: PathStyle,
1424    ssh_shell: String,
1425    _temp_dir: TempDir,
1426}
1427
1428#[async_trait(?Send)]
1429impl RemoteConnection for SshRemoteConnection {
1430    async fn kill(&self) -> Result<()> {
1431        let Some(mut process) = self.master_process.lock().take() else {
1432            return Ok(());
1433        };
1434        process.kill().ok();
1435        process.status().await?;
1436        Ok(())
1437    }
1438
1439    fn has_been_killed(&self) -> bool {
1440        self.master_process.lock().is_none()
1441    }
1442
1443    fn ssh_args(&self) -> SshArgs {
1444        self.socket.ssh_args()
1445    }
1446
1447    fn connection_options(&self) -> SshConnectionOptions {
1448        self.socket.connection_options.clone()
1449    }
1450
1451    fn shell(&self) -> String {
1452        self.ssh_shell.clone()
1453    }
1454
1455    fn upload_directory(
1456        &self,
1457        src_path: PathBuf,
1458        dest_path: RemotePathBuf,
1459        cx: &App,
1460    ) -> Task<Result<()>> {
1461        let mut command = util::command::new_smol_command("scp");
1462        let output = self
1463            .socket
1464            .ssh_options(&mut command)
1465            .args(
1466                self.socket
1467                    .connection_options
1468                    .port
1469                    .map(|port| vec!["-P".to_string(), port.to_string()])
1470                    .unwrap_or_default(),
1471            )
1472            .arg("-C")
1473            .arg("-r")
1474            .arg(&src_path)
1475            .arg(format!(
1476                "{}:{}",
1477                self.socket.connection_options.scp_url(),
1478                dest_path
1479            ))
1480            .output();
1481
1482        cx.background_spawn(async move {
1483            let output = output.await?;
1484
1485            anyhow::ensure!(
1486                output.status.success(),
1487                "failed to upload directory {} -> {}: {}",
1488                src_path.display(),
1489                dest_path.to_string(),
1490                String::from_utf8_lossy(&output.stderr)
1491            );
1492
1493            Ok(())
1494        })
1495    }
1496
1497    fn start_proxy(
1498        &self,
1499        unique_identifier: String,
1500        reconnect: bool,
1501        incoming_tx: UnboundedSender<Envelope>,
1502        outgoing_rx: UnboundedReceiver<Envelope>,
1503        connection_activity_tx: Sender<()>,
1504        delegate: Arc<dyn SshClientDelegate>,
1505        cx: &mut AsyncApp,
1506    ) -> Task<Result<i32>> {
1507        delegate.set_status(Some("Starting proxy"), cx);
1508
1509        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1510            return Task::ready(Err(anyhow!("Remote binary path not set")));
1511        };
1512
1513        let mut start_proxy_command = shell_script!(
1514            "exec {binary_path} proxy --identifier {identifier}",
1515            binary_path = &remote_binary_path.to_string(),
1516            identifier = &unique_identifier,
1517        );
1518
1519        for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
1520            if let Some(value) = std::env::var(env_var).ok() {
1521                start_proxy_command = format!(
1522                    "{}={} {} ",
1523                    env_var,
1524                    shlex::try_quote(&value).unwrap(),
1525                    start_proxy_command,
1526                );
1527            }
1528        }
1529
1530        if reconnect {
1531            start_proxy_command.push_str(" --reconnect");
1532        }
1533
1534        let ssh_proxy_process = match self
1535            .socket
1536            .ssh_command("sh", &["-lc", &start_proxy_command])
1537            // IMPORTANT: we kill this process when we drop the task that uses it.
1538            .kill_on_drop(true)
1539            .spawn()
1540        {
1541            Ok(process) => process,
1542            Err(error) => {
1543                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1544            }
1545        };
1546
1547        Self::multiplex(
1548            ssh_proxy_process,
1549            incoming_tx,
1550            outgoing_rx,
1551            connection_activity_tx,
1552            cx,
1553        )
1554    }
1555
1556    fn path_style(&self) -> PathStyle {
1557        self.ssh_path_style
1558    }
1559}
1560
1561impl SshRemoteConnection {
1562    async fn new(
1563        connection_options: SshConnectionOptions,
1564        delegate: Arc<dyn SshClientDelegate>,
1565        cx: &mut AsyncApp,
1566    ) -> Result<Self> {
1567        use askpass::AskPassResult;
1568
1569        delegate.set_status(Some("Connecting"), cx);
1570
1571        let url = connection_options.ssh_url();
1572
1573        let temp_dir = tempfile::Builder::new()
1574            .prefix("zed-ssh-session")
1575            .tempdir()?;
1576        let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1577            let delegate = delegate.clone();
1578            move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1579        });
1580
1581        let mut askpass =
1582            askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1583
1584        // Start the master SSH process, which does not do anything except for establish
1585        // the connection and keep it open, allowing other ssh commands to reuse it
1586        // via a control socket.
1587        #[cfg(not(target_os = "windows"))]
1588        let socket_path = temp_dir.path().join("ssh.sock");
1589
1590        let mut master_process = {
1591            #[cfg(not(target_os = "windows"))]
1592            let args = [
1593                "-N",
1594                "-o",
1595                "ControlPersist=no",
1596                "-o",
1597                "ControlMaster=yes",
1598                "-o",
1599            ];
1600            // On Windows, `ControlMaster` and `ControlPath` are not supported:
1601            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1602            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1603            #[cfg(target_os = "windows")]
1604            let args = ["-N"];
1605            let mut master_process = util::command::new_smol_command("ssh");
1606            master_process
1607                .kill_on_drop(true)
1608                .stdin(Stdio::null())
1609                .stdout(Stdio::piped())
1610                .stderr(Stdio::piped())
1611                .env("SSH_ASKPASS_REQUIRE", "force")
1612                .env("SSH_ASKPASS", askpass.script_path())
1613                .args(connection_options.additional_args())
1614                .args(args);
1615            #[cfg(not(target_os = "windows"))]
1616            master_process.arg(format!("ControlPath={}", socket_path.display()));
1617            master_process.arg(&url).spawn()?
1618        };
1619        // Wait for this ssh process to close its stdout, indicating that authentication
1620        // has completed.
1621        let mut stdout = master_process.stdout.take().unwrap();
1622        let mut output = Vec::new();
1623
1624        let result = select_biased! {
1625            result = askpass.run().fuse() => {
1626                match result {
1627                    AskPassResult::CancelledByUser => {
1628                        master_process.kill().ok();
1629                        anyhow::bail!("SSH connection canceled")
1630                    }
1631                    AskPassResult::Timedout => {
1632                        anyhow::bail!("connecting to host timed out")
1633                    }
1634                }
1635            }
1636            _ = stdout.read_to_end(&mut output).fuse() => {
1637                anyhow::Ok(())
1638            }
1639        };
1640
1641        if let Err(e) = result {
1642            return Err(e.context("Failed to connect to host"));
1643        }
1644
1645        if master_process.try_status()?.is_some() {
1646            output.clear();
1647            let mut stderr = master_process.stderr.take().unwrap();
1648            stderr.read_to_end(&mut output).await?;
1649
1650            let error_message = format!(
1651                "failed to connect: {}",
1652                String::from_utf8_lossy(&output).trim()
1653            );
1654            anyhow::bail!(error_message);
1655        }
1656
1657        #[cfg(not(target_os = "windows"))]
1658        let socket = SshSocket::new(connection_options, socket_path)?;
1659        #[cfg(target_os = "windows")]
1660        let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1661        drop(askpass);
1662
1663        let ssh_platform = socket.platform().await?;
1664        let ssh_path_style = match ssh_platform.os {
1665            "windows" => PathStyle::Windows,
1666            _ => PathStyle::Posix,
1667        };
1668        let ssh_shell = socket.shell().await;
1669
1670        let mut this = Self {
1671            socket,
1672            master_process: Mutex::new(Some(master_process)),
1673            _temp_dir: temp_dir,
1674            remote_binary_path: None,
1675            ssh_path_style,
1676            ssh_platform,
1677            ssh_shell,
1678        };
1679
1680        let (release_channel, version, commit) = cx.update(|cx| {
1681            (
1682                ReleaseChannel::global(cx),
1683                AppVersion::global(cx),
1684                AppCommitSha::try_global(cx),
1685            )
1686        })?;
1687        this.remote_binary_path = Some(
1688            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1689                .await?,
1690        );
1691
1692        Ok(this)
1693    }
1694
1695    fn multiplex(
1696        mut ssh_proxy_process: Child,
1697        incoming_tx: UnboundedSender<Envelope>,
1698        mut outgoing_rx: UnboundedReceiver<Envelope>,
1699        mut connection_activity_tx: Sender<()>,
1700        cx: &AsyncApp,
1701    ) -> Task<Result<i32>> {
1702        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1703        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1704        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1705
1706        let mut stdin_buffer = Vec::new();
1707        let mut stdout_buffer = Vec::new();
1708        let mut stderr_buffer = Vec::new();
1709        let mut stderr_offset = 0;
1710
1711        let stdin_task = cx.background_spawn(async move {
1712            while let Some(outgoing) = outgoing_rx.next().await {
1713                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1714            }
1715            anyhow::Ok(())
1716        });
1717
1718        let stdout_task = cx.background_spawn({
1719            let mut connection_activity_tx = connection_activity_tx.clone();
1720            async move {
1721                loop {
1722                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1723                    let len = child_stdout.read(&mut stdout_buffer).await?;
1724
1725                    if len == 0 {
1726                        return anyhow::Ok(());
1727                    }
1728
1729                    if len < MESSAGE_LEN_SIZE {
1730                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1731                    }
1732
1733                    let message_len = message_len_from_buffer(&stdout_buffer);
1734                    let envelope =
1735                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1736                            .await?;
1737                    connection_activity_tx.try_send(()).ok();
1738                    incoming_tx.unbounded_send(envelope).ok();
1739                }
1740            }
1741        });
1742
1743        let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1744            loop {
1745                stderr_buffer.resize(stderr_offset + 1024, 0);
1746
1747                let len = child_stderr
1748                    .read(&mut stderr_buffer[stderr_offset..])
1749                    .await?;
1750                if len == 0 {
1751                    return anyhow::Ok(());
1752                }
1753
1754                stderr_offset += len;
1755                let mut start_ix = 0;
1756                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1757                    .iter()
1758                    .position(|b| b == &b'\n')
1759                {
1760                    let line_ix = start_ix + ix;
1761                    let content = &stderr_buffer[start_ix..line_ix];
1762                    start_ix = line_ix + 1;
1763                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1764                        record.log(log::logger())
1765                    } else {
1766                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1767                    }
1768                }
1769                stderr_buffer.drain(0..start_ix);
1770                stderr_offset -= start_ix;
1771
1772                connection_activity_tx.try_send(()).ok();
1773            }
1774        });
1775
1776        cx.background_spawn(async move {
1777            let result = futures::select! {
1778                result = stdin_task.fuse() => {
1779                    result.context("stdin")
1780                }
1781                result = stdout_task.fuse() => {
1782                    result.context("stdout")
1783                }
1784                result = stderr_task.fuse() => {
1785                    result.context("stderr")
1786                }
1787            };
1788
1789            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1790            match result {
1791                Ok(_) => Ok(status),
1792                Err(error) => Err(error),
1793            }
1794        })
1795    }
1796
1797    #[allow(unused)]
1798    async fn ensure_server_binary(
1799        &self,
1800        delegate: &Arc<dyn SshClientDelegate>,
1801        release_channel: ReleaseChannel,
1802        version: SemanticVersion,
1803        commit: Option<AppCommitSha>,
1804        cx: &mut AsyncApp,
1805    ) -> Result<RemotePathBuf> {
1806        let version_str = match release_channel {
1807            ReleaseChannel::Nightly => {
1808                let commit = commit.map(|s| s.full()).unwrap_or_default();
1809                format!("{}-{}", version, commit)
1810            }
1811            ReleaseChannel::Dev => "build".to_string(),
1812            _ => version.to_string(),
1813        };
1814        let binary_name = format!(
1815            "zed-remote-server-{}-{}",
1816            release_channel.dev_name(),
1817            version_str
1818        );
1819        let dst_path = RemotePathBuf::new(
1820            paths::remote_server_dir_relative().join(binary_name),
1821            self.ssh_path_style,
1822        );
1823
1824        let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1825        #[cfg(debug_assertions)]
1826        if let Some(build_remote_server) = build_remote_server {
1827            let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1828            let tmp_path = RemotePathBuf::new(
1829                paths::remote_server_dir_relative().join(format!(
1830                    "download-{}-{}",
1831                    std::process::id(),
1832                    src_path.file_name().unwrap().to_string_lossy()
1833                )),
1834                self.ssh_path_style,
1835            );
1836            self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1837                .await?;
1838            self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1839                .await?;
1840            return Ok(dst_path);
1841        }
1842
1843        if self
1844            .socket
1845            .run_command(&dst_path.to_string(), &["version"])
1846            .await
1847            .is_ok()
1848        {
1849            return Ok(dst_path);
1850        }
1851
1852        let wanted_version = cx.update(|cx| match release_channel {
1853            ReleaseChannel::Nightly => Ok(None),
1854            ReleaseChannel::Dev => {
1855                anyhow::bail!(
1856                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1857                    dst_path
1858                )
1859            }
1860            _ => Ok(Some(AppVersion::global(cx))),
1861        })??;
1862
1863        let tmp_path_gz = RemotePathBuf::new(
1864            PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
1865            self.ssh_path_style,
1866        );
1867        if !self.socket.connection_options.upload_binary_over_ssh
1868            && let Some((url, body)) = delegate
1869                .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1870                .await?
1871        {
1872            match self
1873                .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1874                .await
1875            {
1876                Ok(_) => {
1877                    self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1878                        .await?;
1879                    return Ok(dst_path);
1880                }
1881                Err(e) => {
1882                    log::error!(
1883                        "Failed to download binary on server, attempting to upload server: {}",
1884                        e
1885                    )
1886                }
1887            }
1888        }
1889
1890        let src_path = delegate
1891            .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1892            .await?;
1893        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1894            .await?;
1895        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1896            .await?;
1897        Ok(dst_path)
1898    }
1899
1900    async fn download_binary_on_server(
1901        &self,
1902        url: &str,
1903        body: &str,
1904        tmp_path_gz: &RemotePathBuf,
1905        delegate: &Arc<dyn SshClientDelegate>,
1906        cx: &mut AsyncApp,
1907    ) -> Result<()> {
1908        if let Some(parent) = tmp_path_gz.parent() {
1909            self.socket
1910                .run_command(
1911                    "sh",
1912                    &[
1913                        "-lc",
1914                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1915                    ],
1916                )
1917                .await?;
1918        }
1919
1920        delegate.set_status(Some("Downloading remote development server on host"), cx);
1921
1922        match self
1923            .socket
1924            .run_command(
1925                "curl",
1926                &[
1927                    "-f",
1928                    "-L",
1929                    "-X",
1930                    "GET",
1931                    "-H",
1932                    "Content-Type: application/json",
1933                    "-d",
1934                    body,
1935                    url,
1936                    "-o",
1937                    &tmp_path_gz.to_string(),
1938                ],
1939            )
1940            .await
1941        {
1942            Ok(_) => {}
1943            Err(e) => {
1944                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1945                    return Err(e);
1946                }
1947
1948                match self
1949                    .socket
1950                    .run_command(
1951                        "wget",
1952                        &[
1953                            "--method=GET",
1954                            "--header=Content-Type: application/json",
1955                            "--body-data",
1956                            body,
1957                            url,
1958                            "-O",
1959                            &tmp_path_gz.to_string(),
1960                        ],
1961                    )
1962                    .await
1963                {
1964                    Ok(_) => {}
1965                    Err(e) => {
1966                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1967                            return Err(e);
1968                        } else {
1969                            anyhow::bail!("Neither curl nor wget is available");
1970                        }
1971                    }
1972                }
1973            }
1974        }
1975
1976        Ok(())
1977    }
1978
1979    async fn upload_local_server_binary(
1980        &self,
1981        src_path: &Path,
1982        tmp_path_gz: &RemotePathBuf,
1983        delegate: &Arc<dyn SshClientDelegate>,
1984        cx: &mut AsyncApp,
1985    ) -> Result<()> {
1986        if let Some(parent) = tmp_path_gz.parent() {
1987            self.socket
1988                .run_command(
1989                    "sh",
1990                    &[
1991                        "-lc",
1992                        &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1993                    ],
1994                )
1995                .await?;
1996        }
1997
1998        let src_stat = fs::metadata(&src_path).await?;
1999        let size = src_stat.len();
2000
2001        let t0 = Instant::now();
2002        delegate.set_status(Some("Uploading remote development server"), cx);
2003        log::info!(
2004            "uploading remote development server to {:?} ({}kb)",
2005            tmp_path_gz,
2006            size / 1024
2007        );
2008        self.upload_file(src_path, tmp_path_gz)
2009            .await
2010            .context("failed to upload server binary")?;
2011        log::info!("uploaded remote development server in {:?}", t0.elapsed());
2012        Ok(())
2013    }
2014
2015    async fn extract_server_binary(
2016        &self,
2017        dst_path: &RemotePathBuf,
2018        tmp_path: &RemotePathBuf,
2019        delegate: &Arc<dyn SshClientDelegate>,
2020        cx: &mut AsyncApp,
2021    ) -> Result<()> {
2022        delegate.set_status(Some("Extracting remote development server"), cx);
2023        let server_mode = 0o755;
2024
2025        let orig_tmp_path = tmp_path.to_string();
2026        let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2027            shell_script!(
2028                "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2029                server_mode = &format!("{:o}", server_mode),
2030                dst_path = &dst_path.to_string(),
2031            )
2032        } else {
2033            shell_script!(
2034                "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2035                server_mode = &format!("{:o}", server_mode),
2036                dst_path = &dst_path.to_string()
2037            )
2038        };
2039        self.socket.run_command("sh", &["-lc", &script]).await?;
2040        Ok(())
2041    }
2042
2043    async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2044        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2045        let mut command = util::command::new_smol_command("scp");
2046        let output = self
2047            .socket
2048            .ssh_options(&mut command)
2049            .args(
2050                self.socket
2051                    .connection_options
2052                    .port
2053                    .map(|port| vec!["-P".to_string(), port.to_string()])
2054                    .unwrap_or_default(),
2055            )
2056            .arg(src_path)
2057            .arg(format!(
2058                "{}:{}",
2059                self.socket.connection_options.scp_url(),
2060                dest_path
2061            ))
2062            .output()
2063            .await?;
2064
2065        anyhow::ensure!(
2066            output.status.success(),
2067            "failed to upload file {} -> {}: {}",
2068            src_path.display(),
2069            dest_path.to_string(),
2070            String::from_utf8_lossy(&output.stderr)
2071        );
2072        Ok(())
2073    }
2074
2075    #[cfg(debug_assertions)]
2076    async fn build_local(
2077        &self,
2078        build_remote_server: String,
2079        delegate: &Arc<dyn SshClientDelegate>,
2080        cx: &mut AsyncApp,
2081    ) -> Result<PathBuf> {
2082        use smol::process::{Command, Stdio};
2083        use std::env::VarError;
2084
2085        async fn run_cmd(command: &mut Command) -> Result<()> {
2086            let output = command
2087                .kill_on_drop(true)
2088                .stderr(Stdio::inherit())
2089                .output()
2090                .await?;
2091            anyhow::ensure!(
2092                output.status.success(),
2093                "Failed to run command: {command:?}"
2094            );
2095            Ok(())
2096        }
2097
2098        let use_musl = !build_remote_server.contains("nomusl");
2099        let triple = format!(
2100            "{}-{}",
2101            self.ssh_platform.arch,
2102            match self.ssh_platform.os {
2103                "linux" =>
2104                    if use_musl {
2105                        "unknown-linux-musl"
2106                    } else {
2107                        "unknown-linux-gnu"
2108                    },
2109                "macos" => "apple-darwin",
2110                _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2111            }
2112        );
2113        let mut rust_flags = match std::env::var("RUSTFLAGS") {
2114            Ok(val) => val,
2115            Err(VarError::NotPresent) => String::new(),
2116            Err(e) => {
2117                log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2118                String::new()
2119            }
2120        };
2121        if self.ssh_platform.os == "linux" && use_musl {
2122            rust_flags.push_str(" -C target-feature=+crt-static");
2123        }
2124        if build_remote_server.contains("mold") {
2125            rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2126        }
2127
2128        if self.ssh_platform.arch == std::env::consts::ARCH
2129            && self.ssh_platform.os == std::env::consts::OS
2130        {
2131            delegate.set_status(Some("Building remote server binary from source"), cx);
2132            log::info!("building remote server binary from source");
2133            run_cmd(
2134                Command::new("cargo")
2135                    .args([
2136                        "build",
2137                        "--package",
2138                        "remote_server",
2139                        "--features",
2140                        "debug-embed",
2141                        "--target-dir",
2142                        "target/remote_server",
2143                        "--target",
2144                        &triple,
2145                    ])
2146                    .env("RUSTFLAGS", &rust_flags),
2147            )
2148            .await?;
2149        } else if build_remote_server.contains("cross") {
2150            #[cfg(target_os = "windows")]
2151            use util::paths::SanitizedPath;
2152
2153            delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2154            log::info!("installing cross");
2155            run_cmd(Command::new("cargo").args([
2156                "install",
2157                "cross",
2158                "--git",
2159                "https://github.com/cross-rs/cross",
2160            ]))
2161            .await?;
2162
2163            delegate.set_status(
2164                Some(&format!(
2165                    "Building remote server binary from source for {} with Docker",
2166                    &triple
2167                )),
2168                cx,
2169            );
2170            log::info!("building remote server binary from source for {}", &triple);
2171
2172            // On Windows, the binding needs to be set to the canonical path
2173            #[cfg(target_os = "windows")]
2174            let src =
2175                SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2176            #[cfg(not(target_os = "windows"))]
2177            let src = "./target";
2178            run_cmd(
2179                Command::new("cross")
2180                    .args([
2181                        "build",
2182                        "--package",
2183                        "remote_server",
2184                        "--features",
2185                        "debug-embed",
2186                        "--target-dir",
2187                        "target/remote_server",
2188                        "--target",
2189                        &triple,
2190                    ])
2191                    .env(
2192                        "CROSS_CONTAINER_OPTS",
2193                        format!("--mount type=bind,src={src},dst=/app/target"),
2194                    )
2195                    .env("RUSTFLAGS", &rust_flags),
2196            )
2197            .await?;
2198        } else {
2199            let which = cx
2200                .background_spawn(async move { which::which("zig") })
2201                .await;
2202
2203            if which.is_err() {
2204                #[cfg(not(target_os = "windows"))]
2205                {
2206                    anyhow::bail!(
2207                        "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2208                    )
2209                }
2210                #[cfg(target_os = "windows")]
2211                {
2212                    anyhow::bail!(
2213                        "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2214                    )
2215                }
2216            }
2217
2218            delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2219            log::info!("adding rustup target");
2220            run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2221
2222            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2223            log::info!("installing cargo-zigbuild");
2224            run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
2225
2226            delegate.set_status(
2227                Some(&format!(
2228                    "Building remote binary from source for {triple} with Zig"
2229                )),
2230                cx,
2231            );
2232            log::info!("building remote binary from source for {triple} with Zig");
2233            run_cmd(
2234                Command::new("cargo")
2235                    .args([
2236                        "zigbuild",
2237                        "--package",
2238                        "remote_server",
2239                        "--features",
2240                        "debug-embed",
2241                        "--target-dir",
2242                        "target/remote_server",
2243                        "--target",
2244                        &triple,
2245                    ])
2246                    .env("RUSTFLAGS", &rust_flags),
2247            )
2248            .await?;
2249        };
2250        let bin_path = Path::new("target")
2251            .join("remote_server")
2252            .join(&triple)
2253            .join("debug")
2254            .join("remote_server");
2255
2256        let path = if !build_remote_server.contains("nocompress") {
2257            delegate.set_status(Some("Compressing binary"), cx);
2258
2259            #[cfg(not(target_os = "windows"))]
2260            {
2261                run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
2262            }
2263            #[cfg(target_os = "windows")]
2264            {
2265                // On Windows, we use 7z to compress the binary
2266                let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2267                let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2268                if smol::fs::metadata(&gz_path).await.is_ok() {
2269                    smol::fs::remove_file(&gz_path).await?;
2270                }
2271                run_cmd(Command::new(seven_zip).args([
2272                    "a",
2273                    "-tgzip",
2274                    &gz_path,
2275                    &bin_path.to_string_lossy(),
2276                ]))
2277                .await?;
2278            }
2279
2280            let mut archive_path = bin_path;
2281            archive_path.set_extension("gz");
2282            std::env::current_dir()?.join(archive_path)
2283        } else {
2284            bin_path
2285        };
2286
2287        Ok(path)
2288    }
2289}
2290
2291type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2292
2293struct ChannelClient {
2294    next_message_id: AtomicU32,
2295    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2296    buffer: Mutex<VecDeque<Envelope>>,
2297    response_channels: ResponseChannels,
2298    message_handlers: Mutex<ProtoMessageHandlerSet>,
2299    max_received: AtomicU32,
2300    name: &'static str,
2301    task: Mutex<Task<Result<()>>>,
2302}
2303
2304impl ChannelClient {
2305    fn new(
2306        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2307        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2308        cx: &App,
2309        name: &'static str,
2310    ) -> Arc<Self> {
2311        Arc::new_cyclic(|this| Self {
2312            outgoing_tx: Mutex::new(outgoing_tx),
2313            next_message_id: AtomicU32::new(0),
2314            max_received: AtomicU32::new(0),
2315            response_channels: ResponseChannels::default(),
2316            message_handlers: Default::default(),
2317            buffer: Mutex::new(VecDeque::new()),
2318            name,
2319            task: Mutex::new(Self::start_handling_messages(
2320                this.clone(),
2321                incoming_rx,
2322                &cx.to_async(),
2323            )),
2324        })
2325    }
2326
2327    fn start_handling_messages(
2328        this: Weak<Self>,
2329        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2330        cx: &AsyncApp,
2331    ) -> Task<Result<()>> {
2332        cx.spawn(async move |cx| {
2333            let peer_id = PeerId { owner_id: 0, id: 0 };
2334            while let Some(incoming) = incoming_rx.next().await {
2335                let Some(this) = this.upgrade() else {
2336                    return anyhow::Ok(());
2337                };
2338                if let Some(ack_id) = incoming.ack_id {
2339                    let mut buffer = this.buffer.lock();
2340                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2341                        buffer.pop_front();
2342                    }
2343                }
2344                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2345                {
2346                    log::debug!(
2347                        "{}:ssh message received. name:FlushBufferedMessages",
2348                        this.name
2349                    );
2350                    {
2351                        let buffer = this.buffer.lock();
2352                        for envelope in buffer.iter() {
2353                            this.outgoing_tx
2354                                .lock()
2355                                .unbounded_send(envelope.clone())
2356                                .ok();
2357                        }
2358                    }
2359                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2360                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2361                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2362                    continue;
2363                }
2364
2365                this.max_received.store(incoming.id, SeqCst);
2366
2367                if let Some(request_id) = incoming.responding_to {
2368                    let request_id = MessageId(request_id);
2369                    let sender = this.response_channels.lock().remove(&request_id);
2370                    if let Some(sender) = sender {
2371                        let (tx, rx) = oneshot::channel();
2372                        if incoming.payload.is_some() {
2373                            sender.send((incoming, tx)).ok();
2374                        }
2375                        rx.await.ok();
2376                    }
2377                } else if let Some(envelope) =
2378                    build_typed_envelope(peer_id, Instant::now(), incoming)
2379                {
2380                    let type_name = envelope.payload_type_name();
2381                    let message_id = envelope.message_id();
2382                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2383                        &this.message_handlers,
2384                        envelope,
2385                        this.clone().into(),
2386                        cx.clone(),
2387                    ) {
2388                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2389                        cx.foreground_executor()
2390                            .spawn(async move {
2391                                match future.await {
2392                                    Ok(_) => {
2393                                        log::debug!(
2394                                            "{}:ssh message handled. name:{type_name}",
2395                                            this.name
2396                                        );
2397                                    }
2398                                    Err(error) => {
2399                                        log::error!(
2400                                            "{}:error handling message. type:{}, error:{}",
2401                                            this.name,
2402                                            type_name,
2403                                            format!("{error:#}").lines().fold(
2404                                                String::new(),
2405                                                |mut message, line| {
2406                                                    if !message.is_empty() {
2407                                                        message.push(' ');
2408                                                    }
2409                                                    message.push_str(line);
2410                                                    message
2411                                                }
2412                                            )
2413                                        );
2414                                    }
2415                                }
2416                            })
2417                            .detach()
2418                    } else {
2419                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2420                        if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
2421                            message_id,
2422                            anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
2423                        ) {
2424                            log::error!(
2425                                "{}:error sending error response for {type_name}:{e:#}",
2426                                this.name
2427                            );
2428                        }
2429                    }
2430                }
2431            }
2432            anyhow::Ok(())
2433        })
2434    }
2435
2436    fn reconnect(
2437        self: &Arc<Self>,
2438        incoming_rx: UnboundedReceiver<Envelope>,
2439        outgoing_tx: UnboundedSender<Envelope>,
2440        cx: &AsyncApp,
2441    ) {
2442        *self.outgoing_tx.lock() = outgoing_tx;
2443        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2444    }
2445
2446    fn request<T: RequestMessage>(
2447        &self,
2448        payload: T,
2449    ) -> impl 'static + Future<Output = Result<T::Response>> {
2450        self.request_internal(payload, true)
2451    }
2452
2453    fn request_internal<T: RequestMessage>(
2454        &self,
2455        payload: T,
2456        use_buffer: bool,
2457    ) -> impl 'static + Future<Output = Result<T::Response>> {
2458        log::debug!("ssh request start. name:{}", T::NAME);
2459        let response =
2460            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2461        async move {
2462            let response = response.await?;
2463            log::debug!("ssh request finish. name:{}", T::NAME);
2464            T::Response::from_envelope(response).context("received a response of the wrong type")
2465        }
2466    }
2467
2468    async fn resync(&self, timeout: Duration) -> Result<()> {
2469        smol::future::or(
2470            async {
2471                self.request_internal(proto::FlushBufferedMessages {}, false)
2472                    .await?;
2473
2474                for envelope in self.buffer.lock().iter() {
2475                    self.outgoing_tx
2476                        .lock()
2477                        .unbounded_send(envelope.clone())
2478                        .ok();
2479                }
2480                Ok(())
2481            },
2482            async {
2483                smol::Timer::after(timeout).await;
2484                anyhow::bail!("Timed out resyncing remote client")
2485            },
2486        )
2487        .await
2488    }
2489
2490    async fn ping(&self, timeout: Duration) -> Result<()> {
2491        smol::future::or(
2492            async {
2493                self.request(proto::Ping {}).await?;
2494                Ok(())
2495            },
2496            async {
2497                smol::Timer::after(timeout).await;
2498                anyhow::bail!("Timed out pinging remote client")
2499            },
2500        )
2501        .await
2502    }
2503
2504    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2505        log::debug!("ssh send name:{}", T::NAME);
2506        self.send_dynamic(payload.into_envelope(0, None, None))
2507    }
2508
2509    fn request_dynamic(
2510        &self,
2511        mut envelope: proto::Envelope,
2512        type_name: &'static str,
2513        use_buffer: bool,
2514    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2515        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2516        let (tx, rx) = oneshot::channel();
2517        let mut response_channels_lock = self.response_channels.lock();
2518        response_channels_lock.insert(MessageId(envelope.id), tx);
2519        drop(response_channels_lock);
2520
2521        let result = if use_buffer {
2522            self.send_buffered(envelope)
2523        } else {
2524            self.send_unbuffered(envelope)
2525        };
2526        async move {
2527            if let Err(error) = &result {
2528                log::error!("failed to send message: {error}");
2529                anyhow::bail!("failed to send message: {error}");
2530            }
2531
2532            let response = rx.await.context("connection lost")?.0;
2533            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2534                return Err(RpcError::from_proto(error, type_name));
2535            }
2536            Ok(response)
2537        }
2538    }
2539
2540    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2541        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2542        self.send_buffered(envelope)
2543    }
2544
2545    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2546        envelope.ack_id = Some(self.max_received.load(SeqCst));
2547        self.buffer.lock().push_back(envelope.clone());
2548        // ignore errors on send (happen while we're reconnecting)
2549        // assume that the global "disconnected" overlay is sufficient.
2550        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2551        Ok(())
2552    }
2553
2554    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2555        envelope.ack_id = Some(self.max_received.load(SeqCst));
2556        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2557        Ok(())
2558    }
2559}
2560
2561impl ProtoClient for ChannelClient {
2562    fn request(
2563        &self,
2564        envelope: proto::Envelope,
2565        request_type: &'static str,
2566    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2567        self.request_dynamic(envelope, request_type, true).boxed()
2568    }
2569
2570    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2571        self.send_dynamic(envelope)
2572    }
2573
2574    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2575        self.send_dynamic(envelope)
2576    }
2577
2578    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2579        &self.message_handlers
2580    }
2581
2582    fn is_via_collab(&self) -> bool {
2583        false
2584    }
2585}
2586
2587#[cfg(any(test, feature = "test-support"))]
2588mod fake {
2589    use std::{path::PathBuf, sync::Arc};
2590
2591    use anyhow::Result;
2592    use async_trait::async_trait;
2593    use futures::{
2594        FutureExt, SinkExt, StreamExt,
2595        channel::{
2596            mpsc::{self, Sender},
2597            oneshot,
2598        },
2599        select_biased,
2600    };
2601    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2602    use release_channel::ReleaseChannel;
2603    use rpc::proto::Envelope;
2604    use util::paths::{PathStyle, RemotePathBuf};
2605
2606    use super::{
2607        ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2608        SshPlatform,
2609    };
2610
2611    pub(super) struct FakeRemoteConnection {
2612        pub(super) connection_options: SshConnectionOptions,
2613        pub(super) server_channel: Arc<ChannelClient>,
2614        pub(super) server_cx: SendableCx,
2615    }
2616
2617    pub(super) struct SendableCx(AsyncApp);
2618    impl SendableCx {
2619        // SAFETY: When run in test mode, GPUI is always single threaded.
2620        pub(super) fn new(cx: &TestAppContext) -> Self {
2621            Self(cx.to_async())
2622        }
2623
2624        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2625        fn get(&self, _: &AsyncApp) -> AsyncApp {
2626            self.0.clone()
2627        }
2628    }
2629
2630    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2631    unsafe impl Send for SendableCx {}
2632    unsafe impl Sync for SendableCx {}
2633
2634    #[async_trait(?Send)]
2635    impl RemoteConnection for FakeRemoteConnection {
2636        async fn kill(&self) -> Result<()> {
2637            Ok(())
2638        }
2639
2640        fn has_been_killed(&self) -> bool {
2641            false
2642        }
2643
2644        fn ssh_args(&self) -> SshArgs {
2645            SshArgs {
2646                arguments: Vec::new(),
2647                envs: None,
2648            }
2649        }
2650
2651        fn upload_directory(
2652            &self,
2653            _src_path: PathBuf,
2654            _dest_path: RemotePathBuf,
2655            _cx: &App,
2656        ) -> Task<Result<()>> {
2657            unreachable!()
2658        }
2659
2660        fn connection_options(&self) -> SshConnectionOptions {
2661            self.connection_options.clone()
2662        }
2663
2664        fn simulate_disconnect(&self, cx: &AsyncApp) {
2665            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2666            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2667            self.server_channel
2668                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
2669        }
2670
2671        fn start_proxy(
2672            &self,
2673            _unique_identifier: String,
2674            _reconnect: bool,
2675            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2676            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2677            mut connection_activity_tx: Sender<()>,
2678            _delegate: Arc<dyn SshClientDelegate>,
2679            cx: &mut AsyncApp,
2680        ) -> Task<Result<i32>> {
2681            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2682            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2683
2684            self.server_channel.reconnect(
2685                server_incoming_rx,
2686                server_outgoing_tx,
2687                &self.server_cx.get(cx),
2688            );
2689
2690            cx.background_spawn(async move {
2691                loop {
2692                    select_biased! {
2693                        server_to_client = server_outgoing_rx.next().fuse() => {
2694                            let Some(server_to_client) = server_to_client else {
2695                                return Ok(1)
2696                            };
2697                            connection_activity_tx.try_send(()).ok();
2698                            client_incoming_tx.send(server_to_client).await.ok();
2699                        }
2700                        client_to_server = client_outgoing_rx.next().fuse() => {
2701                            let Some(client_to_server) = client_to_server else {
2702                                return Ok(1)
2703                            };
2704                            server_incoming_tx.send(client_to_server).await.ok();
2705                        }
2706                    }
2707                }
2708            })
2709        }
2710
2711        fn path_style(&self) -> PathStyle {
2712            PathStyle::current()
2713        }
2714
2715        fn shell(&self) -> String {
2716            "sh".to_owned()
2717        }
2718    }
2719
2720    pub(super) struct Delegate;
2721
2722    impl SshClientDelegate for Delegate {
2723        fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2724            unreachable!()
2725        }
2726
2727        fn download_server_binary_locally(
2728            &self,
2729            _: SshPlatform,
2730            _: ReleaseChannel,
2731            _: Option<SemanticVersion>,
2732            _: &mut AsyncApp,
2733        ) -> Task<Result<PathBuf>> {
2734            unreachable!()
2735        }
2736
2737        fn get_download_params(
2738            &self,
2739            _platform: SshPlatform,
2740            _release_channel: ReleaseChannel,
2741            _version: Option<SemanticVersion>,
2742            _cx: &mut AsyncApp,
2743        ) -> Task<Result<Option<(String, String)>>> {
2744            unreachable!()
2745        }
2746
2747        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2748    }
2749}