remote_client.rs

   1use crate::{
   2    SshConnectionOptions,
   3    protocol::MessageId,
   4    proxy::ProxyLaunchError,
   5    transport::{
   6        ssh::SshRemoteConnection,
   7        wsl::{WslConnectionOptions, WslRemoteConnection},
   8    },
   9};
  10use anyhow::{Context as _, Result, anyhow};
  11use askpass::EncryptedPassword;
  12use async_trait::async_trait;
  13use collections::HashMap;
  14use futures::{
  15    Future, FutureExt as _, StreamExt as _,
  16    channel::{
  17        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  18        oneshot,
  19    },
  20    future::{BoxFuture, Shared},
  21    select, select_biased,
  22};
  23use gpui::{
  24    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  25    EventEmitter, FutureExt, Global, SemanticVersion, Task, WeakEntity,
  26};
  27use parking_lot::Mutex;
  28
  29use release_channel::ReleaseChannel;
  30use rpc::{
  31    AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
  32    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  33};
  34use std::{
  35    collections::VecDeque,
  36    fmt,
  37    ops::ControlFlow,
  38    path::PathBuf,
  39    sync::{
  40        Arc, Weak,
  41        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  42    },
  43    time::{Duration, Instant},
  44};
  45use util::{
  46    ResultExt,
  47    paths::{PathStyle, RemotePathBuf},
  48};
  49
  50#[derive(Copy, Clone, Debug)]
  51pub struct RemotePlatform {
  52    pub os: &'static str,
  53    pub arch: &'static str,
  54}
  55
  56#[derive(Clone, Debug)]
  57pub struct CommandTemplate {
  58    pub program: String,
  59    pub args: Vec<String>,
  60    pub env: HashMap<String, String>,
  61}
  62
  63pub trait RemoteClientDelegate: Send + Sync {
  64    fn ask_password(
  65        &self,
  66        prompt: String,
  67        tx: oneshot::Sender<EncryptedPassword>,
  68        cx: &mut AsyncApp,
  69    );
  70    fn get_download_params(
  71        &self,
  72        platform: RemotePlatform,
  73        release_channel: ReleaseChannel,
  74        version: Option<SemanticVersion>,
  75        cx: &mut AsyncApp,
  76    ) -> Task<Result<Option<(String, String)>>>;
  77    fn download_server_binary_locally(
  78        &self,
  79        platform: RemotePlatform,
  80        release_channel: ReleaseChannel,
  81        version: Option<SemanticVersion>,
  82        cx: &mut AsyncApp,
  83    ) -> Task<Result<PathBuf>>;
  84    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
  85}
  86
  87const MAX_MISSED_HEARTBEATS: usize = 5;
  88const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
  89const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
  90const INITIAL_CONNECTION_TIMEOUT: Duration = Duration::from_secs(60);
  91
  92const MAX_RECONNECT_ATTEMPTS: usize = 3;
  93
  94enum State {
  95    Connecting,
  96    Connected {
  97        remote_connection: Arc<dyn RemoteConnection>,
  98        delegate: Arc<dyn RemoteClientDelegate>,
  99
 100        multiplex_task: Task<Result<()>>,
 101        heartbeat_task: Task<Result<()>>,
 102    },
 103    HeartbeatMissed {
 104        missed_heartbeats: usize,
 105
 106        ssh_connection: Arc<dyn RemoteConnection>,
 107        delegate: Arc<dyn RemoteClientDelegate>,
 108
 109        multiplex_task: Task<Result<()>>,
 110        heartbeat_task: Task<Result<()>>,
 111    },
 112    Reconnecting,
 113    ReconnectFailed {
 114        ssh_connection: Arc<dyn RemoteConnection>,
 115        delegate: Arc<dyn RemoteClientDelegate>,
 116
 117        error: anyhow::Error,
 118        attempts: usize,
 119    },
 120    ReconnectExhausted,
 121    ServerNotRunning,
 122}
 123
 124impl fmt::Display for State {
 125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 126        match self {
 127            Self::Connecting => write!(f, "connecting"),
 128            Self::Connected { .. } => write!(f, "connected"),
 129            Self::Reconnecting => write!(f, "reconnecting"),
 130            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 131            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 132            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 133            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 134        }
 135    }
 136}
 137
 138impl State {
 139    fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 140        match self {
 141            Self::Connected {
 142                remote_connection: ssh_connection,
 143                ..
 144            } => Some(ssh_connection.clone()),
 145            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.clone()),
 146            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.clone()),
 147            _ => None,
 148        }
 149    }
 150
 151    fn can_reconnect(&self) -> bool {
 152        match self {
 153            Self::Connected { .. }
 154            | Self::HeartbeatMissed { .. }
 155            | Self::ReconnectFailed { .. } => true,
 156            State::Connecting
 157            | State::Reconnecting
 158            | State::ReconnectExhausted
 159            | State::ServerNotRunning => false,
 160        }
 161    }
 162
 163    fn is_reconnect_failed(&self) -> bool {
 164        matches!(self, Self::ReconnectFailed { .. })
 165    }
 166
 167    fn is_reconnect_exhausted(&self) -> bool {
 168        matches!(self, Self::ReconnectExhausted { .. })
 169    }
 170
 171    fn is_server_not_running(&self) -> bool {
 172        matches!(self, Self::ServerNotRunning)
 173    }
 174
 175    fn is_reconnecting(&self) -> bool {
 176        matches!(self, Self::Reconnecting { .. })
 177    }
 178
 179    fn heartbeat_recovered(self) -> Self {
 180        match self {
 181            Self::HeartbeatMissed {
 182                ssh_connection,
 183                delegate,
 184                multiplex_task,
 185                heartbeat_task,
 186                ..
 187            } => Self::Connected {
 188                remote_connection: ssh_connection,
 189                delegate,
 190                multiplex_task,
 191                heartbeat_task,
 192            },
 193            _ => self,
 194        }
 195    }
 196
 197    fn heartbeat_missed(self) -> Self {
 198        match self {
 199            Self::Connected {
 200                remote_connection: ssh_connection,
 201                delegate,
 202                multiplex_task,
 203                heartbeat_task,
 204            } => Self::HeartbeatMissed {
 205                missed_heartbeats: 1,
 206                ssh_connection,
 207                delegate,
 208                multiplex_task,
 209                heartbeat_task,
 210            },
 211            Self::HeartbeatMissed {
 212                missed_heartbeats,
 213                ssh_connection,
 214                delegate,
 215                multiplex_task,
 216                heartbeat_task,
 217            } => Self::HeartbeatMissed {
 218                missed_heartbeats: missed_heartbeats + 1,
 219                ssh_connection,
 220                delegate,
 221                multiplex_task,
 222                heartbeat_task,
 223            },
 224            _ => self,
 225        }
 226    }
 227}
 228
 229/// The state of the ssh connection.
 230#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 231pub enum ConnectionState {
 232    Connecting,
 233    Connected,
 234    HeartbeatMissed,
 235    Reconnecting,
 236    Disconnected,
 237}
 238
 239impl From<&State> for ConnectionState {
 240    fn from(value: &State) -> Self {
 241        match value {
 242            State::Connecting => Self::Connecting,
 243            State::Connected { .. } => Self::Connected,
 244            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 245            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 246            State::ReconnectExhausted => Self::Disconnected,
 247            State::ServerNotRunning => Self::Disconnected,
 248        }
 249    }
 250}
 251
 252pub struct RemoteClient {
 253    client: Arc<ChannelClient>,
 254    unique_identifier: String,
 255    connection_options: RemoteConnectionOptions,
 256    path_style: PathStyle,
 257    state: Option<State>,
 258}
 259
 260#[derive(Debug)]
 261pub enum RemoteClientEvent {
 262    Disconnected,
 263}
 264
 265impl EventEmitter<RemoteClientEvent> for RemoteClient {}
 266
 267/// Identifies the socket on the remote server so that reconnects
 268/// can re-join the same project.
 269pub enum ConnectionIdentifier {
 270    Setup(u64),
 271    Workspace(i64),
 272}
 273
 274static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 275
 276impl ConnectionIdentifier {
 277    pub fn setup() -> Self {
 278        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 279    }
 280
 281    // This string gets used in a socket name, and so must be relatively short.
 282    // The total length of:
 283    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 284    // Must be less than about 100 characters
 285    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 286    // So our strings should be at most 20 characters or so.
 287    fn to_string(&self, cx: &App) -> String {
 288        let identifier_prefix = match ReleaseChannel::global(cx) {
 289            ReleaseChannel::Stable => "".to_string(),
 290            release_channel => format!("{}-", release_channel.dev_name()),
 291        };
 292        match self {
 293            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 294            Self::Workspace(workspace_id) => {
 295                format!("{identifier_prefix}workspace-{workspace_id}",)
 296            }
 297        }
 298    }
 299}
 300
 301pub async fn connect(
 302    connection_options: RemoteConnectionOptions,
 303    delegate: Arc<dyn RemoteClientDelegate>,
 304    cx: &mut AsyncApp,
 305) -> Result<Arc<dyn RemoteConnection>> {
 306    cx.update(|cx| {
 307        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 308            pool.connect(connection_options.clone(), delegate.clone(), cx)
 309        })
 310    })?
 311    .await
 312    .map_err(|e| e.cloned())
 313}
 314
 315impl RemoteClient {
 316    pub fn new(
 317        unique_identifier: ConnectionIdentifier,
 318        remote_connection: Arc<dyn RemoteConnection>,
 319        cancellation: oneshot::Receiver<()>,
 320        delegate: Arc<dyn RemoteClientDelegate>,
 321        cx: &mut App,
 322    ) -> Task<Result<Option<Entity<Self>>>> {
 323        let unique_identifier = unique_identifier.to_string(cx);
 324        cx.spawn(async move |cx| {
 325            let success = Box::pin(async move {
 326                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 327                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 328                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 329
 330                let client =
 331                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 332
 333                let path_style = remote_connection.path_style();
 334                let this = cx.new(|_| Self {
 335                    client: client.clone(),
 336                    unique_identifier: unique_identifier.clone(),
 337                    connection_options: remote_connection.connection_options(),
 338                    path_style,
 339                    state: Some(State::Connecting),
 340                })?;
 341
 342                let io_task = remote_connection.start_proxy(
 343                    unique_identifier,
 344                    false,
 345                    incoming_tx,
 346                    outgoing_rx,
 347                    connection_activity_tx,
 348                    delegate.clone(),
 349                    cx,
 350                );
 351
 352                let ready = client
 353                    .wait_for_remote_started()
 354                    .with_timeout(INITIAL_CONNECTION_TIMEOUT, cx.background_executor())
 355                    .await;
 356                match ready {
 357                    Ok(Some(_)) => {}
 358                    Ok(None) => {
 359                        let mut error = "remote client exited before becoming ready".to_owned();
 360                        if let Some(status) = io_task.now_or_never() {
 361                            match status {
 362                                Ok(exit_code) => {
 363                                    error.push_str(&format!(", exit_code={exit_code:?}"))
 364                                }
 365                                Err(e) => error.push_str(&format!(", error={e:?}")),
 366                            }
 367                        }
 368                        let error = anyhow::anyhow!("{error}");
 369                        log::error!("failed to establish connection: {}", error);
 370                        return Err(error);
 371                    }
 372                    Err(_) => {
 373                        let mut error =
 374                            "remote client did not become ready within the timeout".to_owned();
 375                        if let Some(status) = io_task.now_or_never() {
 376                            match status {
 377                                Ok(exit_code) => {
 378                                    error.push_str(&format!(", exit_code={exit_code:?}"))
 379                                }
 380                                Err(e) => error.push_str(&format!(", error={e:?}")),
 381                            }
 382                        }
 383                        let error = anyhow::anyhow!("{error}");
 384                        log::error!("failed to establish connection: {}", error);
 385                        return Err(error);
 386                    }
 387                }
 388                let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 389                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 390                    log::error!("failed to establish connection: {}", error);
 391                    return Err(error);
 392                }
 393
 394                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 395
 396                this.update(cx, |this, _| {
 397                    this.state = Some(State::Connected {
 398                        remote_connection,
 399                        delegate,
 400                        multiplex_task,
 401                        heartbeat_task,
 402                    });
 403                })?;
 404
 405                Ok(Some(this))
 406            });
 407
 408            select! {
 409                _ = cancellation.fuse() => {
 410                    Ok(None)
 411                }
 412                result = success.fuse() =>  result
 413            }
 414        })
 415    }
 416
 417    pub fn proto_client_from_channels(
 418        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 419        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 420        cx: &App,
 421        name: &'static str,
 422    ) -> AnyProtoClient {
 423        ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
 424    }
 425
 426    pub fn shutdown_processes<T: RequestMessage>(
 427        &mut self,
 428        shutdown_request: Option<T>,
 429        executor: BackgroundExecutor,
 430    ) -> Option<impl Future<Output = ()> + use<T>> {
 431        let state = self.state.take()?;
 432        log::info!("shutting down ssh processes");
 433
 434        let State::Connected {
 435            multiplex_task,
 436            heartbeat_task,
 437            remote_connection: ssh_connection,
 438            delegate,
 439        } = state
 440        else {
 441            return None;
 442        };
 443
 444        let client = self.client.clone();
 445
 446        Some(async move {
 447            if let Some(shutdown_request) = shutdown_request {
 448                client.send(shutdown_request).log_err();
 449                // We wait 50ms instead of waiting for a response, because
 450                // waiting for a response would require us to wait on the main thread
 451                // which we want to avoid in an `on_app_quit` callback.
 452                executor.timer(Duration::from_millis(50)).await;
 453            }
 454
 455            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 456            // child of master_process.
 457            drop(multiplex_task);
 458            // Now drop the rest of state, which kills master process.
 459            drop(heartbeat_task);
 460            drop(ssh_connection);
 461            drop(delegate);
 462        })
 463    }
 464
 465    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 466        let can_reconnect = self
 467            .state
 468            .as_ref()
 469            .map(|state| state.can_reconnect())
 470            .unwrap_or(false);
 471        if !can_reconnect {
 472            log::info!("aborting reconnect, because not in state that allows reconnecting");
 473            let error = if let Some(state) = self.state.as_ref() {
 474                format!("invalid state, cannot reconnect while in state {state}")
 475            } else {
 476                "no state set".to_string()
 477            };
 478            anyhow::bail!(error);
 479        }
 480
 481        let state = self.state.take().unwrap();
 482        let (attempts, remote_connection, delegate) = match state {
 483            State::Connected {
 484                remote_connection: ssh_connection,
 485                delegate,
 486                multiplex_task,
 487                heartbeat_task,
 488            }
 489            | State::HeartbeatMissed {
 490                ssh_connection,
 491                delegate,
 492                multiplex_task,
 493                heartbeat_task,
 494                ..
 495            } => {
 496                drop(multiplex_task);
 497                drop(heartbeat_task);
 498                (0, ssh_connection, delegate)
 499            }
 500            State::ReconnectFailed {
 501                attempts,
 502                ssh_connection,
 503                delegate,
 504                ..
 505            } => (attempts, ssh_connection, delegate),
 506            State::Connecting
 507            | State::Reconnecting
 508            | State::ReconnectExhausted
 509            | State::ServerNotRunning => unreachable!(),
 510        };
 511
 512        let attempts = attempts + 1;
 513        if attempts > MAX_RECONNECT_ATTEMPTS {
 514            log::error!(
 515                "Failed to reconnect to after {} attempts, giving up",
 516                MAX_RECONNECT_ATTEMPTS
 517            );
 518            self.set_state(State::ReconnectExhausted, cx);
 519            return Ok(());
 520        }
 521
 522        self.set_state(State::Reconnecting, cx);
 523
 524        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 525
 526        let unique_identifier = self.unique_identifier.clone();
 527        let client = self.client.clone();
 528        let reconnect_task = cx.spawn(async move |this, cx| {
 529            macro_rules! failed {
 530                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 531                    delegate.set_status(Some(&format!("{error:#}", error = $error)), cx);
 532                    return State::ReconnectFailed {
 533                        error: anyhow!($error),
 534                        attempts: $attempts,
 535                        ssh_connection: $ssh_connection,
 536                        delegate: $delegate,
 537                    };
 538                };
 539            }
 540
 541            if let Err(error) = remote_connection
 542                .kill()
 543                .await
 544                .context("Failed to kill ssh process")
 545            {
 546                failed!(error, attempts, remote_connection, delegate);
 547            };
 548
 549            let connection_options = remote_connection.connection_options();
 550
 551            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 552            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 553            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 554
 555            let (ssh_connection, io_task) = match async {
 556                let ssh_connection = cx
 557                    .update_global(|pool: &mut ConnectionPool, cx| {
 558                        pool.connect(connection_options, delegate.clone(), cx)
 559                    })?
 560                    .await
 561                    .map_err(|error| error.cloned())?;
 562
 563                let io_task = ssh_connection.start_proxy(
 564                    unique_identifier,
 565                    true,
 566                    incoming_tx,
 567                    outgoing_rx,
 568                    connection_activity_tx,
 569                    delegate.clone(),
 570                    cx,
 571                );
 572                anyhow::Ok((ssh_connection, io_task))
 573            }
 574            .await
 575            {
 576                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 577                Err(error) => {
 578                    failed!(error, attempts, remote_connection, delegate);
 579                }
 580            };
 581
 582            let multiplex_task = Self::monitor(this.clone(), io_task, cx);
 583            client.reconnect(incoming_rx, outgoing_tx, cx);
 584
 585            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 586                failed!(error, attempts, ssh_connection, delegate);
 587            };
 588
 589            State::Connected {
 590                remote_connection: ssh_connection,
 591                delegate,
 592                multiplex_task,
 593                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 594            }
 595        });
 596
 597        cx.spawn(async move |this, cx| {
 598            let new_state = reconnect_task.await;
 599            this.update(cx, |this, cx| {
 600                this.try_set_state(cx, |old_state| {
 601                    if old_state.is_reconnecting() {
 602                        match &new_state {
 603                            State::Connecting
 604                            | State::Reconnecting
 605                            | State::HeartbeatMissed { .. }
 606                            | State::ServerNotRunning => {}
 607                            State::Connected { .. } => {
 608                                log::info!("Successfully reconnected");
 609                            }
 610                            State::ReconnectFailed {
 611                                error, attempts, ..
 612                            } => {
 613                                log::error!(
 614                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 615                                    attempts,
 616                                    error
 617                                );
 618                            }
 619                            State::ReconnectExhausted => {
 620                                log::error!("Reconnect attempt failed and all attempts exhausted");
 621                            }
 622                        }
 623                        Some(new_state)
 624                    } else {
 625                        None
 626                    }
 627                });
 628
 629                if this.state_is(State::is_reconnect_failed) {
 630                    this.reconnect(cx)
 631                } else if this.state_is(State::is_reconnect_exhausted) {
 632                    Ok(())
 633                } else {
 634                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 635                    Ok(())
 636                }
 637            })
 638        })
 639        .detach_and_log_err(cx);
 640
 641        Ok(())
 642    }
 643
 644    fn heartbeat(
 645        this: WeakEntity<Self>,
 646        mut connection_activity_rx: mpsc::Receiver<()>,
 647        cx: &mut AsyncApp,
 648    ) -> Task<Result<()>> {
 649        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 650            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 651        };
 652
 653        cx.spawn(async move |cx| {
 654            let mut missed_heartbeats = 0;
 655
 656            let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 657            futures::pin_mut!(keepalive_timer);
 658
 659            loop {
 660                select_biased! {
 661                    result = connection_activity_rx.next().fuse() => {
 662                        if result.is_none() {
 663                            log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 664                            return Ok(());
 665                        }
 666
 667                        if missed_heartbeats != 0 {
 668                            missed_heartbeats = 0;
 669                            let _ =this.update(cx, |this, cx| {
 670                                this.handle_heartbeat_result(missed_heartbeats, cx)
 671                            })?;
 672                        }
 673                    }
 674                    _ = keepalive_timer => {
 675                        log::debug!("Sending heartbeat to server...");
 676
 677                        let result = select_biased! {
 678                            _ = connection_activity_rx.next().fuse() => {
 679                                Ok(())
 680                            }
 681                            ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 682                                ping_result
 683                            }
 684                        };
 685
 686                        if result.is_err() {
 687                            missed_heartbeats += 1;
 688                            log::warn!(
 689                                "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 690                                HEARTBEAT_TIMEOUT,
 691                                missed_heartbeats,
 692                                MAX_MISSED_HEARTBEATS
 693                            );
 694                        } else if missed_heartbeats != 0 {
 695                            missed_heartbeats = 0;
 696                        } else {
 697                            continue;
 698                        }
 699
 700                        let result = this.update(cx, |this, cx| {
 701                            this.handle_heartbeat_result(missed_heartbeats, cx)
 702                        })?;
 703                        if result.is_break() {
 704                            return Ok(());
 705                        }
 706                    }
 707                }
 708
 709                keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 710            }
 711        })
 712    }
 713
 714    fn handle_heartbeat_result(
 715        &mut self,
 716        missed_heartbeats: usize,
 717        cx: &mut Context<Self>,
 718    ) -> ControlFlow<()> {
 719        let state = self.state.take().unwrap();
 720        let next_state = if missed_heartbeats > 0 {
 721            state.heartbeat_missed()
 722        } else {
 723            state.heartbeat_recovered()
 724        };
 725
 726        self.set_state(next_state, cx);
 727
 728        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 729            log::error!(
 730                "Missed last {} heartbeats. Reconnecting...",
 731                missed_heartbeats
 732            );
 733
 734            self.reconnect(cx)
 735                .context("failed to start reconnect process after missing heartbeats")
 736                .log_err();
 737            ControlFlow::Break(())
 738        } else {
 739            ControlFlow::Continue(())
 740        }
 741    }
 742
 743    fn monitor(
 744        this: WeakEntity<Self>,
 745        io_task: Task<Result<i32>>,
 746        cx: &AsyncApp,
 747    ) -> Task<Result<()>> {
 748        cx.spawn(async move |cx| {
 749            let result = io_task.await;
 750
 751            match result {
 752                Ok(exit_code) => {
 753                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 754                        match error {
 755                            ProxyLaunchError::ServerNotRunning => {
 756                                log::error!("failed to reconnect because server is not running");
 757                                this.update(cx, |this, cx| {
 758                                    this.set_state(State::ServerNotRunning, cx);
 759                                })?;
 760                            }
 761                        }
 762                    } else if exit_code > 0 {
 763                        log::error!("proxy process terminated unexpectedly");
 764                        this.update(cx, |this, cx| {
 765                            this.reconnect(cx).ok();
 766                        })?;
 767                    }
 768                }
 769                Err(error) => {
 770                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 771                    this.update(cx, |this, cx| {
 772                        this.reconnect(cx).ok();
 773                    })?;
 774                }
 775            }
 776
 777            Ok(())
 778        })
 779    }
 780
 781    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 782        self.state.as_ref().is_some_and(check)
 783    }
 784
 785    fn try_set_state(&mut self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
 786        let new_state = self.state.as_ref().and_then(map);
 787        if let Some(new_state) = new_state {
 788            self.state.replace(new_state);
 789            cx.notify();
 790        }
 791    }
 792
 793    fn set_state(&mut self, state: State, cx: &mut Context<Self>) {
 794        log::info!("setting state to '{}'", &state);
 795
 796        let is_reconnect_exhausted = state.is_reconnect_exhausted();
 797        let is_server_not_running = state.is_server_not_running();
 798        self.state.replace(state);
 799
 800        if is_reconnect_exhausted || is_server_not_running {
 801            cx.emit(RemoteClientEvent::Disconnected);
 802        }
 803        cx.notify();
 804    }
 805
 806    pub fn shell(&self) -> Option<String> {
 807        Some(self.remote_connection()?.shell())
 808    }
 809
 810    pub fn default_system_shell(&self) -> Option<String> {
 811        Some(self.remote_connection()?.default_system_shell())
 812    }
 813
 814    pub fn shares_network_interface(&self) -> bool {
 815        self.remote_connection()
 816            .map_or(false, |connection| connection.shares_network_interface())
 817    }
 818
 819    pub fn build_command(
 820        &self,
 821        program: Option<String>,
 822        args: &[String],
 823        env: &HashMap<String, String>,
 824        working_dir: Option<String>,
 825        port_forward: Option<(u16, String, u16)>,
 826    ) -> Result<CommandTemplate> {
 827        let Some(connection) = self.remote_connection() else {
 828            return Err(anyhow!("no ssh connection"));
 829        };
 830        connection.build_command(program, args, env, working_dir, port_forward)
 831    }
 832
 833    pub fn build_forward_ports_command(
 834        &self,
 835        forwards: Vec<(u16, String, u16)>,
 836    ) -> Result<CommandTemplate> {
 837        let Some(connection) = self.remote_connection() else {
 838            return Err(anyhow!("no ssh connection"));
 839        };
 840        connection.build_forward_ports_command(forwards)
 841    }
 842
 843    pub fn upload_directory(
 844        &self,
 845        src_path: PathBuf,
 846        dest_path: RemotePathBuf,
 847        cx: &App,
 848    ) -> Task<Result<()>> {
 849        let Some(connection) = self.remote_connection() else {
 850            return Task::ready(Err(anyhow!("no ssh connection")));
 851        };
 852        connection.upload_directory(src_path, dest_path, cx)
 853    }
 854
 855    pub fn proto_client(&self) -> AnyProtoClient {
 856        self.client.clone().into()
 857    }
 858
 859    pub fn connection_options(&self) -> RemoteConnectionOptions {
 860        self.connection_options.clone()
 861    }
 862
 863    pub fn connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 864        if let State::Connected {
 865            remote_connection, ..
 866        } = self.state.as_ref()?
 867        {
 868            Some(remote_connection.clone())
 869        } else {
 870            None
 871        }
 872    }
 873
 874    pub fn connection_state(&self) -> ConnectionState {
 875        self.state
 876            .as_ref()
 877            .map(ConnectionState::from)
 878            .unwrap_or(ConnectionState::Disconnected)
 879    }
 880
 881    pub fn is_disconnected(&self) -> bool {
 882        self.connection_state() == ConnectionState::Disconnected
 883    }
 884
 885    pub fn path_style(&self) -> PathStyle {
 886        self.path_style
 887    }
 888
 889    #[cfg(any(test, feature = "test-support"))]
 890    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
 891        let opts = self.connection_options();
 892        client_cx.spawn(async move |cx| {
 893            let connection = cx
 894                .update_global(|c: &mut ConnectionPool, _| {
 895                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
 896                        c.clone()
 897                    } else {
 898                        panic!("missing test connection")
 899                    }
 900                })
 901                .unwrap()
 902                .await
 903                .unwrap();
 904
 905            connection.simulate_disconnect(cx);
 906        })
 907    }
 908
 909    #[cfg(any(test, feature = "test-support"))]
 910    pub fn fake_server(
 911        client_cx: &mut gpui::TestAppContext,
 912        server_cx: &mut gpui::TestAppContext,
 913    ) -> (RemoteConnectionOptions, AnyProtoClient) {
 914        let port = client_cx
 915            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
 916        let opts = RemoteConnectionOptions::Ssh(SshConnectionOptions {
 917            host: "<fake>".to_string(),
 918            port: Some(port),
 919            ..Default::default()
 920        });
 921        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
 922        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
 923        let server_client =
 924            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
 925        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
 926            connection_options: opts.clone(),
 927            server_cx: fake::SendableCx::new(server_cx),
 928            server_channel: server_client.clone(),
 929        });
 930
 931        client_cx.update(|cx| {
 932            cx.update_default_global(|c: &mut ConnectionPool, cx| {
 933                c.connections.insert(
 934                    opts.clone(),
 935                    ConnectionPoolEntry::Connecting(
 936                        cx.background_spawn({
 937                            let connection = connection.clone();
 938                            async move { Ok(connection.clone()) }
 939                        })
 940                        .shared(),
 941                    ),
 942                );
 943            })
 944        });
 945
 946        (opts, server_client.into())
 947    }
 948
 949    #[cfg(any(test, feature = "test-support"))]
 950    pub async fn fake_client(
 951        opts: RemoteConnectionOptions,
 952        client_cx: &mut gpui::TestAppContext,
 953    ) -> Entity<Self> {
 954        let (_tx, rx) = oneshot::channel();
 955        let mut cx = client_cx.to_async();
 956        let connection = connect(opts, Arc::new(fake::Delegate), &mut cx)
 957            .await
 958            .unwrap();
 959        client_cx
 960            .update(|cx| {
 961                Self::new(
 962                    ConnectionIdentifier::setup(),
 963                    connection,
 964                    rx,
 965                    Arc::new(fake::Delegate),
 966                    cx,
 967                )
 968            })
 969            .await
 970            .unwrap()
 971            .unwrap()
 972    }
 973
 974    fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 975        self.state
 976            .as_ref()
 977            .and_then(|state| state.remote_connection())
 978    }
 979}
 980
 981enum ConnectionPoolEntry {
 982    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
 983    Connected(Weak<dyn RemoteConnection>),
 984}
 985
 986#[derive(Default)]
 987struct ConnectionPool {
 988    connections: HashMap<RemoteConnectionOptions, ConnectionPoolEntry>,
 989}
 990
 991impl Global for ConnectionPool {}
 992
 993impl ConnectionPool {
 994    pub fn connect(
 995        &mut self,
 996        opts: RemoteConnectionOptions,
 997        delegate: Arc<dyn RemoteClientDelegate>,
 998        cx: &mut App,
 999    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1000        let connection = self.connections.get(&opts);
1001        match connection {
1002            Some(ConnectionPoolEntry::Connecting(task)) => {
1003                delegate.set_status(
1004                    Some("Waiting for existing connection attempt"),
1005                    &mut cx.to_async(),
1006                );
1007                return task.clone();
1008            }
1009            Some(ConnectionPoolEntry::Connected(ssh)) => {
1010                if let Some(ssh) = ssh.upgrade()
1011                    && !ssh.has_been_killed()
1012                {
1013                    return Task::ready(Ok(ssh)).shared();
1014                }
1015                self.connections.remove(&opts);
1016            }
1017            None => {}
1018        }
1019
1020        let task = cx
1021            .spawn({
1022                let opts = opts.clone();
1023                let delegate = delegate.clone();
1024                async move |cx| {
1025                    let connection = match opts.clone() {
1026                        RemoteConnectionOptions::Ssh(opts) => {
1027                            SshRemoteConnection::new(opts, delegate, cx)
1028                                .await
1029                                .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1030                        }
1031                        RemoteConnectionOptions::Wsl(opts) => {
1032                            WslRemoteConnection::new(opts, delegate, cx)
1033                                .await
1034                                .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1035                        }
1036                    };
1037
1038                    cx.update_global(|pool: &mut Self, _| {
1039                        debug_assert!(matches!(
1040                            pool.connections.get(&opts),
1041                            Some(ConnectionPoolEntry::Connecting(_))
1042                        ));
1043                        match connection {
1044                            Ok(connection) => {
1045                                pool.connections.insert(
1046                                    opts.clone(),
1047                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1048                                );
1049                                Ok(connection)
1050                            }
1051                            Err(error) => {
1052                                pool.connections.remove(&opts);
1053                                Err(Arc::new(error))
1054                            }
1055                        }
1056                    })?
1057                }
1058            })
1059            .shared();
1060
1061        self.connections
1062            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1063        task
1064    }
1065}
1066
1067#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1068pub enum RemoteConnectionOptions {
1069    Ssh(SshConnectionOptions),
1070    Wsl(WslConnectionOptions),
1071}
1072
1073impl RemoteConnectionOptions {
1074    pub fn display_name(&self) -> String {
1075        match self {
1076            RemoteConnectionOptions::Ssh(opts) => opts.host.clone(),
1077            RemoteConnectionOptions::Wsl(opts) => opts.distro_name.clone(),
1078        }
1079    }
1080
1081    pub fn is_wsl(&self) -> bool {
1082        matches!(self, RemoteConnectionOptions::Wsl(_))
1083    }
1084}
1085
1086impl From<SshConnectionOptions> for RemoteConnectionOptions {
1087    fn from(opts: SshConnectionOptions) -> Self {
1088        RemoteConnectionOptions::Ssh(opts)
1089    }
1090}
1091
1092impl From<WslConnectionOptions> for RemoteConnectionOptions {
1093    fn from(opts: WslConnectionOptions) -> Self {
1094        RemoteConnectionOptions::Wsl(opts)
1095    }
1096}
1097
1098#[cfg(target_os = "windows")]
1099/// Open a wsl path (\\wsl.localhost\<distro>\path)
1100#[derive(Debug, Clone, PartialEq, Eq, gpui::Action)]
1101#[action(namespace = workspace, no_json, no_register)]
1102pub struct OpenWslPath {
1103    pub distro: WslConnectionOptions,
1104    pub paths: Vec<PathBuf>,
1105}
1106
1107#[async_trait(?Send)]
1108pub trait RemoteConnection: Send + Sync {
1109    fn start_proxy(
1110        &self,
1111        unique_identifier: String,
1112        reconnect: bool,
1113        incoming_tx: UnboundedSender<Envelope>,
1114        outgoing_rx: UnboundedReceiver<Envelope>,
1115        connection_activity_tx: Sender<()>,
1116        delegate: Arc<dyn RemoteClientDelegate>,
1117        cx: &mut AsyncApp,
1118    ) -> Task<Result<i32>>;
1119    fn upload_directory(
1120        &self,
1121        src_path: PathBuf,
1122        dest_path: RemotePathBuf,
1123        cx: &App,
1124    ) -> Task<Result<()>>;
1125    async fn kill(&self) -> Result<()>;
1126    fn has_been_killed(&self) -> bool;
1127    fn shares_network_interface(&self) -> bool {
1128        false
1129    }
1130    fn build_command(
1131        &self,
1132        program: Option<String>,
1133        args: &[String],
1134        env: &HashMap<String, String>,
1135        working_dir: Option<String>,
1136        port_forward: Option<(u16, String, u16)>,
1137    ) -> Result<CommandTemplate>;
1138    fn build_forward_ports_command(
1139        &self,
1140        forwards: Vec<(u16, String, u16)>,
1141    ) -> Result<CommandTemplate>;
1142    fn connection_options(&self) -> RemoteConnectionOptions;
1143    fn path_style(&self) -> PathStyle;
1144    fn shell(&self) -> String;
1145    fn default_system_shell(&self) -> String;
1146
1147    #[cfg(any(test, feature = "test-support"))]
1148    fn simulate_disconnect(&self, _: &AsyncApp) {}
1149}
1150
1151type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1152
1153struct Signal<T> {
1154    tx: Mutex<Option<oneshot::Sender<T>>>,
1155    rx: Shared<Task<Option<T>>>,
1156}
1157
1158impl<T: Send + Clone + 'static> Signal<T> {
1159    pub fn new(cx: &App) -> Self {
1160        let (tx, rx) = oneshot::channel();
1161
1162        let task = cx
1163            .background_executor()
1164            .spawn(async move { rx.await.ok() })
1165            .shared();
1166
1167        Self {
1168            tx: Mutex::new(Some(tx)),
1169            rx: task,
1170        }
1171    }
1172
1173    fn set(&self, value: T) {
1174        if let Some(tx) = self.tx.lock().take() {
1175            let _ = tx.send(value);
1176        }
1177    }
1178
1179    fn wait(&self) -> Shared<Task<Option<T>>> {
1180        self.rx.clone()
1181    }
1182}
1183
1184struct ChannelClient {
1185    next_message_id: AtomicU32,
1186    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1187    buffer: Mutex<VecDeque<Envelope>>,
1188    response_channels: ResponseChannels,
1189    message_handlers: Mutex<ProtoMessageHandlerSet>,
1190    max_received: AtomicU32,
1191    name: &'static str,
1192    task: Mutex<Task<Result<()>>>,
1193    remote_started: Signal<()>,
1194}
1195
1196impl ChannelClient {
1197    fn new(
1198        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1199        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1200        cx: &App,
1201        name: &'static str,
1202    ) -> Arc<Self> {
1203        Arc::new_cyclic(|this| Self {
1204            outgoing_tx: Mutex::new(outgoing_tx),
1205            next_message_id: AtomicU32::new(0),
1206            max_received: AtomicU32::new(0),
1207            response_channels: ResponseChannels::default(),
1208            message_handlers: Default::default(),
1209            buffer: Mutex::new(VecDeque::new()),
1210            name,
1211            task: Mutex::new(Self::start_handling_messages(
1212                this.clone(),
1213                incoming_rx,
1214                &cx.to_async(),
1215            )),
1216            remote_started: Signal::new(cx),
1217        })
1218    }
1219
1220    fn wait_for_remote_started(&self) -> Shared<Task<Option<()>>> {
1221        self.remote_started.wait()
1222    }
1223
1224    fn start_handling_messages(
1225        this: Weak<Self>,
1226        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1227        cx: &AsyncApp,
1228    ) -> Task<Result<()>> {
1229        cx.spawn(async move |cx| {
1230            if let Some(this) = this.upgrade() {
1231                let envelope = proto::RemoteStarted {}.into_envelope(0, None, None);
1232                this.outgoing_tx.lock().unbounded_send(envelope).ok();
1233            };
1234
1235            let peer_id = PeerId { owner_id: 0, id: 0 };
1236            while let Some(incoming) = incoming_rx.next().await {
1237                let Some(this) = this.upgrade() else {
1238                    return anyhow::Ok(());
1239                };
1240                if let Some(ack_id) = incoming.ack_id {
1241                    let mut buffer = this.buffer.lock();
1242                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1243                        buffer.pop_front();
1244                    }
1245                }
1246                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
1247                {
1248                    log::debug!(
1249                        "{}:ssh message received. name:FlushBufferedMessages",
1250                        this.name
1251                    );
1252                    {
1253                        let buffer = this.buffer.lock();
1254                        for envelope in buffer.iter() {
1255                            this.outgoing_tx
1256                                .lock()
1257                                .unbounded_send(envelope.clone())
1258                                .ok();
1259                        }
1260                    }
1261                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1262                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1263                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
1264                    continue;
1265                }
1266
1267                if let Some(proto::envelope::Payload::RemoteStarted(_)) = &incoming.payload {
1268                    this.remote_started.set(());
1269                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1270                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1271                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
1272                    continue;
1273                }
1274
1275                this.max_received.store(incoming.id, SeqCst);
1276
1277                if let Some(request_id) = incoming.responding_to {
1278                    let request_id = MessageId(request_id);
1279                    let sender = this.response_channels.lock().remove(&request_id);
1280                    if let Some(sender) = sender {
1281                        let (tx, rx) = oneshot::channel();
1282                        if incoming.payload.is_some() {
1283                            sender.send((incoming, tx)).ok();
1284                        }
1285                        rx.await.ok();
1286                    }
1287                } else if let Some(envelope) =
1288                    build_typed_envelope(peer_id, Instant::now(), incoming)
1289                {
1290                    let type_name = envelope.payload_type_name();
1291                    let message_id = envelope.message_id();
1292                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
1293                        &this.message_handlers,
1294                        envelope,
1295                        this.clone().into(),
1296                        cx.clone(),
1297                    ) {
1298                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
1299                        cx.foreground_executor()
1300                            .spawn(async move {
1301                                match future.await {
1302                                    Ok(_) => {
1303                                        log::debug!(
1304                                            "{}:ssh message handled. name:{type_name}",
1305                                            this.name
1306                                        );
1307                                    }
1308                                    Err(error) => {
1309                                        log::error!(
1310                                            "{}:error handling message. type:{}, error:{:#}",
1311                                            this.name,
1312                                            type_name,
1313                                            format!("{error:#}").lines().fold(
1314                                                String::new(),
1315                                                |mut message, line| {
1316                                                    if !message.is_empty() {
1317                                                        message.push(' ');
1318                                                    }
1319                                                    message.push_str(line);
1320                                                    message
1321                                                }
1322                                            )
1323                                        );
1324                                    }
1325                                }
1326                            })
1327                            .detach()
1328                    } else {
1329                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1330                        if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
1331                            message_id,
1332                            anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
1333                        ) {
1334                            log::error!(
1335                                "{}:error sending error response for {type_name}:{e:#}",
1336                                this.name
1337                            );
1338                        }
1339                    }
1340                }
1341            }
1342            anyhow::Ok(())
1343        })
1344    }
1345
1346    fn reconnect(
1347        self: &Arc<Self>,
1348        incoming_rx: UnboundedReceiver<Envelope>,
1349        outgoing_tx: UnboundedSender<Envelope>,
1350        cx: &AsyncApp,
1351    ) {
1352        *self.outgoing_tx.lock() = outgoing_tx;
1353        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1354    }
1355
1356    fn request<T: RequestMessage>(
1357        &self,
1358        payload: T,
1359    ) -> impl 'static + Future<Output = Result<T::Response>> {
1360        self.request_internal(payload, true)
1361    }
1362
1363    fn request_internal<T: RequestMessage>(
1364        &self,
1365        payload: T,
1366        use_buffer: bool,
1367    ) -> impl 'static + Future<Output = Result<T::Response>> {
1368        log::debug!("ssh request start. name:{}", T::NAME);
1369        let response =
1370            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
1371        async move {
1372            let response = response.await?;
1373            log::debug!("ssh request finish. name:{}", T::NAME);
1374            T::Response::from_envelope(response).context("received a response of the wrong type")
1375        }
1376    }
1377
1378    async fn resync(&self, timeout: Duration) -> Result<()> {
1379        smol::future::or(
1380            async {
1381                self.request_internal(proto::FlushBufferedMessages {}, false)
1382                    .await?;
1383
1384                for envelope in self.buffer.lock().iter() {
1385                    self.outgoing_tx
1386                        .lock()
1387                        .unbounded_send(envelope.clone())
1388                        .ok();
1389                }
1390                Ok(())
1391            },
1392            async {
1393                smol::Timer::after(timeout).await;
1394                anyhow::bail!("Timed out resyncing remote client")
1395            },
1396        )
1397        .await
1398    }
1399
1400    async fn ping(&self, timeout: Duration) -> Result<()> {
1401        smol::future::or(
1402            async {
1403                self.request(proto::Ping {}).await?;
1404                Ok(())
1405            },
1406            async {
1407                smol::Timer::after(timeout).await;
1408                anyhow::bail!("Timed out pinging remote client")
1409            },
1410        )
1411        .await
1412    }
1413
1414    fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1415        log::debug!("ssh send name:{}", T::NAME);
1416        self.send_dynamic(payload.into_envelope(0, None, None))
1417    }
1418
1419    fn request_dynamic(
1420        &self,
1421        mut envelope: proto::Envelope,
1422        type_name: &'static str,
1423        use_buffer: bool,
1424    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1425        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1426        let (tx, rx) = oneshot::channel();
1427        let mut response_channels_lock = self.response_channels.lock();
1428        response_channels_lock.insert(MessageId(envelope.id), tx);
1429        drop(response_channels_lock);
1430
1431        let result = if use_buffer {
1432            self.send_buffered(envelope)
1433        } else {
1434            self.send_unbuffered(envelope)
1435        };
1436        async move {
1437            if let Err(error) = &result {
1438                log::error!("failed to send message: {error}");
1439                anyhow::bail!("failed to send message: {error}");
1440            }
1441
1442            let response = rx.await.context("connection lost")?.0;
1443            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1444                return Err(RpcError::from_proto(error, type_name));
1445            }
1446            Ok(response)
1447        }
1448    }
1449
1450    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1451        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1452        self.send_buffered(envelope)
1453    }
1454
1455    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1456        envelope.ack_id = Some(self.max_received.load(SeqCst));
1457        self.buffer.lock().push_back(envelope.clone());
1458        // ignore errors on send (happen while we're reconnecting)
1459        // assume that the global "disconnected" overlay is sufficient.
1460        self.outgoing_tx.lock().unbounded_send(envelope).ok();
1461        Ok(())
1462    }
1463
1464    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1465        envelope.ack_id = Some(self.max_received.load(SeqCst));
1466        self.outgoing_tx.lock().unbounded_send(envelope).ok();
1467        Ok(())
1468    }
1469}
1470
1471impl ProtoClient for ChannelClient {
1472    fn request(
1473        &self,
1474        envelope: proto::Envelope,
1475        request_type: &'static str,
1476    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1477        self.request_dynamic(envelope, request_type, true).boxed()
1478    }
1479
1480    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1481        self.send_dynamic(envelope)
1482    }
1483
1484    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1485        self.send_dynamic(envelope)
1486    }
1487
1488    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1489        &self.message_handlers
1490    }
1491
1492    fn is_via_collab(&self) -> bool {
1493        false
1494    }
1495}
1496
1497#[cfg(any(test, feature = "test-support"))]
1498mod fake {
1499    use super::{ChannelClient, RemoteClientDelegate, RemoteConnection, RemotePlatform};
1500    use crate::remote_client::{CommandTemplate, RemoteConnectionOptions};
1501    use anyhow::Result;
1502    use askpass::EncryptedPassword;
1503    use async_trait::async_trait;
1504    use collections::HashMap;
1505    use futures::{
1506        FutureExt, SinkExt, StreamExt,
1507        channel::{
1508            mpsc::{self, Sender},
1509            oneshot,
1510        },
1511        select_biased,
1512    };
1513    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
1514    use release_channel::ReleaseChannel;
1515    use rpc::proto::Envelope;
1516    use std::{path::PathBuf, sync::Arc};
1517    use util::paths::{PathStyle, RemotePathBuf};
1518
1519    pub(super) struct FakeRemoteConnection {
1520        pub(super) connection_options: RemoteConnectionOptions,
1521        pub(super) server_channel: Arc<ChannelClient>,
1522        pub(super) server_cx: SendableCx,
1523    }
1524
1525    pub(super) struct SendableCx(AsyncApp);
1526    impl SendableCx {
1527        // SAFETY: When run in test mode, GPUI is always single threaded.
1528        pub(super) fn new(cx: &TestAppContext) -> Self {
1529            Self(cx.to_async())
1530        }
1531
1532        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
1533        fn get(&self, _: &AsyncApp) -> AsyncApp {
1534            self.0.clone()
1535        }
1536    }
1537
1538    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
1539    unsafe impl Send for SendableCx {}
1540    unsafe impl Sync for SendableCx {}
1541
1542    #[async_trait(?Send)]
1543    impl RemoteConnection for FakeRemoteConnection {
1544        async fn kill(&self) -> Result<()> {
1545            Ok(())
1546        }
1547
1548        fn has_been_killed(&self) -> bool {
1549            false
1550        }
1551
1552        fn build_command(
1553            &self,
1554            program: Option<String>,
1555            args: &[String],
1556            env: &HashMap<String, String>,
1557            _: Option<String>,
1558            _: Option<(u16, String, u16)>,
1559        ) -> Result<CommandTemplate> {
1560            let ssh_program = program.unwrap_or_else(|| "sh".to_string());
1561            let mut ssh_args = Vec::new();
1562            ssh_args.push(ssh_program);
1563            ssh_args.extend(args.iter().cloned());
1564            Ok(CommandTemplate {
1565                program: "ssh".into(),
1566                args: ssh_args,
1567                env: env.clone(),
1568            })
1569        }
1570
1571        fn build_forward_ports_command(
1572            &self,
1573            forwards: Vec<(u16, String, u16)>,
1574        ) -> anyhow::Result<CommandTemplate> {
1575            Ok(CommandTemplate {
1576                program: "ssh".into(),
1577                args: std::iter::once("-N".to_owned())
1578                    .chain(forwards.into_iter().map(|(local_port, host, remote_port)| {
1579                        format!("{local_port}:{host}:{remote_port}")
1580                    }))
1581                    .collect(),
1582                env: Default::default(),
1583            })
1584        }
1585
1586        fn upload_directory(
1587            &self,
1588            _src_path: PathBuf,
1589            _dest_path: RemotePathBuf,
1590            _cx: &App,
1591        ) -> Task<Result<()>> {
1592            unreachable!()
1593        }
1594
1595        fn connection_options(&self) -> RemoteConnectionOptions {
1596            self.connection_options.clone()
1597        }
1598
1599        fn simulate_disconnect(&self, cx: &AsyncApp) {
1600            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1601            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1602            self.server_channel
1603                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
1604        }
1605
1606        fn start_proxy(
1607            &self,
1608            _unique_identifier: String,
1609            _reconnect: bool,
1610            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
1611            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
1612            mut connection_activity_tx: Sender<()>,
1613            _delegate: Arc<dyn RemoteClientDelegate>,
1614            cx: &mut AsyncApp,
1615        ) -> Task<Result<i32>> {
1616            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
1617            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
1618
1619            self.server_channel.reconnect(
1620                server_incoming_rx,
1621                server_outgoing_tx,
1622                &self.server_cx.get(cx),
1623            );
1624
1625            cx.background_spawn(async move {
1626                loop {
1627                    select_biased! {
1628                        server_to_client = server_outgoing_rx.next().fuse() => {
1629                            let Some(server_to_client) = server_to_client else {
1630                                return Ok(1)
1631                            };
1632                            connection_activity_tx.try_send(()).ok();
1633                            client_incoming_tx.send(server_to_client).await.ok();
1634                        }
1635                        client_to_server = client_outgoing_rx.next().fuse() => {
1636                            let Some(client_to_server) = client_to_server else {
1637                                return Ok(1)
1638                            };
1639                            server_incoming_tx.send(client_to_server).await.ok();
1640                        }
1641                    }
1642                }
1643            })
1644        }
1645
1646        fn path_style(&self) -> PathStyle {
1647            PathStyle::local()
1648        }
1649
1650        fn shell(&self) -> String {
1651            "sh".to_owned()
1652        }
1653
1654        fn default_system_shell(&self) -> String {
1655            "sh".to_owned()
1656        }
1657    }
1658
1659    pub(super) struct Delegate;
1660
1661    impl RemoteClientDelegate for Delegate {
1662        fn ask_password(&self, _: String, _: oneshot::Sender<EncryptedPassword>, _: &mut AsyncApp) {
1663            unreachable!()
1664        }
1665
1666        fn download_server_binary_locally(
1667            &self,
1668            _: RemotePlatform,
1669            _: ReleaseChannel,
1670            _: Option<SemanticVersion>,
1671            _: &mut AsyncApp,
1672        ) -> Task<Result<PathBuf>> {
1673            unreachable!()
1674        }
1675
1676        fn get_download_params(
1677            &self,
1678            _platform: RemotePlatform,
1679            _release_channel: ReleaseChannel,
1680            _version: Option<SemanticVersion>,
1681            _cx: &mut AsyncApp,
1682        ) -> Task<Result<Option<(String, String)>>> {
1683            unreachable!()
1684        }
1685
1686        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
1687    }
1688}