remote_client.rs

   1use crate::{
   2    SshConnectionOptions,
   3    protocol::MessageId,
   4    proxy::ProxyLaunchError,
   5    transport::{
   6        ssh::SshRemoteConnection,
   7        wsl::{WslConnectionOptions, WslRemoteConnection},
   8    },
   9};
  10use anyhow::{Context as _, Result, anyhow};
  11use askpass::EncryptedPassword;
  12use async_trait::async_trait;
  13use collections::HashMap;
  14use futures::{
  15    Future, FutureExt as _, StreamExt as _,
  16    channel::{
  17        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  18        oneshot,
  19    },
  20    future::{BoxFuture, Shared},
  21    select, select_biased,
  22};
  23use gpui::{
  24    App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
  25    EventEmitter, FutureExt, Global, SemanticVersion, Task, WeakEntity,
  26};
  27use parking_lot::Mutex;
  28
  29use release_channel::ReleaseChannel;
  30use rpc::{
  31    AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
  32    proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
  33};
  34use std::{
  35    collections::VecDeque,
  36    fmt,
  37    ops::ControlFlow,
  38    path::PathBuf,
  39    sync::{
  40        Arc, Weak,
  41        atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
  42    },
  43    time::{Duration, Instant},
  44};
  45use util::{
  46    ResultExt,
  47    paths::{PathStyle, RemotePathBuf},
  48};
  49
  50#[derive(Copy, Clone, Debug)]
  51pub struct RemotePlatform {
  52    pub os: &'static str,
  53    pub arch: &'static str,
  54}
  55
  56#[derive(Clone, Debug)]
  57pub struct CommandTemplate {
  58    pub program: String,
  59    pub args: Vec<String>,
  60    pub env: HashMap<String, String>,
  61}
  62
  63pub trait RemoteClientDelegate: Send + Sync {
  64    fn ask_password(
  65        &self,
  66        prompt: String,
  67        tx: oneshot::Sender<EncryptedPassword>,
  68        cx: &mut AsyncApp,
  69    );
  70    fn get_download_params(
  71        &self,
  72        platform: RemotePlatform,
  73        release_channel: ReleaseChannel,
  74        version: Option<SemanticVersion>,
  75        cx: &mut AsyncApp,
  76    ) -> Task<Result<Option<(String, String)>>>;
  77    fn download_server_binary_locally(
  78        &self,
  79        platform: RemotePlatform,
  80        release_channel: ReleaseChannel,
  81        version: Option<SemanticVersion>,
  82        cx: &mut AsyncApp,
  83    ) -> Task<Result<PathBuf>>;
  84    fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
  85}
  86
  87const MAX_MISSED_HEARTBEATS: usize = 5;
  88const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
  89const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
  90
  91const MAX_RECONNECT_ATTEMPTS: usize = 3;
  92
  93enum State {
  94    Connecting,
  95    Connected {
  96        remote_connection: Arc<dyn RemoteConnection>,
  97        delegate: Arc<dyn RemoteClientDelegate>,
  98
  99        multiplex_task: Task<Result<()>>,
 100        heartbeat_task: Task<Result<()>>,
 101    },
 102    HeartbeatMissed {
 103        missed_heartbeats: usize,
 104
 105        ssh_connection: Arc<dyn RemoteConnection>,
 106        delegate: Arc<dyn RemoteClientDelegate>,
 107
 108        multiplex_task: Task<Result<()>>,
 109        heartbeat_task: Task<Result<()>>,
 110    },
 111    Reconnecting,
 112    ReconnectFailed {
 113        ssh_connection: Arc<dyn RemoteConnection>,
 114        delegate: Arc<dyn RemoteClientDelegate>,
 115
 116        error: anyhow::Error,
 117        attempts: usize,
 118    },
 119    ReconnectExhausted,
 120    ServerNotRunning,
 121}
 122
 123impl fmt::Display for State {
 124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 125        match self {
 126            Self::Connecting => write!(f, "connecting"),
 127            Self::Connected { .. } => write!(f, "connected"),
 128            Self::Reconnecting => write!(f, "reconnecting"),
 129            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 130            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 131            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 132            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 133        }
 134    }
 135}
 136
 137impl State {
 138    fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 139        match self {
 140            Self::Connected {
 141                remote_connection: ssh_connection,
 142                ..
 143            } => Some(ssh_connection.clone()),
 144            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.clone()),
 145            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.clone()),
 146            _ => None,
 147        }
 148    }
 149
 150    fn can_reconnect(&self) -> bool {
 151        match self {
 152            Self::Connected { .. }
 153            | Self::HeartbeatMissed { .. }
 154            | Self::ReconnectFailed { .. } => true,
 155            State::Connecting
 156            | State::Reconnecting
 157            | State::ReconnectExhausted
 158            | State::ServerNotRunning => false,
 159        }
 160    }
 161
 162    fn is_reconnect_failed(&self) -> bool {
 163        matches!(self, Self::ReconnectFailed { .. })
 164    }
 165
 166    fn is_reconnect_exhausted(&self) -> bool {
 167        matches!(self, Self::ReconnectExhausted { .. })
 168    }
 169
 170    fn is_server_not_running(&self) -> bool {
 171        matches!(self, Self::ServerNotRunning)
 172    }
 173
 174    fn is_reconnecting(&self) -> bool {
 175        matches!(self, Self::Reconnecting { .. })
 176    }
 177
 178    fn heartbeat_recovered(self) -> Self {
 179        match self {
 180            Self::HeartbeatMissed {
 181                ssh_connection,
 182                delegate,
 183                multiplex_task,
 184                heartbeat_task,
 185                ..
 186            } => Self::Connected {
 187                remote_connection: ssh_connection,
 188                delegate,
 189                multiplex_task,
 190                heartbeat_task,
 191            },
 192            _ => self,
 193        }
 194    }
 195
 196    fn heartbeat_missed(self) -> Self {
 197        match self {
 198            Self::Connected {
 199                remote_connection: ssh_connection,
 200                delegate,
 201                multiplex_task,
 202                heartbeat_task,
 203            } => Self::HeartbeatMissed {
 204                missed_heartbeats: 1,
 205                ssh_connection,
 206                delegate,
 207                multiplex_task,
 208                heartbeat_task,
 209            },
 210            Self::HeartbeatMissed {
 211                missed_heartbeats,
 212                ssh_connection,
 213                delegate,
 214                multiplex_task,
 215                heartbeat_task,
 216            } => Self::HeartbeatMissed {
 217                missed_heartbeats: missed_heartbeats + 1,
 218                ssh_connection,
 219                delegate,
 220                multiplex_task,
 221                heartbeat_task,
 222            },
 223            _ => self,
 224        }
 225    }
 226}
 227
 228/// The state of the ssh connection.
 229#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 230pub enum ConnectionState {
 231    Connecting,
 232    Connected,
 233    HeartbeatMissed,
 234    Reconnecting,
 235    Disconnected,
 236}
 237
 238impl From<&State> for ConnectionState {
 239    fn from(value: &State) -> Self {
 240        match value {
 241            State::Connecting => Self::Connecting,
 242            State::Connected { .. } => Self::Connected,
 243            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 244            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 245            State::ReconnectExhausted => Self::Disconnected,
 246            State::ServerNotRunning => Self::Disconnected,
 247        }
 248    }
 249}
 250
 251pub struct RemoteClient {
 252    client: Arc<ChannelClient>,
 253    unique_identifier: String,
 254    connection_options: RemoteConnectionOptions,
 255    path_style: PathStyle,
 256    state: Option<State>,
 257}
 258
 259#[derive(Debug)]
 260pub enum RemoteClientEvent {
 261    Disconnected,
 262}
 263
 264impl EventEmitter<RemoteClientEvent> for RemoteClient {}
 265
 266/// Identifies the socket on the remote server so that reconnects
 267/// can re-join the same project.
 268pub enum ConnectionIdentifier {
 269    Setup(u64),
 270    Workspace(i64),
 271}
 272
 273static NEXT_ID: AtomicU64 = AtomicU64::new(1);
 274
 275impl ConnectionIdentifier {
 276    pub fn setup() -> Self {
 277        Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
 278    }
 279
 280    // This string gets used in a socket name, and so must be relatively short.
 281    // The total length of:
 282    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 283    // Must be less than about 100 characters
 284    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 285    // So our strings should be at most 20 characters or so.
 286    fn to_string(&self, cx: &App) -> String {
 287        let identifier_prefix = match ReleaseChannel::global(cx) {
 288            ReleaseChannel::Stable => "".to_string(),
 289            release_channel => format!("{}-", release_channel.dev_name()),
 290        };
 291        match self {
 292            Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
 293            Self::Workspace(workspace_id) => {
 294                format!("{identifier_prefix}workspace-{workspace_id}",)
 295            }
 296        }
 297    }
 298}
 299
 300pub async fn connect(
 301    connection_options: RemoteConnectionOptions,
 302    delegate: Arc<dyn RemoteClientDelegate>,
 303    cx: &mut AsyncApp,
 304) -> Result<Arc<dyn RemoteConnection>> {
 305    cx.update(|cx| {
 306        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 307            pool.connect(connection_options.clone(), delegate.clone(), cx)
 308        })
 309    })?
 310    .await
 311    .map_err(|e| e.cloned())
 312}
 313
 314impl RemoteClient {
 315    pub fn new(
 316        unique_identifier: ConnectionIdentifier,
 317        remote_connection: Arc<dyn RemoteConnection>,
 318        cancellation: oneshot::Receiver<()>,
 319        delegate: Arc<dyn RemoteClientDelegate>,
 320        cx: &mut App,
 321    ) -> Task<Result<Option<Entity<Self>>>> {
 322        let unique_identifier = unique_identifier.to_string(cx);
 323        cx.spawn(async move |cx| {
 324            let success = Box::pin(async move {
 325                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 326                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 327                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 328
 329                let client =
 330                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 331
 332                let path_style = remote_connection.path_style();
 333                let this = cx.new(|_| Self {
 334                    client: client.clone(),
 335                    unique_identifier: unique_identifier.clone(),
 336                    connection_options: remote_connection.connection_options(),
 337                    path_style,
 338                    state: Some(State::Connecting),
 339                })?;
 340
 341                let io_task = remote_connection.start_proxy(
 342                    unique_identifier,
 343                    false,
 344                    incoming_tx,
 345                    outgoing_rx,
 346                    connection_activity_tx,
 347                    delegate.clone(),
 348                    cx,
 349                );
 350
 351                let ready = client
 352                    .wait_for_remote_started()
 353                    .with_timeout(HEARTBEAT_TIMEOUT, cx.background_executor())
 354                    .await;
 355                match ready {
 356                    Ok(Some(_)) => {}
 357                    Ok(None) => {
 358                        let mut error = "remote client exited before becoming ready".to_owned();
 359                        if let Some(status) = io_task.now_or_never() {
 360                            match status {
 361                                Ok(exit_code) => {
 362                                    error.push_str(&format!(", exit_code={exit_code:?}"))
 363                                }
 364                                Err(e) => error.push_str(&format!(", error={e:?}")),
 365                            }
 366                        }
 367                        let error = anyhow::anyhow!("{error}");
 368                        log::error!("failed to establish connection: {}", error);
 369                        return Err(error);
 370                    }
 371                    Err(_) => {
 372                        let mut error =
 373                            "remote client did not become ready within the timeout".to_owned();
 374                        if let Some(status) = io_task.now_or_never() {
 375                            match status {
 376                                Ok(exit_code) => {
 377                                    error.push_str(&format!(", exit_code={exit_code:?}"))
 378                                }
 379                                Err(e) => error.push_str(&format!(", error={e:?}")),
 380                            }
 381                        }
 382                        let error = anyhow::anyhow!("{error}");
 383                        log::error!("failed to establish connection: {}", error);
 384                        return Err(error);
 385                    }
 386                }
 387                let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 388                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 389                    log::error!("failed to establish connection: {}", error);
 390                    return Err(error);
 391                }
 392
 393                let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
 394
 395                this.update(cx, |this, _| {
 396                    this.state = Some(State::Connected {
 397                        remote_connection,
 398                        delegate,
 399                        multiplex_task,
 400                        heartbeat_task,
 401                    });
 402                })?;
 403
 404                Ok(Some(this))
 405            });
 406
 407            select! {
 408                _ = cancellation.fuse() => {
 409                    Ok(None)
 410                }
 411                result = success.fuse() =>  result
 412            }
 413        })
 414    }
 415
 416    pub fn proto_client_from_channels(
 417        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 418        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 419        cx: &App,
 420        name: &'static str,
 421    ) -> AnyProtoClient {
 422        ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
 423    }
 424
 425    pub fn shutdown_processes<T: RequestMessage>(
 426        &mut self,
 427        shutdown_request: Option<T>,
 428        executor: BackgroundExecutor,
 429    ) -> Option<impl Future<Output = ()> + use<T>> {
 430        let state = self.state.take()?;
 431        log::info!("shutting down ssh processes");
 432
 433        let State::Connected {
 434            multiplex_task,
 435            heartbeat_task,
 436            remote_connection: ssh_connection,
 437            delegate,
 438        } = state
 439        else {
 440            return None;
 441        };
 442
 443        let client = self.client.clone();
 444
 445        Some(async move {
 446            if let Some(shutdown_request) = shutdown_request {
 447                client.send(shutdown_request).log_err();
 448                // We wait 50ms instead of waiting for a response, because
 449                // waiting for a response would require us to wait on the main thread
 450                // which we want to avoid in an `on_app_quit` callback.
 451                executor.timer(Duration::from_millis(50)).await;
 452            }
 453
 454            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 455            // child of master_process.
 456            drop(multiplex_task);
 457            // Now drop the rest of state, which kills master process.
 458            drop(heartbeat_task);
 459            drop(ssh_connection);
 460            drop(delegate);
 461        })
 462    }
 463
 464    fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
 465        let can_reconnect = self
 466            .state
 467            .as_ref()
 468            .map(|state| state.can_reconnect())
 469            .unwrap_or(false);
 470        if !can_reconnect {
 471            log::info!("aborting reconnect, because not in state that allows reconnecting");
 472            let error = if let Some(state) = self.state.as_ref() {
 473                format!("invalid state, cannot reconnect while in state {state}")
 474            } else {
 475                "no state set".to_string()
 476            };
 477            anyhow::bail!(error);
 478        }
 479
 480        let state = self.state.take().unwrap();
 481        let (attempts, remote_connection, delegate) = match state {
 482            State::Connected {
 483                remote_connection: ssh_connection,
 484                delegate,
 485                multiplex_task,
 486                heartbeat_task,
 487            }
 488            | State::HeartbeatMissed {
 489                ssh_connection,
 490                delegate,
 491                multiplex_task,
 492                heartbeat_task,
 493                ..
 494            } => {
 495                drop(multiplex_task);
 496                drop(heartbeat_task);
 497                (0, ssh_connection, delegate)
 498            }
 499            State::ReconnectFailed {
 500                attempts,
 501                ssh_connection,
 502                delegate,
 503                ..
 504            } => (attempts, ssh_connection, delegate),
 505            State::Connecting
 506            | State::Reconnecting
 507            | State::ReconnectExhausted
 508            | State::ServerNotRunning => unreachable!(),
 509        };
 510
 511        let attempts = attempts + 1;
 512        if attempts > MAX_RECONNECT_ATTEMPTS {
 513            log::error!(
 514                "Failed to reconnect to after {} attempts, giving up",
 515                MAX_RECONNECT_ATTEMPTS
 516            );
 517            self.set_state(State::ReconnectExhausted, cx);
 518            return Ok(());
 519        }
 520
 521        self.set_state(State::Reconnecting, cx);
 522
 523        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 524
 525        let unique_identifier = self.unique_identifier.clone();
 526        let client = self.client.clone();
 527        let reconnect_task = cx.spawn(async move |this, cx| {
 528            macro_rules! failed {
 529                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 530                    return State::ReconnectFailed {
 531                        error: anyhow!($error),
 532                        attempts: $attempts,
 533                        ssh_connection: $ssh_connection,
 534                        delegate: $delegate,
 535                    };
 536                };
 537            }
 538
 539            if let Err(error) = remote_connection
 540                .kill()
 541                .await
 542                .context("Failed to kill ssh process")
 543            {
 544                failed!(error, attempts, remote_connection, delegate);
 545            };
 546
 547            let connection_options = remote_connection.connection_options();
 548
 549            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 550            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 551            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 552
 553            let (ssh_connection, io_task) = match async {
 554                let ssh_connection = cx
 555                    .update_global(|pool: &mut ConnectionPool, cx| {
 556                        pool.connect(connection_options, delegate.clone(), cx)
 557                    })?
 558                    .await
 559                    .map_err(|error| error.cloned())?;
 560
 561                let io_task = ssh_connection.start_proxy(
 562                    unique_identifier,
 563                    true,
 564                    incoming_tx,
 565                    outgoing_rx,
 566                    connection_activity_tx,
 567                    delegate.clone(),
 568                    cx,
 569                );
 570                anyhow::Ok((ssh_connection, io_task))
 571            }
 572            .await
 573            {
 574                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 575                Err(error) => {
 576                    failed!(error, attempts, remote_connection, delegate);
 577                }
 578            };
 579
 580            let multiplex_task = Self::monitor(this.clone(), io_task, cx);
 581            client.reconnect(incoming_rx, outgoing_tx, cx);
 582
 583            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 584                failed!(error, attempts, ssh_connection, delegate);
 585            };
 586
 587            State::Connected {
 588                remote_connection: ssh_connection,
 589                delegate,
 590                multiplex_task,
 591                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
 592            }
 593        });
 594
 595        cx.spawn(async move |this, cx| {
 596            let new_state = reconnect_task.await;
 597            this.update(cx, |this, cx| {
 598                this.try_set_state(cx, |old_state| {
 599                    if old_state.is_reconnecting() {
 600                        match &new_state {
 601                            State::Connecting
 602                            | State::Reconnecting
 603                            | State::HeartbeatMissed { .. }
 604                            | State::ServerNotRunning => {}
 605                            State::Connected { .. } => {
 606                                log::info!("Successfully reconnected");
 607                            }
 608                            State::ReconnectFailed {
 609                                error, attempts, ..
 610                            } => {
 611                                log::error!(
 612                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 613                                    attempts,
 614                                    error
 615                                );
 616                            }
 617                            State::ReconnectExhausted => {
 618                                log::error!("Reconnect attempt failed and all attempts exhausted");
 619                            }
 620                        }
 621                        Some(new_state)
 622                    } else {
 623                        None
 624                    }
 625                });
 626
 627                if this.state_is(State::is_reconnect_failed) {
 628                    this.reconnect(cx)
 629                } else if this.state_is(State::is_reconnect_exhausted) {
 630                    Ok(())
 631                } else {
 632                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 633                    Ok(())
 634                }
 635            })
 636        })
 637        .detach_and_log_err(cx);
 638
 639        Ok(())
 640    }
 641
 642    fn heartbeat(
 643        this: WeakEntity<Self>,
 644        mut connection_activity_rx: mpsc::Receiver<()>,
 645        cx: &mut AsyncApp,
 646    ) -> Task<Result<()>> {
 647        let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
 648            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 649        };
 650
 651        cx.spawn(async move |cx| {
 652            let mut missed_heartbeats = 0;
 653
 654            let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 655            futures::pin_mut!(keepalive_timer);
 656
 657            loop {
 658                select_biased! {
 659                    result = connection_activity_rx.next().fuse() => {
 660                        if result.is_none() {
 661                            log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 662                            return Ok(());
 663                        }
 664
 665                        if missed_heartbeats != 0 {
 666                            missed_heartbeats = 0;
 667                            let _ =this.update(cx, |this, cx| {
 668                                this.handle_heartbeat_result(missed_heartbeats, cx)
 669                            })?;
 670                        }
 671                    }
 672                    _ = keepalive_timer => {
 673                        log::debug!("Sending heartbeat to server...");
 674
 675                        let result = select_biased! {
 676                            _ = connection_activity_rx.next().fuse() => {
 677                                Ok(())
 678                            }
 679                            ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 680                                ping_result
 681                            }
 682                        };
 683
 684                        if result.is_err() {
 685                            missed_heartbeats += 1;
 686                            log::warn!(
 687                                "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 688                                HEARTBEAT_TIMEOUT,
 689                                missed_heartbeats,
 690                                MAX_MISSED_HEARTBEATS
 691                            );
 692                        } else if missed_heartbeats != 0 {
 693                            missed_heartbeats = 0;
 694                        } else {
 695                            continue;
 696                        }
 697
 698                        let result = this.update(cx, |this, cx| {
 699                            this.handle_heartbeat_result(missed_heartbeats, cx)
 700                        })?;
 701                        if result.is_break() {
 702                            return Ok(());
 703                        }
 704                    }
 705                }
 706
 707                keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 708            }
 709        })
 710    }
 711
 712    fn handle_heartbeat_result(
 713        &mut self,
 714        missed_heartbeats: usize,
 715        cx: &mut Context<Self>,
 716    ) -> ControlFlow<()> {
 717        let state = self.state.take().unwrap();
 718        let next_state = if missed_heartbeats > 0 {
 719            state.heartbeat_missed()
 720        } else {
 721            state.heartbeat_recovered()
 722        };
 723
 724        self.set_state(next_state, cx);
 725
 726        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 727            log::error!(
 728                "Missed last {} heartbeats. Reconnecting...",
 729                missed_heartbeats
 730            );
 731
 732            self.reconnect(cx)
 733                .context("failed to start reconnect process after missing heartbeats")
 734                .log_err();
 735            ControlFlow::Break(())
 736        } else {
 737            ControlFlow::Continue(())
 738        }
 739    }
 740
 741    fn monitor(
 742        this: WeakEntity<Self>,
 743        io_task: Task<Result<i32>>,
 744        cx: &AsyncApp,
 745    ) -> Task<Result<()>> {
 746        cx.spawn(async move |cx| {
 747            let result = io_task.await;
 748
 749            match result {
 750                Ok(exit_code) => {
 751                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 752                        match error {
 753                            ProxyLaunchError::ServerNotRunning => {
 754                                log::error!("failed to reconnect because server is not running");
 755                                this.update(cx, |this, cx| {
 756                                    this.set_state(State::ServerNotRunning, cx);
 757                                })?;
 758                            }
 759                        }
 760                    } else if exit_code > 0 {
 761                        log::error!("proxy process terminated unexpectedly");
 762                        this.update(cx, |this, cx| {
 763                            this.reconnect(cx).ok();
 764                        })?;
 765                    }
 766                }
 767                Err(error) => {
 768                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 769                    this.update(cx, |this, cx| {
 770                        this.reconnect(cx).ok();
 771                    })?;
 772                }
 773            }
 774
 775            Ok(())
 776        })
 777    }
 778
 779    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 780        self.state.as_ref().is_some_and(check)
 781    }
 782
 783    fn try_set_state(&mut self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
 784        let new_state = self.state.as_ref().and_then(map);
 785        if let Some(new_state) = new_state {
 786            self.state.replace(new_state);
 787            cx.notify();
 788        }
 789    }
 790
 791    fn set_state(&mut self, state: State, cx: &mut Context<Self>) {
 792        log::info!("setting state to '{}'", &state);
 793
 794        let is_reconnect_exhausted = state.is_reconnect_exhausted();
 795        let is_server_not_running = state.is_server_not_running();
 796        self.state.replace(state);
 797
 798        if is_reconnect_exhausted || is_server_not_running {
 799            cx.emit(RemoteClientEvent::Disconnected);
 800        }
 801        cx.notify();
 802    }
 803
 804    pub fn shell(&self) -> Option<String> {
 805        Some(self.remote_connection()?.shell())
 806    }
 807
 808    pub fn default_system_shell(&self) -> Option<String> {
 809        Some(self.remote_connection()?.default_system_shell())
 810    }
 811
 812    pub fn shares_network_interface(&self) -> bool {
 813        self.remote_connection()
 814            .map_or(false, |connection| connection.shares_network_interface())
 815    }
 816
 817    pub fn build_command(
 818        &self,
 819        program: Option<String>,
 820        args: &[String],
 821        env: &HashMap<String, String>,
 822        working_dir: Option<String>,
 823        port_forward: Option<(u16, String, u16)>,
 824    ) -> Result<CommandTemplate> {
 825        let Some(connection) = self.remote_connection() else {
 826            return Err(anyhow!("no ssh connection"));
 827        };
 828        connection.build_command(program, args, env, working_dir, port_forward)
 829    }
 830
 831    pub fn build_forward_ports_command(
 832        &self,
 833        forwards: Vec<(u16, String, u16)>,
 834    ) -> Result<CommandTemplate> {
 835        let Some(connection) = self.remote_connection() else {
 836            return Err(anyhow!("no ssh connection"));
 837        };
 838        connection.build_forward_ports_command(forwards)
 839    }
 840
 841    pub fn upload_directory(
 842        &self,
 843        src_path: PathBuf,
 844        dest_path: RemotePathBuf,
 845        cx: &App,
 846    ) -> Task<Result<()>> {
 847        let Some(connection) = self.remote_connection() else {
 848            return Task::ready(Err(anyhow!("no ssh connection")));
 849        };
 850        connection.upload_directory(src_path, dest_path, cx)
 851    }
 852
 853    pub fn proto_client(&self) -> AnyProtoClient {
 854        self.client.clone().into()
 855    }
 856
 857    pub fn connection_options(&self) -> RemoteConnectionOptions {
 858        self.connection_options.clone()
 859    }
 860
 861    pub fn connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 862        if let State::Connected {
 863            remote_connection, ..
 864        } = self.state.as_ref()?
 865        {
 866            Some(remote_connection.clone())
 867        } else {
 868            None
 869        }
 870    }
 871
 872    pub fn connection_state(&self) -> ConnectionState {
 873        self.state
 874            .as_ref()
 875            .map(ConnectionState::from)
 876            .unwrap_or(ConnectionState::Disconnected)
 877    }
 878
 879    pub fn is_disconnected(&self) -> bool {
 880        self.connection_state() == ConnectionState::Disconnected
 881    }
 882
 883    pub fn path_style(&self) -> PathStyle {
 884        self.path_style
 885    }
 886
 887    #[cfg(any(test, feature = "test-support"))]
 888    pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
 889        let opts = self.connection_options();
 890        client_cx.spawn(async move |cx| {
 891            let connection = cx
 892                .update_global(|c: &mut ConnectionPool, _| {
 893                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
 894                        c.clone()
 895                    } else {
 896                        panic!("missing test connection")
 897                    }
 898                })
 899                .unwrap()
 900                .await
 901                .unwrap();
 902
 903            connection.simulate_disconnect(cx);
 904        })
 905    }
 906
 907    #[cfg(any(test, feature = "test-support"))]
 908    pub fn fake_server(
 909        client_cx: &mut gpui::TestAppContext,
 910        server_cx: &mut gpui::TestAppContext,
 911    ) -> (RemoteConnectionOptions, AnyProtoClient) {
 912        let port = client_cx
 913            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
 914        let opts = RemoteConnectionOptions::Ssh(SshConnectionOptions {
 915            host: "<fake>".to_string(),
 916            port: Some(port),
 917            ..Default::default()
 918        });
 919        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
 920        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
 921        let server_client =
 922            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
 923        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
 924            connection_options: opts.clone(),
 925            server_cx: fake::SendableCx::new(server_cx),
 926            server_channel: server_client.clone(),
 927        });
 928
 929        client_cx.update(|cx| {
 930            cx.update_default_global(|c: &mut ConnectionPool, cx| {
 931                c.connections.insert(
 932                    opts.clone(),
 933                    ConnectionPoolEntry::Connecting(
 934                        cx.background_spawn({
 935                            let connection = connection.clone();
 936                            async move { Ok(connection.clone()) }
 937                        })
 938                        .shared(),
 939                    ),
 940                );
 941            })
 942        });
 943
 944        (opts, server_client.into())
 945    }
 946
 947    #[cfg(any(test, feature = "test-support"))]
 948    pub async fn fake_client(
 949        opts: RemoteConnectionOptions,
 950        client_cx: &mut gpui::TestAppContext,
 951    ) -> Entity<Self> {
 952        let (_tx, rx) = oneshot::channel();
 953        let mut cx = client_cx.to_async();
 954        let connection = connect(opts, Arc::new(fake::Delegate), &mut cx)
 955            .await
 956            .unwrap();
 957        client_cx
 958            .update(|cx| {
 959                Self::new(
 960                    ConnectionIdentifier::setup(),
 961                    connection,
 962                    rx,
 963                    Arc::new(fake::Delegate),
 964                    cx,
 965                )
 966            })
 967            .await
 968            .unwrap()
 969            .unwrap()
 970    }
 971
 972    fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
 973        self.state
 974            .as_ref()
 975            .and_then(|state| state.remote_connection())
 976    }
 977}
 978
 979enum ConnectionPoolEntry {
 980    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
 981    Connected(Weak<dyn RemoteConnection>),
 982}
 983
 984#[derive(Default)]
 985struct ConnectionPool {
 986    connections: HashMap<RemoteConnectionOptions, ConnectionPoolEntry>,
 987}
 988
 989impl Global for ConnectionPool {}
 990
 991impl ConnectionPool {
 992    pub fn connect(
 993        &mut self,
 994        opts: RemoteConnectionOptions,
 995        delegate: Arc<dyn RemoteClientDelegate>,
 996        cx: &mut App,
 997    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
 998        let connection = self.connections.get(&opts);
 999        match connection {
1000            Some(ConnectionPoolEntry::Connecting(task)) => {
1001                let delegate = delegate.clone();
1002                cx.spawn(async move |cx| {
1003                    delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1004                })
1005                .detach();
1006                return task.clone();
1007            }
1008            Some(ConnectionPoolEntry::Connected(ssh)) => {
1009                if let Some(ssh) = ssh.upgrade()
1010                    && !ssh.has_been_killed()
1011                {
1012                    return Task::ready(Ok(ssh)).shared();
1013                }
1014                self.connections.remove(&opts);
1015            }
1016            None => {}
1017        }
1018
1019        let task = cx
1020            .spawn({
1021                let opts = opts.clone();
1022                let delegate = delegate.clone();
1023                async move |cx| {
1024                    let connection = match opts.clone() {
1025                        RemoteConnectionOptions::Ssh(opts) => {
1026                            SshRemoteConnection::new(opts, delegate, cx)
1027                                .await
1028                                .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1029                        }
1030                        RemoteConnectionOptions::Wsl(opts) => {
1031                            WslRemoteConnection::new(opts, delegate, cx)
1032                                .await
1033                                .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1034                        }
1035                    };
1036
1037                    cx.update_global(|pool: &mut Self, _| {
1038                        debug_assert!(matches!(
1039                            pool.connections.get(&opts),
1040                            Some(ConnectionPoolEntry::Connecting(_))
1041                        ));
1042                        match connection {
1043                            Ok(connection) => {
1044                                pool.connections.insert(
1045                                    opts.clone(),
1046                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1047                                );
1048                                Ok(connection)
1049                            }
1050                            Err(error) => {
1051                                pool.connections.remove(&opts);
1052                                Err(Arc::new(error))
1053                            }
1054                        }
1055                    })?
1056                }
1057            })
1058            .shared();
1059
1060        self.connections
1061            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1062        task
1063    }
1064}
1065
1066#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1067pub enum RemoteConnectionOptions {
1068    Ssh(SshConnectionOptions),
1069    Wsl(WslConnectionOptions),
1070}
1071
1072impl RemoteConnectionOptions {
1073    pub fn display_name(&self) -> String {
1074        match self {
1075            RemoteConnectionOptions::Ssh(opts) => opts.host.clone(),
1076            RemoteConnectionOptions::Wsl(opts) => opts.distro_name.clone(),
1077        }
1078    }
1079}
1080
1081impl From<SshConnectionOptions> for RemoteConnectionOptions {
1082    fn from(opts: SshConnectionOptions) -> Self {
1083        RemoteConnectionOptions::Ssh(opts)
1084    }
1085}
1086
1087impl From<WslConnectionOptions> for RemoteConnectionOptions {
1088    fn from(opts: WslConnectionOptions) -> Self {
1089        RemoteConnectionOptions::Wsl(opts)
1090    }
1091}
1092
1093#[cfg(target_os = "windows")]
1094/// Open a wsl path (\\wsl.localhost\<distro>\path)
1095#[derive(Debug, Clone, PartialEq, Eq, gpui::Action)]
1096#[action(namespace = workspace, no_json, no_register)]
1097pub struct OpenWslPath {
1098    pub distro: WslConnectionOptions,
1099    pub paths: Vec<PathBuf>,
1100}
1101
1102#[async_trait(?Send)]
1103pub trait RemoteConnection: Send + Sync {
1104    fn start_proxy(
1105        &self,
1106        unique_identifier: String,
1107        reconnect: bool,
1108        incoming_tx: UnboundedSender<Envelope>,
1109        outgoing_rx: UnboundedReceiver<Envelope>,
1110        connection_activity_tx: Sender<()>,
1111        delegate: Arc<dyn RemoteClientDelegate>,
1112        cx: &mut AsyncApp,
1113    ) -> Task<Result<i32>>;
1114    fn upload_directory(
1115        &self,
1116        src_path: PathBuf,
1117        dest_path: RemotePathBuf,
1118        cx: &App,
1119    ) -> Task<Result<()>>;
1120    async fn kill(&self) -> Result<()>;
1121    fn has_been_killed(&self) -> bool;
1122    fn shares_network_interface(&self) -> bool {
1123        false
1124    }
1125    fn build_command(
1126        &self,
1127        program: Option<String>,
1128        args: &[String],
1129        env: &HashMap<String, String>,
1130        working_dir: Option<String>,
1131        port_forward: Option<(u16, String, u16)>,
1132    ) -> Result<CommandTemplate>;
1133    fn build_forward_ports_command(
1134        &self,
1135        forwards: Vec<(u16, String, u16)>,
1136    ) -> Result<CommandTemplate>;
1137    fn connection_options(&self) -> RemoteConnectionOptions;
1138    fn path_style(&self) -> PathStyle;
1139    fn shell(&self) -> String;
1140    fn default_system_shell(&self) -> String;
1141
1142    #[cfg(any(test, feature = "test-support"))]
1143    fn simulate_disconnect(&self, _: &AsyncApp) {}
1144}
1145
1146type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1147
1148struct Signal<T> {
1149    tx: Mutex<Option<oneshot::Sender<T>>>,
1150    rx: Shared<Task<Option<T>>>,
1151}
1152
1153impl<T: Send + Clone + 'static> Signal<T> {
1154    pub fn new(cx: &App) -> Self {
1155        let (tx, rx) = oneshot::channel();
1156
1157        let task = cx
1158            .background_executor()
1159            .spawn(async move { rx.await.ok() })
1160            .shared();
1161
1162        Self {
1163            tx: Mutex::new(Some(tx)),
1164            rx: task,
1165        }
1166    }
1167
1168    fn set(&self, value: T) {
1169        if let Some(tx) = self.tx.lock().take() {
1170            let _ = tx.send(value);
1171        }
1172    }
1173
1174    fn wait(&self) -> Shared<Task<Option<T>>> {
1175        self.rx.clone()
1176    }
1177}
1178
1179struct ChannelClient {
1180    next_message_id: AtomicU32,
1181    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1182    buffer: Mutex<VecDeque<Envelope>>,
1183    response_channels: ResponseChannels,
1184    message_handlers: Mutex<ProtoMessageHandlerSet>,
1185    max_received: AtomicU32,
1186    name: &'static str,
1187    task: Mutex<Task<Result<()>>>,
1188    remote_started: Signal<()>,
1189}
1190
1191impl ChannelClient {
1192    fn new(
1193        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1194        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1195        cx: &App,
1196        name: &'static str,
1197    ) -> Arc<Self> {
1198        Arc::new_cyclic(|this| Self {
1199            outgoing_tx: Mutex::new(outgoing_tx),
1200            next_message_id: AtomicU32::new(0),
1201            max_received: AtomicU32::new(0),
1202            response_channels: ResponseChannels::default(),
1203            message_handlers: Default::default(),
1204            buffer: Mutex::new(VecDeque::new()),
1205            name,
1206            task: Mutex::new(Self::start_handling_messages(
1207                this.clone(),
1208                incoming_rx,
1209                &cx.to_async(),
1210            )),
1211            remote_started: Signal::new(cx),
1212        })
1213    }
1214
1215    fn wait_for_remote_started(&self) -> Shared<Task<Option<()>>> {
1216        self.remote_started.wait()
1217    }
1218
1219    fn start_handling_messages(
1220        this: Weak<Self>,
1221        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1222        cx: &AsyncApp,
1223    ) -> Task<Result<()>> {
1224        cx.spawn(async move |cx| {
1225            if let Some(this) = this.upgrade() {
1226                let envelope = proto::RemoteStarted {}.into_envelope(0, None, None);
1227                this.outgoing_tx.lock().unbounded_send(envelope).ok();
1228            };
1229
1230            let peer_id = PeerId { owner_id: 0, id: 0 };
1231            while let Some(incoming) = incoming_rx.next().await {
1232                let Some(this) = this.upgrade() else {
1233                    return anyhow::Ok(());
1234                };
1235                if let Some(ack_id) = incoming.ack_id {
1236                    let mut buffer = this.buffer.lock();
1237                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1238                        buffer.pop_front();
1239                    }
1240                }
1241                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
1242                {
1243                    log::debug!(
1244                        "{}:ssh message received. name:FlushBufferedMessages",
1245                        this.name
1246                    );
1247                    {
1248                        let buffer = this.buffer.lock();
1249                        for envelope in buffer.iter() {
1250                            this.outgoing_tx
1251                                .lock()
1252                                .unbounded_send(envelope.clone())
1253                                .ok();
1254                        }
1255                    }
1256                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1257                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1258                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
1259                    continue;
1260                }
1261
1262                if let Some(proto::envelope::Payload::RemoteStarted(_)) = &incoming.payload {
1263                    this.remote_started.set(());
1264                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1265                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1266                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
1267                    continue;
1268                }
1269
1270                this.max_received.store(incoming.id, SeqCst);
1271
1272                if let Some(request_id) = incoming.responding_to {
1273                    let request_id = MessageId(request_id);
1274                    let sender = this.response_channels.lock().remove(&request_id);
1275                    if let Some(sender) = sender {
1276                        let (tx, rx) = oneshot::channel();
1277                        if incoming.payload.is_some() {
1278                            sender.send((incoming, tx)).ok();
1279                        }
1280                        rx.await.ok();
1281                    }
1282                } else if let Some(envelope) =
1283                    build_typed_envelope(peer_id, Instant::now(), incoming)
1284                {
1285                    let type_name = envelope.payload_type_name();
1286                    let message_id = envelope.message_id();
1287                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
1288                        &this.message_handlers,
1289                        envelope,
1290                        this.clone().into(),
1291                        cx.clone(),
1292                    ) {
1293                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
1294                        cx.foreground_executor()
1295                            .spawn(async move {
1296                                match future.await {
1297                                    Ok(_) => {
1298                                        log::debug!(
1299                                            "{}:ssh message handled. name:{type_name}",
1300                                            this.name
1301                                        );
1302                                    }
1303                                    Err(error) => {
1304                                        log::error!(
1305                                            "{}:error handling message. type:{}, error:{:#}",
1306                                            this.name,
1307                                            type_name,
1308                                            format!("{error:#}").lines().fold(
1309                                                String::new(),
1310                                                |mut message, line| {
1311                                                    if !message.is_empty() {
1312                                                        message.push(' ');
1313                                                    }
1314                                                    message.push_str(line);
1315                                                    message
1316                                                }
1317                                            )
1318                                        );
1319                                    }
1320                                }
1321                            })
1322                            .detach()
1323                    } else {
1324                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1325                        if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
1326                            message_id,
1327                            anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
1328                        ) {
1329                            log::error!(
1330                                "{}:error sending error response for {type_name}:{e:#}",
1331                                this.name
1332                            );
1333                        }
1334                    }
1335                }
1336            }
1337            anyhow::Ok(())
1338        })
1339    }
1340
1341    fn reconnect(
1342        self: &Arc<Self>,
1343        incoming_rx: UnboundedReceiver<Envelope>,
1344        outgoing_tx: UnboundedSender<Envelope>,
1345        cx: &AsyncApp,
1346    ) {
1347        *self.outgoing_tx.lock() = outgoing_tx;
1348        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1349    }
1350
1351    fn request<T: RequestMessage>(
1352        &self,
1353        payload: T,
1354    ) -> impl 'static + Future<Output = Result<T::Response>> {
1355        self.request_internal(payload, true)
1356    }
1357
1358    fn request_internal<T: RequestMessage>(
1359        &self,
1360        payload: T,
1361        use_buffer: bool,
1362    ) -> impl 'static + Future<Output = Result<T::Response>> {
1363        log::debug!("ssh request start. name:{}", T::NAME);
1364        let response =
1365            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
1366        async move {
1367            let response = response.await?;
1368            log::debug!("ssh request finish. name:{}", T::NAME);
1369            T::Response::from_envelope(response).context("received a response of the wrong type")
1370        }
1371    }
1372
1373    async fn resync(&self, timeout: Duration) -> Result<()> {
1374        smol::future::or(
1375            async {
1376                self.request_internal(proto::FlushBufferedMessages {}, false)
1377                    .await?;
1378
1379                for envelope in self.buffer.lock().iter() {
1380                    self.outgoing_tx
1381                        .lock()
1382                        .unbounded_send(envelope.clone())
1383                        .ok();
1384                }
1385                Ok(())
1386            },
1387            async {
1388                smol::Timer::after(timeout).await;
1389                anyhow::bail!("Timed out resyncing remote client")
1390            },
1391        )
1392        .await
1393    }
1394
1395    async fn ping(&self, timeout: Duration) -> Result<()> {
1396        smol::future::or(
1397            async {
1398                self.request(proto::Ping {}).await?;
1399                Ok(())
1400            },
1401            async {
1402                smol::Timer::after(timeout).await;
1403                anyhow::bail!("Timed out pinging remote client")
1404            },
1405        )
1406        .await
1407    }
1408
1409    fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1410        log::debug!("ssh send name:{}", T::NAME);
1411        self.send_dynamic(payload.into_envelope(0, None, None))
1412    }
1413
1414    fn request_dynamic(
1415        &self,
1416        mut envelope: proto::Envelope,
1417        type_name: &'static str,
1418        use_buffer: bool,
1419    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1420        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1421        let (tx, rx) = oneshot::channel();
1422        let mut response_channels_lock = self.response_channels.lock();
1423        response_channels_lock.insert(MessageId(envelope.id), tx);
1424        drop(response_channels_lock);
1425
1426        let result = if use_buffer {
1427            self.send_buffered(envelope)
1428        } else {
1429            self.send_unbuffered(envelope)
1430        };
1431        async move {
1432            if let Err(error) = &result {
1433                log::error!("failed to send message: {error}");
1434                anyhow::bail!("failed to send message: {error}");
1435            }
1436
1437            let response = rx.await.context("connection lost")?.0;
1438            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1439                return Err(RpcError::from_proto(error, type_name));
1440            }
1441            Ok(response)
1442        }
1443    }
1444
1445    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1446        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1447        self.send_buffered(envelope)
1448    }
1449
1450    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1451        envelope.ack_id = Some(self.max_received.load(SeqCst));
1452        self.buffer.lock().push_back(envelope.clone());
1453        // ignore errors on send (happen while we're reconnecting)
1454        // assume that the global "disconnected" overlay is sufficient.
1455        self.outgoing_tx.lock().unbounded_send(envelope).ok();
1456        Ok(())
1457    }
1458
1459    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1460        envelope.ack_id = Some(self.max_received.load(SeqCst));
1461        self.outgoing_tx.lock().unbounded_send(envelope).ok();
1462        Ok(())
1463    }
1464}
1465
1466impl ProtoClient for ChannelClient {
1467    fn request(
1468        &self,
1469        envelope: proto::Envelope,
1470        request_type: &'static str,
1471    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1472        self.request_dynamic(envelope, request_type, true).boxed()
1473    }
1474
1475    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1476        self.send_dynamic(envelope)
1477    }
1478
1479    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1480        self.send_dynamic(envelope)
1481    }
1482
1483    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1484        &self.message_handlers
1485    }
1486
1487    fn is_via_collab(&self) -> bool {
1488        false
1489    }
1490}
1491
1492#[cfg(any(test, feature = "test-support"))]
1493mod fake {
1494    use super::{ChannelClient, RemoteClientDelegate, RemoteConnection, RemotePlatform};
1495    use crate::remote_client::{CommandTemplate, RemoteConnectionOptions};
1496    use anyhow::Result;
1497    use askpass::EncryptedPassword;
1498    use async_trait::async_trait;
1499    use collections::HashMap;
1500    use futures::{
1501        FutureExt, SinkExt, StreamExt,
1502        channel::{
1503            mpsc::{self, Sender},
1504            oneshot,
1505        },
1506        select_biased,
1507    };
1508    use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
1509    use release_channel::ReleaseChannel;
1510    use rpc::proto::Envelope;
1511    use std::{path::PathBuf, sync::Arc};
1512    use util::paths::{PathStyle, RemotePathBuf};
1513
1514    pub(super) struct FakeRemoteConnection {
1515        pub(super) connection_options: RemoteConnectionOptions,
1516        pub(super) server_channel: Arc<ChannelClient>,
1517        pub(super) server_cx: SendableCx,
1518    }
1519
1520    pub(super) struct SendableCx(AsyncApp);
1521    impl SendableCx {
1522        // SAFETY: When run in test mode, GPUI is always single threaded.
1523        pub(super) fn new(cx: &TestAppContext) -> Self {
1524            Self(cx.to_async())
1525        }
1526
1527        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
1528        fn get(&self, _: &AsyncApp) -> AsyncApp {
1529            self.0.clone()
1530        }
1531    }
1532
1533    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
1534    unsafe impl Send for SendableCx {}
1535    unsafe impl Sync for SendableCx {}
1536
1537    #[async_trait(?Send)]
1538    impl RemoteConnection for FakeRemoteConnection {
1539        async fn kill(&self) -> Result<()> {
1540            Ok(())
1541        }
1542
1543        fn has_been_killed(&self) -> bool {
1544            false
1545        }
1546
1547        fn build_command(
1548            &self,
1549            program: Option<String>,
1550            args: &[String],
1551            env: &HashMap<String, String>,
1552            _: Option<String>,
1553            _: Option<(u16, String, u16)>,
1554        ) -> Result<CommandTemplate> {
1555            let ssh_program = program.unwrap_or_else(|| "sh".to_string());
1556            let mut ssh_args = Vec::new();
1557            ssh_args.push(ssh_program);
1558            ssh_args.extend(args.iter().cloned());
1559            Ok(CommandTemplate {
1560                program: "ssh".into(),
1561                args: ssh_args,
1562                env: env.clone(),
1563            })
1564        }
1565
1566        fn build_forward_ports_command(
1567            &self,
1568            forwards: Vec<(u16, String, u16)>,
1569        ) -> anyhow::Result<CommandTemplate> {
1570            Ok(CommandTemplate {
1571                program: "ssh".into(),
1572                args: std::iter::once("-N".to_owned())
1573                    .chain(forwards.into_iter().map(|(local_port, host, remote_port)| {
1574                        format!("{local_port}:{host}:{remote_port}")
1575                    }))
1576                    .collect(),
1577                env: Default::default(),
1578            })
1579        }
1580
1581        fn upload_directory(
1582            &self,
1583            _src_path: PathBuf,
1584            _dest_path: RemotePathBuf,
1585            _cx: &App,
1586        ) -> Task<Result<()>> {
1587            unreachable!()
1588        }
1589
1590        fn connection_options(&self) -> RemoteConnectionOptions {
1591            self.connection_options.clone()
1592        }
1593
1594        fn simulate_disconnect(&self, cx: &AsyncApp) {
1595            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1596            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1597            self.server_channel
1598                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
1599        }
1600
1601        fn start_proxy(
1602            &self,
1603            _unique_identifier: String,
1604            _reconnect: bool,
1605            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
1606            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
1607            mut connection_activity_tx: Sender<()>,
1608            _delegate: Arc<dyn RemoteClientDelegate>,
1609            cx: &mut AsyncApp,
1610        ) -> Task<Result<i32>> {
1611            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
1612            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
1613
1614            self.server_channel.reconnect(
1615                server_incoming_rx,
1616                server_outgoing_tx,
1617                &self.server_cx.get(cx),
1618            );
1619
1620            cx.background_spawn(async move {
1621                loop {
1622                    select_biased! {
1623                        server_to_client = server_outgoing_rx.next().fuse() => {
1624                            let Some(server_to_client) = server_to_client else {
1625                                return Ok(1)
1626                            };
1627                            connection_activity_tx.try_send(()).ok();
1628                            client_incoming_tx.send(server_to_client).await.ok();
1629                        }
1630                        client_to_server = client_outgoing_rx.next().fuse() => {
1631                            let Some(client_to_server) = client_to_server else {
1632                                return Ok(1)
1633                            };
1634                            server_incoming_tx.send(client_to_server).await.ok();
1635                        }
1636                    }
1637                }
1638            })
1639        }
1640
1641        fn path_style(&self) -> PathStyle {
1642            PathStyle::local()
1643        }
1644
1645        fn shell(&self) -> String {
1646            "sh".to_owned()
1647        }
1648
1649        fn default_system_shell(&self) -> String {
1650            "sh".to_owned()
1651        }
1652    }
1653
1654    pub(super) struct Delegate;
1655
1656    impl RemoteClientDelegate for Delegate {
1657        fn ask_password(&self, _: String, _: oneshot::Sender<EncryptedPassword>, _: &mut AsyncApp) {
1658            unreachable!()
1659        }
1660
1661        fn download_server_binary_locally(
1662            &self,
1663            _: RemotePlatform,
1664            _: ReleaseChannel,
1665            _: Option<SemanticVersion>,
1666            _: &mut AsyncApp,
1667        ) -> Task<Result<PathBuf>> {
1668            unreachable!()
1669        }
1670
1671        fn get_download_params(
1672            &self,
1673            _platform: RemotePlatform,
1674            _release_channel: ReleaseChannel,
1675            _version: Option<SemanticVersion>,
1676            _cx: &mut AsyncApp,
1677        ) -> Task<Result<Option<(String, String)>>> {
1678            unreachable!()
1679        }
1680
1681        fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
1682    }
1683}