ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use collections::HashMap;
  10use futures::{
  11    channel::{
  12        mpsc::{self, UnboundedReceiver, UnboundedSender},
  13        oneshot,
  14    },
  15    future::BoxFuture,
  16    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  17    StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
  21    WeakModel,
  22};
  23use parking_lot::Mutex;
  24use rpc::{
  25    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  26    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  27};
  28use smol::{
  29    fs,
  30    process::{self, Child, Stdio},
  31    Timer,
  32};
  33use std::{
  34    any::TypeId,
  35    ffi::OsStr,
  36    fmt,
  37    ops::ControlFlow,
  38    path::{Path, PathBuf},
  39    sync::{
  40        atomic::{AtomicU32, Ordering::SeqCst},
  41        Arc,
  42    },
  43    time::{Duration, Instant},
  44};
  45use tempfile::TempDir;
  46use util::ResultExt;
  47
  48#[derive(
  49    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  50)]
  51pub struct SshProjectId(pub u64);
  52
  53#[derive(Clone)]
  54pub struct SshSocket {
  55    connection_options: SshConnectionOptions,
  56    socket_path: PathBuf,
  57}
  58
  59#[derive(Debug, Default, Clone, PartialEq, Eq)]
  60pub struct SshConnectionOptions {
  61    pub host: String,
  62    pub username: Option<String>,
  63    pub port: Option<u16>,
  64    pub password: Option<String>,
  65}
  66
  67impl SshConnectionOptions {
  68    pub fn ssh_url(&self) -> String {
  69        let mut result = String::from("ssh://");
  70        if let Some(username) = &self.username {
  71            result.push_str(username);
  72            result.push('@');
  73        }
  74        result.push_str(&self.host);
  75        if let Some(port) = self.port {
  76            result.push(':');
  77            result.push_str(&port.to_string());
  78        }
  79        result
  80    }
  81
  82    fn scp_url(&self) -> String {
  83        if let Some(username) = &self.username {
  84            format!("{}@{}", username, self.host)
  85        } else {
  86            self.host.clone()
  87        }
  88    }
  89
  90    pub fn connection_string(&self) -> String {
  91        let host = if let Some(username) = &self.username {
  92            format!("{}@{}", username, self.host)
  93        } else {
  94            self.host.clone()
  95        };
  96        if let Some(port) = &self.port {
  97            format!("{}:{}", host, port)
  98        } else {
  99            host
 100        }
 101    }
 102
 103    // Uniquely identifies dev server projects on a remote host. Needs to be
 104    // stable for the same dev server project.
 105    pub fn dev_server_identifier(&self) -> String {
 106        let mut identifier = format!("dev-server-{:?}", self.host);
 107        if let Some(username) = self.username.as_ref() {
 108            identifier.push('-');
 109            identifier.push_str(&username);
 110        }
 111        identifier
 112    }
 113}
 114
 115#[derive(Copy, Clone, Debug)]
 116pub struct SshPlatform {
 117    pub os: &'static str,
 118    pub arch: &'static str,
 119}
 120
 121pub trait SshClientDelegate: Send + Sync {
 122    fn ask_password(
 123        &self,
 124        prompt: String,
 125        cx: &mut AsyncAppContext,
 126    ) -> oneshot::Receiver<Result<String>>;
 127    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 128    fn get_server_binary(
 129        &self,
 130        platform: SshPlatform,
 131        cx: &mut AsyncAppContext,
 132    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 133    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 134    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 135}
 136
 137impl SshSocket {
 138    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 139        let mut command = process::Command::new("ssh");
 140        self.ssh_options(&mut command)
 141            .arg(self.connection_options.ssh_url())
 142            .arg(program);
 143        command
 144    }
 145
 146    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 147        command
 148            .stdin(Stdio::piped())
 149            .stdout(Stdio::piped())
 150            .stderr(Stdio::piped())
 151            .args(["-o", "ControlMaster=no", "-o"])
 152            .arg(format!("ControlPath={}", self.socket_path.display()))
 153    }
 154
 155    fn ssh_args(&self) -> Vec<String> {
 156        vec![
 157            "-o".to_string(),
 158            "ControlMaster=no".to_string(),
 159            "-o".to_string(),
 160            format!("ControlPath={}", self.socket_path.display()),
 161            self.connection_options.ssh_url(),
 162        ]
 163    }
 164}
 165
 166async fn run_cmd(command: &mut process::Command) -> Result<String> {
 167    let output = command.output().await?;
 168    if output.status.success() {
 169        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 170    } else {
 171        Err(anyhow!(
 172            "failed to run command: {}",
 173            String::from_utf8_lossy(&output.stderr)
 174        ))
 175    }
 176}
 177
 178struct ChannelForwarder {
 179    quit_tx: UnboundedSender<()>,
 180    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 181}
 182
 183impl ChannelForwarder {
 184    fn new(
 185        mut incoming_tx: UnboundedSender<Envelope>,
 186        mut outgoing_rx: UnboundedReceiver<Envelope>,
 187        cx: &AsyncAppContext,
 188    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 189        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 190
 191        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 192        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 193
 194        let forwarding_task = cx.background_executor().spawn(async move {
 195            loop {
 196                select_biased! {
 197                    _ = quit_rx.next().fuse() => {
 198                        break;
 199                    },
 200                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 201                        if let Some(envelope) = incoming_envelope {
 202                            if incoming_tx.send(envelope).await.is_err() {
 203                                break;
 204                            }
 205                        } else {
 206                            break;
 207                        }
 208                    }
 209                    outgoing_envelope = outgoing_rx.next().fuse() => {
 210                        if let Some(envelope) = outgoing_envelope {
 211                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 212                                break;
 213                            }
 214                        } else {
 215                            break;
 216                        }
 217                    }
 218                }
 219            }
 220
 221            (incoming_tx, outgoing_rx)
 222        });
 223
 224        (
 225            Self {
 226                forwarding_task,
 227                quit_tx,
 228            },
 229            proxy_incoming_tx,
 230            proxy_outgoing_rx,
 231        )
 232    }
 233
 234    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 235        let _ = self.quit_tx.send(()).await;
 236        self.forwarding_task.await
 237    }
 238}
 239
 240const MAX_MISSED_HEARTBEATS: usize = 5;
 241const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 242const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 243
 244const MAX_RECONNECT_ATTEMPTS: usize = 3;
 245
 246enum State {
 247    Connecting,
 248    Connected {
 249        ssh_connection: SshRemoteConnection,
 250        delegate: Arc<dyn SshClientDelegate>,
 251        forwarder: ChannelForwarder,
 252
 253        multiplex_task: Task<Result<()>>,
 254        heartbeat_task: Task<Result<()>>,
 255    },
 256    HeartbeatMissed {
 257        missed_heartbeats: usize,
 258
 259        ssh_connection: SshRemoteConnection,
 260        delegate: Arc<dyn SshClientDelegate>,
 261        forwarder: ChannelForwarder,
 262
 263        multiplex_task: Task<Result<()>>,
 264        heartbeat_task: Task<Result<()>>,
 265    },
 266    Reconnecting,
 267    ReconnectFailed {
 268        ssh_connection: SshRemoteConnection,
 269        delegate: Arc<dyn SshClientDelegate>,
 270        forwarder: ChannelForwarder,
 271
 272        error: anyhow::Error,
 273        attempts: usize,
 274    },
 275    ReconnectExhausted,
 276    ServerNotRunning,
 277}
 278
 279impl fmt::Display for State {
 280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 281        match self {
 282            Self::Connecting => write!(f, "connecting"),
 283            Self::Connected { .. } => write!(f, "connected"),
 284            Self::Reconnecting => write!(f, "reconnecting"),
 285            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 286            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 287            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 288            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 289        }
 290    }
 291}
 292
 293impl State {
 294    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
 295        match self {
 296            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
 297            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
 298            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
 299            _ => None,
 300        }
 301    }
 302
 303    fn can_reconnect(&self) -> bool {
 304        match self {
 305            Self::Connected { .. }
 306            | Self::HeartbeatMissed { .. }
 307            | Self::ReconnectFailed { .. } => true,
 308            State::Connecting
 309            | State::Reconnecting
 310            | State::ReconnectExhausted
 311            | State::ServerNotRunning => false,
 312        }
 313    }
 314
 315    fn is_reconnect_failed(&self) -> bool {
 316        matches!(self, Self::ReconnectFailed { .. })
 317    }
 318
 319    fn is_reconnect_exhausted(&self) -> bool {
 320        matches!(self, Self::ReconnectExhausted { .. })
 321    }
 322
 323    fn is_reconnecting(&self) -> bool {
 324        matches!(self, Self::Reconnecting { .. })
 325    }
 326
 327    fn heartbeat_recovered(self) -> Self {
 328        match self {
 329            Self::HeartbeatMissed {
 330                ssh_connection,
 331                delegate,
 332                forwarder,
 333                multiplex_task,
 334                heartbeat_task,
 335                ..
 336            } => Self::Connected {
 337                ssh_connection,
 338                delegate,
 339                forwarder,
 340                multiplex_task,
 341                heartbeat_task,
 342            },
 343            _ => self,
 344        }
 345    }
 346
 347    fn heartbeat_missed(self) -> Self {
 348        match self {
 349            Self::Connected {
 350                ssh_connection,
 351                delegate,
 352                forwarder,
 353                multiplex_task,
 354                heartbeat_task,
 355            } => Self::HeartbeatMissed {
 356                missed_heartbeats: 1,
 357                ssh_connection,
 358                delegate,
 359                forwarder,
 360                multiplex_task,
 361                heartbeat_task,
 362            },
 363            Self::HeartbeatMissed {
 364                missed_heartbeats,
 365                ssh_connection,
 366                delegate,
 367                forwarder,
 368                multiplex_task,
 369                heartbeat_task,
 370            } => Self::HeartbeatMissed {
 371                missed_heartbeats: missed_heartbeats + 1,
 372                ssh_connection,
 373                delegate,
 374                forwarder,
 375                multiplex_task,
 376                heartbeat_task,
 377            },
 378            _ => self,
 379        }
 380    }
 381}
 382
 383/// The state of the ssh connection.
 384#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 385pub enum ConnectionState {
 386    Connecting,
 387    Connected,
 388    HeartbeatMissed,
 389    Reconnecting,
 390    Disconnected,
 391}
 392
 393impl From<&State> for ConnectionState {
 394    fn from(value: &State) -> Self {
 395        match value {
 396            State::Connecting => Self::Connecting,
 397            State::Connected { .. } => Self::Connected,
 398            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 399            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 400            State::ReconnectExhausted => Self::Disconnected,
 401            State::ServerNotRunning => Self::Disconnected,
 402        }
 403    }
 404}
 405
 406pub struct SshRemoteClient {
 407    client: Arc<ChannelClient>,
 408    unique_identifier: String,
 409    connection_options: SshConnectionOptions,
 410    state: Arc<Mutex<Option<State>>>,
 411}
 412
 413impl Drop for SshRemoteClient {
 414    fn drop(&mut self) {
 415        self.shutdown_processes();
 416    }
 417}
 418
 419#[derive(Debug)]
 420pub enum SshRemoteEvent {
 421    Disconnected,
 422}
 423
 424impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 425
 426impl SshRemoteClient {
 427    pub fn new(
 428        unique_identifier: String,
 429        connection_options: SshConnectionOptions,
 430        delegate: Arc<dyn SshClientDelegate>,
 431        cx: &AppContext,
 432    ) -> Task<Result<Model<Self>>> {
 433        cx.spawn(|mut cx| async move {
 434            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 435            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 436
 437            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 438            let this = cx.new_model(|cx| {
 439                cx.on_app_quit(|this: &mut Self, _| {
 440                    this.shutdown_processes();
 441                    futures::future::ready(())
 442                })
 443                .detach();
 444
 445                Self {
 446                    client: client.clone(),
 447                    unique_identifier: unique_identifier.clone(),
 448                    connection_options: connection_options.clone(),
 449                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 450                }
 451            })?;
 452
 453            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 454                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 455
 456            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 457                unique_identifier,
 458                false,
 459                connection_options,
 460                delegate.clone(),
 461                &mut cx,
 462            )
 463            .await?;
 464
 465            let multiplex_task = Self::multiplex(
 466                this.downgrade(),
 467                ssh_proxy_process,
 468                proxy_incoming_tx,
 469                proxy_outgoing_rx,
 470                &mut cx,
 471            );
 472
 473            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 474                log::error!("failed to establish connection: {}", error);
 475                delegate.set_error(error.to_string(), &mut cx);
 476                return Err(error);
 477            }
 478
 479            let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
 480
 481            this.update(&mut cx, |this, _| {
 482                *this.state.lock() = Some(State::Connected {
 483                    ssh_connection,
 484                    delegate,
 485                    forwarder: proxy,
 486                    multiplex_task,
 487                    heartbeat_task,
 488                });
 489            })?;
 490
 491            Ok(this)
 492        })
 493    }
 494
 495    fn shutdown_processes(&self) {
 496        let Some(state) = self.state.lock().take() else {
 497            return;
 498        };
 499        log::info!("shutting down ssh processes");
 500
 501        let State::Connected {
 502            multiplex_task,
 503            heartbeat_task,
 504            ..
 505        } = state
 506        else {
 507            return;
 508        };
 509        // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 510        // child of master_process.
 511        drop(multiplex_task);
 512        // Now drop the rest of state, which kills master process.
 513        drop(heartbeat_task);
 514    }
 515
 516    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 517        let mut lock = self.state.lock();
 518
 519        let can_reconnect = lock
 520            .as_ref()
 521            .map(|state| state.can_reconnect())
 522            .unwrap_or(false);
 523        if !can_reconnect {
 524            let error = if let Some(state) = lock.as_ref() {
 525                format!("invalid state, cannot reconnect while in state {state}")
 526            } else {
 527                "no state set".to_string()
 528            };
 529            log::info!("aborting reconnect, because not in state that allows reconnecting");
 530            return Err(anyhow!(error));
 531        }
 532
 533        let state = lock.take().unwrap();
 534        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 535            State::Connected {
 536                ssh_connection,
 537                delegate,
 538                forwarder,
 539                multiplex_task,
 540                heartbeat_task,
 541            }
 542            | State::HeartbeatMissed {
 543                ssh_connection,
 544                delegate,
 545                forwarder,
 546                multiplex_task,
 547                heartbeat_task,
 548                ..
 549            } => {
 550                drop(multiplex_task);
 551                drop(heartbeat_task);
 552                (0, ssh_connection, delegate, forwarder)
 553            }
 554            State::ReconnectFailed {
 555                attempts,
 556                ssh_connection,
 557                delegate,
 558                forwarder,
 559                ..
 560            } => (attempts, ssh_connection, delegate, forwarder),
 561            State::Connecting
 562            | State::Reconnecting
 563            | State::ReconnectExhausted
 564            | State::ServerNotRunning => unreachable!(),
 565        };
 566
 567        let attempts = attempts + 1;
 568        if attempts > MAX_RECONNECT_ATTEMPTS {
 569            log::error!(
 570                "Failed to reconnect to after {} attempts, giving up",
 571                MAX_RECONNECT_ATTEMPTS
 572            );
 573            drop(lock);
 574            self.set_state(State::ReconnectExhausted, cx);
 575            return Ok(());
 576        }
 577        drop(lock);
 578
 579        self.set_state(State::Reconnecting, cx);
 580
 581        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 582
 583        let identifier = self.unique_identifier.clone();
 584        let client = self.client.clone();
 585        let reconnect_task = cx.spawn(|this, mut cx| async move {
 586            macro_rules! failed {
 587                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 588                    return State::ReconnectFailed {
 589                        error: anyhow!($error),
 590                        attempts: $attempts,
 591                        ssh_connection: $ssh_connection,
 592                        delegate: $delegate,
 593                        forwarder: $forwarder,
 594                    };
 595                };
 596            }
 597
 598            if let Err(error) = ssh_connection.master_process.kill() {
 599                failed!(error, attempts, ssh_connection, delegate, forwarder);
 600            };
 601
 602            if let Err(error) = ssh_connection
 603                .master_process
 604                .status()
 605                .await
 606                .context("Failed to kill ssh process")
 607            {
 608                failed!(error, attempts, ssh_connection, delegate, forwarder);
 609            }
 610
 611            let connection_options = ssh_connection.socket.connection_options.clone();
 612
 613            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 614            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 615                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 616
 617            let (ssh_connection, ssh_process) = match Self::establish_connection(
 618                identifier,
 619                true,
 620                connection_options,
 621                delegate.clone(),
 622                &mut cx,
 623            )
 624            .await
 625            {
 626                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 627                Err(error) => {
 628                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 629                }
 630            };
 631
 632            let multiplex_task = Self::multiplex(
 633                this.clone(),
 634                ssh_process,
 635                proxy_incoming_tx,
 636                proxy_outgoing_rx,
 637                &mut cx,
 638            );
 639
 640            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 641                failed!(error, attempts, ssh_connection, delegate, forwarder);
 642            };
 643
 644            State::Connected {
 645                ssh_connection,
 646                delegate,
 647                forwarder,
 648                multiplex_task,
 649                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
 650            }
 651        });
 652
 653        cx.spawn(|this, mut cx| async move {
 654            let new_state = reconnect_task.await;
 655            this.update(&mut cx, |this, cx| {
 656                this.try_set_state(cx, |old_state| {
 657                    if old_state.is_reconnecting() {
 658                        match &new_state {
 659                            State::Connecting
 660                            | State::Reconnecting { .. }
 661                            | State::HeartbeatMissed { .. }
 662                            | State::ServerNotRunning => {}
 663                            State::Connected { .. } => {
 664                                log::info!("Successfully reconnected");
 665                            }
 666                            State::ReconnectFailed {
 667                                error, attempts, ..
 668                            } => {
 669                                log::error!(
 670                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 671                                    attempts,
 672                                    error
 673                                );
 674                            }
 675                            State::ReconnectExhausted => {
 676                                log::error!("Reconnect attempt failed and all attempts exhausted");
 677                            }
 678                        }
 679                        Some(new_state)
 680                    } else {
 681                        None
 682                    }
 683                });
 684
 685                if this.state_is(State::is_reconnect_failed) {
 686                    this.reconnect(cx)
 687                } else if this.state_is(State::is_reconnect_exhausted) {
 688                    cx.emit(SshRemoteEvent::Disconnected);
 689                    Ok(())
 690                } else {
 691                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 692                    Ok(())
 693                }
 694            })
 695        })
 696        .detach_and_log_err(cx);
 697
 698        Ok(())
 699    }
 700
 701    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
 702        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 703            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 704        };
 705        cx.spawn(|mut cx| {
 706            let this = this.clone();
 707            async move {
 708                let mut missed_heartbeats = 0;
 709
 710                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
 711                loop {
 712                    timer.next().await;
 713
 714                    log::debug!("Sending heartbeat to server...");
 715
 716                    let result = client.ping(HEARTBEAT_TIMEOUT).await;
 717                    if result.is_err() {
 718                        missed_heartbeats += 1;
 719                        log::warn!(
 720                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 721                            HEARTBEAT_TIMEOUT,
 722                            missed_heartbeats,
 723                            MAX_MISSED_HEARTBEATS
 724                        );
 725                    } else if missed_heartbeats != 0 {
 726                        missed_heartbeats = 0;
 727                    } else {
 728                        continue;
 729                    }
 730
 731                    let result = this.update(&mut cx, |this, mut cx| {
 732                        this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 733                    })?;
 734                    if result.is_break() {
 735                        return Ok(());
 736                    }
 737                }
 738            }
 739        })
 740    }
 741
 742    fn handle_heartbeat_result(
 743        &mut self,
 744        missed_heartbeats: usize,
 745        cx: &mut ModelContext<Self>,
 746    ) -> ControlFlow<()> {
 747        let state = self.state.lock().take().unwrap();
 748        let next_state = if missed_heartbeats > 0 {
 749            state.heartbeat_missed()
 750        } else {
 751            state.heartbeat_recovered()
 752        };
 753
 754        self.set_state(next_state, cx);
 755
 756        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 757            log::error!(
 758                "Missed last {} heartbeats. Reconnecting...",
 759                missed_heartbeats
 760            );
 761
 762            self.reconnect(cx)
 763                .context("failed to start reconnect process after missing heartbeats")
 764                .log_err();
 765            ControlFlow::Break(())
 766        } else {
 767            ControlFlow::Continue(())
 768        }
 769    }
 770
 771    fn multiplex(
 772        this: WeakModel<Self>,
 773        mut ssh_proxy_process: Child,
 774        incoming_tx: UnboundedSender<Envelope>,
 775        mut outgoing_rx: UnboundedReceiver<Envelope>,
 776        cx: &AsyncAppContext,
 777    ) -> Task<Result<()>> {
 778        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 779        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 780        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 781
 782        let io_task = cx.background_executor().spawn(async move {
 783            let mut stdin_buffer = Vec::new();
 784            let mut stdout_buffer = Vec::new();
 785            let mut stderr_buffer = Vec::new();
 786            let mut stderr_offset = 0;
 787
 788            loop {
 789                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 790                stderr_buffer.resize(stderr_offset + 1024, 0);
 791
 792                select_biased! {
 793                    outgoing = outgoing_rx.next().fuse() => {
 794                        let Some(outgoing) = outgoing else {
 795                            return anyhow::Ok(None);
 796                        };
 797
 798                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 799                    }
 800
 801                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 802                        match result {
 803                            Ok(0) => {
 804                                child_stdin.close().await?;
 805                                outgoing_rx.close();
 806                                let status = ssh_proxy_process.status().await?;
 807                                return Ok(status.code());
 808                            }
 809                            Ok(len) => {
 810                                if len < stdout_buffer.len() {
 811                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 812                                }
 813
 814                                let message_len = message_len_from_buffer(&stdout_buffer);
 815                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 816                                    Ok(envelope) => {
 817                                        incoming_tx.unbounded_send(envelope).ok();
 818                                    }
 819                                    Err(error) => {
 820                                        log::error!("error decoding message {error:?}");
 821                                    }
 822                                }
 823                            }
 824                            Err(error) => {
 825                                Err(anyhow!("error reading stdout: {error:?}"))?;
 826                            }
 827                        }
 828                    }
 829
 830                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 831                        match result {
 832                            Ok(len) => {
 833                                stderr_offset += len;
 834                                let mut start_ix = 0;
 835                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 836                                    let line_ix = start_ix + ix;
 837                                    let content = &stderr_buffer[start_ix..line_ix];
 838                                    start_ix = line_ix + 1;
 839                                    if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
 840                                        record.message = format!("(remote) {}", record.message);
 841                                        record.log(log::logger())
 842                                    } else {
 843                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 844                                    }
 845                                }
 846                                stderr_buffer.drain(0..start_ix);
 847                                stderr_offset -= start_ix;
 848                            }
 849                            Err(error) => {
 850                                Err(anyhow!("error reading stderr: {error:?}"))?;
 851                            }
 852                        }
 853                    }
 854                }
 855            }
 856        });
 857
 858        cx.spawn(|mut cx| async move {
 859            let result = io_task.await;
 860
 861            match result {
 862                Ok(Some(exit_code)) => {
 863                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 864                        match error {
 865                            ProxyLaunchError::ServerNotRunning => {
 866                                log::error!("failed to reconnect because server is not running");
 867                                this.update(&mut cx, |this, cx| {
 868                                    this.set_state(State::ServerNotRunning, cx);
 869                                    cx.emit(SshRemoteEvent::Disconnected);
 870                                })?;
 871                            }
 872                        }
 873                    } else if exit_code > 0 {
 874                        log::error!("proxy process terminated unexpectedly");
 875                        this.update(&mut cx, |this, cx| {
 876                            this.reconnect(cx).ok();
 877                        })?;
 878                    }
 879                }
 880                Ok(None) => {}
 881                Err(error) => {
 882                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 883                    this.update(&mut cx, |this, cx| {
 884                        this.reconnect(cx).ok();
 885                    })?;
 886                }
 887            }
 888            Ok(())
 889        })
 890    }
 891
 892    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 893        self.state.lock().as_ref().map_or(false, check)
 894    }
 895
 896    fn try_set_state(
 897        &self,
 898        cx: &mut ModelContext<Self>,
 899        map: impl FnOnce(&State) -> Option<State>,
 900    ) {
 901        let mut lock = self.state.lock();
 902        let new_state = lock.as_ref().and_then(map);
 903
 904        if let Some(new_state) = new_state {
 905            lock.replace(new_state);
 906            cx.notify();
 907        }
 908    }
 909
 910    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 911        log::info!("setting state to '{}'", &state);
 912        self.state.lock().replace(state);
 913        cx.notify();
 914    }
 915
 916    async fn establish_connection(
 917        unique_identifier: String,
 918        reconnect: bool,
 919        connection_options: SshConnectionOptions,
 920        delegate: Arc<dyn SshClientDelegate>,
 921        cx: &mut AsyncAppContext,
 922    ) -> Result<(SshRemoteConnection, Child)> {
 923        let ssh_connection =
 924            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 925
 926        let platform = ssh_connection.query_platform().await?;
 927        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 928        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 929        ssh_connection
 930            .ensure_server_binary(
 931                &delegate,
 932                &local_binary_path,
 933                &remote_binary_path,
 934                version,
 935                cx,
 936            )
 937            .await?;
 938
 939        let socket = ssh_connection.socket.clone();
 940        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 941
 942        delegate.set_status(Some("Starting proxy"), cx);
 943
 944        let mut start_proxy_command = format!(
 945            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 946            std::env::var("RUST_LOG").unwrap_or_default(),
 947            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 948            remote_binary_path,
 949            unique_identifier,
 950        );
 951        if reconnect {
 952            start_proxy_command.push_str(" --reconnect");
 953        }
 954
 955        let ssh_proxy_process = socket
 956            .ssh_command(start_proxy_command)
 957            // IMPORTANT: we kill this process when we drop the task that uses it.
 958            .kill_on_drop(true)
 959            .spawn()
 960            .context("failed to spawn remote server")?;
 961
 962        Ok((ssh_connection, ssh_proxy_process))
 963    }
 964
 965    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 966        self.client.subscribe_to_entity(remote_id, entity);
 967    }
 968
 969    pub fn ssh_args(&self) -> Option<Vec<String>> {
 970        self.state
 971            .lock()
 972            .as_ref()
 973            .and_then(|state| state.ssh_connection())
 974            .map(|ssh_connection| ssh_connection.socket.ssh_args())
 975    }
 976
 977    pub fn to_proto_client(&self) -> AnyProtoClient {
 978        self.client.clone().into()
 979    }
 980
 981    pub fn connection_string(&self) -> String {
 982        self.connection_options.connection_string()
 983    }
 984
 985    pub fn connection_options(&self) -> SshConnectionOptions {
 986        self.connection_options.clone()
 987    }
 988
 989    #[cfg(not(any(test, feature = "test-support")))]
 990    pub fn connection_state(&self) -> ConnectionState {
 991        self.state
 992            .lock()
 993            .as_ref()
 994            .map(ConnectionState::from)
 995            .unwrap_or(ConnectionState::Disconnected)
 996    }
 997
 998    #[cfg(any(test, feature = "test-support"))]
 999    pub fn connection_state(&self) -> ConnectionState {
1000        ConnectionState::Connected
1001    }
1002
1003    pub fn is_disconnected(&self) -> bool {
1004        self.connection_state() == ConnectionState::Disconnected
1005    }
1006
1007    #[cfg(any(test, feature = "test-support"))]
1008    pub fn fake(
1009        client_cx: &mut gpui::TestAppContext,
1010        server_cx: &mut gpui::TestAppContext,
1011    ) -> (Model<Self>, Arc<ChannelClient>) {
1012        use gpui::Context;
1013
1014        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1015        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1016
1017        (
1018            client_cx.update(|cx| {
1019                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1020                cx.new_model(|_| Self {
1021                    client,
1022                    unique_identifier: "fake".to_string(),
1023                    connection_options: SshConnectionOptions::default(),
1024                    state: Arc::new(Mutex::new(None)),
1025                })
1026            }),
1027            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1028        )
1029    }
1030}
1031
1032impl From<SshRemoteClient> for AnyProtoClient {
1033    fn from(client: SshRemoteClient) -> Self {
1034        AnyProtoClient::new(client.client.clone())
1035    }
1036}
1037
1038struct SshRemoteConnection {
1039    socket: SshSocket,
1040    master_process: process::Child,
1041    _temp_dir: TempDir,
1042}
1043
1044impl Drop for SshRemoteConnection {
1045    fn drop(&mut self) {
1046        if let Err(error) = self.master_process.kill() {
1047            log::error!("failed to kill SSH master process: {}", error);
1048        }
1049    }
1050}
1051
1052impl SshRemoteConnection {
1053    #[cfg(not(unix))]
1054    async fn new(
1055        _connection_options: SshConnectionOptions,
1056        _delegate: Arc<dyn SshClientDelegate>,
1057        _cx: &mut AsyncAppContext,
1058    ) -> Result<Self> {
1059        Err(anyhow!("ssh is not supported on this platform"))
1060    }
1061
1062    #[cfg(unix)]
1063    async fn new(
1064        connection_options: SshConnectionOptions,
1065        delegate: Arc<dyn SshClientDelegate>,
1066        cx: &mut AsyncAppContext,
1067    ) -> Result<Self> {
1068        use futures::{io::BufReader, AsyncBufReadExt as _};
1069        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1070        use util::ResultExt as _;
1071
1072        delegate.set_status(Some("connecting"), cx);
1073
1074        let url = connection_options.ssh_url();
1075        let temp_dir = tempfile::Builder::new()
1076            .prefix("zed-ssh-session")
1077            .tempdir()?;
1078
1079        // Create a domain socket listener to handle requests from the askpass program.
1080        let askpass_socket = temp_dir.path().join("askpass.sock");
1081        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1082        let listener =
1083            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1084
1085        let askpass_task = cx.spawn({
1086            let delegate = delegate.clone();
1087            |mut cx| async move {
1088                let mut askpass_opened_tx = Some(askpass_opened_tx);
1089
1090                while let Ok((mut stream, _)) = listener.accept().await {
1091                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1092                        askpass_opened_tx.send(()).ok();
1093                    }
1094                    let mut buffer = Vec::new();
1095                    let mut reader = BufReader::new(&mut stream);
1096                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1097                        buffer.clear();
1098                    }
1099                    let password_prompt = String::from_utf8_lossy(&buffer);
1100                    if let Some(password) = delegate
1101                        .ask_password(password_prompt.to_string(), &mut cx)
1102                        .await
1103                        .context("failed to get ssh password")
1104                        .and_then(|p| p)
1105                        .log_err()
1106                    {
1107                        stream.write_all(password.as_bytes()).await.log_err();
1108                    }
1109                }
1110            }
1111        });
1112
1113        // Create an askpass script that communicates back to this process.
1114        let askpass_script = format!(
1115            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1116            askpass_socket = askpass_socket.display(),
1117            print_args = "printf '%s\\0' \"$@\"",
1118            shebang = "#!/bin/sh",
1119        );
1120        let askpass_script_path = temp_dir.path().join("askpass.sh");
1121        fs::write(&askpass_script_path, askpass_script).await?;
1122        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1123
1124        // Start the master SSH process, which does not do anything except for establish
1125        // the connection and keep it open, allowing other ssh commands to reuse it
1126        // via a control socket.
1127        let socket_path = temp_dir.path().join("ssh.sock");
1128        let mut master_process = process::Command::new("ssh")
1129            .stdin(Stdio::null())
1130            .stdout(Stdio::piped())
1131            .stderr(Stdio::piped())
1132            .env("SSH_ASKPASS_REQUIRE", "force")
1133            .env("SSH_ASKPASS", &askpass_script_path)
1134            .args(["-N", "-o", "ControlMaster=yes", "-o"])
1135            .arg(format!("ControlPath={}", socket_path.display()))
1136            .arg(&url)
1137            .spawn()?;
1138
1139        // Wait for this ssh process to close its stdout, indicating that authentication
1140        // has completed.
1141        let stdout = master_process.stdout.as_mut().unwrap();
1142        let mut output = Vec::new();
1143        let connection_timeout = Duration::from_secs(10);
1144
1145        let result = select_biased! {
1146            _ = askpass_opened_rx.fuse() => {
1147                // If the askpass script has opened, that means the user is typing
1148                // their password, in which case we don't want to timeout anymore,
1149                // since we know a connection has been established.
1150                stdout.read_to_end(&mut output).await?;
1151                Ok(())
1152            }
1153            result = stdout.read_to_end(&mut output).fuse() => {
1154                result?;
1155                Ok(())
1156            }
1157            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1158                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1159            }
1160        };
1161
1162        if let Err(e) = result {
1163            let error_message = format!("Failed to connect to host: {}.", e);
1164            delegate.set_error(error_message, cx);
1165            return Err(e);
1166        }
1167
1168        drop(askpass_task);
1169
1170        if master_process.try_status()?.is_some() {
1171            output.clear();
1172            let mut stderr = master_process.stderr.take().unwrap();
1173            stderr.read_to_end(&mut output).await?;
1174
1175            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1176            delegate.set_error(error_message.clone(), cx);
1177            Err(anyhow!(error_message))?;
1178        }
1179
1180        Ok(Self {
1181            socket: SshSocket {
1182                connection_options,
1183                socket_path,
1184            },
1185            master_process,
1186            _temp_dir: temp_dir,
1187        })
1188    }
1189
1190    async fn ensure_server_binary(
1191        &self,
1192        delegate: &Arc<dyn SshClientDelegate>,
1193        src_path: &Path,
1194        dst_path: &Path,
1195        version: SemanticVersion,
1196        cx: &mut AsyncAppContext,
1197    ) -> Result<()> {
1198        let mut dst_path_gz = dst_path.to_path_buf();
1199        dst_path_gz.set_extension("gz");
1200
1201        if let Some(parent) = dst_path.parent() {
1202            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1203        }
1204
1205        let mut server_binary_exists = false;
1206        if cfg!(not(debug_assertions)) {
1207            if let Ok(installed_version) =
1208                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1209            {
1210                if installed_version.trim() == version.to_string() {
1211                    server_binary_exists = true;
1212                }
1213            }
1214        }
1215
1216        if server_binary_exists {
1217            log::info!("remote development server already present",);
1218            return Ok(());
1219        }
1220
1221        let src_stat = fs::metadata(src_path).await?;
1222        let size = src_stat.len();
1223        let server_mode = 0o755;
1224
1225        let t0 = Instant::now();
1226        delegate.set_status(Some("uploading remote development server"), cx);
1227        log::info!("uploading remote development server ({}kb)", size / 1024);
1228        self.upload_file(src_path, &dst_path_gz)
1229            .await
1230            .context("failed to upload server binary")?;
1231        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1232
1233        delegate.set_status(Some("extracting remote development server"), cx);
1234        run_cmd(
1235            self.socket
1236                .ssh_command("gunzip")
1237                .arg("--force")
1238                .arg(&dst_path_gz),
1239        )
1240        .await?;
1241
1242        delegate.set_status(Some("unzipping remote development server"), cx);
1243        run_cmd(
1244            self.socket
1245                .ssh_command("chmod")
1246                .arg(format!("{:o}", server_mode))
1247                .arg(dst_path),
1248        )
1249        .await?;
1250
1251        Ok(())
1252    }
1253
1254    async fn query_platform(&self) -> Result<SshPlatform> {
1255        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1256        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1257
1258        let os = match os.trim() {
1259            "Darwin" => "macos",
1260            "Linux" => "linux",
1261            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1262        };
1263        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1264            "aarch64"
1265        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1266            "x86_64"
1267        } else {
1268            Err(anyhow!("unknown uname architecture {arch:?}"))?
1269        };
1270
1271        Ok(SshPlatform { os, arch })
1272    }
1273
1274    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1275        let mut command = process::Command::new("scp");
1276        let output = self
1277            .socket
1278            .ssh_options(&mut command)
1279            .args(
1280                self.socket
1281                    .connection_options
1282                    .port
1283                    .map(|port| vec!["-P".to_string(), port.to_string()])
1284                    .unwrap_or_default(),
1285            )
1286            .arg(src_path)
1287            .arg(format!(
1288                "{}:{}",
1289                self.socket.connection_options.scp_url(),
1290                dest_path.display()
1291            ))
1292            .output()
1293            .await?;
1294
1295        if output.status.success() {
1296            Ok(())
1297        } else {
1298            Err(anyhow!(
1299                "failed to upload file {} -> {}: {}",
1300                src_path.display(),
1301                dest_path.display(),
1302                String::from_utf8_lossy(&output.stderr)
1303            ))
1304        }
1305    }
1306}
1307
1308type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1309
1310pub struct ChannelClient {
1311    next_message_id: AtomicU32,
1312    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1313    response_channels: ResponseChannels,             // Lock
1314    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1315}
1316
1317impl ChannelClient {
1318    pub fn new(
1319        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1320        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1321        cx: &AppContext,
1322    ) -> Arc<Self> {
1323        let this = Arc::new(Self {
1324            outgoing_tx,
1325            next_message_id: AtomicU32::new(0),
1326            response_channels: ResponseChannels::default(),
1327            message_handlers: Default::default(),
1328        });
1329
1330        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1331
1332        this
1333    }
1334
1335    fn start_handling_messages(
1336        this: Arc<Self>,
1337        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1338        cx: &AppContext,
1339    ) {
1340        cx.spawn(|cx| {
1341            let this = Arc::downgrade(&this);
1342            async move {
1343                let peer_id = PeerId { owner_id: 0, id: 0 };
1344                while let Some(incoming) = incoming_rx.next().await {
1345                    let Some(this) = this.upgrade() else {
1346                        return anyhow::Ok(());
1347                    };
1348
1349                    if let Some(request_id) = incoming.responding_to {
1350                        let request_id = MessageId(request_id);
1351                        let sender = this.response_channels.lock().remove(&request_id);
1352                        if let Some(sender) = sender {
1353                            let (tx, rx) = oneshot::channel();
1354                            if incoming.payload.is_some() {
1355                                sender.send((incoming, tx)).ok();
1356                            }
1357                            rx.await.ok();
1358                        }
1359                    } else if let Some(envelope) =
1360                        build_typed_envelope(peer_id, Instant::now(), incoming)
1361                    {
1362                        let type_name = envelope.payload_type_name();
1363                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1364                            &this.message_handlers,
1365                            envelope,
1366                            this.clone().into(),
1367                            cx.clone(),
1368                        ) {
1369                            log::debug!("ssh message received. name:{type_name}");
1370                            match future.await {
1371                                Ok(_) => {
1372                                    log::debug!("ssh message handled. name:{type_name}");
1373                                }
1374                                Err(error) => {
1375                                    log::error!(
1376                                        "error handling message. type:{type_name}, error:{error}",
1377                                    );
1378                                }
1379                            }
1380                        } else {
1381                            log::error!("unhandled ssh message name:{type_name}");
1382                        }
1383                    }
1384                }
1385                anyhow::Ok(())
1386            }
1387        })
1388        .detach();
1389    }
1390
1391    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1392        let id = (TypeId::of::<E>(), remote_id);
1393
1394        let mut message_handlers = self.message_handlers.lock();
1395        if message_handlers
1396            .entities_by_type_and_remote_id
1397            .contains_key(&id)
1398        {
1399            panic!("already subscribed to entity");
1400        }
1401
1402        message_handlers.entities_by_type_and_remote_id.insert(
1403            id,
1404            EntityMessageSubscriber::Entity {
1405                handle: entity.downgrade().into(),
1406            },
1407        );
1408    }
1409
1410    pub fn request<T: RequestMessage>(
1411        &self,
1412        payload: T,
1413    ) -> impl 'static + Future<Output = Result<T::Response>> {
1414        log::debug!("ssh request start. name:{}", T::NAME);
1415        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1416        async move {
1417            let response = response.await?;
1418            log::debug!("ssh request finish. name:{}", T::NAME);
1419            T::Response::from_envelope(response)
1420                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1421        }
1422    }
1423
1424    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1425        smol::future::or(
1426            async {
1427                self.request(proto::Ping {}).await?;
1428                Ok(())
1429            },
1430            async {
1431                smol::Timer::after(timeout).await;
1432                Err(anyhow!("Timeout detected"))
1433            },
1434        )
1435        .await
1436    }
1437
1438    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1439        log::debug!("ssh send name:{}", T::NAME);
1440        self.send_dynamic(payload.into_envelope(0, None, None))
1441    }
1442
1443    pub fn request_dynamic(
1444        &self,
1445        mut envelope: proto::Envelope,
1446        type_name: &'static str,
1447    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1448        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1449        let (tx, rx) = oneshot::channel();
1450        let mut response_channels_lock = self.response_channels.lock();
1451        response_channels_lock.insert(MessageId(envelope.id), tx);
1452        drop(response_channels_lock);
1453        let result = self.outgoing_tx.unbounded_send(envelope);
1454        async move {
1455            if let Err(error) = &result {
1456                log::error!("failed to send message: {}", error);
1457                return Err(anyhow!("failed to send message: {}", error));
1458            }
1459
1460            let response = rx.await.context("connection lost")?.0;
1461            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1462                return Err(RpcError::from_proto(error, type_name));
1463            }
1464            Ok(response)
1465        }
1466    }
1467
1468    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1469        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1470        self.outgoing_tx.unbounded_send(envelope)?;
1471        Ok(())
1472    }
1473}
1474
1475impl ProtoClient for ChannelClient {
1476    fn request(
1477        &self,
1478        envelope: proto::Envelope,
1479        request_type: &'static str,
1480    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1481        self.request_dynamic(envelope, request_type).boxed()
1482    }
1483
1484    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1485        self.send_dynamic(envelope)
1486    }
1487
1488    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1489        self.send_dynamic(envelope)
1490    }
1491
1492    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1493        &self.message_handlers
1494    }
1495
1496    fn is_via_collab(&self) -> bool {
1497        false
1498    }
1499}