ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use collections::HashMap;
  10use futures::{
  11    channel::{
  12        mpsc::{self, UnboundedReceiver, UnboundedSender},
  13        oneshot,
  14    },
  15    future::BoxFuture,
  16    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  17    StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
  21};
  22use parking_lot::Mutex;
  23use rpc::{
  24    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  25    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  26};
  27use smol::{
  28    fs,
  29    process::{self, Child, Stdio},
  30    Timer,
  31};
  32use std::{
  33    any::TypeId,
  34    ffi::OsStr,
  35    fmt,
  36    ops::ControlFlow,
  37    path::{Path, PathBuf},
  38    sync::{
  39        atomic::{AtomicU32, Ordering::SeqCst},
  40        Arc,
  41    },
  42    time::{Duration, Instant},
  43};
  44use tempfile::TempDir;
  45use util::ResultExt;
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  49)]
  50pub struct SshProjectId(pub u64);
  51
  52#[derive(Clone)]
  53pub struct SshSocket {
  54    connection_options: SshConnectionOptions,
  55    socket_path: PathBuf,
  56}
  57
  58#[derive(Debug, Default, Clone, PartialEq, Eq)]
  59pub struct SshConnectionOptions {
  60    pub host: String,
  61    pub username: Option<String>,
  62    pub port: Option<u16>,
  63    pub password: Option<String>,
  64}
  65
  66impl SshConnectionOptions {
  67    pub fn ssh_url(&self) -> String {
  68        let mut result = String::from("ssh://");
  69        if let Some(username) = &self.username {
  70            result.push_str(username);
  71            result.push('@');
  72        }
  73        result.push_str(&self.host);
  74        if let Some(port) = self.port {
  75            result.push(':');
  76            result.push_str(&port.to_string());
  77        }
  78        result
  79    }
  80
  81    fn scp_url(&self) -> String {
  82        if let Some(username) = &self.username {
  83            format!("{}@{}", username, self.host)
  84        } else {
  85            self.host.clone()
  86        }
  87    }
  88
  89    pub fn connection_string(&self) -> String {
  90        let host = if let Some(username) = &self.username {
  91            format!("{}@{}", username, self.host)
  92        } else {
  93            self.host.clone()
  94        };
  95        if let Some(port) = &self.port {
  96            format!("{}:{}", host, port)
  97        } else {
  98            host
  99        }
 100    }
 101
 102    // Uniquely identifies dev server projects on a remote host. Needs to be
 103    // stable for the same dev server project.
 104    pub fn dev_server_identifier(&self) -> String {
 105        let mut identifier = format!("dev-server-{:?}", self.host);
 106        if let Some(username) = self.username.as_ref() {
 107            identifier.push('-');
 108            identifier.push_str(&username);
 109        }
 110        identifier
 111    }
 112}
 113
 114#[derive(Copy, Clone, Debug)]
 115pub struct SshPlatform {
 116    pub os: &'static str,
 117    pub arch: &'static str,
 118}
 119
 120pub trait SshClientDelegate: Send + Sync {
 121    fn ask_password(
 122        &self,
 123        prompt: String,
 124        cx: &mut AsyncAppContext,
 125    ) -> oneshot::Receiver<Result<String>>;
 126    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 127    fn get_server_binary(
 128        &self,
 129        platform: SshPlatform,
 130        cx: &mut AsyncAppContext,
 131    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 132    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 133    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 134}
 135
 136impl SshSocket {
 137    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 138        let mut command = process::Command::new("ssh");
 139        self.ssh_options(&mut command)
 140            .arg(self.connection_options.ssh_url())
 141            .arg(program);
 142        command
 143    }
 144
 145    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 146        command
 147            .stdin(Stdio::piped())
 148            .stdout(Stdio::piped())
 149            .stderr(Stdio::piped())
 150            .args(["-o", "ControlMaster=no", "-o"])
 151            .arg(format!("ControlPath={}", self.socket_path.display()))
 152    }
 153
 154    fn ssh_args(&self) -> Vec<String> {
 155        vec![
 156            "-o".to_string(),
 157            "ControlMaster=no".to_string(),
 158            "-o".to_string(),
 159            format!("ControlPath={}", self.socket_path.display()),
 160            self.connection_options.ssh_url(),
 161        ]
 162    }
 163}
 164
 165async fn run_cmd(command: &mut process::Command) -> Result<String> {
 166    let output = command.output().await?;
 167    if output.status.success() {
 168        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 169    } else {
 170        Err(anyhow!(
 171            "failed to run command: {}",
 172            String::from_utf8_lossy(&output.stderr)
 173        ))
 174    }
 175}
 176
 177struct ChannelForwarder {
 178    quit_tx: UnboundedSender<()>,
 179    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 180}
 181
 182impl ChannelForwarder {
 183    fn new(
 184        mut incoming_tx: UnboundedSender<Envelope>,
 185        mut outgoing_rx: UnboundedReceiver<Envelope>,
 186        cx: &AsyncAppContext,
 187    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 188        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 189
 190        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 191        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 192
 193        let forwarding_task = cx.background_executor().spawn(async move {
 194            loop {
 195                select_biased! {
 196                    _ = quit_rx.next().fuse() => {
 197                        break;
 198                    },
 199                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 200                        if let Some(envelope) = incoming_envelope {
 201                            if incoming_tx.send(envelope).await.is_err() {
 202                                break;
 203                            }
 204                        } else {
 205                            break;
 206                        }
 207                    }
 208                    outgoing_envelope = outgoing_rx.next().fuse() => {
 209                        if let Some(envelope) = outgoing_envelope {
 210                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 211                                break;
 212                            }
 213                        } else {
 214                            break;
 215                        }
 216                    }
 217                }
 218            }
 219
 220            (incoming_tx, outgoing_rx)
 221        });
 222
 223        (
 224            Self {
 225                forwarding_task,
 226                quit_tx,
 227            },
 228            proxy_incoming_tx,
 229            proxy_outgoing_rx,
 230        )
 231    }
 232
 233    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 234        let _ = self.quit_tx.send(()).await;
 235        self.forwarding_task.await
 236    }
 237}
 238
 239const MAX_MISSED_HEARTBEATS: usize = 5;
 240const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 241const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 242
 243const MAX_RECONNECT_ATTEMPTS: usize = 3;
 244
 245enum State {
 246    Connecting,
 247    Connected {
 248        ssh_connection: SshRemoteConnection,
 249        delegate: Arc<dyn SshClientDelegate>,
 250        forwarder: ChannelForwarder,
 251
 252        multiplex_task: Task<Result<()>>,
 253        heartbeat_task: Task<Result<()>>,
 254    },
 255    HeartbeatMissed {
 256        missed_heartbeats: usize,
 257
 258        ssh_connection: SshRemoteConnection,
 259        delegate: Arc<dyn SshClientDelegate>,
 260        forwarder: ChannelForwarder,
 261
 262        multiplex_task: Task<Result<()>>,
 263        heartbeat_task: Task<Result<()>>,
 264    },
 265    Reconnecting,
 266    ReconnectFailed {
 267        ssh_connection: SshRemoteConnection,
 268        delegate: Arc<dyn SshClientDelegate>,
 269        forwarder: ChannelForwarder,
 270
 271        error: anyhow::Error,
 272        attempts: usize,
 273    },
 274    ReconnectExhausted,
 275    ServerNotRunning,
 276}
 277
 278impl fmt::Display for State {
 279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 280        match self {
 281            Self::Connecting => write!(f, "connecting"),
 282            Self::Connected { .. } => write!(f, "connected"),
 283            Self::Reconnecting => write!(f, "reconnecting"),
 284            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 285            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 286            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 287            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 288        }
 289    }
 290}
 291
 292impl State {
 293    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
 294        match self {
 295            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
 296            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
 297            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
 298            _ => None,
 299        }
 300    }
 301
 302    fn can_reconnect(&self) -> bool {
 303        match self {
 304            Self::Connected { .. }
 305            | Self::HeartbeatMissed { .. }
 306            | Self::ReconnectFailed { .. } => true,
 307            State::Connecting
 308            | State::Reconnecting
 309            | State::ReconnectExhausted
 310            | State::ServerNotRunning => false,
 311        }
 312    }
 313
 314    fn is_reconnect_failed(&self) -> bool {
 315        matches!(self, Self::ReconnectFailed { .. })
 316    }
 317
 318    fn is_reconnecting(&self) -> bool {
 319        matches!(self, Self::Reconnecting { .. })
 320    }
 321
 322    fn heartbeat_recovered(self) -> Self {
 323        match self {
 324            Self::HeartbeatMissed {
 325                ssh_connection,
 326                delegate,
 327                forwarder,
 328                multiplex_task,
 329                heartbeat_task,
 330                ..
 331            } => Self::Connected {
 332                ssh_connection,
 333                delegate,
 334                forwarder,
 335                multiplex_task,
 336                heartbeat_task,
 337            },
 338            _ => self,
 339        }
 340    }
 341
 342    fn heartbeat_missed(self) -> Self {
 343        match self {
 344            Self::Connected {
 345                ssh_connection,
 346                delegate,
 347                forwarder,
 348                multiplex_task,
 349                heartbeat_task,
 350            } => Self::HeartbeatMissed {
 351                missed_heartbeats: 1,
 352                ssh_connection,
 353                delegate,
 354                forwarder,
 355                multiplex_task,
 356                heartbeat_task,
 357            },
 358            Self::HeartbeatMissed {
 359                missed_heartbeats,
 360                ssh_connection,
 361                delegate,
 362                forwarder,
 363                multiplex_task,
 364                heartbeat_task,
 365            } => Self::HeartbeatMissed {
 366                missed_heartbeats: missed_heartbeats + 1,
 367                ssh_connection,
 368                delegate,
 369                forwarder,
 370                multiplex_task,
 371                heartbeat_task,
 372            },
 373            _ => self,
 374        }
 375    }
 376}
 377
 378/// The state of the ssh connection.
 379#[derive(Clone, Copy, Debug)]
 380pub enum ConnectionState {
 381    Connecting,
 382    Connected,
 383    HeartbeatMissed,
 384    Reconnecting,
 385    Disconnected,
 386}
 387
 388impl From<&State> for ConnectionState {
 389    fn from(value: &State) -> Self {
 390        match value {
 391            State::Connecting => Self::Connecting,
 392            State::Connected { .. } => Self::Connected,
 393            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 394            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 395            State::ReconnectExhausted => Self::Disconnected,
 396            State::ServerNotRunning => Self::Disconnected,
 397        }
 398    }
 399}
 400
 401pub struct SshRemoteClient {
 402    client: Arc<ChannelClient>,
 403    unique_identifier: String,
 404    connection_options: SshConnectionOptions,
 405    state: Arc<Mutex<Option<State>>>,
 406}
 407
 408impl Drop for SshRemoteClient {
 409    fn drop(&mut self) {
 410        self.shutdown_processes();
 411    }
 412}
 413
 414impl SshRemoteClient {
 415    pub fn new(
 416        unique_identifier: String,
 417        connection_options: SshConnectionOptions,
 418        delegate: Arc<dyn SshClientDelegate>,
 419        cx: &AppContext,
 420    ) -> Task<Result<Model<Self>>> {
 421        cx.spawn(|mut cx| async move {
 422            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 423            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 424
 425            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 426            let this = cx.new_model(|cx| {
 427                cx.on_app_quit(|this: &mut Self, _| {
 428                    this.shutdown_processes();
 429                    futures::future::ready(())
 430                })
 431                .detach();
 432
 433                Self {
 434                    client: client.clone(),
 435                    unique_identifier: unique_identifier.clone(),
 436                    connection_options: connection_options.clone(),
 437                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 438                }
 439            })?;
 440
 441            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 442                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 443
 444            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 445                unique_identifier,
 446                false,
 447                connection_options,
 448                delegate.clone(),
 449                &mut cx,
 450            )
 451            .await?;
 452
 453            let multiplex_task = Self::multiplex(
 454                this.downgrade(),
 455                ssh_proxy_process,
 456                proxy_incoming_tx,
 457                proxy_outgoing_rx,
 458                &mut cx,
 459            );
 460
 461            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 462                log::error!("failed to establish connection: {}", error);
 463                delegate.set_error(error.to_string(), &mut cx);
 464                return Err(error);
 465            }
 466
 467            let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
 468
 469            this.update(&mut cx, |this, _| {
 470                *this.state.lock() = Some(State::Connected {
 471                    ssh_connection,
 472                    delegate,
 473                    forwarder: proxy,
 474                    multiplex_task,
 475                    heartbeat_task,
 476                });
 477            })?;
 478
 479            Ok(this)
 480        })
 481    }
 482
 483    fn shutdown_processes(&self) {
 484        let Some(state) = self.state.lock().take() else {
 485            return;
 486        };
 487        log::info!("shutting down ssh processes");
 488
 489        let State::Connected {
 490            multiplex_task,
 491            heartbeat_task,
 492            ..
 493        } = state
 494        else {
 495            return;
 496        };
 497        // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 498        // child of master_process.
 499        drop(multiplex_task);
 500        // Now drop the rest of state, which kills master process.
 501        drop(heartbeat_task);
 502    }
 503
 504    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 505        let mut lock = self.state.lock();
 506
 507        let can_reconnect = lock
 508            .as_ref()
 509            .map(|state| state.can_reconnect())
 510            .unwrap_or(false);
 511        if !can_reconnect {
 512            let error = if let Some(state) = lock.as_ref() {
 513                format!("invalid state, cannot reconnect while in state {state}")
 514            } else {
 515                "no state set".to_string()
 516            };
 517            log::info!("aborting reconnect, because not in state that allows reconnecting");
 518            return Err(anyhow!(error));
 519        }
 520
 521        let state = lock.take().unwrap();
 522        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 523            State::Connected {
 524                ssh_connection,
 525                delegate,
 526                forwarder,
 527                multiplex_task,
 528                heartbeat_task,
 529            }
 530            | State::HeartbeatMissed {
 531                ssh_connection,
 532                delegate,
 533                forwarder,
 534                multiplex_task,
 535                heartbeat_task,
 536                ..
 537            } => {
 538                drop(multiplex_task);
 539                drop(heartbeat_task);
 540                (0, ssh_connection, delegate, forwarder)
 541            }
 542            State::ReconnectFailed {
 543                attempts,
 544                ssh_connection,
 545                delegate,
 546                forwarder,
 547                ..
 548            } => (attempts, ssh_connection, delegate, forwarder),
 549            State::Connecting
 550            | State::Reconnecting
 551            | State::ReconnectExhausted
 552            | State::ServerNotRunning => unreachable!(),
 553        };
 554
 555        let attempts = attempts + 1;
 556        if attempts > MAX_RECONNECT_ATTEMPTS {
 557            log::error!(
 558                "Failed to reconnect to after {} attempts, giving up",
 559                MAX_RECONNECT_ATTEMPTS
 560            );
 561            drop(lock);
 562            self.set_state(State::ReconnectExhausted, cx);
 563            return Ok(());
 564        }
 565        drop(lock);
 566
 567        self.set_state(State::Reconnecting, cx);
 568
 569        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 570
 571        let identifier = self.unique_identifier.clone();
 572        let client = self.client.clone();
 573        let reconnect_task = cx.spawn(|this, mut cx| async move {
 574            macro_rules! failed {
 575                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 576                    return State::ReconnectFailed {
 577                        error: anyhow!($error),
 578                        attempts: $attempts,
 579                        ssh_connection: $ssh_connection,
 580                        delegate: $delegate,
 581                        forwarder: $forwarder,
 582                    };
 583                };
 584            }
 585
 586            if let Err(error) = ssh_connection.master_process.kill() {
 587                failed!(error, attempts, ssh_connection, delegate, forwarder);
 588            };
 589
 590            if let Err(error) = ssh_connection
 591                .master_process
 592                .status()
 593                .await
 594                .context("Failed to kill ssh process")
 595            {
 596                failed!(error, attempts, ssh_connection, delegate, forwarder);
 597            }
 598
 599            let connection_options = ssh_connection.socket.connection_options.clone();
 600
 601            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 602            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 603                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 604
 605            let (ssh_connection, ssh_process) = match Self::establish_connection(
 606                identifier,
 607                true,
 608                connection_options,
 609                delegate.clone(),
 610                &mut cx,
 611            )
 612            .await
 613            {
 614                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 615                Err(error) => {
 616                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 617                }
 618            };
 619
 620            let multiplex_task = Self::multiplex(
 621                this.clone(),
 622                ssh_process,
 623                proxy_incoming_tx,
 624                proxy_outgoing_rx,
 625                &mut cx,
 626            );
 627
 628            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 629                failed!(error, attempts, ssh_connection, delegate, forwarder);
 630            };
 631
 632            State::Connected {
 633                ssh_connection,
 634                delegate,
 635                forwarder,
 636                multiplex_task,
 637                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
 638            }
 639        });
 640
 641        cx.spawn(|this, mut cx| async move {
 642            let new_state = reconnect_task.await;
 643            this.update(&mut cx, |this, cx| {
 644                this.try_set_state(cx, |old_state| {
 645                    if old_state.is_reconnecting() {
 646                        match &new_state {
 647                            State::Connecting
 648                            | State::Reconnecting { .. }
 649                            | State::HeartbeatMissed { .. }
 650                            | State::ServerNotRunning => {}
 651                            State::Connected { .. } => {
 652                                log::info!("Successfully reconnected");
 653                            }
 654                            State::ReconnectFailed {
 655                                error, attempts, ..
 656                            } => {
 657                                log::error!(
 658                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 659                                    attempts,
 660                                    error
 661                                );
 662                            }
 663                            State::ReconnectExhausted => {
 664                                log::error!("Reconnect attempt failed and all attempts exhausted");
 665                            }
 666                        }
 667                        Some(new_state)
 668                    } else {
 669                        None
 670                    }
 671                });
 672
 673                if this.state_is(State::is_reconnect_failed) {
 674                    this.reconnect(cx)
 675                } else {
 676                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 677                    Ok(())
 678                }
 679            })
 680        })
 681        .detach_and_log_err(cx);
 682
 683        Ok(())
 684    }
 685
 686    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
 687        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 688            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 689        };
 690        cx.spawn(|mut cx| {
 691            let this = this.clone();
 692            async move {
 693                let mut missed_heartbeats = 0;
 694
 695                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
 696                loop {
 697                    timer.next().await;
 698
 699                    log::debug!("Sending heartbeat to server...");
 700
 701                    let result = client.ping(HEARTBEAT_TIMEOUT).await;
 702                    if result.is_err() {
 703                        missed_heartbeats += 1;
 704                        log::warn!(
 705                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 706                            HEARTBEAT_TIMEOUT,
 707                            missed_heartbeats,
 708                            MAX_MISSED_HEARTBEATS
 709                        );
 710                    } else if missed_heartbeats != 0 {
 711                        missed_heartbeats = 0;
 712                    } else {
 713                        continue;
 714                    }
 715
 716                    let result = this.update(&mut cx, |this, mut cx| {
 717                        this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 718                    })?;
 719                    if result.is_break() {
 720                        return Ok(());
 721                    }
 722                }
 723            }
 724        })
 725    }
 726
 727    fn handle_heartbeat_result(
 728        &mut self,
 729        missed_heartbeats: usize,
 730        cx: &mut ModelContext<Self>,
 731    ) -> ControlFlow<()> {
 732        let state = self.state.lock().take().unwrap();
 733        let next_state = if missed_heartbeats > 0 {
 734            state.heartbeat_missed()
 735        } else {
 736            state.heartbeat_recovered()
 737        };
 738
 739        self.set_state(next_state, cx);
 740
 741        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 742            log::error!(
 743                "Missed last {} heartbeats. Reconnecting...",
 744                missed_heartbeats
 745            );
 746
 747            self.reconnect(cx)
 748                .context("failed to start reconnect process after missing heartbeats")
 749                .log_err();
 750            ControlFlow::Break(())
 751        } else {
 752            ControlFlow::Continue(())
 753        }
 754    }
 755
 756    fn multiplex(
 757        this: WeakModel<Self>,
 758        mut ssh_proxy_process: Child,
 759        incoming_tx: UnboundedSender<Envelope>,
 760        mut outgoing_rx: UnboundedReceiver<Envelope>,
 761        cx: &AsyncAppContext,
 762    ) -> Task<Result<()>> {
 763        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 764        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 765        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 766
 767        let io_task = cx.background_executor().spawn(async move {
 768            let mut stdin_buffer = Vec::new();
 769            let mut stdout_buffer = Vec::new();
 770            let mut stderr_buffer = Vec::new();
 771            let mut stderr_offset = 0;
 772
 773            loop {
 774                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 775                stderr_buffer.resize(stderr_offset + 1024, 0);
 776
 777                select_biased! {
 778                    outgoing = outgoing_rx.next().fuse() => {
 779                        let Some(outgoing) = outgoing else {
 780                            return anyhow::Ok(None);
 781                        };
 782
 783                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 784                    }
 785
 786                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 787                        match result {
 788                            Ok(0) => {
 789                                child_stdin.close().await?;
 790                                outgoing_rx.close();
 791                                let status = ssh_proxy_process.status().await?;
 792                                return Ok(status.code());
 793                            }
 794                            Ok(len) => {
 795                                if len < stdout_buffer.len() {
 796                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 797                                }
 798
 799                                let message_len = message_len_from_buffer(&stdout_buffer);
 800                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 801                                    Ok(envelope) => {
 802                                        incoming_tx.unbounded_send(envelope).ok();
 803                                    }
 804                                    Err(error) => {
 805                                        log::error!("error decoding message {error:?}");
 806                                    }
 807                                }
 808                            }
 809                            Err(error) => {
 810                                Err(anyhow!("error reading stdout: {error:?}"))?;
 811                            }
 812                        }
 813                    }
 814
 815                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 816                        match result {
 817                            Ok(len) => {
 818                                stderr_offset += len;
 819                                let mut start_ix = 0;
 820                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 821                                    let line_ix = start_ix + ix;
 822                                    let content = &stderr_buffer[start_ix..line_ix];
 823                                    start_ix = line_ix + 1;
 824                                    if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
 825                                        record.message = format!("(remote) {}", record.message);
 826                                        record.log(log::logger())
 827                                    } else {
 828                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 829                                    }
 830                                }
 831                                stderr_buffer.drain(0..start_ix);
 832                                stderr_offset -= start_ix;
 833                            }
 834                            Err(error) => {
 835                                Err(anyhow!("error reading stderr: {error:?}"))?;
 836                            }
 837                        }
 838                    }
 839                }
 840            }
 841        });
 842
 843        cx.spawn(|mut cx| async move {
 844            let result = io_task.await;
 845
 846            match result {
 847                Ok(Some(exit_code)) => {
 848                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 849                        match error {
 850                            ProxyLaunchError::ServerNotRunning => {
 851                                log::error!("failed to reconnect because server is not running");
 852                                this.update(&mut cx, |this, cx| {
 853                                    this.set_state(State::ServerNotRunning, cx);
 854                                })?;
 855                            }
 856                        }
 857                    } else if exit_code > 0 {
 858                        log::error!("proxy process terminated unexpectedly");
 859                    }
 860                }
 861                Ok(None) => {}
 862                Err(error) => {
 863                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 864                    this.update(&mut cx, |this, cx| {
 865                        this.reconnect(cx).ok();
 866                    })?;
 867                }
 868            }
 869            Ok(())
 870        })
 871    }
 872
 873    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 874        self.state.lock().as_ref().map_or(false, check)
 875    }
 876
 877    fn try_set_state(
 878        &self,
 879        cx: &mut ModelContext<Self>,
 880        map: impl FnOnce(&State) -> Option<State>,
 881    ) {
 882        let mut lock = self.state.lock();
 883        let new_state = lock.as_ref().and_then(map);
 884
 885        if let Some(new_state) = new_state {
 886            lock.replace(new_state);
 887            cx.notify();
 888        }
 889    }
 890
 891    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 892        log::info!("setting state to '{}'", &state);
 893        self.state.lock().replace(state);
 894        cx.notify();
 895    }
 896
 897    async fn establish_connection(
 898        unique_identifier: String,
 899        reconnect: bool,
 900        connection_options: SshConnectionOptions,
 901        delegate: Arc<dyn SshClientDelegate>,
 902        cx: &mut AsyncAppContext,
 903    ) -> Result<(SshRemoteConnection, Child)> {
 904        let ssh_connection =
 905            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 906
 907        let platform = ssh_connection.query_platform().await?;
 908        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 909        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 910        ssh_connection
 911            .ensure_server_binary(
 912                &delegate,
 913                &local_binary_path,
 914                &remote_binary_path,
 915                version,
 916                cx,
 917            )
 918            .await?;
 919
 920        let socket = ssh_connection.socket.clone();
 921        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 922
 923        delegate.set_status(Some("Starting proxy"), cx);
 924
 925        let mut start_proxy_command = format!(
 926            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 927            std::env::var("RUST_LOG").unwrap_or_default(),
 928            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 929            remote_binary_path,
 930            unique_identifier,
 931        );
 932        if reconnect {
 933            start_proxy_command.push_str(" --reconnect");
 934        }
 935
 936        let ssh_proxy_process = socket
 937            .ssh_command(start_proxy_command)
 938            // IMPORTANT: we kill this process when we drop the task that uses it.
 939            .kill_on_drop(true)
 940            .spawn()
 941            .context("failed to spawn remote server")?;
 942
 943        Ok((ssh_connection, ssh_proxy_process))
 944    }
 945
 946    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 947        self.client.subscribe_to_entity(remote_id, entity);
 948    }
 949
 950    pub fn ssh_args(&self) -> Option<Vec<String>> {
 951        self.state
 952            .lock()
 953            .as_ref()
 954            .and_then(|state| state.ssh_connection())
 955            .map(|ssh_connection| ssh_connection.socket.ssh_args())
 956    }
 957
 958    pub fn to_proto_client(&self) -> AnyProtoClient {
 959        self.client.clone().into()
 960    }
 961
 962    pub fn connection_string(&self) -> String {
 963        self.connection_options.connection_string()
 964    }
 965
 966    pub fn connection_state(&self) -> ConnectionState {
 967        self.state
 968            .lock()
 969            .as_ref()
 970            .map(ConnectionState::from)
 971            .unwrap_or(ConnectionState::Disconnected)
 972    }
 973
 974    #[cfg(any(test, feature = "test-support"))]
 975    pub fn fake(
 976        client_cx: &mut gpui::TestAppContext,
 977        server_cx: &mut gpui::TestAppContext,
 978    ) -> (Model<Self>, Arc<ChannelClient>) {
 979        use gpui::Context;
 980
 981        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
 982        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
 983
 984        (
 985            client_cx.update(|cx| {
 986                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
 987                cx.new_model(|_| Self {
 988                    client,
 989                    unique_identifier: "fake".to_string(),
 990                    connection_options: SshConnectionOptions::default(),
 991                    state: Arc::new(Mutex::new(None)),
 992                })
 993            }),
 994            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
 995        )
 996    }
 997}
 998
 999impl From<SshRemoteClient> for AnyProtoClient {
1000    fn from(client: SshRemoteClient) -> Self {
1001        AnyProtoClient::new(client.client.clone())
1002    }
1003}
1004
1005struct SshRemoteConnection {
1006    socket: SshSocket,
1007    master_process: process::Child,
1008    _temp_dir: TempDir,
1009}
1010
1011impl Drop for SshRemoteConnection {
1012    fn drop(&mut self) {
1013        if let Err(error) = self.master_process.kill() {
1014            log::error!("failed to kill SSH master process: {}", error);
1015        }
1016    }
1017}
1018
1019impl SshRemoteConnection {
1020    #[cfg(not(unix))]
1021    async fn new(
1022        _connection_options: SshConnectionOptions,
1023        _delegate: Arc<dyn SshClientDelegate>,
1024        _cx: &mut AsyncAppContext,
1025    ) -> Result<Self> {
1026        Err(anyhow!("ssh is not supported on this platform"))
1027    }
1028
1029    #[cfg(unix)]
1030    async fn new(
1031        connection_options: SshConnectionOptions,
1032        delegate: Arc<dyn SshClientDelegate>,
1033        cx: &mut AsyncAppContext,
1034    ) -> Result<Self> {
1035        use futures::{io::BufReader, AsyncBufReadExt as _};
1036        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1037        use util::ResultExt as _;
1038
1039        delegate.set_status(Some("connecting"), cx);
1040
1041        let url = connection_options.ssh_url();
1042        let temp_dir = tempfile::Builder::new()
1043            .prefix("zed-ssh-session")
1044            .tempdir()?;
1045
1046        // Create a domain socket listener to handle requests from the askpass program.
1047        let askpass_socket = temp_dir.path().join("askpass.sock");
1048        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1049        let listener =
1050            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1051
1052        let askpass_task = cx.spawn({
1053            let delegate = delegate.clone();
1054            |mut cx| async move {
1055                let mut askpass_opened_tx = Some(askpass_opened_tx);
1056
1057                while let Ok((mut stream, _)) = listener.accept().await {
1058                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1059                        askpass_opened_tx.send(()).ok();
1060                    }
1061                    let mut buffer = Vec::new();
1062                    let mut reader = BufReader::new(&mut stream);
1063                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1064                        buffer.clear();
1065                    }
1066                    let password_prompt = String::from_utf8_lossy(&buffer);
1067                    if let Some(password) = delegate
1068                        .ask_password(password_prompt.to_string(), &mut cx)
1069                        .await
1070                        .context("failed to get ssh password")
1071                        .and_then(|p| p)
1072                        .log_err()
1073                    {
1074                        stream.write_all(password.as_bytes()).await.log_err();
1075                    }
1076                }
1077            }
1078        });
1079
1080        // Create an askpass script that communicates back to this process.
1081        let askpass_script = format!(
1082            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1083            askpass_socket = askpass_socket.display(),
1084            print_args = "printf '%s\\0' \"$@\"",
1085            shebang = "#!/bin/sh",
1086        );
1087        let askpass_script_path = temp_dir.path().join("askpass.sh");
1088        fs::write(&askpass_script_path, askpass_script).await?;
1089        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1090
1091        // Start the master SSH process, which does not do anything except for establish
1092        // the connection and keep it open, allowing other ssh commands to reuse it
1093        // via a control socket.
1094        let socket_path = temp_dir.path().join("ssh.sock");
1095        let mut master_process = process::Command::new("ssh")
1096            .stdin(Stdio::null())
1097            .stdout(Stdio::piped())
1098            .stderr(Stdio::piped())
1099            .env("SSH_ASKPASS_REQUIRE", "force")
1100            .env("SSH_ASKPASS", &askpass_script_path)
1101            .args(["-N", "-o", "ControlMaster=yes", "-o"])
1102            .arg(format!("ControlPath={}", socket_path.display()))
1103            .arg(&url)
1104            .spawn()?;
1105
1106        // Wait for this ssh process to close its stdout, indicating that authentication
1107        // has completed.
1108        let stdout = master_process.stdout.as_mut().unwrap();
1109        let mut output = Vec::new();
1110        let connection_timeout = Duration::from_secs(10);
1111
1112        let result = select_biased! {
1113            _ = askpass_opened_rx.fuse() => {
1114                // If the askpass script has opened, that means the user is typing
1115                // their password, in which case we don't want to timeout anymore,
1116                // since we know a connection has been established.
1117                stdout.read_to_end(&mut output).await?;
1118                Ok(())
1119            }
1120            result = stdout.read_to_end(&mut output).fuse() => {
1121                result?;
1122                Ok(())
1123            }
1124            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1125                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1126            }
1127        };
1128
1129        if let Err(e) = result {
1130            let error_message = format!("Failed to connect to host: {}.", e);
1131            delegate.set_error(error_message, cx);
1132            return Err(e);
1133        }
1134
1135        drop(askpass_task);
1136
1137        if master_process.try_status()?.is_some() {
1138            output.clear();
1139            let mut stderr = master_process.stderr.take().unwrap();
1140            stderr.read_to_end(&mut output).await?;
1141
1142            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1143            delegate.set_error(error_message.clone(), cx);
1144            Err(anyhow!(error_message))?;
1145        }
1146
1147        Ok(Self {
1148            socket: SshSocket {
1149                connection_options,
1150                socket_path,
1151            },
1152            master_process,
1153            _temp_dir: temp_dir,
1154        })
1155    }
1156
1157    async fn ensure_server_binary(
1158        &self,
1159        delegate: &Arc<dyn SshClientDelegate>,
1160        src_path: &Path,
1161        dst_path: &Path,
1162        version: SemanticVersion,
1163        cx: &mut AsyncAppContext,
1164    ) -> Result<()> {
1165        let mut dst_path_gz = dst_path.to_path_buf();
1166        dst_path_gz.set_extension("gz");
1167
1168        if let Some(parent) = dst_path.parent() {
1169            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1170        }
1171
1172        let mut server_binary_exists = false;
1173        if cfg!(not(debug_assertions)) {
1174            if let Ok(installed_version) =
1175                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1176            {
1177                if installed_version.trim() == version.to_string() {
1178                    server_binary_exists = true;
1179                }
1180            }
1181        }
1182
1183        if server_binary_exists {
1184            log::info!("remote development server already present",);
1185            return Ok(());
1186        }
1187
1188        let src_stat = fs::metadata(src_path).await?;
1189        let size = src_stat.len();
1190        let server_mode = 0o755;
1191
1192        let t0 = Instant::now();
1193        delegate.set_status(Some("uploading remote development server"), cx);
1194        log::info!("uploading remote development server ({}kb)", size / 1024);
1195        self.upload_file(src_path, &dst_path_gz)
1196            .await
1197            .context("failed to upload server binary")?;
1198        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1199
1200        delegate.set_status(Some("extracting remote development server"), cx);
1201        run_cmd(
1202            self.socket
1203                .ssh_command("gunzip")
1204                .arg("--force")
1205                .arg(&dst_path_gz),
1206        )
1207        .await?;
1208
1209        delegate.set_status(Some("unzipping remote development server"), cx);
1210        run_cmd(
1211            self.socket
1212                .ssh_command("chmod")
1213                .arg(format!("{:o}", server_mode))
1214                .arg(dst_path),
1215        )
1216        .await?;
1217
1218        Ok(())
1219    }
1220
1221    async fn query_platform(&self) -> Result<SshPlatform> {
1222        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1223        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1224
1225        let os = match os.trim() {
1226            "Darwin" => "macos",
1227            "Linux" => "linux",
1228            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1229        };
1230        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1231            "aarch64"
1232        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1233            "x86_64"
1234        } else {
1235            Err(anyhow!("unknown uname architecture {arch:?}"))?
1236        };
1237
1238        Ok(SshPlatform { os, arch })
1239    }
1240
1241    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1242        let mut command = process::Command::new("scp");
1243        let output = self
1244            .socket
1245            .ssh_options(&mut command)
1246            .args(
1247                self.socket
1248                    .connection_options
1249                    .port
1250                    .map(|port| vec!["-P".to_string(), port.to_string()])
1251                    .unwrap_or_default(),
1252            )
1253            .arg(src_path)
1254            .arg(format!(
1255                "{}:{}",
1256                self.socket.connection_options.scp_url(),
1257                dest_path.display()
1258            ))
1259            .output()
1260            .await?;
1261
1262        if output.status.success() {
1263            Ok(())
1264        } else {
1265            Err(anyhow!(
1266                "failed to upload file {} -> {}: {}",
1267                src_path.display(),
1268                dest_path.display(),
1269                String::from_utf8_lossy(&output.stderr)
1270            ))
1271        }
1272    }
1273}
1274
1275type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1276
1277pub struct ChannelClient {
1278    next_message_id: AtomicU32,
1279    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1280    response_channels: ResponseChannels,             // Lock
1281    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1282}
1283
1284impl ChannelClient {
1285    pub fn new(
1286        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1287        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1288        cx: &AppContext,
1289    ) -> Arc<Self> {
1290        let this = Arc::new(Self {
1291            outgoing_tx,
1292            next_message_id: AtomicU32::new(0),
1293            response_channels: ResponseChannels::default(),
1294            message_handlers: Default::default(),
1295        });
1296
1297        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1298
1299        this
1300    }
1301
1302    fn start_handling_messages(
1303        this: Arc<Self>,
1304        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1305        cx: &AppContext,
1306    ) {
1307        cx.spawn(|cx| {
1308            let this = Arc::downgrade(&this);
1309            async move {
1310                let peer_id = PeerId { owner_id: 0, id: 0 };
1311                while let Some(incoming) = incoming_rx.next().await {
1312                    let Some(this) = this.upgrade() else {
1313                        return anyhow::Ok(());
1314                    };
1315
1316                    if let Some(request_id) = incoming.responding_to {
1317                        let request_id = MessageId(request_id);
1318                        let sender = this.response_channels.lock().remove(&request_id);
1319                        if let Some(sender) = sender {
1320                            let (tx, rx) = oneshot::channel();
1321                            if incoming.payload.is_some() {
1322                                sender.send((incoming, tx)).ok();
1323                            }
1324                            rx.await.ok();
1325                        }
1326                    } else if let Some(envelope) =
1327                        build_typed_envelope(peer_id, Instant::now(), incoming)
1328                    {
1329                        let type_name = envelope.payload_type_name();
1330                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1331                            &this.message_handlers,
1332                            envelope,
1333                            this.clone().into(),
1334                            cx.clone(),
1335                        ) {
1336                            log::debug!("ssh message received. name:{type_name}");
1337                            match future.await {
1338                                Ok(_) => {
1339                                    log::debug!("ssh message handled. name:{type_name}");
1340                                }
1341                                Err(error) => {
1342                                    log::error!(
1343                                        "error handling message. type:{type_name}, error:{error}",
1344                                    );
1345                                }
1346                            }
1347                        } else {
1348                            log::error!("unhandled ssh message name:{type_name}");
1349                        }
1350                    }
1351                }
1352                anyhow::Ok(())
1353            }
1354        })
1355        .detach();
1356    }
1357
1358    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1359        let id = (TypeId::of::<E>(), remote_id);
1360
1361        let mut message_handlers = self.message_handlers.lock();
1362        if message_handlers
1363            .entities_by_type_and_remote_id
1364            .contains_key(&id)
1365        {
1366            panic!("already subscribed to entity");
1367        }
1368
1369        message_handlers.entities_by_type_and_remote_id.insert(
1370            id,
1371            EntityMessageSubscriber::Entity {
1372                handle: entity.downgrade().into(),
1373            },
1374        );
1375    }
1376
1377    pub fn request<T: RequestMessage>(
1378        &self,
1379        payload: T,
1380    ) -> impl 'static + Future<Output = Result<T::Response>> {
1381        log::debug!("ssh request start. name:{}", T::NAME);
1382        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1383        async move {
1384            let response = response.await?;
1385            log::debug!("ssh request finish. name:{}", T::NAME);
1386            T::Response::from_envelope(response)
1387                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1388        }
1389    }
1390
1391    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1392        smol::future::or(
1393            async {
1394                self.request(proto::Ping {}).await?;
1395                Ok(())
1396            },
1397            async {
1398                smol::Timer::after(timeout).await;
1399                Err(anyhow!("Timeout detected"))
1400            },
1401        )
1402        .await
1403    }
1404
1405    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1406        log::debug!("ssh send name:{}", T::NAME);
1407        self.send_dynamic(payload.into_envelope(0, None, None))
1408    }
1409
1410    pub fn request_dynamic(
1411        &self,
1412        mut envelope: proto::Envelope,
1413        type_name: &'static str,
1414    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1415        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1416        let (tx, rx) = oneshot::channel();
1417        let mut response_channels_lock = self.response_channels.lock();
1418        response_channels_lock.insert(MessageId(envelope.id), tx);
1419        drop(response_channels_lock);
1420        let result = self.outgoing_tx.unbounded_send(envelope);
1421        async move {
1422            if let Err(error) = &result {
1423                log::error!("failed to send message: {}", error);
1424                return Err(anyhow!("failed to send message: {}", error));
1425            }
1426
1427            let response = rx.await.context("connection lost")?.0;
1428            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1429                return Err(RpcError::from_proto(error, type_name));
1430            }
1431            Ok(response)
1432        }
1433    }
1434
1435    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1436        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1437        self.outgoing_tx.unbounded_send(envelope)?;
1438        Ok(())
1439    }
1440}
1441
1442impl ProtoClient for ChannelClient {
1443    fn request(
1444        &self,
1445        envelope: proto::Envelope,
1446        request_type: &'static str,
1447    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1448        self.request_dynamic(envelope, request_type).boxed()
1449    }
1450
1451    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1452        self.send_dynamic(envelope)
1453    }
1454
1455    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1456        self.send_dynamic(envelope)
1457    }
1458
1459    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1460        &self.message_handlers
1461    }
1462
1463    fn is_via_collab(&self) -> bool {
1464        false
1465    }
1466}