ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    channel::{
  13        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  14        oneshot,
  15    },
  16    future::{BoxFuture, Shared},
  17    select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  18};
  19use gpui::{
  20    App, AppContext, AsyncApp, BorrowAppContext, Context, Entity, EventEmitter, Global,
  21    SemanticVersion, Task, WeakEntity,
  22};
  23use itertools::Itertools;
  24use parking_lot::Mutex;
  25use paths;
  26use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  27use rpc::{
  28    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  29    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  30    RpcError,
  31};
  32use smol::{
  33    fs,
  34    process::{self, Child, Stdio},
  35};
  36use std::{
  37    any::TypeId,
  38    collections::VecDeque,
  39    fmt, iter,
  40    ops::ControlFlow,
  41    path::{Path, PathBuf},
  42    sync::{
  43        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  44        Arc, Weak,
  45    },
  46    time::{Duration, Instant},
  47};
  48use tempfile::TempDir;
  49use util::ResultExt;
  50
  51#[derive(
  52    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  53)]
  54pub struct SshProjectId(pub u64);
  55
  56#[derive(Clone)]
  57pub struct SshSocket {
  58    connection_options: SshConnectionOptions,
  59    socket_path: PathBuf,
  60}
  61
  62#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  63pub struct SshConnectionOptions {
  64    pub host: String,
  65    pub username: Option<String>,
  66    pub port: Option<u16>,
  67    pub password: Option<String>,
  68    pub args: Option<Vec<String>>,
  69
  70    pub nickname: Option<String>,
  71    pub upload_binary_over_ssh: bool,
  72}
  73
  74#[macro_export]
  75macro_rules! shell_script {
  76    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
  77        format!(
  78            $fmt,
  79            $(
  80                $name = shlex::try_quote($arg).unwrap()
  81            ),+
  82        )
  83    }};
  84}
  85
  86impl SshConnectionOptions {
  87    pub fn parse_command_line(input: &str) -> Result<Self> {
  88        let input = input.trim_start_matches("ssh ");
  89        let mut hostname: Option<String> = None;
  90        let mut username: Option<String> = None;
  91        let mut port: Option<u16> = None;
  92        let mut args = Vec::new();
  93
  94        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
  95        const ALLOWED_OPTS: &[&str] = &[
  96            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
  97        ];
  98        const ALLOWED_ARGS: &[&str] = &[
  99            "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
 100            "-w",
 101        ];
 102
 103        let mut tokens = shlex::split(input)
 104            .ok_or_else(|| anyhow!("invalid input"))?
 105            .into_iter();
 106
 107        'outer: while let Some(arg) = tokens.next() {
 108            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 109                args.push(arg.to_string());
 110                continue;
 111            }
 112            if arg == "-p" {
 113                port = tokens.next().and_then(|arg| arg.parse().ok());
 114                continue;
 115            } else if let Some(p) = arg.strip_prefix("-p") {
 116                port = p.parse().ok();
 117                continue;
 118            }
 119            if arg == "-l" {
 120                username = tokens.next();
 121                continue;
 122            } else if let Some(l) = arg.strip_prefix("-l") {
 123                username = Some(l.to_string());
 124                continue;
 125            }
 126            for a in ALLOWED_ARGS {
 127                if arg == *a {
 128                    args.push(arg);
 129                    if let Some(next) = tokens.next() {
 130                        args.push(next);
 131                    }
 132                    continue 'outer;
 133                } else if arg.starts_with(a) {
 134                    args.push(arg);
 135                    continue 'outer;
 136                }
 137            }
 138            if arg.starts_with("-") || hostname.is_some() {
 139                anyhow::bail!("unsupported argument: {:?}", arg);
 140            }
 141            let mut input = &arg as &str;
 142            if let Some((u, rest)) = input.split_once('@') {
 143                input = rest;
 144                username = Some(u.to_string());
 145            }
 146            if let Some((rest, p)) = input.split_once(':') {
 147                input = rest;
 148                port = p.parse().ok()
 149            }
 150            hostname = Some(input.to_string())
 151        }
 152
 153        let Some(hostname) = hostname else {
 154            anyhow::bail!("missing hostname");
 155        };
 156
 157        Ok(Self {
 158            host: hostname.to_string(),
 159            username: username.clone(),
 160            port,
 161            args: Some(args),
 162            password: None,
 163            nickname: None,
 164            upload_binary_over_ssh: false,
 165        })
 166    }
 167
 168    pub fn ssh_url(&self) -> String {
 169        let mut result = String::from("ssh://");
 170        if let Some(username) = &self.username {
 171            result.push_str(username);
 172            result.push('@');
 173        }
 174        result.push_str(&self.host);
 175        if let Some(port) = self.port {
 176            result.push(':');
 177            result.push_str(&port.to_string());
 178        }
 179        result
 180    }
 181
 182    pub fn additional_args(&self) -> Option<&Vec<String>> {
 183        self.args.as_ref()
 184    }
 185
 186    fn scp_url(&self) -> String {
 187        if let Some(username) = &self.username {
 188            format!("{}@{}", username, self.host)
 189        } else {
 190            self.host.clone()
 191        }
 192    }
 193
 194    pub fn connection_string(&self) -> String {
 195        let host = if let Some(username) = &self.username {
 196            format!("{}@{}", username, self.host)
 197        } else {
 198            self.host.clone()
 199        };
 200        if let Some(port) = &self.port {
 201            format!("{}:{}", host, port)
 202        } else {
 203            host
 204        }
 205    }
 206}
 207
 208#[derive(Copy, Clone, Debug)]
 209pub struct SshPlatform {
 210    pub os: &'static str,
 211    pub arch: &'static str,
 212}
 213
 214impl SshPlatform {
 215    pub fn triple(&self) -> Option<String> {
 216        Some(format!(
 217            "{}-{}",
 218            self.arch,
 219            match self.os {
 220                "linux" => "unknown-linux-gnu",
 221                "macos" => "apple-darwin",
 222                _ => return None,
 223            }
 224        ))
 225    }
 226}
 227
 228pub trait SshClientDelegate: Send + Sync {
 229    fn ask_password(&self, prompt: String, cx: &mut AsyncApp) -> oneshot::Receiver<Result<String>>;
 230    fn get_download_params(
 231        &self,
 232        platform: SshPlatform,
 233        release_channel: ReleaseChannel,
 234        version: Option<SemanticVersion>,
 235        cx: &mut AsyncApp,
 236    ) -> Task<Result<Option<(String, String)>>>;
 237
 238    fn download_server_binary_locally(
 239        &self,
 240        platform: SshPlatform,
 241        release_channel: ReleaseChannel,
 242        version: Option<SemanticVersion>,
 243        cx: &mut AsyncApp,
 244    ) -> Task<Result<PathBuf>>;
 245    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
 246}
 247
 248impl SshSocket {
 249    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 250    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 251    // and passes -l as an argument to sh, not to ls.
 252    // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
 253    // into a machine. You must use `cd` to get back to $HOME.
 254    // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
 255    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 256        let mut command = util::command::new_smol_command("ssh");
 257        let to_run = iter::once(&program)
 258            .chain(args.iter())
 259            .map(|token| {
 260                // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
 261                debug_assert!(
 262                    !token.contains('\n'),
 263                    "multiline arguments do not work in all shells"
 264                );
 265                shlex::try_quote(token).unwrap()
 266            })
 267            .join(" ");
 268        let to_run = format!("cd; {to_run}");
 269        log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
 270        self.ssh_options(&mut command)
 271            .arg(self.connection_options.ssh_url())
 272            .arg(to_run);
 273        command
 274    }
 275
 276    async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
 277        let output = self.ssh_command(program, args).output().await?;
 278        if output.status.success() {
 279            Ok(String::from_utf8_lossy(&output.stdout).to_string())
 280        } else {
 281            Err(anyhow!(
 282                "failed to run command: {}",
 283                String::from_utf8_lossy(&output.stderr)
 284            ))
 285        }
 286    }
 287
 288    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 289        command
 290            .stdin(Stdio::piped())
 291            .stdout(Stdio::piped())
 292            .stderr(Stdio::piped())
 293            .args(["-o", "ControlMaster=no", "-o"])
 294            .arg(format!("ControlPath={}", self.socket_path.display()))
 295    }
 296
 297    fn ssh_args(&self) -> Vec<String> {
 298        vec![
 299            "-o".to_string(),
 300            "ControlMaster=no".to_string(),
 301            "-o".to_string(),
 302            format!("ControlPath={}", self.socket_path.display()),
 303            self.connection_options.ssh_url(),
 304        ]
 305    }
 306}
 307
 308const MAX_MISSED_HEARTBEATS: usize = 5;
 309const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 310const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 311
 312const MAX_RECONNECT_ATTEMPTS: usize = 3;
 313
 314enum State {
 315    Connecting,
 316    Connected {
 317        ssh_connection: Arc<dyn RemoteConnection>,
 318        delegate: Arc<dyn SshClientDelegate>,
 319
 320        multiplex_task: Task<Result<()>>,
 321        heartbeat_task: Task<Result<()>>,
 322    },
 323    HeartbeatMissed {
 324        missed_heartbeats: usize,
 325
 326        ssh_connection: Arc<dyn RemoteConnection>,
 327        delegate: Arc<dyn SshClientDelegate>,
 328
 329        multiplex_task: Task<Result<()>>,
 330        heartbeat_task: Task<Result<()>>,
 331    },
 332    Reconnecting,
 333    ReconnectFailed {
 334        ssh_connection: Arc<dyn RemoteConnection>,
 335        delegate: Arc<dyn SshClientDelegate>,
 336
 337        error: anyhow::Error,
 338        attempts: usize,
 339    },
 340    ReconnectExhausted,
 341    ServerNotRunning,
 342}
 343
 344impl fmt::Display for State {
 345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 346        match self {
 347            Self::Connecting => write!(f, "connecting"),
 348            Self::Connected { .. } => write!(f, "connected"),
 349            Self::Reconnecting => write!(f, "reconnecting"),
 350            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 351            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 352            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 353            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 354        }
 355    }
 356}
 357
 358impl State {
 359    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 360        match self {
 361            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 362            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 363            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 364            _ => None,
 365        }
 366    }
 367
 368    fn can_reconnect(&self) -> bool {
 369        match self {
 370            Self::Connected { .. }
 371            | Self::HeartbeatMissed { .. }
 372            | Self::ReconnectFailed { .. } => true,
 373            State::Connecting
 374            | State::Reconnecting
 375            | State::ReconnectExhausted
 376            | State::ServerNotRunning => false,
 377        }
 378    }
 379
 380    fn is_reconnect_failed(&self) -> bool {
 381        matches!(self, Self::ReconnectFailed { .. })
 382    }
 383
 384    fn is_reconnect_exhausted(&self) -> bool {
 385        matches!(self, Self::ReconnectExhausted { .. })
 386    }
 387
 388    fn is_server_not_running(&self) -> bool {
 389        matches!(self, Self::ServerNotRunning)
 390    }
 391
 392    fn is_reconnecting(&self) -> bool {
 393        matches!(self, Self::Reconnecting { .. })
 394    }
 395
 396    fn heartbeat_recovered(self) -> Self {
 397        match self {
 398            Self::HeartbeatMissed {
 399                ssh_connection,
 400                delegate,
 401                multiplex_task,
 402                heartbeat_task,
 403                ..
 404            } => Self::Connected {
 405                ssh_connection,
 406                delegate,
 407                multiplex_task,
 408                heartbeat_task,
 409            },
 410            _ => self,
 411        }
 412    }
 413
 414    fn heartbeat_missed(self) -> Self {
 415        match self {
 416            Self::Connected {
 417                ssh_connection,
 418                delegate,
 419                multiplex_task,
 420                heartbeat_task,
 421            } => Self::HeartbeatMissed {
 422                missed_heartbeats: 1,
 423                ssh_connection,
 424                delegate,
 425                multiplex_task,
 426                heartbeat_task,
 427            },
 428            Self::HeartbeatMissed {
 429                missed_heartbeats,
 430                ssh_connection,
 431                delegate,
 432                multiplex_task,
 433                heartbeat_task,
 434            } => Self::HeartbeatMissed {
 435                missed_heartbeats: missed_heartbeats + 1,
 436                ssh_connection,
 437                delegate,
 438                multiplex_task,
 439                heartbeat_task,
 440            },
 441            _ => self,
 442        }
 443    }
 444}
 445
 446/// The state of the ssh connection.
 447#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 448pub enum ConnectionState {
 449    Connecting,
 450    Connected,
 451    HeartbeatMissed,
 452    Reconnecting,
 453    Disconnected,
 454}
 455
 456impl From<&State> for ConnectionState {
 457    fn from(value: &State) -> Self {
 458        match value {
 459            State::Connecting => Self::Connecting,
 460            State::Connected { .. } => Self::Connected,
 461            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 462            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 463            State::ReconnectExhausted => Self::Disconnected,
 464            State::ServerNotRunning => Self::Disconnected,
 465        }
 466    }
 467}
 468
 469pub struct SshRemoteClient {
 470    client: Arc<ChannelClient>,
 471    unique_identifier: String,
 472    connection_options: SshConnectionOptions,
 473    state: Arc<Mutex<Option<State>>>,
 474}
 475
 476#[derive(Debug)]
 477pub enum SshRemoteEvent {
 478    Disconnected,
 479}
 480
 481impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 482
 483// Identifies the socket on the remote server so that reconnects
 484// can re-join the same project.
 485pub enum ConnectionIdentifier {
 486    Setup(u64),
 487    Workspace(i64),
 488}
 489
 490static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 491
 492impl ConnectionIdentifier {
 493    pub fn setup() -> Self {
 494        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 495    }
 496    // This string gets used in a socket name, and so must be relatively short.
 497    // The total length of:
 498    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 499    // Must be less than about 100 characters
 500    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 501    // So our strings should be at most 20 characters or so.
 502    fn to_string(&self, cx: &App) -> String {
 503        let identifier_prefix = match ReleaseChannel::global(cx) {
 504            ReleaseChannel::Stable => "".to_string(),
 505            release_channel => format!("{}-", release_channel.dev_name()),
 506        };
 507        match self {
 508            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 509            Self::Workspace(workspace_id) => {
 510                format!("{identifier_prefix}workspace-{workspace_id}",)
 511            }
 512        }
 513    }
 514}
 515
 516impl SshRemoteClient {
 517    pub fn new(
 518        unique_identifier: ConnectionIdentifier,
 519        connection_options: SshConnectionOptions,
 520        cancellation: oneshot::Receiver<()>,
 521        delegate: Arc<dyn SshClientDelegate>,
 522        cx: &mut App,
 523    ) -> Task<Result<Option<Entity<Self>>>> {
 524        let unique_identifier = unique_identifier.to_string(cx);
 525        cx.spawn(|mut cx| async move {
 526            let success = Box::pin(async move {
 527                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 528                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 529                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 530
 531                let client =
 532                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 533                let this = cx.new(|_| Self {
 534                    client: client.clone(),
 535                    unique_identifier: unique_identifier.clone(),
 536                    connection_options: connection_options.clone(),
 537                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 538                })?;
 539
 540                let ssh_connection = cx
 541                    .update(|cx| {
 542                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 543                            pool.connect(connection_options, &delegate, cx)
 544                        })
 545                    })?
 546                    .await
 547                    .map_err(|e| e.cloned())?;
 548
 549                let io_task = ssh_connection.start_proxy(
 550                    unique_identifier,
 551                    false,
 552                    incoming_tx,
 553                    outgoing_rx,
 554                    connection_activity_tx,
 555                    delegate.clone(),
 556                    &mut cx,
 557                );
 558
 559                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 560
 561                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 562                    log::error!("failed to establish connection: {}", error);
 563                    return Err(error);
 564                }
 565
 566                let heartbeat_task =
 567                    Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 568
 569                this.update(&mut cx, |this, _| {
 570                    *this.state.lock() = Some(State::Connected {
 571                        ssh_connection,
 572                        delegate,
 573                        multiplex_task,
 574                        heartbeat_task,
 575                    });
 576                })?;
 577
 578                Ok(Some(this))
 579            });
 580
 581            select! {
 582                _ = cancellation.fuse() => {
 583                    Ok(None)
 584                }
 585                result = success.fuse() =>  result
 586            }
 587        })
 588    }
 589
 590    pub fn shutdown_processes<T: RequestMessage>(
 591        &self,
 592        shutdown_request: Option<T>,
 593    ) -> Option<impl Future<Output = ()>> {
 594        let state = self.state.lock().take()?;
 595        log::info!("shutting down ssh processes");
 596
 597        let State::Connected {
 598            multiplex_task,
 599            heartbeat_task,
 600            ssh_connection,
 601            delegate,
 602        } = state
 603        else {
 604            return None;
 605        };
 606
 607        let client = self.client.clone();
 608
 609        Some(async move {
 610            if let Some(shutdown_request) = shutdown_request {
 611                client.send(shutdown_request).log_err();
 612                // We wait 50ms instead of waiting for a response, because
 613                // waiting for a response would require us to wait on the main thread
 614                // which we want to avoid in an `on_app_quit` callback.
 615                smol::Timer::after(Duration::from_millis(50)).await;
 616            }
 617
 618            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 619            // child of master_process.
 620            drop(multiplex_task);
 621            // Now drop the rest of state, which kills master process.
 622            drop(heartbeat_task);
 623            drop(ssh_connection);
 624            drop(delegate);
 625        })
 626    }
 627
 628    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 629        let mut lock = self.state.lock();
 630
 631        let can_reconnect = lock
 632            .as_ref()
 633            .map(|state| state.can_reconnect())
 634            .unwrap_or(false);
 635        if !can_reconnect {
 636            let error = if let Some(state) = lock.as_ref() {
 637                format!("invalid state, cannot reconnect while in state {state}")
 638            } else {
 639                "no state set".to_string()
 640            };
 641            log::info!("aborting reconnect, because not in state that allows reconnecting");
 642            return Err(anyhow!(error));
 643        }
 644
 645        let state = lock.take().unwrap();
 646        let (attempts, ssh_connection, delegate) = match state {
 647            State::Connected {
 648                ssh_connection,
 649                delegate,
 650                multiplex_task,
 651                heartbeat_task,
 652            }
 653            | State::HeartbeatMissed {
 654                ssh_connection,
 655                delegate,
 656                multiplex_task,
 657                heartbeat_task,
 658                ..
 659            } => {
 660                drop(multiplex_task);
 661                drop(heartbeat_task);
 662                (0, ssh_connection, delegate)
 663            }
 664            State::ReconnectFailed {
 665                attempts,
 666                ssh_connection,
 667                delegate,
 668                ..
 669            } => (attempts, ssh_connection, delegate),
 670            State::Connecting
 671            | State::Reconnecting
 672            | State::ReconnectExhausted
 673            | State::ServerNotRunning => unreachable!(),
 674        };
 675
 676        let attempts = attempts + 1;
 677        if attempts > MAX_RECONNECT_ATTEMPTS {
 678            log::error!(
 679                "Failed to reconnect to after {} attempts, giving up",
 680                MAX_RECONNECT_ATTEMPTS
 681            );
 682            drop(lock);
 683            self.set_state(State::ReconnectExhausted, cx);
 684            return Ok(());
 685        }
 686        drop(lock);
 687
 688        self.set_state(State::Reconnecting, cx);
 689
 690        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 691
 692        let unique_identifier = self.unique_identifier.clone();
 693        let client = self.client.clone();
 694        let reconnect_task = cx.spawn(|this, mut cx| async move {
 695            macro_rules! failed {
 696                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 697                    return State::ReconnectFailed {
 698                        error: anyhow!($error),
 699                        attempts: $attempts,
 700                        ssh_connection: $ssh_connection,
 701                        delegate: $delegate,
 702                    };
 703                };
 704            }
 705
 706            if let Err(error) = ssh_connection
 707                .kill()
 708                .await
 709                .context("Failed to kill ssh process")
 710            {
 711                failed!(error, attempts, ssh_connection, delegate);
 712            };
 713
 714            let connection_options = ssh_connection.connection_options();
 715
 716            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 717            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 718            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 719
 720            let (ssh_connection, io_task) = match async {
 721                let ssh_connection = cx
 722                    .update_global(|pool: &mut ConnectionPool, cx| {
 723                        pool.connect(connection_options, &delegate, cx)
 724                    })?
 725                    .await
 726                    .map_err(|error| error.cloned())?;
 727
 728                let io_task = ssh_connection.start_proxy(
 729                    unique_identifier,
 730                    true,
 731                    incoming_tx,
 732                    outgoing_rx,
 733                    connection_activity_tx,
 734                    delegate.clone(),
 735                    &mut cx,
 736                );
 737                anyhow::Ok((ssh_connection, io_task))
 738            }
 739            .await
 740            {
 741                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 742                Err(error) => {
 743                    failed!(error, attempts, ssh_connection, delegate);
 744                }
 745            };
 746
 747            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 748            client.reconnect(incoming_rx, outgoing_tx, &cx);
 749
 750            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 751                failed!(error, attempts, ssh_connection, delegate);
 752            };
 753
 754            State::Connected {
 755                ssh_connection,
 756                delegate,
 757                multiplex_task,
 758                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 759            }
 760        });
 761
 762        cx.spawn(|this, mut cx| async move {
 763            let new_state = reconnect_task.await;
 764            this.update(&mut cx, |this, cx| {
 765                this.try_set_state(cx, |old_state| {
 766                    if old_state.is_reconnecting() {
 767                        match &new_state {
 768                            State::Connecting
 769                            | State::Reconnecting { .. }
 770                            | State::HeartbeatMissed { .. }
 771                            | State::ServerNotRunning => {}
 772                            State::Connected { .. } => {
 773                                log::info!("Successfully reconnected");
 774                            }
 775                            State::ReconnectFailed {
 776                                error, attempts, ..
 777                            } => {
 778                                log::error!(
 779                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 780                                    attempts,
 781                                    error
 782                                );
 783                            }
 784                            State::ReconnectExhausted => {
 785                                log::error!("Reconnect attempt failed and all attempts exhausted");
 786                            }
 787                        }
 788                        Some(new_state)
 789                    } else {
 790                        None
 791                    }
 792                });
 793
 794                if this.state_is(State::is_reconnect_failed) {
 795                    this.reconnect(cx)
 796                } else if this.state_is(State::is_reconnect_exhausted) {
 797                    Ok(())
 798                } else {
 799                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 800                    Ok(())
 801                }
 802            })
 803        })
 804        .detach_and_log_err(cx);
 805
 806        Ok(())
 807    }
 808
 809    fn heartbeat(
 810        this: WeakEntity<Self>,
 811        mut connection_activity_rx: mpsc::Receiver<()>,
 812        cx: &mut AsyncApp,
 813    ) -> Task<Result<()>> {
 814        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 815            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 816        };
 817
 818        cx.spawn(|mut cx| {
 819            let this = this.clone();
 820            async move {
 821                let mut missed_heartbeats = 0;
 822
 823                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 824                futures::pin_mut!(keepalive_timer);
 825
 826                loop {
 827                    select_biased! {
 828                        result = connection_activity_rx.next().fuse() => {
 829                            if result.is_none() {
 830                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 831                                return Ok(());
 832                            }
 833
 834                            if missed_heartbeats != 0 {
 835                                missed_heartbeats = 0;
 836                                this.update(&mut cx, |this, mut cx| {
 837                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 838                                })?;
 839                            }
 840                        }
 841                        _ = keepalive_timer => {
 842                            log::debug!("Sending heartbeat to server...");
 843
 844                            let result = select_biased! {
 845                                _ = connection_activity_rx.next().fuse() => {
 846                                    Ok(())
 847                                }
 848                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 849                                    ping_result
 850                                }
 851                            };
 852
 853                            if result.is_err() {
 854                                missed_heartbeats += 1;
 855                                log::warn!(
 856                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 857                                    HEARTBEAT_TIMEOUT,
 858                                    missed_heartbeats,
 859                                    MAX_MISSED_HEARTBEATS
 860                                );
 861                            } else if missed_heartbeats != 0 {
 862                                missed_heartbeats = 0;
 863                            } else {
 864                                continue;
 865                            }
 866
 867                            let result = this.update(&mut cx, |this, mut cx| {
 868                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 869                            })?;
 870                            if result.is_break() {
 871                                return Ok(());
 872                            }
 873                        }
 874                    }
 875
 876                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 877                }
 878            }
 879        })
 880    }
 881
 882    fn handle_heartbeat_result(
 883        &mut self,
 884        missed_heartbeats: usize,
 885        cx: &mut Context<Self>,
 886    ) -> ControlFlow<()> {
 887        let state = self.state.lock().take().unwrap();
 888        let next_state = if missed_heartbeats > 0 {
 889            state.heartbeat_missed()
 890        } else {
 891            state.heartbeat_recovered()
 892        };
 893
 894        self.set_state(next_state, cx);
 895
 896        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 897            log::error!(
 898                "Missed last {} heartbeats. Reconnecting...",
 899                missed_heartbeats
 900            );
 901
 902            self.reconnect(cx)
 903                .context("failed to start reconnect process after missing heartbeats")
 904                .log_err();
 905            ControlFlow::Break(())
 906        } else {
 907            ControlFlow::Continue(())
 908        }
 909    }
 910
 911    fn monitor(
 912        this: WeakEntity<Self>,
 913        io_task: Task<Result<i32>>,
 914        cx: &AsyncApp,
 915    ) -> Task<Result<()>> {
 916        cx.spawn(|mut cx| async move {
 917            let result = io_task.await;
 918
 919            match result {
 920                Ok(exit_code) => {
 921                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 922                        match error {
 923                            ProxyLaunchError::ServerNotRunning => {
 924                                log::error!("failed to reconnect because server is not running");
 925                                this.update(&mut cx, |this, cx| {
 926                                    this.set_state(State::ServerNotRunning, cx);
 927                                })?;
 928                            }
 929                        }
 930                    } else if exit_code > 0 {
 931                        log::error!("proxy process terminated unexpectedly");
 932                        this.update(&mut cx, |this, cx| {
 933                            this.reconnect(cx).ok();
 934                        })?;
 935                    }
 936                }
 937                Err(error) => {
 938                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 939                    this.update(&mut cx, |this, cx| {
 940                        this.reconnect(cx).ok();
 941                    })?;
 942                }
 943            }
 944
 945            Ok(())
 946        })
 947    }
 948
 949    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 950        self.state.lock().as_ref().map_or(false, check)
 951    }
 952
 953    fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
 954        let mut lock = self.state.lock();
 955        let new_state = lock.as_ref().and_then(map);
 956
 957        if let Some(new_state) = new_state {
 958            lock.replace(new_state);
 959            cx.notify();
 960        }
 961    }
 962
 963    fn set_state(&self, state: State, cx: &mut Context<Self>) {
 964        log::info!("setting state to '{}'", &state);
 965
 966        let is_reconnect_exhausted = state.is_reconnect_exhausted();
 967        let is_server_not_running = state.is_server_not_running();
 968        self.state.lock().replace(state);
 969
 970        if is_reconnect_exhausted || is_server_not_running {
 971            cx.emit(SshRemoteEvent::Disconnected);
 972        }
 973        cx.notify();
 974    }
 975
 976    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
 977        self.client.subscribe_to_entity(remote_id, entity);
 978    }
 979
 980    pub fn ssh_args(&self) -> Option<Vec<String>> {
 981        self.state
 982            .lock()
 983            .as_ref()
 984            .and_then(|state| state.ssh_connection())
 985            .map(|ssh_connection| ssh_connection.ssh_args())
 986    }
 987
 988    pub fn upload_directory(
 989        &self,
 990        src_path: PathBuf,
 991        dest_path: PathBuf,
 992        cx: &App,
 993    ) -> Task<Result<()>> {
 994        let state = self.state.lock();
 995        let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
 996            return Task::ready(Err(anyhow!("no ssh connection")));
 997        };
 998        connection.upload_directory(src_path, dest_path, cx)
 999    }
1000
1001    pub fn proto_client(&self) -> AnyProtoClient {
1002        self.client.clone().into()
1003    }
1004
1005    pub fn connection_string(&self) -> String {
1006        self.connection_options.connection_string()
1007    }
1008
1009    pub fn connection_options(&self) -> SshConnectionOptions {
1010        self.connection_options.clone()
1011    }
1012
1013    pub fn connection_state(&self) -> ConnectionState {
1014        self.state
1015            .lock()
1016            .as_ref()
1017            .map(ConnectionState::from)
1018            .unwrap_or(ConnectionState::Disconnected)
1019    }
1020
1021    pub fn is_disconnected(&self) -> bool {
1022        self.connection_state() == ConnectionState::Disconnected
1023    }
1024
1025    #[cfg(any(test, feature = "test-support"))]
1026    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1027        let opts = self.connection_options();
1028        client_cx.spawn(|cx| async move {
1029            let connection = cx
1030                .update_global(|c: &mut ConnectionPool, _| {
1031                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1032                        c.clone()
1033                    } else {
1034                        panic!("missing test connection")
1035                    }
1036                })
1037                .unwrap()
1038                .await
1039                .unwrap();
1040
1041            connection.simulate_disconnect(&cx);
1042        })
1043    }
1044
1045    #[cfg(any(test, feature = "test-support"))]
1046    pub fn fake_server(
1047        client_cx: &mut gpui::TestAppContext,
1048        server_cx: &mut gpui::TestAppContext,
1049    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1050        let port = client_cx
1051            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1052        let opts = SshConnectionOptions {
1053            host: "<fake>".to_string(),
1054            port: Some(port),
1055            ..Default::default()
1056        };
1057        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1058        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1059        let server_client =
1060            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1061        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1062            connection_options: opts.clone(),
1063            server_cx: fake::SendableCx::new(server_cx),
1064            server_channel: server_client.clone(),
1065        });
1066
1067        client_cx.update(|cx| {
1068            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1069                c.connections.insert(
1070                    opts.clone(),
1071                    ConnectionPoolEntry::Connecting(
1072                        cx.background_executor()
1073                            .spawn({
1074                                let connection = connection.clone();
1075                                async move { Ok(connection.clone()) }
1076                            })
1077                            .shared(),
1078                    ),
1079                );
1080            })
1081        });
1082
1083        (opts, server_client)
1084    }
1085
1086    #[cfg(any(test, feature = "test-support"))]
1087    pub async fn fake_client(
1088        opts: SshConnectionOptions,
1089        client_cx: &mut gpui::TestAppContext,
1090    ) -> Entity<Self> {
1091        let (_tx, rx) = oneshot::channel();
1092        client_cx
1093            .update(|cx| {
1094                Self::new(
1095                    ConnectionIdentifier::setup(),
1096                    opts,
1097                    rx,
1098                    Arc::new(fake::Delegate),
1099                    cx,
1100                )
1101            })
1102            .await
1103            .unwrap()
1104            .unwrap()
1105    }
1106}
1107
1108enum ConnectionPoolEntry {
1109    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1110    Connected(Weak<dyn RemoteConnection>),
1111}
1112
1113#[derive(Default)]
1114struct ConnectionPool {
1115    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1116}
1117
1118impl Global for ConnectionPool {}
1119
1120impl ConnectionPool {
1121    pub fn connect(
1122        &mut self,
1123        opts: SshConnectionOptions,
1124        delegate: &Arc<dyn SshClientDelegate>,
1125        cx: &mut App,
1126    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1127        let connection = self.connections.get(&opts);
1128        match connection {
1129            Some(ConnectionPoolEntry::Connecting(task)) => {
1130                let delegate = delegate.clone();
1131                cx.spawn(|mut cx| async move {
1132                    delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx);
1133                })
1134                .detach();
1135                return task.clone();
1136            }
1137            Some(ConnectionPoolEntry::Connected(ssh)) => {
1138                if let Some(ssh) = ssh.upgrade() {
1139                    if !ssh.has_been_killed() {
1140                        return Task::ready(Ok(ssh)).shared();
1141                    }
1142                }
1143                self.connections.remove(&opts);
1144            }
1145            None => {}
1146        }
1147
1148        let task = cx
1149            .spawn({
1150                let opts = opts.clone();
1151                let delegate = delegate.clone();
1152                |mut cx| async move {
1153                    let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx)
1154                        .await
1155                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1156
1157                    cx.update_global(|pool: &mut Self, _| {
1158                        debug_assert!(matches!(
1159                            pool.connections.get(&opts),
1160                            Some(ConnectionPoolEntry::Connecting(_))
1161                        ));
1162                        match connection {
1163                            Ok(connection) => {
1164                                pool.connections.insert(
1165                                    opts.clone(),
1166                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1167                                );
1168                                Ok(connection)
1169                            }
1170                            Err(error) => {
1171                                pool.connections.remove(&opts);
1172                                Err(Arc::new(error))
1173                            }
1174                        }
1175                    })?
1176                }
1177            })
1178            .shared();
1179
1180        self.connections
1181            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1182        task
1183    }
1184}
1185
1186impl From<SshRemoteClient> for AnyProtoClient {
1187    fn from(client: SshRemoteClient) -> Self {
1188        AnyProtoClient::new(client.client.clone())
1189    }
1190}
1191
1192#[async_trait(?Send)]
1193trait RemoteConnection: Send + Sync {
1194    #[allow(clippy::too_many_arguments)]
1195    fn start_proxy(
1196        &self,
1197        unique_identifier: String,
1198        reconnect: bool,
1199        incoming_tx: UnboundedSender<Envelope>,
1200        outgoing_rx: UnboundedReceiver<Envelope>,
1201        connection_activity_tx: Sender<()>,
1202        delegate: Arc<dyn SshClientDelegate>,
1203        cx: &mut AsyncApp,
1204    ) -> Task<Result<i32>>;
1205    fn upload_directory(&self, src_path: PathBuf, dest_path: PathBuf, cx: &App)
1206        -> Task<Result<()>>;
1207    async fn kill(&self) -> Result<()>;
1208    fn has_been_killed(&self) -> bool;
1209    fn ssh_args(&self) -> Vec<String>;
1210    fn connection_options(&self) -> SshConnectionOptions;
1211
1212    #[cfg(any(test, feature = "test-support"))]
1213    fn simulate_disconnect(&self, _: &AsyncApp) {}
1214}
1215
1216struct SshRemoteConnection {
1217    socket: SshSocket,
1218    master_process: Mutex<Option<Child>>,
1219    remote_binary_path: Option<PathBuf>,
1220    _temp_dir: TempDir,
1221}
1222
1223#[async_trait(?Send)]
1224impl RemoteConnection for SshRemoteConnection {
1225    async fn kill(&self) -> Result<()> {
1226        let Some(mut process) = self.master_process.lock().take() else {
1227            return Ok(());
1228        };
1229        process.kill().ok();
1230        process.status().await?;
1231        Ok(())
1232    }
1233
1234    fn has_been_killed(&self) -> bool {
1235        self.master_process.lock().is_none()
1236    }
1237
1238    fn ssh_args(&self) -> Vec<String> {
1239        self.socket.ssh_args()
1240    }
1241
1242    fn connection_options(&self) -> SshConnectionOptions {
1243        self.socket.connection_options.clone()
1244    }
1245
1246    fn upload_directory(
1247        &self,
1248        src_path: PathBuf,
1249        dest_path: PathBuf,
1250        cx: &App,
1251    ) -> Task<Result<()>> {
1252        let mut command = util::command::new_smol_command("scp");
1253        let output = self
1254            .socket
1255            .ssh_options(&mut command)
1256            .args(
1257                self.socket
1258                    .connection_options
1259                    .port
1260                    .map(|port| vec!["-P".to_string(), port.to_string()])
1261                    .unwrap_or_default(),
1262            )
1263            .arg("-C")
1264            .arg("-r")
1265            .arg(&src_path)
1266            .arg(format!(
1267                "{}:{}",
1268                self.socket.connection_options.scp_url(),
1269                dest_path.display()
1270            ))
1271            .output();
1272
1273        cx.background_executor().spawn(async move {
1274            let output = output.await?;
1275
1276            if !output.status.success() {
1277                return Err(anyhow!(
1278                    "failed to upload directory {} -> {}: {}",
1279                    src_path.display(),
1280                    dest_path.display(),
1281                    String::from_utf8_lossy(&output.stderr)
1282                ));
1283            }
1284
1285            Ok(())
1286        })
1287    }
1288
1289    fn start_proxy(
1290        &self,
1291        unique_identifier: String,
1292        reconnect: bool,
1293        incoming_tx: UnboundedSender<Envelope>,
1294        outgoing_rx: UnboundedReceiver<Envelope>,
1295        connection_activity_tx: Sender<()>,
1296        delegate: Arc<dyn SshClientDelegate>,
1297        cx: &mut AsyncApp,
1298    ) -> Task<Result<i32>> {
1299        delegate.set_status(Some("Starting proxy"), cx);
1300
1301        let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1302            return Task::ready(Err(anyhow!("Remote binary path not set")));
1303        };
1304
1305        let mut start_proxy_command = shell_script!(
1306            "exec {binary_path} proxy --identifier {identifier}",
1307            binary_path = &remote_binary_path.to_string_lossy(),
1308            identifier = &unique_identifier,
1309        );
1310
1311        if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1312            start_proxy_command = format!(
1313                "RUST_LOG={} {}",
1314                shlex::try_quote(&rust_log).unwrap(),
1315                start_proxy_command
1316            )
1317        }
1318        if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1319            start_proxy_command = format!(
1320                "RUST_BACKTRACE={} {}",
1321                shlex::try_quote(&rust_backtrace).unwrap(),
1322                start_proxy_command
1323            )
1324        }
1325        if reconnect {
1326            start_proxy_command.push_str(" --reconnect");
1327        }
1328
1329        let ssh_proxy_process = match self
1330            .socket
1331            .ssh_command("sh", &["-c", &start_proxy_command])
1332            // IMPORTANT: we kill this process when we drop the task that uses it.
1333            .kill_on_drop(true)
1334            .spawn()
1335        {
1336            Ok(process) => process,
1337            Err(error) => {
1338                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)))
1339            }
1340        };
1341
1342        Self::multiplex(
1343            ssh_proxy_process,
1344            incoming_tx,
1345            outgoing_rx,
1346            connection_activity_tx,
1347            &cx,
1348        )
1349    }
1350}
1351
1352impl SshRemoteConnection {
1353    #[cfg(not(unix))]
1354    async fn new(
1355        _connection_options: SshConnectionOptions,
1356        _delegate: Arc<dyn SshClientDelegate>,
1357        _cx: &mut AsyncApp,
1358    ) -> Result<Self> {
1359        Err(anyhow!("ssh is not supported on this platform"))
1360    }
1361
1362    #[cfg(unix)]
1363    async fn new(
1364        connection_options: SshConnectionOptions,
1365        delegate: Arc<dyn SshClientDelegate>,
1366        cx: &mut AsyncApp,
1367    ) -> Result<Self> {
1368        use futures::AsyncWriteExt as _;
1369        use futures::{io::BufReader, AsyncBufReadExt as _};
1370        use smol::net::unix::UnixStream;
1371        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1372        use util::ResultExt as _;
1373
1374        delegate.set_status(Some("Connecting"), cx);
1375
1376        let url = connection_options.ssh_url();
1377        let temp_dir = tempfile::Builder::new()
1378            .prefix("zed-ssh-session")
1379            .tempdir()?;
1380
1381        // Create a domain socket listener to handle requests from the askpass program.
1382        let askpass_socket = temp_dir.path().join("askpass.sock");
1383        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1384        let listener =
1385            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1386
1387        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<UnixStream>();
1388        let mut kill_tx = Some(askpass_kill_master_tx);
1389
1390        let askpass_task = cx.spawn({
1391            let delegate = delegate.clone();
1392            |mut cx| async move {
1393                let mut askpass_opened_tx = Some(askpass_opened_tx);
1394
1395                while let Ok((mut stream, _)) = listener.accept().await {
1396                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1397                        askpass_opened_tx.send(()).ok();
1398                    }
1399                    let mut buffer = Vec::new();
1400                    let mut reader = BufReader::new(&mut stream);
1401                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1402                        buffer.clear();
1403                    }
1404                    let password_prompt = String::from_utf8_lossy(&buffer);
1405                    if let Some(password) = delegate
1406                        .ask_password(password_prompt.to_string(), &mut cx)
1407                        .await
1408                        .context("failed to get ssh password")
1409                        .and_then(|p| p)
1410                        .log_err()
1411                    {
1412                        stream.write_all(password.as_bytes()).await.log_err();
1413                    } else {
1414                        if let Some(kill_tx) = kill_tx.take() {
1415                            kill_tx.send(stream).log_err();
1416                            break;
1417                        }
1418                    }
1419                }
1420            }
1421        });
1422
1423        anyhow::ensure!(
1424            which::which("nc").is_ok(),
1425            "Cannot find `nc` command (netcat), which is required to connect over SSH."
1426        );
1427
1428        // Create an askpass script that communicates back to this process.
1429        let askpass_script = format!(
1430            "{shebang}\n{print_args} | {nc} -U {askpass_socket} 2> /dev/null \n",
1431            // on macOS `brew install netcat` provides the GNU netcat implementation
1432            // which does not support -U.
1433            nc = if cfg!(target_os = "macos") {
1434                "/usr/bin/nc"
1435            } else {
1436                "nc"
1437            },
1438            askpass_socket = askpass_socket.display(),
1439            print_args = "printf '%s\\0' \"$@\"",
1440            shebang = "#!/bin/sh",
1441        );
1442        let askpass_script_path = temp_dir.path().join("askpass.sh");
1443        fs::write(&askpass_script_path, askpass_script).await?;
1444        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1445
1446        // Start the master SSH process, which does not do anything except for establish
1447        // the connection and keep it open, allowing other ssh commands to reuse it
1448        // via a control socket.
1449        let socket_path = temp_dir.path().join("ssh.sock");
1450
1451        let mut master_process = process::Command::new("ssh")
1452            .stdin(Stdio::null())
1453            .stdout(Stdio::piped())
1454            .stderr(Stdio::piped())
1455            .env("SSH_ASKPASS_REQUIRE", "force")
1456            .env("SSH_ASKPASS", &askpass_script_path)
1457            .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1458            .args([
1459                "-N",
1460                "-o",
1461                "ControlPersist=no",
1462                "-o",
1463                "ControlMaster=yes",
1464                "-o",
1465            ])
1466            .arg(format!("ControlPath={}", socket_path.display()))
1467            .arg(&url)
1468            .kill_on_drop(true)
1469            .spawn()?;
1470
1471        // Wait for this ssh process to close its stdout, indicating that authentication
1472        // has completed.
1473        let mut stdout = master_process.stdout.take().unwrap();
1474        let mut output = Vec::new();
1475        let connection_timeout = Duration::from_secs(10);
1476
1477        let result = select_biased! {
1478            _ = askpass_opened_rx.fuse() => {
1479                select_biased! {
1480                    stream = askpass_kill_master_rx.fuse() => {
1481                        master_process.kill().ok();
1482                        drop(stream);
1483                        Err(anyhow!("SSH connection canceled"))
1484                    }
1485                    // If the askpass script has opened, that means the user is typing
1486                    // their password, in which case we don't want to timeout anymore,
1487                    // since we know a connection has been established.
1488                    result = stdout.read_to_end(&mut output).fuse() => {
1489                        result?;
1490                        Ok(())
1491                    }
1492                }
1493            }
1494            _ = stdout.read_to_end(&mut output).fuse() => {
1495                Ok(())
1496            }
1497            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1498                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1499            }
1500        };
1501
1502        if let Err(e) = result {
1503            return Err(e.context("Failed to connect to host"));
1504        }
1505
1506        drop(askpass_task);
1507
1508        if master_process.try_status()?.is_some() {
1509            output.clear();
1510            let mut stderr = master_process.stderr.take().unwrap();
1511            stderr.read_to_end(&mut output).await?;
1512
1513            let error_message = format!(
1514                "failed to connect: {}",
1515                String::from_utf8_lossy(&output).trim()
1516            );
1517            Err(anyhow!(error_message))?;
1518        }
1519
1520        let socket = SshSocket {
1521            connection_options,
1522            socket_path,
1523        };
1524
1525        let mut this = Self {
1526            socket,
1527            master_process: Mutex::new(Some(master_process)),
1528            _temp_dir: temp_dir,
1529            remote_binary_path: None,
1530        };
1531
1532        let (release_channel, version, commit) = cx.update(|cx| {
1533            (
1534                ReleaseChannel::global(cx),
1535                AppVersion::global(cx),
1536                AppCommitSha::try_global(cx),
1537            )
1538        })?;
1539        this.remote_binary_path = Some(
1540            this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1541                .await?,
1542        );
1543
1544        Ok(this)
1545    }
1546
1547    async fn platform(&self) -> Result<SshPlatform> {
1548        let uname = self.socket.run_command("uname", &["-sm"]).await?;
1549        let Some((os, arch)) = uname.split_once(" ") else {
1550            Err(anyhow!("unknown uname: {uname:?}"))?
1551        };
1552
1553        let os = match os.trim() {
1554            "Darwin" => "macos",
1555            "Linux" => "linux",
1556            _ => Err(anyhow!(
1557                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
1558            ))?,
1559        };
1560        // exclude armv5,6,7 as they are 32-bit.
1561        let arch = if arch.starts_with("armv8")
1562            || arch.starts_with("armv9")
1563            || arch.starts_with("arm64")
1564            || arch.starts_with("aarch64")
1565        {
1566            "aarch64"
1567        } else if arch.starts_with("x86") {
1568            "x86_64"
1569        } else {
1570            Err(anyhow!(
1571                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
1572            ))?
1573        };
1574
1575        Ok(SshPlatform { os, arch })
1576    }
1577
1578    fn multiplex(
1579        mut ssh_proxy_process: Child,
1580        incoming_tx: UnboundedSender<Envelope>,
1581        mut outgoing_rx: UnboundedReceiver<Envelope>,
1582        mut connection_activity_tx: Sender<()>,
1583        cx: &AsyncApp,
1584    ) -> Task<Result<i32>> {
1585        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1586        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1587        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1588
1589        let mut stdin_buffer = Vec::new();
1590        let mut stdout_buffer = Vec::new();
1591        let mut stderr_buffer = Vec::new();
1592        let mut stderr_offset = 0;
1593
1594        let stdin_task = cx.background_executor().spawn(async move {
1595            while let Some(outgoing) = outgoing_rx.next().await {
1596                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1597            }
1598            anyhow::Ok(())
1599        });
1600
1601        let stdout_task = cx.background_executor().spawn({
1602            let mut connection_activity_tx = connection_activity_tx.clone();
1603            async move {
1604                loop {
1605                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1606                    let len = child_stdout.read(&mut stdout_buffer).await?;
1607
1608                    if len == 0 {
1609                        return anyhow::Ok(());
1610                    }
1611
1612                    if len < MESSAGE_LEN_SIZE {
1613                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1614                    }
1615
1616                    let message_len = message_len_from_buffer(&stdout_buffer);
1617                    let envelope =
1618                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1619                            .await?;
1620                    connection_activity_tx.try_send(()).ok();
1621                    incoming_tx.unbounded_send(envelope).ok();
1622                }
1623            }
1624        });
1625
1626        let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
1627            loop {
1628                stderr_buffer.resize(stderr_offset + 1024, 0);
1629
1630                let len = child_stderr
1631                    .read(&mut stderr_buffer[stderr_offset..])
1632                    .await?;
1633                if len == 0 {
1634                    return anyhow::Ok(());
1635                }
1636
1637                stderr_offset += len;
1638                let mut start_ix = 0;
1639                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1640                    .iter()
1641                    .position(|b| b == &b'\n')
1642                {
1643                    let line_ix = start_ix + ix;
1644                    let content = &stderr_buffer[start_ix..line_ix];
1645                    start_ix = line_ix + 1;
1646                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1647                        record.log(log::logger())
1648                    } else {
1649                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1650                    }
1651                }
1652                stderr_buffer.drain(0..start_ix);
1653                stderr_offset -= start_ix;
1654
1655                connection_activity_tx.try_send(()).ok();
1656            }
1657        });
1658
1659        cx.spawn(|_| async move {
1660            let result = futures::select! {
1661                result = stdin_task.fuse() => {
1662                    result.context("stdin")
1663                }
1664                result = stdout_task.fuse() => {
1665                    result.context("stdout")
1666                }
1667                result = stderr_task.fuse() => {
1668                    result.context("stderr")
1669                }
1670            };
1671
1672            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1673            match result {
1674                Ok(_) => Ok(status),
1675                Err(error) => Err(error),
1676            }
1677        })
1678    }
1679
1680    #[allow(unused)]
1681    async fn ensure_server_binary(
1682        &self,
1683        delegate: &Arc<dyn SshClientDelegate>,
1684        release_channel: ReleaseChannel,
1685        version: SemanticVersion,
1686        commit: Option<AppCommitSha>,
1687        cx: &mut AsyncApp,
1688    ) -> Result<PathBuf> {
1689        let version_str = match release_channel {
1690            ReleaseChannel::Nightly => {
1691                let commit = commit.map(|s| s.0.to_string()).unwrap_or_default();
1692
1693                format!("{}-{}", version, commit)
1694            }
1695            ReleaseChannel::Dev => "build".to_string(),
1696            _ => version.to_string(),
1697        };
1698        let binary_name = format!(
1699            "zed-remote-server-{}-{}",
1700            release_channel.dev_name(),
1701            version_str
1702        );
1703        let dst_path = paths::remote_server_dir_relative().join(binary_name);
1704        let tmp_path_gz = PathBuf::from(format!(
1705            "{}-download-{}.gz",
1706            dst_path.to_string_lossy(),
1707            std::process::id()
1708        ));
1709
1710        #[cfg(debug_assertions)]
1711        if std::env::var("ZED_BUILD_REMOTE_SERVER").is_ok() {
1712            let src_path = self
1713                .build_local(self.platform().await?, delegate, cx)
1714                .await?;
1715            self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1716                .await?;
1717            self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1718                .await?;
1719            return Ok(dst_path);
1720        }
1721
1722        if self
1723            .socket
1724            .run_command(&dst_path.to_string_lossy(), &["version"])
1725            .await
1726            .is_ok()
1727        {
1728            return Ok(dst_path);
1729        }
1730
1731        let wanted_version = cx.update(|cx| match release_channel {
1732            ReleaseChannel::Nightly => Ok(None),
1733            ReleaseChannel::Dev => {
1734                anyhow::bail!(
1735                    "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1736                    dst_path
1737                )
1738            }
1739            _ => Ok(Some(AppVersion::global(cx))),
1740        })??;
1741
1742        let platform = self.platform().await?;
1743
1744        if !self.socket.connection_options.upload_binary_over_ssh {
1745            if let Some((url, body)) = delegate
1746                .get_download_params(platform, release_channel, wanted_version, cx)
1747                .await?
1748            {
1749                match self
1750                    .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1751                    .await
1752                {
1753                    Ok(_) => {
1754                        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1755                            .await?;
1756                        return Ok(dst_path);
1757                    }
1758                    Err(e) => {
1759                        log::error!(
1760                            "Failed to download binary on server, attempting to upload server: {}",
1761                            e
1762                        )
1763                    }
1764                }
1765            }
1766        }
1767
1768        let src_path = delegate
1769            .download_server_binary_locally(platform, release_channel, wanted_version, cx)
1770            .await?;
1771        self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1772            .await?;
1773        self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1774            .await?;
1775        return Ok(dst_path);
1776    }
1777
1778    async fn download_binary_on_server(
1779        &self,
1780        url: &str,
1781        body: &str,
1782        tmp_path_gz: &Path,
1783        delegate: &Arc<dyn SshClientDelegate>,
1784        cx: &mut AsyncApp,
1785    ) -> Result<()> {
1786        if let Some(parent) = tmp_path_gz.parent() {
1787            self.socket
1788                .run_command("mkdir", &["-p", &parent.to_string_lossy()])
1789                .await?;
1790        }
1791
1792        delegate.set_status(Some("Downloading remote development server on host"), cx);
1793
1794        match self
1795            .socket
1796            .run_command(
1797                "curl",
1798                &[
1799                    "-f",
1800                    "-L",
1801                    "-X",
1802                    "GET",
1803                    "-H",
1804                    "Content-Type: application/json",
1805                    "-d",
1806                    &body,
1807                    &url,
1808                    "-o",
1809                    &tmp_path_gz.to_string_lossy(),
1810                ],
1811            )
1812            .await
1813        {
1814            Ok(_) => {}
1815            Err(e) => {
1816                if self.socket.run_command("which", &["curl"]).await.is_ok() {
1817                    return Err(e);
1818                }
1819
1820                match self
1821                    .socket
1822                    .run_command(
1823                        "wget",
1824                        &[
1825                            "--max-redirect=5",
1826                            "--method=GET",
1827                            "--header=Content-Type: application/json",
1828                            "--body-data",
1829                            &body,
1830                            &url,
1831                            "-O",
1832                            &tmp_path_gz.to_string_lossy(),
1833                        ],
1834                    )
1835                    .await
1836                {
1837                    Ok(_) => {}
1838                    Err(e) => {
1839                        if self.socket.run_command("which", &["wget"]).await.is_ok() {
1840                            return Err(e);
1841                        } else {
1842                            anyhow::bail!("Neither curl nor wget is available");
1843                        }
1844                    }
1845                }
1846            }
1847        }
1848
1849        Ok(())
1850    }
1851
1852    async fn upload_local_server_binary(
1853        &self,
1854        src_path: &Path,
1855        tmp_path_gz: &Path,
1856        delegate: &Arc<dyn SshClientDelegate>,
1857        cx: &mut AsyncApp,
1858    ) -> Result<()> {
1859        if let Some(parent) = tmp_path_gz.parent() {
1860            self.socket
1861                .run_command("mkdir", &["-p", &parent.to_string_lossy()])
1862                .await?;
1863        }
1864
1865        let src_stat = fs::metadata(&src_path).await?;
1866        let size = src_stat.len();
1867
1868        let t0 = Instant::now();
1869        delegate.set_status(Some("Uploading remote development server"), cx);
1870        log::info!(
1871            "uploading remote development server to {:?} ({}kb)",
1872            tmp_path_gz,
1873            size / 1024
1874        );
1875        self.upload_file(&src_path, &tmp_path_gz)
1876            .await
1877            .context("failed to upload server binary")?;
1878        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1879        Ok(())
1880    }
1881
1882    async fn extract_server_binary(
1883        &self,
1884        dst_path: &Path,
1885        tmp_path_gz: &Path,
1886        delegate: &Arc<dyn SshClientDelegate>,
1887        cx: &mut AsyncApp,
1888    ) -> Result<()> {
1889        delegate.set_status(Some("Extracting remote development server"), cx);
1890        let server_mode = 0o755;
1891
1892        let script = shell_script!(
1893            "gunzip -f {tmp_path_gz} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
1894            tmp_path_gz = &tmp_path_gz.to_string_lossy(),
1895            tmp_path = &tmp_path_gz.to_string_lossy().strip_suffix(".gz").unwrap(),
1896            server_mode = &format!("{:o}", server_mode),
1897            dst_path = &dst_path.to_string_lossy()
1898        );
1899        self.socket.run_command("sh", &["-c", &script]).await?;
1900        Ok(())
1901    }
1902
1903    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1904        log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
1905        let mut command = util::command::new_smol_command("scp");
1906        let output = self
1907            .socket
1908            .ssh_options(&mut command)
1909            .args(
1910                self.socket
1911                    .connection_options
1912                    .port
1913                    .map(|port| vec!["-P".to_string(), port.to_string()])
1914                    .unwrap_or_default(),
1915            )
1916            .arg(src_path)
1917            .arg(format!(
1918                "{}:{}",
1919                self.socket.connection_options.scp_url(),
1920                dest_path.display()
1921            ))
1922            .output()
1923            .await?;
1924
1925        if output.status.success() {
1926            Ok(())
1927        } else {
1928            Err(anyhow!(
1929                "failed to upload file {} -> {}: {}",
1930                src_path.display(),
1931                dest_path.display(),
1932                String::from_utf8_lossy(&output.stderr)
1933            ))
1934        }
1935    }
1936
1937    #[cfg(debug_assertions)]
1938    async fn build_local(
1939        &self,
1940        platform: SshPlatform,
1941        delegate: &Arc<dyn SshClientDelegate>,
1942        cx: &mut AsyncApp,
1943    ) -> Result<PathBuf> {
1944        use smol::process::{Command, Stdio};
1945
1946        async fn run_cmd(command: &mut Command) -> Result<()> {
1947            let output = command
1948                .kill_on_drop(true)
1949                .stderr(Stdio::inherit())
1950                .output()
1951                .await?;
1952            if !output.status.success() {
1953                Err(anyhow!("Failed to run command: {:?}", command))?;
1954            }
1955            Ok(())
1956        }
1957
1958        if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
1959            delegate.set_status(Some("Building remote server binary from source"), cx);
1960            log::info!("building remote server binary from source");
1961            run_cmd(Command::new("cargo").args([
1962                "build",
1963                "--package",
1964                "remote_server",
1965                "--features",
1966                "debug-embed",
1967                "--target-dir",
1968                "target/remote_server",
1969            ]))
1970            .await?;
1971
1972            delegate.set_status(Some("Compressing binary"), cx);
1973
1974            run_cmd(Command::new("gzip").args([
1975                "-9",
1976                "-f",
1977                "target/remote_server/debug/remote_server",
1978            ]))
1979            .await?;
1980
1981            let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
1982            return Ok(path);
1983        }
1984        let Some(triple) = platform.triple() else {
1985            anyhow::bail!("can't cross compile for: {:?}", platform);
1986        };
1987        smol::fs::create_dir_all("target/remote_server").await?;
1988
1989        delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
1990        log::info!("installing cross");
1991        run_cmd(Command::new("cargo").args([
1992            "install",
1993            "cross",
1994            "--git",
1995            "https://github.com/cross-rs/cross",
1996        ]))
1997        .await?;
1998
1999        delegate.set_status(
2000            Some(&format!(
2001                "Building remote server binary from source for {} with Docker",
2002                &triple
2003            )),
2004            cx,
2005        );
2006        log::info!("building remote server binary from source for {}", &triple);
2007        run_cmd(
2008            Command::new("cross")
2009                .args([
2010                    "build",
2011                    "--package",
2012                    "remote_server",
2013                    "--features",
2014                    "debug-embed",
2015                    "--target-dir",
2016                    "target/remote_server",
2017                    "--target",
2018                    &triple,
2019                ])
2020                .env(
2021                    "CROSS_CONTAINER_OPTS",
2022                    "--mount type=bind,src=./target,dst=/app/target",
2023                ),
2024        )
2025        .await?;
2026
2027        delegate.set_status(Some("Compressing binary"), cx);
2028
2029        run_cmd(Command::new("gzip").args([
2030            "-9",
2031            "-f",
2032            &format!("target/remote_server/{}/debug/remote_server", triple),
2033        ]))
2034        .await?;
2035
2036        let path = std::env::current_dir()?.join(format!(
2037            "target/remote_server/{}/debug/remote_server.gz",
2038            triple
2039        ));
2040
2041        return Ok(path);
2042    }
2043}
2044
2045type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2046
2047pub struct ChannelClient {
2048    next_message_id: AtomicU32,
2049    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2050    buffer: Mutex<VecDeque<Envelope>>,
2051    response_channels: ResponseChannels,
2052    message_handlers: Mutex<ProtoMessageHandlerSet>,
2053    max_received: AtomicU32,
2054    name: &'static str,
2055    task: Mutex<Task<Result<()>>>,
2056}
2057
2058impl ChannelClient {
2059    pub fn new(
2060        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2061        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2062        cx: &App,
2063        name: &'static str,
2064    ) -> Arc<Self> {
2065        Arc::new_cyclic(|this| Self {
2066            outgoing_tx: Mutex::new(outgoing_tx),
2067            next_message_id: AtomicU32::new(0),
2068            max_received: AtomicU32::new(0),
2069            response_channels: ResponseChannels::default(),
2070            message_handlers: Default::default(),
2071            buffer: Mutex::new(VecDeque::new()),
2072            name,
2073            task: Mutex::new(Self::start_handling_messages(
2074                this.clone(),
2075                incoming_rx,
2076                &cx.to_async(),
2077            )),
2078        })
2079    }
2080
2081    fn start_handling_messages(
2082        this: Weak<Self>,
2083        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2084        cx: &AsyncApp,
2085    ) -> Task<Result<()>> {
2086        cx.spawn(|cx| async move {
2087            let peer_id = PeerId { owner_id: 0, id: 0 };
2088            while let Some(incoming) = incoming_rx.next().await {
2089                let Some(this) = this.upgrade() else {
2090                    return anyhow::Ok(());
2091                };
2092                if let Some(ack_id) = incoming.ack_id {
2093                    let mut buffer = this.buffer.lock();
2094                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2095                        buffer.pop_front();
2096                    }
2097                }
2098                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2099                {
2100                    log::debug!(
2101                        "{}:ssh message received. name:FlushBufferedMessages",
2102                        this.name
2103                    );
2104                    {
2105                        let buffer = this.buffer.lock();
2106                        for envelope in buffer.iter() {
2107                            this.outgoing_tx
2108                                .lock()
2109                                .unbounded_send(envelope.clone())
2110                                .ok();
2111                        }
2112                    }
2113                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2114                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2115                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2116                    continue;
2117                }
2118
2119                this.max_received.store(incoming.id, SeqCst);
2120
2121                if let Some(request_id) = incoming.responding_to {
2122                    let request_id = MessageId(request_id);
2123                    let sender = this.response_channels.lock().remove(&request_id);
2124                    if let Some(sender) = sender {
2125                        let (tx, rx) = oneshot::channel();
2126                        if incoming.payload.is_some() {
2127                            sender.send((incoming, tx)).ok();
2128                        }
2129                        rx.await.ok();
2130                    }
2131                } else if let Some(envelope) =
2132                    build_typed_envelope(peer_id, Instant::now(), incoming)
2133                {
2134                    let type_name = envelope.payload_type_name();
2135                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2136                        &this.message_handlers,
2137                        envelope,
2138                        this.clone().into(),
2139                        cx.clone(),
2140                    ) {
2141                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2142                        cx.foreground_executor()
2143                            .spawn(async move {
2144                                match future.await {
2145                                    Ok(_) => {
2146                                        log::debug!(
2147                                            "{}:ssh message handled. name:{type_name}",
2148                                            this.name
2149                                        );
2150                                    }
2151                                    Err(error) => {
2152                                        log::error!(
2153                                            "{}:error handling message. type:{}, error:{}",
2154                                            this.name,
2155                                            type_name,
2156                                            format!("{error:#}").lines().fold(
2157                                                String::new(),
2158                                                |mut message, line| {
2159                                                    if !message.is_empty() {
2160                                                        message.push(' ');
2161                                                    }
2162                                                    message.push_str(line);
2163                                                    message
2164                                                }
2165                                            )
2166                                        );
2167                                    }
2168                                }
2169                            })
2170                            .detach()
2171                    } else {
2172                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2173                    }
2174                }
2175            }
2176            anyhow::Ok(())
2177        })
2178    }
2179
2180    pub fn reconnect(
2181        self: &Arc<Self>,
2182        incoming_rx: UnboundedReceiver<Envelope>,
2183        outgoing_tx: UnboundedSender<Envelope>,
2184        cx: &AsyncApp,
2185    ) {
2186        *self.outgoing_tx.lock() = outgoing_tx;
2187        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2188    }
2189
2190    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2191        let id = (TypeId::of::<E>(), remote_id);
2192
2193        let mut message_handlers = self.message_handlers.lock();
2194        if message_handlers
2195            .entities_by_type_and_remote_id
2196            .contains_key(&id)
2197        {
2198            panic!("already subscribed to entity");
2199        }
2200
2201        message_handlers.entities_by_type_and_remote_id.insert(
2202            id,
2203            EntityMessageSubscriber::Entity {
2204                handle: entity.downgrade().into(),
2205            },
2206        );
2207    }
2208
2209    pub fn request<T: RequestMessage>(
2210        &self,
2211        payload: T,
2212    ) -> impl 'static + Future<Output = Result<T::Response>> {
2213        self.request_internal(payload, true)
2214    }
2215
2216    fn request_internal<T: RequestMessage>(
2217        &self,
2218        payload: T,
2219        use_buffer: bool,
2220    ) -> impl 'static + Future<Output = Result<T::Response>> {
2221        log::debug!("ssh request start. name:{}", T::NAME);
2222        let response =
2223            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2224        async move {
2225            let response = response.await?;
2226            log::debug!("ssh request finish. name:{}", T::NAME);
2227            T::Response::from_envelope(response)
2228                .ok_or_else(|| anyhow!("received a response of the wrong type"))
2229        }
2230    }
2231
2232    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2233        smol::future::or(
2234            async {
2235                self.request_internal(proto::FlushBufferedMessages {}, false)
2236                    .await?;
2237
2238                for envelope in self.buffer.lock().iter() {
2239                    self.outgoing_tx
2240                        .lock()
2241                        .unbounded_send(envelope.clone())
2242                        .ok();
2243                }
2244                Ok(())
2245            },
2246            async {
2247                smol::Timer::after(timeout).await;
2248                Err(anyhow!("Timeout detected"))
2249            },
2250        )
2251        .await
2252    }
2253
2254    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2255        smol::future::or(
2256            async {
2257                self.request(proto::Ping {}).await?;
2258                Ok(())
2259            },
2260            async {
2261                smol::Timer::after(timeout).await;
2262                Err(anyhow!("Timeout detected"))
2263            },
2264        )
2265        .await
2266    }
2267
2268    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2269        log::debug!("ssh send name:{}", T::NAME);
2270        self.send_dynamic(payload.into_envelope(0, None, None))
2271    }
2272
2273    fn request_dynamic(
2274        &self,
2275        mut envelope: proto::Envelope,
2276        type_name: &'static str,
2277        use_buffer: bool,
2278    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2279        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2280        let (tx, rx) = oneshot::channel();
2281        let mut response_channels_lock = self.response_channels.lock();
2282        response_channels_lock.insert(MessageId(envelope.id), tx);
2283        drop(response_channels_lock);
2284
2285        let result = if use_buffer {
2286            self.send_buffered(envelope)
2287        } else {
2288            self.send_unbuffered(envelope)
2289        };
2290        async move {
2291            if let Err(error) = &result {
2292                log::error!("failed to send message: {}", error);
2293                return Err(anyhow!("failed to send message: {}", error));
2294            }
2295
2296            let response = rx.await.context("connection lost")?.0;
2297            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2298                return Err(RpcError::from_proto(error, type_name));
2299            }
2300            Ok(response)
2301        }
2302    }
2303
2304    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2305        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2306        self.send_buffered(envelope)
2307    }
2308
2309    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2310        envelope.ack_id = Some(self.max_received.load(SeqCst));
2311        self.buffer.lock().push_back(envelope.clone());
2312        // ignore errors on send (happen while we're reconnecting)
2313        // assume that the global "disconnected" overlay is sufficient.
2314        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2315        Ok(())
2316    }
2317
2318    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2319        envelope.ack_id = Some(self.max_received.load(SeqCst));
2320        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2321        Ok(())
2322    }
2323}
2324
2325impl ProtoClient for ChannelClient {
2326    fn request(
2327        &self,
2328        envelope: proto::Envelope,
2329        request_type: &'static str,
2330    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2331        self.request_dynamic(envelope, request_type, true).boxed()
2332    }
2333
2334    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2335        self.send_dynamic(envelope)
2336    }
2337
2338    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2339        self.send_dynamic(envelope)
2340    }
2341
2342    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2343        &self.message_handlers
2344    }
2345
2346    fn is_via_collab(&self) -> bool {
2347        false
2348    }
2349}
2350
2351#[cfg(any(test, feature = "test-support"))]
2352mod fake {
2353    use std::{path::PathBuf, sync::Arc};
2354
2355    use anyhow::Result;
2356    use async_trait::async_trait;
2357    use futures::{
2358        channel::{
2359            mpsc::{self, Sender},
2360            oneshot,
2361        },
2362        select_biased, FutureExt, SinkExt, StreamExt,
2363    };
2364    use gpui::{App, AsyncApp, SemanticVersion, Task, TestAppContext};
2365    use release_channel::ReleaseChannel;
2366    use rpc::proto::Envelope;
2367
2368    use super::{
2369        ChannelClient, RemoteConnection, SshClientDelegate, SshConnectionOptions, SshPlatform,
2370    };
2371
2372    pub(super) struct FakeRemoteConnection {
2373        pub(super) connection_options: SshConnectionOptions,
2374        pub(super) server_channel: Arc<ChannelClient>,
2375        pub(super) server_cx: SendableCx,
2376    }
2377
2378    pub(super) struct SendableCx(AsyncApp);
2379    impl SendableCx {
2380        // SAFETY: When run in test mode, GPUI is always single threaded.
2381        pub(super) fn new(cx: &TestAppContext) -> Self {
2382            Self(cx.to_async())
2383        }
2384
2385        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2386        fn get(&self, _: &AsyncApp) -> AsyncApp {
2387            self.0.clone()
2388        }
2389    }
2390
2391    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2392    unsafe impl Send for SendableCx {}
2393    unsafe impl Sync for SendableCx {}
2394
2395    #[async_trait(?Send)]
2396    impl RemoteConnection for FakeRemoteConnection {
2397        async fn kill(&self) -> Result<()> {
2398            Ok(())
2399        }
2400
2401        fn has_been_killed(&self) -> bool {
2402            false
2403        }
2404
2405        fn ssh_args(&self) -> Vec<String> {
2406            Vec::new()
2407        }
2408        fn upload_directory(
2409            &self,
2410            _src_path: PathBuf,
2411            _dest_path: PathBuf,
2412            _cx: &App,
2413        ) -> Task<Result<()>> {
2414            unreachable!()
2415        }
2416
2417        fn connection_options(&self) -> SshConnectionOptions {
2418            self.connection_options.clone()
2419        }
2420
2421        fn simulate_disconnect(&self, cx: &AsyncApp) {
2422            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2423            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2424            self.server_channel
2425                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2426        }
2427
2428        fn start_proxy(
2429            &self,
2430
2431            _unique_identifier: String,
2432            _reconnect: bool,
2433            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2434            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2435            mut connection_activity_tx: Sender<()>,
2436            _delegate: Arc<dyn SshClientDelegate>,
2437            cx: &mut AsyncApp,
2438        ) -> Task<Result<i32>> {
2439            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2440            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2441
2442            self.server_channel.reconnect(
2443                server_incoming_rx,
2444                server_outgoing_tx,
2445                &self.server_cx.get(cx),
2446            );
2447
2448            cx.background_executor().spawn(async move {
2449                loop {
2450                    select_biased! {
2451                        server_to_client = server_outgoing_rx.next().fuse() => {
2452                            let Some(server_to_client) = server_to_client else {
2453                                return Ok(1)
2454                            };
2455                            connection_activity_tx.try_send(()).ok();
2456                            client_incoming_tx.send(server_to_client).await.ok();
2457                        }
2458                        client_to_server = client_outgoing_rx.next().fuse() => {
2459                            let Some(client_to_server) = client_to_server else {
2460                                return Ok(1)
2461                            };
2462                            server_incoming_tx.send(client_to_server).await.ok();
2463                        }
2464                    }
2465                }
2466            })
2467        }
2468    }
2469
2470    pub(super) struct Delegate;
2471
2472    impl SshClientDelegate for Delegate {
2473        fn ask_password(&self, _: String, _: &mut AsyncApp) -> oneshot::Receiver<Result<String>> {
2474            unreachable!()
2475        }
2476
2477        fn download_server_binary_locally(
2478            &self,
2479            _: SshPlatform,
2480            _: ReleaseChannel,
2481            _: Option<SemanticVersion>,
2482            _: &mut AsyncApp,
2483        ) -> Task<Result<PathBuf>> {
2484            unreachable!()
2485        }
2486
2487        fn get_download_params(
2488            &self,
2489            _platform: SshPlatform,
2490            _release_channel: ReleaseChannel,
2491            _version: Option<SemanticVersion>,
2492            _cx: &mut AsyncApp,
2493        ) -> Task<Result<Option<(String, String)>>> {
2494            unreachable!()
2495        }
2496
2497        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2498    }
2499}