remote_client.rs

   1use crate::{
   2    SshConnectionOptions,
   3    protocol::MessageId,
   4    proxy::ProxyLaunchError,
   5    transport::{
   6        ssh::SshRemoteConnection,
   7        wsl::{WslConnectionOptions, WslRemoteConnection},
   8    },
   9};
  10use anyhow::{Context as _, Result, anyhow};
  11use askpass::EncryptedPassword;
  12use async_trait::async_trait;
  13use collections::HashMap;
  14use futures::{
  15    Future, FutureExt as _, StreamExt as _,
  16    channel::{
  17        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  18        oneshot,
  19    },
  20    future::{BoxFuture, Shared},
  21    select, select_biased,
  22};
  23use gpui::{
  24    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  25    EventEmitter, FutureExt, Global, SemanticVersion, Task, WeakEntity,
  26};
  27use parking_lot::Mutex;
  28
  29use release_channel::ReleaseChannel;
  30use rpc::{
  31    AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
  32    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  33};
  34use std::{
  35    collections::VecDeque,
  36    fmt,
  37    ops::ControlFlow,
  38    path::PathBuf,
  39    sync::{
  40        Arc, Weak,
  41        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  42    },
  43    time::{Duration, Instant},
  44};
  45use util::{
  46    ResultExt,
  47    paths::{PathStyle, RemotePathBuf},
  48};
  49
  50#[derive(Copy, Clone, Debug)]
  51pub struct RemotePlatform {
  52    pub os: &'static str,
  53    pub arch: &'static str,
  54}
  55
  56#[derive(Clone, Debug)]
  57pub struct CommandTemplate {
  58    pub program: String,
  59    pub args: Vec<String>,
  60    pub env: HashMap<String, String>,
  61}
  62
  63pub trait RemoteClientDelegate: Send + Sync {
  64    fn ask_password(
  65        &self,
  66        prompt: String,
  67        tx: oneshot::Sender<EncryptedPassword>,
  68        cx: &mut AsyncApp,
  69    );
  70    fn get_download_params(
  71        &self,
  72        platform: RemotePlatform,
  73        release_channel: ReleaseChannel,
  74        version: Option<SemanticVersion>,
  75        cx: &mut AsyncApp,
  76    ) -> Task<Result<Option<(String, String)>>>;
  77    fn download_server_binary_locally(
  78        &self,
  79        platform: RemotePlatform,
  80        release_channel: ReleaseChannel,
  81        version: Option<SemanticVersion>,
  82        cx: &mut AsyncApp,
  83    ) -> Task<Result<PathBuf>>;
  84    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
  85}
  86
  87const MAX_MISSED_HEARTBEATS: usize = 5;
  88const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
  89const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
  90
  91const MAX_RECONNECT_ATTEMPTS: usize = 3;
  92
  93enum State {
  94    Connecting,
  95    Connected {
  96        ssh_connection: Arc<dyn RemoteConnection>,
  97        delegate: Arc<dyn RemoteClientDelegate>,
  98
  99        multiplex_task: Task<Result<()>>,
 100        heartbeat_task: Task<Result<()>>,
 101    },
 102    HeartbeatMissed {
 103        missed_heartbeats: usize,
 104
 105        ssh_connection: Arc<dyn RemoteConnection>,
 106        delegate: Arc<dyn RemoteClientDelegate>,
 107
 108        multiplex_task: Task<Result<()>>,
 109        heartbeat_task: Task<Result<()>>,
 110    },
 111    Reconnecting,
 112    ReconnectFailed {
 113        ssh_connection: Arc<dyn RemoteConnection>,
 114        delegate: Arc<dyn RemoteClientDelegate>,
 115
 116        error: anyhow::Error,
 117        attempts: usize,
 118    },
 119    ReconnectExhausted,
 120    ServerNotRunning,
 121}
 122
 123impl fmt::Display for State {
 124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 125        match self {
 126            Self::Connecting => write!(f, "connecting"),
 127            Self::Connected { .. } => write!(f, "connected"),
 128            Self::Reconnecting => write!(f, "reconnecting"),
 129            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 130            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 131            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 132            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 133        }
 134    }
 135}
 136
 137impl State {
 138    fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 139        match self {
 140            Self::Connected { ssh_connection, .. } => Some(ssh_connection.clone()),
 141            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.clone()),
 142            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.clone()),
 143            _ => None,
 144        }
 145    }
 146
 147    fn can_reconnect(&self) -> bool {
 148        match self {
 149            Self::Connected { .. }
 150            | Self::HeartbeatMissed { .. }
 151            | Self::ReconnectFailed { .. } => true,
 152            State::Connecting
 153            | State::Reconnecting
 154            | State::ReconnectExhausted
 155            | State::ServerNotRunning => false,
 156        }
 157    }
 158
 159    fn is_reconnect_failed(&self) -> bool {
 160        matches!(self, Self::ReconnectFailed { .. })
 161    }
 162
 163    fn is_reconnect_exhausted(&self) -> bool {
 164        matches!(self, Self::ReconnectExhausted { .. })
 165    }
 166
 167    fn is_server_not_running(&self) -> bool {
 168        matches!(self, Self::ServerNotRunning)
 169    }
 170
 171    fn is_reconnecting(&self) -> bool {
 172        matches!(self, Self::Reconnecting { .. })
 173    }
 174
 175    fn heartbeat_recovered(self) -> Self {
 176        match self {
 177            Self::HeartbeatMissed {
 178                ssh_connection,
 179                delegate,
 180                multiplex_task,
 181                heartbeat_task,
 182                ..
 183            } => Self::Connected {
 184                ssh_connection,
 185                delegate,
 186                multiplex_task,
 187                heartbeat_task,
 188            },
 189            _ => self,
 190        }
 191    }
 192
 193    fn heartbeat_missed(self) -> Self {
 194        match self {
 195            Self::Connected {
 196                ssh_connection,
 197                delegate,
 198                multiplex_task,
 199                heartbeat_task,
 200            } => Self::HeartbeatMissed {
 201                missed_heartbeats: 1,
 202                ssh_connection,
 203                delegate,
 204                multiplex_task,
 205                heartbeat_task,
 206            },
 207            Self::HeartbeatMissed {
 208                missed_heartbeats,
 209                ssh_connection,
 210                delegate,
 211                multiplex_task,
 212                heartbeat_task,
 213            } => Self::HeartbeatMissed {
 214                missed_heartbeats: missed_heartbeats + 1,
 215                ssh_connection,
 216                delegate,
 217                multiplex_task,
 218                heartbeat_task,
 219            },
 220            _ => self,
 221        }
 222    }
 223}
 224
 225/// The state of the ssh connection.
 226#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 227pub enum ConnectionState {
 228    Connecting,
 229    Connected,
 230    HeartbeatMissed,
 231    Reconnecting,
 232    Disconnected,
 233}
 234
 235impl From<&State> for ConnectionState {
 236    fn from(value: &State) -> Self {
 237        match value {
 238            State::Connecting => Self::Connecting,
 239            State::Connected { .. } => Self::Connected,
 240            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 241            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 242            State::ReconnectExhausted => Self::Disconnected,
 243            State::ServerNotRunning => Self::Disconnected,
 244        }
 245    }
 246}
 247
 248pub struct RemoteClient {
 249    client: Arc<ChannelClient>,
 250    unique_identifier: String,
 251    connection_options: RemoteConnectionOptions,
 252    path_style: PathStyle,
 253    state: Option<State>,
 254}
 255
 256#[derive(Debug)]
 257pub enum RemoteClientEvent {
 258    Disconnected,
 259}
 260
 261impl EventEmitter<RemoteClientEvent> for RemoteClient {}
 262
 263// Identifies the socket on the remote server so that reconnects
 264// can re-join the same project.
 265pub enum ConnectionIdentifier {
 266    Setup(u64),
 267    Workspace(i64),
 268}
 269
 270static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 271
 272impl ConnectionIdentifier {
 273    pub fn setup() -> Self {
 274        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 275    }
 276
 277    // This string gets used in a socket name, and so must be relatively short.
 278    // The total length of:
 279    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 280    // Must be less than about 100 characters
 281    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 282    // So our strings should be at most 20 characters or so.
 283    fn to_string(&self, cx: &App) -> String {
 284        let identifier_prefix = match ReleaseChannel::global(cx) {
 285            ReleaseChannel::Stable => "".to_string(),
 286            release_channel => format!("{}-", release_channel.dev_name()),
 287        };
 288        match self {
 289            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 290            Self::Workspace(workspace_id) => {
 291                format!("{identifier_prefix}workspace-{workspace_id}",)
 292            }
 293        }
 294    }
 295}
 296
 297impl RemoteClient {
 298    pub fn ssh(
 299        unique_identifier: ConnectionIdentifier,
 300        connection_options: SshConnectionOptions,
 301        cancellation: oneshot::Receiver<()>,
 302        delegate: Arc<dyn RemoteClientDelegate>,
 303        cx: &mut App,
 304    ) -> Task<Result<Option<Entity<Self>>>> {
 305        Self::new(
 306            unique_identifier,
 307            RemoteConnectionOptions::Ssh(connection_options),
 308            cancellation,
 309            delegate,
 310            cx,
 311        )
 312    }
 313
 314    pub fn new(
 315        unique_identifier: ConnectionIdentifier,
 316        connection_options: RemoteConnectionOptions,
 317        cancellation: oneshot::Receiver<()>,
 318        delegate: Arc<dyn RemoteClientDelegate>,
 319        cx: &mut App,
 320    ) -> Task<Result<Option<Entity<Self>>>> {
 321        let unique_identifier = unique_identifier.to_string(cx);
 322        cx.spawn(async move |cx| {
 323            let success = Box::pin(async move {
 324                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 325                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 326                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 327
 328                let client =
 329                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 330
 331                let ssh_connection = cx
 332                    .update(|cx| {
 333                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 334                            pool.connect(connection_options.clone(), &delegate, cx)
 335                        })
 336                    })?
 337                    .await
 338                    .map_err(|e| e.cloned())?;
 339
 340                let path_style = ssh_connection.path_style();
 341                let this = cx.new(|_| Self {
 342                    client: client.clone(),
 343                    unique_identifier: unique_identifier.clone(),
 344                    connection_options,
 345                    path_style,
 346                    state: Some(State::Connecting),
 347                })?;
 348
 349                let io_task = ssh_connection.start_proxy(
 350                    unique_identifier,
 351                    false,
 352                    incoming_tx,
 353                    outgoing_rx,
 354                    connection_activity_tx,
 355                    delegate.clone(),
 356                    cx,
 357                );
 358
 359                let ready = client
 360                    .wait_for_remote_started()
 361                    .with_timeout(HEARTBEAT_TIMEOUT, cx.background_executor())
 362                    .await;
 363                match ready {
 364                    Ok(Some(_)) => {}
 365                    Ok(None) => {
 366                        let mut error = "remote client exited before becoming ready".to_owned();
 367                        if let Some(status) = io_task.now_or_never() {
 368                            match status {
 369                                Ok(exit_code) => {
 370                                    error.push_str(&format!(", exit_code={exit_code:?}"))
 371                                }
 372                                Err(e) => error.push_str(&format!(", error={e:?}")),
 373                            }
 374                        }
 375                        let error = anyhow::anyhow!("{error}");
 376                        log::error!("failed to establish connection: {}", error);
 377                        return Err(error);
 378                    }
 379                    Err(_) => {
 380                        let mut error =
 381                            "remote client did not become ready within the timeout".to_owned();
 382                        if let Some(status) = io_task.now_or_never() {
 383                            match status {
 384                                Ok(exit_code) => {
 385                                    error.push_str(&format!(", exit_code={exit_code:?}"))
 386                                }
 387                                Err(e) => error.push_str(&format!(", error={e:?}")),
 388                            }
 389                        }
 390                        let error = anyhow::anyhow!("{error}");
 391                        log::error!("failed to establish connection: {}", error);
 392                        return Err(error);
 393                    }
 394                }
 395                let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 396                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 397                    log::error!("failed to establish connection: {}", error);
 398                    return Err(error);
 399                }
 400
 401                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 402
 403                this.update(cx, |this, _| {
 404                    this.state = Some(State::Connected {
 405                        ssh_connection,
 406                        delegate,
 407                        multiplex_task,
 408                        heartbeat_task,
 409                    });
 410                })?;
 411
 412                Ok(Some(this))
 413            });
 414
 415            select! {
 416                _ = cancellation.fuse() => {
 417                    Ok(None)
 418                }
 419                result = success.fuse() =>  result
 420            }
 421        })
 422    }
 423
 424    pub fn proto_client_from_channels(
 425        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 426        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 427        cx: &App,
 428        name: &'static str,
 429    ) -> AnyProtoClient {
 430        ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
 431    }
 432
 433    pub fn shutdown_processes<T: RequestMessage>(
 434        &mut self,
 435        shutdown_request: Option<T>,
 436        executor: BackgroundExecutor,
 437    ) -> Option<impl Future<Output = ()> + use<T>> {
 438        let state = self.state.take()?;
 439        log::info!("shutting down ssh processes");
 440
 441        let State::Connected {
 442            multiplex_task,
 443            heartbeat_task,
 444            ssh_connection,
 445            delegate,
 446        } = state
 447        else {
 448            return None;
 449        };
 450
 451        let client = self.client.clone();
 452
 453        Some(async move {
 454            if let Some(shutdown_request) = shutdown_request {
 455                client.send(shutdown_request).log_err();
 456                // We wait 50ms instead of waiting for a response, because
 457                // waiting for a response would require us to wait on the main thread
 458                // which we want to avoid in an `on_app_quit` callback.
 459                executor.timer(Duration::from_millis(50)).await;
 460            }
 461
 462            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 463            // child of master_process.
 464            drop(multiplex_task);
 465            // Now drop the rest of state, which kills master process.
 466            drop(heartbeat_task);
 467            drop(ssh_connection);
 468            drop(delegate);
 469        })
 470    }
 471
 472    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 473        let can_reconnect = self
 474            .state
 475            .as_ref()
 476            .map(|state| state.can_reconnect())
 477            .unwrap_or(false);
 478        if !can_reconnect {
 479            log::info!("aborting reconnect, because not in state that allows reconnecting");
 480            let error = if let Some(state) = self.state.as_ref() {
 481                format!("invalid state, cannot reconnect while in state {state}")
 482            } else {
 483                "no state set".to_string()
 484            };
 485            anyhow::bail!(error);
 486        }
 487
 488        let state = self.state.take().unwrap();
 489        let (attempts, remote_connection, delegate) = match state {
 490            State::Connected {
 491                ssh_connection,
 492                delegate,
 493                multiplex_task,
 494                heartbeat_task,
 495            }
 496            | State::HeartbeatMissed {
 497                ssh_connection,
 498                delegate,
 499                multiplex_task,
 500                heartbeat_task,
 501                ..
 502            } => {
 503                drop(multiplex_task);
 504                drop(heartbeat_task);
 505                (0, ssh_connection, delegate)
 506            }
 507            State::ReconnectFailed {
 508                attempts,
 509                ssh_connection,
 510                delegate,
 511                ..
 512            } => (attempts, ssh_connection, delegate),
 513            State::Connecting
 514            | State::Reconnecting
 515            | State::ReconnectExhausted
 516            | State::ServerNotRunning => unreachable!(),
 517        };
 518
 519        let attempts = attempts + 1;
 520        if attempts > MAX_RECONNECT_ATTEMPTS {
 521            log::error!(
 522                "Failed to reconnect to after {} attempts, giving up",
 523                MAX_RECONNECT_ATTEMPTS
 524            );
 525            self.set_state(State::ReconnectExhausted, cx);
 526            return Ok(());
 527        }
 528
 529        self.set_state(State::Reconnecting, cx);
 530
 531        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 532
 533        let unique_identifier = self.unique_identifier.clone();
 534        let client = self.client.clone();
 535        let reconnect_task = cx.spawn(async move |this, cx| {
 536            macro_rules! failed {
 537                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 538                    return State::ReconnectFailed {
 539                        error: anyhow!($error),
 540                        attempts: $attempts,
 541                        ssh_connection: $ssh_connection,
 542                        delegate: $delegate,
 543                    };
 544                };
 545            }
 546
 547            if let Err(error) = remote_connection
 548                .kill()
 549                .await
 550                .context("Failed to kill ssh process")
 551            {
 552                failed!(error, attempts, remote_connection, delegate);
 553            };
 554
 555            let connection_options = remote_connection.connection_options();
 556
 557            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 558            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 559            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 560
 561            let (ssh_connection, io_task) = match async {
 562                let ssh_connection = cx
 563                    .update_global(|pool: &mut ConnectionPool, cx| {
 564                        pool.connect(connection_options, &delegate, cx)
 565                    })?
 566                    .await
 567                    .map_err(|error| error.cloned())?;
 568
 569                let io_task = ssh_connection.start_proxy(
 570                    unique_identifier,
 571                    true,
 572                    incoming_tx,
 573                    outgoing_rx,
 574                    connection_activity_tx,
 575                    delegate.clone(),
 576                    cx,
 577                );
 578                anyhow::Ok((ssh_connection, io_task))
 579            }
 580            .await
 581            {
 582                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 583                Err(error) => {
 584                    failed!(error, attempts, remote_connection, delegate);
 585                }
 586            };
 587
 588            let multiplex_task = Self::monitor(this.clone(), io_task, cx);
 589            client.reconnect(incoming_rx, outgoing_tx, cx);
 590
 591            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 592                failed!(error, attempts, ssh_connection, delegate);
 593            };
 594
 595            State::Connected {
 596                ssh_connection,
 597                delegate,
 598                multiplex_task,
 599                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 600            }
 601        });
 602
 603        cx.spawn(async move |this, cx| {
 604            let new_state = reconnect_task.await;
 605            this.update(cx, |this, cx| {
 606                this.try_set_state(cx, |old_state| {
 607                    if old_state.is_reconnecting() {
 608                        match &new_state {
 609                            State::Connecting
 610                            | State::Reconnecting
 611                            | State::HeartbeatMissed { .. }
 612                            | State::ServerNotRunning => {}
 613                            State::Connected { .. } => {
 614                                log::info!("Successfully reconnected");
 615                            }
 616                            State::ReconnectFailed {
 617                                error, attempts, ..
 618                            } => {
 619                                log::error!(
 620                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 621                                    attempts,
 622                                    error
 623                                );
 624                            }
 625                            State::ReconnectExhausted => {
 626                                log::error!("Reconnect attempt failed and all attempts exhausted");
 627                            }
 628                        }
 629                        Some(new_state)
 630                    } else {
 631                        None
 632                    }
 633                });
 634
 635                if this.state_is(State::is_reconnect_failed) {
 636                    this.reconnect(cx)
 637                } else if this.state_is(State::is_reconnect_exhausted) {
 638                    Ok(())
 639                } else {
 640                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 641                    Ok(())
 642                }
 643            })
 644        })
 645        .detach_and_log_err(cx);
 646
 647        Ok(())
 648    }
 649
 650    fn heartbeat(
 651        this: WeakEntity<Self>,
 652        mut connection_activity_rx: mpsc::Receiver<()>,
 653        cx: &mut AsyncApp,
 654    ) -> Task<Result<()>> {
 655        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 656            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 657        };
 658
 659        cx.spawn(async move |cx| {
 660            let mut missed_heartbeats = 0;
 661
 662            let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 663            futures::pin_mut!(keepalive_timer);
 664
 665            loop {
 666                select_biased! {
 667                    result = connection_activity_rx.next().fuse() => {
 668                        if result.is_none() {
 669                            log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 670                            return Ok(());
 671                        }
 672
 673                        if missed_heartbeats != 0 {
 674                            missed_heartbeats = 0;
 675                            let _ =this.update(cx, |this, cx| {
 676                                this.handle_heartbeat_result(missed_heartbeats, cx)
 677                            })?;
 678                        }
 679                    }
 680                    _ = keepalive_timer => {
 681                        log::debug!("Sending heartbeat to server...");
 682
 683                        let result = select_biased! {
 684                            _ = connection_activity_rx.next().fuse() => {
 685                                Ok(())
 686                            }
 687                            ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 688                                ping_result
 689                            }
 690                        };
 691
 692                        if result.is_err() {
 693                            missed_heartbeats += 1;
 694                            log::warn!(
 695                                "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 696                                HEARTBEAT_TIMEOUT,
 697                                missed_heartbeats,
 698                                MAX_MISSED_HEARTBEATS
 699                            );
 700                        } else if missed_heartbeats != 0 {
 701                            missed_heartbeats = 0;
 702                        } else {
 703                            continue;
 704                        }
 705
 706                        let result = this.update(cx, |this, cx| {
 707                            this.handle_heartbeat_result(missed_heartbeats, cx)
 708                        })?;
 709                        if result.is_break() {
 710                            return Ok(());
 711                        }
 712                    }
 713                }
 714
 715                keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 716            }
 717        })
 718    }
 719
 720    fn handle_heartbeat_result(
 721        &mut self,
 722        missed_heartbeats: usize,
 723        cx: &mut Context<Self>,
 724    ) -> ControlFlow<()> {
 725        let state = self.state.take().unwrap();
 726        let next_state = if missed_heartbeats > 0 {
 727            state.heartbeat_missed()
 728        } else {
 729            state.heartbeat_recovered()
 730        };
 731
 732        self.set_state(next_state, cx);
 733
 734        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 735            log::error!(
 736                "Missed last {} heartbeats. Reconnecting...",
 737                missed_heartbeats
 738            );
 739
 740            self.reconnect(cx)
 741                .context("failed to start reconnect process after missing heartbeats")
 742                .log_err();
 743            ControlFlow::Break(())
 744        } else {
 745            ControlFlow::Continue(())
 746        }
 747    }
 748
 749    fn monitor(
 750        this: WeakEntity<Self>,
 751        io_task: Task<Result<i32>>,
 752        cx: &AsyncApp,
 753    ) -> Task<Result<()>> {
 754        cx.spawn(async move |cx| {
 755            let result = io_task.await;
 756
 757            match result {
 758                Ok(exit_code) => {
 759                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 760                        match error {
 761                            ProxyLaunchError::ServerNotRunning => {
 762                                log::error!("failed to reconnect because server is not running");
 763                                this.update(cx, |this, cx| {
 764                                    this.set_state(State::ServerNotRunning, cx);
 765                                })?;
 766                            }
 767                        }
 768                    } else if exit_code > 0 {
 769                        log::error!("proxy process terminated unexpectedly");
 770                        this.update(cx, |this, cx| {
 771                            this.reconnect(cx).ok();
 772                        })?;
 773                    }
 774                }
 775                Err(error) => {
 776                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 777                    this.update(cx, |this, cx| {
 778                        this.reconnect(cx).ok();
 779                    })?;
 780                }
 781            }
 782
 783            Ok(())
 784        })
 785    }
 786
 787    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 788        self.state.as_ref().is_some_and(check)
 789    }
 790
 791    fn try_set_state(&mut self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
 792        let new_state = self.state.as_ref().and_then(map);
 793        if let Some(new_state) = new_state {
 794            self.state.replace(new_state);
 795            cx.notify();
 796        }
 797    }
 798
 799    fn set_state(&mut self, state: State, cx: &mut Context<Self>) {
 800        log::info!("setting state to '{}'", &state);
 801
 802        let is_reconnect_exhausted = state.is_reconnect_exhausted();
 803        let is_server_not_running = state.is_server_not_running();
 804        self.state.replace(state);
 805
 806        if is_reconnect_exhausted || is_server_not_running {
 807            cx.emit(RemoteClientEvent::Disconnected);
 808        }
 809        cx.notify();
 810    }
 811
 812    pub fn shell(&self) -> Option<String> {
 813        Some(self.remote_connection()?.shell())
 814    }
 815
 816    pub fn default_system_shell(&self) -> Option<String> {
 817        Some(self.remote_connection()?.default_system_shell())
 818    }
 819
 820    pub fn shares_network_interface(&self) -> bool {
 821        self.remote_connection()
 822            .map_or(false, |connection| connection.shares_network_interface())
 823    }
 824
 825    pub fn build_command(
 826        &self,
 827        program: Option<String>,
 828        args: &[String],
 829        env: &HashMap<String, String>,
 830        working_dir: Option<String>,
 831        port_forward: Option<(u16, String, u16)>,
 832    ) -> Result<CommandTemplate> {
 833        let Some(connection) = self.remote_connection() else {
 834            return Err(anyhow!("no ssh connection"));
 835        };
 836        connection.build_command(program, args, env, working_dir, port_forward)
 837    }
 838
 839    pub fn build_forward_port_command(
 840        &self,
 841        local_port: u16,
 842        host: String,
 843        remote_port: u16,
 844    ) -> Result<CommandTemplate> {
 845        let Some(connection) = self.remote_connection() else {
 846            return Err(anyhow!("no ssh connection"));
 847        };
 848        connection.build_forward_port_command(local_port, host, remote_port)
 849    }
 850
 851    pub fn upload_directory(
 852        &self,
 853        src_path: PathBuf,
 854        dest_path: RemotePathBuf,
 855        cx: &App,
 856    ) -> Task<Result<()>> {
 857        let Some(connection) = self.remote_connection() else {
 858            return Task::ready(Err(anyhow!("no ssh connection")));
 859        };
 860        connection.upload_directory(src_path, dest_path, cx)
 861    }
 862
 863    pub fn proto_client(&self) -> AnyProtoClient {
 864        self.client.clone().into()
 865    }
 866
 867    pub fn connection_options(&self) -> RemoteConnectionOptions {
 868        self.connection_options.clone()
 869    }
 870
 871    pub fn connection_state(&self) -> ConnectionState {
 872        self.state
 873            .as_ref()
 874            .map(ConnectionState::from)
 875            .unwrap_or(ConnectionState::Disconnected)
 876    }
 877
 878    pub fn is_disconnected(&self) -> bool {
 879        self.connection_state() == ConnectionState::Disconnected
 880    }
 881
 882    pub fn path_style(&self) -> PathStyle {
 883        self.path_style
 884    }
 885
 886    #[cfg(any(test, feature = "test-support"))]
 887    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
 888        let opts = self.connection_options();
 889        client_cx.spawn(async move |cx| {
 890            let connection = cx
 891                .update_global(|c: &mut ConnectionPool, _| {
 892                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
 893                        c.clone()
 894                    } else {
 895                        panic!("missing test connection")
 896                    }
 897                })
 898                .unwrap()
 899                .await
 900                .unwrap();
 901
 902            connection.simulate_disconnect(cx);
 903        })
 904    }
 905
 906    #[cfg(any(test, feature = "test-support"))]
 907    pub fn fake_server(
 908        client_cx: &mut gpui::TestAppContext,
 909        server_cx: &mut gpui::TestAppContext,
 910    ) -> (RemoteConnectionOptions, AnyProtoClient) {
 911        let port = client_cx
 912            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
 913        let opts = RemoteConnectionOptions::Ssh(SshConnectionOptions {
 914            host: "<fake>".to_string(),
 915            port: Some(port),
 916            ..Default::default()
 917        });
 918        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
 919        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
 920        let server_client =
 921            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
 922        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
 923            connection_options: opts.clone(),
 924            server_cx: fake::SendableCx::new(server_cx),
 925            server_channel: server_client.clone(),
 926        });
 927
 928        client_cx.update(|cx| {
 929            cx.update_default_global(|c: &mut ConnectionPool, cx| {
 930                c.connections.insert(
 931                    opts.clone(),
 932                    ConnectionPoolEntry::Connecting(
 933                        cx.background_spawn({
 934                            let connection = connection.clone();
 935                            async move { Ok(connection.clone()) }
 936                        })
 937                        .shared(),
 938                    ),
 939                );
 940            })
 941        });
 942
 943        (opts, server_client.into())
 944    }
 945
 946    #[cfg(any(test, feature = "test-support"))]
 947    pub async fn fake_client(
 948        opts: RemoteConnectionOptions,
 949        client_cx: &mut gpui::TestAppContext,
 950    ) -> Entity<Self> {
 951        let (_tx, rx) = oneshot::channel();
 952        client_cx
 953            .update(|cx| {
 954                Self::new(
 955                    ConnectionIdentifier::setup(),
 956                    opts,
 957                    rx,
 958                    Arc::new(fake::Delegate),
 959                    cx,
 960                )
 961            })
 962            .await
 963            .unwrap()
 964            .unwrap()
 965    }
 966
 967    fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 968        self.state
 969            .as_ref()
 970            .and_then(|state| state.remote_connection())
 971    }
 972}
 973
 974enum ConnectionPoolEntry {
 975    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
 976    Connected(Weak<dyn RemoteConnection>),
 977}
 978
 979#[derive(Default)]
 980struct ConnectionPool {
 981    connections: HashMap<RemoteConnectionOptions, ConnectionPoolEntry>,
 982}
 983
 984impl Global for ConnectionPool {}
 985
 986impl ConnectionPool {
 987    pub fn connect(
 988        &mut self,
 989        opts: RemoteConnectionOptions,
 990        delegate: &Arc<dyn RemoteClientDelegate>,
 991        cx: &mut App,
 992    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
 993        let connection = self.connections.get(&opts);
 994        match connection {
 995            Some(ConnectionPoolEntry::Connecting(task)) => {
 996                let delegate = delegate.clone();
 997                cx.spawn(async move |cx| {
 998                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
 999                })
1000                .detach();
1001                return task.clone();
1002            }
1003            Some(ConnectionPoolEntry::Connected(ssh)) => {
1004                if let Some(ssh) = ssh.upgrade()
1005                    && !ssh.has_been_killed()
1006                {
1007                    return Task::ready(Ok(ssh)).shared();
1008                }
1009                self.connections.remove(&opts);
1010            }
1011            None => {}
1012        }
1013
1014        let task = cx
1015            .spawn({
1016                let opts = opts.clone();
1017                let delegate = delegate.clone();
1018                async move |cx| {
1019                    let connection = match opts.clone() {
1020                        RemoteConnectionOptions::Ssh(opts) => {
1021                            SshRemoteConnection::new(opts, delegate, cx)
1022                                .await
1023                                .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1024                        }
1025                        RemoteConnectionOptions::Wsl(opts) => {
1026                            WslRemoteConnection::new(opts, delegate, cx)
1027                                .await
1028                                .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1029                        }
1030                    };
1031
1032                    cx.update_global(|pool: &mut Self, _| {
1033                        debug_assert!(matches!(
1034                            pool.connections.get(&opts),
1035                            Some(ConnectionPoolEntry::Connecting(_))
1036                        ));
1037                        match connection {
1038                            Ok(connection) => {
1039                                pool.connections.insert(
1040                                    opts.clone(),
1041                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1042                                );
1043                                Ok(connection)
1044                            }
1045                            Err(error) => {
1046                                pool.connections.remove(&opts);
1047                                Err(Arc::new(error))
1048                            }
1049                        }
1050                    })?
1051                }
1052            })
1053            .shared();
1054
1055        self.connections
1056            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1057        task
1058    }
1059}
1060
1061#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1062pub enum RemoteConnectionOptions {
1063    Ssh(SshConnectionOptions),
1064    Wsl(WslConnectionOptions),
1065}
1066
1067impl RemoteConnectionOptions {
1068    pub fn display_name(&self) -> String {
1069        match self {
1070            RemoteConnectionOptions::Ssh(opts) => opts.host.clone(),
1071            RemoteConnectionOptions::Wsl(opts) => opts.distro_name.clone(),
1072        }
1073    }
1074}
1075
1076impl From<SshConnectionOptions> for RemoteConnectionOptions {
1077    fn from(opts: SshConnectionOptions) -> Self {
1078        RemoteConnectionOptions::Ssh(opts)
1079    }
1080}
1081
1082impl From<WslConnectionOptions> for RemoteConnectionOptions {
1083    fn from(opts: WslConnectionOptions) -> Self {
1084        RemoteConnectionOptions::Wsl(opts)
1085    }
1086}
1087
1088#[async_trait(?Send)]
1089pub(crate) trait RemoteConnection: Send + Sync {
1090    fn start_proxy(
1091        &self,
1092        unique_identifier: String,
1093        reconnect: bool,
1094        incoming_tx: UnboundedSender<Envelope>,
1095        outgoing_rx: UnboundedReceiver<Envelope>,
1096        connection_activity_tx: Sender<()>,
1097        delegate: Arc<dyn RemoteClientDelegate>,
1098        cx: &mut AsyncApp,
1099    ) -> Task<Result<i32>>;
1100    fn upload_directory(
1101        &self,
1102        src_path: PathBuf,
1103        dest_path: RemotePathBuf,
1104        cx: &App,
1105    ) -> Task<Result<()>>;
1106    async fn kill(&self) -> Result<()>;
1107    fn has_been_killed(&self) -> bool;
1108    fn shares_network_interface(&self) -> bool {
1109        false
1110    }
1111    fn build_command(
1112        &self,
1113        program: Option<String>,
1114        args: &[String],
1115        env: &HashMap<String, String>,
1116        working_dir: Option<String>,
1117        port_forward: Option<(u16, String, u16)>,
1118    ) -> Result<CommandTemplate>;
1119    fn build_forward_port_command(
1120        &self,
1121        local_port: u16,
1122        remote: String,
1123        remote_port: u16,
1124    ) -> Result<CommandTemplate>;
1125    fn connection_options(&self) -> RemoteConnectionOptions;
1126    fn path_style(&self) -> PathStyle;
1127    fn shell(&self) -> String;
1128    fn default_system_shell(&self) -> String;
1129
1130    #[cfg(any(test, feature = "test-support"))]
1131    fn simulate_disconnect(&self, _: &AsyncApp) {}
1132}
1133
1134type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1135
1136struct Signal<T> {
1137    tx: Mutex<Option<oneshot::Sender<T>>>,
1138    rx: Shared<Task<Option<T>>>,
1139}
1140
1141impl<T: Send + Clone + 'static> Signal<T> {
1142    pub fn new(cx: &App) -> Self {
1143        let (tx, rx) = oneshot::channel();
1144
1145        let task = cx
1146            .background_executor()
1147            .spawn(async move { rx.await.ok() })
1148            .shared();
1149
1150        Self {
1151            tx: Mutex::new(Some(tx)),
1152            rx: task,
1153        }
1154    }
1155
1156    fn set(&self, value: T) {
1157        if let Some(tx) = self.tx.lock().take() {
1158            let _ = tx.send(value);
1159        }
1160    }
1161
1162    fn wait(&self) -> Shared<Task<Option<T>>> {
1163        self.rx.clone()
1164    }
1165}
1166
1167struct ChannelClient {
1168    next_message_id: AtomicU32,
1169    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1170    buffer: Mutex<VecDeque<Envelope>>,
1171    response_channels: ResponseChannels,
1172    message_handlers: Mutex<ProtoMessageHandlerSet>,
1173    max_received: AtomicU32,
1174    name: &'static str,
1175    task: Mutex<Task<Result<()>>>,
1176    remote_started: Signal<()>,
1177}
1178
1179impl ChannelClient {
1180    fn new(
1181        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1182        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1183        cx: &App,
1184        name: &'static str,
1185    ) -> Arc<Self> {
1186        Arc::new_cyclic(|this| Self {
1187            outgoing_tx: Mutex::new(outgoing_tx),
1188            next_message_id: AtomicU32::new(0),
1189            max_received: AtomicU32::new(0),
1190            response_channels: ResponseChannels::default(),
1191            message_handlers: Default::default(),
1192            buffer: Mutex::new(VecDeque::new()),
1193            name,
1194            task: Mutex::new(Self::start_handling_messages(
1195                this.clone(),
1196                incoming_rx,
1197                &cx.to_async(),
1198            )),
1199            remote_started: Signal::new(cx),
1200        })
1201    }
1202
1203    fn wait_for_remote_started(&self) -> Shared<Task<Option<()>>> {
1204        self.remote_started.wait()
1205    }
1206
1207    fn start_handling_messages(
1208        this: Weak<Self>,
1209        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1210        cx: &AsyncApp,
1211    ) -> Task<Result<()>> {
1212        cx.spawn(async move |cx| {
1213            if let Some(this) = this.upgrade() {
1214                let envelope = proto::RemoteStarted {}.into_envelope(0, None, None);
1215                this.outgoing_tx.lock().unbounded_send(envelope).ok();
1216            };
1217
1218            let peer_id = PeerId { owner_id: 0, id: 0 };
1219            while let Some(incoming) = incoming_rx.next().await {
1220                let Some(this) = this.upgrade() else {
1221                    return anyhow::Ok(());
1222                };
1223                if let Some(ack_id) = incoming.ack_id {
1224                    let mut buffer = this.buffer.lock();
1225                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1226                        buffer.pop_front();
1227                    }
1228                }
1229                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
1230                {
1231                    log::debug!(
1232                        "{}:ssh message received. name:FlushBufferedMessages",
1233                        this.name
1234                    );
1235                    {
1236                        let buffer = this.buffer.lock();
1237                        for envelope in buffer.iter() {
1238                            this.outgoing_tx
1239                                .lock()
1240                                .unbounded_send(envelope.clone())
1241                                .ok();
1242                        }
1243                    }
1244                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1245                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1246                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
1247                    continue;
1248                }
1249
1250                if let Some(proto::envelope::Payload::RemoteStarted(_)) = &incoming.payload {
1251                    this.remote_started.set(());
1252                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1253                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1254                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
1255                    continue;
1256                }
1257
1258                this.max_received.store(incoming.id, SeqCst);
1259
1260                if let Some(request_id) = incoming.responding_to {
1261                    let request_id = MessageId(request_id);
1262                    let sender = this.response_channels.lock().remove(&request_id);
1263                    if let Some(sender) = sender {
1264                        let (tx, rx) = oneshot::channel();
1265                        if incoming.payload.is_some() {
1266                            sender.send((incoming, tx)).ok();
1267                        }
1268                        rx.await.ok();
1269                    }
1270                } else if let Some(envelope) =
1271                    build_typed_envelope(peer_id, Instant::now(), incoming)
1272                {
1273                    let type_name = envelope.payload_type_name();
1274                    let message_id = envelope.message_id();
1275                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
1276                        &this.message_handlers,
1277                        envelope,
1278                        this.clone().into(),
1279                        cx.clone(),
1280                    ) {
1281                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
1282                        cx.foreground_executor()
1283                            .spawn(async move {
1284                                match future.await {
1285                                    Ok(_) => {
1286                                        log::debug!(
1287                                            "{}:ssh message handled. name:{type_name}",
1288                                            this.name
1289                                        );
1290                                    }
1291                                    Err(error) => {
1292                                        log::error!(
1293                                            "{}:error handling message. type:{}, error:{:#}",
1294                                            this.name,
1295                                            type_name,
1296                                            format!("{error:#}").lines().fold(
1297                                                String::new(),
1298                                                |mut message, line| {
1299                                                    if !message.is_empty() {
1300                                                        message.push(' ');
1301                                                    }
1302                                                    message.push_str(line);
1303                                                    message
1304                                                }
1305                                            )
1306                                        );
1307                                    }
1308                                }
1309                            })
1310                            .detach()
1311                    } else {
1312                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1313                        if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
1314                            message_id,
1315                            anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
1316                        ) {
1317                            log::error!(
1318                                "{}:error sending error response for {type_name}:{e:#}",
1319                                this.name
1320                            );
1321                        }
1322                    }
1323                }
1324            }
1325            anyhow::Ok(())
1326        })
1327    }
1328
1329    fn reconnect(
1330        self: &Arc<Self>,
1331        incoming_rx: UnboundedReceiver<Envelope>,
1332        outgoing_tx: UnboundedSender<Envelope>,
1333        cx: &AsyncApp,
1334    ) {
1335        *self.outgoing_tx.lock() = outgoing_tx;
1336        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1337    }
1338
1339    fn request<T: RequestMessage>(
1340        &self,
1341        payload: T,
1342    ) -> impl 'static + Future<Output = Result<T::Response>> {
1343        self.request_internal(payload, true)
1344    }
1345
1346    fn request_internal<T: RequestMessage>(
1347        &self,
1348        payload: T,
1349        use_buffer: bool,
1350    ) -> impl 'static + Future<Output = Result<T::Response>> {
1351        log::debug!("ssh request start. name:{}", T::NAME);
1352        let response =
1353            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
1354        async move {
1355            let response = response.await?;
1356            log::debug!("ssh request finish. name:{}", T::NAME);
1357            T::Response::from_envelope(response).context("received a response of the wrong type")
1358        }
1359    }
1360
1361    async fn resync(&self, timeout: Duration) -> Result<()> {
1362        smol::future::or(
1363            async {
1364                self.request_internal(proto::FlushBufferedMessages {}, false)
1365                    .await?;
1366
1367                for envelope in self.buffer.lock().iter() {
1368                    self.outgoing_tx
1369                        .lock()
1370                        .unbounded_send(envelope.clone())
1371                        .ok();
1372                }
1373                Ok(())
1374            },
1375            async {
1376                smol::Timer::after(timeout).await;
1377                anyhow::bail!("Timed out resyncing remote client")
1378            },
1379        )
1380        .await
1381    }
1382
1383    async fn ping(&self, timeout: Duration) -> Result<()> {
1384        smol::future::or(
1385            async {
1386                self.request(proto::Ping {}).await?;
1387                Ok(())
1388            },
1389            async {
1390                smol::Timer::after(timeout).await;
1391                anyhow::bail!("Timed out pinging remote client")
1392            },
1393        )
1394        .await
1395    }
1396
1397    fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1398        log::debug!("ssh send name:{}", T::NAME);
1399        self.send_dynamic(payload.into_envelope(0, None, None))
1400    }
1401
1402    fn request_dynamic(
1403        &self,
1404        mut envelope: proto::Envelope,
1405        type_name: &'static str,
1406        use_buffer: bool,
1407    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1408        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1409        let (tx, rx) = oneshot::channel();
1410        let mut response_channels_lock = self.response_channels.lock();
1411        response_channels_lock.insert(MessageId(envelope.id), tx);
1412        drop(response_channels_lock);
1413
1414        let result = if use_buffer {
1415            self.send_buffered(envelope)
1416        } else {
1417            self.send_unbuffered(envelope)
1418        };
1419        async move {
1420            if let Err(error) = &result {
1421                log::error!("failed to send message: {error}");
1422                anyhow::bail!("failed to send message: {error}");
1423            }
1424
1425            let response = rx.await.context("connection lost")?.0;
1426            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1427                return Err(RpcError::from_proto(error, type_name));
1428            }
1429            Ok(response)
1430        }
1431    }
1432
1433    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1434        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1435        self.send_buffered(envelope)
1436    }
1437
1438    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1439        envelope.ack_id = Some(self.max_received.load(SeqCst));
1440        self.buffer.lock().push_back(envelope.clone());
1441        // ignore errors on send (happen while we're reconnecting)
1442        // assume that the global "disconnected" overlay is sufficient.
1443        self.outgoing_tx.lock().unbounded_send(envelope).ok();
1444        Ok(())
1445    }
1446
1447    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1448        envelope.ack_id = Some(self.max_received.load(SeqCst));
1449        self.outgoing_tx.lock().unbounded_send(envelope).ok();
1450        Ok(())
1451    }
1452}
1453
1454impl ProtoClient for ChannelClient {
1455    fn request(
1456        &self,
1457        envelope: proto::Envelope,
1458        request_type: &'static str,
1459    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1460        self.request_dynamic(envelope, request_type, true).boxed()
1461    }
1462
1463    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1464        self.send_dynamic(envelope)
1465    }
1466
1467    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1468        self.send_dynamic(envelope)
1469    }
1470
1471    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1472        &self.message_handlers
1473    }
1474
1475    fn is_via_collab(&self) -> bool {
1476        false
1477    }
1478}
1479
1480#[cfg(any(test, feature = "test-support"))]
1481mod fake {
1482    use super::{ChannelClient, RemoteClientDelegate, RemoteConnection, RemotePlatform};
1483    use crate::remote_client::{CommandTemplate, RemoteConnectionOptions};
1484    use anyhow::Result;
1485    use askpass::EncryptedPassword;
1486    use async_trait::async_trait;
1487    use collections::HashMap;
1488    use futures::{
1489        FutureExt, SinkExt, StreamExt,
1490        channel::{
1491            mpsc::{self, Sender},
1492            oneshot,
1493        },
1494        select_biased,
1495    };
1496    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
1497    use release_channel::ReleaseChannel;
1498    use rpc::proto::Envelope;
1499    use std::{path::PathBuf, sync::Arc};
1500    use util::paths::{PathStyle, RemotePathBuf};
1501
1502    pub(super) struct FakeRemoteConnection {
1503        pub(super) connection_options: RemoteConnectionOptions,
1504        pub(super) server_channel: Arc<ChannelClient>,
1505        pub(super) server_cx: SendableCx,
1506    }
1507
1508    pub(super) struct SendableCx(AsyncApp);
1509    impl SendableCx {
1510        // SAFETY: When run in test mode, GPUI is always single threaded.
1511        pub(super) fn new(cx: &TestAppContext) -> Self {
1512            Self(cx.to_async())
1513        }
1514
1515        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
1516        fn get(&self, _: &AsyncApp) -> AsyncApp {
1517            self.0.clone()
1518        }
1519    }
1520
1521    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
1522    unsafe impl Send for SendableCx {}
1523    unsafe impl Sync for SendableCx {}
1524
1525    #[async_trait(?Send)]
1526    impl RemoteConnection for FakeRemoteConnection {
1527        async fn kill(&self) -> Result<()> {
1528            Ok(())
1529        }
1530
1531        fn has_been_killed(&self) -> bool {
1532            false
1533        }
1534
1535        fn build_command(
1536            &self,
1537            program: Option<String>,
1538            args: &[String],
1539            env: &HashMap<String, String>,
1540            _: Option<String>,
1541            _: Option<(u16, String, u16)>,
1542        ) -> Result<CommandTemplate> {
1543            let ssh_program = program.unwrap_or_else(|| "sh".to_string());
1544            let mut ssh_args = Vec::new();
1545            ssh_args.push(ssh_program);
1546            ssh_args.extend(args.iter().cloned());
1547            Ok(CommandTemplate {
1548                program: "ssh".into(),
1549                args: ssh_args,
1550                env: env.clone(),
1551            })
1552        }
1553
1554        fn build_forward_port_command(
1555            &self,
1556            local_port: u16,
1557            host: String,
1558            remote_port: u16,
1559        ) -> anyhow::Result<CommandTemplate> {
1560            Ok(CommandTemplate {
1561                program: "ssh".into(),
1562                args: vec![
1563                    "-N".into(),
1564                    "-L".into(),
1565                    format!("{local_port}:{host}:{remote_port}"),
1566                ],
1567                env: Default::default(),
1568            })
1569        }
1570
1571        fn upload_directory(
1572            &self,
1573            _src_path: PathBuf,
1574            _dest_path: RemotePathBuf,
1575            _cx: &App,
1576        ) -> Task<Result<()>> {
1577            unreachable!()
1578        }
1579
1580        fn connection_options(&self) -> RemoteConnectionOptions {
1581            self.connection_options.clone()
1582        }
1583
1584        fn simulate_disconnect(&self, cx: &AsyncApp) {
1585            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1586            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1587            self.server_channel
1588                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
1589        }
1590
1591        fn start_proxy(
1592            &self,
1593            _unique_identifier: String,
1594            _reconnect: bool,
1595            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
1596            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
1597            mut connection_activity_tx: Sender<()>,
1598            _delegate: Arc<dyn RemoteClientDelegate>,
1599            cx: &mut AsyncApp,
1600        ) -> Task<Result<i32>> {
1601            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
1602            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
1603
1604            self.server_channel.reconnect(
1605                server_incoming_rx,
1606                server_outgoing_tx,
1607                &self.server_cx.get(cx),
1608            );
1609
1610            cx.background_spawn(async move {
1611                loop {
1612                    select_biased! {
1613                        server_to_client = server_outgoing_rx.next().fuse() => {
1614                            let Some(server_to_client) = server_to_client else {
1615                                return Ok(1)
1616                            };
1617                            connection_activity_tx.try_send(()).ok();
1618                            client_incoming_tx.send(server_to_client).await.ok();
1619                        }
1620                        client_to_server = client_outgoing_rx.next().fuse() => {
1621                            let Some(client_to_server) = client_to_server else {
1622                                return Ok(1)
1623                            };
1624                            server_incoming_tx.send(client_to_server).await.ok();
1625                        }
1626                    }
1627                }
1628            })
1629        }
1630
1631        fn path_style(&self) -> PathStyle {
1632            PathStyle::local()
1633        }
1634
1635        fn shell(&self) -> String {
1636            "sh".to_owned()
1637        }
1638
1639        fn default_system_shell(&self) -> String {
1640            "sh".to_owned()
1641        }
1642    }
1643
1644    pub(super) struct Delegate;
1645
1646    impl RemoteClientDelegate for Delegate {
1647        fn ask_password(&self, _: String, _: oneshot::Sender<EncryptedPassword>, _: &mut AsyncApp) {
1648            unreachable!()
1649        }
1650
1651        fn download_server_binary_locally(
1652            &self,
1653            _: RemotePlatform,
1654            _: ReleaseChannel,
1655            _: Option<SemanticVersion>,
1656            _: &mut AsyncApp,
1657        ) -> Task<Result<PathBuf>> {
1658            unreachable!()
1659        }
1660
1661        fn get_download_params(
1662            &self,
1663            _platform: RemotePlatform,
1664            _release_channel: ReleaseChannel,
1665            _version: Option<SemanticVersion>,
1666            _cx: &mut AsyncApp,
1667        ) -> Task<Result<Option<(String, String)>>> {
1668            unreachable!()
1669        }
1670
1671        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
1672    }
1673}