ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use collections::HashMap;
  10use futures::{
  11    channel::{
  12        mpsc::{self, UnboundedReceiver, UnboundedSender},
  13        oneshot,
  14    },
  15    future::BoxFuture,
  16    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  17    StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
  21};
  22use parking_lot::Mutex;
  23use rpc::{
  24    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  25    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  26};
  27use smol::{
  28    fs,
  29    process::{self, Child, Stdio},
  30    Timer,
  31};
  32use std::{
  33    any::TypeId,
  34    ffi::OsStr,
  35    fmt,
  36    ops::ControlFlow,
  37    path::{Path, PathBuf},
  38    sync::{
  39        atomic::{AtomicU32, Ordering::SeqCst},
  40        Arc,
  41    },
  42    time::{Duration, Instant},
  43};
  44use tempfile::TempDir;
  45use util::ResultExt;
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  49)]
  50pub struct SshProjectId(pub u64);
  51
  52#[derive(Clone)]
  53pub struct SshSocket {
  54    connection_options: SshConnectionOptions,
  55    socket_path: PathBuf,
  56}
  57
  58#[derive(Debug, Default, Clone, PartialEq, Eq)]
  59pub struct SshConnectionOptions {
  60    pub host: String,
  61    pub username: Option<String>,
  62    pub port: Option<u16>,
  63    pub password: Option<String>,
  64}
  65
  66impl SshConnectionOptions {
  67    pub fn ssh_url(&self) -> String {
  68        let mut result = String::from("ssh://");
  69        if let Some(username) = &self.username {
  70            result.push_str(username);
  71            result.push('@');
  72        }
  73        result.push_str(&self.host);
  74        if let Some(port) = self.port {
  75            result.push(':');
  76            result.push_str(&port.to_string());
  77        }
  78        result
  79    }
  80
  81    fn scp_url(&self) -> String {
  82        if let Some(username) = &self.username {
  83            format!("{}@{}", username, self.host)
  84        } else {
  85            self.host.clone()
  86        }
  87    }
  88
  89    pub fn connection_string(&self) -> String {
  90        let host = if let Some(username) = &self.username {
  91            format!("{}@{}", username, self.host)
  92        } else {
  93            self.host.clone()
  94        };
  95        if let Some(port) = &self.port {
  96            format!("{}:{}", host, port)
  97        } else {
  98            host
  99        }
 100    }
 101
 102    // Uniquely identifies dev server projects on a remote host. Needs to be
 103    // stable for the same dev server project.
 104    pub fn dev_server_identifier(&self) -> String {
 105        let mut identifier = format!("dev-server-{:?}", self.host);
 106        if let Some(username) = self.username.as_ref() {
 107            identifier.push('-');
 108            identifier.push_str(&username);
 109        }
 110        identifier
 111    }
 112}
 113
 114#[derive(Copy, Clone, Debug)]
 115pub struct SshPlatform {
 116    pub os: &'static str,
 117    pub arch: &'static str,
 118}
 119
 120pub trait SshClientDelegate: Send + Sync {
 121    fn ask_password(
 122        &self,
 123        prompt: String,
 124        cx: &mut AsyncAppContext,
 125    ) -> oneshot::Receiver<Result<String>>;
 126    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 127    fn get_server_binary(
 128        &self,
 129        platform: SshPlatform,
 130        cx: &mut AsyncAppContext,
 131    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 132    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 133    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 134}
 135
 136impl SshSocket {
 137    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 138        let mut command = process::Command::new("ssh");
 139        self.ssh_options(&mut command)
 140            .arg(self.connection_options.ssh_url())
 141            .arg(program);
 142        command
 143    }
 144
 145    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 146        command
 147            .stdin(Stdio::piped())
 148            .stdout(Stdio::piped())
 149            .stderr(Stdio::piped())
 150            .args(["-o", "ControlMaster=no", "-o"])
 151            .arg(format!("ControlPath={}", self.socket_path.display()))
 152    }
 153
 154    fn ssh_args(&self) -> Vec<String> {
 155        vec![
 156            "-o".to_string(),
 157            "ControlMaster=no".to_string(),
 158            "-o".to_string(),
 159            format!("ControlPath={}", self.socket_path.display()),
 160            self.connection_options.ssh_url(),
 161        ]
 162    }
 163}
 164
 165async fn run_cmd(command: &mut process::Command) -> Result<String> {
 166    let output = command.output().await?;
 167    if output.status.success() {
 168        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 169    } else {
 170        Err(anyhow!(
 171            "failed to run command: {}",
 172            String::from_utf8_lossy(&output.stderr)
 173        ))
 174    }
 175}
 176
 177struct ChannelForwarder {
 178    quit_tx: UnboundedSender<()>,
 179    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 180}
 181
 182impl ChannelForwarder {
 183    fn new(
 184        mut incoming_tx: UnboundedSender<Envelope>,
 185        mut outgoing_rx: UnboundedReceiver<Envelope>,
 186        cx: &AsyncAppContext,
 187    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 188        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 189
 190        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 191        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 192
 193        let forwarding_task = cx.background_executor().spawn(async move {
 194            loop {
 195                select_biased! {
 196                    _ = quit_rx.next().fuse() => {
 197                        break;
 198                    },
 199                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 200                        if let Some(envelope) = incoming_envelope {
 201                            if incoming_tx.send(envelope).await.is_err() {
 202                                break;
 203                            }
 204                        } else {
 205                            break;
 206                        }
 207                    }
 208                    outgoing_envelope = outgoing_rx.next().fuse() => {
 209                        if let Some(envelope) = outgoing_envelope {
 210                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 211                                break;
 212                            }
 213                        } else {
 214                            break;
 215                        }
 216                    }
 217                }
 218            }
 219
 220            (incoming_tx, outgoing_rx)
 221        });
 222
 223        (
 224            Self {
 225                forwarding_task,
 226                quit_tx,
 227            },
 228            proxy_incoming_tx,
 229            proxy_outgoing_rx,
 230        )
 231    }
 232
 233    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 234        let _ = self.quit_tx.send(()).await;
 235        self.forwarding_task.await
 236    }
 237}
 238
 239const MAX_MISSED_HEARTBEATS: usize = 5;
 240const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 241const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 242
 243const MAX_RECONNECT_ATTEMPTS: usize = 3;
 244
 245enum State {
 246    Connecting,
 247    Connected {
 248        ssh_connection: SshRemoteConnection,
 249        delegate: Arc<dyn SshClientDelegate>,
 250        forwarder: ChannelForwarder,
 251
 252        multiplex_task: Task<Result<()>>,
 253        heartbeat_task: Task<Result<()>>,
 254    },
 255    HeartbeatMissed {
 256        missed_heartbeats: usize,
 257
 258        ssh_connection: SshRemoteConnection,
 259        delegate: Arc<dyn SshClientDelegate>,
 260        forwarder: ChannelForwarder,
 261
 262        multiplex_task: Task<Result<()>>,
 263        heartbeat_task: Task<Result<()>>,
 264    },
 265    Reconnecting,
 266    ReconnectFailed {
 267        ssh_connection: SshRemoteConnection,
 268        delegate: Arc<dyn SshClientDelegate>,
 269        forwarder: ChannelForwarder,
 270
 271        error: anyhow::Error,
 272        attempts: usize,
 273    },
 274    ReconnectExhausted,
 275    ServerNotRunning,
 276}
 277
 278impl fmt::Display for State {
 279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 280        match self {
 281            Self::Connecting => write!(f, "connecting"),
 282            Self::Connected { .. } => write!(f, "connected"),
 283            Self::Reconnecting => write!(f, "reconnecting"),
 284            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 285            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 286            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 287            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 288        }
 289    }
 290}
 291
 292impl State {
 293    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
 294        match self {
 295            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
 296            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
 297            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
 298            _ => None,
 299        }
 300    }
 301
 302    fn can_reconnect(&self) -> bool {
 303        match self {
 304            Self::Connected { .. }
 305            | Self::HeartbeatMissed { .. }
 306            | Self::ReconnectFailed { .. } => true,
 307            State::Connecting
 308            | State::Reconnecting
 309            | State::ReconnectExhausted
 310            | State::ServerNotRunning => false,
 311        }
 312    }
 313
 314    fn is_reconnect_failed(&self) -> bool {
 315        matches!(self, Self::ReconnectFailed { .. })
 316    }
 317
 318    fn is_reconnecting(&self) -> bool {
 319        matches!(self, Self::Reconnecting { .. })
 320    }
 321
 322    fn heartbeat_recovered(self) -> Self {
 323        match self {
 324            Self::HeartbeatMissed {
 325                ssh_connection,
 326                delegate,
 327                forwarder,
 328                multiplex_task,
 329                heartbeat_task,
 330                ..
 331            } => Self::Connected {
 332                ssh_connection,
 333                delegate,
 334                forwarder,
 335                multiplex_task,
 336                heartbeat_task,
 337            },
 338            _ => self,
 339        }
 340    }
 341
 342    fn heartbeat_missed(self) -> Self {
 343        match self {
 344            Self::Connected {
 345                ssh_connection,
 346                delegate,
 347                forwarder,
 348                multiplex_task,
 349                heartbeat_task,
 350            } => Self::HeartbeatMissed {
 351                missed_heartbeats: 1,
 352                ssh_connection,
 353                delegate,
 354                forwarder,
 355                multiplex_task,
 356                heartbeat_task,
 357            },
 358            Self::HeartbeatMissed {
 359                missed_heartbeats,
 360                ssh_connection,
 361                delegate,
 362                forwarder,
 363                multiplex_task,
 364                heartbeat_task,
 365            } => Self::HeartbeatMissed {
 366                missed_heartbeats: missed_heartbeats + 1,
 367                ssh_connection,
 368                delegate,
 369                forwarder,
 370                multiplex_task,
 371                heartbeat_task,
 372            },
 373            _ => self,
 374        }
 375    }
 376}
 377
 378/// The state of the ssh connection.
 379#[derive(Clone, Copy, Debug)]
 380pub enum ConnectionState {
 381    Connecting,
 382    Connected,
 383    HeartbeatMissed,
 384    Reconnecting,
 385    Disconnected,
 386}
 387
 388impl From<&State> for ConnectionState {
 389    fn from(value: &State) -> Self {
 390        match value {
 391            State::Connecting => Self::Connecting,
 392            State::Connected { .. } => Self::Connected,
 393            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 394            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 395            State::ReconnectExhausted => Self::Disconnected,
 396            State::ServerNotRunning => Self::Disconnected,
 397        }
 398    }
 399}
 400
 401pub struct SshRemoteClient {
 402    client: Arc<ChannelClient>,
 403    unique_identifier: String,
 404    connection_options: SshConnectionOptions,
 405    state: Arc<Mutex<Option<State>>>,
 406}
 407
 408impl Drop for SshRemoteClient {
 409    fn drop(&mut self) {
 410        self.shutdown_processes();
 411    }
 412}
 413
 414impl SshRemoteClient {
 415    pub fn new(
 416        unique_identifier: String,
 417        connection_options: SshConnectionOptions,
 418        delegate: Arc<dyn SshClientDelegate>,
 419        cx: &AppContext,
 420    ) -> Task<Result<Model<Self>>> {
 421        cx.spawn(|mut cx| async move {
 422            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 423            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 424
 425            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 426            let this = cx.new_model(|cx| {
 427                cx.on_app_quit(|this: &mut Self, _| {
 428                    this.shutdown_processes();
 429                    futures::future::ready(())
 430                })
 431                .detach();
 432
 433                Self {
 434                    client: client.clone(),
 435                    unique_identifier: unique_identifier.clone(),
 436                    connection_options: connection_options.clone(),
 437                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 438                }
 439            })?;
 440
 441            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 442                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 443
 444            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 445                unique_identifier,
 446                false,
 447                connection_options,
 448                delegate.clone(),
 449                &mut cx,
 450            )
 451            .await?;
 452
 453            let multiplex_task = Self::multiplex(
 454                this.downgrade(),
 455                ssh_proxy_process,
 456                proxy_incoming_tx,
 457                proxy_outgoing_rx,
 458                &mut cx,
 459            );
 460
 461            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 462                log::error!("failed to establish connection: {}", error);
 463                delegate.set_error(error.to_string(), &mut cx);
 464                return Err(error);
 465            }
 466
 467            let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
 468
 469            this.update(&mut cx, |this, _| {
 470                *this.state.lock() = Some(State::Connected {
 471                    ssh_connection,
 472                    delegate,
 473                    forwarder: proxy,
 474                    multiplex_task,
 475                    heartbeat_task,
 476                });
 477            })?;
 478
 479            Ok(this)
 480        })
 481    }
 482
 483    fn shutdown_processes(&self) {
 484        let Some(state) = self.state.lock().take() else {
 485            return;
 486        };
 487        log::info!("shutting down ssh processes");
 488
 489        let State::Connected {
 490            multiplex_task,
 491            heartbeat_task,
 492            ..
 493        } = state
 494        else {
 495            return;
 496        };
 497        // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 498        // child of master_process.
 499        drop(multiplex_task);
 500        // Now drop the rest of state, which kills master process.
 501        drop(heartbeat_task);
 502    }
 503
 504    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 505        let mut lock = self.state.lock();
 506
 507        let can_reconnect = lock
 508            .as_ref()
 509            .map(|state| state.can_reconnect())
 510            .unwrap_or(false);
 511        if !can_reconnect {
 512            let error = if let Some(state) = lock.as_ref() {
 513                format!("invalid state, cannot reconnect while in state {state}")
 514            } else {
 515                "no state set".to_string()
 516            };
 517            log::info!("aborting reconnect, because not in state that allows reconnecting");
 518            return Err(anyhow!(error));
 519        }
 520
 521        let state = lock.take().unwrap();
 522        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 523            State::Connected {
 524                ssh_connection,
 525                delegate,
 526                forwarder,
 527                multiplex_task,
 528                heartbeat_task,
 529            }
 530            | State::HeartbeatMissed {
 531                ssh_connection,
 532                delegate,
 533                forwarder,
 534                multiplex_task,
 535                heartbeat_task,
 536                ..
 537            } => {
 538                drop(multiplex_task);
 539                drop(heartbeat_task);
 540                (0, ssh_connection, delegate, forwarder)
 541            }
 542            State::ReconnectFailed {
 543                attempts,
 544                ssh_connection,
 545                delegate,
 546                forwarder,
 547                ..
 548            } => (attempts, ssh_connection, delegate, forwarder),
 549            State::Connecting
 550            | State::Reconnecting
 551            | State::ReconnectExhausted
 552            | State::ServerNotRunning => unreachable!(),
 553        };
 554
 555        let attempts = attempts + 1;
 556        if attempts > MAX_RECONNECT_ATTEMPTS {
 557            log::error!(
 558                "Failed to reconnect to after {} attempts, giving up",
 559                MAX_RECONNECT_ATTEMPTS
 560            );
 561            drop(lock);
 562            self.set_state(State::ReconnectExhausted, cx);
 563            return Ok(());
 564        }
 565        drop(lock);
 566        self.set_state(State::Reconnecting, cx);
 567
 568        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 569
 570        let identifier = self.unique_identifier.clone();
 571        let client = self.client.clone();
 572        let reconnect_task = cx.spawn(|this, mut cx| async move {
 573            macro_rules! failed {
 574                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 575                    return State::ReconnectFailed {
 576                        error: anyhow!($error),
 577                        attempts: $attempts,
 578                        ssh_connection: $ssh_connection,
 579                        delegate: $delegate,
 580                        forwarder: $forwarder,
 581                    };
 582                };
 583            }
 584
 585            if let Err(error) = ssh_connection.master_process.kill() {
 586                failed!(error, attempts, ssh_connection, delegate, forwarder);
 587            };
 588
 589            if let Err(error) = ssh_connection
 590                .master_process
 591                .status()
 592                .await
 593                .context("Failed to kill ssh process")
 594            {
 595                failed!(error, attempts, ssh_connection, delegate, forwarder);
 596            }
 597
 598            let connection_options = ssh_connection.socket.connection_options.clone();
 599
 600            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 601            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 602                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 603
 604            let (ssh_connection, ssh_process) = match Self::establish_connection(
 605                identifier,
 606                true,
 607                connection_options,
 608                delegate.clone(),
 609                &mut cx,
 610            )
 611            .await
 612            {
 613                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 614                Err(error) => {
 615                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 616                }
 617            };
 618
 619            let multiplex_task = Self::multiplex(
 620                this.clone(),
 621                ssh_process,
 622                proxy_incoming_tx,
 623                proxy_outgoing_rx,
 624                &mut cx,
 625            );
 626
 627            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 628                failed!(error, attempts, ssh_connection, delegate, forwarder);
 629            };
 630
 631            State::Connected {
 632                ssh_connection,
 633                delegate,
 634                forwarder,
 635                multiplex_task,
 636                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
 637            }
 638        });
 639
 640        cx.spawn(|this, mut cx| async move {
 641            let new_state = reconnect_task.await;
 642            this.update(&mut cx, |this, cx| {
 643                this.try_set_state(cx, |old_state| {
 644                    if old_state.is_reconnecting() {
 645                        match &new_state {
 646                            State::Connecting
 647                            | State::Reconnecting { .. }
 648                            | State::HeartbeatMissed { .. }
 649                            | State::ServerNotRunning => {}
 650                            State::Connected { .. } => {
 651                                log::info!("Successfully reconnected");
 652                            }
 653                            State::ReconnectFailed {
 654                                error, attempts, ..
 655                            } => {
 656                                log::error!(
 657                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 658                                    attempts,
 659                                    error
 660                                );
 661                            }
 662                            State::ReconnectExhausted => {
 663                                log::error!("Reconnect attempt failed and all attempts exhausted");
 664                            }
 665                        }
 666                        Some(new_state)
 667                    } else {
 668                        None
 669                    }
 670                });
 671
 672                if this.state_is(State::is_reconnect_failed) {
 673                    this.reconnect(cx)
 674                } else {
 675                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 676                    Ok(())
 677                }
 678            })
 679        })
 680        .detach_and_log_err(cx);
 681
 682        Ok(())
 683    }
 684
 685    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
 686        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 687            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 688        };
 689        cx.spawn(|mut cx| {
 690            let this = this.clone();
 691            async move {
 692                let mut missed_heartbeats = 0;
 693
 694                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
 695                loop {
 696                    timer.next().await;
 697
 698                    log::debug!("Sending heartbeat to server...");
 699
 700                    let result = client.ping(HEARTBEAT_TIMEOUT).await;
 701                    if result.is_err() {
 702                        missed_heartbeats += 1;
 703                        log::warn!(
 704                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 705                            HEARTBEAT_TIMEOUT,
 706                            missed_heartbeats,
 707                            MAX_MISSED_HEARTBEATS
 708                        );
 709                    } else if missed_heartbeats != 0 {
 710                        missed_heartbeats = 0;
 711                    } else {
 712                        continue;
 713                    }
 714
 715                    let result = this.update(&mut cx, |this, mut cx| {
 716                        this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 717                    })?;
 718                    if result.is_break() {
 719                        return Ok(());
 720                    }
 721                }
 722            }
 723        })
 724    }
 725
 726    fn handle_heartbeat_result(
 727        &mut self,
 728        missed_heartbeats: usize,
 729        cx: &mut ModelContext<Self>,
 730    ) -> ControlFlow<()> {
 731        let state = self.state.lock().take().unwrap();
 732        let next_state = if missed_heartbeats > 0 {
 733            state.heartbeat_missed()
 734        } else {
 735            state.heartbeat_recovered()
 736        };
 737        self.set_state(next_state, cx);
 738
 739        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 740            log::error!(
 741                "Missed last {} heartbeats. Reconnecting...",
 742                missed_heartbeats
 743            );
 744
 745            self.reconnect(cx)
 746                .context("failed to start reconnect process after missing heartbeats")
 747                .log_err();
 748            ControlFlow::Break(())
 749        } else {
 750            ControlFlow::Continue(())
 751        }
 752    }
 753
 754    fn multiplex(
 755        this: WeakModel<Self>,
 756        mut ssh_proxy_process: Child,
 757        incoming_tx: UnboundedSender<Envelope>,
 758        mut outgoing_rx: UnboundedReceiver<Envelope>,
 759        cx: &AsyncAppContext,
 760    ) -> Task<Result<()>> {
 761        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 762        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 763        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 764
 765        let io_task = cx.background_executor().spawn(async move {
 766            let mut stdin_buffer = Vec::new();
 767            let mut stdout_buffer = Vec::new();
 768            let mut stderr_buffer = Vec::new();
 769            let mut stderr_offset = 0;
 770
 771            loop {
 772                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 773                stderr_buffer.resize(stderr_offset + 1024, 0);
 774
 775                select_biased! {
 776                    outgoing = outgoing_rx.next().fuse() => {
 777                        let Some(outgoing) = outgoing else {
 778                            return anyhow::Ok(None);
 779                        };
 780
 781                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 782                    }
 783
 784                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 785                        match result {
 786                            Ok(0) => {
 787                                child_stdin.close().await?;
 788                                outgoing_rx.close();
 789                                let status = ssh_proxy_process.status().await?;
 790                                return Ok(status.code());
 791                            }
 792                            Ok(len) => {
 793                                if len < stdout_buffer.len() {
 794                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 795                                }
 796
 797                                let message_len = message_len_from_buffer(&stdout_buffer);
 798                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 799                                    Ok(envelope) => {
 800                                        incoming_tx.unbounded_send(envelope).ok();
 801                                    }
 802                                    Err(error) => {
 803                                        log::error!("error decoding message {error:?}");
 804                                    }
 805                                }
 806                            }
 807                            Err(error) => {
 808                                Err(anyhow!("error reading stdout: {error:?}"))?;
 809                            }
 810                        }
 811                    }
 812
 813                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 814                        match result {
 815                            Ok(len) => {
 816                                stderr_offset += len;
 817                                let mut start_ix = 0;
 818                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 819                                    let line_ix = start_ix + ix;
 820                                    let content = &stderr_buffer[start_ix..line_ix];
 821                                    start_ix = line_ix + 1;
 822                                    if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
 823                                        record.message = format!("(remote) {}", record.message);
 824                                        record.log(log::logger())
 825                                    } else {
 826                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 827                                    }
 828                                }
 829                                stderr_buffer.drain(0..start_ix);
 830                                stderr_offset -= start_ix;
 831                            }
 832                            Err(error) => {
 833                                Err(anyhow!("error reading stderr: {error:?}"))?;
 834                            }
 835                        }
 836                    }
 837                }
 838            }
 839        });
 840
 841        cx.spawn(|mut cx| async move {
 842            let result = io_task.await;
 843
 844            match result {
 845                Ok(Some(exit_code)) => {
 846                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 847                        match error {
 848                            ProxyLaunchError::ServerNotRunning => {
 849                                log::error!("failed to reconnect because server is not running");
 850                                this.update(&mut cx, |this, cx| {
 851                                    this.set_state(State::ServerNotRunning, cx);
 852                                })?;
 853                            }
 854                        }
 855                    } else if exit_code > 0 {
 856                        log::error!("proxy process terminated unexpectedly");
 857                    }
 858                }
 859                Ok(None) => {}
 860                Err(error) => {
 861                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 862                    this.update(&mut cx, |this, cx| {
 863                        this.reconnect(cx).ok();
 864                    })?;
 865                }
 866            }
 867            Ok(())
 868        })
 869    }
 870
 871    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 872        self.state.lock().as_ref().map_or(false, check)
 873    }
 874
 875    fn try_set_state(
 876        &self,
 877        cx: &mut ModelContext<Self>,
 878        map: impl FnOnce(&State) -> Option<State>,
 879    ) {
 880        if let Some(new_state) = self.state.lock().as_ref().and_then(map) {
 881            self.set_state(new_state, cx);
 882        }
 883    }
 884
 885    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 886        log::info!("setting state to '{}'", &state);
 887        self.state.lock().replace(state);
 888        cx.notify();
 889    }
 890
 891    async fn establish_connection(
 892        unique_identifier: String,
 893        reconnect: bool,
 894        connection_options: SshConnectionOptions,
 895        delegate: Arc<dyn SshClientDelegate>,
 896        cx: &mut AsyncAppContext,
 897    ) -> Result<(SshRemoteConnection, Child)> {
 898        let ssh_connection =
 899            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 900
 901        let platform = ssh_connection.query_platform().await?;
 902        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 903        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 904        ssh_connection
 905            .ensure_server_binary(
 906                &delegate,
 907                &local_binary_path,
 908                &remote_binary_path,
 909                version,
 910                cx,
 911            )
 912            .await?;
 913
 914        let socket = ssh_connection.socket.clone();
 915        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 916
 917        delegate.set_status(Some("Starting proxy"), cx);
 918
 919        let mut start_proxy_command = format!(
 920            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 921            std::env::var("RUST_LOG").unwrap_or_default(),
 922            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 923            remote_binary_path,
 924            unique_identifier,
 925        );
 926        if reconnect {
 927            start_proxy_command.push_str(" --reconnect");
 928        }
 929
 930        let ssh_proxy_process = socket
 931            .ssh_command(start_proxy_command)
 932            // IMPORTANT: we kill this process when we drop the task that uses it.
 933            .kill_on_drop(true)
 934            .spawn()
 935            .context("failed to spawn remote server")?;
 936
 937        Ok((ssh_connection, ssh_proxy_process))
 938    }
 939
 940    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 941        self.client.subscribe_to_entity(remote_id, entity);
 942    }
 943
 944    pub fn ssh_args(&self) -> Option<Vec<String>> {
 945        self.state
 946            .lock()
 947            .as_ref()
 948            .and_then(|state| state.ssh_connection())
 949            .map(|ssh_connection| ssh_connection.socket.ssh_args())
 950    }
 951
 952    pub fn to_proto_client(&self) -> AnyProtoClient {
 953        self.client.clone().into()
 954    }
 955
 956    pub fn connection_string(&self) -> String {
 957        self.connection_options.connection_string()
 958    }
 959
 960    pub fn connection_state(&self) -> ConnectionState {
 961        self.state
 962            .lock()
 963            .as_ref()
 964            .map(ConnectionState::from)
 965            .unwrap_or(ConnectionState::Disconnected)
 966    }
 967
 968    #[cfg(any(test, feature = "test-support"))]
 969    pub fn fake(
 970        client_cx: &mut gpui::TestAppContext,
 971        server_cx: &mut gpui::TestAppContext,
 972    ) -> (Model<Self>, Arc<ChannelClient>) {
 973        use gpui::Context;
 974
 975        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
 976        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
 977
 978        (
 979            client_cx.update(|cx| {
 980                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
 981                cx.new_model(|_| Self {
 982                    client,
 983                    unique_identifier: "fake".to_string(),
 984                    connection_options: SshConnectionOptions::default(),
 985                    state: Arc::new(Mutex::new(None)),
 986                })
 987            }),
 988            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
 989        )
 990    }
 991}
 992
 993impl From<SshRemoteClient> for AnyProtoClient {
 994    fn from(client: SshRemoteClient) -> Self {
 995        AnyProtoClient::new(client.client.clone())
 996    }
 997}
 998
 999struct SshRemoteConnection {
1000    socket: SshSocket,
1001    master_process: process::Child,
1002    _temp_dir: TempDir,
1003}
1004
1005impl Drop for SshRemoteConnection {
1006    fn drop(&mut self) {
1007        if let Err(error) = self.master_process.kill() {
1008            log::error!("failed to kill SSH master process: {}", error);
1009        }
1010    }
1011}
1012
1013impl SshRemoteConnection {
1014    #[cfg(not(unix))]
1015    async fn new(
1016        _connection_options: SshConnectionOptions,
1017        _delegate: Arc<dyn SshClientDelegate>,
1018        _cx: &mut AsyncAppContext,
1019    ) -> Result<Self> {
1020        Err(anyhow!("ssh is not supported on this platform"))
1021    }
1022
1023    #[cfg(unix)]
1024    async fn new(
1025        connection_options: SshConnectionOptions,
1026        delegate: Arc<dyn SshClientDelegate>,
1027        cx: &mut AsyncAppContext,
1028    ) -> Result<Self> {
1029        use futures::{io::BufReader, AsyncBufReadExt as _};
1030        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1031        use util::ResultExt as _;
1032
1033        delegate.set_status(Some("connecting"), cx);
1034
1035        let url = connection_options.ssh_url();
1036        let temp_dir = tempfile::Builder::new()
1037            .prefix("zed-ssh-session")
1038            .tempdir()?;
1039
1040        // Create a domain socket listener to handle requests from the askpass program.
1041        let askpass_socket = temp_dir.path().join("askpass.sock");
1042        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1043        let listener =
1044            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1045
1046        let askpass_task = cx.spawn({
1047            let delegate = delegate.clone();
1048            |mut cx| async move {
1049                let mut askpass_opened_tx = Some(askpass_opened_tx);
1050
1051                while let Ok((mut stream, _)) = listener.accept().await {
1052                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1053                        askpass_opened_tx.send(()).ok();
1054                    }
1055                    let mut buffer = Vec::new();
1056                    let mut reader = BufReader::new(&mut stream);
1057                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1058                        buffer.clear();
1059                    }
1060                    let password_prompt = String::from_utf8_lossy(&buffer);
1061                    if let Some(password) = delegate
1062                        .ask_password(password_prompt.to_string(), &mut cx)
1063                        .await
1064                        .context("failed to get ssh password")
1065                        .and_then(|p| p)
1066                        .log_err()
1067                    {
1068                        stream.write_all(password.as_bytes()).await.log_err();
1069                    }
1070                }
1071            }
1072        });
1073
1074        // Create an askpass script that communicates back to this process.
1075        let askpass_script = format!(
1076            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1077            askpass_socket = askpass_socket.display(),
1078            print_args = "printf '%s\\0' \"$@\"",
1079            shebang = "#!/bin/sh",
1080        );
1081        let askpass_script_path = temp_dir.path().join("askpass.sh");
1082        fs::write(&askpass_script_path, askpass_script).await?;
1083        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1084
1085        // Start the master SSH process, which does not do anything except for establish
1086        // the connection and keep it open, allowing other ssh commands to reuse it
1087        // via a control socket.
1088        let socket_path = temp_dir.path().join("ssh.sock");
1089        let mut master_process = process::Command::new("ssh")
1090            .stdin(Stdio::null())
1091            .stdout(Stdio::piped())
1092            .stderr(Stdio::piped())
1093            .env("SSH_ASKPASS_REQUIRE", "force")
1094            .env("SSH_ASKPASS", &askpass_script_path)
1095            .args(["-N", "-o", "ControlMaster=yes", "-o"])
1096            .arg(format!("ControlPath={}", socket_path.display()))
1097            .arg(&url)
1098            .spawn()?;
1099
1100        // Wait for this ssh process to close its stdout, indicating that authentication
1101        // has completed.
1102        let stdout = master_process.stdout.as_mut().unwrap();
1103        let mut output = Vec::new();
1104        let connection_timeout = Duration::from_secs(10);
1105
1106        let result = select_biased! {
1107            _ = askpass_opened_rx.fuse() => {
1108                // If the askpass script has opened, that means the user is typing
1109                // their password, in which case we don't want to timeout anymore,
1110                // since we know a connection has been established.
1111                stdout.read_to_end(&mut output).await?;
1112                Ok(())
1113            }
1114            result = stdout.read_to_end(&mut output).fuse() => {
1115                result?;
1116                Ok(())
1117            }
1118            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1119                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1120            }
1121        };
1122
1123        if let Err(e) = result {
1124            let error_message = format!("Failed to connect to host: {}.", e);
1125            delegate.set_error(error_message, cx);
1126            return Err(e);
1127        }
1128
1129        drop(askpass_task);
1130
1131        if master_process.try_status()?.is_some() {
1132            output.clear();
1133            let mut stderr = master_process.stderr.take().unwrap();
1134            stderr.read_to_end(&mut output).await?;
1135
1136            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1137            delegate.set_error(error_message.clone(), cx);
1138            Err(anyhow!(error_message))?;
1139        }
1140
1141        Ok(Self {
1142            socket: SshSocket {
1143                connection_options,
1144                socket_path,
1145            },
1146            master_process,
1147            _temp_dir: temp_dir,
1148        })
1149    }
1150
1151    async fn ensure_server_binary(
1152        &self,
1153        delegate: &Arc<dyn SshClientDelegate>,
1154        src_path: &Path,
1155        dst_path: &Path,
1156        version: SemanticVersion,
1157        cx: &mut AsyncAppContext,
1158    ) -> Result<()> {
1159        let mut dst_path_gz = dst_path.to_path_buf();
1160        dst_path_gz.set_extension("gz");
1161
1162        if let Some(parent) = dst_path.parent() {
1163            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1164        }
1165
1166        let mut server_binary_exists = false;
1167        if cfg!(not(debug_assertions)) {
1168            if let Ok(installed_version) =
1169                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1170            {
1171                if installed_version.trim() == version.to_string() {
1172                    server_binary_exists = true;
1173                }
1174            }
1175        }
1176
1177        if server_binary_exists {
1178            log::info!("remote development server already present",);
1179            return Ok(());
1180        }
1181
1182        let src_stat = fs::metadata(src_path).await?;
1183        let size = src_stat.len();
1184        let server_mode = 0o755;
1185
1186        let t0 = Instant::now();
1187        delegate.set_status(Some("uploading remote development server"), cx);
1188        log::info!("uploading remote development server ({}kb)", size / 1024);
1189        self.upload_file(src_path, &dst_path_gz)
1190            .await
1191            .context("failed to upload server binary")?;
1192        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1193
1194        delegate.set_status(Some("extracting remote development server"), cx);
1195        run_cmd(
1196            self.socket
1197                .ssh_command("gunzip")
1198                .arg("--force")
1199                .arg(&dst_path_gz),
1200        )
1201        .await?;
1202
1203        delegate.set_status(Some("unzipping remote development server"), cx);
1204        run_cmd(
1205            self.socket
1206                .ssh_command("chmod")
1207                .arg(format!("{:o}", server_mode))
1208                .arg(dst_path),
1209        )
1210        .await?;
1211
1212        Ok(())
1213    }
1214
1215    async fn query_platform(&self) -> Result<SshPlatform> {
1216        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1217        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1218
1219        let os = match os.trim() {
1220            "Darwin" => "macos",
1221            "Linux" => "linux",
1222            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1223        };
1224        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1225            "aarch64"
1226        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1227            "x86_64"
1228        } else {
1229            Err(anyhow!("unknown uname architecture {arch:?}"))?
1230        };
1231
1232        Ok(SshPlatform { os, arch })
1233    }
1234
1235    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1236        let mut command = process::Command::new("scp");
1237        let output = self
1238            .socket
1239            .ssh_options(&mut command)
1240            .args(
1241                self.socket
1242                    .connection_options
1243                    .port
1244                    .map(|port| vec!["-P".to_string(), port.to_string()])
1245                    .unwrap_or_default(),
1246            )
1247            .arg(src_path)
1248            .arg(format!(
1249                "{}:{}",
1250                self.socket.connection_options.scp_url(),
1251                dest_path.display()
1252            ))
1253            .output()
1254            .await?;
1255
1256        if output.status.success() {
1257            Ok(())
1258        } else {
1259            Err(anyhow!(
1260                "failed to upload file {} -> {}: {}",
1261                src_path.display(),
1262                dest_path.display(),
1263                String::from_utf8_lossy(&output.stderr)
1264            ))
1265        }
1266    }
1267}
1268
1269type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1270
1271pub struct ChannelClient {
1272    next_message_id: AtomicU32,
1273    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1274    response_channels: ResponseChannels,             // Lock
1275    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1276}
1277
1278impl ChannelClient {
1279    pub fn new(
1280        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1281        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1282        cx: &AppContext,
1283    ) -> Arc<Self> {
1284        let this = Arc::new(Self {
1285            outgoing_tx,
1286            next_message_id: AtomicU32::new(0),
1287            response_channels: ResponseChannels::default(),
1288            message_handlers: Default::default(),
1289        });
1290
1291        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1292
1293        this
1294    }
1295
1296    fn start_handling_messages(
1297        this: Arc<Self>,
1298        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1299        cx: &AppContext,
1300    ) {
1301        cx.spawn(|cx| {
1302            let this = Arc::downgrade(&this);
1303            async move {
1304                let peer_id = PeerId { owner_id: 0, id: 0 };
1305                while let Some(incoming) = incoming_rx.next().await {
1306                    let Some(this) = this.upgrade() else {
1307                        return anyhow::Ok(());
1308                    };
1309
1310                    if let Some(request_id) = incoming.responding_to {
1311                        let request_id = MessageId(request_id);
1312                        let sender = this.response_channels.lock().remove(&request_id);
1313                        if let Some(sender) = sender {
1314                            let (tx, rx) = oneshot::channel();
1315                            if incoming.payload.is_some() {
1316                                sender.send((incoming, tx)).ok();
1317                            }
1318                            rx.await.ok();
1319                        }
1320                    } else if let Some(envelope) =
1321                        build_typed_envelope(peer_id, Instant::now(), incoming)
1322                    {
1323                        let type_name = envelope.payload_type_name();
1324                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1325                            &this.message_handlers,
1326                            envelope,
1327                            this.clone().into(),
1328                            cx.clone(),
1329                        ) {
1330                            log::debug!("ssh message received. name:{type_name}");
1331                            match future.await {
1332                                Ok(_) => {
1333                                    log::debug!("ssh message handled. name:{type_name}");
1334                                }
1335                                Err(error) => {
1336                                    log::error!(
1337                                        "error handling message. type:{type_name}, error:{error}",
1338                                    );
1339                                }
1340                            }
1341                        } else {
1342                            log::error!("unhandled ssh message name:{type_name}");
1343                        }
1344                    }
1345                }
1346                anyhow::Ok(())
1347            }
1348        })
1349        .detach();
1350    }
1351
1352    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1353        let id = (TypeId::of::<E>(), remote_id);
1354
1355        let mut message_handlers = self.message_handlers.lock();
1356        if message_handlers
1357            .entities_by_type_and_remote_id
1358            .contains_key(&id)
1359        {
1360            panic!("already subscribed to entity");
1361        }
1362
1363        message_handlers.entities_by_type_and_remote_id.insert(
1364            id,
1365            EntityMessageSubscriber::Entity {
1366                handle: entity.downgrade().into(),
1367            },
1368        );
1369    }
1370
1371    pub fn request<T: RequestMessage>(
1372        &self,
1373        payload: T,
1374    ) -> impl 'static + Future<Output = Result<T::Response>> {
1375        log::debug!("ssh request start. name:{}", T::NAME);
1376        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1377        async move {
1378            let response = response.await?;
1379            log::debug!("ssh request finish. name:{}", T::NAME);
1380            T::Response::from_envelope(response)
1381                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1382        }
1383    }
1384
1385    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1386        smol::future::or(
1387            async {
1388                self.request(proto::Ping {}).await?;
1389                Ok(())
1390            },
1391            async {
1392                smol::Timer::after(timeout).await;
1393                Err(anyhow!("Timeout detected"))
1394            },
1395        )
1396        .await
1397    }
1398
1399    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1400        log::debug!("ssh send name:{}", T::NAME);
1401        self.send_dynamic(payload.into_envelope(0, None, None))
1402    }
1403
1404    pub fn request_dynamic(
1405        &self,
1406        mut envelope: proto::Envelope,
1407        type_name: &'static str,
1408    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1409        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1410        let (tx, rx) = oneshot::channel();
1411        let mut response_channels_lock = self.response_channels.lock();
1412        response_channels_lock.insert(MessageId(envelope.id), tx);
1413        drop(response_channels_lock);
1414        let result = self.outgoing_tx.unbounded_send(envelope);
1415        async move {
1416            if let Err(error) = &result {
1417                log::error!("failed to send message: {}", error);
1418                return Err(anyhow!("failed to send message: {}", error));
1419            }
1420
1421            let response = rx.await.context("connection lost")?.0;
1422            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1423                return Err(RpcError::from_proto(error, type_name));
1424            }
1425            Ok(response)
1426        }
1427    }
1428
1429    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1430        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1431        self.outgoing_tx.unbounded_send(envelope)?;
1432        Ok(())
1433    }
1434}
1435
1436impl ProtoClient for ChannelClient {
1437    fn request(
1438        &self,
1439        envelope: proto::Envelope,
1440        request_type: &'static str,
1441    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1442        self.request_dynamic(envelope, request_type).boxed()
1443    }
1444
1445    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1446        self.send_dynamic(envelope)
1447    }
1448
1449    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1450        self.send_dynamic(envelope)
1451    }
1452
1453    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1454        &self.message_handlers
1455    }
1456
1457    fn is_via_collab(&self) -> bool {
1458        false
1459    }
1460}