ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    channel::{
  13        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  14        oneshot,
  15    },
  16    future::{BoxFuture, Shared},
  17    select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model,
  21    ModelContext, SemanticVersion, Task, WeakModel,
  22};
  23use parking_lot::Mutex;
  24use rpc::{
  25    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  26    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  27    RpcError,
  28};
  29use smol::{
  30    fs,
  31    process::{self, Child, Stdio},
  32};
  33use std::{
  34    any::TypeId,
  35    collections::VecDeque,
  36    ffi::OsStr,
  37    fmt,
  38    ops::ControlFlow,
  39    path::{Path, PathBuf},
  40    sync::{
  41        atomic::{AtomicU32, Ordering::SeqCst},
  42        Arc, Weak,
  43    },
  44    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
  45};
  46use tempfile::TempDir;
  47use util::ResultExt;
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  51)]
  52pub struct SshProjectId(pub u64);
  53
  54#[derive(Clone)]
  55pub struct SshSocket {
  56    connection_options: SshConnectionOptions,
  57    socket_path: PathBuf,
  58}
  59
  60#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  61pub struct SshConnectionOptions {
  62    pub host: String,
  63    pub username: Option<String>,
  64    pub port: Option<u16>,
  65    pub password: Option<String>,
  66    pub args: Option<Vec<String>>,
  67}
  68
  69impl SshConnectionOptions {
  70    pub fn parse_command_line(input: &str) -> Result<Self> {
  71        let input = input.trim_start_matches("ssh ");
  72        let mut hostname: Option<String> = None;
  73        let mut username: Option<String> = None;
  74        let mut port: Option<u16> = None;
  75        let mut args = Vec::new();
  76
  77        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
  78        const ALLOWED_OPTS: &[&str] = &[
  79            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
  80        ];
  81        const ALLOWED_ARGS: &[&str] = &[
  82            "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
  83            "-w",
  84        ];
  85
  86        let mut tokens = shlex::split(input)
  87            .ok_or_else(|| anyhow!("invalid input"))?
  88            .into_iter();
  89
  90        'outer: while let Some(arg) = tokens.next() {
  91            if ALLOWED_OPTS.contains(&(&arg as &str)) {
  92                args.push(arg.to_string());
  93                continue;
  94            }
  95            if arg == "-p" {
  96                port = tokens.next().and_then(|arg| arg.parse().ok());
  97                continue;
  98            } else if let Some(p) = arg.strip_prefix("-p") {
  99                port = p.parse().ok();
 100                continue;
 101            }
 102            if arg == "-l" {
 103                username = tokens.next();
 104                continue;
 105            } else if let Some(l) = arg.strip_prefix("-l") {
 106                username = Some(l.to_string());
 107                continue;
 108            }
 109            for a in ALLOWED_ARGS {
 110                if arg == *a {
 111                    args.push(arg);
 112                    if let Some(next) = tokens.next() {
 113                        args.push(next);
 114                    }
 115                    continue 'outer;
 116                } else if arg.starts_with(a) {
 117                    args.push(arg);
 118                    continue 'outer;
 119                }
 120            }
 121            if arg.starts_with("-") || hostname.is_some() {
 122                anyhow::bail!("unsupported argument: {:?}", arg);
 123            }
 124            let mut input = &arg as &str;
 125            if let Some((u, rest)) = input.split_once('@') {
 126                input = rest;
 127                username = Some(u.to_string());
 128            }
 129            if let Some((rest, p)) = input.split_once(':') {
 130                input = rest;
 131                port = p.parse().ok()
 132            }
 133            hostname = Some(input.to_string())
 134        }
 135
 136        let Some(hostname) = hostname else {
 137            anyhow::bail!("missing hostname");
 138        };
 139
 140        Ok(Self {
 141            host: hostname.to_string(),
 142            username: username.clone(),
 143            port,
 144            password: None,
 145            args: Some(args),
 146        })
 147    }
 148
 149    pub fn ssh_url(&self) -> String {
 150        let mut result = String::from("ssh://");
 151        if let Some(username) = &self.username {
 152            result.push_str(username);
 153            result.push('@');
 154        }
 155        result.push_str(&self.host);
 156        if let Some(port) = self.port {
 157            result.push(':');
 158            result.push_str(&port.to_string());
 159        }
 160        result
 161    }
 162
 163    pub fn additional_args(&self) -> Option<&Vec<String>> {
 164        self.args.as_ref()
 165    }
 166
 167    fn scp_url(&self) -> String {
 168        if let Some(username) = &self.username {
 169            format!("{}@{}", username, self.host)
 170        } else {
 171            self.host.clone()
 172        }
 173    }
 174
 175    pub fn connection_string(&self) -> String {
 176        let host = if let Some(username) = &self.username {
 177            format!("{}@{}", username, self.host)
 178        } else {
 179            self.host.clone()
 180        };
 181        if let Some(port) = &self.port {
 182            format!("{}:{}", host, port)
 183        } else {
 184            host
 185        }
 186    }
 187
 188    // Uniquely identifies dev server projects on a remote host. Needs to be
 189    // stable for the same dev server project.
 190    pub fn remote_server_identifier(&self) -> String {
 191        let mut identifier = format!("dev-server-{:?}", self.host);
 192        if let Some(username) = self.username.as_ref() {
 193            identifier.push('-');
 194            identifier.push_str(&username);
 195        }
 196        identifier
 197    }
 198}
 199
 200#[derive(Copy, Clone, Debug)]
 201pub struct SshPlatform {
 202    pub os: &'static str,
 203    pub arch: &'static str,
 204}
 205
 206impl SshPlatform {
 207    pub fn triple(&self) -> Option<String> {
 208        Some(format!(
 209            "{}-{}",
 210            self.arch,
 211            match self.os {
 212                "linux" => "unknown-linux-gnu",
 213                "macos" => "apple-darwin",
 214                _ => return None,
 215            }
 216        ))
 217    }
 218}
 219
 220pub enum ServerBinary {
 221    LocalBinary(PathBuf),
 222    ReleaseUrl { url: String, body: String },
 223}
 224
 225pub trait SshClientDelegate: Send + Sync {
 226    fn ask_password(
 227        &self,
 228        prompt: String,
 229        cx: &mut AsyncAppContext,
 230    ) -> oneshot::Receiver<Result<String>>;
 231    fn remote_server_binary_path(
 232        &self,
 233        platform: SshPlatform,
 234        cx: &mut AsyncAppContext,
 235    ) -> Result<PathBuf>;
 236    fn get_server_binary(
 237        &self,
 238        platform: SshPlatform,
 239        cx: &mut AsyncAppContext,
 240    ) -> oneshot::Receiver<Result<(ServerBinary, SemanticVersion)>>;
 241    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 242}
 243
 244impl SshSocket {
 245    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 246        let mut command = process::Command::new("ssh");
 247        self.ssh_options(&mut command)
 248            .arg(self.connection_options.ssh_url())
 249            .arg(program);
 250        command
 251    }
 252
 253    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 254        command
 255            .stdin(Stdio::piped())
 256            .stdout(Stdio::piped())
 257            .stderr(Stdio::piped())
 258            .args(["-o", "ControlMaster=no", "-o"])
 259            .arg(format!("ControlPath={}", self.socket_path.display()))
 260    }
 261
 262    fn ssh_args(&self) -> Vec<String> {
 263        vec![
 264            "-o".to_string(),
 265            "ControlMaster=no".to_string(),
 266            "-o".to_string(),
 267            format!("ControlPath={}", self.socket_path.display()),
 268            self.connection_options.ssh_url(),
 269        ]
 270    }
 271}
 272
 273async fn run_cmd(command: &mut process::Command) -> Result<String> {
 274    let output = command.output().await?;
 275    if output.status.success() {
 276        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 277    } else {
 278        Err(anyhow!(
 279            "failed to run command: {}",
 280            String::from_utf8_lossy(&output.stderr)
 281        ))
 282    }
 283}
 284
 285const MAX_MISSED_HEARTBEATS: usize = 5;
 286const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 287const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 288
 289const MAX_RECONNECT_ATTEMPTS: usize = 3;
 290
 291enum State {
 292    Connecting,
 293    Connected {
 294        ssh_connection: Arc<dyn RemoteConnection>,
 295        delegate: Arc<dyn SshClientDelegate>,
 296
 297        multiplex_task: Task<Result<()>>,
 298        heartbeat_task: Task<Result<()>>,
 299    },
 300    HeartbeatMissed {
 301        missed_heartbeats: usize,
 302
 303        ssh_connection: Arc<dyn RemoteConnection>,
 304        delegate: Arc<dyn SshClientDelegate>,
 305
 306        multiplex_task: Task<Result<()>>,
 307        heartbeat_task: Task<Result<()>>,
 308    },
 309    Reconnecting,
 310    ReconnectFailed {
 311        ssh_connection: Arc<dyn RemoteConnection>,
 312        delegate: Arc<dyn SshClientDelegate>,
 313
 314        error: anyhow::Error,
 315        attempts: usize,
 316    },
 317    ReconnectExhausted,
 318    ServerNotRunning,
 319}
 320
 321impl fmt::Display for State {
 322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 323        match self {
 324            Self::Connecting => write!(f, "connecting"),
 325            Self::Connected { .. } => write!(f, "connected"),
 326            Self::Reconnecting => write!(f, "reconnecting"),
 327            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 328            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 329            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 330            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 331        }
 332    }
 333}
 334
 335impl State {
 336    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 337        match self {
 338            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 339            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 340            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 341            _ => None,
 342        }
 343    }
 344
 345    fn can_reconnect(&self) -> bool {
 346        match self {
 347            Self::Connected { .. }
 348            | Self::HeartbeatMissed { .. }
 349            | Self::ReconnectFailed { .. } => true,
 350            State::Connecting
 351            | State::Reconnecting
 352            | State::ReconnectExhausted
 353            | State::ServerNotRunning => false,
 354        }
 355    }
 356
 357    fn is_reconnect_failed(&self) -> bool {
 358        matches!(self, Self::ReconnectFailed { .. })
 359    }
 360
 361    fn is_reconnect_exhausted(&self) -> bool {
 362        matches!(self, Self::ReconnectExhausted { .. })
 363    }
 364
 365    fn is_server_not_running(&self) -> bool {
 366        matches!(self, Self::ServerNotRunning)
 367    }
 368
 369    fn is_reconnecting(&self) -> bool {
 370        matches!(self, Self::Reconnecting { .. })
 371    }
 372
 373    fn heartbeat_recovered(self) -> Self {
 374        match self {
 375            Self::HeartbeatMissed {
 376                ssh_connection,
 377                delegate,
 378                multiplex_task,
 379                heartbeat_task,
 380                ..
 381            } => Self::Connected {
 382                ssh_connection,
 383                delegate,
 384                multiplex_task,
 385                heartbeat_task,
 386            },
 387            _ => self,
 388        }
 389    }
 390
 391    fn heartbeat_missed(self) -> Self {
 392        match self {
 393            Self::Connected {
 394                ssh_connection,
 395                delegate,
 396                multiplex_task,
 397                heartbeat_task,
 398            } => Self::HeartbeatMissed {
 399                missed_heartbeats: 1,
 400                ssh_connection,
 401                delegate,
 402                multiplex_task,
 403                heartbeat_task,
 404            },
 405            Self::HeartbeatMissed {
 406                missed_heartbeats,
 407                ssh_connection,
 408                delegate,
 409                multiplex_task,
 410                heartbeat_task,
 411            } => Self::HeartbeatMissed {
 412                missed_heartbeats: missed_heartbeats + 1,
 413                ssh_connection,
 414                delegate,
 415                multiplex_task,
 416                heartbeat_task,
 417            },
 418            _ => self,
 419        }
 420    }
 421}
 422
 423/// The state of the ssh connection.
 424#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 425pub enum ConnectionState {
 426    Connecting,
 427    Connected,
 428    HeartbeatMissed,
 429    Reconnecting,
 430    Disconnected,
 431}
 432
 433impl From<&State> for ConnectionState {
 434    fn from(value: &State) -> Self {
 435        match value {
 436            State::Connecting => Self::Connecting,
 437            State::Connected { .. } => Self::Connected,
 438            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 439            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 440            State::ReconnectExhausted => Self::Disconnected,
 441            State::ServerNotRunning => Self::Disconnected,
 442        }
 443    }
 444}
 445
 446pub struct SshRemoteClient {
 447    client: Arc<ChannelClient>,
 448    unique_identifier: String,
 449    connection_options: SshConnectionOptions,
 450    state: Arc<Mutex<Option<State>>>,
 451}
 452
 453#[derive(Debug)]
 454pub enum SshRemoteEvent {
 455    Disconnected,
 456}
 457
 458impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 459
 460impl SshRemoteClient {
 461    pub fn new(
 462        unique_identifier: String,
 463        connection_options: SshConnectionOptions,
 464        cancellation: oneshot::Receiver<()>,
 465        delegate: Arc<dyn SshClientDelegate>,
 466        cx: &mut AppContext,
 467    ) -> Task<Result<Option<Model<Self>>>> {
 468        cx.spawn(|mut cx| async move {
 469            let success = Box::pin(async move {
 470                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 471                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 472                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 473
 474                let client =
 475                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 476                let this = cx.new_model(|_| Self {
 477                    client: client.clone(),
 478                    unique_identifier: unique_identifier.clone(),
 479                    connection_options: connection_options.clone(),
 480                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 481                })?;
 482
 483                let ssh_connection = cx
 484                    .update(|cx| {
 485                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 486                            pool.connect(connection_options, &delegate, cx)
 487                        })
 488                    })?
 489                    .await
 490                    .map_err(|e| e.cloned())?;
 491                let remote_binary_path = ssh_connection
 492                    .get_remote_binary_path(&delegate, false, &mut cx)
 493                    .await?;
 494
 495                let io_task = ssh_connection.start_proxy(
 496                    remote_binary_path,
 497                    unique_identifier,
 498                    false,
 499                    incoming_tx,
 500                    outgoing_rx,
 501                    connection_activity_tx,
 502                    delegate.clone(),
 503                    &mut cx,
 504                );
 505
 506                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 507
 508                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 509                    log::error!("failed to establish connection: {}", error);
 510                    return Err(error);
 511                }
 512
 513                let heartbeat_task =
 514                    Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 515
 516                this.update(&mut cx, |this, _| {
 517                    *this.state.lock() = Some(State::Connected {
 518                        ssh_connection,
 519                        delegate,
 520                        multiplex_task,
 521                        heartbeat_task,
 522                    });
 523                })?;
 524
 525                Ok(Some(this))
 526            });
 527
 528            select! {
 529                _ = cancellation.fuse() => {
 530                    Ok(None)
 531                }
 532                result = success.fuse() =>  result
 533            }
 534        })
 535    }
 536
 537    pub fn shutdown_processes<T: RequestMessage>(
 538        &self,
 539        shutdown_request: Option<T>,
 540    ) -> Option<impl Future<Output = ()>> {
 541        let state = self.state.lock().take()?;
 542        log::info!("shutting down ssh processes");
 543
 544        let State::Connected {
 545            multiplex_task,
 546            heartbeat_task,
 547            ssh_connection,
 548            delegate,
 549        } = state
 550        else {
 551            return None;
 552        };
 553
 554        let client = self.client.clone();
 555
 556        Some(async move {
 557            if let Some(shutdown_request) = shutdown_request {
 558                client.send(shutdown_request).log_err();
 559                // We wait 50ms instead of waiting for a response, because
 560                // waiting for a response would require us to wait on the main thread
 561                // which we want to avoid in an `on_app_quit` callback.
 562                smol::Timer::after(Duration::from_millis(50)).await;
 563            }
 564
 565            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 566            // child of master_process.
 567            drop(multiplex_task);
 568            // Now drop the rest of state, which kills master process.
 569            drop(heartbeat_task);
 570            drop(ssh_connection);
 571            drop(delegate);
 572        })
 573    }
 574
 575    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 576        let mut lock = self.state.lock();
 577
 578        let can_reconnect = lock
 579            .as_ref()
 580            .map(|state| state.can_reconnect())
 581            .unwrap_or(false);
 582        if !can_reconnect {
 583            let error = if let Some(state) = lock.as_ref() {
 584                format!("invalid state, cannot reconnect while in state {state}")
 585            } else {
 586                "no state set".to_string()
 587            };
 588            log::info!("aborting reconnect, because not in state that allows reconnecting");
 589            return Err(anyhow!(error));
 590        }
 591
 592        let state = lock.take().unwrap();
 593        let (attempts, ssh_connection, delegate) = match state {
 594            State::Connected {
 595                ssh_connection,
 596                delegate,
 597                multiplex_task,
 598                heartbeat_task,
 599            }
 600            | State::HeartbeatMissed {
 601                ssh_connection,
 602                delegate,
 603                multiplex_task,
 604                heartbeat_task,
 605                ..
 606            } => {
 607                drop(multiplex_task);
 608                drop(heartbeat_task);
 609                (0, ssh_connection, delegate)
 610            }
 611            State::ReconnectFailed {
 612                attempts,
 613                ssh_connection,
 614                delegate,
 615                ..
 616            } => (attempts, ssh_connection, delegate),
 617            State::Connecting
 618            | State::Reconnecting
 619            | State::ReconnectExhausted
 620            | State::ServerNotRunning => unreachable!(),
 621        };
 622
 623        let attempts = attempts + 1;
 624        if attempts > MAX_RECONNECT_ATTEMPTS {
 625            log::error!(
 626                "Failed to reconnect to after {} attempts, giving up",
 627                MAX_RECONNECT_ATTEMPTS
 628            );
 629            drop(lock);
 630            self.set_state(State::ReconnectExhausted, cx);
 631            return Ok(());
 632        }
 633        drop(lock);
 634
 635        self.set_state(State::Reconnecting, cx);
 636
 637        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 638
 639        let unique_identifier = self.unique_identifier.clone();
 640        let client = self.client.clone();
 641        let reconnect_task = cx.spawn(|this, mut cx| async move {
 642            macro_rules! failed {
 643                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 644                    return State::ReconnectFailed {
 645                        error: anyhow!($error),
 646                        attempts: $attempts,
 647                        ssh_connection: $ssh_connection,
 648                        delegate: $delegate,
 649                    };
 650                };
 651            }
 652
 653            if let Err(error) = ssh_connection
 654                .kill()
 655                .await
 656                .context("Failed to kill ssh process")
 657            {
 658                failed!(error, attempts, ssh_connection, delegate);
 659            };
 660
 661            let connection_options = ssh_connection.connection_options();
 662
 663            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 664            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 665            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 666
 667            let (ssh_connection, io_task) = match async {
 668                let ssh_connection = cx
 669                    .update_global(|pool: &mut ConnectionPool, cx| {
 670                        pool.connect(connection_options, &delegate, cx)
 671                    })?
 672                    .await
 673                    .map_err(|error| error.cloned())?;
 674
 675                let remote_binary_path = ssh_connection
 676                    .get_remote_binary_path(&delegate, true, &mut cx)
 677                    .await?;
 678
 679                let io_task = ssh_connection.start_proxy(
 680                    remote_binary_path,
 681                    unique_identifier,
 682                    true,
 683                    incoming_tx,
 684                    outgoing_rx,
 685                    connection_activity_tx,
 686                    delegate.clone(),
 687                    &mut cx,
 688                );
 689                anyhow::Ok((ssh_connection, io_task))
 690            }
 691            .await
 692            {
 693                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 694                Err(error) => {
 695                    failed!(error, attempts, ssh_connection, delegate);
 696                }
 697            };
 698
 699            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 700            client.reconnect(incoming_rx, outgoing_tx, &cx);
 701
 702            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 703                failed!(error, attempts, ssh_connection, delegate);
 704            };
 705
 706            State::Connected {
 707                ssh_connection,
 708                delegate,
 709                multiplex_task,
 710                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 711            }
 712        });
 713
 714        cx.spawn(|this, mut cx| async move {
 715            let new_state = reconnect_task.await;
 716            this.update(&mut cx, |this, cx| {
 717                this.try_set_state(cx, |old_state| {
 718                    if old_state.is_reconnecting() {
 719                        match &new_state {
 720                            State::Connecting
 721                            | State::Reconnecting { .. }
 722                            | State::HeartbeatMissed { .. }
 723                            | State::ServerNotRunning => {}
 724                            State::Connected { .. } => {
 725                                log::info!("Successfully reconnected");
 726                            }
 727                            State::ReconnectFailed {
 728                                error, attempts, ..
 729                            } => {
 730                                log::error!(
 731                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 732                                    attempts,
 733                                    error
 734                                );
 735                            }
 736                            State::ReconnectExhausted => {
 737                                log::error!("Reconnect attempt failed and all attempts exhausted");
 738                            }
 739                        }
 740                        Some(new_state)
 741                    } else {
 742                        None
 743                    }
 744                });
 745
 746                if this.state_is(State::is_reconnect_failed) {
 747                    this.reconnect(cx)
 748                } else if this.state_is(State::is_reconnect_exhausted) {
 749                    Ok(())
 750                } else {
 751                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 752                    Ok(())
 753                }
 754            })
 755        })
 756        .detach_and_log_err(cx);
 757
 758        Ok(())
 759    }
 760
 761    fn heartbeat(
 762        this: WeakModel<Self>,
 763        mut connection_activity_rx: mpsc::Receiver<()>,
 764        cx: &mut AsyncAppContext,
 765    ) -> Task<Result<()>> {
 766        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 767            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 768        };
 769
 770        cx.spawn(|mut cx| {
 771            let this = this.clone();
 772            async move {
 773                let mut missed_heartbeats = 0;
 774
 775                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 776                futures::pin_mut!(keepalive_timer);
 777
 778                loop {
 779                    select_biased! {
 780                        result = connection_activity_rx.next().fuse() => {
 781                            if result.is_none() {
 782                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 783                                return Ok(());
 784                            }
 785
 786                            if missed_heartbeats != 0 {
 787                                missed_heartbeats = 0;
 788                                this.update(&mut cx, |this, mut cx| {
 789                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 790                                })?;
 791                            }
 792                        }
 793                        _ = keepalive_timer => {
 794                            log::debug!("Sending heartbeat to server...");
 795
 796                            let result = select_biased! {
 797                                _ = connection_activity_rx.next().fuse() => {
 798                                    Ok(())
 799                                }
 800                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 801                                    ping_result
 802                                }
 803                            };
 804
 805                            if result.is_err() {
 806                                missed_heartbeats += 1;
 807                                log::warn!(
 808                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 809                                    HEARTBEAT_TIMEOUT,
 810                                    missed_heartbeats,
 811                                    MAX_MISSED_HEARTBEATS
 812                                );
 813                            } else if missed_heartbeats != 0 {
 814                                missed_heartbeats = 0;
 815                            } else {
 816                                continue;
 817                            }
 818
 819                            let result = this.update(&mut cx, |this, mut cx| {
 820                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 821                            })?;
 822                            if result.is_break() {
 823                                return Ok(());
 824                            }
 825                        }
 826                    }
 827
 828                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 829                }
 830            }
 831        })
 832    }
 833
 834    fn handle_heartbeat_result(
 835        &mut self,
 836        missed_heartbeats: usize,
 837        cx: &mut ModelContext<Self>,
 838    ) -> ControlFlow<()> {
 839        let state = self.state.lock().take().unwrap();
 840        let next_state = if missed_heartbeats > 0 {
 841            state.heartbeat_missed()
 842        } else {
 843            state.heartbeat_recovered()
 844        };
 845
 846        self.set_state(next_state, cx);
 847
 848        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 849            log::error!(
 850                "Missed last {} heartbeats. Reconnecting...",
 851                missed_heartbeats
 852            );
 853
 854            self.reconnect(cx)
 855                .context("failed to start reconnect process after missing heartbeats")
 856                .log_err();
 857            ControlFlow::Break(())
 858        } else {
 859            ControlFlow::Continue(())
 860        }
 861    }
 862
 863    fn monitor(
 864        this: WeakModel<Self>,
 865        io_task: Task<Result<i32>>,
 866        cx: &AsyncAppContext,
 867    ) -> Task<Result<()>> {
 868        cx.spawn(|mut cx| async move {
 869            let result = io_task.await;
 870
 871            match result {
 872                Ok(exit_code) => {
 873                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 874                        match error {
 875                            ProxyLaunchError::ServerNotRunning => {
 876                                log::error!("failed to reconnect because server is not running");
 877                                this.update(&mut cx, |this, cx| {
 878                                    this.set_state(State::ServerNotRunning, cx);
 879                                })?;
 880                            }
 881                        }
 882                    } else if exit_code > 0 {
 883                        log::error!("proxy process terminated unexpectedly");
 884                        this.update(&mut cx, |this, cx| {
 885                            this.reconnect(cx).ok();
 886                        })?;
 887                    }
 888                }
 889                Err(error) => {
 890                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 891                    this.update(&mut cx, |this, cx| {
 892                        this.reconnect(cx).ok();
 893                    })?;
 894                }
 895            }
 896
 897            Ok(())
 898        })
 899    }
 900
 901    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 902        self.state.lock().as_ref().map_or(false, check)
 903    }
 904
 905    fn try_set_state(
 906        &self,
 907        cx: &mut ModelContext<Self>,
 908        map: impl FnOnce(&State) -> Option<State>,
 909    ) {
 910        let mut lock = self.state.lock();
 911        let new_state = lock.as_ref().and_then(map);
 912
 913        if let Some(new_state) = new_state {
 914            lock.replace(new_state);
 915            cx.notify();
 916        }
 917    }
 918
 919    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
 920        log::info!("setting state to '{}'", &state);
 921
 922        let is_reconnect_exhausted = state.is_reconnect_exhausted();
 923        let is_server_not_running = state.is_server_not_running();
 924        self.state.lock().replace(state);
 925
 926        if is_reconnect_exhausted || is_server_not_running {
 927            cx.emit(SshRemoteEvent::Disconnected);
 928        }
 929        cx.notify();
 930    }
 931
 932    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 933        self.client.subscribe_to_entity(remote_id, entity);
 934    }
 935
 936    pub fn ssh_args(&self) -> Option<Vec<String>> {
 937        self.state
 938            .lock()
 939            .as_ref()
 940            .and_then(|state| state.ssh_connection())
 941            .map(|ssh_connection| ssh_connection.ssh_args())
 942    }
 943
 944    pub fn proto_client(&self) -> AnyProtoClient {
 945        self.client.clone().into()
 946    }
 947
 948    pub fn connection_string(&self) -> String {
 949        self.connection_options.connection_string()
 950    }
 951
 952    pub fn connection_options(&self) -> SshConnectionOptions {
 953        self.connection_options.clone()
 954    }
 955
 956    pub fn connection_state(&self) -> ConnectionState {
 957        self.state
 958            .lock()
 959            .as_ref()
 960            .map(ConnectionState::from)
 961            .unwrap_or(ConnectionState::Disconnected)
 962    }
 963
 964    pub fn is_disconnected(&self) -> bool {
 965        self.connection_state() == ConnectionState::Disconnected
 966    }
 967
 968    #[cfg(any(test, feature = "test-support"))]
 969    pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
 970        let opts = self.connection_options();
 971        client_cx.spawn(|cx| async move {
 972            let connection = cx
 973                .update_global(|c: &mut ConnectionPool, _| {
 974                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
 975                        c.clone()
 976                    } else {
 977                        panic!("missing test connection")
 978                    }
 979                })
 980                .unwrap()
 981                .await
 982                .unwrap();
 983
 984            connection.simulate_disconnect(&cx);
 985        })
 986    }
 987
 988    #[cfg(any(test, feature = "test-support"))]
 989    pub fn fake_server(
 990        client_cx: &mut gpui::TestAppContext,
 991        server_cx: &mut gpui::TestAppContext,
 992    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
 993        let port = client_cx
 994            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
 995        let opts = SshConnectionOptions {
 996            host: "<fake>".to_string(),
 997            port: Some(port),
 998            ..Default::default()
 999        };
1000        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1001        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1002        let server_client =
1003            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1004        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1005            connection_options: opts.clone(),
1006            server_cx: fake::SendableCx::new(server_cx.to_async()),
1007            server_channel: server_client.clone(),
1008        });
1009
1010        client_cx.update(|cx| {
1011            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1012                c.connections.insert(
1013                    opts.clone(),
1014                    ConnectionPoolEntry::Connecting(
1015                        cx.foreground_executor()
1016                            .spawn({
1017                                let connection = connection.clone();
1018                                async move { Ok(connection.clone()) }
1019                            })
1020                            .shared(),
1021                    ),
1022                );
1023            })
1024        });
1025
1026        (opts, server_client)
1027    }
1028
1029    #[cfg(any(test, feature = "test-support"))]
1030    pub async fn fake_client(
1031        opts: SshConnectionOptions,
1032        client_cx: &mut gpui::TestAppContext,
1033    ) -> Model<Self> {
1034        let (_tx, rx) = oneshot::channel();
1035        client_cx
1036            .update(|cx| Self::new("fake".to_string(), opts, rx, Arc::new(fake::Delegate), cx))
1037            .await
1038            .unwrap()
1039            .unwrap()
1040    }
1041}
1042
1043enum ConnectionPoolEntry {
1044    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1045    Connected(Weak<dyn RemoteConnection>),
1046}
1047
1048#[derive(Default)]
1049struct ConnectionPool {
1050    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1051}
1052
1053impl Global for ConnectionPool {}
1054
1055impl ConnectionPool {
1056    pub fn connect(
1057        &mut self,
1058        opts: SshConnectionOptions,
1059        delegate: &Arc<dyn SshClientDelegate>,
1060        cx: &mut AppContext,
1061    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1062        let connection = self.connections.get(&opts);
1063        match connection {
1064            Some(ConnectionPoolEntry::Connecting(task)) => {
1065                let delegate = delegate.clone();
1066                cx.spawn(|mut cx| async move {
1067                    delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx);
1068                })
1069                .detach();
1070                return task.clone();
1071            }
1072            Some(ConnectionPoolEntry::Connected(ssh)) => {
1073                if let Some(ssh) = ssh.upgrade() {
1074                    if !ssh.has_been_killed() {
1075                        return Task::ready(Ok(ssh)).shared();
1076                    }
1077                }
1078                self.connections.remove(&opts);
1079            }
1080            None => {}
1081        }
1082
1083        let task = cx
1084            .spawn({
1085                let opts = opts.clone();
1086                let delegate = delegate.clone();
1087                |mut cx| async move {
1088                    let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx)
1089                        .await
1090                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1091
1092                    cx.update_global(|pool: &mut Self, _| {
1093                        debug_assert!(matches!(
1094                            pool.connections.get(&opts),
1095                            Some(ConnectionPoolEntry::Connecting(_))
1096                        ));
1097                        match connection {
1098                            Ok(connection) => {
1099                                pool.connections.insert(
1100                                    opts.clone(),
1101                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1102                                );
1103                                Ok(connection)
1104                            }
1105                            Err(error) => {
1106                                pool.connections.remove(&opts);
1107                                Err(Arc::new(error))
1108                            }
1109                        }
1110                    })?
1111                }
1112            })
1113            .shared();
1114
1115        self.connections
1116            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1117        task
1118    }
1119}
1120
1121impl From<SshRemoteClient> for AnyProtoClient {
1122    fn from(client: SshRemoteClient) -> Self {
1123        AnyProtoClient::new(client.client.clone())
1124    }
1125}
1126
1127#[async_trait(?Send)]
1128trait RemoteConnection: Send + Sync {
1129    #[allow(clippy::too_many_arguments)]
1130    fn start_proxy(
1131        &self,
1132        remote_binary_path: PathBuf,
1133        unique_identifier: String,
1134        reconnect: bool,
1135        incoming_tx: UnboundedSender<Envelope>,
1136        outgoing_rx: UnboundedReceiver<Envelope>,
1137        connection_activity_tx: Sender<()>,
1138        delegate: Arc<dyn SshClientDelegate>,
1139        cx: &mut AsyncAppContext,
1140    ) -> Task<Result<i32>>;
1141    async fn get_remote_binary_path(
1142        &self,
1143        delegate: &Arc<dyn SshClientDelegate>,
1144        reconnect: bool,
1145        cx: &mut AsyncAppContext,
1146    ) -> Result<PathBuf>;
1147    async fn kill(&self) -> Result<()>;
1148    fn has_been_killed(&self) -> bool;
1149    fn ssh_args(&self) -> Vec<String>;
1150    fn connection_options(&self) -> SshConnectionOptions;
1151
1152    #[cfg(any(test, feature = "test-support"))]
1153    fn simulate_disconnect(&self, _: &AsyncAppContext) {}
1154}
1155
1156struct SshRemoteConnection {
1157    socket: SshSocket,
1158    master_process: Mutex<Option<process::Child>>,
1159    platform: SshPlatform,
1160    _temp_dir: TempDir,
1161}
1162
1163#[async_trait(?Send)]
1164impl RemoteConnection for SshRemoteConnection {
1165    async fn kill(&self) -> Result<()> {
1166        let Some(mut process) = self.master_process.lock().take() else {
1167            return Ok(());
1168        };
1169        process.kill().ok();
1170        process.status().await?;
1171        Ok(())
1172    }
1173
1174    fn has_been_killed(&self) -> bool {
1175        self.master_process.lock().is_none()
1176    }
1177
1178    fn ssh_args(&self) -> Vec<String> {
1179        self.socket.ssh_args()
1180    }
1181
1182    fn connection_options(&self) -> SshConnectionOptions {
1183        self.socket.connection_options.clone()
1184    }
1185
1186    async fn get_remote_binary_path(
1187        &self,
1188        delegate: &Arc<dyn SshClientDelegate>,
1189        reconnect: bool,
1190        cx: &mut AsyncAppContext,
1191    ) -> Result<PathBuf> {
1192        let platform = self.platform;
1193        let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1194        if !reconnect {
1195            self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1196                .await?;
1197        }
1198
1199        let socket = self.socket.clone();
1200        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1201        Ok(remote_binary_path)
1202    }
1203
1204    fn start_proxy(
1205        &self,
1206        remote_binary_path: PathBuf,
1207        unique_identifier: String,
1208        reconnect: bool,
1209        incoming_tx: UnboundedSender<Envelope>,
1210        outgoing_rx: UnboundedReceiver<Envelope>,
1211        connection_activity_tx: Sender<()>,
1212        delegate: Arc<dyn SshClientDelegate>,
1213        cx: &mut AsyncAppContext,
1214    ) -> Task<Result<i32>> {
1215        delegate.set_status(Some("Starting proxy"), cx);
1216
1217        let mut start_proxy_command = format!(
1218            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1219            std::env::var("RUST_LOG").unwrap_or_default(),
1220            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1221            remote_binary_path,
1222            unique_identifier,
1223        );
1224        if reconnect {
1225            start_proxy_command.push_str(" --reconnect");
1226        }
1227
1228        let ssh_proxy_process = match self
1229            .socket
1230            .ssh_command(start_proxy_command)
1231            // IMPORTANT: we kill this process when we drop the task that uses it.
1232            .kill_on_drop(true)
1233            .spawn()
1234        {
1235            Ok(process) => process,
1236            Err(error) => {
1237                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)))
1238            }
1239        };
1240
1241        Self::multiplex(
1242            ssh_proxy_process,
1243            incoming_tx,
1244            outgoing_rx,
1245            connection_activity_tx,
1246            &cx,
1247        )
1248    }
1249}
1250
1251impl SshRemoteConnection {
1252    #[cfg(not(unix))]
1253    async fn new(
1254        _connection_options: SshConnectionOptions,
1255        _delegate: Arc<dyn SshClientDelegate>,
1256        _cx: &mut AsyncAppContext,
1257    ) -> Result<Self> {
1258        Err(anyhow!("ssh is not supported on this platform"))
1259    }
1260
1261    #[cfg(unix)]
1262    async fn new(
1263        connection_options: SshConnectionOptions,
1264        delegate: Arc<dyn SshClientDelegate>,
1265        cx: &mut AsyncAppContext,
1266    ) -> Result<Self> {
1267        use futures::AsyncWriteExt as _;
1268        use futures::{io::BufReader, AsyncBufReadExt as _};
1269        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1270        use util::ResultExt as _;
1271
1272        delegate.set_status(Some("Connecting"), cx);
1273
1274        let url = connection_options.ssh_url();
1275        let temp_dir = tempfile::Builder::new()
1276            .prefix("zed-ssh-session")
1277            .tempdir()?;
1278
1279        // Create a domain socket listener to handle requests from the askpass program.
1280        let askpass_socket = temp_dir.path().join("askpass.sock");
1281        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1282        let listener =
1283            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1284
1285        let askpass_task = cx.spawn({
1286            let delegate = delegate.clone();
1287            |mut cx| async move {
1288                let mut askpass_opened_tx = Some(askpass_opened_tx);
1289
1290                while let Ok((mut stream, _)) = listener.accept().await {
1291                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1292                        askpass_opened_tx.send(()).ok();
1293                    }
1294                    let mut buffer = Vec::new();
1295                    let mut reader = BufReader::new(&mut stream);
1296                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1297                        buffer.clear();
1298                    }
1299                    let password_prompt = String::from_utf8_lossy(&buffer);
1300                    if let Some(password) = delegate
1301                        .ask_password(password_prompt.to_string(), &mut cx)
1302                        .await
1303                        .context("failed to get ssh password")
1304                        .and_then(|p| p)
1305                        .log_err()
1306                    {
1307                        stream.write_all(password.as_bytes()).await.log_err();
1308                    }
1309                }
1310            }
1311        });
1312
1313        // Create an askpass script that communicates back to this process.
1314        let askpass_script = format!(
1315            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1316            askpass_socket = askpass_socket.display(),
1317            print_args = "printf '%s\\0' \"$@\"",
1318            shebang = "#!/bin/sh",
1319        );
1320        let askpass_script_path = temp_dir.path().join("askpass.sh");
1321        fs::write(&askpass_script_path, askpass_script).await?;
1322        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1323
1324        // Start the master SSH process, which does not do anything except for establish
1325        // the connection and keep it open, allowing other ssh commands to reuse it
1326        // via a control socket.
1327        let socket_path = temp_dir.path().join("ssh.sock");
1328        let mut master_process = process::Command::new("ssh")
1329            .stdin(Stdio::null())
1330            .stdout(Stdio::piped())
1331            .stderr(Stdio::piped())
1332            .env("SSH_ASKPASS_REQUIRE", "force")
1333            .env("SSH_ASKPASS", &askpass_script_path)
1334            .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1335            .args([
1336                "-N",
1337                "-o",
1338                "ControlPersist=no",
1339                "-o",
1340                "ControlMaster=yes",
1341                "-o",
1342            ])
1343            .arg(format!("ControlPath={}", socket_path.display()))
1344            .arg(&url)
1345            .kill_on_drop(true)
1346            .spawn()?;
1347
1348        // Wait for this ssh process to close its stdout, indicating that authentication
1349        // has completed.
1350        let stdout = master_process.stdout.as_mut().unwrap();
1351        let mut output = Vec::new();
1352        let connection_timeout = Duration::from_secs(10);
1353
1354        let result = select_biased! {
1355            _ = askpass_opened_rx.fuse() => {
1356                // If the askpass script has opened, that means the user is typing
1357                // their password, in which case we don't want to timeout anymore,
1358                // since we know a connection has been established.
1359                stdout.read_to_end(&mut output).await?;
1360                Ok(())
1361            }
1362            result = stdout.read_to_end(&mut output).fuse() => {
1363                result?;
1364                Ok(())
1365            }
1366            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1367                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1368            }
1369        };
1370
1371        if let Err(e) = result {
1372            return Err(e.context("Failed to connect to host"));
1373        }
1374
1375        drop(askpass_task);
1376
1377        if master_process.try_status()?.is_some() {
1378            output.clear();
1379            let mut stderr = master_process.stderr.take().unwrap();
1380            stderr.read_to_end(&mut output).await?;
1381
1382            let error_message = format!(
1383                "failed to connect: {}",
1384                String::from_utf8_lossy(&output).trim()
1385            );
1386            Err(anyhow!(error_message))?;
1387        }
1388
1389        let socket = SshSocket {
1390            connection_options,
1391            socket_path,
1392        };
1393
1394        let os = run_cmd(socket.ssh_command("uname").arg("-s")).await?;
1395        let arch = run_cmd(socket.ssh_command("uname").arg("-m")).await?;
1396
1397        let os = match os.trim() {
1398            "Darwin" => "macos",
1399            "Linux" => "linux",
1400            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1401        };
1402        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1403            "aarch64"
1404        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1405            "x86_64"
1406        } else {
1407            Err(anyhow!("unknown uname architecture {arch:?}"))?
1408        };
1409
1410        let platform = SshPlatform { os, arch };
1411
1412        Ok(Self {
1413            socket,
1414            master_process: Mutex::new(Some(master_process)),
1415            platform,
1416            _temp_dir: temp_dir,
1417        })
1418    }
1419
1420    fn multiplex(
1421        mut ssh_proxy_process: Child,
1422        incoming_tx: UnboundedSender<Envelope>,
1423        mut outgoing_rx: UnboundedReceiver<Envelope>,
1424        mut connection_activity_tx: Sender<()>,
1425        cx: &AsyncAppContext,
1426    ) -> Task<Result<i32>> {
1427        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1428        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1429        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1430
1431        let mut stdin_buffer = Vec::new();
1432        let mut stdout_buffer = Vec::new();
1433        let mut stderr_buffer = Vec::new();
1434        let mut stderr_offset = 0;
1435
1436        let stdin_task = cx.background_executor().spawn(async move {
1437            while let Some(outgoing) = outgoing_rx.next().await {
1438                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1439            }
1440            anyhow::Ok(())
1441        });
1442
1443        let stdout_task = cx.background_executor().spawn({
1444            let mut connection_activity_tx = connection_activity_tx.clone();
1445            async move {
1446                loop {
1447                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1448                    let len = child_stdout.read(&mut stdout_buffer).await?;
1449
1450                    if len == 0 {
1451                        return anyhow::Ok(());
1452                    }
1453
1454                    if len < MESSAGE_LEN_SIZE {
1455                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1456                    }
1457
1458                    let message_len = message_len_from_buffer(&stdout_buffer);
1459                    let envelope =
1460                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1461                            .await?;
1462                    connection_activity_tx.try_send(()).ok();
1463                    incoming_tx.unbounded_send(envelope).ok();
1464                }
1465            }
1466        });
1467
1468        let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
1469            loop {
1470                stderr_buffer.resize(stderr_offset + 1024, 0);
1471
1472                let len = child_stderr
1473                    .read(&mut stderr_buffer[stderr_offset..])
1474                    .await?;
1475                if len == 0 {
1476                    return anyhow::Ok(());
1477                }
1478
1479                stderr_offset += len;
1480                let mut start_ix = 0;
1481                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1482                    .iter()
1483                    .position(|b| b == &b'\n')
1484                {
1485                    let line_ix = start_ix + ix;
1486                    let content = &stderr_buffer[start_ix..line_ix];
1487                    start_ix = line_ix + 1;
1488                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1489                        record.log(log::logger())
1490                    } else {
1491                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1492                    }
1493                }
1494                stderr_buffer.drain(0..start_ix);
1495                stderr_offset -= start_ix;
1496
1497                connection_activity_tx.try_send(()).ok();
1498            }
1499        });
1500
1501        cx.spawn(|_| async move {
1502            let result = futures::select! {
1503                result = stdin_task.fuse() => {
1504                    result.context("stdin")
1505                }
1506                result = stdout_task.fuse() => {
1507                    result.context("stdout")
1508                }
1509                result = stderr_task.fuse() => {
1510                    result.context("stderr")
1511                }
1512            };
1513
1514            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1515            match result {
1516                Ok(_) => Ok(status),
1517                Err(error) => Err(error),
1518            }
1519        })
1520    }
1521
1522    async fn ensure_server_binary(
1523        &self,
1524        delegate: &Arc<dyn SshClientDelegate>,
1525        dst_path: &Path,
1526        platform: SshPlatform,
1527        cx: &mut AsyncAppContext,
1528    ) -> Result<()> {
1529        let lock_file = dst_path.with_extension("lock");
1530        let lock_content = {
1531            let timestamp = SystemTime::now()
1532                .duration_since(UNIX_EPOCH)
1533                .context("failed to get timestamp")?
1534                .as_secs();
1535            let source_port = self.get_ssh_source_port().await?;
1536            format!("{} {}", source_port, timestamp)
1537        };
1538
1539        let lock_stale_age = Duration::from_secs(10 * 60);
1540        let max_wait_time = Duration::from_secs(10 * 60);
1541        let check_interval = Duration::from_secs(5);
1542        let start_time = Instant::now();
1543
1544        loop {
1545            let lock_acquired = self.create_lock_file(&lock_file, &lock_content).await?;
1546            if lock_acquired {
1547                delegate.set_status(Some("Acquired lock file on host"), cx);
1548                let result = self
1549                    .update_server_binary_if_needed(delegate, dst_path, platform, cx)
1550                    .await;
1551
1552                self.remove_lock_file(&lock_file).await.ok();
1553
1554                return result;
1555            } else {
1556                if let Ok(is_stale) = self.is_lock_stale(&lock_file, &lock_stale_age).await {
1557                    if is_stale {
1558                        delegate.set_status(
1559                            Some("Detected lock file on host being stale. Removing"),
1560                            cx,
1561                        );
1562                        self.remove_lock_file(&lock_file).await?;
1563                        continue;
1564                    } else {
1565                        if start_time.elapsed() > max_wait_time {
1566                            return Err(anyhow!("Timeout waiting for lock to be released"));
1567                        }
1568                        log::info!(
1569                            "Found lockfile: {:?}. Will check again in {:?}",
1570                            lock_file,
1571                            check_interval
1572                        );
1573                        delegate.set_status(
1574                            Some("Waiting for another Zed instance to finish uploading binary"),
1575                            cx,
1576                        );
1577                        smol::Timer::after(check_interval).await;
1578                        continue;
1579                    }
1580                } else {
1581                    // Unable to check lock, assume it's valid and wait
1582                    if start_time.elapsed() > max_wait_time {
1583                        return Err(anyhow!("Timeout waiting for lock to be released"));
1584                    }
1585                    smol::Timer::after(check_interval).await;
1586                    continue;
1587                }
1588            }
1589        }
1590    }
1591
1592    async fn get_ssh_source_port(&self) -> Result<String> {
1593        let output = run_cmd(
1594            self.socket
1595                .ssh_command("sh")
1596                .arg("-c")
1597                .arg(r#""echo $SSH_CLIENT | cut -d' ' -f2""#),
1598        )
1599        .await
1600        .context("failed to get source port from SSH_CLIENT on host")?;
1601
1602        Ok(output.trim().to_string())
1603    }
1604
1605    async fn create_lock_file(&self, lock_file: &Path, content: &str) -> Result<bool> {
1606        let parent_dir = lock_file
1607            .parent()
1608            .ok_or_else(|| anyhow!("Lock file path has no parent directory"))?;
1609
1610        let script = format!(
1611            r#"'mkdir -p "{parent_dir}" && [ ! -f "{lock_file}" ] && echo "{content}" > "{lock_file}" && echo "created" || echo "exists"'"#,
1612            parent_dir = parent_dir.display(),
1613            lock_file = lock_file.display(),
1614            content = content,
1615        );
1616
1617        let output = run_cmd(self.socket.ssh_command("sh").arg("-c").arg(&script))
1618            .await
1619            .with_context(|| format!("failed to create a lock file at {:?}", lock_file))?;
1620
1621        Ok(output.trim() == "created")
1622    }
1623
1624    fn generate_stale_check_script(lock_file: &Path, max_age: u64) -> String {
1625        format!(
1626            r#"
1627            if [ ! -f "{lock_file}" ]; then
1628                echo "lock file does not exist"
1629                exit 0
1630            fi
1631
1632            read -r port timestamp < "{lock_file}"
1633
1634            # Check if port is still active
1635            if command -v ss >/dev/null 2>&1; then
1636                if ! ss -n | grep -q ":$port[[:space:]]"; then
1637                    echo "ss reports port $port is not open"
1638                    exit 0
1639                fi
1640            elif command -v netstat >/dev/null 2>&1; then
1641                if ! netstat -n | grep -q ":$port[[:space:]]"; then
1642                    echo "netstat reports port $port is not open"
1643                    exit 0
1644                fi
1645            fi
1646
1647            # Check timestamp
1648            if [ $(( $(date +%s) - timestamp )) -gt {max_age} ]; then
1649                echo "timestamp in lockfile is too old"
1650            else
1651                echo "recent"
1652            fi"#,
1653            lock_file = lock_file.display(),
1654            max_age = max_age
1655        )
1656    }
1657
1658    async fn is_lock_stale(&self, lock_file: &Path, max_age: &Duration) -> Result<bool> {
1659        let script = format!(
1660            "'{}'",
1661            Self::generate_stale_check_script(lock_file, max_age.as_secs())
1662        );
1663
1664        let output = run_cmd(self.socket.ssh_command("sh").arg("-c").arg(&script))
1665            .await
1666            .with_context(|| {
1667                format!("failed to check whether lock file {:?} is stale", lock_file)
1668            })?;
1669
1670        let trimmed = output.trim();
1671        let is_stale = trimmed != "recent";
1672        log::info!("checked lockfile for staleness. stale: {is_stale}, output: {trimmed:?}");
1673        Ok(is_stale)
1674    }
1675
1676    async fn remove_lock_file(&self, lock_file: &Path) -> Result<()> {
1677        run_cmd(self.socket.ssh_command("rm").arg("-f").arg(lock_file))
1678            .await
1679            .context("failed to remove lock file")?;
1680        Ok(())
1681    }
1682
1683    async fn update_server_binary_if_needed(
1684        &self,
1685        delegate: &Arc<dyn SshClientDelegate>,
1686        dst_path: &Path,
1687        platform: SshPlatform,
1688        cx: &mut AsyncAppContext,
1689    ) -> Result<()> {
1690        if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1691            if let Ok(installed_version) =
1692                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1693            {
1694                log::info!("using cached server binary version {}", installed_version);
1695                return Ok(());
1696            }
1697        }
1698
1699        if self.is_binary_in_use(dst_path).await? {
1700            log::info!("server binary is opened by another process. not updating");
1701            delegate.set_status(
1702                Some("Skipping update of remote development server, since it's still in use"),
1703                cx,
1704            );
1705            return Ok(());
1706        }
1707
1708        let (binary, version) = delegate.get_server_binary(platform, cx).await??;
1709
1710        let mut remote_version = None;
1711        if cfg!(not(debug_assertions)) {
1712            if let Ok(installed_version) =
1713                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1714            {
1715                if let Ok(version) = installed_version.trim().parse::<SemanticVersion>() {
1716                    remote_version = Some(version);
1717                } else {
1718                    log::warn!("failed to parse version of remote server: {installed_version:?}",);
1719                }
1720            }
1721
1722            if let Some(remote_version) = remote_version {
1723                if remote_version == version {
1724                    log::info!("remote development server present and matching client version");
1725                    return Ok(());
1726                } else if remote_version > version {
1727                    let error = anyhow!("The version of the remote server ({}) is newer than the Zed version ({}). Please update Zed.", remote_version, version);
1728                    return Err(error);
1729                } else {
1730                    log::info!(
1731                        "remote development server has older version: {}. updating...",
1732                        remote_version
1733                    );
1734                }
1735            }
1736        }
1737
1738        match binary {
1739            ServerBinary::LocalBinary(src_path) => {
1740                self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
1741                    .await
1742            }
1743            ServerBinary::ReleaseUrl { url, body } => {
1744                self.download_binary_on_server(&url, &body, dst_path, delegate, cx)
1745                    .await
1746            }
1747        }
1748    }
1749
1750    async fn is_binary_in_use(&self, binary_path: &Path) -> Result<bool> {
1751        let script = format!(
1752            r#"'
1753            if command -v lsof >/dev/null 2>&1; then
1754                if lsof "{}" >/dev/null 2>&1; then
1755                    echo "in_use"
1756                    exit 0
1757                fi
1758            elif command -v fuser >/dev/null 2>&1; then
1759                if fuser "{}" >/dev/null 2>&1; then
1760                    echo "in_use"
1761                    exit 0
1762                fi
1763            fi
1764            echo "not_in_use"
1765            '"#,
1766            binary_path.display(),
1767            binary_path.display(),
1768        );
1769
1770        let output = run_cmd(self.socket.ssh_command("sh").arg("-c").arg(script))
1771            .await
1772            .context("failed to check if binary is in use")?;
1773
1774        Ok(output.trim() == "in_use")
1775    }
1776
1777    async fn download_binary_on_server(
1778        &self,
1779        url: &str,
1780        body: &str,
1781        dst_path: &Path,
1782        delegate: &Arc<dyn SshClientDelegate>,
1783        cx: &mut AsyncAppContext,
1784    ) -> Result<()> {
1785        let mut dst_path_gz = dst_path.to_path_buf();
1786        dst_path_gz.set_extension("gz");
1787
1788        if let Some(parent) = dst_path.parent() {
1789            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1790        }
1791
1792        delegate.set_status(Some("Downloading remote development server on host"), cx);
1793
1794        let script = format!(
1795            r#"
1796            if command -v wget >/dev/null 2>&1; then
1797                wget --max-redirect=5 --method=GET --header="Content-Type: application/json" --body-data='{}' '{}' -O '{}' && echo "wget"
1798            elif command -v curl >/dev/null 2>&1; then
1799                curl -L -X GET -H "Content-Type: application/json" -d '{}' '{}' -o '{}' && echo "curl"
1800            else
1801                echo "Neither curl nor wget is available" >&2
1802                exit 1
1803            fi
1804            "#,
1805            body.replace("'", r#"\'"#),
1806            url,
1807            dst_path_gz.display(),
1808            body.replace("'", r#"\'"#),
1809            url,
1810            dst_path_gz.display(),
1811        );
1812
1813        let output = run_cmd(self.socket.ssh_command("bash").arg("-c").arg(script))
1814            .await
1815            .context("Failed to download server binary")?;
1816
1817        if !output.contains("curl") && !output.contains("wget") {
1818            return Err(anyhow!("Failed to download server binary: {}", output));
1819        }
1820
1821        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
1822            .await
1823    }
1824
1825    async fn upload_local_server_binary(
1826        &self,
1827        src_path: &Path,
1828        dst_path: &Path,
1829        delegate: &Arc<dyn SshClientDelegate>,
1830        cx: &mut AsyncAppContext,
1831    ) -> Result<()> {
1832        let mut dst_path_gz = dst_path.to_path_buf();
1833        dst_path_gz.set_extension("gz");
1834
1835        if let Some(parent) = dst_path.parent() {
1836            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1837        }
1838
1839        let src_stat = fs::metadata(&src_path).await?;
1840        let size = src_stat.len();
1841
1842        let t0 = Instant::now();
1843        delegate.set_status(Some("Uploading remote development server"), cx);
1844        log::info!("uploading remote development server ({}kb)", size / 1024);
1845        self.upload_file(&src_path, &dst_path_gz)
1846            .await
1847            .context("failed to upload server binary")?;
1848        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1849
1850        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
1851            .await
1852    }
1853
1854    async fn extract_server_binary(
1855        &self,
1856        dst_path: &Path,
1857        dst_path_gz: &Path,
1858        delegate: &Arc<dyn SshClientDelegate>,
1859        cx: &mut AsyncAppContext,
1860    ) -> Result<()> {
1861        delegate.set_status(Some("Extracting remote development server"), cx);
1862        run_cmd(
1863            self.socket
1864                .ssh_command("gunzip")
1865                .arg("--force")
1866                .arg(&dst_path_gz),
1867        )
1868        .await?;
1869
1870        let server_mode = 0o755;
1871        delegate.set_status(Some("Marking remote development server executable"), cx);
1872        run_cmd(
1873            self.socket
1874                .ssh_command("chmod")
1875                .arg(format!("{:o}", server_mode))
1876                .arg(dst_path),
1877        )
1878        .await?;
1879
1880        Ok(())
1881    }
1882
1883    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1884        let mut command = process::Command::new("scp");
1885        let output = self
1886            .socket
1887            .ssh_options(&mut command)
1888            .args(
1889                self.socket
1890                    .connection_options
1891                    .port
1892                    .map(|port| vec!["-P".to_string(), port.to_string()])
1893                    .unwrap_or_default(),
1894            )
1895            .arg(src_path)
1896            .arg(format!(
1897                "{}:{}",
1898                self.socket.connection_options.scp_url(),
1899                dest_path.display()
1900            ))
1901            .output()
1902            .await?;
1903
1904        if output.status.success() {
1905            Ok(())
1906        } else {
1907            Err(anyhow!(
1908                "failed to upload file {} -> {}: {}",
1909                src_path.display(),
1910                dest_path.display(),
1911                String::from_utf8_lossy(&output.stderr)
1912            ))
1913        }
1914    }
1915}
1916
1917type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1918
1919pub struct ChannelClient {
1920    next_message_id: AtomicU32,
1921    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1922    buffer: Mutex<VecDeque<Envelope>>,
1923    response_channels: ResponseChannels,
1924    message_handlers: Mutex<ProtoMessageHandlerSet>,
1925    max_received: AtomicU32,
1926    name: &'static str,
1927    task: Mutex<Task<Result<()>>>,
1928}
1929
1930impl ChannelClient {
1931    pub fn new(
1932        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1933        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1934        cx: &AppContext,
1935        name: &'static str,
1936    ) -> Arc<Self> {
1937        Arc::new_cyclic(|this| Self {
1938            outgoing_tx: Mutex::new(outgoing_tx),
1939            next_message_id: AtomicU32::new(0),
1940            max_received: AtomicU32::new(0),
1941            response_channels: ResponseChannels::default(),
1942            message_handlers: Default::default(),
1943            buffer: Mutex::new(VecDeque::new()),
1944            name,
1945            task: Mutex::new(Self::start_handling_messages(
1946                this.clone(),
1947                incoming_rx,
1948                &cx.to_async(),
1949            )),
1950        })
1951    }
1952
1953    fn start_handling_messages(
1954        this: Weak<Self>,
1955        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1956        cx: &AsyncAppContext,
1957    ) -> Task<Result<()>> {
1958        cx.spawn(|cx| {
1959            async move {
1960                let peer_id = PeerId { owner_id: 0, id: 0 };
1961                while let Some(incoming) = incoming_rx.next().await {
1962                    let Some(this) = this.upgrade() else {
1963                        return anyhow::Ok(());
1964                    };
1965                    if let Some(ack_id) = incoming.ack_id {
1966                        let mut buffer = this.buffer.lock();
1967                        while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1968                            buffer.pop_front();
1969                        }
1970                    }
1971                    if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) =
1972                        &incoming.payload
1973                    {
1974                        log::debug!("{}:ssh message received. name:FlushBufferedMessages", this.name);
1975                        {
1976                            let buffer = this.buffer.lock();
1977                            for envelope in buffer.iter() {
1978                                this.outgoing_tx.lock().unbounded_send(envelope.clone()).ok();
1979                            }
1980                        }
1981                        let mut envelope = proto::Ack{}.into_envelope(0, Some(incoming.id), None);
1982                        envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1983                        this.outgoing_tx.lock().unbounded_send(envelope).ok();
1984                        continue;
1985                    }
1986
1987                    this.max_received.store(incoming.id, SeqCst);
1988
1989                    if let Some(request_id) = incoming.responding_to {
1990                        let request_id = MessageId(request_id);
1991                        let sender = this.response_channels.lock().remove(&request_id);
1992                        if let Some(sender) = sender {
1993                            let (tx, rx) = oneshot::channel();
1994                            if incoming.payload.is_some() {
1995                                sender.send((incoming, tx)).ok();
1996                            }
1997                            rx.await.ok();
1998                        }
1999                    } else if let Some(envelope) =
2000                        build_typed_envelope(peer_id, Instant::now(), incoming)
2001                    {
2002                        let type_name = envelope.payload_type_name();
2003                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
2004                            &this.message_handlers,
2005                            envelope,
2006                            this.clone().into(),
2007                            cx.clone(),
2008                        ) {
2009                            log::debug!("{}:ssh message received. name:{type_name}", this.name);
2010                            cx.foreground_executor().spawn(async move {
2011                                match future.await {
2012                                    Ok(_) => {
2013                                        log::debug!("{}:ssh message handled. name:{type_name}", this.name);
2014                                    }
2015                                    Err(error) => {
2016                                        log::error!(
2017                                            "{}:error handling message. type:{type_name}, error:{error}", this.name,
2018                                        );
2019                                    }
2020                                }
2021                            }).detach()
2022                        } else {
2023                            log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2024                        }
2025                    }
2026                }
2027                anyhow::Ok(())
2028            }
2029        })
2030    }
2031
2032    pub fn reconnect(
2033        self: &Arc<Self>,
2034        incoming_rx: UnboundedReceiver<Envelope>,
2035        outgoing_tx: UnboundedSender<Envelope>,
2036        cx: &AsyncAppContext,
2037    ) {
2038        *self.outgoing_tx.lock() = outgoing_tx;
2039        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2040    }
2041
2042    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
2043        let id = (TypeId::of::<E>(), remote_id);
2044
2045        let mut message_handlers = self.message_handlers.lock();
2046        if message_handlers
2047            .entities_by_type_and_remote_id
2048            .contains_key(&id)
2049        {
2050            panic!("already subscribed to entity");
2051        }
2052
2053        message_handlers.entities_by_type_and_remote_id.insert(
2054            id,
2055            EntityMessageSubscriber::Entity {
2056                handle: entity.downgrade().into(),
2057            },
2058        );
2059    }
2060
2061    pub fn request<T: RequestMessage>(
2062        &self,
2063        payload: T,
2064    ) -> impl 'static + Future<Output = Result<T::Response>> {
2065        self.request_internal(payload, true)
2066    }
2067
2068    fn request_internal<T: RequestMessage>(
2069        &self,
2070        payload: T,
2071        use_buffer: bool,
2072    ) -> impl 'static + Future<Output = Result<T::Response>> {
2073        log::debug!("ssh request start. name:{}", T::NAME);
2074        let response =
2075            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2076        async move {
2077            let response = response.await?;
2078            log::debug!("ssh request finish. name:{}", T::NAME);
2079            T::Response::from_envelope(response)
2080                .ok_or_else(|| anyhow!("received a response of the wrong type"))
2081        }
2082    }
2083
2084    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2085        smol::future::or(
2086            async {
2087                self.request_internal(proto::FlushBufferedMessages {}, false)
2088                    .await?;
2089
2090                for envelope in self.buffer.lock().iter() {
2091                    self.outgoing_tx
2092                        .lock()
2093                        .unbounded_send(envelope.clone())
2094                        .ok();
2095                }
2096                Ok(())
2097            },
2098            async {
2099                smol::Timer::after(timeout).await;
2100                Err(anyhow!("Timeout detected"))
2101            },
2102        )
2103        .await
2104    }
2105
2106    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2107        smol::future::or(
2108            async {
2109                self.request(proto::Ping {}).await?;
2110                Ok(())
2111            },
2112            async {
2113                smol::Timer::after(timeout).await;
2114                Err(anyhow!("Timeout detected"))
2115            },
2116        )
2117        .await
2118    }
2119
2120    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2121        log::debug!("ssh send name:{}", T::NAME);
2122        self.send_dynamic(payload.into_envelope(0, None, None))
2123    }
2124
2125    fn request_dynamic(
2126        &self,
2127        mut envelope: proto::Envelope,
2128        type_name: &'static str,
2129        use_buffer: bool,
2130    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2131        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2132        let (tx, rx) = oneshot::channel();
2133        let mut response_channels_lock = self.response_channels.lock();
2134        response_channels_lock.insert(MessageId(envelope.id), tx);
2135        drop(response_channels_lock);
2136
2137        let result = if use_buffer {
2138            self.send_buffered(envelope)
2139        } else {
2140            self.send_unbuffered(envelope)
2141        };
2142        async move {
2143            if let Err(error) = &result {
2144                log::error!("failed to send message: {}", error);
2145                return Err(anyhow!("failed to send message: {}", error));
2146            }
2147
2148            let response = rx.await.context("connection lost")?.0;
2149            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2150                return Err(RpcError::from_proto(error, type_name));
2151            }
2152            Ok(response)
2153        }
2154    }
2155
2156    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2157        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2158        self.send_buffered(envelope)
2159    }
2160
2161    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2162        envelope.ack_id = Some(self.max_received.load(SeqCst));
2163        self.buffer.lock().push_back(envelope.clone());
2164        // ignore errors on send (happen while we're reconnecting)
2165        // assume that the global "disconnected" overlay is sufficient.
2166        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2167        Ok(())
2168    }
2169
2170    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2171        envelope.ack_id = Some(self.max_received.load(SeqCst));
2172        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2173        Ok(())
2174    }
2175}
2176
2177impl ProtoClient for ChannelClient {
2178    fn request(
2179        &self,
2180        envelope: proto::Envelope,
2181        request_type: &'static str,
2182    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2183        self.request_dynamic(envelope, request_type, true).boxed()
2184    }
2185
2186    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2187        self.send_dynamic(envelope)
2188    }
2189
2190    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2191        self.send_dynamic(envelope)
2192    }
2193
2194    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2195        &self.message_handlers
2196    }
2197
2198    fn is_via_collab(&self) -> bool {
2199        false
2200    }
2201}
2202
2203#[cfg(any(test, feature = "test-support"))]
2204mod fake {
2205    use std::{path::PathBuf, sync::Arc};
2206
2207    use anyhow::Result;
2208    use async_trait::async_trait;
2209    use futures::{
2210        channel::{
2211            mpsc::{self, Sender},
2212            oneshot,
2213        },
2214        select_biased, FutureExt, SinkExt, StreamExt,
2215    };
2216    use gpui::{AsyncAppContext, SemanticVersion, Task};
2217    use rpc::proto::Envelope;
2218
2219    use super::{
2220        ChannelClient, RemoteConnection, ServerBinary, SshClientDelegate, SshConnectionOptions,
2221        SshPlatform,
2222    };
2223
2224    pub(super) struct FakeRemoteConnection {
2225        pub(super) connection_options: SshConnectionOptions,
2226        pub(super) server_channel: Arc<ChannelClient>,
2227        pub(super) server_cx: SendableCx,
2228    }
2229
2230    pub(super) struct SendableCx(AsyncAppContext);
2231    // safety: you can only get the other cx on the main thread.
2232    impl SendableCx {
2233        pub(super) fn new(cx: AsyncAppContext) -> Self {
2234            Self(cx)
2235        }
2236        fn get(&self, _: &AsyncAppContext) -> AsyncAppContext {
2237            self.0.clone()
2238        }
2239    }
2240    unsafe impl Send for SendableCx {}
2241    unsafe impl Sync for SendableCx {}
2242
2243    #[async_trait(?Send)]
2244    impl RemoteConnection for FakeRemoteConnection {
2245        async fn kill(&self) -> Result<()> {
2246            Ok(())
2247        }
2248
2249        fn has_been_killed(&self) -> bool {
2250            false
2251        }
2252
2253        fn ssh_args(&self) -> Vec<String> {
2254            Vec::new()
2255        }
2256
2257        fn connection_options(&self) -> SshConnectionOptions {
2258            self.connection_options.clone()
2259        }
2260
2261        fn simulate_disconnect(&self, cx: &AsyncAppContext) {
2262            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2263            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2264            self.server_channel
2265                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2266        }
2267
2268        async fn get_remote_binary_path(
2269            &self,
2270            _delegate: &Arc<dyn SshClientDelegate>,
2271            _reconnect: bool,
2272            _cx: &mut AsyncAppContext,
2273        ) -> Result<PathBuf> {
2274            Ok(PathBuf::new())
2275        }
2276
2277        fn start_proxy(
2278            &self,
2279            _remote_binary_path: PathBuf,
2280            _unique_identifier: String,
2281            _reconnect: bool,
2282            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2283            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2284            mut connection_activity_tx: Sender<()>,
2285            _delegate: Arc<dyn SshClientDelegate>,
2286            cx: &mut AsyncAppContext,
2287        ) -> Task<Result<i32>> {
2288            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2289            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2290
2291            self.server_channel.reconnect(
2292                server_incoming_rx,
2293                server_outgoing_tx,
2294                &self.server_cx.get(cx),
2295            );
2296
2297            cx.background_executor().spawn(async move {
2298                loop {
2299                    select_biased! {
2300                        server_to_client = server_outgoing_rx.next().fuse() => {
2301                            let Some(server_to_client) = server_to_client else {
2302                                return Ok(1)
2303                            };
2304                            connection_activity_tx.try_send(()).ok();
2305                            client_incoming_tx.send(server_to_client).await.ok();
2306                        }
2307                        client_to_server = client_outgoing_rx.next().fuse() => {
2308                            let Some(client_to_server) = client_to_server else {
2309                                return Ok(1)
2310                            };
2311                            server_incoming_tx.send(client_to_server).await.ok();
2312                        }
2313                    }
2314                }
2315            })
2316        }
2317    }
2318
2319    pub(super) struct Delegate;
2320
2321    impl SshClientDelegate for Delegate {
2322        fn ask_password(
2323            &self,
2324            _: String,
2325            _: &mut AsyncAppContext,
2326        ) -> oneshot::Receiver<Result<String>> {
2327            unreachable!()
2328        }
2329        fn remote_server_binary_path(
2330            &self,
2331            _: SshPlatform,
2332            _: &mut AsyncAppContext,
2333        ) -> Result<PathBuf> {
2334            unreachable!()
2335        }
2336        fn get_server_binary(
2337            &self,
2338            _: SshPlatform,
2339            _: &mut AsyncAppContext,
2340        ) -> oneshot::Receiver<Result<(ServerBinary, SemanticVersion)>> {
2341            unreachable!()
2342        }
2343
2344        fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {}
2345    }
2346}
2347
2348#[cfg(all(test, unix))]
2349mod tests {
2350    use super::*;
2351    use std::fs;
2352    use tempfile::TempDir;
2353
2354    fn run_stale_check_script(
2355        lock_file: &Path,
2356        max_age: Duration,
2357        simulate_port_open: Option<&str>,
2358    ) -> Result<String> {
2359        let wrapper = format!(
2360            r#"
2361            # Mock ss/netstat commands
2362            ss() {{
2363                # Only handle the -n argument
2364                if [ "$1" = "-n" ]; then
2365                    # If we're simulating an open port, output a line containing that port
2366                    if [ "{simulated_port}" != "" ]; then
2367                        echo "ESTAB 0 0 1.2.3.4:{simulated_port} 5.6.7.8:12345"
2368                    fi
2369                fi
2370            }}
2371            netstat() {{
2372                ss "$@"
2373            }}
2374            export -f ss netstat
2375
2376            # Real script starts here
2377            {script}"#,
2378            simulated_port = simulate_port_open.unwrap_or(""),
2379            script = SshRemoteConnection::generate_stale_check_script(lock_file, max_age.as_secs())
2380        );
2381
2382        let output = std::process::Command::new("bash")
2383            .arg("-c")
2384            .arg(&wrapper)
2385            .output()?;
2386
2387        if !output.stderr.is_empty() {
2388            eprintln!("Script stderr: {}", String::from_utf8_lossy(&output.stderr));
2389        }
2390
2391        Ok(String::from_utf8(output.stdout)?.trim().to_string())
2392    }
2393
2394    #[test]
2395    fn test_lock_staleness() -> Result<()> {
2396        let temp_dir = TempDir::new()?;
2397        let lock_file = temp_dir.path().join("test.lock");
2398
2399        // Test 1: No lock file
2400        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), None)?;
2401        assert_eq!(output, "lock file does not exist");
2402
2403        // Test 2: Lock file with port that's not open
2404        fs::write(&lock_file, "54321 1234567890")?;
2405        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("98765"))?;
2406        assert_eq!(output, "ss reports port 54321 is not open");
2407
2408        // Test 3: Lock file with port that is open but old timestamp
2409        let old_timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() - 700; // 700 seconds ago
2410        fs::write(&lock_file, format!("54321 {}", old_timestamp))?;
2411        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("54321"))?;
2412        assert_eq!(output, "timestamp in lockfile is too old");
2413
2414        // Test 4: Lock file with port that is open and recent timestamp
2415        let recent_timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() - 60; // 1 minute ago
2416        fs::write(&lock_file, format!("54321 {}", recent_timestamp))?;
2417        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("54321"))?;
2418        assert_eq!(output, "recent");
2419
2420        Ok(())
2421    }
2422}