ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use collections::HashMap;
  10use futures::{
  11    channel::{
  12        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  13        oneshot,
  14    },
  15    future::BoxFuture,
  16    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  17    StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
  21    WeakModel,
  22};
  23use parking_lot::Mutex;
  24use rpc::{
  25    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  26    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  27};
  28use smol::{
  29    fs,
  30    process::{self, Child, Stdio},
  31};
  32use std::{
  33    any::TypeId,
  34    ffi::OsStr,
  35    fmt,
  36    ops::ControlFlow,
  37    path::{Path, PathBuf},
  38    sync::{
  39        atomic::{AtomicU32, Ordering::SeqCst},
  40        Arc,
  41    },
  42    time::{Duration, Instant},
  43};
  44use tempfile::TempDir;
  45use util::ResultExt;
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  49)]
  50pub struct SshProjectId(pub u64);
  51
  52#[derive(Clone)]
  53pub struct SshSocket {
  54    connection_options: SshConnectionOptions,
  55    socket_path: PathBuf,
  56}
  57
  58#[derive(Debug, Default, Clone, PartialEq, Eq)]
  59pub struct SshConnectionOptions {
  60    pub host: String,
  61    pub username: Option<String>,
  62    pub port: Option<u16>,
  63    pub password: Option<String>,
  64}
  65
  66impl SshConnectionOptions {
  67    pub fn ssh_url(&self) -> String {
  68        let mut result = String::from("ssh://");
  69        if let Some(username) = &self.username {
  70            result.push_str(username);
  71            result.push('@');
  72        }
  73        result.push_str(&self.host);
  74        if let Some(port) = self.port {
  75            result.push(':');
  76            result.push_str(&port.to_string());
  77        }
  78        result
  79    }
  80
  81    fn scp_url(&self) -> String {
  82        if let Some(username) = &self.username {
  83            format!("{}@{}", username, self.host)
  84        } else {
  85            self.host.clone()
  86        }
  87    }
  88
  89    pub fn connection_string(&self) -> String {
  90        let host = if let Some(username) = &self.username {
  91            format!("{}@{}", username, self.host)
  92        } else {
  93            self.host.clone()
  94        };
  95        if let Some(port) = &self.port {
  96            format!("{}:{}", host, port)
  97        } else {
  98            host
  99        }
 100    }
 101
 102    // Uniquely identifies dev server projects on a remote host. Needs to be
 103    // stable for the same dev server project.
 104    pub fn dev_server_identifier(&self) -> String {
 105        let mut identifier = format!("dev-server-{:?}", self.host);
 106        if let Some(username) = self.username.as_ref() {
 107            identifier.push('-');
 108            identifier.push_str(&username);
 109        }
 110        identifier
 111    }
 112}
 113
 114#[derive(Copy, Clone, Debug)]
 115pub struct SshPlatform {
 116    pub os: &'static str,
 117    pub arch: &'static str,
 118}
 119
 120impl SshPlatform {
 121    pub fn triple(&self) -> Option<String> {
 122        Some(format!(
 123            "{}-{}",
 124            self.arch,
 125            match self.os {
 126                "linux" => "unknown-linux-gnu",
 127                "macos" => "apple-darwin",
 128                _ => return None,
 129            }
 130        ))
 131    }
 132}
 133
 134pub trait SshClientDelegate: Send + Sync {
 135    fn ask_password(
 136        &self,
 137        prompt: String,
 138        cx: &mut AsyncAppContext,
 139    ) -> oneshot::Receiver<Result<String>>;
 140    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 141    fn get_server_binary(
 142        &self,
 143        platform: SshPlatform,
 144        cx: &mut AsyncAppContext,
 145    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 146    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 147    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 148}
 149
 150impl SshSocket {
 151    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 152        let mut command = process::Command::new("ssh");
 153        self.ssh_options(&mut command)
 154            .arg(self.connection_options.ssh_url())
 155            .arg(program);
 156        command
 157    }
 158
 159    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 160        command
 161            .stdin(Stdio::piped())
 162            .stdout(Stdio::piped())
 163            .stderr(Stdio::piped())
 164            .args(["-o", "ControlMaster=no", "-o"])
 165            .arg(format!("ControlPath={}", self.socket_path.display()))
 166    }
 167
 168    fn ssh_args(&self) -> Vec<String> {
 169        vec![
 170            "-o".to_string(),
 171            "ControlMaster=no".to_string(),
 172            "-o".to_string(),
 173            format!("ControlPath={}", self.socket_path.display()),
 174            self.connection_options.ssh_url(),
 175        ]
 176    }
 177}
 178
 179async fn run_cmd(command: &mut process::Command) -> Result<String> {
 180    let output = command.output().await?;
 181    if output.status.success() {
 182        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 183    } else {
 184        Err(anyhow!(
 185            "failed to run command: {}",
 186            String::from_utf8_lossy(&output.stderr)
 187        ))
 188    }
 189}
 190
 191struct ChannelForwarder {
 192    quit_tx: UnboundedSender<()>,
 193    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 194}
 195
 196impl ChannelForwarder {
 197    fn new(
 198        mut incoming_tx: UnboundedSender<Envelope>,
 199        mut outgoing_rx: UnboundedReceiver<Envelope>,
 200        cx: &AsyncAppContext,
 201    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 202        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 203
 204        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 205        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 206
 207        let forwarding_task = cx.background_executor().spawn(async move {
 208            loop {
 209                select_biased! {
 210                    _ = quit_rx.next().fuse() => {
 211                        break;
 212                    },
 213                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 214                        if let Some(envelope) = incoming_envelope {
 215                            if incoming_tx.send(envelope).await.is_err() {
 216                                break;
 217                            }
 218                        } else {
 219                            break;
 220                        }
 221                    }
 222                    outgoing_envelope = outgoing_rx.next().fuse() => {
 223                        if let Some(envelope) = outgoing_envelope {
 224                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 225                                break;
 226                            }
 227                        } else {
 228                            break;
 229                        }
 230                    }
 231                }
 232            }
 233
 234            (incoming_tx, outgoing_rx)
 235        });
 236
 237        (
 238            Self {
 239                forwarding_task,
 240                quit_tx,
 241            },
 242            proxy_incoming_tx,
 243            proxy_outgoing_rx,
 244        )
 245    }
 246
 247    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 248        let _ = self.quit_tx.send(()).await;
 249        self.forwarding_task.await
 250    }
 251}
 252
 253const MAX_MISSED_HEARTBEATS: usize = 5;
 254const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 255const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 256
 257const MAX_RECONNECT_ATTEMPTS: usize = 3;
 258
 259enum State {
 260    Connecting,
 261    Connected {
 262        ssh_connection: SshRemoteConnection,
 263        delegate: Arc<dyn SshClientDelegate>,
 264        forwarder: ChannelForwarder,
 265
 266        multiplex_task: Task<Result<()>>,
 267        heartbeat_task: Task<Result<()>>,
 268    },
 269    HeartbeatMissed {
 270        missed_heartbeats: usize,
 271
 272        ssh_connection: SshRemoteConnection,
 273        delegate: Arc<dyn SshClientDelegate>,
 274        forwarder: ChannelForwarder,
 275
 276        multiplex_task: Task<Result<()>>,
 277        heartbeat_task: Task<Result<()>>,
 278    },
 279    Reconnecting,
 280    ReconnectFailed {
 281        ssh_connection: SshRemoteConnection,
 282        delegate: Arc<dyn SshClientDelegate>,
 283        forwarder: ChannelForwarder,
 284
 285        error: anyhow::Error,
 286        attempts: usize,
 287    },
 288    ReconnectExhausted,
 289    ServerNotRunning,
 290}
 291
 292impl fmt::Display for State {
 293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 294        match self {
 295            Self::Connecting => write!(f, "connecting"),
 296            Self::Connected { .. } => write!(f, "connected"),
 297            Self::Reconnecting => write!(f, "reconnecting"),
 298            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 299            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 300            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 301            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 302        }
 303    }
 304}
 305
 306impl State {
 307    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
 308        match self {
 309            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
 310            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
 311            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
 312            _ => None,
 313        }
 314    }
 315
 316    fn can_reconnect(&self) -> bool {
 317        match self {
 318            Self::Connected { .. }
 319            | Self::HeartbeatMissed { .. }
 320            | Self::ReconnectFailed { .. } => true,
 321            State::Connecting
 322            | State::Reconnecting
 323            | State::ReconnectExhausted
 324            | State::ServerNotRunning => false,
 325        }
 326    }
 327
 328    fn is_reconnect_failed(&self) -> bool {
 329        matches!(self, Self::ReconnectFailed { .. })
 330    }
 331
 332    fn is_reconnect_exhausted(&self) -> bool {
 333        matches!(self, Self::ReconnectExhausted { .. })
 334    }
 335
 336    fn is_reconnecting(&self) -> bool {
 337        matches!(self, Self::Reconnecting { .. })
 338    }
 339
 340    fn heartbeat_recovered(self) -> Self {
 341        match self {
 342            Self::HeartbeatMissed {
 343                ssh_connection,
 344                delegate,
 345                forwarder,
 346                multiplex_task,
 347                heartbeat_task,
 348                ..
 349            } => Self::Connected {
 350                ssh_connection,
 351                delegate,
 352                forwarder,
 353                multiplex_task,
 354                heartbeat_task,
 355            },
 356            _ => self,
 357        }
 358    }
 359
 360    fn heartbeat_missed(self) -> Self {
 361        match self {
 362            Self::Connected {
 363                ssh_connection,
 364                delegate,
 365                forwarder,
 366                multiplex_task,
 367                heartbeat_task,
 368            } => Self::HeartbeatMissed {
 369                missed_heartbeats: 1,
 370                ssh_connection,
 371                delegate,
 372                forwarder,
 373                multiplex_task,
 374                heartbeat_task,
 375            },
 376            Self::HeartbeatMissed {
 377                missed_heartbeats,
 378                ssh_connection,
 379                delegate,
 380                forwarder,
 381                multiplex_task,
 382                heartbeat_task,
 383            } => Self::HeartbeatMissed {
 384                missed_heartbeats: missed_heartbeats + 1,
 385                ssh_connection,
 386                delegate,
 387                forwarder,
 388                multiplex_task,
 389                heartbeat_task,
 390            },
 391            _ => self,
 392        }
 393    }
 394}
 395
 396/// The state of the ssh connection.
 397#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 398pub enum ConnectionState {
 399    Connecting,
 400    Connected,
 401    HeartbeatMissed,
 402    Reconnecting,
 403    Disconnected,
 404}
 405
 406impl From<&State> for ConnectionState {
 407    fn from(value: &State) -> Self {
 408        match value {
 409            State::Connecting => Self::Connecting,
 410            State::Connected { .. } => Self::Connected,
 411            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 412            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 413            State::ReconnectExhausted => Self::Disconnected,
 414            State::ServerNotRunning => Self::Disconnected,
 415        }
 416    }
 417}
 418
 419pub struct SshRemoteClient {
 420    client: Arc<ChannelClient>,
 421    unique_identifier: String,
 422    connection_options: SshConnectionOptions,
 423    state: Arc<Mutex<Option<State>>>,
 424}
 425
 426#[derive(Debug)]
 427pub enum SshRemoteEvent {
 428    Disconnected,
 429}
 430
 431impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 432
 433impl SshRemoteClient {
 434    pub fn new(
 435        unique_identifier: String,
 436        connection_options: SshConnectionOptions,
 437        delegate: Arc<dyn SshClientDelegate>,
 438        cx: &AppContext,
 439    ) -> Task<Result<Model<Self>>> {
 440        cx.spawn(|mut cx| async move {
 441            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 442            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 443            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 444
 445            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 446            let this = cx.new_model(|_| Self {
 447                client: client.clone(),
 448                unique_identifier: unique_identifier.clone(),
 449                connection_options: connection_options.clone(),
 450                state: Arc::new(Mutex::new(Some(State::Connecting))),
 451            })?;
 452
 453            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 454                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 455
 456            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 457                unique_identifier,
 458                false,
 459                connection_options,
 460                delegate.clone(),
 461                &mut cx,
 462            )
 463            .await?;
 464
 465            let multiplex_task = Self::multiplex(
 466                this.downgrade(),
 467                ssh_proxy_process,
 468                proxy_incoming_tx,
 469                proxy_outgoing_rx,
 470                connection_activity_tx,
 471                &mut cx,
 472            );
 473
 474            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 475                log::error!("failed to establish connection: {}", error);
 476                delegate.set_error(error.to_string(), &mut cx);
 477                return Err(error);
 478            }
 479
 480            let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 481
 482            this.update(&mut cx, |this, _| {
 483                *this.state.lock() = Some(State::Connected {
 484                    ssh_connection,
 485                    delegate,
 486                    forwarder: proxy,
 487                    multiplex_task,
 488                    heartbeat_task,
 489                });
 490            })?;
 491
 492            Ok(this)
 493        })
 494    }
 495
 496    pub fn shutdown_processes<T: RequestMessage>(
 497        &self,
 498        shutdown_request: Option<T>,
 499    ) -> Option<impl Future<Output = ()>> {
 500        let state = self.state.lock().take()?;
 501        log::info!("shutting down ssh processes");
 502
 503        let State::Connected {
 504            multiplex_task,
 505            heartbeat_task,
 506            ssh_connection,
 507            delegate,
 508            forwarder,
 509        } = state
 510        else {
 511            return None;
 512        };
 513
 514        let client = self.client.clone();
 515
 516        Some(async move {
 517            if let Some(shutdown_request) = shutdown_request {
 518                client.send(shutdown_request).log_err();
 519                // We wait 50ms instead of waiting for a response, because
 520                // waiting for a response would require us to wait on the main thread
 521                // which we want to avoid in an `on_app_quit` callback.
 522                smol::Timer::after(Duration::from_millis(50)).await;
 523            }
 524
 525            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 526            // child of master_process.
 527            drop(multiplex_task);
 528            // Now drop the rest of state, which kills master process.
 529            drop(heartbeat_task);
 530            drop(ssh_connection);
 531            drop(delegate);
 532            drop(forwarder);
 533        })
 534    }
 535
 536    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 537        let mut lock = self.state.lock();
 538
 539        let can_reconnect = lock
 540            .as_ref()
 541            .map(|state| state.can_reconnect())
 542            .unwrap_or(false);
 543        if !can_reconnect {
 544            let error = if let Some(state) = lock.as_ref() {
 545                format!("invalid state, cannot reconnect while in state {state}")
 546            } else {
 547                "no state set".to_string()
 548            };
 549            log::info!("aborting reconnect, because not in state that allows reconnecting");
 550            return Err(anyhow!(error));
 551        }
 552
 553        let state = lock.take().unwrap();
 554        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 555            State::Connected {
 556                ssh_connection,
 557                delegate,
 558                forwarder,
 559                multiplex_task,
 560                heartbeat_task,
 561            }
 562            | State::HeartbeatMissed {
 563                ssh_connection,
 564                delegate,
 565                forwarder,
 566                multiplex_task,
 567                heartbeat_task,
 568                ..
 569            } => {
 570                drop(multiplex_task);
 571                drop(heartbeat_task);
 572                (0, ssh_connection, delegate, forwarder)
 573            }
 574            State::ReconnectFailed {
 575                attempts,
 576                ssh_connection,
 577                delegate,
 578                forwarder,
 579                ..
 580            } => (attempts, ssh_connection, delegate, forwarder),
 581            State::Connecting
 582            | State::Reconnecting
 583            | State::ReconnectExhausted
 584            | State::ServerNotRunning => unreachable!(),
 585        };
 586
 587        let attempts = attempts + 1;
 588        if attempts > MAX_RECONNECT_ATTEMPTS {
 589            log::error!(
 590                "Failed to reconnect to after {} attempts, giving up",
 591                MAX_RECONNECT_ATTEMPTS
 592            );
 593            drop(lock);
 594            self.set_state(State::ReconnectExhausted, cx);
 595            return Ok(());
 596        }
 597        drop(lock);
 598
 599        self.set_state(State::Reconnecting, cx);
 600
 601        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 602
 603        let identifier = self.unique_identifier.clone();
 604        let client = self.client.clone();
 605        let reconnect_task = cx.spawn(|this, mut cx| async move {
 606            macro_rules! failed {
 607                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 608                    return State::ReconnectFailed {
 609                        error: anyhow!($error),
 610                        attempts: $attempts,
 611                        ssh_connection: $ssh_connection,
 612                        delegate: $delegate,
 613                        forwarder: $forwarder,
 614                    };
 615                };
 616            }
 617
 618            if let Err(error) = ssh_connection.master_process.kill() {
 619                failed!(error, attempts, ssh_connection, delegate, forwarder);
 620            };
 621
 622            if let Err(error) = ssh_connection
 623                .master_process
 624                .status()
 625                .await
 626                .context("Failed to kill ssh process")
 627            {
 628                failed!(error, attempts, ssh_connection, delegate, forwarder);
 629            }
 630
 631            let connection_options = ssh_connection.socket.connection_options.clone();
 632
 633            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 634            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 635                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 636            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 637
 638            let (ssh_connection, ssh_process) = match Self::establish_connection(
 639                identifier,
 640                true,
 641                connection_options,
 642                delegate.clone(),
 643                &mut cx,
 644            )
 645            .await
 646            {
 647                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 648                Err(error) => {
 649                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 650                }
 651            };
 652
 653            let multiplex_task = Self::multiplex(
 654                this.clone(),
 655                ssh_process,
 656                proxy_incoming_tx,
 657                proxy_outgoing_rx,
 658                connection_activity_tx,
 659                &mut cx,
 660            );
 661
 662            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 663                failed!(error, attempts, ssh_connection, delegate, forwarder);
 664            };
 665
 666            State::Connected {
 667                ssh_connection,
 668                delegate,
 669                forwarder,
 670                multiplex_task,
 671                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 672            }
 673        });
 674
 675        cx.spawn(|this, mut cx| async move {
 676            let new_state = reconnect_task.await;
 677            this.update(&mut cx, |this, cx| {
 678                this.try_set_state(cx, |old_state| {
 679                    if old_state.is_reconnecting() {
 680                        match &new_state {
 681                            State::Connecting
 682                            | State::Reconnecting { .. }
 683                            | State::HeartbeatMissed { .. }
 684                            | State::ServerNotRunning => {}
 685                            State::Connected { .. } => {
 686                                log::info!("Successfully reconnected");
 687                            }
 688                            State::ReconnectFailed {
 689                                error, attempts, ..
 690                            } => {
 691                                log::error!(
 692                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 693                                    attempts,
 694                                    error
 695                                );
 696                            }
 697                            State::ReconnectExhausted => {
 698                                log::error!("Reconnect attempt failed and all attempts exhausted");
 699                            }
 700                        }
 701                        Some(new_state)
 702                    } else {
 703                        None
 704                    }
 705                });
 706
 707                if this.state_is(State::is_reconnect_failed) {
 708                    this.reconnect(cx)
 709                } else if this.state_is(State::is_reconnect_exhausted) {
 710                    cx.emit(SshRemoteEvent::Disconnected);
 711                    Ok(())
 712                } else {
 713                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 714                    Ok(())
 715                }
 716            })
 717        })
 718        .detach_and_log_err(cx);
 719
 720        Ok(())
 721    }
 722
 723    fn heartbeat(
 724        this: WeakModel<Self>,
 725        mut connection_activity_rx: mpsc::Receiver<()>,
 726        cx: &mut AsyncAppContext,
 727    ) -> Task<Result<()>> {
 728        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 729            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 730        };
 731
 732        cx.spawn(|mut cx| {
 733            let this = this.clone();
 734            async move {
 735                let mut missed_heartbeats = 0;
 736
 737                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 738                futures::pin_mut!(keepalive_timer);
 739
 740                loop {
 741                    select_biased! {
 742                        _ = connection_activity_rx.next().fuse() => {
 743                            keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 744                        }
 745                        _ = keepalive_timer => {
 746                            log::debug!("Sending heartbeat to server...");
 747
 748                            let result = select_biased! {
 749                                _ = connection_activity_rx.next().fuse() => {
 750                                    Ok(())
 751                                }
 752                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 753                                    ping_result
 754                                }
 755                            };
 756                            if result.is_err() {
 757                                missed_heartbeats += 1;
 758                                log::warn!(
 759                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 760                                    HEARTBEAT_TIMEOUT,
 761                                    missed_heartbeats,
 762                                    MAX_MISSED_HEARTBEATS
 763                                );
 764                            } else if missed_heartbeats != 0 {
 765                                missed_heartbeats = 0;
 766                            } else {
 767                                continue;
 768                            }
 769
 770                            let result = this.update(&mut cx, |this, mut cx| {
 771                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 772                            })?;
 773                            if result.is_break() {
 774                                return Ok(());
 775                            }
 776                        }
 777                    }
 778                }
 779            }
 780        })
 781    }
 782
 783    fn handle_heartbeat_result(
 784        &mut self,
 785        missed_heartbeats: usize,
 786        cx: &mut ModelContext<Self>,
 787    ) -> ControlFlow<()> {
 788        let state = self.state.lock().take().unwrap();
 789        let next_state = if missed_heartbeats > 0 {
 790            state.heartbeat_missed()
 791        } else {
 792            state.heartbeat_recovered()
 793        };
 794
 795        self.set_state(next_state, cx);
 796
 797        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 798            log::error!(
 799                "Missed last {} heartbeats. Reconnecting...",
 800                missed_heartbeats
 801            );
 802
 803            self.reconnect(cx)
 804                .context("failed to start reconnect process after missing heartbeats")
 805                .log_err();
 806            ControlFlow::Break(())
 807        } else {
 808            ControlFlow::Continue(())
 809        }
 810    }
 811
 812    fn multiplex(
 813        this: WeakModel<Self>,
 814        mut ssh_proxy_process: Child,
 815        incoming_tx: UnboundedSender<Envelope>,
 816        mut outgoing_rx: UnboundedReceiver<Envelope>,
 817        mut connection_activity_tx: Sender<()>,
 818        cx: &AsyncAppContext,
 819    ) -> Task<Result<()>> {
 820        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 821        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 822        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 823
 824        let io_task = cx.background_executor().spawn(async move {
 825            let mut stdin_buffer = Vec::new();
 826            let mut stdout_buffer = Vec::new();
 827            let mut stderr_buffer = Vec::new();
 828            let mut stderr_offset = 0;
 829
 830            loop {
 831                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 832                stderr_buffer.resize(stderr_offset + 1024, 0);
 833
 834                select_biased! {
 835                    outgoing = outgoing_rx.next().fuse() => {
 836                        let Some(outgoing) = outgoing else {
 837                            return anyhow::Ok(None);
 838                        };
 839
 840                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 841                    }
 842
 843                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 844                        match result {
 845                            Ok(0) => {
 846                                child_stdin.close().await?;
 847                                outgoing_rx.close();
 848                                let status = ssh_proxy_process.status().await?;
 849                                // If we don't have a code, we assume process
 850                                // has been killed and treat it as non-zero exit
 851                                // code
 852                                return Ok(status.code().or_else(|| Some(1)));
 853                            }
 854                            Ok(len) => {
 855                                if len < stdout_buffer.len() {
 856                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 857                                }
 858
 859                                let message_len = message_len_from_buffer(&stdout_buffer);
 860                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 861                                    Ok(envelope) => {
 862                                        connection_activity_tx.try_send(()).ok();
 863                                        incoming_tx.unbounded_send(envelope).ok();
 864                                    }
 865                                    Err(error) => {
 866                                        log::error!("error decoding message {error:?}");
 867                                    }
 868                                }
 869                            }
 870                            Err(error) => {
 871                                Err(anyhow!("error reading stdout: {error:?}"))?;
 872                            }
 873                        }
 874                    }
 875
 876                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 877                        match result {
 878                            Ok(len) => {
 879                                stderr_offset += len;
 880                                let mut start_ix = 0;
 881                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 882                                    let line_ix = start_ix + ix;
 883                                    let content = &stderr_buffer[start_ix..line_ix];
 884                                    start_ix = line_ix + 1;
 885                                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
 886                                        record.log(log::logger())
 887                                    } else {
 888                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 889                                    }
 890                                }
 891                                stderr_buffer.drain(0..start_ix);
 892                                stderr_offset -= start_ix;
 893
 894                                connection_activity_tx.try_send(()).ok();
 895                            }
 896                            Err(error) => {
 897                                Err(anyhow!("error reading stderr: {error:?}"))?;
 898                            }
 899                        }
 900                    }
 901                }
 902            }
 903        });
 904
 905        cx.spawn(|mut cx| async move {
 906            let result = io_task.await;
 907
 908            match result {
 909                Ok(Some(exit_code)) => {
 910                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 911                        match error {
 912                            ProxyLaunchError::ServerNotRunning => {
 913                                log::error!("failed to reconnect because server is not running");
 914                                this.update(&mut cx, |this, cx| {
 915                                    this.set_state(State::ServerNotRunning, cx);
 916                                    cx.emit(SshRemoteEvent::Disconnected);
 917                                })?;
 918                            }
 919                        }
 920                    } else if exit_code > 0 {
 921                        log::error!("proxy process terminated unexpectedly");
 922                        this.update(&mut cx, |this, cx| {
 923                            this.reconnect(cx).ok();
 924                        })?;
 925                    }
 926                }
 927                Ok(None) => {}
 928                Err(error) => {
 929                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 930                    this.update(&mut cx, |this, cx| {
 931                        this.reconnect(cx).ok();
 932                    })?;
 933                }
 934            }
 935            Ok(())
 936        })
 937    }
 938
 939    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 940        self.state.lock().as_ref().map_or(false, check)
 941    }
 942
 943    fn try_set_state(
 944        &self,
 945        cx: &mut ModelContext<Self>,
 946        map: impl FnOnce(&State) -> Option<State>,
 947    ) {
 948        let mut lock = self.state.lock();
 949        let new_state = lock.as_ref().and_then(map);
 950
 951        if let Some(new_state) = new_state {
 952            lock.replace(new_state);
 953            cx.notify();
 954        }
 955    }
 956
 957    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 958        log::info!("setting state to '{}'", &state);
 959        self.state.lock().replace(state);
 960        cx.notify();
 961    }
 962
 963    async fn establish_connection(
 964        unique_identifier: String,
 965        reconnect: bool,
 966        connection_options: SshConnectionOptions,
 967        delegate: Arc<dyn SshClientDelegate>,
 968        cx: &mut AsyncAppContext,
 969    ) -> Result<(SshRemoteConnection, Child)> {
 970        let ssh_connection =
 971            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 972
 973        let platform = ssh_connection.query_platform().await?;
 974        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 975        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 976        ssh_connection
 977            .ensure_server_binary(
 978                &delegate,
 979                &local_binary_path,
 980                &remote_binary_path,
 981                version,
 982                cx,
 983            )
 984            .await?;
 985
 986        let socket = ssh_connection.socket.clone();
 987        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 988
 989        delegate.set_status(Some("Starting proxy"), cx);
 990
 991        let mut start_proxy_command = format!(
 992            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 993            std::env::var("RUST_LOG").unwrap_or_default(),
 994            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 995            remote_binary_path,
 996            unique_identifier,
 997        );
 998        if reconnect {
 999            start_proxy_command.push_str(" --reconnect");
1000        }
1001
1002        let ssh_proxy_process = socket
1003            .ssh_command(start_proxy_command)
1004            // IMPORTANT: we kill this process when we drop the task that uses it.
1005            .kill_on_drop(true)
1006            .spawn()
1007            .context("failed to spawn remote server")?;
1008
1009        Ok((ssh_connection, ssh_proxy_process))
1010    }
1011
1012    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1013        self.client.subscribe_to_entity(remote_id, entity);
1014    }
1015
1016    pub fn ssh_args(&self) -> Option<Vec<String>> {
1017        self.state
1018            .lock()
1019            .as_ref()
1020            .and_then(|state| state.ssh_connection())
1021            .map(|ssh_connection| ssh_connection.socket.ssh_args())
1022    }
1023
1024    pub fn to_proto_client(&self) -> AnyProtoClient {
1025        self.client.clone().into()
1026    }
1027
1028    pub fn connection_string(&self) -> String {
1029        self.connection_options.connection_string()
1030    }
1031
1032    pub fn connection_options(&self) -> SshConnectionOptions {
1033        self.connection_options.clone()
1034    }
1035
1036    #[cfg(not(any(test, feature = "test-support")))]
1037    pub fn connection_state(&self) -> ConnectionState {
1038        self.state
1039            .lock()
1040            .as_ref()
1041            .map(ConnectionState::from)
1042            .unwrap_or(ConnectionState::Disconnected)
1043    }
1044
1045    #[cfg(any(test, feature = "test-support"))]
1046    pub fn connection_state(&self) -> ConnectionState {
1047        ConnectionState::Connected
1048    }
1049
1050    pub fn is_disconnected(&self) -> bool {
1051        self.connection_state() == ConnectionState::Disconnected
1052    }
1053
1054    #[cfg(any(test, feature = "test-support"))]
1055    pub fn fake(
1056        client_cx: &mut gpui::TestAppContext,
1057        server_cx: &mut gpui::TestAppContext,
1058    ) -> (Model<Self>, Arc<ChannelClient>) {
1059        use gpui::Context;
1060
1061        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1062        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1063
1064        (
1065            client_cx.update(|cx| {
1066                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1067                cx.new_model(|_| Self {
1068                    client,
1069                    unique_identifier: "fake".to_string(),
1070                    connection_options: SshConnectionOptions::default(),
1071                    state: Arc::new(Mutex::new(None)),
1072                })
1073            }),
1074            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1075        )
1076    }
1077}
1078
1079impl From<SshRemoteClient> for AnyProtoClient {
1080    fn from(client: SshRemoteClient) -> Self {
1081        AnyProtoClient::new(client.client.clone())
1082    }
1083}
1084
1085struct SshRemoteConnection {
1086    socket: SshSocket,
1087    master_process: process::Child,
1088    _temp_dir: TempDir,
1089}
1090
1091impl Drop for SshRemoteConnection {
1092    fn drop(&mut self) {
1093        if let Err(error) = self.master_process.kill() {
1094            log::error!("failed to kill SSH master process: {}", error);
1095        }
1096    }
1097}
1098
1099impl SshRemoteConnection {
1100    #[cfg(not(unix))]
1101    async fn new(
1102        _connection_options: SshConnectionOptions,
1103        _delegate: Arc<dyn SshClientDelegate>,
1104        _cx: &mut AsyncAppContext,
1105    ) -> Result<Self> {
1106        Err(anyhow!("ssh is not supported on this platform"))
1107    }
1108
1109    #[cfg(unix)]
1110    async fn new(
1111        connection_options: SshConnectionOptions,
1112        delegate: Arc<dyn SshClientDelegate>,
1113        cx: &mut AsyncAppContext,
1114    ) -> Result<Self> {
1115        use futures::{io::BufReader, AsyncBufReadExt as _};
1116        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1117        use util::ResultExt as _;
1118
1119        delegate.set_status(Some("connecting"), cx);
1120
1121        let url = connection_options.ssh_url();
1122        let temp_dir = tempfile::Builder::new()
1123            .prefix("zed-ssh-session")
1124            .tempdir()?;
1125
1126        // Create a domain socket listener to handle requests from the askpass program.
1127        let askpass_socket = temp_dir.path().join("askpass.sock");
1128        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1129        let listener =
1130            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1131
1132        let askpass_task = cx.spawn({
1133            let delegate = delegate.clone();
1134            |mut cx| async move {
1135                let mut askpass_opened_tx = Some(askpass_opened_tx);
1136
1137                while let Ok((mut stream, _)) = listener.accept().await {
1138                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1139                        askpass_opened_tx.send(()).ok();
1140                    }
1141                    let mut buffer = Vec::new();
1142                    let mut reader = BufReader::new(&mut stream);
1143                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1144                        buffer.clear();
1145                    }
1146                    let password_prompt = String::from_utf8_lossy(&buffer);
1147                    if let Some(password) = delegate
1148                        .ask_password(password_prompt.to_string(), &mut cx)
1149                        .await
1150                        .context("failed to get ssh password")
1151                        .and_then(|p| p)
1152                        .log_err()
1153                    {
1154                        stream.write_all(password.as_bytes()).await.log_err();
1155                    }
1156                }
1157            }
1158        });
1159
1160        // Create an askpass script that communicates back to this process.
1161        let askpass_script = format!(
1162            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1163            askpass_socket = askpass_socket.display(),
1164            print_args = "printf '%s\\0' \"$@\"",
1165            shebang = "#!/bin/sh",
1166        );
1167        let askpass_script_path = temp_dir.path().join("askpass.sh");
1168        fs::write(&askpass_script_path, askpass_script).await?;
1169        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1170
1171        // Start the master SSH process, which does not do anything except for establish
1172        // the connection and keep it open, allowing other ssh commands to reuse it
1173        // via a control socket.
1174        let socket_path = temp_dir.path().join("ssh.sock");
1175        let mut master_process = process::Command::new("ssh")
1176            .stdin(Stdio::null())
1177            .stdout(Stdio::piped())
1178            .stderr(Stdio::piped())
1179            .env("SSH_ASKPASS_REQUIRE", "force")
1180            .env("SSH_ASKPASS", &askpass_script_path)
1181            .args(["-N", "-o", "ControlMaster=yes", "-o"])
1182            .arg(format!("ControlPath={}", socket_path.display()))
1183            .arg(&url)
1184            .spawn()?;
1185
1186        // Wait for this ssh process to close its stdout, indicating that authentication
1187        // has completed.
1188        let stdout = master_process.stdout.as_mut().unwrap();
1189        let mut output = Vec::new();
1190        let connection_timeout = Duration::from_secs(10);
1191
1192        let result = select_biased! {
1193            _ = askpass_opened_rx.fuse() => {
1194                // If the askpass script has opened, that means the user is typing
1195                // their password, in which case we don't want to timeout anymore,
1196                // since we know a connection has been established.
1197                stdout.read_to_end(&mut output).await?;
1198                Ok(())
1199            }
1200            result = stdout.read_to_end(&mut output).fuse() => {
1201                result?;
1202                Ok(())
1203            }
1204            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1205                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1206            }
1207        };
1208
1209        if let Err(e) = result {
1210            let error_message = format!("Failed to connect to host: {}.", e);
1211            delegate.set_error(error_message, cx);
1212            return Err(e);
1213        }
1214
1215        drop(askpass_task);
1216
1217        if master_process.try_status()?.is_some() {
1218            output.clear();
1219            let mut stderr = master_process.stderr.take().unwrap();
1220            stderr.read_to_end(&mut output).await?;
1221
1222            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1223            delegate.set_error(error_message.clone(), cx);
1224            Err(anyhow!(error_message))?;
1225        }
1226
1227        Ok(Self {
1228            socket: SshSocket {
1229                connection_options,
1230                socket_path,
1231            },
1232            master_process,
1233            _temp_dir: temp_dir,
1234        })
1235    }
1236
1237    async fn ensure_server_binary(
1238        &self,
1239        delegate: &Arc<dyn SshClientDelegate>,
1240        src_path: &Path,
1241        dst_path: &Path,
1242        version: SemanticVersion,
1243        cx: &mut AsyncAppContext,
1244    ) -> Result<()> {
1245        let mut dst_path_gz = dst_path.to_path_buf();
1246        dst_path_gz.set_extension("gz");
1247
1248        if let Some(parent) = dst_path.parent() {
1249            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1250        }
1251
1252        let mut server_binary_exists = false;
1253        if cfg!(not(debug_assertions)) {
1254            if let Ok(installed_version) =
1255                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1256            {
1257                if installed_version.trim() == version.to_string() {
1258                    server_binary_exists = true;
1259                }
1260            }
1261        }
1262
1263        if server_binary_exists {
1264            log::info!("remote development server already present",);
1265            return Ok(());
1266        }
1267
1268        let src_stat = fs::metadata(src_path).await?;
1269        let size = src_stat.len();
1270        let server_mode = 0o755;
1271
1272        let t0 = Instant::now();
1273        delegate.set_status(Some("uploading remote development server"), cx);
1274        log::info!("uploading remote development server ({}kb)", size / 1024);
1275        self.upload_file(src_path, &dst_path_gz)
1276            .await
1277            .context("failed to upload server binary")?;
1278        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1279
1280        delegate.set_status(Some("extracting remote development server"), cx);
1281        run_cmd(
1282            self.socket
1283                .ssh_command("gunzip")
1284                .arg("--force")
1285                .arg(&dst_path_gz),
1286        )
1287        .await?;
1288
1289        delegate.set_status(Some("unzipping remote development server"), cx);
1290        run_cmd(
1291            self.socket
1292                .ssh_command("chmod")
1293                .arg(format!("{:o}", server_mode))
1294                .arg(dst_path),
1295        )
1296        .await?;
1297
1298        Ok(())
1299    }
1300
1301    async fn query_platform(&self) -> Result<SshPlatform> {
1302        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1303        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1304
1305        let os = match os.trim() {
1306            "Darwin" => "macos",
1307            "Linux" => "linux",
1308            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1309        };
1310        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1311            "aarch64"
1312        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1313            "x86_64"
1314        } else {
1315            Err(anyhow!("unknown uname architecture {arch:?}"))?
1316        };
1317
1318        Ok(SshPlatform { os, arch })
1319    }
1320
1321    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1322        let mut command = process::Command::new("scp");
1323        let output = self
1324            .socket
1325            .ssh_options(&mut command)
1326            .args(
1327                self.socket
1328                    .connection_options
1329                    .port
1330                    .map(|port| vec!["-P".to_string(), port.to_string()])
1331                    .unwrap_or_default(),
1332            )
1333            .arg(src_path)
1334            .arg(format!(
1335                "{}:{}",
1336                self.socket.connection_options.scp_url(),
1337                dest_path.display()
1338            ))
1339            .output()
1340            .await?;
1341
1342        if output.status.success() {
1343            Ok(())
1344        } else {
1345            Err(anyhow!(
1346                "failed to upload file {} -> {}: {}",
1347                src_path.display(),
1348                dest_path.display(),
1349                String::from_utf8_lossy(&output.stderr)
1350            ))
1351        }
1352    }
1353}
1354
1355type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1356
1357pub struct ChannelClient {
1358    next_message_id: AtomicU32,
1359    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1360    response_channels: ResponseChannels,             // Lock
1361    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1362}
1363
1364impl ChannelClient {
1365    pub fn new(
1366        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1367        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1368        cx: &AppContext,
1369    ) -> Arc<Self> {
1370        let this = Arc::new(Self {
1371            outgoing_tx,
1372            next_message_id: AtomicU32::new(0),
1373            response_channels: ResponseChannels::default(),
1374            message_handlers: Default::default(),
1375        });
1376
1377        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1378
1379        this
1380    }
1381
1382    fn start_handling_messages(
1383        this: Arc<Self>,
1384        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1385        cx: &AppContext,
1386    ) {
1387        cx.spawn(|cx| {
1388            let this = Arc::downgrade(&this);
1389            async move {
1390                let peer_id = PeerId { owner_id: 0, id: 0 };
1391                while let Some(incoming) = incoming_rx.next().await {
1392                    let Some(this) = this.upgrade() else {
1393                        return anyhow::Ok(());
1394                    };
1395
1396                    if let Some(request_id) = incoming.responding_to {
1397                        let request_id = MessageId(request_id);
1398                        let sender = this.response_channels.lock().remove(&request_id);
1399                        if let Some(sender) = sender {
1400                            let (tx, rx) = oneshot::channel();
1401                            if incoming.payload.is_some() {
1402                                sender.send((incoming, tx)).ok();
1403                            }
1404                            rx.await.ok();
1405                        }
1406                    } else if let Some(envelope) =
1407                        build_typed_envelope(peer_id, Instant::now(), incoming)
1408                    {
1409                        let type_name = envelope.payload_type_name();
1410                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1411                            &this.message_handlers,
1412                            envelope,
1413                            this.clone().into(),
1414                            cx.clone(),
1415                        ) {
1416                            log::debug!("ssh message received. name:{type_name}");
1417                            cx.foreground_executor().spawn(async move {
1418                                match future.await {
1419                                    Ok(_) => {
1420                                        log::debug!("ssh message handled. name:{type_name}");
1421                                    }
1422                                    Err(error) => {
1423                                        log::error!(
1424                                            "error handling message. type:{type_name}, error:{error}",
1425                                        );
1426                                    }
1427                                }
1428                            }).detach();
1429
1430                        } else {
1431                            log::error!("unhandled ssh message name:{type_name}");
1432                        }
1433                    }
1434                }
1435                anyhow::Ok(())
1436            }
1437        })
1438        .detach();
1439    }
1440
1441    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1442        let id = (TypeId::of::<E>(), remote_id);
1443
1444        let mut message_handlers = self.message_handlers.lock();
1445        if message_handlers
1446            .entities_by_type_and_remote_id
1447            .contains_key(&id)
1448        {
1449            panic!("already subscribed to entity");
1450        }
1451
1452        message_handlers.entities_by_type_and_remote_id.insert(
1453            id,
1454            EntityMessageSubscriber::Entity {
1455                handle: entity.downgrade().into(),
1456            },
1457        );
1458    }
1459
1460    pub fn request<T: RequestMessage>(
1461        &self,
1462        payload: T,
1463    ) -> impl 'static + Future<Output = Result<T::Response>> {
1464        log::debug!("ssh request start. name:{}", T::NAME);
1465        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1466        async move {
1467            let response = response.await?;
1468            log::debug!("ssh request finish. name:{}", T::NAME);
1469            T::Response::from_envelope(response)
1470                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1471        }
1472    }
1473
1474    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1475        smol::future::or(
1476            async {
1477                self.request(proto::Ping {}).await?;
1478                Ok(())
1479            },
1480            async {
1481                smol::Timer::after(timeout).await;
1482                Err(anyhow!("Timeout detected"))
1483            },
1484        )
1485        .await
1486    }
1487
1488    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1489        log::debug!("ssh send name:{}", T::NAME);
1490        self.send_dynamic(payload.into_envelope(0, None, None))
1491    }
1492
1493    pub fn request_dynamic(
1494        &self,
1495        mut envelope: proto::Envelope,
1496        type_name: &'static str,
1497    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1498        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1499        let (tx, rx) = oneshot::channel();
1500        let mut response_channels_lock = self.response_channels.lock();
1501        response_channels_lock.insert(MessageId(envelope.id), tx);
1502        drop(response_channels_lock);
1503        let result = self.outgoing_tx.unbounded_send(envelope);
1504        async move {
1505            if let Err(error) = &result {
1506                log::error!("failed to send message: {}", error);
1507                return Err(anyhow!("failed to send message: {}", error));
1508            }
1509
1510            let response = rx.await.context("connection lost")?.0;
1511            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1512                return Err(RpcError::from_proto(error, type_name));
1513            }
1514            Ok(response)
1515        }
1516    }
1517
1518    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1519        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1520        self.outgoing_tx.unbounded_send(envelope)?;
1521        Ok(())
1522    }
1523}
1524
1525impl ProtoClient for ChannelClient {
1526    fn request(
1527        &self,
1528        envelope: proto::Envelope,
1529        request_type: &'static str,
1530    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1531        self.request_dynamic(envelope, request_type).boxed()
1532    }
1533
1534    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1535        self.send_dynamic(envelope)
1536    }
1537
1538    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1539        self.send_dynamic(envelope)
1540    }
1541
1542    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1543        &self.message_handlers
1544    }
1545
1546    fn is_via_collab(&self) -> bool {
1547        false
1548    }
1549}