ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    channel::{
  13        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  14        oneshot,
  15    },
  16    future::{BoxFuture, Shared},
  17    select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
  18};
  19use gpui::{
  20    AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model,
  21    ModelContext, SemanticVersion, Task, WeakModel,
  22};
  23use itertools::Itertools;
  24use parking_lot::Mutex;
  25use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
  26use rpc::{
  27    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  28    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
  29    RpcError,
  30};
  31use smol::{
  32    fs,
  33    process::{self, Child, Stdio},
  34};
  35use std::{
  36    any::TypeId,
  37    collections::VecDeque,
  38    fmt, iter,
  39    ops::ControlFlow,
  40    path::{Path, PathBuf},
  41    sync::{
  42        atomic::{AtomicU32, Ordering::SeqCst},
  43        Arc, Weak,
  44    },
  45    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
  46};
  47use tempfile::TempDir;
  48use util::ResultExt;
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  52)]
  53pub struct SshProjectId(pub u64);
  54
  55#[derive(Clone)]
  56pub struct SshSocket {
  57    connection_options: SshConnectionOptions,
  58    socket_path: PathBuf,
  59}
  60
  61#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
  62pub struct SshConnectionOptions {
  63    pub host: String,
  64    pub username: Option<String>,
  65    pub port: Option<u16>,
  66    pub password: Option<String>,
  67    pub args: Option<Vec<String>>,
  68
  69    pub nickname: Option<String>,
  70    pub upload_binary_over_ssh: bool,
  71}
  72
  73#[macro_export]
  74macro_rules! shell_script {
  75    ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
  76        format!(
  77            $fmt,
  78            $(
  79                $name = shlex::try_quote($arg).unwrap()
  80            ),+
  81        )
  82    }};
  83}
  84
  85impl SshConnectionOptions {
  86    pub fn parse_command_line(input: &str) -> Result<Self> {
  87        let input = input.trim_start_matches("ssh ");
  88        let mut hostname: Option<String> = None;
  89        let mut username: Option<String> = None;
  90        let mut port: Option<u16> = None;
  91        let mut args = Vec::new();
  92
  93        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
  94        const ALLOWED_OPTS: &[&str] = &[
  95            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
  96        ];
  97        const ALLOWED_ARGS: &[&str] = &[
  98            "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
  99            "-w",
 100        ];
 101
 102        let mut tokens = shlex::split(input)
 103            .ok_or_else(|| anyhow!("invalid input"))?
 104            .into_iter();
 105
 106        'outer: while let Some(arg) = tokens.next() {
 107            if ALLOWED_OPTS.contains(&(&arg as &str)) {
 108                args.push(arg.to_string());
 109                continue;
 110            }
 111            if arg == "-p" {
 112                port = tokens.next().and_then(|arg| arg.parse().ok());
 113                continue;
 114            } else if let Some(p) = arg.strip_prefix("-p") {
 115                port = p.parse().ok();
 116                continue;
 117            }
 118            if arg == "-l" {
 119                username = tokens.next();
 120                continue;
 121            } else if let Some(l) = arg.strip_prefix("-l") {
 122                username = Some(l.to_string());
 123                continue;
 124            }
 125            for a in ALLOWED_ARGS {
 126                if arg == *a {
 127                    args.push(arg);
 128                    if let Some(next) = tokens.next() {
 129                        args.push(next);
 130                    }
 131                    continue 'outer;
 132                } else if arg.starts_with(a) {
 133                    args.push(arg);
 134                    continue 'outer;
 135                }
 136            }
 137            if arg.starts_with("-") || hostname.is_some() {
 138                anyhow::bail!("unsupported argument: {:?}", arg);
 139            }
 140            let mut input = &arg as &str;
 141            if let Some((u, rest)) = input.split_once('@') {
 142                input = rest;
 143                username = Some(u.to_string());
 144            }
 145            if let Some((rest, p)) = input.split_once(':') {
 146                input = rest;
 147                port = p.parse().ok()
 148            }
 149            hostname = Some(input.to_string())
 150        }
 151
 152        let Some(hostname) = hostname else {
 153            anyhow::bail!("missing hostname");
 154        };
 155
 156        Ok(Self {
 157            host: hostname.to_string(),
 158            username: username.clone(),
 159            port,
 160            args: Some(args),
 161            password: None,
 162            nickname: None,
 163            upload_binary_over_ssh: false,
 164        })
 165    }
 166
 167    pub fn ssh_url(&self) -> String {
 168        let mut result = String::from("ssh://");
 169        if let Some(username) = &self.username {
 170            result.push_str(username);
 171            result.push('@');
 172        }
 173        result.push_str(&self.host);
 174        if let Some(port) = self.port {
 175            result.push(':');
 176            result.push_str(&port.to_string());
 177        }
 178        result
 179    }
 180
 181    pub fn additional_args(&self) -> Option<&Vec<String>> {
 182        self.args.as_ref()
 183    }
 184
 185    fn scp_url(&self) -> String {
 186        if let Some(username) = &self.username {
 187            format!("{}@{}", username, self.host)
 188        } else {
 189            self.host.clone()
 190        }
 191    }
 192
 193    pub fn connection_string(&self) -> String {
 194        let host = if let Some(username) = &self.username {
 195            format!("{}@{}", username, self.host)
 196        } else {
 197            self.host.clone()
 198        };
 199        if let Some(port) = &self.port {
 200            format!("{}:{}", host, port)
 201        } else {
 202            host
 203        }
 204    }
 205}
 206
 207#[derive(Copy, Clone, Debug)]
 208pub struct SshPlatform {
 209    pub os: &'static str,
 210    pub arch: &'static str,
 211}
 212
 213impl SshPlatform {
 214    pub fn triple(&self) -> Option<String> {
 215        Some(format!(
 216            "{}-{}",
 217            self.arch,
 218            match self.os {
 219                "linux" => "unknown-linux-gnu",
 220                "macos" => "apple-darwin",
 221                _ => return None,
 222            }
 223        ))
 224    }
 225}
 226
 227pub enum ServerBinary {
 228    LocalBinary(PathBuf),
 229    ReleaseUrl { url: String, body: String },
 230}
 231
 232#[derive(Clone, Debug, PartialEq, Eq)]
 233pub enum ServerVersion {
 234    Semantic(SemanticVersion),
 235    Commit(String),
 236}
 237impl ServerVersion {
 238    pub fn semantic_version(&self) -> Option<SemanticVersion> {
 239        match self {
 240            Self::Semantic(version) => Some(*version),
 241            _ => None,
 242        }
 243    }
 244}
 245
 246impl std::fmt::Display for ServerVersion {
 247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 248        match self {
 249            Self::Semantic(version) => write!(f, "{}", version),
 250            Self::Commit(commit) => write!(f, "{}", commit),
 251        }
 252    }
 253}
 254
 255pub trait SshClientDelegate: Send + Sync {
 256    fn ask_password(
 257        &self,
 258        prompt: String,
 259        cx: &mut AsyncAppContext,
 260    ) -> oneshot::Receiver<Result<String>>;
 261    fn remote_server_binary_path(
 262        &self,
 263        platform: SshPlatform,
 264        cx: &mut AsyncAppContext,
 265    ) -> Result<PathBuf>;
 266    fn get_download_params(
 267        &self,
 268        platform: SshPlatform,
 269        release_channel: ReleaseChannel,
 270        version: Option<SemanticVersion>,
 271        cx: &mut AsyncAppContext,
 272    ) -> Task<Result<(String, String)>>;
 273
 274    fn download_server_binary_locally(
 275        &self,
 276        platform: SshPlatform,
 277        release_channel: ReleaseChannel,
 278        version: Option<SemanticVersion>,
 279        cx: &mut AsyncAppContext,
 280    ) -> Task<Result<PathBuf>>;
 281    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 282}
 283
 284impl SshSocket {
 285    // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
 286    // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
 287    // and passes -l as an argument to sh, not to ls.
 288    // You need to do it like this: $ ssh host "sh -c 'ls -l /tmp'"
 289    fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
 290        let mut command = process::Command::new("ssh");
 291        let to_run = iter::once(&program)
 292            .chain(args.iter())
 293            .map(|token| shlex::try_quote(token).unwrap())
 294            .join(" ");
 295        self.ssh_options(&mut command)
 296            .arg(self.connection_options.ssh_url())
 297            .arg(to_run);
 298        command
 299    }
 300
 301    fn shell_script(&self, script: impl AsRef<str>) -> process::Command {
 302        return self.ssh_command("sh", &["-c", script.as_ref()]);
 303    }
 304
 305    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 306        command
 307            .stdin(Stdio::piped())
 308            .stdout(Stdio::piped())
 309            .stderr(Stdio::piped())
 310            .args(["-o", "ControlMaster=no", "-o"])
 311            .arg(format!("ControlPath={}", self.socket_path.display()))
 312    }
 313
 314    fn ssh_args(&self) -> Vec<String> {
 315        vec![
 316            "-o".to_string(),
 317            "ControlMaster=no".to_string(),
 318            "-o".to_string(),
 319            format!("ControlPath={}", self.socket_path.display()),
 320            self.connection_options.ssh_url(),
 321        ]
 322    }
 323}
 324
 325async fn run_cmd(mut command: process::Command) -> Result<String> {
 326    let output = command.output().await?;
 327    if output.status.success() {
 328        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 329    } else {
 330        Err(anyhow!(
 331            "failed to run command: {}",
 332            String::from_utf8_lossy(&output.stderr)
 333        ))
 334    }
 335}
 336
 337const MAX_MISSED_HEARTBEATS: usize = 5;
 338const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 339const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 340
 341const MAX_RECONNECT_ATTEMPTS: usize = 3;
 342
 343enum State {
 344    Connecting,
 345    Connected {
 346        ssh_connection: Arc<dyn RemoteConnection>,
 347        delegate: Arc<dyn SshClientDelegate>,
 348
 349        multiplex_task: Task<Result<()>>,
 350        heartbeat_task: Task<Result<()>>,
 351    },
 352    HeartbeatMissed {
 353        missed_heartbeats: usize,
 354
 355        ssh_connection: Arc<dyn RemoteConnection>,
 356        delegate: Arc<dyn SshClientDelegate>,
 357
 358        multiplex_task: Task<Result<()>>,
 359        heartbeat_task: Task<Result<()>>,
 360    },
 361    Reconnecting,
 362    ReconnectFailed {
 363        ssh_connection: Arc<dyn RemoteConnection>,
 364        delegate: Arc<dyn SshClientDelegate>,
 365
 366        error: anyhow::Error,
 367        attempts: usize,
 368    },
 369    ReconnectExhausted,
 370    ServerNotRunning,
 371}
 372
 373impl fmt::Display for State {
 374    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 375        match self {
 376            Self::Connecting => write!(f, "connecting"),
 377            Self::Connected { .. } => write!(f, "connected"),
 378            Self::Reconnecting => write!(f, "reconnecting"),
 379            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 380            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 381            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 382            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 383        }
 384    }
 385}
 386
 387impl State {
 388    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
 389        match self {
 390            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 391            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 392            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 393            _ => None,
 394        }
 395    }
 396
 397    fn can_reconnect(&self) -> bool {
 398        match self {
 399            Self::Connected { .. }
 400            | Self::HeartbeatMissed { .. }
 401            | Self::ReconnectFailed { .. } => true,
 402            State::Connecting
 403            | State::Reconnecting
 404            | State::ReconnectExhausted
 405            | State::ServerNotRunning => false,
 406        }
 407    }
 408
 409    fn is_reconnect_failed(&self) -> bool {
 410        matches!(self, Self::ReconnectFailed { .. })
 411    }
 412
 413    fn is_reconnect_exhausted(&self) -> bool {
 414        matches!(self, Self::ReconnectExhausted { .. })
 415    }
 416
 417    fn is_server_not_running(&self) -> bool {
 418        matches!(self, Self::ServerNotRunning)
 419    }
 420
 421    fn is_reconnecting(&self) -> bool {
 422        matches!(self, Self::Reconnecting { .. })
 423    }
 424
 425    fn heartbeat_recovered(self) -> Self {
 426        match self {
 427            Self::HeartbeatMissed {
 428                ssh_connection,
 429                delegate,
 430                multiplex_task,
 431                heartbeat_task,
 432                ..
 433            } => Self::Connected {
 434                ssh_connection,
 435                delegate,
 436                multiplex_task,
 437                heartbeat_task,
 438            },
 439            _ => self,
 440        }
 441    }
 442
 443    fn heartbeat_missed(self) -> Self {
 444        match self {
 445            Self::Connected {
 446                ssh_connection,
 447                delegate,
 448                multiplex_task,
 449                heartbeat_task,
 450            } => Self::HeartbeatMissed {
 451                missed_heartbeats: 1,
 452                ssh_connection,
 453                delegate,
 454                multiplex_task,
 455                heartbeat_task,
 456            },
 457            Self::HeartbeatMissed {
 458                missed_heartbeats,
 459                ssh_connection,
 460                delegate,
 461                multiplex_task,
 462                heartbeat_task,
 463            } => Self::HeartbeatMissed {
 464                missed_heartbeats: missed_heartbeats + 1,
 465                ssh_connection,
 466                delegate,
 467                multiplex_task,
 468                heartbeat_task,
 469            },
 470            _ => self,
 471        }
 472    }
 473}
 474
 475/// The state of the ssh connection.
 476#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 477pub enum ConnectionState {
 478    Connecting,
 479    Connected,
 480    HeartbeatMissed,
 481    Reconnecting,
 482    Disconnected,
 483}
 484
 485impl From<&State> for ConnectionState {
 486    fn from(value: &State) -> Self {
 487        match value {
 488            State::Connecting => Self::Connecting,
 489            State::Connected { .. } => Self::Connected,
 490            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 491            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 492            State::ReconnectExhausted => Self::Disconnected,
 493            State::ServerNotRunning => Self::Disconnected,
 494        }
 495    }
 496}
 497
 498pub struct SshRemoteClient {
 499    client: Arc<ChannelClient>,
 500    unique_identifier: String,
 501    connection_options: SshConnectionOptions,
 502    state: Arc<Mutex<Option<State>>>,
 503}
 504
 505#[derive(Debug)]
 506pub enum SshRemoteEvent {
 507    Disconnected,
 508}
 509
 510impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 511
 512// Identifies the socket on the remote server so that reconnects
 513// can re-join the same project.
 514pub enum ConnectionIdentifier {
 515    Setup,
 516    Workspace(i64),
 517}
 518
 519impl ConnectionIdentifier {
 520    // This string gets used in a socket name, and so must be relatively short.
 521    // The total length of:
 522    //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
 523    // Must be less than about 100 characters
 524    //   https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
 525    // So our strings should be at most 20 characters or so.
 526    fn to_string(&self, cx: &AppContext) -> String {
 527        let identifier_prefix = match ReleaseChannel::global(cx) {
 528            ReleaseChannel::Stable => "".to_string(),
 529            release_channel => format!("{}-", release_channel.dev_name()),
 530        };
 531        match self {
 532            Self::Setup => format!("{identifier_prefix}setup"),
 533            Self::Workspace(workspace_id) => {
 534                format!("{identifier_prefix}workspace-{workspace_id}",)
 535            }
 536        }
 537    }
 538}
 539
 540impl SshRemoteClient {
 541    pub fn new(
 542        unique_identifier: ConnectionIdentifier,
 543        connection_options: SshConnectionOptions,
 544        cancellation: oneshot::Receiver<()>,
 545        delegate: Arc<dyn SshClientDelegate>,
 546        cx: &mut AppContext,
 547    ) -> Task<Result<Option<Model<Self>>>> {
 548        let unique_identifier = unique_identifier.to_string(cx);
 549        cx.spawn(|mut cx| async move {
 550            let success = Box::pin(async move {
 551                let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 552                let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 553                let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 554
 555                let client =
 556                    cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
 557                let this = cx.new_model(|_| Self {
 558                    client: client.clone(),
 559                    unique_identifier: unique_identifier.clone(),
 560                    connection_options: connection_options.clone(),
 561                    state: Arc::new(Mutex::new(Some(State::Connecting))),
 562                })?;
 563
 564                let ssh_connection = cx
 565                    .update(|cx| {
 566                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
 567                            pool.connect(connection_options, &delegate, cx)
 568                        })
 569                    })?
 570                    .await
 571                    .map_err(|e| e.cloned())?;
 572                let remote_binary_path = ssh_connection
 573                    .get_remote_binary_path(&delegate, false, &mut cx)
 574                    .await?;
 575
 576                let io_task = ssh_connection.start_proxy(
 577                    remote_binary_path,
 578                    unique_identifier,
 579                    false,
 580                    incoming_tx,
 581                    outgoing_rx,
 582                    connection_activity_tx,
 583                    delegate.clone(),
 584                    &mut cx,
 585                );
 586
 587                let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 588
 589                if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 590                    log::error!("failed to establish connection: {}", error);
 591                    return Err(error);
 592                }
 593
 594                let heartbeat_task =
 595                    Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 596
 597                this.update(&mut cx, |this, _| {
 598                    *this.state.lock() = Some(State::Connected {
 599                        ssh_connection,
 600                        delegate,
 601                        multiplex_task,
 602                        heartbeat_task,
 603                    });
 604                })?;
 605
 606                Ok(Some(this))
 607            });
 608
 609            select! {
 610                _ = cancellation.fuse() => {
 611                    Ok(None)
 612                }
 613                result = success.fuse() =>  result
 614            }
 615        })
 616    }
 617
 618    pub fn shutdown_processes<T: RequestMessage>(
 619        &self,
 620        shutdown_request: Option<T>,
 621    ) -> Option<impl Future<Output = ()>> {
 622        let state = self.state.lock().take()?;
 623        log::info!("shutting down ssh processes");
 624
 625        let State::Connected {
 626            multiplex_task,
 627            heartbeat_task,
 628            ssh_connection,
 629            delegate,
 630        } = state
 631        else {
 632            return None;
 633        };
 634
 635        let client = self.client.clone();
 636
 637        Some(async move {
 638            if let Some(shutdown_request) = shutdown_request {
 639                client.send(shutdown_request).log_err();
 640                // We wait 50ms instead of waiting for a response, because
 641                // waiting for a response would require us to wait on the main thread
 642                // which we want to avoid in an `on_app_quit` callback.
 643                smol::Timer::after(Duration::from_millis(50)).await;
 644            }
 645
 646            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 647            // child of master_process.
 648            drop(multiplex_task);
 649            // Now drop the rest of state, which kills master process.
 650            drop(heartbeat_task);
 651            drop(ssh_connection);
 652            drop(delegate);
 653        })
 654    }
 655
 656    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 657        let mut lock = self.state.lock();
 658
 659        let can_reconnect = lock
 660            .as_ref()
 661            .map(|state| state.can_reconnect())
 662            .unwrap_or(false);
 663        if !can_reconnect {
 664            let error = if let Some(state) = lock.as_ref() {
 665                format!("invalid state, cannot reconnect while in state {state}")
 666            } else {
 667                "no state set".to_string()
 668            };
 669            log::info!("aborting reconnect, because not in state that allows reconnecting");
 670            return Err(anyhow!(error));
 671        }
 672
 673        let state = lock.take().unwrap();
 674        let (attempts, ssh_connection, delegate) = match state {
 675            State::Connected {
 676                ssh_connection,
 677                delegate,
 678                multiplex_task,
 679                heartbeat_task,
 680            }
 681            | State::HeartbeatMissed {
 682                ssh_connection,
 683                delegate,
 684                multiplex_task,
 685                heartbeat_task,
 686                ..
 687            } => {
 688                drop(multiplex_task);
 689                drop(heartbeat_task);
 690                (0, ssh_connection, delegate)
 691            }
 692            State::ReconnectFailed {
 693                attempts,
 694                ssh_connection,
 695                delegate,
 696                ..
 697            } => (attempts, ssh_connection, delegate),
 698            State::Connecting
 699            | State::Reconnecting
 700            | State::ReconnectExhausted
 701            | State::ServerNotRunning => unreachable!(),
 702        };
 703
 704        let attempts = attempts + 1;
 705        if attempts > MAX_RECONNECT_ATTEMPTS {
 706            log::error!(
 707                "Failed to reconnect to after {} attempts, giving up",
 708                MAX_RECONNECT_ATTEMPTS
 709            );
 710            drop(lock);
 711            self.set_state(State::ReconnectExhausted, cx);
 712            return Ok(());
 713        }
 714        drop(lock);
 715
 716        self.set_state(State::Reconnecting, cx);
 717
 718        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 719
 720        let unique_identifier = self.unique_identifier.clone();
 721        let client = self.client.clone();
 722        let reconnect_task = cx.spawn(|this, mut cx| async move {
 723            macro_rules! failed {
 724                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
 725                    return State::ReconnectFailed {
 726                        error: anyhow!($error),
 727                        attempts: $attempts,
 728                        ssh_connection: $ssh_connection,
 729                        delegate: $delegate,
 730                    };
 731                };
 732            }
 733
 734            if let Err(error) = ssh_connection
 735                .kill()
 736                .await
 737                .context("Failed to kill ssh process")
 738            {
 739                failed!(error, attempts, ssh_connection, delegate);
 740            };
 741
 742            let connection_options = ssh_connection.connection_options();
 743
 744            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 745            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 746            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 747
 748            let (ssh_connection, io_task) = match async {
 749                let ssh_connection = cx
 750                    .update_global(|pool: &mut ConnectionPool, cx| {
 751                        pool.connect(connection_options, &delegate, cx)
 752                    })?
 753                    .await
 754                    .map_err(|error| error.cloned())?;
 755
 756                let remote_binary_path = ssh_connection
 757                    .get_remote_binary_path(&delegate, true, &mut cx)
 758                    .await?;
 759
 760                let io_task = ssh_connection.start_proxy(
 761                    remote_binary_path,
 762                    unique_identifier,
 763                    true,
 764                    incoming_tx,
 765                    outgoing_rx,
 766                    connection_activity_tx,
 767                    delegate.clone(),
 768                    &mut cx,
 769                );
 770                anyhow::Ok((ssh_connection, io_task))
 771            }
 772            .await
 773            {
 774                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
 775                Err(error) => {
 776                    failed!(error, attempts, ssh_connection, delegate);
 777                }
 778            };
 779
 780            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 781            client.reconnect(incoming_rx, outgoing_tx, &cx);
 782
 783            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 784                failed!(error, attempts, ssh_connection, delegate);
 785            };
 786
 787            State::Connected {
 788                ssh_connection,
 789                delegate,
 790                multiplex_task,
 791                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 792            }
 793        });
 794
 795        cx.spawn(|this, mut cx| async move {
 796            let new_state = reconnect_task.await;
 797            this.update(&mut cx, |this, cx| {
 798                this.try_set_state(cx, |old_state| {
 799                    if old_state.is_reconnecting() {
 800                        match &new_state {
 801                            State::Connecting
 802                            | State::Reconnecting { .. }
 803                            | State::HeartbeatMissed { .. }
 804                            | State::ServerNotRunning => {}
 805                            State::Connected { .. } => {
 806                                log::info!("Successfully reconnected");
 807                            }
 808                            State::ReconnectFailed {
 809                                error, attempts, ..
 810                            } => {
 811                                log::error!(
 812                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 813                                    attempts,
 814                                    error
 815                                );
 816                            }
 817                            State::ReconnectExhausted => {
 818                                log::error!("Reconnect attempt failed and all attempts exhausted");
 819                            }
 820                        }
 821                        Some(new_state)
 822                    } else {
 823                        None
 824                    }
 825                });
 826
 827                if this.state_is(State::is_reconnect_failed) {
 828                    this.reconnect(cx)
 829                } else if this.state_is(State::is_reconnect_exhausted) {
 830                    Ok(())
 831                } else {
 832                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
 833                    Ok(())
 834                }
 835            })
 836        })
 837        .detach_and_log_err(cx);
 838
 839        Ok(())
 840    }
 841
 842    fn heartbeat(
 843        this: WeakModel<Self>,
 844        mut connection_activity_rx: mpsc::Receiver<()>,
 845        cx: &mut AsyncAppContext,
 846    ) -> Task<Result<()>> {
 847        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 848            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 849        };
 850
 851        cx.spawn(|mut cx| {
 852            let this = this.clone();
 853            async move {
 854                let mut missed_heartbeats = 0;
 855
 856                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 857                futures::pin_mut!(keepalive_timer);
 858
 859                loop {
 860                    select_biased! {
 861                        result = connection_activity_rx.next().fuse() => {
 862                            if result.is_none() {
 863                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 864                                return Ok(());
 865                            }
 866
 867                            if missed_heartbeats != 0 {
 868                                missed_heartbeats = 0;
 869                                this.update(&mut cx, |this, mut cx| {
 870                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 871                                })?;
 872                            }
 873                        }
 874                        _ = keepalive_timer => {
 875                            log::debug!("Sending heartbeat to server...");
 876
 877                            let result = select_biased! {
 878                                _ = connection_activity_rx.next().fuse() => {
 879                                    Ok(())
 880                                }
 881                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 882                                    ping_result
 883                                }
 884                            };
 885
 886                            if result.is_err() {
 887                                missed_heartbeats += 1;
 888                                log::warn!(
 889                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 890                                    HEARTBEAT_TIMEOUT,
 891                                    missed_heartbeats,
 892                                    MAX_MISSED_HEARTBEATS
 893                                );
 894                            } else if missed_heartbeats != 0 {
 895                                missed_heartbeats = 0;
 896                            } else {
 897                                continue;
 898                            }
 899
 900                            let result = this.update(&mut cx, |this, mut cx| {
 901                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 902                            })?;
 903                            if result.is_break() {
 904                                return Ok(());
 905                            }
 906                        }
 907                    }
 908
 909                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 910                }
 911            }
 912        })
 913    }
 914
 915    fn handle_heartbeat_result(
 916        &mut self,
 917        missed_heartbeats: usize,
 918        cx: &mut ModelContext<Self>,
 919    ) -> ControlFlow<()> {
 920        let state = self.state.lock().take().unwrap();
 921        let next_state = if missed_heartbeats > 0 {
 922            state.heartbeat_missed()
 923        } else {
 924            state.heartbeat_recovered()
 925        };
 926
 927        self.set_state(next_state, cx);
 928
 929        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 930            log::error!(
 931                "Missed last {} heartbeats. Reconnecting...",
 932                missed_heartbeats
 933            );
 934
 935            self.reconnect(cx)
 936                .context("failed to start reconnect process after missing heartbeats")
 937                .log_err();
 938            ControlFlow::Break(())
 939        } else {
 940            ControlFlow::Continue(())
 941        }
 942    }
 943
 944    fn monitor(
 945        this: WeakModel<Self>,
 946        io_task: Task<Result<i32>>,
 947        cx: &AsyncAppContext,
 948    ) -> Task<Result<()>> {
 949        cx.spawn(|mut cx| async move {
 950            let result = io_task.await;
 951
 952            match result {
 953                Ok(exit_code) => {
 954                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
 955                        match error {
 956                            ProxyLaunchError::ServerNotRunning => {
 957                                log::error!("failed to reconnect because server is not running");
 958                                this.update(&mut cx, |this, cx| {
 959                                    this.set_state(State::ServerNotRunning, cx);
 960                                })?;
 961                            }
 962                        }
 963                    } else if exit_code > 0 {
 964                        log::error!("proxy process terminated unexpectedly");
 965                        this.update(&mut cx, |this, cx| {
 966                            this.reconnect(cx).ok();
 967                        })?;
 968                    }
 969                }
 970                Err(error) => {
 971                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 972                    this.update(&mut cx, |this, cx| {
 973                        this.reconnect(cx).ok();
 974                    })?;
 975                }
 976            }
 977
 978            Ok(())
 979        })
 980    }
 981
 982    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
 983        self.state.lock().as_ref().map_or(false, check)
 984    }
 985
 986    fn try_set_state(
 987        &self,
 988        cx: &mut ModelContext<Self>,
 989        map: impl FnOnce(&State) -> Option<State>,
 990    ) {
 991        let mut lock = self.state.lock();
 992        let new_state = lock.as_ref().and_then(map);
 993
 994        if let Some(new_state) = new_state {
 995            lock.replace(new_state);
 996            cx.notify();
 997        }
 998    }
 999
1000    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
1001        log::info!("setting state to '{}'", &state);
1002
1003        let is_reconnect_exhausted = state.is_reconnect_exhausted();
1004        let is_server_not_running = state.is_server_not_running();
1005        self.state.lock().replace(state);
1006
1007        if is_reconnect_exhausted || is_server_not_running {
1008            cx.emit(SshRemoteEvent::Disconnected);
1009        }
1010        cx.notify();
1011    }
1012
1013    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1014        self.client.subscribe_to_entity(remote_id, entity);
1015    }
1016
1017    pub fn ssh_args(&self) -> Option<Vec<String>> {
1018        self.state
1019            .lock()
1020            .as_ref()
1021            .and_then(|state| state.ssh_connection())
1022            .map(|ssh_connection| ssh_connection.ssh_args())
1023    }
1024
1025    pub fn proto_client(&self) -> AnyProtoClient {
1026        self.client.clone().into()
1027    }
1028
1029    pub fn connection_string(&self) -> String {
1030        self.connection_options.connection_string()
1031    }
1032
1033    pub fn connection_options(&self) -> SshConnectionOptions {
1034        self.connection_options.clone()
1035    }
1036
1037    pub fn connection_state(&self) -> ConnectionState {
1038        self.state
1039            .lock()
1040            .as_ref()
1041            .map(ConnectionState::from)
1042            .unwrap_or(ConnectionState::Disconnected)
1043    }
1044
1045    pub fn is_disconnected(&self) -> bool {
1046        self.connection_state() == ConnectionState::Disconnected
1047    }
1048
1049    #[cfg(any(test, feature = "test-support"))]
1050    pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
1051        let opts = self.connection_options();
1052        client_cx.spawn(|cx| async move {
1053            let connection = cx
1054                .update_global(|c: &mut ConnectionPool, _| {
1055                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1056                        c.clone()
1057                    } else {
1058                        panic!("missing test connection")
1059                    }
1060                })
1061                .unwrap()
1062                .await
1063                .unwrap();
1064
1065            connection.simulate_disconnect(&cx);
1066        })
1067    }
1068
1069    #[cfg(any(test, feature = "test-support"))]
1070    pub fn fake_server(
1071        client_cx: &mut gpui::TestAppContext,
1072        server_cx: &mut gpui::TestAppContext,
1073    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1074        let port = client_cx
1075            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1076        let opts = SshConnectionOptions {
1077            host: "<fake>".to_string(),
1078            port: Some(port),
1079            ..Default::default()
1080        };
1081        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1082        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1083        let server_client =
1084            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1085        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1086            connection_options: opts.clone(),
1087            server_cx: fake::SendableCx::new(server_cx),
1088            server_channel: server_client.clone(),
1089        });
1090
1091        client_cx.update(|cx| {
1092            cx.update_default_global(|c: &mut ConnectionPool, cx| {
1093                c.connections.insert(
1094                    opts.clone(),
1095                    ConnectionPoolEntry::Connecting(
1096                        cx.foreground_executor()
1097                            .spawn({
1098                                let connection = connection.clone();
1099                                async move { Ok(connection.clone()) }
1100                            })
1101                            .shared(),
1102                    ),
1103                );
1104            })
1105        });
1106
1107        (opts, server_client)
1108    }
1109
1110    #[cfg(any(test, feature = "test-support"))]
1111    pub async fn fake_client(
1112        opts: SshConnectionOptions,
1113        client_cx: &mut gpui::TestAppContext,
1114    ) -> Model<Self> {
1115        let (_tx, rx) = oneshot::channel();
1116        client_cx
1117            .update(|cx| {
1118                Self::new(
1119                    ConnectionIdentifier::Setup,
1120                    opts,
1121                    rx,
1122                    Arc::new(fake::Delegate),
1123                    cx,
1124                )
1125            })
1126            .await
1127            .unwrap()
1128            .unwrap()
1129    }
1130}
1131
1132enum ConnectionPoolEntry {
1133    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1134    Connected(Weak<dyn RemoteConnection>),
1135}
1136
1137#[derive(Default)]
1138struct ConnectionPool {
1139    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1140}
1141
1142impl Global for ConnectionPool {}
1143
1144impl ConnectionPool {
1145    pub fn connect(
1146        &mut self,
1147        opts: SshConnectionOptions,
1148        delegate: &Arc<dyn SshClientDelegate>,
1149        cx: &mut AppContext,
1150    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1151        let connection = self.connections.get(&opts);
1152        match connection {
1153            Some(ConnectionPoolEntry::Connecting(task)) => {
1154                let delegate = delegate.clone();
1155                cx.spawn(|mut cx| async move {
1156                    delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx);
1157                })
1158                .detach();
1159                return task.clone();
1160            }
1161            Some(ConnectionPoolEntry::Connected(ssh)) => {
1162                if let Some(ssh) = ssh.upgrade() {
1163                    if !ssh.has_been_killed() {
1164                        return Task::ready(Ok(ssh)).shared();
1165                    }
1166                }
1167                self.connections.remove(&opts);
1168            }
1169            None => {}
1170        }
1171
1172        let task = cx
1173            .spawn({
1174                let opts = opts.clone();
1175                let delegate = delegate.clone();
1176                |mut cx| async move {
1177                    let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx)
1178                        .await
1179                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1180
1181                    cx.update_global(|pool: &mut Self, _| {
1182                        debug_assert!(matches!(
1183                            pool.connections.get(&opts),
1184                            Some(ConnectionPoolEntry::Connecting(_))
1185                        ));
1186                        match connection {
1187                            Ok(connection) => {
1188                                pool.connections.insert(
1189                                    opts.clone(),
1190                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1191                                );
1192                                Ok(connection)
1193                            }
1194                            Err(error) => {
1195                                pool.connections.remove(&opts);
1196                                Err(Arc::new(error))
1197                            }
1198                        }
1199                    })?
1200                }
1201            })
1202            .shared();
1203
1204        self.connections
1205            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1206        task
1207    }
1208}
1209
1210impl From<SshRemoteClient> for AnyProtoClient {
1211    fn from(client: SshRemoteClient) -> Self {
1212        AnyProtoClient::new(client.client.clone())
1213    }
1214}
1215
1216#[async_trait(?Send)]
1217trait RemoteConnection: Send + Sync {
1218    #[allow(clippy::too_many_arguments)]
1219    fn start_proxy(
1220        &self,
1221        remote_binary_path: PathBuf,
1222        unique_identifier: String,
1223        reconnect: bool,
1224        incoming_tx: UnboundedSender<Envelope>,
1225        outgoing_rx: UnboundedReceiver<Envelope>,
1226        connection_activity_tx: Sender<()>,
1227        delegate: Arc<dyn SshClientDelegate>,
1228        cx: &mut AsyncAppContext,
1229    ) -> Task<Result<i32>>;
1230    async fn get_remote_binary_path(
1231        &self,
1232        delegate: &Arc<dyn SshClientDelegate>,
1233        reconnect: bool,
1234        cx: &mut AsyncAppContext,
1235    ) -> Result<PathBuf>;
1236    async fn kill(&self) -> Result<()>;
1237    fn has_been_killed(&self) -> bool;
1238    fn ssh_args(&self) -> Vec<String>;
1239    fn connection_options(&self) -> SshConnectionOptions;
1240
1241    #[cfg(any(test, feature = "test-support"))]
1242    fn simulate_disconnect(&self, _: &AsyncAppContext) {}
1243}
1244
1245struct SshRemoteConnection {
1246    socket: SshSocket,
1247    master_process: Mutex<Option<process::Child>>,
1248    platform: SshPlatform,
1249    _temp_dir: TempDir,
1250}
1251
1252#[async_trait(?Send)]
1253impl RemoteConnection for SshRemoteConnection {
1254    async fn kill(&self) -> Result<()> {
1255        let Some(mut process) = self.master_process.lock().take() else {
1256            return Ok(());
1257        };
1258        process.kill().ok();
1259        process.status().await?;
1260        Ok(())
1261    }
1262
1263    fn has_been_killed(&self) -> bool {
1264        self.master_process.lock().is_none()
1265    }
1266
1267    fn ssh_args(&self) -> Vec<String> {
1268        self.socket.ssh_args()
1269    }
1270
1271    fn connection_options(&self) -> SshConnectionOptions {
1272        self.socket.connection_options.clone()
1273    }
1274
1275    async fn get_remote_binary_path(
1276        &self,
1277        delegate: &Arc<dyn SshClientDelegate>,
1278        reconnect: bool,
1279        cx: &mut AsyncAppContext,
1280    ) -> Result<PathBuf> {
1281        let platform = self.platform;
1282        let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1283        if !reconnect {
1284            self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1285                .await?;
1286        }
1287
1288        let socket = self.socket.clone();
1289        run_cmd(socket.ssh_command(&remote_binary_path.to_string_lossy(), &["version"])).await?;
1290        Ok(remote_binary_path)
1291    }
1292
1293    fn start_proxy(
1294        &self,
1295        remote_binary_path: PathBuf,
1296        unique_identifier: String,
1297        reconnect: bool,
1298        incoming_tx: UnboundedSender<Envelope>,
1299        outgoing_rx: UnboundedReceiver<Envelope>,
1300        connection_activity_tx: Sender<()>,
1301        delegate: Arc<dyn SshClientDelegate>,
1302        cx: &mut AsyncAppContext,
1303    ) -> Task<Result<i32>> {
1304        delegate.set_status(Some("Starting proxy"), cx);
1305
1306        let mut start_proxy_command = shell_script!(
1307            "exec {binary_path} proxy --identifier {identifier}",
1308            binary_path = &remote_binary_path.to_string_lossy(),
1309            identifier = &unique_identifier,
1310        );
1311
1312        if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1313            start_proxy_command = format!(
1314                "RUST_LOG={} {}",
1315                shlex::try_quote(&rust_log).unwrap(),
1316                start_proxy_command
1317            )
1318        }
1319        if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1320            start_proxy_command = format!(
1321                "RUST_BACKTRACE={} {}",
1322                shlex::try_quote(&rust_backtrace).unwrap(),
1323                start_proxy_command
1324            )
1325        }
1326        if reconnect {
1327            start_proxy_command.push_str(" --reconnect");
1328        }
1329
1330        let ssh_proxy_process = match self
1331            .socket
1332            .shell_script(start_proxy_command)
1333            // IMPORTANT: we kill this process when we drop the task that uses it.
1334            .kill_on_drop(true)
1335            .spawn()
1336        {
1337            Ok(process) => process,
1338            Err(error) => {
1339                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)))
1340            }
1341        };
1342
1343        Self::multiplex(
1344            ssh_proxy_process,
1345            incoming_tx,
1346            outgoing_rx,
1347            connection_activity_tx,
1348            &cx,
1349        )
1350    }
1351}
1352
1353impl SshRemoteConnection {
1354    #[cfg(not(unix))]
1355    async fn new(
1356        _connection_options: SshConnectionOptions,
1357        _delegate: Arc<dyn SshClientDelegate>,
1358        _cx: &mut AsyncAppContext,
1359    ) -> Result<Self> {
1360        Err(anyhow!("ssh is not supported on this platform"))
1361    }
1362
1363    #[cfg(unix)]
1364    async fn new(
1365        connection_options: SshConnectionOptions,
1366        delegate: Arc<dyn SshClientDelegate>,
1367        cx: &mut AsyncAppContext,
1368    ) -> Result<Self> {
1369        use futures::AsyncWriteExt as _;
1370        use futures::{io::BufReader, AsyncBufReadExt as _};
1371        use smol::net::unix::UnixStream;
1372        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1373        use util::ResultExt as _;
1374
1375        delegate.set_status(Some("Connecting"), cx);
1376
1377        let url = connection_options.ssh_url();
1378        let temp_dir = tempfile::Builder::new()
1379            .prefix("zed-ssh-session")
1380            .tempdir()?;
1381
1382        // Create a domain socket listener to handle requests from the askpass program.
1383        let askpass_socket = temp_dir.path().join("askpass.sock");
1384        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1385        let listener =
1386            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1387
1388        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<UnixStream>();
1389        let mut kill_tx = Some(askpass_kill_master_tx);
1390
1391        let askpass_task = cx.spawn({
1392            let delegate = delegate.clone();
1393            |mut cx| async move {
1394                let mut askpass_opened_tx = Some(askpass_opened_tx);
1395
1396                while let Ok((mut stream, _)) = listener.accept().await {
1397                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1398                        askpass_opened_tx.send(()).ok();
1399                    }
1400                    let mut buffer = Vec::new();
1401                    let mut reader = BufReader::new(&mut stream);
1402                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1403                        buffer.clear();
1404                    }
1405                    let password_prompt = String::from_utf8_lossy(&buffer);
1406                    if let Some(password) = delegate
1407                        .ask_password(password_prompt.to_string(), &mut cx)
1408                        .await
1409                        .context("failed to get ssh password")
1410                        .and_then(|p| p)
1411                        .log_err()
1412                    {
1413                        stream.write_all(password.as_bytes()).await.log_err();
1414                    } else {
1415                        if let Some(kill_tx) = kill_tx.take() {
1416                            kill_tx.send(stream).log_err();
1417                            break;
1418                        }
1419                    }
1420                }
1421            }
1422        });
1423
1424        // Create an askpass script that communicates back to this process.
1425        let askpass_script = format!(
1426            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1427            askpass_socket = askpass_socket.display(),
1428            print_args = "printf '%s\\0' \"$@\"",
1429            shebang = "#!/bin/sh",
1430        );
1431        let askpass_script_path = temp_dir.path().join("askpass.sh");
1432        fs::write(&askpass_script_path, askpass_script).await?;
1433        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1434
1435        // Start the master SSH process, which does not do anything except for establish
1436        // the connection and keep it open, allowing other ssh commands to reuse it
1437        // via a control socket.
1438        let socket_path = temp_dir.path().join("ssh.sock");
1439
1440        let mut master_process = process::Command::new("ssh")
1441            .stdin(Stdio::null())
1442            .stdout(Stdio::piped())
1443            .stderr(Stdio::piped())
1444            .env("SSH_ASKPASS_REQUIRE", "force")
1445            .env("SSH_ASKPASS", &askpass_script_path)
1446            .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1447            .args([
1448                "-N",
1449                "-o",
1450                "ControlPersist=no",
1451                "-o",
1452                "ControlMaster=yes",
1453                "-o",
1454            ])
1455            .arg(format!("ControlPath={}", socket_path.display()))
1456            .arg(&url)
1457            .kill_on_drop(true)
1458            .spawn()?;
1459
1460        // Wait for this ssh process to close its stdout, indicating that authentication
1461        // has completed.
1462        let mut stdout = master_process.stdout.take().unwrap();
1463        let mut output = Vec::new();
1464        let connection_timeout = Duration::from_secs(10);
1465
1466        let result = select_biased! {
1467            _ = askpass_opened_rx.fuse() => {
1468                select_biased! {
1469                    stream = askpass_kill_master_rx.fuse() => {
1470                        master_process.kill().ok();
1471                        drop(stream);
1472                        Err(anyhow!("SSH connection canceled"))
1473                    }
1474                    // If the askpass script has opened, that means the user is typing
1475                    // their password, in which case we don't want to timeout anymore,
1476                    // since we know a connection has been established.
1477                    result = stdout.read_to_end(&mut output).fuse() => {
1478                        result?;
1479                        Ok(())
1480                    }
1481                }
1482            }
1483            _ = stdout.read_to_end(&mut output).fuse() => {
1484                Ok(())
1485            }
1486            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1487                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1488            }
1489        };
1490
1491        if let Err(e) = result {
1492            return Err(e.context("Failed to connect to host"));
1493        }
1494
1495        drop(askpass_task);
1496
1497        if master_process.try_status()?.is_some() {
1498            output.clear();
1499            let mut stderr = master_process.stderr.take().unwrap();
1500            stderr.read_to_end(&mut output).await?;
1501
1502            let error_message = format!(
1503                "failed to connect: {}",
1504                String::from_utf8_lossy(&output).trim()
1505            );
1506            Err(anyhow!(error_message))?;
1507        }
1508
1509        let socket = SshSocket {
1510            connection_options,
1511            socket_path,
1512        };
1513
1514        let os = run_cmd(socket.ssh_command("uname", &["-s"])).await?;
1515        let arch = run_cmd(socket.ssh_command("uname", &["-m"])).await?;
1516
1517        let os = match os.trim() {
1518            "Darwin" => "macos",
1519            "Linux" => "linux",
1520            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1521        };
1522        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1523            "aarch64"
1524        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1525            "x86_64"
1526        } else {
1527            Err(anyhow!("unknown uname architecture {arch:?}"))?
1528        };
1529
1530        let platform = SshPlatform { os, arch };
1531
1532        Ok(Self {
1533            socket,
1534            master_process: Mutex::new(Some(master_process)),
1535            platform,
1536            _temp_dir: temp_dir,
1537        })
1538    }
1539
1540    fn multiplex(
1541        mut ssh_proxy_process: Child,
1542        incoming_tx: UnboundedSender<Envelope>,
1543        mut outgoing_rx: UnboundedReceiver<Envelope>,
1544        mut connection_activity_tx: Sender<()>,
1545        cx: &AsyncAppContext,
1546    ) -> Task<Result<i32>> {
1547        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1548        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1549        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1550
1551        let mut stdin_buffer = Vec::new();
1552        let mut stdout_buffer = Vec::new();
1553        let mut stderr_buffer = Vec::new();
1554        let mut stderr_offset = 0;
1555
1556        let stdin_task = cx.background_executor().spawn(async move {
1557            while let Some(outgoing) = outgoing_rx.next().await {
1558                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1559            }
1560            anyhow::Ok(())
1561        });
1562
1563        let stdout_task = cx.background_executor().spawn({
1564            let mut connection_activity_tx = connection_activity_tx.clone();
1565            async move {
1566                loop {
1567                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1568                    let len = child_stdout.read(&mut stdout_buffer).await?;
1569
1570                    if len == 0 {
1571                        return anyhow::Ok(());
1572                    }
1573
1574                    if len < MESSAGE_LEN_SIZE {
1575                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1576                    }
1577
1578                    let message_len = message_len_from_buffer(&stdout_buffer);
1579                    let envelope =
1580                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1581                            .await?;
1582                    connection_activity_tx.try_send(()).ok();
1583                    incoming_tx.unbounded_send(envelope).ok();
1584                }
1585            }
1586        });
1587
1588        let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
1589            loop {
1590                stderr_buffer.resize(stderr_offset + 1024, 0);
1591
1592                let len = child_stderr
1593                    .read(&mut stderr_buffer[stderr_offset..])
1594                    .await?;
1595                if len == 0 {
1596                    return anyhow::Ok(());
1597                }
1598
1599                stderr_offset += len;
1600                let mut start_ix = 0;
1601                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1602                    .iter()
1603                    .position(|b| b == &b'\n')
1604                {
1605                    let line_ix = start_ix + ix;
1606                    let content = &stderr_buffer[start_ix..line_ix];
1607                    start_ix = line_ix + 1;
1608                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1609                        record.log(log::logger())
1610                    } else {
1611                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
1612                    }
1613                }
1614                stderr_buffer.drain(0..start_ix);
1615                stderr_offset -= start_ix;
1616
1617                connection_activity_tx.try_send(()).ok();
1618            }
1619        });
1620
1621        cx.spawn(|_| async move {
1622            let result = futures::select! {
1623                result = stdin_task.fuse() => {
1624                    result.context("stdin")
1625                }
1626                result = stdout_task.fuse() => {
1627                    result.context("stdout")
1628                }
1629                result = stderr_task.fuse() => {
1630                    result.context("stderr")
1631                }
1632            };
1633
1634            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1635            match result {
1636                Ok(_) => Ok(status),
1637                Err(error) => Err(error),
1638            }
1639        })
1640    }
1641
1642    async fn ensure_server_binary(
1643        &self,
1644        delegate: &Arc<dyn SshClientDelegate>,
1645        dst_path: &Path,
1646        platform: SshPlatform,
1647        cx: &mut AsyncAppContext,
1648    ) -> Result<()> {
1649        let lock_file = dst_path.with_extension("lock");
1650        let lock_content = {
1651            let timestamp = SystemTime::now()
1652                .duration_since(UNIX_EPOCH)
1653                .context("failed to get timestamp")?
1654                .as_secs();
1655            let source_port = self.get_ssh_source_port().await?;
1656            format!("{} {}", source_port, timestamp)
1657        };
1658
1659        let lock_stale_age = Duration::from_secs(10 * 60);
1660        let max_wait_time = Duration::from_secs(10 * 60);
1661        let check_interval = Duration::from_secs(5);
1662        let start_time = Instant::now();
1663
1664        loop {
1665            let lock_acquired = self.create_lock_file(&lock_file, &lock_content).await?;
1666            if lock_acquired {
1667                delegate.set_status(Some("Acquired lock file on host"), cx);
1668                let result = self
1669                    .update_server_binary_if_needed(delegate, dst_path, platform, cx)
1670                    .await;
1671
1672                self.remove_lock_file(&lock_file).await.ok();
1673
1674                return result;
1675            } else {
1676                if let Ok(is_stale) = self.is_lock_stale(&lock_file, &lock_stale_age).await {
1677                    if is_stale {
1678                        delegate.set_status(
1679                            Some("Detected lock file on host being stale. Removing"),
1680                            cx,
1681                        );
1682                        self.remove_lock_file(&lock_file).await?;
1683                        continue;
1684                    } else {
1685                        if start_time.elapsed() > max_wait_time {
1686                            return Err(anyhow!("Timeout waiting for lock to be released"));
1687                        }
1688                        log::info!(
1689                            "Found lockfile: {:?}. Will check again in {:?}",
1690                            lock_file,
1691                            check_interval
1692                        );
1693                        delegate.set_status(
1694                            Some("Waiting for another Zed instance to finish uploading binary"),
1695                            cx,
1696                        );
1697                        smol::Timer::after(check_interval).await;
1698                        continue;
1699                    }
1700                } else {
1701                    // Unable to check lock, assume it's valid and wait
1702                    if start_time.elapsed() > max_wait_time {
1703                        return Err(anyhow!("Timeout waiting for lock to be released"));
1704                    }
1705                    smol::Timer::after(check_interval).await;
1706                    continue;
1707                }
1708            }
1709        }
1710    }
1711
1712    async fn get_ssh_source_port(&self) -> Result<String> {
1713        let output = run_cmd(self.socket.shell_script("echo $SSH_CLIENT | cut -d' ' -f2"))
1714            .await
1715            .context("failed to get source port from SSH_CLIENT on host")?;
1716
1717        Ok(output.trim().to_string())
1718    }
1719
1720    async fn create_lock_file(&self, lock_file: &Path, content: &str) -> Result<bool> {
1721        let parent_dir = lock_file
1722            .parent()
1723            .ok_or_else(|| anyhow!("Lock file path has no parent directory"))?;
1724
1725        let script = format!(
1726            r#"mkdir -p "{parent_dir}" && [ ! -f "{lock_file}" ] && echo "{content}" > "{lock_file}" && echo "created" || echo "exists""#,
1727            parent_dir = parent_dir.display(),
1728            lock_file = lock_file.display(),
1729            content = content,
1730        );
1731
1732        let output = run_cmd(self.socket.shell_script(&script))
1733            .await
1734            .with_context(|| format!("failed to create a lock file at {:?}", lock_file))?;
1735
1736        Ok(output.trim() == "created")
1737    }
1738
1739    fn generate_stale_check_script(lock_file: &Path, max_age: u64) -> String {
1740        shell_script!(
1741            r#"
1742            if [ ! -f "{lock_file}" ]; then
1743                echo "lock file does not exist"
1744                exit 0
1745            fi
1746
1747            read -r port timestamp < "{lock_file}"
1748
1749            # Check if port is still active
1750            if command -v ss >/dev/null 2>&1; then
1751                if ! ss -n | grep -q ":$port[[:space:]]"; then
1752                    echo "ss reports port $port is not open"
1753                    exit 0
1754                fi
1755            elif command -v netstat >/dev/null 2>&1; then
1756                if ! netstat -n | grep -q ":$port[[:space:]]"; then
1757                    echo "netstat reports port $port is not open"
1758                    exit 0
1759                fi
1760            fi
1761
1762            # Check timestamp
1763            if [ $(( $(date +%s) - timestamp )) -gt {max_age} ]; then
1764                echo "timestamp in lockfile is too old"
1765            else
1766                echo "recent"
1767            fi"#,
1768            lock_file = &lock_file.to_string_lossy(),
1769            max_age = &max_age.to_string()
1770        )
1771    }
1772
1773    async fn is_lock_stale(&self, lock_file: &Path, max_age: &Duration) -> Result<bool> {
1774        let script = Self::generate_stale_check_script(lock_file, max_age.as_secs());
1775
1776        let output = run_cmd(self.socket.shell_script(script))
1777            .await
1778            .with_context(|| {
1779                format!("failed to check whether lock file {:?} is stale", lock_file)
1780            })?;
1781
1782        let trimmed = output.trim();
1783        let is_stale = trimmed != "recent";
1784        log::info!("checked lockfile for staleness. stale: {is_stale}, output: {trimmed:?}");
1785        Ok(is_stale)
1786    }
1787
1788    async fn remove_lock_file(&self, lock_file: &Path) -> Result<()> {
1789        run_cmd(
1790            self.socket
1791                .ssh_command("rm", &["-f", &lock_file.to_string_lossy()]),
1792        )
1793        .await
1794        .context("failed to remove lock file")?;
1795        Ok(())
1796    }
1797
1798    async fn update_server_binary_if_needed(
1799        &self,
1800        delegate: &Arc<dyn SshClientDelegate>,
1801        dst_path: &Path,
1802        platform: SshPlatform,
1803        cx: &mut AsyncAppContext,
1804    ) -> Result<()> {
1805        let current_version = match run_cmd(
1806            self.socket
1807                .ssh_command(&dst_path.to_string_lossy(), &["version"]),
1808        )
1809        .await
1810        {
1811            Ok(version_output) => {
1812                if let Ok(version) = version_output.trim().parse::<SemanticVersion>() {
1813                    Some(ServerVersion::Semantic(version))
1814                } else {
1815                    Some(ServerVersion::Commit(version_output.trim().to_string()))
1816                }
1817            }
1818            Err(_) => None,
1819        };
1820        let (release_channel, wanted_version) = cx.update(|cx| {
1821            let release_channel = ReleaseChannel::global(cx);
1822            let wanted_version = match release_channel {
1823                ReleaseChannel::Nightly => {
1824                    AppCommitSha::try_global(cx).map(|sha| ServerVersion::Commit(sha.0))
1825                }
1826                ReleaseChannel::Dev => None,
1827                _ => Some(ServerVersion::Semantic(AppVersion::global(cx))),
1828            };
1829            (release_channel, wanted_version)
1830        })?;
1831
1832        match (&current_version, &wanted_version) {
1833            (Some(current), Some(wanted)) if current == wanted => {
1834                log::info!("remote development server present and matching client version");
1835                return Ok(());
1836            }
1837            (Some(ServerVersion::Semantic(current)), Some(ServerVersion::Semantic(wanted)))
1838                if current > wanted =>
1839            {
1840                anyhow::bail!("The version of the remote server ({}) is newer than the Zed version ({}). Please update Zed.", current, wanted);
1841            }
1842            _ => {
1843                log::info!("Installing remote development server");
1844            }
1845        }
1846
1847        if self.is_binary_in_use(dst_path).await? {
1848            // When we're not in dev mode, we don't want to switch out the binary if it's
1849            // still open.
1850            // In dev mode, that's fine, since we often kill Zed processes with Ctrl-C and want
1851            // to still replace the binary.
1852            if cfg!(not(debug_assertions)) {
1853                anyhow::bail!("The remote server version ({:?}) does not match the wanted version ({:?}), but is in use by another Zed client so cannot be upgraded.", &current_version, &wanted_version)
1854            } else {
1855                log::info!("Binary is currently in use, ignoring because this is a dev build")
1856            }
1857        }
1858
1859        if wanted_version.is_none() {
1860            if std::env::var("ZED_BUILD_REMOTE_SERVER").is_err() {
1861                if let Some(current_version) = current_version {
1862                    log::warn!(
1863                        "In development, using cached remote server binary version ({})",
1864                        current_version
1865                    );
1866
1867                    return Ok(());
1868                } else {
1869                    anyhow::bail!(
1870                        "ZED_BUILD_REMOTE_SERVER is not set, but no remote server exists at ({:?})",
1871                        dst_path
1872                    )
1873                }
1874            }
1875
1876            #[cfg(debug_assertions)]
1877            {
1878                let src_path = self.build_local(platform, delegate, cx).await?;
1879
1880                return self
1881                    .upload_local_server_binary(&src_path, dst_path, delegate, cx)
1882                    .await;
1883            }
1884
1885            #[cfg(not(debug_assertions))]
1886            anyhow::bail!("Running development build in release mode, cannot cross compile (unset ZED_BUILD_REMOTE_SERVER)")
1887        }
1888
1889        let upload_binary_over_ssh = self.socket.connection_options.upload_binary_over_ssh;
1890
1891        if !upload_binary_over_ssh {
1892            let (url, body) = delegate
1893                .get_download_params(
1894                    platform,
1895                    release_channel,
1896                    wanted_version.clone().and_then(|v| v.semantic_version()),
1897                    cx,
1898                )
1899                .await?;
1900
1901            match self
1902                .download_binary_on_server(&url, &body, dst_path, delegate, cx)
1903                .await
1904            {
1905                Ok(_) => return Ok(()),
1906                Err(e) => {
1907                    log::error!(
1908                        "Failed to download binary on server, attempting to upload server: {}",
1909                        e
1910                    )
1911                }
1912            }
1913        }
1914
1915        let src_path = delegate
1916            .download_server_binary_locally(
1917                platform,
1918                release_channel,
1919                wanted_version.and_then(|v| v.semantic_version()),
1920                cx,
1921            )
1922            .await?;
1923
1924        self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
1925            .await
1926    }
1927
1928    async fn is_binary_in_use(&self, binary_path: &Path) -> Result<bool> {
1929        let script = shell_script!(
1930            r#"
1931            if command -v lsof >/dev/null 2>&1; then
1932                if lsof "{binary_path}" >/dev/null 2>&1; then
1933                    echo "in_use"
1934                    exit 0
1935                fi
1936            elif command -v fuser >/dev/null 2>&1; then
1937                if fuser "{binary_path}" >/dev/null 2>&1; then
1938                    echo "in_use"
1939                    exit 0
1940                fi
1941            fi
1942            echo "not_in_use"
1943            "#,
1944            binary_path = &binary_path.to_string_lossy(),
1945        );
1946
1947        let output = run_cmd(self.socket.shell_script(script))
1948            .await
1949            .context("failed to check if binary is in use")?;
1950
1951        Ok(output.trim() == "in_use")
1952    }
1953
1954    async fn download_binary_on_server(
1955        &self,
1956        url: &str,
1957        body: &str,
1958        dst_path: &Path,
1959        delegate: &Arc<dyn SshClientDelegate>,
1960        cx: &mut AsyncAppContext,
1961    ) -> Result<()> {
1962        let mut dst_path_gz = dst_path.to_path_buf();
1963        dst_path_gz.set_extension("gz");
1964
1965        if let Some(parent) = dst_path.parent() {
1966            run_cmd(
1967                self.socket
1968                    .ssh_command("mkdir", &["-p", &parent.to_string_lossy()]),
1969            )
1970            .await?;
1971        }
1972
1973        delegate.set_status(Some("Downloading remote development server on host"), cx);
1974
1975        let script = shell_script!(
1976            r#"
1977            if command -v curl >/dev/null 2>&1; then
1978                curl -f -L -X GET -H "Content-Type: application/json" -d {body} {url} -o {dst_path} && echo "curl"
1979            elif command -v wget >/dev/null 2>&1; then
1980                wget --max-redirect=5 --method=GET --header="Content-Type: application/json" --body-data={body} {url} -O {dst_path} && echo "wget"
1981            else
1982                echo "Neither curl nor wget is available" >&2
1983                exit 1
1984            fi
1985            "#,
1986            body = body,
1987            url = url,
1988            dst_path = &dst_path_gz.to_string_lossy(),
1989        );
1990
1991        let output = run_cmd(self.socket.shell_script(script))
1992            .await
1993            .context("Failed to download server binary")?;
1994
1995        if !output.contains("curl") && !output.contains("wget") {
1996            return Err(anyhow!("Failed to download server binary: {}", output));
1997        }
1998
1999        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
2000            .await
2001    }
2002
2003    async fn upload_local_server_binary(
2004        &self,
2005        src_path: &Path,
2006        dst_path: &Path,
2007        delegate: &Arc<dyn SshClientDelegate>,
2008        cx: &mut AsyncAppContext,
2009    ) -> Result<()> {
2010        let mut dst_path_gz = dst_path.to_path_buf();
2011        dst_path_gz.set_extension("gz");
2012
2013        if let Some(parent) = dst_path.parent() {
2014            run_cmd(
2015                self.socket
2016                    .ssh_command("mkdir", &["-p", &parent.to_string_lossy()]),
2017            )
2018            .await?;
2019        }
2020
2021        let src_stat = fs::metadata(&src_path).await?;
2022        let size = src_stat.len();
2023
2024        let t0 = Instant::now();
2025        delegate.set_status(Some("Uploading remote development server"), cx);
2026        log::info!("uploading remote development server ({}kb)", size / 1024);
2027        self.upload_file(&src_path, &dst_path_gz)
2028            .await
2029            .context("failed to upload server binary")?;
2030        log::info!("uploaded remote development server in {:?}", t0.elapsed());
2031
2032        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
2033            .await
2034    }
2035
2036    async fn extract_server_binary(
2037        &self,
2038        dst_path: &Path,
2039        dst_path_gz: &Path,
2040        delegate: &Arc<dyn SshClientDelegate>,
2041        cx: &mut AsyncAppContext,
2042    ) -> Result<()> {
2043        delegate.set_status(Some("Extracting remote development server"), cx);
2044        run_cmd(
2045            self.socket
2046                .ssh_command("gunzip", &["-f", &dst_path_gz.to_string_lossy()]),
2047        )
2048        .await?;
2049
2050        let server_mode = 0o755;
2051        delegate.set_status(Some("Marking remote development server executable"), cx);
2052        run_cmd(self.socket.ssh_command(
2053            "chmod",
2054            &[&format!("{:o}", server_mode), &dst_path.to_string_lossy()],
2055        ))
2056        .await?;
2057
2058        Ok(())
2059    }
2060
2061    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
2062        let mut command = process::Command::new("scp");
2063        let output = self
2064            .socket
2065            .ssh_options(&mut command)
2066            .args(
2067                self.socket
2068                    .connection_options
2069                    .port
2070                    .map(|port| vec!["-P".to_string(), port.to_string()])
2071                    .unwrap_or_default(),
2072            )
2073            .arg(src_path)
2074            .arg(format!(
2075                "{}:{}",
2076                self.socket.connection_options.scp_url(),
2077                dest_path.display()
2078            ))
2079            .output()
2080            .await?;
2081
2082        if output.status.success() {
2083            Ok(())
2084        } else {
2085            Err(anyhow!(
2086                "failed to upload file {} -> {}: {}",
2087                src_path.display(),
2088                dest_path.display(),
2089                String::from_utf8_lossy(&output.stderr)
2090            ))
2091        }
2092    }
2093
2094    #[cfg(debug_assertions)]
2095    async fn build_local(
2096        &self,
2097        platform: SshPlatform,
2098        delegate: &Arc<dyn SshClientDelegate>,
2099        cx: &mut AsyncAppContext,
2100    ) -> Result<PathBuf> {
2101        use smol::process::{Command, Stdio};
2102
2103        async fn run_cmd(command: &mut Command) -> Result<()> {
2104            let output = command
2105                .kill_on_drop(true)
2106                .stderr(Stdio::inherit())
2107                .output()
2108                .await?;
2109            if !output.status.success() {
2110                Err(anyhow!("Failed to run command: {:?}", command))?;
2111            }
2112            Ok(())
2113        }
2114
2115        if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
2116            delegate.set_status(Some("Building remote server binary from source"), cx);
2117            log::info!("building remote server binary from source");
2118            run_cmd(Command::new("cargo").args([
2119                "build",
2120                "--package",
2121                "remote_server",
2122                "--features",
2123                "debug-embed",
2124                "--target-dir",
2125                "target/remote_server",
2126            ]))
2127            .await?;
2128
2129            delegate.set_status(Some("Compressing binary"), cx);
2130
2131            run_cmd(Command::new("gzip").args([
2132                "-9",
2133                "-f",
2134                "target/remote_server/debug/remote_server",
2135            ]))
2136            .await?;
2137
2138            let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
2139            return Ok(path);
2140        }
2141        let Some(triple) = platform.triple() else {
2142            anyhow::bail!("can't cross compile for: {:?}", platform);
2143        };
2144        smol::fs::create_dir_all("target/remote_server").await?;
2145
2146        delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2147        log::info!("installing cross");
2148        run_cmd(Command::new("cargo").args([
2149            "install",
2150            "cross",
2151            "--git",
2152            "https://github.com/cross-rs/cross",
2153        ]))
2154        .await?;
2155
2156        delegate.set_status(
2157            Some(&format!(
2158                "Building remote server binary from source for {} with Docker",
2159                &triple
2160            )),
2161            cx,
2162        );
2163        log::info!("building remote server binary from source for {}", &triple);
2164        run_cmd(
2165            Command::new("cross")
2166                .args([
2167                    "build",
2168                    "--package",
2169                    "remote_server",
2170                    "--features",
2171                    "debug-embed",
2172                    "--target-dir",
2173                    "target/remote_server",
2174                    "--target",
2175                    &triple,
2176                ])
2177                .env(
2178                    "CROSS_CONTAINER_OPTS",
2179                    "--mount type=bind,src=./target,dst=/app/target",
2180                ),
2181        )
2182        .await?;
2183
2184        delegate.set_status(Some("Compressing binary"), cx);
2185
2186        run_cmd(Command::new("gzip").args([
2187            "-9",
2188            "-f",
2189            &format!("target/remote_server/{}/debug/remote_server", triple),
2190        ]))
2191        .await?;
2192
2193        let path = std::env::current_dir()?.join(format!(
2194            "target/remote_server/{}/debug/remote_server.gz",
2195            triple
2196        ));
2197
2198        return Ok(path);
2199    }
2200}
2201
2202type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2203
2204pub struct ChannelClient {
2205    next_message_id: AtomicU32,
2206    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2207    buffer: Mutex<VecDeque<Envelope>>,
2208    response_channels: ResponseChannels,
2209    message_handlers: Mutex<ProtoMessageHandlerSet>,
2210    max_received: AtomicU32,
2211    name: &'static str,
2212    task: Mutex<Task<Result<()>>>,
2213}
2214
2215impl ChannelClient {
2216    pub fn new(
2217        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2218        outgoing_tx: mpsc::UnboundedSender<Envelope>,
2219        cx: &AppContext,
2220        name: &'static str,
2221    ) -> Arc<Self> {
2222        Arc::new_cyclic(|this| Self {
2223            outgoing_tx: Mutex::new(outgoing_tx),
2224            next_message_id: AtomicU32::new(0),
2225            max_received: AtomicU32::new(0),
2226            response_channels: ResponseChannels::default(),
2227            message_handlers: Default::default(),
2228            buffer: Mutex::new(VecDeque::new()),
2229            name,
2230            task: Mutex::new(Self::start_handling_messages(
2231                this.clone(),
2232                incoming_rx,
2233                &cx.to_async(),
2234            )),
2235        })
2236    }
2237
2238    fn start_handling_messages(
2239        this: Weak<Self>,
2240        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2241        cx: &AsyncAppContext,
2242    ) -> Task<Result<()>> {
2243        cx.spawn(|cx| async move {
2244            let peer_id = PeerId { owner_id: 0, id: 0 };
2245            while let Some(incoming) = incoming_rx.next().await {
2246                let Some(this) = this.upgrade() else {
2247                    return anyhow::Ok(());
2248                };
2249                if let Some(ack_id) = incoming.ack_id {
2250                    let mut buffer = this.buffer.lock();
2251                    while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2252                        buffer.pop_front();
2253                    }
2254                }
2255                if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2256                {
2257                    log::debug!(
2258                        "{}:ssh message received. name:FlushBufferedMessages",
2259                        this.name
2260                    );
2261                    {
2262                        let buffer = this.buffer.lock();
2263                        for envelope in buffer.iter() {
2264                            this.outgoing_tx
2265                                .lock()
2266                                .unbounded_send(envelope.clone())
2267                                .ok();
2268                        }
2269                    }
2270                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2271                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2272                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
2273                    continue;
2274                }
2275
2276                this.max_received.store(incoming.id, SeqCst);
2277
2278                if let Some(request_id) = incoming.responding_to {
2279                    let request_id = MessageId(request_id);
2280                    let sender = this.response_channels.lock().remove(&request_id);
2281                    if let Some(sender) = sender {
2282                        let (tx, rx) = oneshot::channel();
2283                        if incoming.payload.is_some() {
2284                            sender.send((incoming, tx)).ok();
2285                        }
2286                        rx.await.ok();
2287                    }
2288                } else if let Some(envelope) =
2289                    build_typed_envelope(peer_id, Instant::now(), incoming)
2290                {
2291                    let type_name = envelope.payload_type_name();
2292                    if let Some(future) = ProtoMessageHandlerSet::handle_message(
2293                        &this.message_handlers,
2294                        envelope,
2295                        this.clone().into(),
2296                        cx.clone(),
2297                    ) {
2298                        log::debug!("{}:ssh message received. name:{type_name}", this.name);
2299                        cx.foreground_executor()
2300                            .spawn(async move {
2301                                match future.await {
2302                                    Ok(_) => {
2303                                        log::debug!(
2304                                            "{}:ssh message handled. name:{type_name}",
2305                                            this.name
2306                                        );
2307                                    }
2308                                    Err(error) => {
2309                                        log::error!(
2310                                            "{}:error handling message. type:{}, error:{}",
2311                                            this.name,
2312                                            type_name,
2313                                            format!("{error:#}").lines().fold(
2314                                                String::new(),
2315                                                |mut message, line| {
2316                                                    if !message.is_empty() {
2317                                                        message.push(' ');
2318                                                    }
2319                                                    message.push_str(line);
2320                                                    message
2321                                                }
2322                                            )
2323                                        );
2324                                    }
2325                                }
2326                            })
2327                            .detach()
2328                    } else {
2329                        log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2330                    }
2331                }
2332            }
2333            anyhow::Ok(())
2334        })
2335    }
2336
2337    pub fn reconnect(
2338        self: &Arc<Self>,
2339        incoming_rx: UnboundedReceiver<Envelope>,
2340        outgoing_tx: UnboundedSender<Envelope>,
2341        cx: &AsyncAppContext,
2342    ) {
2343        *self.outgoing_tx.lock() = outgoing_tx;
2344        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2345    }
2346
2347    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
2348        let id = (TypeId::of::<E>(), remote_id);
2349
2350        let mut message_handlers = self.message_handlers.lock();
2351        if message_handlers
2352            .entities_by_type_and_remote_id
2353            .contains_key(&id)
2354        {
2355            panic!("already subscribed to entity");
2356        }
2357
2358        message_handlers.entities_by_type_and_remote_id.insert(
2359            id,
2360            EntityMessageSubscriber::Entity {
2361                handle: entity.downgrade().into(),
2362            },
2363        );
2364    }
2365
2366    pub fn request<T: RequestMessage>(
2367        &self,
2368        payload: T,
2369    ) -> impl 'static + Future<Output = Result<T::Response>> {
2370        self.request_internal(payload, true)
2371    }
2372
2373    fn request_internal<T: RequestMessage>(
2374        &self,
2375        payload: T,
2376        use_buffer: bool,
2377    ) -> impl 'static + Future<Output = Result<T::Response>> {
2378        log::debug!("ssh request start. name:{}", T::NAME);
2379        let response =
2380            self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2381        async move {
2382            let response = response.await?;
2383            log::debug!("ssh request finish. name:{}", T::NAME);
2384            T::Response::from_envelope(response)
2385                .ok_or_else(|| anyhow!("received a response of the wrong type"))
2386        }
2387    }
2388
2389    pub async fn resync(&self, timeout: Duration) -> Result<()> {
2390        smol::future::or(
2391            async {
2392                self.request_internal(proto::FlushBufferedMessages {}, false)
2393                    .await?;
2394
2395                for envelope in self.buffer.lock().iter() {
2396                    self.outgoing_tx
2397                        .lock()
2398                        .unbounded_send(envelope.clone())
2399                        .ok();
2400                }
2401                Ok(())
2402            },
2403            async {
2404                smol::Timer::after(timeout).await;
2405                Err(anyhow!("Timeout detected"))
2406            },
2407        )
2408        .await
2409    }
2410
2411    pub async fn ping(&self, timeout: Duration) -> Result<()> {
2412        smol::future::or(
2413            async {
2414                self.request(proto::Ping {}).await?;
2415                Ok(())
2416            },
2417            async {
2418                smol::Timer::after(timeout).await;
2419                Err(anyhow!("Timeout detected"))
2420            },
2421        )
2422        .await
2423    }
2424
2425    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2426        log::debug!("ssh send name:{}", T::NAME);
2427        self.send_dynamic(payload.into_envelope(0, None, None))
2428    }
2429
2430    fn request_dynamic(
2431        &self,
2432        mut envelope: proto::Envelope,
2433        type_name: &'static str,
2434        use_buffer: bool,
2435    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2436        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2437        let (tx, rx) = oneshot::channel();
2438        let mut response_channels_lock = self.response_channels.lock();
2439        response_channels_lock.insert(MessageId(envelope.id), tx);
2440        drop(response_channels_lock);
2441
2442        let result = if use_buffer {
2443            self.send_buffered(envelope)
2444        } else {
2445            self.send_unbuffered(envelope)
2446        };
2447        async move {
2448            if let Err(error) = &result {
2449                log::error!("failed to send message: {}", error);
2450                return Err(anyhow!("failed to send message: {}", error));
2451            }
2452
2453            let response = rx.await.context("connection lost")?.0;
2454            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2455                return Err(RpcError::from_proto(error, type_name));
2456            }
2457            Ok(response)
2458        }
2459    }
2460
2461    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2462        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2463        self.send_buffered(envelope)
2464    }
2465
2466    fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2467        envelope.ack_id = Some(self.max_received.load(SeqCst));
2468        self.buffer.lock().push_back(envelope.clone());
2469        // ignore errors on send (happen while we're reconnecting)
2470        // assume that the global "disconnected" overlay is sufficient.
2471        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2472        Ok(())
2473    }
2474
2475    fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2476        envelope.ack_id = Some(self.max_received.load(SeqCst));
2477        self.outgoing_tx.lock().unbounded_send(envelope).ok();
2478        Ok(())
2479    }
2480}
2481
2482impl ProtoClient for ChannelClient {
2483    fn request(
2484        &self,
2485        envelope: proto::Envelope,
2486        request_type: &'static str,
2487    ) -> BoxFuture<'static, Result<proto::Envelope>> {
2488        self.request_dynamic(envelope, request_type, true).boxed()
2489    }
2490
2491    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2492        self.send_dynamic(envelope)
2493    }
2494
2495    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2496        self.send_dynamic(envelope)
2497    }
2498
2499    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2500        &self.message_handlers
2501    }
2502
2503    fn is_via_collab(&self) -> bool {
2504        false
2505    }
2506}
2507
2508#[cfg(any(test, feature = "test-support"))]
2509mod fake {
2510    use std::{path::PathBuf, sync::Arc};
2511
2512    use anyhow::Result;
2513    use async_trait::async_trait;
2514    use futures::{
2515        channel::{
2516            mpsc::{self, Sender},
2517            oneshot,
2518        },
2519        select_biased, FutureExt, SinkExt, StreamExt,
2520    };
2521    use gpui::{AsyncAppContext, SemanticVersion, Task, TestAppContext};
2522    use release_channel::ReleaseChannel;
2523    use rpc::proto::Envelope;
2524
2525    use super::{
2526        ChannelClient, RemoteConnection, SshClientDelegate, SshConnectionOptions, SshPlatform,
2527    };
2528
2529    pub(super) struct FakeRemoteConnection {
2530        pub(super) connection_options: SshConnectionOptions,
2531        pub(super) server_channel: Arc<ChannelClient>,
2532        pub(super) server_cx: SendableCx,
2533    }
2534
2535    pub(super) struct SendableCx(AsyncAppContext);
2536    impl SendableCx {
2537        // SAFETY: When run in test mode, GPUI is always single threaded.
2538        pub(super) fn new(cx: &TestAppContext) -> Self {
2539            Self(cx.to_async())
2540        }
2541
2542        // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncAppContext
2543        fn get(&self, _: &AsyncAppContext) -> AsyncAppContext {
2544            self.0.clone()
2545        }
2546    }
2547
2548    // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2549    unsafe impl Send for SendableCx {}
2550    unsafe impl Sync for SendableCx {}
2551
2552    #[async_trait(?Send)]
2553    impl RemoteConnection for FakeRemoteConnection {
2554        async fn kill(&self) -> Result<()> {
2555            Ok(())
2556        }
2557
2558        fn has_been_killed(&self) -> bool {
2559            false
2560        }
2561
2562        fn ssh_args(&self) -> Vec<String> {
2563            Vec::new()
2564        }
2565
2566        fn connection_options(&self) -> SshConnectionOptions {
2567            self.connection_options.clone()
2568        }
2569
2570        fn simulate_disconnect(&self, cx: &AsyncAppContext) {
2571            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2572            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2573            self.server_channel
2574                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2575        }
2576
2577        async fn get_remote_binary_path(
2578            &self,
2579            _delegate: &Arc<dyn SshClientDelegate>,
2580            _reconnect: bool,
2581            _cx: &mut AsyncAppContext,
2582        ) -> Result<PathBuf> {
2583            Ok(PathBuf::new())
2584        }
2585
2586        fn start_proxy(
2587            &self,
2588            _remote_binary_path: PathBuf,
2589            _unique_identifier: String,
2590            _reconnect: bool,
2591            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2592            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2593            mut connection_activity_tx: Sender<()>,
2594            _delegate: Arc<dyn SshClientDelegate>,
2595            cx: &mut AsyncAppContext,
2596        ) -> Task<Result<i32>> {
2597            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2598            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2599
2600            self.server_channel.reconnect(
2601                server_incoming_rx,
2602                server_outgoing_tx,
2603                &self.server_cx.get(cx),
2604            );
2605
2606            cx.background_executor().spawn(async move {
2607                loop {
2608                    select_biased! {
2609                        server_to_client = server_outgoing_rx.next().fuse() => {
2610                            let Some(server_to_client) = server_to_client else {
2611                                return Ok(1)
2612                            };
2613                            connection_activity_tx.try_send(()).ok();
2614                            client_incoming_tx.send(server_to_client).await.ok();
2615                        }
2616                        client_to_server = client_outgoing_rx.next().fuse() => {
2617                            let Some(client_to_server) = client_to_server else {
2618                                return Ok(1)
2619                            };
2620                            server_incoming_tx.send(client_to_server).await.ok();
2621                        }
2622                    }
2623                }
2624            })
2625        }
2626    }
2627
2628    pub(super) struct Delegate;
2629
2630    impl SshClientDelegate for Delegate {
2631        fn ask_password(
2632            &self,
2633            _: String,
2634            _: &mut AsyncAppContext,
2635        ) -> oneshot::Receiver<Result<String>> {
2636            unreachable!()
2637        }
2638
2639        fn download_server_binary_locally(
2640            &self,
2641            _: SshPlatform,
2642            _: ReleaseChannel,
2643            _: Option<SemanticVersion>,
2644            _: &mut AsyncAppContext,
2645        ) -> Task<Result<PathBuf>> {
2646            unreachable!()
2647        }
2648
2649        fn get_download_params(
2650            &self,
2651            _platform: SshPlatform,
2652            _release_channel: ReleaseChannel,
2653            _version: Option<SemanticVersion>,
2654            _cx: &mut AsyncAppContext,
2655        ) -> Task<Result<(String, String)>> {
2656            unreachable!()
2657        }
2658
2659        fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {}
2660
2661        fn remote_server_binary_path(
2662            &self,
2663            _platform: SshPlatform,
2664            _cx: &mut AsyncAppContext,
2665        ) -> Result<PathBuf> {
2666            unreachable!()
2667        }
2668    }
2669}
2670
2671#[cfg(all(test, unix))]
2672mod tests {
2673    use super::*;
2674    use std::fs;
2675    use tempfile::TempDir;
2676
2677    fn run_stale_check_script(
2678        lock_file: &Path,
2679        max_age: Duration,
2680        simulate_port_open: Option<&str>,
2681    ) -> Result<String> {
2682        let wrapper = format!(
2683            r#"
2684            # Mock ss/netstat commands
2685            ss() {{
2686                # Only handle the -n argument
2687                if [ "$1" = "-n" ]; then
2688                    # If we're simulating an open port, output a line containing that port
2689                    if [ "{simulated_port}" != "" ]; then
2690                        echo "ESTAB 0 0 1.2.3.4:{simulated_port} 5.6.7.8:12345"
2691                    fi
2692                fi
2693            }}
2694            netstat() {{
2695                ss "$@"
2696            }}
2697            export -f ss netstat
2698
2699            # Real script starts here
2700            {script}"#,
2701            simulated_port = simulate_port_open.unwrap_or(""),
2702            script = SshRemoteConnection::generate_stale_check_script(lock_file, max_age.as_secs())
2703        );
2704
2705        let output = std::process::Command::new("bash")
2706            .arg("-c")
2707            .arg(&wrapper)
2708            .output()?;
2709
2710        if !output.stderr.is_empty() {
2711            eprintln!("Script stderr: {}", String::from_utf8_lossy(&output.stderr));
2712        }
2713
2714        Ok(String::from_utf8(output.stdout)?.trim().to_string())
2715    }
2716
2717    #[test]
2718    fn test_lock_staleness() -> Result<()> {
2719        let temp_dir = TempDir::new()?;
2720        let lock_file = temp_dir.path().join("test.lock");
2721
2722        // Test 1: No lock file
2723        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), None)?;
2724        assert_eq!(output, "lock file does not exist");
2725
2726        // Test 2: Lock file with port that's not open
2727        fs::write(&lock_file, "54321 1234567890")?;
2728        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("98765"))?;
2729        assert_eq!(output, "ss reports port 54321 is not open");
2730
2731        // Test 3: Lock file with port that is open but old timestamp
2732        let old_timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() - 700; // 700 seconds ago
2733        fs::write(&lock_file, format!("54321 {}", old_timestamp))?;
2734        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("54321"))?;
2735        assert_eq!(output, "timestamp in lockfile is too old");
2736
2737        // Test 4: Lock file with port that is open and recent timestamp
2738        let recent_timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() - 60; // 1 minute ago
2739        fs::write(&lock_file, format!("54321 {}", recent_timestamp))?;
2740        let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("54321"))?;
2741        assert_eq!(output, "recent");
2742
2743        Ok(())
2744    }
2745}