ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use collections::HashMap;
  10use futures::{
  11    channel::{
  12        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  13        oneshot,
  14    },
  15    future::BoxFuture,
  16    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  17    StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
  21    WeakModel,
  22};
  23use parking_lot::Mutex;
  24use rpc::{
  25    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  26    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  27};
  28use smol::{
  29    fs,
  30    process::{self, Child, Stdio},
  31};
  32use std::{
  33    any::TypeId,
  34    ffi::OsStr,
  35    fmt,
  36    ops::ControlFlow,
  37    path::{Path, PathBuf},
  38    sync::{
  39        atomic::{AtomicU32, Ordering::SeqCst},
  40        Arc,
  41    },
  42    time::{Duration, Instant},
  43};
  44use tempfile::TempDir;
  45use util::ResultExt;
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  49)]
  50pub struct SshProjectId(pub u64);
  51
  52#[derive(Clone)]
  53pub struct SshSocket {
  54    connection_options: SshConnectionOptions,
  55    socket_path: PathBuf,
  56}
  57
  58#[derive(Debug, Default, Clone, PartialEq, Eq)]
  59pub struct SshConnectionOptions {
  60    pub host: String,
  61    pub username: Option<String>,
  62    pub port: Option<u16>,
  63    pub password: Option<String>,
  64}
  65
  66impl SshConnectionOptions {
  67    pub fn ssh_url(&self) -> String {
  68        let mut result = String::from("ssh://");
  69        if let Some(username) = &self.username {
  70            result.push_str(username);
  71            result.push('@');
  72        }
  73        result.push_str(&self.host);
  74        if let Some(port) = self.port {
  75            result.push(':');
  76            result.push_str(&port.to_string());
  77        }
  78        result
  79    }
  80
  81    fn scp_url(&self) -> String {
  82        if let Some(username) = &self.username {
  83            format!("{}@{}", username, self.host)
  84        } else {
  85            self.host.clone()
  86        }
  87    }
  88
  89    pub fn connection_string(&self) -> String {
  90        let host = if let Some(username) = &self.username {
  91            format!("{}@{}", username, self.host)
  92        } else {
  93            self.host.clone()
  94        };
  95        if let Some(port) = &self.port {
  96            format!("{}:{}", host, port)
  97        } else {
  98            host
  99        }
 100    }
 101
 102    // Uniquely identifies dev server projects on a remote host. Needs to be
 103    // stable for the same dev server project.
 104    pub fn dev_server_identifier(&self) -> String {
 105        let mut identifier = format!("dev-server-{:?}", self.host);
 106        if let Some(username) = self.username.as_ref() {
 107            identifier.push('-');
 108            identifier.push_str(&username);
 109        }
 110        identifier
 111    }
 112}
 113
 114#[derive(Copy, Clone, Debug)]
 115pub struct SshPlatform {
 116    pub os: &'static str,
 117    pub arch: &'static str,
 118}
 119
 120impl SshPlatform {
 121    pub fn triple(&self) -> Option<String> {
 122        Some(format!(
 123            "{}-{}",
 124            self.arch,
 125            match self.os {
 126                "linux" => "unknown-linux-gnu",
 127                "macos" => "apple-darwin",
 128                _ => return None,
 129            }
 130        ))
 131    }
 132}
 133
 134pub trait SshClientDelegate: Send + Sync {
 135    fn ask_password(
 136        &self,
 137        prompt: String,
 138        cx: &mut AsyncAppContext,
 139    ) -> oneshot::Receiver<Result<String>>;
 140    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 141    fn get_server_binary(
 142        &self,
 143        platform: SshPlatform,
 144        cx: &mut AsyncAppContext,
 145    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 146    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 147    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 148}
 149
 150impl SshSocket {
 151    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 152        let mut command = process::Command::new("ssh");
 153        self.ssh_options(&mut command)
 154            .arg(self.connection_options.ssh_url())
 155            .arg(program);
 156        command
 157    }
 158
 159    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 160        command
 161            .stdin(Stdio::piped())
 162            .stdout(Stdio::piped())
 163            .stderr(Stdio::piped())
 164            .args(["-o", "ControlMaster=no", "-o"])
 165            .arg(format!("ControlPath={}", self.socket_path.display()))
 166    }
 167
 168    fn ssh_args(&self) -> Vec<String> {
 169        vec![
 170            "-o".to_string(),
 171            "ControlMaster=no".to_string(),
 172            "-o".to_string(),
 173            format!("ControlPath={}", self.socket_path.display()),
 174            self.connection_options.ssh_url(),
 175        ]
 176    }
 177}
 178
 179async fn run_cmd(command: &mut process::Command) -> Result<String> {
 180    let output = command.output().await?;
 181    if output.status.success() {
 182        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 183    } else {
 184        Err(anyhow!(
 185            "failed to run command: {}",
 186            String::from_utf8_lossy(&output.stderr)
 187        ))
 188    }
 189}
 190
 191struct ChannelForwarder {
 192    quit_tx: UnboundedSender<()>,
 193    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 194}
 195
 196impl ChannelForwarder {
 197    fn new(
 198        mut incoming_tx: UnboundedSender<Envelope>,
 199        mut outgoing_rx: UnboundedReceiver<Envelope>,
 200        cx: &AsyncAppContext,
 201    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 202        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 203
 204        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 205        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 206
 207        let forwarding_task = cx.background_executor().spawn(async move {
 208            loop {
 209                select_biased! {
 210                    _ = quit_rx.next().fuse() => {
 211                        break;
 212                    },
 213                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 214                        if let Some(envelope) = incoming_envelope {
 215                            if incoming_tx.send(envelope).await.is_err() {
 216                                break;
 217                            }
 218                        } else {
 219                            break;
 220                        }
 221                    }
 222                    outgoing_envelope = outgoing_rx.next().fuse() => {
 223                        if let Some(envelope) = outgoing_envelope {
 224                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 225                                break;
 226                            }
 227                        } else {
 228                            break;
 229                        }
 230                    }
 231                }
 232            }
 233
 234            (incoming_tx, outgoing_rx)
 235        });
 236
 237        (
 238            Self {
 239                forwarding_task,
 240                quit_tx,
 241            },
 242            proxy_incoming_tx,
 243            proxy_outgoing_rx,
 244        )
 245    }
 246
 247    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 248        let _ = self.quit_tx.send(()).await;
 249        self.forwarding_task.await
 250    }
 251}
 252
 253const MAX_MISSED_HEARTBEATS: usize = 5;
 254const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 255const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 256
 257const MAX_RECONNECT_ATTEMPTS: usize = 3;
 258
 259enum State {
 260    Connecting,
 261    Connected {
 262        ssh_connection: SshRemoteConnection,
 263        delegate: Arc<dyn SshClientDelegate>,
 264        forwarder: ChannelForwarder,
 265
 266        multiplex_task: Task<Result<()>>,
 267        heartbeat_task: Task<Result<()>>,
 268    },
 269    HeartbeatMissed {
 270        missed_heartbeats: usize,
 271
 272        ssh_connection: SshRemoteConnection,
 273        delegate: Arc<dyn SshClientDelegate>,
 274        forwarder: ChannelForwarder,
 275
 276        multiplex_task: Task<Result<()>>,
 277        heartbeat_task: Task<Result<()>>,
 278    },
 279    Reconnecting,
 280    ReconnectFailed {
 281        ssh_connection: SshRemoteConnection,
 282        delegate: Arc<dyn SshClientDelegate>,
 283        forwarder: ChannelForwarder,
 284
 285        error: anyhow::Error,
 286        attempts: usize,
 287    },
 288    ReconnectExhausted,
 289    ServerNotRunning,
 290}
 291
 292impl fmt::Display for State {
 293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 294        match self {
 295            Self::Connecting => write!(f, "connecting"),
 296            Self::Connected { .. } => write!(f, "connected"),
 297            Self::Reconnecting => write!(f, "reconnecting"),
 298            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 299            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 300            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 301            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 302        }
 303    }
 304}
 305
 306impl State {
 307    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
 308        match self {
 309            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
 310            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
 311            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
 312            _ => None,
 313        }
 314    }
 315
 316    fn can_reconnect(&self) -> bool {
 317        match self {
 318            Self::Connected { .. }
 319            | Self::HeartbeatMissed { .. }
 320            | Self::ReconnectFailed { .. } => true,
 321            State::Connecting
 322            | State::Reconnecting
 323            | State::ReconnectExhausted
 324            | State::ServerNotRunning => false,
 325        }
 326    }
 327
 328    fn is_reconnect_failed(&self) -> bool {
 329        matches!(self, Self::ReconnectFailed { .. })
 330    }
 331
 332    fn is_reconnect_exhausted(&self) -> bool {
 333        matches!(self, Self::ReconnectExhausted { .. })
 334    }
 335
 336    fn is_reconnecting(&self) -> bool {
 337        matches!(self, Self::Reconnecting { .. })
 338    }
 339
 340    fn heartbeat_recovered(self) -> Self {
 341        match self {
 342            Self::HeartbeatMissed {
 343                ssh_connection,
 344                delegate,
 345                forwarder,
 346                multiplex_task,
 347                heartbeat_task,
 348                ..
 349            } => Self::Connected {
 350                ssh_connection,
 351                delegate,
 352                forwarder,
 353                multiplex_task,
 354                heartbeat_task,
 355            },
 356            _ => self,
 357        }
 358    }
 359
 360    fn heartbeat_missed(self) -> Self {
 361        match self {
 362            Self::Connected {
 363                ssh_connection,
 364                delegate,
 365                forwarder,
 366                multiplex_task,
 367                heartbeat_task,
 368            } => Self::HeartbeatMissed {
 369                missed_heartbeats: 1,
 370                ssh_connection,
 371                delegate,
 372                forwarder,
 373                multiplex_task,
 374                heartbeat_task,
 375            },
 376            Self::HeartbeatMissed {
 377                missed_heartbeats,
 378                ssh_connection,
 379                delegate,
 380                forwarder,
 381                multiplex_task,
 382                heartbeat_task,
 383            } => Self::HeartbeatMissed {
 384                missed_heartbeats: missed_heartbeats + 1,
 385                ssh_connection,
 386                delegate,
 387                forwarder,
 388                multiplex_task,
 389                heartbeat_task,
 390            },
 391            _ => self,
 392        }
 393    }
 394}
 395
 396/// The state of the ssh connection.
 397#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 398pub enum ConnectionState {
 399    Connecting,
 400    Connected,
 401    HeartbeatMissed,
 402    Reconnecting,
 403    Disconnected,
 404}
 405
 406impl From<&State> for ConnectionState {
 407    fn from(value: &State) -> Self {
 408        match value {
 409            State::Connecting => Self::Connecting,
 410            State::Connected { .. } => Self::Connected,
 411            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 412            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 413            State::ReconnectExhausted => Self::Disconnected,
 414            State::ServerNotRunning => Self::Disconnected,
 415        }
 416    }
 417}
 418
 419pub struct SshRemoteClient {
 420    client: Arc<ChannelClient>,
 421    unique_identifier: String,
 422    connection_options: SshConnectionOptions,
 423    state: Arc<Mutex<Option<State>>>,
 424}
 425
 426#[derive(Debug)]
 427pub enum SshRemoteEvent {
 428    Disconnected,
 429}
 430
 431impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 432
 433impl SshRemoteClient {
 434    pub fn new(
 435        unique_identifier: String,
 436        connection_options: SshConnectionOptions,
 437        delegate: Arc<dyn SshClientDelegate>,
 438        cx: &AppContext,
 439    ) -> Task<Result<Model<Self>>> {
 440        cx.spawn(|mut cx| async move {
 441            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 442            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 443            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 444
 445            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 446            let this = cx.new_model(|_| Self {
 447                client: client.clone(),
 448                unique_identifier: unique_identifier.clone(),
 449                connection_options: connection_options.clone(),
 450                state: Arc::new(Mutex::new(Some(State::Connecting))),
 451            })?;
 452
 453            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 454                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 455
 456            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 457                unique_identifier,
 458                false,
 459                connection_options,
 460                delegate.clone(),
 461                &mut cx,
 462            )
 463            .await?;
 464
 465            let multiplex_task = Self::multiplex(
 466                this.downgrade(),
 467                ssh_proxy_process,
 468                proxy_incoming_tx,
 469                proxy_outgoing_rx,
 470                connection_activity_tx,
 471                &mut cx,
 472            );
 473
 474            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 475                log::error!("failed to establish connection: {}", error);
 476                delegate.set_error(error.to_string(), &mut cx);
 477                return Err(error);
 478            }
 479
 480            let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 481
 482            this.update(&mut cx, |this, _| {
 483                *this.state.lock() = Some(State::Connected {
 484                    ssh_connection,
 485                    delegate,
 486                    forwarder: proxy,
 487                    multiplex_task,
 488                    heartbeat_task,
 489                });
 490            })?;
 491
 492            Ok(this)
 493        })
 494    }
 495
 496    pub fn shutdown_processes<T: RequestMessage>(
 497        &self,
 498        shutdown_request: Option<T>,
 499    ) -> Option<impl Future<Output = ()>> {
 500        let state = self.state.lock().take()?;
 501        log::info!("shutting down ssh processes");
 502
 503        let State::Connected {
 504            multiplex_task,
 505            heartbeat_task,
 506            ssh_connection,
 507            delegate,
 508            forwarder,
 509        } = state
 510        else {
 511            return None;
 512        };
 513
 514        let client = self.client.clone();
 515
 516        Some(async move {
 517            if let Some(shutdown_request) = shutdown_request {
 518                client.send(shutdown_request).log_err();
 519                // We wait 50ms instead of waiting for a response, because
 520                // waiting for a response would require us to wait on the main thread
 521                // which we want to avoid in an `on_app_quit` callback.
 522                smol::Timer::after(Duration::from_millis(50)).await;
 523            }
 524
 525            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 526            // child of master_process.
 527            drop(multiplex_task);
 528            // Now drop the rest of state, which kills master process.
 529            drop(heartbeat_task);
 530            drop(ssh_connection);
 531            drop(delegate);
 532            drop(forwarder);
 533        })
 534    }
 535
 536    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 537        let mut lock = self.state.lock();
 538
 539        let can_reconnect = lock
 540            .as_ref()
 541            .map(|state| state.can_reconnect())
 542            .unwrap_or(false);
 543        if !can_reconnect {
 544            let error = if let Some(state) = lock.as_ref() {
 545                format!("invalid state, cannot reconnect while in state {state}")
 546            } else {
 547                "no state set".to_string()
 548            };
 549            log::info!("aborting reconnect, because not in state that allows reconnecting");
 550            return Err(anyhow!(error));
 551        }
 552
 553        let state = lock.take().unwrap();
 554        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 555            State::Connected {
 556                ssh_connection,
 557                delegate,
 558                forwarder,
 559                multiplex_task,
 560                heartbeat_task,
 561            }
 562            | State::HeartbeatMissed {
 563                ssh_connection,
 564                delegate,
 565                forwarder,
 566                multiplex_task,
 567                heartbeat_task,
 568                ..
 569            } => {
 570                drop(multiplex_task);
 571                drop(heartbeat_task);
 572                (0, ssh_connection, delegate, forwarder)
 573            }
 574            State::ReconnectFailed {
 575                attempts,
 576                ssh_connection,
 577                delegate,
 578                forwarder,
 579                ..
 580            } => (attempts, ssh_connection, delegate, forwarder),
 581            State::Connecting
 582            | State::Reconnecting
 583            | State::ReconnectExhausted
 584            | State::ServerNotRunning => unreachable!(),
 585        };
 586
 587        let attempts = attempts + 1;
 588        if attempts > MAX_RECONNECT_ATTEMPTS {
 589            log::error!(
 590                "Failed to reconnect to after {} attempts, giving up",
 591                MAX_RECONNECT_ATTEMPTS
 592            );
 593            drop(lock);
 594            self.set_state(State::ReconnectExhausted, cx);
 595            return Ok(());
 596        }
 597        drop(lock);
 598
 599        self.set_state(State::Reconnecting, cx);
 600
 601        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 602
 603        let identifier = self.unique_identifier.clone();
 604        let client = self.client.clone();
 605        let reconnect_task = cx.spawn(|this, mut cx| async move {
 606            macro_rules! failed {
 607                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 608                    return State::ReconnectFailed {
 609                        error: anyhow!($error),
 610                        attempts: $attempts,
 611                        ssh_connection: $ssh_connection,
 612                        delegate: $delegate,
 613                        forwarder: $forwarder,
 614                    };
 615                };
 616            }
 617
 618            if let Err(error) = ssh_connection.master_process.kill() {
 619                failed!(error, attempts, ssh_connection, delegate, forwarder);
 620            };
 621
 622            if let Err(error) = ssh_connection
 623                .master_process
 624                .status()
 625                .await
 626                .context("Failed to kill ssh process")
 627            {
 628                failed!(error, attempts, ssh_connection, delegate, forwarder);
 629            }
 630
 631            let connection_options = ssh_connection.socket.connection_options.clone();
 632
 633            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 634            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 635                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 636            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 637
 638            let (ssh_connection, ssh_process) = match Self::establish_connection(
 639                identifier,
 640                true,
 641                connection_options,
 642                delegate.clone(),
 643                &mut cx,
 644            )
 645            .await
 646            {
 647                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 648                Err(error) => {
 649                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 650                }
 651            };
 652
 653            let multiplex_task = Self::multiplex(
 654                this.clone(),
 655                ssh_process,
 656                proxy_incoming_tx,
 657                proxy_outgoing_rx,
 658                connection_activity_tx,
 659                &mut cx,
 660            );
 661
 662            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 663                failed!(error, attempts, ssh_connection, delegate, forwarder);
 664            };
 665
 666            State::Connected {
 667                ssh_connection,
 668                delegate,
 669                forwarder,
 670                multiplex_task,
 671                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 672            }
 673        });
 674
 675        cx.spawn(|this, mut cx| async move {
 676            let new_state = reconnect_task.await;
 677            this.update(&mut cx, |this, cx| {
 678                this.try_set_state(cx, |old_state| {
 679                    if old_state.is_reconnecting() {
 680                        match &new_state {
 681                            State::Connecting
 682                            | State::Reconnecting { .. }
 683                            | State::HeartbeatMissed { .. }
 684                            | State::ServerNotRunning => {}
 685                            State::Connected { .. } => {
 686                                log::info!("Successfully reconnected");
 687                            }
 688                            State::ReconnectFailed {
 689                                error, attempts, ..
 690                            } => {
 691                                log::error!(
 692                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 693                                    attempts,
 694                                    error
 695                                );
 696                            }
 697                            State::ReconnectExhausted => {
 698                                log::error!("Reconnect attempt failed and all attempts exhausted");
 699                            }
 700                        }
 701                        Some(new_state)
 702                    } else {
 703                        None
 704                    }
 705                });
 706
 707                if this.state_is(State::is_reconnect_failed) {
 708                    this.reconnect(cx)
 709                } else if this.state_is(State::is_reconnect_exhausted) {
 710                    cx.emit(SshRemoteEvent::Disconnected);
 711                    Ok(())
 712                } else {
 713                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 714                    Ok(())
 715                }
 716            })
 717        })
 718        .detach_and_log_err(cx);
 719
 720        Ok(())
 721    }
 722
 723    fn heartbeat(
 724        this: WeakModel<Self>,
 725        mut connection_activity_rx: mpsc::Receiver<()>,
 726        cx: &mut AsyncAppContext,
 727    ) -> Task<Result<()>> {
 728        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 729            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 730        };
 731
 732        cx.spawn(|mut cx| {
 733            let this = this.clone();
 734            async move {
 735                let mut missed_heartbeats = 0;
 736
 737                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 738                futures::pin_mut!(keepalive_timer);
 739
 740                loop {
 741                    select_biased! {
 742                        _ = connection_activity_rx.next().fuse() => {
 743                            keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 744                        }
 745                        _ = keepalive_timer => {
 746                            log::debug!("Sending heartbeat to server...");
 747
 748                            let result = select_biased! {
 749                                _ = connection_activity_rx.next().fuse() => {
 750                                    Ok(())
 751                                }
 752                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 753                                    ping_result
 754                                }
 755                            };
 756                            if result.is_err() {
 757                                missed_heartbeats += 1;
 758                                log::warn!(
 759                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 760                                    HEARTBEAT_TIMEOUT,
 761                                    missed_heartbeats,
 762                                    MAX_MISSED_HEARTBEATS
 763                                );
 764                            } else if missed_heartbeats != 0 {
 765                                missed_heartbeats = 0;
 766                            } else {
 767                                continue;
 768                            }
 769
 770                            let result = this.update(&mut cx, |this, mut cx| {
 771                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 772                            })?;
 773                            if result.is_break() {
 774                                return Ok(());
 775                            }
 776                        }
 777                    }
 778                }
 779            }
 780        })
 781    }
 782
 783    fn handle_heartbeat_result(
 784        &mut self,
 785        missed_heartbeats: usize,
 786        cx: &mut ModelContext<Self>,
 787    ) -> ControlFlow<()> {
 788        let state = self.state.lock().take().unwrap();
 789        let next_state = if missed_heartbeats > 0 {
 790            state.heartbeat_missed()
 791        } else {
 792            state.heartbeat_recovered()
 793        };
 794
 795        self.set_state(next_state, cx);
 796
 797        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 798            log::error!(
 799                "Missed last {} heartbeats. Reconnecting...",
 800                missed_heartbeats
 801            );
 802
 803            self.reconnect(cx)
 804                .context("failed to start reconnect process after missing heartbeats")
 805                .log_err();
 806            ControlFlow::Break(())
 807        } else {
 808            ControlFlow::Continue(())
 809        }
 810    }
 811
 812    fn multiplex(
 813        this: WeakModel<Self>,
 814        mut ssh_proxy_process: Child,
 815        incoming_tx: UnboundedSender<Envelope>,
 816        mut outgoing_rx: UnboundedReceiver<Envelope>,
 817        mut connection_activity_tx: Sender<()>,
 818        cx: &AsyncAppContext,
 819    ) -> Task<Result<()>> {
 820        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 821        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 822        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 823
 824        let io_task = cx.background_executor().spawn(async move {
 825            let mut stdin_buffer = Vec::new();
 826            let mut stdout_buffer = Vec::new();
 827            let mut stderr_buffer = Vec::new();
 828            let mut stderr_offset = 0;
 829
 830            loop {
 831                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 832                stderr_buffer.resize(stderr_offset + 1024, 0);
 833
 834                select_biased! {
 835                    outgoing = outgoing_rx.next().fuse() => {
 836                        let Some(outgoing) = outgoing else {
 837                            return anyhow::Ok(None);
 838                        };
 839
 840                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 841                    }
 842
 843                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 844                        match result {
 845                            Ok(0) => {
 846                                child_stdin.close().await?;
 847                                outgoing_rx.close();
 848                                let status = ssh_proxy_process.status().await?;
 849                                return Ok(status.code());
 850                            }
 851                            Ok(len) => {
 852                                if len < stdout_buffer.len() {
 853                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 854                                }
 855
 856                                let message_len = message_len_from_buffer(&stdout_buffer);
 857                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 858                                    Ok(envelope) => {
 859                                        connection_activity_tx.try_send(()).ok();
 860                                        incoming_tx.unbounded_send(envelope).ok();
 861                                    }
 862                                    Err(error) => {
 863                                        log::error!("error decoding message {error:?}");
 864                                    }
 865                                }
 866                            }
 867                            Err(error) => {
 868                                Err(anyhow!("error reading stdout: {error:?}"))?;
 869                            }
 870                        }
 871                    }
 872
 873                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 874                        match result {
 875                            Ok(len) => {
 876                                stderr_offset += len;
 877                                let mut start_ix = 0;
 878                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 879                                    let line_ix = start_ix + ix;
 880                                    let content = &stderr_buffer[start_ix..line_ix];
 881                                    start_ix = line_ix + 1;
 882                                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
 883                                        record.log(log::logger())
 884                                    } else {
 885                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 886                                    }
 887                                }
 888                                stderr_buffer.drain(0..start_ix);
 889                                stderr_offset -= start_ix;
 890
 891                                connection_activity_tx.try_send(()).ok();
 892                            }
 893                            Err(error) => {
 894                                Err(anyhow!("error reading stderr: {error:?}"))?;
 895                            }
 896                        }
 897                    }
 898                }
 899            }
 900        });
 901
 902        cx.spawn(|mut cx| async move {
 903            let result = io_task.await;
 904
 905            match result {
 906                Ok(Some(exit_code)) => {
 907                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 908                        match error {
 909                            ProxyLaunchError::ServerNotRunning => {
 910                                log::error!("failed to reconnect because server is not running");
 911                                this.update(&mut cx, |this, cx| {
 912                                    this.set_state(State::ServerNotRunning, cx);
 913                                    cx.emit(SshRemoteEvent::Disconnected);
 914                                })?;
 915                            }
 916                        }
 917                    } else if exit_code > 0 {
 918                        log::error!("proxy process terminated unexpectedly");
 919                        this.update(&mut cx, |this, cx| {
 920                            this.reconnect(cx).ok();
 921                        })?;
 922                    }
 923                }
 924                Ok(None) => {}
 925                Err(error) => {
 926                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 927                    this.update(&mut cx, |this, cx| {
 928                        this.reconnect(cx).ok();
 929                    })?;
 930                }
 931            }
 932            Ok(())
 933        })
 934    }
 935
 936    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 937        self.state.lock().as_ref().map_or(false, check)
 938    }
 939
 940    fn try_set_state(
 941        &self,
 942        cx: &mut ModelContext<Self>,
 943        map: impl FnOnce(&State) -> Option<State>,
 944    ) {
 945        let mut lock = self.state.lock();
 946        let new_state = lock.as_ref().and_then(map);
 947
 948        if let Some(new_state) = new_state {
 949            lock.replace(new_state);
 950            cx.notify();
 951        }
 952    }
 953
 954    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 955        log::info!("setting state to '{}'", &state);
 956        self.state.lock().replace(state);
 957        cx.notify();
 958    }
 959
 960    async fn establish_connection(
 961        unique_identifier: String,
 962        reconnect: bool,
 963        connection_options: SshConnectionOptions,
 964        delegate: Arc<dyn SshClientDelegate>,
 965        cx: &mut AsyncAppContext,
 966    ) -> Result<(SshRemoteConnection, Child)> {
 967        let ssh_connection =
 968            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 969
 970        let platform = ssh_connection.query_platform().await?;
 971        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 972        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 973        ssh_connection
 974            .ensure_server_binary(
 975                &delegate,
 976                &local_binary_path,
 977                &remote_binary_path,
 978                version,
 979                cx,
 980            )
 981            .await?;
 982
 983        let socket = ssh_connection.socket.clone();
 984        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 985
 986        delegate.set_status(Some("Starting proxy"), cx);
 987
 988        let mut start_proxy_command = format!(
 989            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 990            std::env::var("RUST_LOG").unwrap_or_default(),
 991            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 992            remote_binary_path,
 993            unique_identifier,
 994        );
 995        if reconnect {
 996            start_proxy_command.push_str(" --reconnect");
 997        }
 998
 999        let ssh_proxy_process = socket
1000            .ssh_command(start_proxy_command)
1001            // IMPORTANT: we kill this process when we drop the task that uses it.
1002            .kill_on_drop(true)
1003            .spawn()
1004            .context("failed to spawn remote server")?;
1005
1006        Ok((ssh_connection, ssh_proxy_process))
1007    }
1008
1009    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1010        self.client.subscribe_to_entity(remote_id, entity);
1011    }
1012
1013    pub fn ssh_args(&self) -> Option<Vec<String>> {
1014        self.state
1015            .lock()
1016            .as_ref()
1017            .and_then(|state| state.ssh_connection())
1018            .map(|ssh_connection| ssh_connection.socket.ssh_args())
1019    }
1020
1021    pub fn to_proto_client(&self) -> AnyProtoClient {
1022        self.client.clone().into()
1023    }
1024
1025    pub fn connection_string(&self) -> String {
1026        self.connection_options.connection_string()
1027    }
1028
1029    pub fn connection_options(&self) -> SshConnectionOptions {
1030        self.connection_options.clone()
1031    }
1032
1033    #[cfg(not(any(test, feature = "test-support")))]
1034    pub fn connection_state(&self) -> ConnectionState {
1035        self.state
1036            .lock()
1037            .as_ref()
1038            .map(ConnectionState::from)
1039            .unwrap_or(ConnectionState::Disconnected)
1040    }
1041
1042    #[cfg(any(test, feature = "test-support"))]
1043    pub fn connection_state(&self) -> ConnectionState {
1044        ConnectionState::Connected
1045    }
1046
1047    pub fn is_disconnected(&self) -> bool {
1048        self.connection_state() == ConnectionState::Disconnected
1049    }
1050
1051    #[cfg(any(test, feature = "test-support"))]
1052    pub fn fake(
1053        client_cx: &mut gpui::TestAppContext,
1054        server_cx: &mut gpui::TestAppContext,
1055    ) -> (Model<Self>, Arc<ChannelClient>) {
1056        use gpui::Context;
1057
1058        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1059        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1060
1061        (
1062            client_cx.update(|cx| {
1063                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1064                cx.new_model(|_| Self {
1065                    client,
1066                    unique_identifier: "fake".to_string(),
1067                    connection_options: SshConnectionOptions::default(),
1068                    state: Arc::new(Mutex::new(None)),
1069                })
1070            }),
1071            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1072        )
1073    }
1074}
1075
1076impl From<SshRemoteClient> for AnyProtoClient {
1077    fn from(client: SshRemoteClient) -> Self {
1078        AnyProtoClient::new(client.client.clone())
1079    }
1080}
1081
1082struct SshRemoteConnection {
1083    socket: SshSocket,
1084    master_process: process::Child,
1085    _temp_dir: TempDir,
1086}
1087
1088impl Drop for SshRemoteConnection {
1089    fn drop(&mut self) {
1090        if let Err(error) = self.master_process.kill() {
1091            log::error!("failed to kill SSH master process: {}", error);
1092        }
1093    }
1094}
1095
1096impl SshRemoteConnection {
1097    #[cfg(not(unix))]
1098    async fn new(
1099        _connection_options: SshConnectionOptions,
1100        _delegate: Arc<dyn SshClientDelegate>,
1101        _cx: &mut AsyncAppContext,
1102    ) -> Result<Self> {
1103        Err(anyhow!("ssh is not supported on this platform"))
1104    }
1105
1106    #[cfg(unix)]
1107    async fn new(
1108        connection_options: SshConnectionOptions,
1109        delegate: Arc<dyn SshClientDelegate>,
1110        cx: &mut AsyncAppContext,
1111    ) -> Result<Self> {
1112        use futures::{io::BufReader, AsyncBufReadExt as _};
1113        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1114        use util::ResultExt as _;
1115
1116        delegate.set_status(Some("connecting"), cx);
1117
1118        let url = connection_options.ssh_url();
1119        let temp_dir = tempfile::Builder::new()
1120            .prefix("zed-ssh-session")
1121            .tempdir()?;
1122
1123        // Create a domain socket listener to handle requests from the askpass program.
1124        let askpass_socket = temp_dir.path().join("askpass.sock");
1125        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1126        let listener =
1127            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1128
1129        let askpass_task = cx.spawn({
1130            let delegate = delegate.clone();
1131            |mut cx| async move {
1132                let mut askpass_opened_tx = Some(askpass_opened_tx);
1133
1134                while let Ok((mut stream, _)) = listener.accept().await {
1135                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1136                        askpass_opened_tx.send(()).ok();
1137                    }
1138                    let mut buffer = Vec::new();
1139                    let mut reader = BufReader::new(&mut stream);
1140                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1141                        buffer.clear();
1142                    }
1143                    let password_prompt = String::from_utf8_lossy(&buffer);
1144                    if let Some(password) = delegate
1145                        .ask_password(password_prompt.to_string(), &mut cx)
1146                        .await
1147                        .context("failed to get ssh password")
1148                        .and_then(|p| p)
1149                        .log_err()
1150                    {
1151                        stream.write_all(password.as_bytes()).await.log_err();
1152                    }
1153                }
1154            }
1155        });
1156
1157        // Create an askpass script that communicates back to this process.
1158        let askpass_script = format!(
1159            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1160            askpass_socket = askpass_socket.display(),
1161            print_args = "printf '%s\\0' \"$@\"",
1162            shebang = "#!/bin/sh",
1163        );
1164        let askpass_script_path = temp_dir.path().join("askpass.sh");
1165        fs::write(&askpass_script_path, askpass_script).await?;
1166        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1167
1168        // Start the master SSH process, which does not do anything except for establish
1169        // the connection and keep it open, allowing other ssh commands to reuse it
1170        // via a control socket.
1171        let socket_path = temp_dir.path().join("ssh.sock");
1172        let mut master_process = process::Command::new("ssh")
1173            .stdin(Stdio::null())
1174            .stdout(Stdio::piped())
1175            .stderr(Stdio::piped())
1176            .env("SSH_ASKPASS_REQUIRE", "force")
1177            .env("SSH_ASKPASS", &askpass_script_path)
1178            .args([
1179                "-N",
1180                "-o",
1181                "ControlPersist=no",
1182                "-o",
1183                "ControlMaster=yes",
1184                "-o",
1185            ])
1186            .arg(format!("ControlPath={}", socket_path.display()))
1187            .arg(&url)
1188            .spawn()?;
1189
1190        // Wait for this ssh process to close its stdout, indicating that authentication
1191        // has completed.
1192        let stdout = master_process.stdout.as_mut().unwrap();
1193        let mut output = Vec::new();
1194        let connection_timeout = Duration::from_secs(10);
1195
1196        let result = select_biased! {
1197            _ = askpass_opened_rx.fuse() => {
1198                // If the askpass script has opened, that means the user is typing
1199                // their password, in which case we don't want to timeout anymore,
1200                // since we know a connection has been established.
1201                stdout.read_to_end(&mut output).await?;
1202                Ok(())
1203            }
1204            result = stdout.read_to_end(&mut output).fuse() => {
1205                result?;
1206                Ok(())
1207            }
1208            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1209                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1210            }
1211        };
1212
1213        if let Err(e) = result {
1214            let error_message = format!("Failed to connect to host: {}.", e);
1215            delegate.set_error(error_message, cx);
1216            return Err(e);
1217        }
1218
1219        drop(askpass_task);
1220
1221        if master_process.try_status()?.is_some() {
1222            output.clear();
1223            let mut stderr = master_process.stderr.take().unwrap();
1224            stderr.read_to_end(&mut output).await?;
1225
1226            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1227            delegate.set_error(error_message.clone(), cx);
1228            Err(anyhow!(error_message))?;
1229        }
1230
1231        Ok(Self {
1232            socket: SshSocket {
1233                connection_options,
1234                socket_path,
1235            },
1236            master_process,
1237            _temp_dir: temp_dir,
1238        })
1239    }
1240
1241    async fn ensure_server_binary(
1242        &self,
1243        delegate: &Arc<dyn SshClientDelegate>,
1244        src_path: &Path,
1245        dst_path: &Path,
1246        version: SemanticVersion,
1247        cx: &mut AsyncAppContext,
1248    ) -> Result<()> {
1249        let mut dst_path_gz = dst_path.to_path_buf();
1250        dst_path_gz.set_extension("gz");
1251
1252        if let Some(parent) = dst_path.parent() {
1253            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1254        }
1255
1256        let mut server_binary_exists = false;
1257        if cfg!(not(debug_assertions)) {
1258            if let Ok(installed_version) =
1259                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1260            {
1261                if installed_version.trim() == version.to_string() {
1262                    server_binary_exists = true;
1263                }
1264            }
1265        }
1266
1267        if server_binary_exists {
1268            log::info!("remote development server already present",);
1269            return Ok(());
1270        }
1271
1272        let src_stat = fs::metadata(src_path).await?;
1273        let size = src_stat.len();
1274        let server_mode = 0o755;
1275
1276        let t0 = Instant::now();
1277        delegate.set_status(Some("uploading remote development server"), cx);
1278        log::info!("uploading remote development server ({}kb)", size / 1024);
1279        self.upload_file(src_path, &dst_path_gz)
1280            .await
1281            .context("failed to upload server binary")?;
1282        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1283
1284        delegate.set_status(Some("extracting remote development server"), cx);
1285        run_cmd(
1286            self.socket
1287                .ssh_command("gunzip")
1288                .arg("--force")
1289                .arg(&dst_path_gz),
1290        )
1291        .await?;
1292
1293        delegate.set_status(Some("unzipping remote development server"), cx);
1294        run_cmd(
1295            self.socket
1296                .ssh_command("chmod")
1297                .arg(format!("{:o}", server_mode))
1298                .arg(dst_path),
1299        )
1300        .await?;
1301
1302        Ok(())
1303    }
1304
1305    async fn query_platform(&self) -> Result<SshPlatform> {
1306        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1307        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1308
1309        let os = match os.trim() {
1310            "Darwin" => "macos",
1311            "Linux" => "linux",
1312            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1313        };
1314        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1315            "aarch64"
1316        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1317            "x86_64"
1318        } else {
1319            Err(anyhow!("unknown uname architecture {arch:?}"))?
1320        };
1321
1322        Ok(SshPlatform { os, arch })
1323    }
1324
1325    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1326        let mut command = process::Command::new("scp");
1327        let output = self
1328            .socket
1329            .ssh_options(&mut command)
1330            .args(
1331                self.socket
1332                    .connection_options
1333                    .port
1334                    .map(|port| vec!["-P".to_string(), port.to_string()])
1335                    .unwrap_or_default(),
1336            )
1337            .arg(src_path)
1338            .arg(format!(
1339                "{}:{}",
1340                self.socket.connection_options.scp_url(),
1341                dest_path.display()
1342            ))
1343            .output()
1344            .await?;
1345
1346        if output.status.success() {
1347            Ok(())
1348        } else {
1349            Err(anyhow!(
1350                "failed to upload file {} -> {}: {}",
1351                src_path.display(),
1352                dest_path.display(),
1353                String::from_utf8_lossy(&output.stderr)
1354            ))
1355        }
1356    }
1357}
1358
1359type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1360
1361pub struct ChannelClient {
1362    next_message_id: AtomicU32,
1363    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1364    response_channels: ResponseChannels,             // Lock
1365    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1366}
1367
1368impl ChannelClient {
1369    pub fn new(
1370        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1371        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1372        cx: &AppContext,
1373    ) -> Arc<Self> {
1374        let this = Arc::new(Self {
1375            outgoing_tx,
1376            next_message_id: AtomicU32::new(0),
1377            response_channels: ResponseChannels::default(),
1378            message_handlers: Default::default(),
1379        });
1380
1381        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1382
1383        this
1384    }
1385
1386    fn start_handling_messages(
1387        this: Arc<Self>,
1388        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1389        cx: &AppContext,
1390    ) {
1391        cx.spawn(|cx| {
1392            let this = Arc::downgrade(&this);
1393            async move {
1394                let peer_id = PeerId { owner_id: 0, id: 0 };
1395                while let Some(incoming) = incoming_rx.next().await {
1396                    let Some(this) = this.upgrade() else {
1397                        return anyhow::Ok(());
1398                    };
1399
1400                    if let Some(request_id) = incoming.responding_to {
1401                        let request_id = MessageId(request_id);
1402                        let sender = this.response_channels.lock().remove(&request_id);
1403                        if let Some(sender) = sender {
1404                            let (tx, rx) = oneshot::channel();
1405                            if incoming.payload.is_some() {
1406                                sender.send((incoming, tx)).ok();
1407                            }
1408                            rx.await.ok();
1409                        }
1410                    } else if let Some(envelope) =
1411                        build_typed_envelope(peer_id, Instant::now(), incoming)
1412                    {
1413                        let type_name = envelope.payload_type_name();
1414                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1415                            &this.message_handlers,
1416                            envelope,
1417                            this.clone().into(),
1418                            cx.clone(),
1419                        ) {
1420                            log::debug!("ssh message received. name:{type_name}");
1421                            cx.foreground_executor().spawn(async move {
1422                                match future.await {
1423                                    Ok(_) => {
1424                                        log::debug!("ssh message handled. name:{type_name}");
1425                                    }
1426                                    Err(error) => {
1427                                        log::error!(
1428                                            "error handling message. type:{type_name}, error:{error}",
1429                                        );
1430                                    }
1431                                }
1432                            }).detach();
1433
1434                        } else {
1435                            log::error!("unhandled ssh message name:{type_name}");
1436                        }
1437                    }
1438                }
1439                anyhow::Ok(())
1440            }
1441        })
1442        .detach();
1443    }
1444
1445    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1446        let id = (TypeId::of::<E>(), remote_id);
1447
1448        let mut message_handlers = self.message_handlers.lock();
1449        if message_handlers
1450            .entities_by_type_and_remote_id
1451            .contains_key(&id)
1452        {
1453            panic!("already subscribed to entity");
1454        }
1455
1456        message_handlers.entities_by_type_and_remote_id.insert(
1457            id,
1458            EntityMessageSubscriber::Entity {
1459                handle: entity.downgrade().into(),
1460            },
1461        );
1462    }
1463
1464    pub fn request<T: RequestMessage>(
1465        &self,
1466        payload: T,
1467    ) -> impl 'static + Future<Output = Result<T::Response>> {
1468        log::debug!("ssh request start. name:{}", T::NAME);
1469        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1470        async move {
1471            let response = response.await?;
1472            log::debug!("ssh request finish. name:{}", T::NAME);
1473            T::Response::from_envelope(response)
1474                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1475        }
1476    }
1477
1478    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1479        smol::future::or(
1480            async {
1481                self.request(proto::Ping {}).await?;
1482                Ok(())
1483            },
1484            async {
1485                smol::Timer::after(timeout).await;
1486                Err(anyhow!("Timeout detected"))
1487            },
1488        )
1489        .await
1490    }
1491
1492    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1493        log::debug!("ssh send name:{}", T::NAME);
1494        self.send_dynamic(payload.into_envelope(0, None, None))
1495    }
1496
1497    pub fn request_dynamic(
1498        &self,
1499        mut envelope: proto::Envelope,
1500        type_name: &'static str,
1501    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1502        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1503        let (tx, rx) = oneshot::channel();
1504        let mut response_channels_lock = self.response_channels.lock();
1505        response_channels_lock.insert(MessageId(envelope.id), tx);
1506        drop(response_channels_lock);
1507        let result = self.outgoing_tx.unbounded_send(envelope);
1508        async move {
1509            if let Err(error) = &result {
1510                log::error!("failed to send message: {}", error);
1511                return Err(anyhow!("failed to send message: {}", error));
1512            }
1513
1514            let response = rx.await.context("connection lost")?.0;
1515            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1516                return Err(RpcError::from_proto(error, type_name));
1517            }
1518            Ok(response)
1519        }
1520    }
1521
1522    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1523        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1524        self.outgoing_tx.unbounded_send(envelope)?;
1525        Ok(())
1526    }
1527}
1528
1529impl ProtoClient for ChannelClient {
1530    fn request(
1531        &self,
1532        envelope: proto::Envelope,
1533        request_type: &'static str,
1534    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1535        self.request_dynamic(envelope, request_type).boxed()
1536    }
1537
1538    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1539        self.send_dynamic(envelope)
1540    }
1541
1542    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1543        self.send_dynamic(envelope)
1544    }
1545
1546    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1547        &self.message_handlers
1548    }
1549
1550    fn is_via_collab(&self) -> bool {
1551        false
1552    }
1553}