ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use collections::HashMap;
  10use futures::{
  11    channel::{
  12        mpsc::{self, UnboundedReceiver, UnboundedSender},
  13        oneshot,
  14    },
  15    future::BoxFuture,
  16    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  17    StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
  21    WeakModel,
  22};
  23use parking_lot::Mutex;
  24use rpc::{
  25    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  26    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  27};
  28use smol::{
  29    fs,
  30    process::{self, Child, Stdio},
  31    Timer,
  32};
  33use std::{
  34    any::TypeId,
  35    ffi::OsStr,
  36    fmt,
  37    ops::ControlFlow,
  38    path::{Path, PathBuf},
  39    sync::{
  40        atomic::{AtomicU32, Ordering::SeqCst},
  41        Arc,
  42    },
  43    time::{Duration, Instant},
  44};
  45use tempfile::TempDir;
  46use util::ResultExt;
  47
  48#[derive(
  49    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  50)]
  51pub struct SshProjectId(pub u64);
  52
  53#[derive(Clone)]
  54pub struct SshSocket {
  55    connection_options: SshConnectionOptions,
  56    socket_path: PathBuf,
  57}
  58
  59#[derive(Debug, Default, Clone, PartialEq, Eq)]
  60pub struct SshConnectionOptions {
  61    pub host: String,
  62    pub username: Option<String>,
  63    pub port: Option<u16>,
  64    pub password: Option<String>,
  65}
  66
  67impl SshConnectionOptions {
  68    pub fn ssh_url(&self) -> String {
  69        let mut result = String::from("ssh://");
  70        if let Some(username) = &self.username {
  71            result.push_str(username);
  72            result.push('@');
  73        }
  74        result.push_str(&self.host);
  75        if let Some(port) = self.port {
  76            result.push(':');
  77            result.push_str(&port.to_string());
  78        }
  79        result
  80    }
  81
  82    fn scp_url(&self) -> String {
  83        if let Some(username) = &self.username {
  84            format!("{}@{}", username, self.host)
  85        } else {
  86            self.host.clone()
  87        }
  88    }
  89
  90    pub fn connection_string(&self) -> String {
  91        let host = if let Some(username) = &self.username {
  92            format!("{}@{}", username, self.host)
  93        } else {
  94            self.host.clone()
  95        };
  96        if let Some(port) = &self.port {
  97            format!("{}:{}", host, port)
  98        } else {
  99            host
 100        }
 101    }
 102
 103    // Uniquely identifies dev server projects on a remote host. Needs to be
 104    // stable for the same dev server project.
 105    pub fn dev_server_identifier(&self) -> String {
 106        let mut identifier = format!("dev-server-{:?}", self.host);
 107        if let Some(username) = self.username.as_ref() {
 108            identifier.push('-');
 109            identifier.push_str(&username);
 110        }
 111        identifier
 112    }
 113}
 114
 115#[derive(Copy, Clone, Debug)]
 116pub struct SshPlatform {
 117    pub os: &'static str,
 118    pub arch: &'static str,
 119}
 120
 121pub trait SshClientDelegate: Send + Sync {
 122    fn ask_password(
 123        &self,
 124        prompt: String,
 125        cx: &mut AsyncAppContext,
 126    ) -> oneshot::Receiver<Result<String>>;
 127    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 128    fn get_server_binary(
 129        &self,
 130        platform: SshPlatform,
 131        cx: &mut AsyncAppContext,
 132    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 133    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 134    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 135}
 136
 137impl SshSocket {
 138    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 139        let mut command = process::Command::new("ssh");
 140        self.ssh_options(&mut command)
 141            .arg(self.connection_options.ssh_url())
 142            .arg(program);
 143        command
 144    }
 145
 146    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 147        command
 148            .stdin(Stdio::piped())
 149            .stdout(Stdio::piped())
 150            .stderr(Stdio::piped())
 151            .args(["-o", "ControlMaster=no", "-o"])
 152            .arg(format!("ControlPath={}", self.socket_path.display()))
 153    }
 154
 155    fn ssh_args(&self) -> Vec<String> {
 156        vec![
 157            "-o".to_string(),
 158            "ControlMaster=no".to_string(),
 159            "-o".to_string(),
 160            format!("ControlPath={}", self.socket_path.display()),
 161            self.connection_options.ssh_url(),
 162        ]
 163    }
 164}
 165
 166async fn run_cmd(command: &mut process::Command) -> Result<String> {
 167    let output = command.output().await?;
 168    if output.status.success() {
 169        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 170    } else {
 171        Err(anyhow!(
 172            "failed to run command: {}",
 173            String::from_utf8_lossy(&output.stderr)
 174        ))
 175    }
 176}
 177
 178struct ChannelForwarder {
 179    quit_tx: UnboundedSender<()>,
 180    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 181}
 182
 183impl ChannelForwarder {
 184    fn new(
 185        mut incoming_tx: UnboundedSender<Envelope>,
 186        mut outgoing_rx: UnboundedReceiver<Envelope>,
 187        cx: &AsyncAppContext,
 188    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 189        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 190
 191        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 192        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 193
 194        let forwarding_task = cx.background_executor().spawn(async move {
 195            loop {
 196                select_biased! {
 197                    _ = quit_rx.next().fuse() => {
 198                        break;
 199                    },
 200                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 201                        if let Some(envelope) = incoming_envelope {
 202                            if incoming_tx.send(envelope).await.is_err() {
 203                                break;
 204                            }
 205                        } else {
 206                            break;
 207                        }
 208                    }
 209                    outgoing_envelope = outgoing_rx.next().fuse() => {
 210                        if let Some(envelope) = outgoing_envelope {
 211                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 212                                break;
 213                            }
 214                        } else {
 215                            break;
 216                        }
 217                    }
 218                }
 219            }
 220
 221            (incoming_tx, outgoing_rx)
 222        });
 223
 224        (
 225            Self {
 226                forwarding_task,
 227                quit_tx,
 228            },
 229            proxy_incoming_tx,
 230            proxy_outgoing_rx,
 231        )
 232    }
 233
 234    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 235        let _ = self.quit_tx.send(()).await;
 236        self.forwarding_task.await
 237    }
 238}
 239
 240const MAX_MISSED_HEARTBEATS: usize = 5;
 241const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 242const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 243
 244const MAX_RECONNECT_ATTEMPTS: usize = 3;
 245
 246enum State {
 247    Connecting,
 248    Connected {
 249        ssh_connection: SshRemoteConnection,
 250        delegate: Arc<dyn SshClientDelegate>,
 251        forwarder: ChannelForwarder,
 252
 253        multiplex_task: Task<Result<()>>,
 254        heartbeat_task: Task<Result<()>>,
 255    },
 256    HeartbeatMissed {
 257        missed_heartbeats: usize,
 258
 259        ssh_connection: SshRemoteConnection,
 260        delegate: Arc<dyn SshClientDelegate>,
 261        forwarder: ChannelForwarder,
 262
 263        multiplex_task: Task<Result<()>>,
 264        heartbeat_task: Task<Result<()>>,
 265    },
 266    Reconnecting,
 267    ReconnectFailed {
 268        ssh_connection: SshRemoteConnection,
 269        delegate: Arc<dyn SshClientDelegate>,
 270        forwarder: ChannelForwarder,
 271
 272        error: anyhow::Error,
 273        attempts: usize,
 274    },
 275    ReconnectExhausted,
 276    ServerNotRunning,
 277}
 278
 279impl fmt::Display for State {
 280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 281        match self {
 282            Self::Connecting => write!(f, "connecting"),
 283            Self::Connected { .. } => write!(f, "connected"),
 284            Self::Reconnecting => write!(f, "reconnecting"),
 285            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 286            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 287            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 288            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 289        }
 290    }
 291}
 292
 293impl State {
 294    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
 295        match self {
 296            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
 297            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
 298            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
 299            _ => None,
 300        }
 301    }
 302
 303    fn can_reconnect(&self) -> bool {
 304        match self {
 305            Self::Connected { .. }
 306            | Self::HeartbeatMissed { .. }
 307            | Self::ReconnectFailed { .. } => true,
 308            State::Connecting
 309            | State::Reconnecting
 310            | State::ReconnectExhausted
 311            | State::ServerNotRunning => false,
 312        }
 313    }
 314
 315    fn is_reconnect_failed(&self) -> bool {
 316        matches!(self, Self::ReconnectFailed { .. })
 317    }
 318
 319    fn is_reconnect_exhausted(&self) -> bool {
 320        matches!(self, Self::ReconnectExhausted { .. })
 321    }
 322
 323    fn is_reconnecting(&self) -> bool {
 324        matches!(self, Self::Reconnecting { .. })
 325    }
 326
 327    fn heartbeat_recovered(self) -> Self {
 328        match self {
 329            Self::HeartbeatMissed {
 330                ssh_connection,
 331                delegate,
 332                forwarder,
 333                multiplex_task,
 334                heartbeat_task,
 335                ..
 336            } => Self::Connected {
 337                ssh_connection,
 338                delegate,
 339                forwarder,
 340                multiplex_task,
 341                heartbeat_task,
 342            },
 343            _ => self,
 344        }
 345    }
 346
 347    fn heartbeat_missed(self) -> Self {
 348        match self {
 349            Self::Connected {
 350                ssh_connection,
 351                delegate,
 352                forwarder,
 353                multiplex_task,
 354                heartbeat_task,
 355            } => Self::HeartbeatMissed {
 356                missed_heartbeats: 1,
 357                ssh_connection,
 358                delegate,
 359                forwarder,
 360                multiplex_task,
 361                heartbeat_task,
 362            },
 363            Self::HeartbeatMissed {
 364                missed_heartbeats,
 365                ssh_connection,
 366                delegate,
 367                forwarder,
 368                multiplex_task,
 369                heartbeat_task,
 370            } => Self::HeartbeatMissed {
 371                missed_heartbeats: missed_heartbeats + 1,
 372                ssh_connection,
 373                delegate,
 374                forwarder,
 375                multiplex_task,
 376                heartbeat_task,
 377            },
 378            _ => self,
 379        }
 380    }
 381}
 382
 383/// The state of the ssh connection.
 384#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 385pub enum ConnectionState {
 386    Connecting,
 387    Connected,
 388    HeartbeatMissed,
 389    Reconnecting,
 390    Disconnected,
 391}
 392
 393impl From<&State> for ConnectionState {
 394    fn from(value: &State) -> Self {
 395        match value {
 396            State::Connecting => Self::Connecting,
 397            State::Connected { .. } => Self::Connected,
 398            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 399            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 400            State::ReconnectExhausted => Self::Disconnected,
 401            State::ServerNotRunning => Self::Disconnected,
 402        }
 403    }
 404}
 405
 406pub struct SshRemoteClient {
 407    client: Arc<ChannelClient>,
 408    unique_identifier: String,
 409    connection_options: SshConnectionOptions,
 410    state: Arc<Mutex<Option<State>>>,
 411}
 412
 413impl Drop for SshRemoteClient {
 414    fn drop(&mut self) {
 415        self.shutdown_processes();
 416    }
 417}
 418
 419#[derive(Debug)]
 420pub enum SshRemoteEvent {
 421    Disconnected,
 422}
 423
 424impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 425
 426impl SshRemoteClient {
 427    pub fn new(
 428        unique_identifier: String,
 429        connection_options: SshConnectionOptions,
 430        delegate: Arc<dyn SshClientDelegate>,
 431        cx: &AppContext,
 432    ) -> Task<Result<Model<Self>>> {
 433        cx.spawn(|mut cx| async move {
 434            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 435            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 436
 437            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 438            let this = cx.new_model(|cx| {
 439                cx.on_app_quit(|this: &mut Self, _| {
 440                    this.shutdown_processes();
 441                    futures::future::ready(())
 442                })
 443                .detach();
 444
 445                Self {
 446                    client: client.clone(),
 447                    unique_identifier: unique_identifier.clone(),
 448                    connection_options: connection_options.clone(),
 449                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 450                }
 451            })?;
 452
 453            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 454                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 455
 456            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 457                unique_identifier,
 458                false,
 459                connection_options,
 460                delegate.clone(),
 461                &mut cx,
 462            )
 463            .await?;
 464
 465            let multiplex_task = Self::multiplex(
 466                this.downgrade(),
 467                ssh_proxy_process,
 468                proxy_incoming_tx,
 469                proxy_outgoing_rx,
 470                &mut cx,
 471            );
 472
 473            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 474                log::error!("failed to establish connection: {}", error);
 475                delegate.set_error(error.to_string(), &mut cx);
 476                return Err(error);
 477            }
 478
 479            let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
 480
 481            this.update(&mut cx, |this, _| {
 482                *this.state.lock() = Some(State::Connected {
 483                    ssh_connection,
 484                    delegate,
 485                    forwarder: proxy,
 486                    multiplex_task,
 487                    heartbeat_task,
 488                });
 489            })?;
 490
 491            Ok(this)
 492        })
 493    }
 494
 495    fn shutdown_processes(&self) {
 496        let Some(state) = self.state.lock().take() else {
 497            return;
 498        };
 499        log::info!("shutting down ssh processes");
 500
 501        let State::Connected {
 502            multiplex_task,
 503            heartbeat_task,
 504            ..
 505        } = state
 506        else {
 507            return;
 508        };
 509        // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 510        // child of master_process.
 511        drop(multiplex_task);
 512        // Now drop the rest of state, which kills master process.
 513        drop(heartbeat_task);
 514    }
 515
 516    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 517        let mut lock = self.state.lock();
 518
 519        let can_reconnect = lock
 520            .as_ref()
 521            .map(|state| state.can_reconnect())
 522            .unwrap_or(false);
 523        if !can_reconnect {
 524            let error = if let Some(state) = lock.as_ref() {
 525                format!("invalid state, cannot reconnect while in state {state}")
 526            } else {
 527                "no state set".to_string()
 528            };
 529            log::info!("aborting reconnect, because not in state that allows reconnecting");
 530            return Err(anyhow!(error));
 531        }
 532
 533        let state = lock.take().unwrap();
 534        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 535            State::Connected {
 536                ssh_connection,
 537                delegate,
 538                forwarder,
 539                multiplex_task,
 540                heartbeat_task,
 541            }
 542            | State::HeartbeatMissed {
 543                ssh_connection,
 544                delegate,
 545                forwarder,
 546                multiplex_task,
 547                heartbeat_task,
 548                ..
 549            } => {
 550                drop(multiplex_task);
 551                drop(heartbeat_task);
 552                (0, ssh_connection, delegate, forwarder)
 553            }
 554            State::ReconnectFailed {
 555                attempts,
 556                ssh_connection,
 557                delegate,
 558                forwarder,
 559                ..
 560            } => (attempts, ssh_connection, delegate, forwarder),
 561            State::Connecting
 562            | State::Reconnecting
 563            | State::ReconnectExhausted
 564            | State::ServerNotRunning => unreachable!(),
 565        };
 566
 567        let attempts = attempts + 1;
 568        if attempts > MAX_RECONNECT_ATTEMPTS {
 569            log::error!(
 570                "Failed to reconnect to after {} attempts, giving up",
 571                MAX_RECONNECT_ATTEMPTS
 572            );
 573            drop(lock);
 574            self.set_state(State::ReconnectExhausted, cx);
 575            return Ok(());
 576        }
 577        drop(lock);
 578
 579        self.set_state(State::Reconnecting, cx);
 580
 581        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 582
 583        let identifier = self.unique_identifier.clone();
 584        let client = self.client.clone();
 585        let reconnect_task = cx.spawn(|this, mut cx| async move {
 586            macro_rules! failed {
 587                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 588                    return State::ReconnectFailed {
 589                        error: anyhow!($error),
 590                        attempts: $attempts,
 591                        ssh_connection: $ssh_connection,
 592                        delegate: $delegate,
 593                        forwarder: $forwarder,
 594                    };
 595                };
 596            }
 597
 598            if let Err(error) = ssh_connection.master_process.kill() {
 599                failed!(error, attempts, ssh_connection, delegate, forwarder);
 600            };
 601
 602            if let Err(error) = ssh_connection
 603                .master_process
 604                .status()
 605                .await
 606                .context("Failed to kill ssh process")
 607            {
 608                failed!(error, attempts, ssh_connection, delegate, forwarder);
 609            }
 610
 611            let connection_options = ssh_connection.socket.connection_options.clone();
 612
 613            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 614            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 615                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 616
 617            let (ssh_connection, ssh_process) = match Self::establish_connection(
 618                identifier,
 619                true,
 620                connection_options,
 621                delegate.clone(),
 622                &mut cx,
 623            )
 624            .await
 625            {
 626                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 627                Err(error) => {
 628                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 629                }
 630            };
 631
 632            let multiplex_task = Self::multiplex(
 633                this.clone(),
 634                ssh_process,
 635                proxy_incoming_tx,
 636                proxy_outgoing_rx,
 637                &mut cx,
 638            );
 639
 640            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 641                failed!(error, attempts, ssh_connection, delegate, forwarder);
 642            };
 643
 644            State::Connected {
 645                ssh_connection,
 646                delegate,
 647                forwarder,
 648                multiplex_task,
 649                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
 650            }
 651        });
 652
 653        cx.spawn(|this, mut cx| async move {
 654            let new_state = reconnect_task.await;
 655            this.update(&mut cx, |this, cx| {
 656                this.try_set_state(cx, |old_state| {
 657                    if old_state.is_reconnecting() {
 658                        match &new_state {
 659                            State::Connecting
 660                            | State::Reconnecting { .. }
 661                            | State::HeartbeatMissed { .. }
 662                            | State::ServerNotRunning => {}
 663                            State::Connected { .. } => {
 664                                log::info!("Successfully reconnected");
 665                            }
 666                            State::ReconnectFailed {
 667                                error, attempts, ..
 668                            } => {
 669                                log::error!(
 670                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 671                                    attempts,
 672                                    error
 673                                );
 674                            }
 675                            State::ReconnectExhausted => {
 676                                log::error!("Reconnect attempt failed and all attempts exhausted");
 677                            }
 678                        }
 679                        Some(new_state)
 680                    } else {
 681                        None
 682                    }
 683                });
 684
 685                if this.state_is(State::is_reconnect_failed) {
 686                    this.reconnect(cx)
 687                } else if this.state_is(State::is_reconnect_exhausted) {
 688                    cx.emit(SshRemoteEvent::Disconnected);
 689                    Ok(())
 690                } else {
 691                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 692                    Ok(())
 693                }
 694            })
 695        })
 696        .detach_and_log_err(cx);
 697
 698        Ok(())
 699    }
 700
 701    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
 702        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 703            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 704        };
 705        cx.spawn(|mut cx| {
 706            let this = this.clone();
 707            async move {
 708                let mut missed_heartbeats = 0;
 709
 710                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
 711                loop {
 712                    timer.next().await;
 713
 714                    log::debug!("Sending heartbeat to server...");
 715
 716                    let result = client.ping(HEARTBEAT_TIMEOUT).await;
 717                    if result.is_err() {
 718                        missed_heartbeats += 1;
 719                        log::warn!(
 720                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 721                            HEARTBEAT_TIMEOUT,
 722                            missed_heartbeats,
 723                            MAX_MISSED_HEARTBEATS
 724                        );
 725                    } else if missed_heartbeats != 0 {
 726                        missed_heartbeats = 0;
 727                    } else {
 728                        continue;
 729                    }
 730
 731                    let result = this.update(&mut cx, |this, mut cx| {
 732                        this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 733                    })?;
 734                    if result.is_break() {
 735                        return Ok(());
 736                    }
 737                }
 738            }
 739        })
 740    }
 741
 742    fn handle_heartbeat_result(
 743        &mut self,
 744        missed_heartbeats: usize,
 745        cx: &mut ModelContext<Self>,
 746    ) -> ControlFlow<()> {
 747        let state = self.state.lock().take().unwrap();
 748        let next_state = if missed_heartbeats > 0 {
 749            state.heartbeat_missed()
 750        } else {
 751            state.heartbeat_recovered()
 752        };
 753
 754        self.set_state(next_state, cx);
 755
 756        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 757            log::error!(
 758                "Missed last {} heartbeats. Reconnecting...",
 759                missed_heartbeats
 760            );
 761
 762            self.reconnect(cx)
 763                .context("failed to start reconnect process after missing heartbeats")
 764                .log_err();
 765            ControlFlow::Break(())
 766        } else {
 767            ControlFlow::Continue(())
 768        }
 769    }
 770
 771    fn multiplex(
 772        this: WeakModel<Self>,
 773        mut ssh_proxy_process: Child,
 774        incoming_tx: UnboundedSender<Envelope>,
 775        mut outgoing_rx: UnboundedReceiver<Envelope>,
 776        cx: &AsyncAppContext,
 777    ) -> Task<Result<()>> {
 778        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 779        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 780        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 781
 782        let io_task = cx.background_executor().spawn(async move {
 783            let mut stdin_buffer = Vec::new();
 784            let mut stdout_buffer = Vec::new();
 785            let mut stderr_buffer = Vec::new();
 786            let mut stderr_offset = 0;
 787
 788            loop {
 789                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 790                stderr_buffer.resize(stderr_offset + 1024, 0);
 791
 792                select_biased! {
 793                    outgoing = outgoing_rx.next().fuse() => {
 794                        let Some(outgoing) = outgoing else {
 795                            return anyhow::Ok(None);
 796                        };
 797
 798                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 799                    }
 800
 801                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 802                        match result {
 803                            Ok(0) => {
 804                                child_stdin.close().await?;
 805                                outgoing_rx.close();
 806                                let status = ssh_proxy_process.status().await?;
 807                                return Ok(status.code());
 808                            }
 809                            Ok(len) => {
 810                                if len < stdout_buffer.len() {
 811                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 812                                }
 813
 814                                let message_len = message_len_from_buffer(&stdout_buffer);
 815                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 816                                    Ok(envelope) => {
 817                                        incoming_tx.unbounded_send(envelope).ok();
 818                                    }
 819                                    Err(error) => {
 820                                        log::error!("error decoding message {error:?}");
 821                                    }
 822                                }
 823                            }
 824                            Err(error) => {
 825                                Err(anyhow!("error reading stdout: {error:?}"))?;
 826                            }
 827                        }
 828                    }
 829
 830                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 831                        match result {
 832                            Ok(len) => {
 833                                stderr_offset += len;
 834                                let mut start_ix = 0;
 835                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 836                                    let line_ix = start_ix + ix;
 837                                    let content = &stderr_buffer[start_ix..line_ix];
 838                                    start_ix = line_ix + 1;
 839                                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
 840                                        record.log(log::logger())
 841                                    } else {
 842                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 843                                    }
 844                                }
 845                                stderr_buffer.drain(0..start_ix);
 846                                stderr_offset -= start_ix;
 847                            }
 848                            Err(error) => {
 849                                Err(anyhow!("error reading stderr: {error:?}"))?;
 850                            }
 851                        }
 852                    }
 853                }
 854            }
 855        });
 856
 857        cx.spawn(|mut cx| async move {
 858            let result = io_task.await;
 859
 860            match result {
 861                Ok(Some(exit_code)) => {
 862                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 863                        match error {
 864                            ProxyLaunchError::ServerNotRunning => {
 865                                log::error!("failed to reconnect because server is not running");
 866                                this.update(&mut cx, |this, cx| {
 867                                    this.set_state(State::ServerNotRunning, cx);
 868                                    cx.emit(SshRemoteEvent::Disconnected);
 869                                })?;
 870                            }
 871                        }
 872                    } else if exit_code > 0 {
 873                        log::error!("proxy process terminated unexpectedly");
 874                        this.update(&mut cx, |this, cx| {
 875                            this.reconnect(cx).ok();
 876                        })?;
 877                    }
 878                }
 879                Ok(None) => {}
 880                Err(error) => {
 881                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 882                    this.update(&mut cx, |this, cx| {
 883                        this.reconnect(cx).ok();
 884                    })?;
 885                }
 886            }
 887            Ok(())
 888        })
 889    }
 890
 891    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 892        self.state.lock().as_ref().map_or(false, check)
 893    }
 894
 895    fn try_set_state(
 896        &self,
 897        cx: &mut ModelContext<Self>,
 898        map: impl FnOnce(&State) -> Option<State>,
 899    ) {
 900        let mut lock = self.state.lock();
 901        let new_state = lock.as_ref().and_then(map);
 902
 903        if let Some(new_state) = new_state {
 904            lock.replace(new_state);
 905            cx.notify();
 906        }
 907    }
 908
 909    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 910        log::info!("setting state to '{}'", &state);
 911        self.state.lock().replace(state);
 912        cx.notify();
 913    }
 914
 915    async fn establish_connection(
 916        unique_identifier: String,
 917        reconnect: bool,
 918        connection_options: SshConnectionOptions,
 919        delegate: Arc<dyn SshClientDelegate>,
 920        cx: &mut AsyncAppContext,
 921    ) -> Result<(SshRemoteConnection, Child)> {
 922        let ssh_connection =
 923            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 924
 925        let platform = ssh_connection.query_platform().await?;
 926        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 927        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 928        ssh_connection
 929            .ensure_server_binary(
 930                &delegate,
 931                &local_binary_path,
 932                &remote_binary_path,
 933                version,
 934                cx,
 935            )
 936            .await?;
 937
 938        let socket = ssh_connection.socket.clone();
 939        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 940
 941        delegate.set_status(Some("Starting proxy"), cx);
 942
 943        let mut start_proxy_command = format!(
 944            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 945            std::env::var("RUST_LOG").unwrap_or_default(),
 946            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 947            remote_binary_path,
 948            unique_identifier,
 949        );
 950        if reconnect {
 951            start_proxy_command.push_str(" --reconnect");
 952        }
 953
 954        let ssh_proxy_process = socket
 955            .ssh_command(start_proxy_command)
 956            // IMPORTANT: we kill this process when we drop the task that uses it.
 957            .kill_on_drop(true)
 958            .spawn()
 959            .context("failed to spawn remote server")?;
 960
 961        Ok((ssh_connection, ssh_proxy_process))
 962    }
 963
 964    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 965        self.client.subscribe_to_entity(remote_id, entity);
 966    }
 967
 968    pub fn ssh_args(&self) -> Option<Vec<String>> {
 969        self.state
 970            .lock()
 971            .as_ref()
 972            .and_then(|state| state.ssh_connection())
 973            .map(|ssh_connection| ssh_connection.socket.ssh_args())
 974    }
 975
 976    pub fn to_proto_client(&self) -> AnyProtoClient {
 977        self.client.clone().into()
 978    }
 979
 980    pub fn connection_string(&self) -> String {
 981        self.connection_options.connection_string()
 982    }
 983
 984    pub fn connection_options(&self) -> SshConnectionOptions {
 985        self.connection_options.clone()
 986    }
 987
 988    #[cfg(not(any(test, feature = "test-support")))]
 989    pub fn connection_state(&self) -> ConnectionState {
 990        self.state
 991            .lock()
 992            .as_ref()
 993            .map(ConnectionState::from)
 994            .unwrap_or(ConnectionState::Disconnected)
 995    }
 996
 997    #[cfg(any(test, feature = "test-support"))]
 998    pub fn connection_state(&self) -> ConnectionState {
 999        ConnectionState::Connected
1000    }
1001
1002    pub fn is_disconnected(&self) -> bool {
1003        self.connection_state() == ConnectionState::Disconnected
1004    }
1005
1006    #[cfg(any(test, feature = "test-support"))]
1007    pub fn fake(
1008        client_cx: &mut gpui::TestAppContext,
1009        server_cx: &mut gpui::TestAppContext,
1010    ) -> (Model<Self>, Arc<ChannelClient>) {
1011        use gpui::Context;
1012
1013        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1014        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1015
1016        (
1017            client_cx.update(|cx| {
1018                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1019                cx.new_model(|_| Self {
1020                    client,
1021                    unique_identifier: "fake".to_string(),
1022                    connection_options: SshConnectionOptions::default(),
1023                    state: Arc::new(Mutex::new(None)),
1024                })
1025            }),
1026            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1027        )
1028    }
1029}
1030
1031impl From<SshRemoteClient> for AnyProtoClient {
1032    fn from(client: SshRemoteClient) -> Self {
1033        AnyProtoClient::new(client.client.clone())
1034    }
1035}
1036
1037struct SshRemoteConnection {
1038    socket: SshSocket,
1039    master_process: process::Child,
1040    _temp_dir: TempDir,
1041}
1042
1043impl Drop for SshRemoteConnection {
1044    fn drop(&mut self) {
1045        if let Err(error) = self.master_process.kill() {
1046            log::error!("failed to kill SSH master process: {}", error);
1047        }
1048    }
1049}
1050
1051impl SshRemoteConnection {
1052    #[cfg(not(unix))]
1053    async fn new(
1054        _connection_options: SshConnectionOptions,
1055        _delegate: Arc<dyn SshClientDelegate>,
1056        _cx: &mut AsyncAppContext,
1057    ) -> Result<Self> {
1058        Err(anyhow!("ssh is not supported on this platform"))
1059    }
1060
1061    #[cfg(unix)]
1062    async fn new(
1063        connection_options: SshConnectionOptions,
1064        delegate: Arc<dyn SshClientDelegate>,
1065        cx: &mut AsyncAppContext,
1066    ) -> Result<Self> {
1067        use futures::{io::BufReader, AsyncBufReadExt as _};
1068        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1069        use util::ResultExt as _;
1070
1071        delegate.set_status(Some("connecting"), cx);
1072
1073        let url = connection_options.ssh_url();
1074        let temp_dir = tempfile::Builder::new()
1075            .prefix("zed-ssh-session")
1076            .tempdir()?;
1077
1078        // Create a domain socket listener to handle requests from the askpass program.
1079        let askpass_socket = temp_dir.path().join("askpass.sock");
1080        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1081        let listener =
1082            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1083
1084        let askpass_task = cx.spawn({
1085            let delegate = delegate.clone();
1086            |mut cx| async move {
1087                let mut askpass_opened_tx = Some(askpass_opened_tx);
1088
1089                while let Ok((mut stream, _)) = listener.accept().await {
1090                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1091                        askpass_opened_tx.send(()).ok();
1092                    }
1093                    let mut buffer = Vec::new();
1094                    let mut reader = BufReader::new(&mut stream);
1095                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1096                        buffer.clear();
1097                    }
1098                    let password_prompt = String::from_utf8_lossy(&buffer);
1099                    if let Some(password) = delegate
1100                        .ask_password(password_prompt.to_string(), &mut cx)
1101                        .await
1102                        .context("failed to get ssh password")
1103                        .and_then(|p| p)
1104                        .log_err()
1105                    {
1106                        stream.write_all(password.as_bytes()).await.log_err();
1107                    }
1108                }
1109            }
1110        });
1111
1112        // Create an askpass script that communicates back to this process.
1113        let askpass_script = format!(
1114            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1115            askpass_socket = askpass_socket.display(),
1116            print_args = "printf '%s\\0' \"$@\"",
1117            shebang = "#!/bin/sh",
1118        );
1119        let askpass_script_path = temp_dir.path().join("askpass.sh");
1120        fs::write(&askpass_script_path, askpass_script).await?;
1121        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1122
1123        // Start the master SSH process, which does not do anything except for establish
1124        // the connection and keep it open, allowing other ssh commands to reuse it
1125        // via a control socket.
1126        let socket_path = temp_dir.path().join("ssh.sock");
1127        let mut master_process = process::Command::new("ssh")
1128            .stdin(Stdio::null())
1129            .stdout(Stdio::piped())
1130            .stderr(Stdio::piped())
1131            .env("SSH_ASKPASS_REQUIRE", "force")
1132            .env("SSH_ASKPASS", &askpass_script_path)
1133            .args(["-N", "-o", "ControlMaster=yes", "-o"])
1134            .arg(format!("ControlPath={}", socket_path.display()))
1135            .arg(&url)
1136            .spawn()?;
1137
1138        // Wait for this ssh process to close its stdout, indicating that authentication
1139        // has completed.
1140        let stdout = master_process.stdout.as_mut().unwrap();
1141        let mut output = Vec::new();
1142        let connection_timeout = Duration::from_secs(10);
1143
1144        let result = select_biased! {
1145            _ = askpass_opened_rx.fuse() => {
1146                // If the askpass script has opened, that means the user is typing
1147                // their password, in which case we don't want to timeout anymore,
1148                // since we know a connection has been established.
1149                stdout.read_to_end(&mut output).await?;
1150                Ok(())
1151            }
1152            result = stdout.read_to_end(&mut output).fuse() => {
1153                result?;
1154                Ok(())
1155            }
1156            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1157                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1158            }
1159        };
1160
1161        if let Err(e) = result {
1162            let error_message = format!("Failed to connect to host: {}.", e);
1163            delegate.set_error(error_message, cx);
1164            return Err(e);
1165        }
1166
1167        drop(askpass_task);
1168
1169        if master_process.try_status()?.is_some() {
1170            output.clear();
1171            let mut stderr = master_process.stderr.take().unwrap();
1172            stderr.read_to_end(&mut output).await?;
1173
1174            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1175            delegate.set_error(error_message.clone(), cx);
1176            Err(anyhow!(error_message))?;
1177        }
1178
1179        Ok(Self {
1180            socket: SshSocket {
1181                connection_options,
1182                socket_path,
1183            },
1184            master_process,
1185            _temp_dir: temp_dir,
1186        })
1187    }
1188
1189    async fn ensure_server_binary(
1190        &self,
1191        delegate: &Arc<dyn SshClientDelegate>,
1192        src_path: &Path,
1193        dst_path: &Path,
1194        version: SemanticVersion,
1195        cx: &mut AsyncAppContext,
1196    ) -> Result<()> {
1197        let mut dst_path_gz = dst_path.to_path_buf();
1198        dst_path_gz.set_extension("gz");
1199
1200        if let Some(parent) = dst_path.parent() {
1201            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1202        }
1203
1204        let mut server_binary_exists = false;
1205        if cfg!(not(debug_assertions)) {
1206            if let Ok(installed_version) =
1207                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1208            {
1209                if installed_version.trim() == version.to_string() {
1210                    server_binary_exists = true;
1211                }
1212            }
1213        }
1214
1215        if server_binary_exists {
1216            log::info!("remote development server already present",);
1217            return Ok(());
1218        }
1219
1220        let src_stat = fs::metadata(src_path).await?;
1221        let size = src_stat.len();
1222        let server_mode = 0o755;
1223
1224        let t0 = Instant::now();
1225        delegate.set_status(Some("uploading remote development server"), cx);
1226        log::info!("uploading remote development server ({}kb)", size / 1024);
1227        self.upload_file(src_path, &dst_path_gz)
1228            .await
1229            .context("failed to upload server binary")?;
1230        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1231
1232        delegate.set_status(Some("extracting remote development server"), cx);
1233        run_cmd(
1234            self.socket
1235                .ssh_command("gunzip")
1236                .arg("--force")
1237                .arg(&dst_path_gz),
1238        )
1239        .await?;
1240
1241        delegate.set_status(Some("unzipping remote development server"), cx);
1242        run_cmd(
1243            self.socket
1244                .ssh_command("chmod")
1245                .arg(format!("{:o}", server_mode))
1246                .arg(dst_path),
1247        )
1248        .await?;
1249
1250        Ok(())
1251    }
1252
1253    async fn query_platform(&self) -> Result<SshPlatform> {
1254        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1255        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1256
1257        let os = match os.trim() {
1258            "Darwin" => "macos",
1259            "Linux" => "linux",
1260            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1261        };
1262        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1263            "aarch64"
1264        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1265            "x86_64"
1266        } else {
1267            Err(anyhow!("unknown uname architecture {arch:?}"))?
1268        };
1269
1270        Ok(SshPlatform { os, arch })
1271    }
1272
1273    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1274        let mut command = process::Command::new("scp");
1275        let output = self
1276            .socket
1277            .ssh_options(&mut command)
1278            .args(
1279                self.socket
1280                    .connection_options
1281                    .port
1282                    .map(|port| vec!["-P".to_string(), port.to_string()])
1283                    .unwrap_or_default(),
1284            )
1285            .arg(src_path)
1286            .arg(format!(
1287                "{}:{}",
1288                self.socket.connection_options.scp_url(),
1289                dest_path.display()
1290            ))
1291            .output()
1292            .await?;
1293
1294        if output.status.success() {
1295            Ok(())
1296        } else {
1297            Err(anyhow!(
1298                "failed to upload file {} -> {}: {}",
1299                src_path.display(),
1300                dest_path.display(),
1301                String::from_utf8_lossy(&output.stderr)
1302            ))
1303        }
1304    }
1305}
1306
1307type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1308
1309pub struct ChannelClient {
1310    next_message_id: AtomicU32,
1311    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1312    response_channels: ResponseChannels,             // Lock
1313    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1314}
1315
1316impl ChannelClient {
1317    pub fn new(
1318        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1319        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1320        cx: &AppContext,
1321    ) -> Arc<Self> {
1322        let this = Arc::new(Self {
1323            outgoing_tx,
1324            next_message_id: AtomicU32::new(0),
1325            response_channels: ResponseChannels::default(),
1326            message_handlers: Default::default(),
1327        });
1328
1329        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1330
1331        this
1332    }
1333
1334    fn start_handling_messages(
1335        this: Arc<Self>,
1336        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1337        cx: &AppContext,
1338    ) {
1339        cx.spawn(|cx| {
1340            let this = Arc::downgrade(&this);
1341            async move {
1342                let peer_id = PeerId { owner_id: 0, id: 0 };
1343                while let Some(incoming) = incoming_rx.next().await {
1344                    let Some(this) = this.upgrade() else {
1345                        return anyhow::Ok(());
1346                    };
1347
1348                    if let Some(request_id) = incoming.responding_to {
1349                        let request_id = MessageId(request_id);
1350                        let sender = this.response_channels.lock().remove(&request_id);
1351                        if let Some(sender) = sender {
1352                            let (tx, rx) = oneshot::channel();
1353                            if incoming.payload.is_some() {
1354                                sender.send((incoming, tx)).ok();
1355                            }
1356                            rx.await.ok();
1357                        }
1358                    } else if let Some(envelope) =
1359                        build_typed_envelope(peer_id, Instant::now(), incoming)
1360                    {
1361                        let type_name = envelope.payload_type_name();
1362                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1363                            &this.message_handlers,
1364                            envelope,
1365                            this.clone().into(),
1366                            cx.clone(),
1367                        ) {
1368                            log::debug!("ssh message received. name:{type_name}");
1369                            match future.await {
1370                                Ok(_) => {
1371                                    log::debug!("ssh message handled. name:{type_name}");
1372                                }
1373                                Err(error) => {
1374                                    log::error!(
1375                                        "error handling message. type:{type_name}, error:{error}",
1376                                    );
1377                                }
1378                            }
1379                        } else {
1380                            log::error!("unhandled ssh message name:{type_name}");
1381                        }
1382                    }
1383                }
1384                anyhow::Ok(())
1385            }
1386        })
1387        .detach();
1388    }
1389
1390    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1391        let id = (TypeId::of::<E>(), remote_id);
1392
1393        let mut message_handlers = self.message_handlers.lock();
1394        if message_handlers
1395            .entities_by_type_and_remote_id
1396            .contains_key(&id)
1397        {
1398            panic!("already subscribed to entity");
1399        }
1400
1401        message_handlers.entities_by_type_and_remote_id.insert(
1402            id,
1403            EntityMessageSubscriber::Entity {
1404                handle: entity.downgrade().into(),
1405            },
1406        );
1407    }
1408
1409    pub fn request<T: RequestMessage>(
1410        &self,
1411        payload: T,
1412    ) -> impl 'static + Future<Output = Result<T::Response>> {
1413        log::debug!("ssh request start. name:{}", T::NAME);
1414        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1415        async move {
1416            let response = response.await?;
1417            log::debug!("ssh request finish. name:{}", T::NAME);
1418            T::Response::from_envelope(response)
1419                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1420        }
1421    }
1422
1423    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1424        smol::future::or(
1425            async {
1426                self.request(proto::Ping {}).await?;
1427                Ok(())
1428            },
1429            async {
1430                smol::Timer::after(timeout).await;
1431                Err(anyhow!("Timeout detected"))
1432            },
1433        )
1434        .await
1435    }
1436
1437    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1438        log::debug!("ssh send name:{}", T::NAME);
1439        self.send_dynamic(payload.into_envelope(0, None, None))
1440    }
1441
1442    pub fn request_dynamic(
1443        &self,
1444        mut envelope: proto::Envelope,
1445        type_name: &'static str,
1446    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1447        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1448        let (tx, rx) = oneshot::channel();
1449        let mut response_channels_lock = self.response_channels.lock();
1450        response_channels_lock.insert(MessageId(envelope.id), tx);
1451        drop(response_channels_lock);
1452        let result = self.outgoing_tx.unbounded_send(envelope);
1453        async move {
1454            if let Err(error) = &result {
1455                log::error!("failed to send message: {}", error);
1456                return Err(anyhow!("failed to send message: {}", error));
1457            }
1458
1459            let response = rx.await.context("connection lost")?.0;
1460            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1461                return Err(RpcError::from_proto(error, type_name));
1462            }
1463            Ok(response)
1464        }
1465    }
1466
1467    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1468        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1469        self.outgoing_tx.unbounded_send(envelope)?;
1470        Ok(())
1471    }
1472}
1473
1474impl ProtoClient for ChannelClient {
1475    fn request(
1476        &self,
1477        envelope: proto::Envelope,
1478        request_type: &'static str,
1479    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1480        self.request_dynamic(envelope, request_type).boxed()
1481    }
1482
1483    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1484        self.send_dynamic(envelope)
1485    }
1486
1487    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1488        self.send_dynamic(envelope)
1489    }
1490
1491    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1492        &self.message_handlers
1493    }
1494
1495    fn is_via_collab(&self) -> bool {
1496        false
1497    }
1498}