ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    channel::{
  13        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  14        oneshot,
  15    },
  16    future::{BoxFuture, Shared},
  17    select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model,
  21    ModelContext, SemanticVersion, Task, WeakModel,
  22};
  23use parking_lot::Mutex;
  24use rpc::{
  25    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  26    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  27    RpcError,
  28};
  29use smol::{
  30    fs,
  31    process::{self, Child, Stdio},
  32};
  33use std::{
  34    any::TypeId,
  35    collections::VecDeque,
  36    ffi::OsStr,
  37    fmt,
  38    ops::ControlFlow,
  39    path::{Path, PathBuf},
  40    sync::{
  41        atomic::{AtomicU32, Ordering::SeqCst},
  42        Arc, Weak,
  43    },
  44    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
  45};
  46use tempfile::TempDir;
  47use util::ResultExt;
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  51)]
  52pub struct SshProjectId(pub u64);
  53
  54#[derive(Clone)]
  55pub struct SshSocket {
  56    connection_options: SshConnectionOptions,
  57    socket_path: PathBuf,
  58}
  59
  60#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  61pub struct SshConnectionOptions {
  62    pub host: String,
  63    pub username: Option<String>,
  64    pub port: Option<u16>,
  65    pub password: Option<String>,
  66    pub args: Option<Vec<String>>,
  67}
  68
  69impl SshConnectionOptions {
  70    pub fn parse_command_line(input: &str) -> Result<Self> {
  71        let input = input.trim_start_matches("ssh ");
  72        let mut hostname: Option<String> = None;
  73        let mut username: Option<String> = None;
  74        let mut port: Option<u16> = None;
  75        let mut args = Vec::new();
  76
  77        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
  78        const ALLOWED_OPTS: &[&str] = &[
  79            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
  80        ];
  81        const ALLOWED_ARGS: &[&str] = &[
  82            "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
  83            "-w",
  84        ];
  85
  86        let mut tokens = shlex::split(input)
  87            .ok_or_else(|| anyhow!("invalid input"))?
  88            .into_iter();
  89
  90        'outer: while let Some(arg) = tokens.next() {
  91            if ALLOWED_OPTS.contains(&(&arg as &str)) {
  92                args.push(arg.to_string());
  93                continue;
  94            }
  95            if arg == "-p" {
  96                port = tokens.next().and_then(|arg| arg.parse().ok());
  97                continue;
  98            } else if let Some(p) = arg.strip_prefix("-p") {
  99                port = p.parse().ok();
 100                continue;
 101            }
 102            if arg == "-l" {
 103                username = tokens.next();
 104                continue;
 105            } else if let Some(l) = arg.strip_prefix("-l") {
 106                username = Some(l.to_string());
 107                continue;
 108            }
 109            for a in ALLOWED_ARGS {
 110                if arg == *a {
 111                    args.push(arg);
 112                    if let Some(next) = tokens.next() {
 113                        args.push(next);
 114                    }
 115                    continue 'outer;
 116                } else if arg.starts_with(a) {
 117                    args.push(arg);
 118                    continue 'outer;
 119                }
 120            }
 121            if arg.starts_with("-") || hostname.is_some() {
 122                anyhow::bail!("unsupported argument: {:?}", arg);
 123            }
 124            let mut input = &arg as &str;
 125            if let Some((u, rest)) = input.split_once('@') {
 126                input = rest;
 127                username = Some(u.to_string());
 128            }
 129            if let Some((rest, p)) = input.split_once(':') {
 130                input = rest;
 131                port = p.parse().ok()
 132            }
 133            hostname = Some(input.to_string())
 134        }
 135
 136        let Some(hostname) = hostname else {
 137            anyhow::bail!("missing hostname");
 138        };
 139
 140        Ok(Self {
 141            host: hostname.to_string(),
 142            username: username.clone(),
 143            port,
 144            password: None,
 145            args: Some(args),
 146        })
 147    }
 148
 149    pub fn ssh_url(&self) -> String {
 150        let mut result = String::from("ssh://");
 151        if let Some(username) = &self.username {
 152            result.push_str(username);
 153            result.push('@');
 154        }
 155        result.push_str(&self.host);
 156        if let Some(port) = self.port {
 157            result.push(':');
 158            result.push_str(&port.to_string());
 159        }
 160        result
 161    }
 162
 163    pub fn additional_args(&self) -> Option<&Vec<String>> {
 164        self.args.as_ref()
 165    }
 166
 167    fn scp_url(&self) -> String {
 168        if let Some(username) = &self.username {
 169            format!("{}@{}", username, self.host)
 170        } else {
 171            self.host.clone()
 172        }
 173    }
 174
 175    pub fn connection_string(&self) -> String {
 176        let host = if let Some(username) = &self.username {
 177            format!("{}@{}", username, self.host)
 178        } else {
 179            self.host.clone()
 180        };
 181        if let Some(port) = &self.port {
 182            format!("{}:{}", host, port)
 183        } else {
 184            host
 185        }
 186    }
 187
 188    // Uniquely identifies dev server projects on a remote host. Needs to be
 189    // stable for the same dev server project.
 190    pub fn remote_server_identifier(&self) -> String {
 191        let mut identifier = format!("dev-server-{:?}", self.host);
 192        if let Some(username) = self.username.as_ref() {
 193            identifier.push('-');
 194            identifier.push_str(&username);
 195        }
 196        identifier
 197    }
 198}
 199
 200#[derive(Copy, Clone, Debug)]
 201pub struct SshPlatform {
 202    pub os: &'static str,
 203    pub arch: &'static str,
 204}
 205
 206impl SshPlatform {
 207    pub fn triple(&self) -> Option<String> {
 208        Some(format!(
 209            "{}-{}",
 210            self.arch,
 211            match self.os {
 212                "linux" => "unknown-linux-gnu",
 213                "macos" => "apple-darwin",
 214                _ => return None,
 215            }
 216        ))
 217    }
 218}
 219
 220pub enum ServerBinary {
 221    LocalBinary(PathBuf),
 222    ReleaseUrl { url: String, body: String },
 223}
 224
 225pub trait SshClientDelegate: Send + Sync {
 226    fn ask_password(
 227        &self,
 228        prompt: String,
 229        cx: &mut AsyncAppContext,
 230    ) -> oneshot::Receiver<Result<String>>;
 231    fn remote_server_binary_path(
 232        &self,
 233        platform: SshPlatform,
 234        cx: &mut AsyncAppContext,
 235    ) -> Result<PathBuf>;
 236    fn get_server_binary(
 237        &self,
 238        platform: SshPlatform,
 239        cx: &mut AsyncAppContext,
 240    ) -> oneshot::Receiver<Result<(ServerBinary, SemanticVersion)>>;
 241    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 242}
 243
 244impl SshSocket {
 245    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 246        let mut command = process::Command::new("ssh");
 247        self.ssh_options(&mut command)
 248            .arg(self.connection_options.ssh_url())
 249            .arg(program);
 250        command
 251    }
 252
 253    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 254        command
 255            .stdin(Stdio::piped())
 256            .stdout(Stdio::piped())
 257            .stderr(Stdio::piped())
 258            .args(["-o", "ControlMaster=no", "-o"])
 259            .arg(format!("ControlPath={}", self.socket_path.display()))
 260    }
 261
 262    fn ssh_args(&self) -> Vec<String> {
 263        vec![
 264            "-o".to_string(),
 265            "ControlMaster=no".to_string(),
 266            "-o".to_string(),
 267            format!("ControlPath={}", self.socket_path.display()),
 268            self.connection_options.ssh_url(),
 269        ]
 270    }
 271}
 272
 273async fn run_cmd(command: &mut process::Command) -> Result<String> {
 274    let output = command.output().await?;
 275    if output.status.success() {
 276        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 277    } else {
 278        Err(anyhow!(
 279            "failed to run command: {}",
 280            String::from_utf8_lossy(&output.stderr)
 281        ))
 282    }
 283}
 284
 285const MAX_MISSED_HEARTBEATS: usize = 5;
 286const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 287const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 288
 289const MAX_RECONNECT_ATTEMPTS: usize = 3;
 290
 291enum State {
 292    Connecting,
 293    Connected {
 294        ssh_connection: Arc<dyn RemoteConnection>,
 295        delegate: Arc<dyn SshClientDelegate>,
 296
 297        multiplex_task: Task<Result<()>>,
 298        heartbeat_task: Task<Result<()>>,
 299    },
 300    HeartbeatMissed {
 301        missed_heartbeats: usize,
 302
 303        ssh_connection: Arc<dyn RemoteConnection>,
 304        delegate: Arc<dyn SshClientDelegate>,
 305
 306        multiplex_task: Task<Result<()>>,
 307        heartbeat_task: Task<Result<()>>,
 308    },
 309    Reconnecting,
 310    ReconnectFailed {
 311        ssh_connection: Arc<dyn RemoteConnection>,
 312        delegate: Arc<dyn SshClientDelegate>,
 313
 314        error: anyhow::Error,
 315        attempts: usize,
 316    },
 317    ReconnectExhausted,
 318    ServerNotRunning,
 319}
 320
 321impl fmt::Display for State {
 322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 323        match self {
 324            Self::Connecting => write!(f, "connecting"),
 325            Self::Connected { .. } => write!(f, "connected"),
 326            Self::Reconnecting => write!(f, "reconnecting"),
 327            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 328            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 329            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 330            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 331        }
 332    }
 333}
 334
 335impl State {
 336    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 337        match self {
 338            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 339            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 340            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 341            _ => None,
 342        }
 343    }
 344
 345    fn can_reconnect(&self) -> bool {
 346        match self {
 347            Self::Connected { .. }
 348            | Self::HeartbeatMissed { .. }
 349            | Self::ReconnectFailed { .. } => true,
 350            State::Connecting
 351            | State::Reconnecting
 352            | State::ReconnectExhausted
 353            | State::ServerNotRunning => false,
 354        }
 355    }
 356
 357    fn is_reconnect_failed(&self) -> bool {
 358        matches!(self, Self::ReconnectFailed { .. })
 359    }
 360
 361    fn is_reconnect_exhausted(&self) -> bool {
 362        matches!(self, Self::ReconnectExhausted { .. })
 363    }
 364
 365    fn is_server_not_running(&self) -> bool {
 366        matches!(self, Self::ServerNotRunning)
 367    }
 368
 369    fn is_reconnecting(&self) -> bool {
 370        matches!(self, Self::Reconnecting { .. })
 371    }
 372
 373    fn heartbeat_recovered(self) -> Self {
 374        match self {
 375            Self::HeartbeatMissed {
 376                ssh_connection,
 377                delegate,
 378                multiplex_task,
 379                heartbeat_task,
 380                ..
 381            } => Self::Connected {
 382                ssh_connection,
 383                delegate,
 384                multiplex_task,
 385                heartbeat_task,
 386            },
 387            _ => self,
 388        }
 389    }
 390
 391    fn heartbeat_missed(self) -> Self {
 392        match self {
 393            Self::Connected {
 394                ssh_connection,
 395                delegate,
 396                multiplex_task,
 397                heartbeat_task,
 398            } => Self::HeartbeatMissed {
 399                missed_heartbeats: 1,
 400                ssh_connection,
 401                delegate,
 402                multiplex_task,
 403                heartbeat_task,
 404            },
 405            Self::HeartbeatMissed {
 406                missed_heartbeats,
 407                ssh_connection,
 408                delegate,
 409                multiplex_task,
 410                heartbeat_task,
 411            } => Self::HeartbeatMissed {
 412                missed_heartbeats: missed_heartbeats + 1,
 413                ssh_connection,
 414                delegate,
 415                multiplex_task,
 416                heartbeat_task,
 417            },
 418            _ => self,
 419        }
 420    }
 421}
 422
 423/// The state of the ssh connection.
 424#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 425pub enum ConnectionState {
 426    Connecting,
 427    Connected,
 428    HeartbeatMissed,
 429    Reconnecting,
 430    Disconnected,
 431}
 432
 433impl From<&State> for ConnectionState {
 434    fn from(value: &State) -> Self {
 435        match value {
 436            State::Connecting => Self::Connecting,
 437            State::Connected { .. } => Self::Connected,
 438            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 439            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 440            State::ReconnectExhausted => Self::Disconnected,
 441            State::ServerNotRunning => Self::Disconnected,
 442        }
 443    }
 444}
 445
 446pub struct SshRemoteClient {
 447    client: Arc<ChannelClient>,
 448    unique_identifier: String,
 449    connection_options: SshConnectionOptions,
 450    state: Arc<Mutex<Option<State>>>,
 451}
 452
 453#[derive(Debug)]
 454pub enum SshRemoteEvent {
 455    Disconnected,
 456}
 457
 458impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 459
 460impl SshRemoteClient {
 461    pub fn new(
 462        unique_identifier: String,
 463        connection_options: SshConnectionOptions,
 464        cancellation: oneshot::Receiver<()>,
 465        delegate: Arc<dyn SshClientDelegate>,
 466        cx: &mut AppContext,
 467    ) -> Task<Result<Option<Model<Self>>>> {
 468        cx.spawn(|mut cx| async move {
 469            let success = Box::pin(async move {
 470                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 471                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 472                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 473
 474                let client =
 475                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 476                let this = cx.new_model(|_| Self {
 477                    client: client.clone(),
 478                    unique_identifier: unique_identifier.clone(),
 479                    connection_options: connection_options.clone(),
 480                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 481                })?;
 482
 483                let ssh_connection = cx
 484                    .update(|cx| {
 485                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 486                            pool.connect(connection_options, &delegate, cx)
 487                        })
 488                    })?
 489                    .await
 490                    .map_err(|e| e.cloned())?;
 491                let remote_binary_path = ssh_connection
 492                    .get_remote_binary_path(&delegate, false, &mut cx)
 493                    .await?;
 494
 495                let io_task = ssh_connection.start_proxy(
 496                    remote_binary_path,
 497                    unique_identifier,
 498                    false,
 499                    incoming_tx,
 500                    outgoing_rx,
 501                    connection_activity_tx,
 502                    delegate.clone(),
 503                    &mut cx,
 504                );
 505
 506                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 507
 508                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 509                    log::error!("failed to establish connection: {}", error);
 510                    return Err(error);
 511                }
 512
 513                let heartbeat_task =
 514                    Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 515
 516                this.update(&mut cx, |this, _| {
 517                    *this.state.lock() = Some(State::Connected {
 518                        ssh_connection,
 519                        delegate,
 520                        multiplex_task,
 521                        heartbeat_task,
 522                    });
 523                })?;
 524
 525                Ok(Some(this))
 526            });
 527
 528            select! {
 529                _ = cancellation.fuse() => {
 530                    Ok(None)
 531                }
 532                result = success.fuse() =>  result
 533            }
 534        })
 535    }
 536
 537    pub fn shutdown_processes<T: RequestMessage>(
 538        &self,
 539        shutdown_request: Option<T>,
 540    ) -> Option<impl Future<Output = ()>> {
 541        let state = self.state.lock().take()?;
 542        log::info!("shutting down ssh processes");
 543
 544        let State::Connected {
 545            multiplex_task,
 546            heartbeat_task,
 547            ssh_connection,
 548            delegate,
 549        } = state
 550        else {
 551            return None;
 552        };
 553
 554        let client = self.client.clone();
 555
 556        Some(async move {
 557            if let Some(shutdown_request) = shutdown_request {
 558                client.send(shutdown_request).log_err();
 559                // We wait 50ms instead of waiting for a response, because
 560                // waiting for a response would require us to wait on the main thread
 561                // which we want to avoid in an `on_app_quit` callback.
 562                smol::Timer::after(Duration::from_millis(50)).await;
 563            }
 564
 565            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 566            // child of master_process.
 567            drop(multiplex_task);
 568            // Now drop the rest of state, which kills master process.
 569            drop(heartbeat_task);
 570            drop(ssh_connection);
 571            drop(delegate);
 572        })
 573    }
 574
 575    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 576        let mut lock = self.state.lock();
 577
 578        let can_reconnect = lock
 579            .as_ref()
 580            .map(|state| state.can_reconnect())
 581            .unwrap_or(false);
 582        if !can_reconnect {
 583            let error = if let Some(state) = lock.as_ref() {
 584                format!("invalid state, cannot reconnect while in state {state}")
 585            } else {
 586                "no state set".to_string()
 587            };
 588            log::info!("aborting reconnect, because not in state that allows reconnecting");
 589            return Err(anyhow!(error));
 590        }
 591
 592        let state = lock.take().unwrap();
 593        let (attempts, ssh_connection, delegate) = match state {
 594            State::Connected {
 595                ssh_connection,
 596                delegate,
 597                multiplex_task,
 598                heartbeat_task,
 599            }
 600            | State::HeartbeatMissed {
 601                ssh_connection,
 602                delegate,
 603                multiplex_task,
 604                heartbeat_task,
 605                ..
 606            } => {
 607                drop(multiplex_task);
 608                drop(heartbeat_task);
 609                (0, ssh_connection, delegate)
 610            }
 611            State::ReconnectFailed {
 612                attempts,
 613                ssh_connection,
 614                delegate,
 615                ..
 616            } => (attempts, ssh_connection, delegate),
 617            State::Connecting
 618            | State::Reconnecting
 619            | State::ReconnectExhausted
 620            | State::ServerNotRunning => unreachable!(),
 621        };
 622
 623        let attempts = attempts + 1;
 624        if attempts > MAX_RECONNECT_ATTEMPTS {
 625            log::error!(
 626                "Failed to reconnect to after {} attempts, giving up",
 627                MAX_RECONNECT_ATTEMPTS
 628            );
 629            drop(lock);
 630            self.set_state(State::ReconnectExhausted, cx);
 631            return Ok(());
 632        }
 633        drop(lock);
 634
 635        self.set_state(State::Reconnecting, cx);
 636
 637        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 638
 639        let unique_identifier = self.unique_identifier.clone();
 640        let client = self.client.clone();
 641        let reconnect_task = cx.spawn(|this, mut cx| async move {
 642            macro_rules! failed {
 643                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 644                    return State::ReconnectFailed {
 645                        error: anyhow!($error),
 646                        attempts: $attempts,
 647                        ssh_connection: $ssh_connection,
 648                        delegate: $delegate,
 649                    };
 650                };
 651            }
 652
 653            if let Err(error) = ssh_connection
 654                .kill()
 655                .await
 656                .context("Failed to kill ssh process")
 657            {
 658                failed!(error, attempts, ssh_connection, delegate);
 659            };
 660
 661            let connection_options = ssh_connection.connection_options();
 662
 663            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 664            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 665            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 666
 667            let (ssh_connection, io_task) = match async {
 668                let ssh_connection = cx
 669                    .update_global(|pool: &mut ConnectionPool, cx| {
 670                        pool.connect(connection_options, &delegate, cx)
 671                    })?
 672                    .await
 673                    .map_err(|error| error.cloned())?;
 674
 675                let remote_binary_path = ssh_connection
 676                    .get_remote_binary_path(&delegate, true, &mut cx)
 677                    .await?;
 678
 679                let io_task = ssh_connection.start_proxy(
 680                    remote_binary_path,
 681                    unique_identifier,
 682                    true,
 683                    incoming_tx,
 684                    outgoing_rx,
 685                    connection_activity_tx,
 686                    delegate.clone(),
 687                    &mut cx,
 688                );
 689                anyhow::Ok((ssh_connection, io_task))
 690            }
 691            .await
 692            {
 693                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 694                Err(error) => {
 695                    failed!(error, attempts, ssh_connection, delegate);
 696                }
 697            };
 698
 699            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 700            client.reconnect(incoming_rx, outgoing_tx, &cx);
 701
 702            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 703                failed!(error, attempts, ssh_connection, delegate);
 704            };
 705
 706            State::Connected {
 707                ssh_connection,
 708                delegate,
 709                multiplex_task,
 710                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 711            }
 712        });
 713
 714        cx.spawn(|this, mut cx| async move {
 715            let new_state = reconnect_task.await;
 716            this.update(&mut cx, |this, cx| {
 717                this.try_set_state(cx, |old_state| {
 718                    if old_state.is_reconnecting() {
 719                        match &new_state {
 720                            State::Connecting
 721                            | State::Reconnecting { .. }
 722                            | State::HeartbeatMissed { .. }
 723                            | State::ServerNotRunning => {}
 724                            State::Connected { .. } => {
 725                                log::info!("Successfully reconnected");
 726                            }
 727                            State::ReconnectFailed {
 728                                error, attempts, ..
 729                            } => {
 730                                log::error!(
 731                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 732                                    attempts,
 733                                    error
 734                                );
 735                            }
 736                            State::ReconnectExhausted => {
 737                                log::error!("Reconnect attempt failed and all attempts exhausted");
 738                            }
 739                        }
 740                        Some(new_state)
 741                    } else {
 742                        None
 743                    }
 744                });
 745
 746                if this.state_is(State::is_reconnect_failed) {
 747                    this.reconnect(cx)
 748                } else if this.state_is(State::is_reconnect_exhausted) {
 749                    Ok(())
 750                } else {
 751                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 752                    Ok(())
 753                }
 754            })
 755        })
 756        .detach_and_log_err(cx);
 757
 758        Ok(())
 759    }
 760
 761    fn heartbeat(
 762        this: WeakModel<Self>,
 763        mut connection_activity_rx: mpsc::Receiver<()>,
 764        cx: &mut AsyncAppContext,
 765    ) -> Task<Result<()>> {
 766        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 767            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 768        };
 769
 770        cx.spawn(|mut cx| {
 771            let this = this.clone();
 772            async move {
 773                let mut missed_heartbeats = 0;
 774
 775                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 776                futures::pin_mut!(keepalive_timer);
 777
 778                loop {
 779                    select_biased! {
 780                        result = connection_activity_rx.next().fuse() => {
 781                            if result.is_none() {
 782                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 783                                return Ok(());
 784                            }
 785
 786                            if missed_heartbeats != 0 {
 787                                missed_heartbeats = 0;
 788                                this.update(&mut cx, |this, mut cx| {
 789                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 790                                })?;
 791                            }
 792                        }
 793                        _ = keepalive_timer => {
 794                            log::debug!("Sending heartbeat to server...");
 795
 796                            let result = select_biased! {
 797                                _ = connection_activity_rx.next().fuse() => {
 798                                    Ok(())
 799                                }
 800                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 801                                    ping_result
 802                                }
 803                            };
 804
 805                            if result.is_err() {
 806                                missed_heartbeats += 1;
 807                                log::warn!(
 808                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 809                                    HEARTBEAT_TIMEOUT,
 810                                    missed_heartbeats,
 811                                    MAX_MISSED_HEARTBEATS
 812                                );
 813                            } else if missed_heartbeats != 0 {
 814                                missed_heartbeats = 0;
 815                            } else {
 816                                continue;
 817                            }
 818
 819                            let result = this.update(&mut cx, |this, mut cx| {
 820                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 821                            })?;
 822                            if result.is_break() {
 823                                return Ok(());
 824                            }
 825                        }
 826                    }
 827
 828                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 829                }
 830            }
 831        })
 832    }
 833
 834    fn handle_heartbeat_result(
 835        &mut self,
 836        missed_heartbeats: usize,
 837        cx: &mut ModelContext<Self>,
 838    ) -> ControlFlow<()> {
 839        let state = self.state.lock().take().unwrap();
 840        let next_state = if missed_heartbeats > 0 {
 841            state.heartbeat_missed()
 842        } else {
 843            state.heartbeat_recovered()
 844        };
 845
 846        self.set_state(next_state, cx);
 847
 848        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 849            log::error!(
 850                "Missed last {} heartbeats. Reconnecting...",
 851                missed_heartbeats
 852            );
 853
 854            self.reconnect(cx)
 855                .context("failed to start reconnect process after missing heartbeats")
 856                .log_err();
 857            ControlFlow::Break(())
 858        } else {
 859            ControlFlow::Continue(())
 860        }
 861    }
 862
 863    fn monitor(
 864        this: WeakModel<Self>,
 865        io_task: Task<Result<i32>>,
 866        cx: &AsyncAppContext,
 867    ) -> Task<Result<()>> {
 868        cx.spawn(|mut cx| async move {
 869            let result = io_task.await;
 870
 871            match result {
 872                Ok(exit_code) => {
 873                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 874                        match error {
 875                            ProxyLaunchError::ServerNotRunning => {
 876                                log::error!("failed to reconnect because server is not running");
 877                                this.update(&mut cx, |this, cx| {
 878                                    this.set_state(State::ServerNotRunning, cx);
 879                                })?;
 880                            }
 881                        }
 882                    } else if exit_code > 0 {
 883                        log::error!("proxy process terminated unexpectedly");
 884                        this.update(&mut cx, |this, cx| {
 885                            this.reconnect(cx).ok();
 886                        })?;
 887                    }
 888                }
 889                Err(error) => {
 890                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 891                    this.update(&mut cx, |this, cx| {
 892                        this.reconnect(cx).ok();
 893                    })?;
 894                }
 895            }
 896
 897            Ok(())
 898        })
 899    }
 900
 901    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 902        self.state.lock().as_ref().map_or(false, check)
 903    }
 904
 905    fn try_set_state(
 906        &self,
 907        cx: &mut ModelContext<Self>,
 908        map: impl FnOnce(&State) -> Option<State>,
 909    ) {
 910        let mut lock = self.state.lock();
 911        let new_state = lock.as_ref().and_then(map);
 912
 913        if let Some(new_state) = new_state {
 914            lock.replace(new_state);
 915            cx.notify();
 916        }
 917    }
 918
 919    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 920        log::info!("setting state to '{}'", &state);
 921
 922        let is_reconnect_exhausted = state.is_reconnect_exhausted();
 923        let is_server_not_running = state.is_server_not_running();
 924        self.state.lock().replace(state);
 925
 926        if is_reconnect_exhausted || is_server_not_running {
 927            cx.emit(SshRemoteEvent::Disconnected);
 928        }
 929        cx.notify();
 930    }
 931
 932    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 933        self.client.subscribe_to_entity(remote_id, entity);
 934    }
 935
 936    pub fn ssh_args(&self) -> Option<Vec<String>> {
 937        self.state
 938            .lock()
 939            .as_ref()
 940            .and_then(|state| state.ssh_connection())
 941            .map(|ssh_connection| ssh_connection.ssh_args())
 942    }
 943
 944    pub fn proto_client(&self) -> AnyProtoClient {
 945        self.client.clone().into()
 946    }
 947
 948    pub fn connection_string(&self) -> String {
 949        self.connection_options.connection_string()
 950    }
 951
 952    pub fn connection_options(&self) -> SshConnectionOptions {
 953        self.connection_options.clone()
 954    }
 955
 956    pub fn connection_state(&self) -> ConnectionState {
 957        self.state
 958            .lock()
 959            .as_ref()
 960            .map(ConnectionState::from)
 961            .unwrap_or(ConnectionState::Disconnected)
 962    }
 963
 964    pub fn is_disconnected(&self) -> bool {
 965        self.connection_state() == ConnectionState::Disconnected
 966    }
 967
 968    #[cfg(any(test, feature = "test-support"))]
 969    pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
 970        let opts = self.connection_options();
 971        client_cx.spawn(|cx| async move {
 972            let connection = cx
 973                .update_global(|c: &mut ConnectionPool, _| {
 974                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
 975                        c.clone()
 976                    } else {
 977                        panic!("missing test connection")
 978                    }
 979                })
 980                .unwrap()
 981                .await
 982                .unwrap();
 983
 984            connection.simulate_disconnect(&cx);
 985        })
 986    }
 987
 988    #[cfg(any(test, feature = "test-support"))]
 989    pub fn fake_server(
 990        client_cx: &mut gpui::TestAppContext,
 991        server_cx: &mut gpui::TestAppContext,
 992    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
 993        let port = client_cx
 994            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
 995        let opts = SshConnectionOptions {
 996            host: "<fake>".to_string(),
 997            port: Some(port),
 998            ..Default::default()
 999        };
1000        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1001        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1002        let server_client =
1003            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1004        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1005            connection_options: opts.clone(),
1006            server_cx: fake::SendableCx::new(server_cx.to_async()),
1007            server_channel: server_client.clone(),
1008        });
1009
1010        client_cx.update(|cx| {
1011            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1012                c.connections.insert(
1013                    opts.clone(),
1014                    ConnectionPoolEntry::Connecting(
1015                        cx.foreground_executor()
1016                            .spawn({
1017                                let connection = connection.clone();
1018                                async move { Ok(connection.clone()) }
1019                            })
1020                            .shared(),
1021                    ),
1022                );
1023            })
1024        });
1025
1026        (opts, server_client)
1027    }
1028
1029    #[cfg(any(test, feature = "test-support"))]
1030    pub async fn fake_client(
1031        opts: SshConnectionOptions,
1032        client_cx: &mut gpui::TestAppContext,
1033    ) -> Model<Self> {
1034        let (_tx, rx) = oneshot::channel();
1035        client_cx
1036            .update(|cx| Self::new("fake".to_string(), opts, rx, Arc::new(fake::Delegate), cx))
1037            .await
1038            .unwrap()
1039            .unwrap()
1040    }
1041}
1042
1043enum ConnectionPoolEntry {
1044    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1045    Connected(Weak<dyn RemoteConnection>),
1046}
1047
1048#[derive(Default)]
1049struct ConnectionPool {
1050    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1051}
1052
1053impl Global for ConnectionPool {}
1054
1055impl ConnectionPool {
1056    pub fn connect(
1057        &mut self,
1058        opts: SshConnectionOptions,
1059        delegate: &Arc<dyn SshClientDelegate>,
1060        cx: &mut AppContext,
1061    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1062        let connection = self.connections.get(&opts);
1063        match connection {
1064            Some(ConnectionPoolEntry::Connecting(task)) => {
1065                let delegate = delegate.clone();
1066                cx.spawn(|mut cx| async move {
1067                    delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx);
1068                })
1069                .detach();
1070                return task.clone();
1071            }
1072            Some(ConnectionPoolEntry::Connected(ssh)) => {
1073                if let Some(ssh) = ssh.upgrade() {
1074                    if !ssh.has_been_killed() {
1075                        return Task::ready(Ok(ssh)).shared();
1076                    }
1077                }
1078                self.connections.remove(&opts);
1079            }
1080            None => {}
1081        }
1082
1083        let task = cx
1084            .spawn({
1085                let opts = opts.clone();
1086                let delegate = delegate.clone();
1087                |mut cx| async move {
1088                    let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx)
1089                        .await
1090                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1091
1092                    cx.update_global(|pool: &mut Self, _| {
1093                        debug_assert!(matches!(
1094                            pool.connections.get(&opts),
1095                            Some(ConnectionPoolEntry::Connecting(_))
1096                        ));
1097                        match connection {
1098                            Ok(connection) => {
1099                                pool.connections.insert(
1100                                    opts.clone(),
1101                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1102                                );
1103                                Ok(connection)
1104                            }
1105                            Err(error) => {
1106                                pool.connections.remove(&opts);
1107                                Err(Arc::new(error))
1108                            }
1109                        }
1110                    })?
1111                }
1112            })
1113            .shared();
1114
1115        self.connections
1116            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1117        task
1118    }
1119}
1120
1121impl From<SshRemoteClient> for AnyProtoClient {
1122    fn from(client: SshRemoteClient) -> Self {
1123        AnyProtoClient::new(client.client.clone())
1124    }
1125}
1126
1127#[async_trait(?Send)]
1128trait RemoteConnection: Send + Sync {
1129    #[allow(clippy::too_many_arguments)]
1130    fn start_proxy(
1131        &self,
1132        remote_binary_path: PathBuf,
1133        unique_identifier: String,
1134        reconnect: bool,
1135        incoming_tx: UnboundedSender<Envelope>,
1136        outgoing_rx: UnboundedReceiver<Envelope>,
1137        connection_activity_tx: Sender<()>,
1138        delegate: Arc<dyn SshClientDelegate>,
1139        cx: &mut AsyncAppContext,
1140    ) -> Task<Result<i32>>;
1141    async fn get_remote_binary_path(
1142        &self,
1143        delegate: &Arc<dyn SshClientDelegate>,
1144        reconnect: bool,
1145        cx: &mut AsyncAppContext,
1146    ) -> Result<PathBuf>;
1147    async fn kill(&self) -> Result<()>;
1148    fn has_been_killed(&self) -> bool;
1149    fn ssh_args(&self) -> Vec<String>;
1150    fn connection_options(&self) -> SshConnectionOptions;
1151
1152    #[cfg(any(test, feature = "test-support"))]
1153    fn simulate_disconnect(&self, _: &AsyncAppContext) {}
1154}
1155
1156struct SshRemoteConnection {
1157    socket: SshSocket,
1158    master_process: Mutex<Option<process::Child>>,
1159    platform: SshPlatform,
1160    _temp_dir: TempDir,
1161}
1162
1163#[async_trait(?Send)]
1164impl RemoteConnection for SshRemoteConnection {
1165    async fn kill(&self) -> Result<()> {
1166        let Some(mut process) = self.master_process.lock().take() else {
1167            return Ok(());
1168        };
1169        process.kill().ok();
1170        process.status().await?;
1171        Ok(())
1172    }
1173
1174    fn has_been_killed(&self) -> bool {
1175        self.master_process.lock().is_none()
1176    }
1177
1178    fn ssh_args(&self) -> Vec<String> {
1179        self.socket.ssh_args()
1180    }
1181
1182    fn connection_options(&self) -> SshConnectionOptions {
1183        self.socket.connection_options.clone()
1184    }
1185
1186    async fn get_remote_binary_path(
1187        &self,
1188        delegate: &Arc<dyn SshClientDelegate>,
1189        reconnect: bool,
1190        cx: &mut AsyncAppContext,
1191    ) -> Result<PathBuf> {
1192        let platform = self.platform;
1193        let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1194        if !reconnect {
1195            self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1196                .await?;
1197        }
1198
1199        let socket = self.socket.clone();
1200        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1201        Ok(remote_binary_path)
1202    }
1203
1204    fn start_proxy(
1205        &self,
1206        remote_binary_path: PathBuf,
1207        unique_identifier: String,
1208        reconnect: bool,
1209        incoming_tx: UnboundedSender<Envelope>,
1210        outgoing_rx: UnboundedReceiver<Envelope>,
1211        connection_activity_tx: Sender<()>,
1212        delegate: Arc<dyn SshClientDelegate>,
1213        cx: &mut AsyncAppContext,
1214    ) -> Task<Result<i32>> {
1215        delegate.set_status(Some("Starting proxy"), cx);
1216
1217        let mut start_proxy_command = format!(
1218            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1219            std::env::var("RUST_LOG").unwrap_or_default(),
1220            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1221            remote_binary_path,
1222            unique_identifier,
1223        );
1224        if reconnect {
1225            start_proxy_command.push_str(" --reconnect");
1226        }
1227
1228        let ssh_proxy_process = match self
1229            .socket
1230            .ssh_command(start_proxy_command)
1231            // IMPORTANT: we kill this process when we drop the task that uses it.
1232            .kill_on_drop(true)
1233            .spawn()
1234        {
1235            Ok(process) => process,
1236            Err(error) => {
1237                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)))
1238            }
1239        };
1240
1241        Self::multiplex(
1242            ssh_proxy_process,
1243            incoming_tx,
1244            outgoing_rx,
1245            connection_activity_tx,
1246            &cx,
1247        )
1248    }
1249}
1250
1251impl SshRemoteConnection {
1252    #[cfg(not(unix))]
1253    async fn new(
1254        _connection_options: SshConnectionOptions,
1255        _delegate: Arc<dyn SshClientDelegate>,
1256        _cx: &mut AsyncAppContext,
1257    ) -> Result<Self> {
1258        Err(anyhow!("ssh is not supported on this platform"))
1259    }
1260
1261    #[cfg(unix)]
1262    async fn new(
1263        connection_options: SshConnectionOptions,
1264        delegate: Arc<dyn SshClientDelegate>,
1265        cx: &mut AsyncAppContext,
1266    ) -> Result<Self> {
1267        use futures::AsyncWriteExt as _;
1268        use futures::{io::BufReader, AsyncBufReadExt as _};
1269        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1270        use util::ResultExt as _;
1271
1272        delegate.set_status(Some("Connecting"), cx);
1273
1274        let url = connection_options.ssh_url();
1275        let temp_dir = tempfile::Builder::new()
1276            .prefix("zed-ssh-session")
1277            .tempdir()?;
1278
1279        // Create a domain socket listener to handle requests from the askpass program.
1280        let askpass_socket = temp_dir.path().join("askpass.sock");
1281        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1282        let listener =
1283            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1284
1285        let askpass_task = cx.spawn({
1286            let delegate = delegate.clone();
1287            |mut cx| async move {
1288                let mut askpass_opened_tx = Some(askpass_opened_tx);
1289
1290                while let Ok((mut stream, _)) = listener.accept().await {
1291                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1292                        askpass_opened_tx.send(()).ok();
1293                    }
1294                    let mut buffer = Vec::new();
1295                    let mut reader = BufReader::new(&mut stream);
1296                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1297                        buffer.clear();
1298                    }
1299                    let password_prompt = String::from_utf8_lossy(&buffer);
1300                    if let Some(password) = delegate
1301                        .ask_password(password_prompt.to_string(), &mut cx)
1302                        .await
1303                        .context("failed to get ssh password")
1304                        .and_then(|p| p)
1305                        .log_err()
1306                    {
1307                        stream.write_all(password.as_bytes()).await.log_err();
1308                    }
1309                }
1310            }
1311        });
1312
1313        // Create an askpass script that communicates back to this process.
1314        let askpass_script = format!(
1315            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1316            askpass_socket = askpass_socket.display(),
1317            print_args = "printf '%s\\0' \"$@\"",
1318            shebang = "#!/bin/sh",
1319        );
1320        let askpass_script_path = temp_dir.path().join("askpass.sh");
1321        fs::write(&askpass_script_path, askpass_script).await?;
1322        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1323
1324        // Start the master SSH process, which does not do anything except for establish
1325        // the connection and keep it open, allowing other ssh commands to reuse it
1326        // via a control socket.
1327        let socket_path = temp_dir.path().join("ssh.sock");
1328        let mut master_process = process::Command::new("ssh")
1329            .stdin(Stdio::null())
1330            .stdout(Stdio::piped())
1331            .stderr(Stdio::piped())
1332            .env("SSH_ASKPASS_REQUIRE", "force")
1333            .env("SSH_ASKPASS", &askpass_script_path)
1334            .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1335            .args([
1336                "-N",
1337                "-o",
1338                "ControlPersist=no",
1339                "-o",
1340                "ControlMaster=yes",
1341                "-o",
1342            ])
1343            .arg(format!("ControlPath={}", socket_path.display()))
1344            .arg(&url)
1345            .kill_on_drop(true)
1346            .spawn()?;
1347
1348        // Wait for this ssh process to close its stdout, indicating that authentication
1349        // has completed.
1350        let stdout = master_process.stdout.as_mut().unwrap();
1351        let mut output = Vec::new();
1352        let connection_timeout = Duration::from_secs(10);
1353
1354        let result = select_biased! {
1355            _ = askpass_opened_rx.fuse() => {
1356                // If the askpass script has opened, that means the user is typing
1357                // their password, in which case we don't want to timeout anymore,
1358                // since we know a connection has been established.
1359                stdout.read_to_end(&mut output).await?;
1360                Ok(())
1361            }
1362            result = stdout.read_to_end(&mut output).fuse() => {
1363                result?;
1364                Ok(())
1365            }
1366            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1367                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1368            }
1369        };
1370
1371        if let Err(e) = result {
1372            return Err(e.context("Failed to connect to host"));
1373        }
1374
1375        drop(askpass_task);
1376
1377        if master_process.try_status()?.is_some() {
1378            output.clear();
1379            let mut stderr = master_process.stderr.take().unwrap();
1380            stderr.read_to_end(&mut output).await?;
1381
1382            let error_message = format!(
1383                "failed to connect: {}",
1384                String::from_utf8_lossy(&output).trim()
1385            );
1386            Err(anyhow!(error_message))?;
1387        }
1388
1389        let socket = SshSocket {
1390            connection_options,
1391            socket_path,
1392        };
1393
1394        let os = run_cmd(socket.ssh_command("uname").arg("-s")).await?;
1395        let arch = run_cmd(socket.ssh_command("uname").arg("-m")).await?;
1396
1397        let os = match os.trim() {
1398            "Darwin" => "macos",
1399            "Linux" => "linux",
1400            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1401        };
1402        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1403            "aarch64"
1404        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1405            "x86_64"
1406        } else {
1407            Err(anyhow!("unknown uname architecture {arch:?}"))?
1408        };
1409
1410        let platform = SshPlatform { os, arch };
1411
1412        Ok(Self {
1413            socket,
1414            master_process: Mutex::new(Some(master_process)),
1415            platform,
1416            _temp_dir: temp_dir,
1417        })
1418    }
1419
1420    fn multiplex(
1421        mut ssh_proxy_process: Child,
1422        incoming_tx: UnboundedSender<Envelope>,
1423        mut outgoing_rx: UnboundedReceiver<Envelope>,
1424        mut connection_activity_tx: Sender<()>,
1425        cx: &AsyncAppContext,
1426    ) -> Task<Result<i32>> {
1427        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1428        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1429        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1430
1431        let mut stdin_buffer = Vec::new();
1432        let mut stdout_buffer = Vec::new();
1433        let mut stderr_buffer = Vec::new();
1434        let mut stderr_offset = 0;
1435
1436        let stdin_task = cx.background_executor().spawn(async move {
1437            while let Some(outgoing) = outgoing_rx.next().await {
1438                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1439            }
1440            anyhow::Ok(())
1441        });
1442
1443        let stdout_task = cx.background_executor().spawn({
1444            let mut connection_activity_tx = connection_activity_tx.clone();
1445            async move {
1446                loop {
1447                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1448                    let len = child_stdout.read(&mut stdout_buffer).await?;
1449
1450                    if len == 0 {
1451                        return anyhow::Ok(());
1452                    }
1453
1454                    if len < MESSAGE_LEN_SIZE {
1455                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1456                    }
1457
1458                    let message_len = message_len_from_buffer(&stdout_buffer);
1459                    let envelope =
1460                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1461                            .await?;
1462                    connection_activity_tx.try_send(()).ok();
1463                    incoming_tx.unbounded_send(envelope).ok();
1464                }
1465            }
1466        });
1467
1468        let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
1469            loop {
1470                stderr_buffer.resize(stderr_offset + 1024, 0);
1471
1472                let len = child_stderr
1473                    .read(&mut stderr_buffer[stderr_offset..])
1474                    .await?;
1475                if len == 0 {
1476                    return anyhow::Ok(());
1477                }
1478
1479                stderr_offset += len;
1480                let mut start_ix = 0;
1481                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1482                    .iter()
1483                    .position(|b| b == &b'\n')
1484                {
1485                    let line_ix = start_ix + ix;
1486                    let content = &stderr_buffer[start_ix..line_ix];
1487                    start_ix = line_ix + 1;
1488                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1489                        record.log(log::logger())
1490                    } else {
1491                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1492                    }
1493                }
1494                stderr_buffer.drain(0..start_ix);
1495                stderr_offset -= start_ix;
1496
1497                connection_activity_tx.try_send(()).ok();
1498            }
1499        });
1500
1501        cx.spawn(|_| async move {
1502            let result = futures::select! {
1503                result = stdin_task.fuse() => {
1504                    result.context("stdin")
1505                }
1506                result = stdout_task.fuse() => {
1507                    result.context("stdout")
1508                }
1509                result = stderr_task.fuse() => {
1510                    result.context("stderr")
1511                }
1512            };
1513
1514            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1515            match result {
1516                Ok(_) => Ok(status),
1517                Err(error) => Err(error),
1518            }
1519        })
1520    }
1521
1522    async fn ensure_server_binary(
1523        &self,
1524        delegate: &Arc<dyn SshClientDelegate>,
1525        dst_path: &Path,
1526        platform: SshPlatform,
1527        cx: &mut AsyncAppContext,
1528    ) -> Result<()> {
1529        let lock_file = dst_path.with_extension("lock");
1530        let timestamp = SystemTime::now()
1531            .duration_since(UNIX_EPOCH)
1532            .unwrap()
1533            .as_secs();
1534        let lock_content = timestamp.to_string();
1535
1536        let lock_stale_age = Duration::from_secs(10 * 60);
1537        let max_wait_time = Duration::from_secs(10 * 60);
1538        let check_interval = Duration::from_secs(5);
1539        let start_time = Instant::now();
1540
1541        loop {
1542            let lock_acquired = self.create_lock_file(&lock_file, &lock_content).await?;
1543            if lock_acquired {
1544                let result = self
1545                    .update_server_binary_if_needed(delegate, dst_path, platform, cx)
1546                    .await;
1547
1548                self.remove_lock_file(&lock_file).await.ok();
1549
1550                return result;
1551            } else {
1552                if let Ok(is_stale) = self.is_lock_stale(&lock_file, &lock_stale_age).await {
1553                    if is_stale {
1554                        self.remove_lock_file(&lock_file).await?;
1555                        continue;
1556                    } else {
1557                        if start_time.elapsed() > max_wait_time {
1558                            return Err(anyhow!("Timeout waiting for lock to be released"));
1559                        }
1560                        log::info!(
1561                            "Found lockfile: {:?}. Will check again in {:?}",
1562                            lock_file,
1563                            check_interval
1564                        );
1565                        delegate.set_status(
1566                            Some("Waiting for another Zed instance to finish uploading binary"),
1567                            cx,
1568                        );
1569                        smol::Timer::after(check_interval).await;
1570                        continue;
1571                    }
1572                } else {
1573                    // Unable to check lock, assume it's valid and wait
1574                    if start_time.elapsed() > max_wait_time {
1575                        return Err(anyhow!("Timeout waiting for lock to be released"));
1576                    }
1577                    smol::Timer::after(check_interval).await;
1578                    continue;
1579                }
1580            }
1581        }
1582    }
1583
1584    async fn create_lock_file(&self, lock_file: &Path, content: &str) -> Result<bool> {
1585        let parent_dir = lock_file
1586            .parent()
1587            .ok_or_else(|| anyhow!("Lock file path has no parent directory"))?;
1588
1589        // Be mindful of the escaping here: we need to make sure that we have quotes
1590        // inside the string, so that `sh -c` gets a quoted string passed to it.
1591        let script = format!(
1592            "\"mkdir -p '{0}' &&  [ ! -f '{1}' ] && echo '{2}' > '{1}' && echo 'created' || echo 'exists'\"",
1593            parent_dir.display(),
1594            lock_file.display(),
1595            content
1596        );
1597
1598        let output = run_cmd(self.socket.ssh_command("sh").arg("-c").arg(&script))
1599            .await
1600            .with_context(|| format!("failed to create a lock file at {:?}", lock_file))?;
1601
1602        Ok(output.trim() == "created")
1603    }
1604
1605    async fn is_lock_stale(&self, lock_file: &Path, max_age: &Duration) -> Result<bool> {
1606        let threshold = max_age.as_secs();
1607
1608        // Be mindful of the escaping here: we need to make sure that we have quotes
1609        // inside the string, so that `sh -c` gets a quoted string passed to it.
1610        let script = format!(
1611            "\"[ -f '{0}' ] && [ $(( $(date +%s) - $(date -r '{0}' +%s) )) -gt {1} ] && echo 'stale' ||  echo 'recent'\"",
1612            lock_file.display(),
1613            threshold
1614        );
1615
1616        let output = run_cmd(self.socket.ssh_command("sh").arg("-c").arg(script))
1617            .await
1618            .with_context(|| {
1619                format!("failed to check whether lock file {:?} is stale", lock_file)
1620            })?;
1621
1622        Ok(output.trim() == "stale")
1623    }
1624
1625    async fn remove_lock_file(&self, lock_file: &Path) -> Result<()> {
1626        run_cmd(self.socket.ssh_command("rm").arg("-f").arg(lock_file))
1627            .await
1628            .context("failed to remove lock file")?;
1629        Ok(())
1630    }
1631
1632    async fn update_server_binary_if_needed(
1633        &self,
1634        delegate: &Arc<dyn SshClientDelegate>,
1635        dst_path: &Path,
1636        platform: SshPlatform,
1637        cx: &mut AsyncAppContext,
1638    ) -> Result<()> {
1639        if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1640            if let Ok(installed_version) =
1641                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1642            {
1643                log::info!("using cached server binary version {}", installed_version);
1644                return Ok(());
1645            }
1646        }
1647
1648        let (binary, version) = delegate.get_server_binary(platform, cx).await??;
1649
1650        let mut server_binary_exists = false;
1651        if !server_binary_exists && cfg!(not(debug_assertions)) {
1652            if let Ok(installed_version) =
1653                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1654            {
1655                if installed_version.trim() == version.to_string() {
1656                    server_binary_exists = true;
1657                }
1658                log::info!("checked remote server binary for version. latest version: {}. remote server version: {}", version.to_string(), installed_version.trim());
1659            }
1660        }
1661
1662        if server_binary_exists {
1663            log::info!("remote development server already present",);
1664            return Ok(());
1665        }
1666
1667        match binary {
1668            ServerBinary::LocalBinary(src_path) => {
1669                self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
1670                    .await
1671            }
1672            ServerBinary::ReleaseUrl { url, body } => {
1673                self.download_binary_on_server(&url, &body, dst_path, delegate, cx)
1674                    .await
1675            }
1676        }
1677    }
1678
1679    async fn download_binary_on_server(
1680        &self,
1681        url: &str,
1682        body: &str,
1683        dst_path: &Path,
1684        delegate: &Arc<dyn SshClientDelegate>,
1685        cx: &mut AsyncAppContext,
1686    ) -> Result<()> {
1687        let mut dst_path_gz = dst_path.to_path_buf();
1688        dst_path_gz.set_extension("gz");
1689
1690        if let Some(parent) = dst_path.parent() {
1691            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1692        }
1693
1694        delegate.set_status(Some("Downloading remote development server on host"), cx);
1695
1696        let script = format!(
1697            r#"
1698            if command -v wget >/dev/null 2>&1; then
1699                wget --max-redirect=5 --method=GET --header="Content-Type: application/json" --body-data='{}' '{}' -O '{}' && echo "wget"
1700            elif command -v curl >/dev/null 2>&1; then
1701                curl -L -X GET -H "Content-Type: application/json" -d '{}' '{}' -o '{}' && echo "curl"
1702            else
1703                echo "Neither curl nor wget is available" >&2
1704                exit 1
1705            fi
1706            "#,
1707            body.replace("'", r#"\'"#),
1708            url,
1709            dst_path_gz.display(),
1710            body.replace("'", r#"\'"#),
1711            url,
1712            dst_path_gz.display(),
1713        );
1714
1715        let output = run_cmd(self.socket.ssh_command("bash").arg("-c").arg(script))
1716            .await
1717            .context("Failed to download server binary")?;
1718
1719        if !output.contains("curl") && !output.contains("wget") {
1720            return Err(anyhow!("Failed to download server binary: {}", output));
1721        }
1722
1723        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
1724            .await
1725    }
1726
1727    async fn upload_local_server_binary(
1728        &self,
1729        src_path: &Path,
1730        dst_path: &Path,
1731        delegate: &Arc<dyn SshClientDelegate>,
1732        cx: &mut AsyncAppContext,
1733    ) -> Result<()> {
1734        let mut dst_path_gz = dst_path.to_path_buf();
1735        dst_path_gz.set_extension("gz");
1736
1737        if let Some(parent) = dst_path.parent() {
1738            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1739        }
1740
1741        let src_stat = fs::metadata(&src_path).await?;
1742        let size = src_stat.len();
1743
1744        let t0 = Instant::now();
1745        delegate.set_status(Some("Uploading remote development server"), cx);
1746        log::info!("uploading remote development server ({}kb)", size / 1024);
1747        self.upload_file(&src_path, &dst_path_gz)
1748            .await
1749            .context("failed to upload server binary")?;
1750        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1751
1752        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
1753            .await
1754    }
1755
1756    async fn extract_server_binary(
1757        &self,
1758        dst_path: &Path,
1759        dst_path_gz: &Path,
1760        delegate: &Arc<dyn SshClientDelegate>,
1761        cx: &mut AsyncAppContext,
1762    ) -> Result<()> {
1763        delegate.set_status(Some("Extracting remote development server"), cx);
1764        run_cmd(
1765            self.socket
1766                .ssh_command("gunzip")
1767                .arg("--force")
1768                .arg(&dst_path_gz),
1769        )
1770        .await?;
1771
1772        let server_mode = 0o755;
1773        delegate.set_status(Some("Marking remote development server executable"), cx);
1774        run_cmd(
1775            self.socket
1776                .ssh_command("chmod")
1777                .arg(format!("{:o}", server_mode))
1778                .arg(dst_path),
1779        )
1780        .await?;
1781
1782        Ok(())
1783    }
1784
1785    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1786        let mut command = process::Command::new("scp");
1787        let output = self
1788            .socket
1789            .ssh_options(&mut command)
1790            .args(
1791                self.socket
1792                    .connection_options
1793                    .port
1794                    .map(|port| vec!["-P".to_string(), port.to_string()])
1795                    .unwrap_or_default(),
1796            )
1797            .arg(src_path)
1798            .arg(format!(
1799                "{}:{}",
1800                self.socket.connection_options.scp_url(),
1801                dest_path.display()
1802            ))
1803            .output()
1804            .await?;
1805
1806        if output.status.success() {
1807            Ok(())
1808        } else {
1809            Err(anyhow!(
1810                "failed to upload file {} -> {}: {}",
1811                src_path.display(),
1812                dest_path.display(),
1813                String::from_utf8_lossy(&output.stderr)
1814            ))
1815        }
1816    }
1817}
1818
1819type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1820
1821pub struct ChannelClient {
1822    next_message_id: AtomicU32,
1823    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1824    buffer: Mutex<VecDeque<Envelope>>,
1825    response_channels: ResponseChannels,
1826    message_handlers: Mutex<ProtoMessageHandlerSet>,
1827    max_received: AtomicU32,
1828    name: &'static str,
1829    task: Mutex<Task<Result<()>>>,
1830}
1831
1832impl ChannelClient {
1833    pub fn new(
1834        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1835        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1836        cx: &AppContext,
1837        name: &'static str,
1838    ) -> Arc<Self> {
1839        Arc::new_cyclic(|this| Self {
1840            outgoing_tx: Mutex::new(outgoing_tx),
1841            next_message_id: AtomicU32::new(0),
1842            max_received: AtomicU32::new(0),
1843            response_channels: ResponseChannels::default(),
1844            message_handlers: Default::default(),
1845            buffer: Mutex::new(VecDeque::new()),
1846            name,
1847            task: Mutex::new(Self::start_handling_messages(
1848                this.clone(),
1849                incoming_rx,
1850                &cx.to_async(),
1851            )),
1852        })
1853    }
1854
1855    fn start_handling_messages(
1856        this: Weak<Self>,
1857        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1858        cx: &AsyncAppContext,
1859    ) -> Task<Result<()>> {
1860        cx.spawn(|cx| {
1861            async move {
1862                let peer_id = PeerId { owner_id: 0, id: 0 };
1863                while let Some(incoming) = incoming_rx.next().await {
1864                    let Some(this) = this.upgrade() else {
1865                        return anyhow::Ok(());
1866                    };
1867                    if let Some(ack_id) = incoming.ack_id {
1868                        let mut buffer = this.buffer.lock();
1869                        while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1870                            buffer.pop_front();
1871                        }
1872                    }
1873                    if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) =
1874                        &incoming.payload
1875                    {
1876                        log::debug!("{}:ssh message received. name:FlushBufferedMessages", this.name);
1877                        {
1878                            let buffer = this.buffer.lock();
1879                            for envelope in buffer.iter() {
1880                                this.outgoing_tx.lock().unbounded_send(envelope.clone()).ok();
1881                            }
1882                        }
1883                        let mut envelope = proto::Ack{}.into_envelope(0, Some(incoming.id), None);
1884                        envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1885                        this.outgoing_tx.lock().unbounded_send(envelope).ok();
1886                        continue;
1887                    }
1888
1889                    this.max_received.store(incoming.id, SeqCst);
1890
1891                    if let Some(request_id) = incoming.responding_to {
1892                        let request_id = MessageId(request_id);
1893                        let sender = this.response_channels.lock().remove(&request_id);
1894                        if let Some(sender) = sender {
1895                            let (tx, rx) = oneshot::channel();
1896                            if incoming.payload.is_some() {
1897                                sender.send((incoming, tx)).ok();
1898                            }
1899                            rx.await.ok();
1900                        }
1901                    } else if let Some(envelope) =
1902                        build_typed_envelope(peer_id, Instant::now(), incoming)
1903                    {
1904                        let type_name = envelope.payload_type_name();
1905                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1906                            &this.message_handlers,
1907                            envelope,
1908                            this.clone().into(),
1909                            cx.clone(),
1910                        ) {
1911                            log::debug!("{}:ssh message received. name:{type_name}", this.name);
1912                            cx.foreground_executor().spawn(async move {
1913                                match future.await {
1914                                    Ok(_) => {
1915                                        log::debug!("{}:ssh message handled. name:{type_name}", this.name);
1916                                    }
1917                                    Err(error) => {
1918                                        log::error!(
1919                                            "{}:error handling message. type:{type_name}, error:{error}", this.name,
1920                                        );
1921                                    }
1922                                }
1923                            }).detach()
1924                        } else {
1925                            log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1926                        }
1927                    }
1928                }
1929                anyhow::Ok(())
1930            }
1931        })
1932    }
1933
1934    pub fn reconnect(
1935        self: &Arc<Self>,
1936        incoming_rx: UnboundedReceiver<Envelope>,
1937        outgoing_tx: UnboundedSender<Envelope>,
1938        cx: &AsyncAppContext,
1939    ) {
1940        *self.outgoing_tx.lock() = outgoing_tx;
1941        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1942    }
1943
1944    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1945        let id = (TypeId::of::<E>(), remote_id);
1946
1947        let mut message_handlers = self.message_handlers.lock();
1948        if message_handlers
1949            .entities_by_type_and_remote_id
1950            .contains_key(&id)
1951        {
1952            panic!("already subscribed to entity");
1953        }
1954
1955        message_handlers.entities_by_type_and_remote_id.insert(
1956            id,
1957            EntityMessageSubscriber::Entity {
1958                handle: entity.downgrade().into(),
1959            },
1960        );
1961    }
1962
1963    pub fn request<T: RequestMessage>(
1964        &self,
1965        payload: T,
1966    ) -> impl 'static + Future<Output = Result<T::Response>> {
1967        self.request_internal(payload, true)
1968    }
1969
1970    fn request_internal<T: RequestMessage>(
1971        &self,
1972        payload: T,
1973        use_buffer: bool,
1974    ) -> impl 'static + Future<Output = Result<T::Response>> {
1975        log::debug!("ssh request start. name:{}", T::NAME);
1976        let response =
1977            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
1978        async move {
1979            let response = response.await?;
1980            log::debug!("ssh request finish. name:{}", T::NAME);
1981            T::Response::from_envelope(response)
1982                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1983        }
1984    }
1985
1986    pub async fn resync(&self, timeout: Duration) -> Result<()> {
1987        smol::future::or(
1988            async {
1989                self.request_internal(proto::FlushBufferedMessages {}, false)
1990                    .await?;
1991
1992                for envelope in self.buffer.lock().iter() {
1993                    self.outgoing_tx
1994                        .lock()
1995                        .unbounded_send(envelope.clone())
1996                        .ok();
1997                }
1998                Ok(())
1999            },
2000            async {
2001                smol::Timer::after(timeout).await;
2002                Err(anyhow!("Timeout detected"))
2003            },
2004        )
2005        .await
2006    }
2007
2008    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2009        smol::future::or(
2010            async {
2011                self.request(proto::Ping {}).await?;
2012                Ok(())
2013            },
2014            async {
2015                smol::Timer::after(timeout).await;
2016                Err(anyhow!("Timeout detected"))
2017            },
2018        )
2019        .await
2020    }
2021
2022    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2023        log::debug!("ssh send name:{}", T::NAME);
2024        self.send_dynamic(payload.into_envelope(0, None, None))
2025    }
2026
2027    fn request_dynamic(
2028        &self,
2029        mut envelope: proto::Envelope,
2030        type_name: &'static str,
2031        use_buffer: bool,
2032    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2033        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2034        let (tx, rx) = oneshot::channel();
2035        let mut response_channels_lock = self.response_channels.lock();
2036        response_channels_lock.insert(MessageId(envelope.id), tx);
2037        drop(response_channels_lock);
2038
2039        let result = if use_buffer {
2040            self.send_buffered(envelope)
2041        } else {
2042            self.send_unbuffered(envelope)
2043        };
2044        async move {
2045            if let Err(error) = &result {
2046                log::error!("failed to send message: {}", error);
2047                return Err(anyhow!("failed to send message: {}", error));
2048            }
2049
2050            let response = rx.await.context("connection lost")?.0;
2051            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2052                return Err(RpcError::from_proto(error, type_name));
2053            }
2054            Ok(response)
2055        }
2056    }
2057
2058    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2059        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2060        self.send_buffered(envelope)
2061    }
2062
2063    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2064        envelope.ack_id = Some(self.max_received.load(SeqCst));
2065        self.buffer.lock().push_back(envelope.clone());
2066        // ignore errors on send (happen while we're reconnecting)
2067        // assume that the global "disconnected" overlay is sufficient.
2068        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2069        Ok(())
2070    }
2071
2072    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2073        envelope.ack_id = Some(self.max_received.load(SeqCst));
2074        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2075        Ok(())
2076    }
2077}
2078
2079impl ProtoClient for ChannelClient {
2080    fn request(
2081        &self,
2082        envelope: proto::Envelope,
2083        request_type: &'static str,
2084    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2085        self.request_dynamic(envelope, request_type, true).boxed()
2086    }
2087
2088    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2089        self.send_dynamic(envelope)
2090    }
2091
2092    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2093        self.send_dynamic(envelope)
2094    }
2095
2096    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2097        &self.message_handlers
2098    }
2099
2100    fn is_via_collab(&self) -> bool {
2101        false
2102    }
2103}
2104
2105#[cfg(any(test, feature = "test-support"))]
2106mod fake {
2107    use std::{path::PathBuf, sync::Arc};
2108
2109    use anyhow::Result;
2110    use async_trait::async_trait;
2111    use futures::{
2112        channel::{
2113            mpsc::{self, Sender},
2114            oneshot,
2115        },
2116        select_biased, FutureExt, SinkExt, StreamExt,
2117    };
2118    use gpui::{AsyncAppContext, SemanticVersion, Task};
2119    use rpc::proto::Envelope;
2120
2121    use super::{
2122        ChannelClient, RemoteConnection, ServerBinary, SshClientDelegate, SshConnectionOptions,
2123        SshPlatform,
2124    };
2125
2126    pub(super) struct FakeRemoteConnection {
2127        pub(super) connection_options: SshConnectionOptions,
2128        pub(super) server_channel: Arc<ChannelClient>,
2129        pub(super) server_cx: SendableCx,
2130    }
2131
2132    pub(super) struct SendableCx(AsyncAppContext);
2133    // safety: you can only get the other cx on the main thread.
2134    impl SendableCx {
2135        pub(super) fn new(cx: AsyncAppContext) -> Self {
2136            Self(cx)
2137        }
2138        fn get(&self, _: &AsyncAppContext) -> AsyncAppContext {
2139            self.0.clone()
2140        }
2141    }
2142    unsafe impl Send for SendableCx {}
2143    unsafe impl Sync for SendableCx {}
2144
2145    #[async_trait(?Send)]
2146    impl RemoteConnection for FakeRemoteConnection {
2147        async fn kill(&self) -> Result<()> {
2148            Ok(())
2149        }
2150
2151        fn has_been_killed(&self) -> bool {
2152            false
2153        }
2154
2155        fn ssh_args(&self) -> Vec<String> {
2156            Vec::new()
2157        }
2158
2159        fn connection_options(&self) -> SshConnectionOptions {
2160            self.connection_options.clone()
2161        }
2162
2163        fn simulate_disconnect(&self, cx: &AsyncAppContext) {
2164            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2165            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2166            self.server_channel
2167                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2168        }
2169
2170        async fn get_remote_binary_path(
2171            &self,
2172            _delegate: &Arc<dyn SshClientDelegate>,
2173            _reconnect: bool,
2174            _cx: &mut AsyncAppContext,
2175        ) -> Result<PathBuf> {
2176            Ok(PathBuf::new())
2177        }
2178
2179        fn start_proxy(
2180            &self,
2181            _remote_binary_path: PathBuf,
2182            _unique_identifier: String,
2183            _reconnect: bool,
2184            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2185            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2186            mut connection_activity_tx: Sender<()>,
2187            _delegate: Arc<dyn SshClientDelegate>,
2188            cx: &mut AsyncAppContext,
2189        ) -> Task<Result<i32>> {
2190            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2191            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2192
2193            self.server_channel.reconnect(
2194                server_incoming_rx,
2195                server_outgoing_tx,
2196                &self.server_cx.get(cx),
2197            );
2198
2199            cx.background_executor().spawn(async move {
2200                loop {
2201                    select_biased! {
2202                        server_to_client = server_outgoing_rx.next().fuse() => {
2203                            let Some(server_to_client) = server_to_client else {
2204                                return Ok(1)
2205                            };
2206                            connection_activity_tx.try_send(()).ok();
2207                            client_incoming_tx.send(server_to_client).await.ok();
2208                        }
2209                        client_to_server = client_outgoing_rx.next().fuse() => {
2210                            let Some(client_to_server) = client_to_server else {
2211                                return Ok(1)
2212                            };
2213                            server_incoming_tx.send(client_to_server).await.ok();
2214                        }
2215                    }
2216                }
2217            })
2218        }
2219    }
2220
2221    pub(super) struct Delegate;
2222
2223    impl SshClientDelegate for Delegate {
2224        fn ask_password(
2225            &self,
2226            _: String,
2227            _: &mut AsyncAppContext,
2228        ) -> oneshot::Receiver<Result<String>> {
2229            unreachable!()
2230        }
2231        fn remote_server_binary_path(
2232            &self,
2233            _: SshPlatform,
2234            _: &mut AsyncAppContext,
2235        ) -> Result<PathBuf> {
2236            unreachable!()
2237        }
2238        fn get_server_binary(
2239            &self,
2240            _: SshPlatform,
2241            _: &mut AsyncAppContext,
2242        ) -> oneshot::Receiver<Result<(ServerBinary, SemanticVersion)>> {
2243            unreachable!()
2244        }
2245
2246        fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {}
2247    }
2248}