ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use async_trait::async_trait;
  10use collections::HashMap;
  11use futures::{
  12    channel::{
  13        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  14        oneshot,
  15    },
  16    future::BoxFuture,
  17    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  18    StreamExt as _,
  19};
  20use gpui::{
  21    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
  22    WeakModel,
  23};
  24use parking_lot::Mutex;
  25use rpc::{
  26    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  27    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  28};
  29use smol::{
  30    fs,
  31    process::{self, Child, Stdio},
  32};
  33use std::{
  34    any::TypeId,
  35    collections::VecDeque,
  36    ffi::OsStr,
  37    fmt,
  38    ops::ControlFlow,
  39    path::{Path, PathBuf},
  40    sync::{
  41        atomic::{AtomicU32, Ordering::SeqCst},
  42        Arc,
  43    },
  44    time::{Duration, Instant},
  45};
  46use tempfile::TempDir;
  47use util::ResultExt;
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  51)]
  52pub struct SshProjectId(pub u64);
  53
  54#[derive(Clone)]
  55pub struct SshSocket {
  56    connection_options: SshConnectionOptions,
  57    socket_path: PathBuf,
  58}
  59
  60#[derive(Debug, Default, Clone, PartialEq, Eq)]
  61pub struct SshConnectionOptions {
  62    pub host: String,
  63    pub username: Option<String>,
  64    pub port: Option<u16>,
  65    pub password: Option<String>,
  66    pub args: Option<Vec<String>>,
  67}
  68
  69impl SshConnectionOptions {
  70    pub fn parse_command_line(input: &str) -> Result<Self> {
  71        let input = input.trim_start_matches("ssh ");
  72        let mut hostname: Option<String> = None;
  73        let mut username: Option<String> = None;
  74        let mut port: Option<u16> = None;
  75        let mut args = Vec::new();
  76
  77        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
  78        const ALLOWED_OPTS: &[&str] = &[
  79            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
  80        ];
  81        const ALLOWED_ARGS: &[&str] = &[
  82            "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
  83            "-w",
  84        ];
  85
  86        let mut tokens = shlex::split(input)
  87            .ok_or_else(|| anyhow!("invalid input"))?
  88            .into_iter();
  89
  90        'outer: while let Some(arg) = tokens.next() {
  91            if ALLOWED_OPTS.contains(&(&arg as &str)) {
  92                args.push(arg.to_string());
  93                continue;
  94            }
  95            if arg == "-p" {
  96                port = tokens.next().and_then(|arg| arg.parse().ok());
  97                continue;
  98            } else if let Some(p) = arg.strip_prefix("-p") {
  99                port = p.parse().ok();
 100                continue;
 101            }
 102            if arg == "-l" {
 103                username = tokens.next();
 104                continue;
 105            } else if let Some(l) = arg.strip_prefix("-l") {
 106                username = Some(l.to_string());
 107                continue;
 108            }
 109            for a in ALLOWED_ARGS {
 110                if arg == *a {
 111                    args.push(arg);
 112                    if let Some(next) = tokens.next() {
 113                        args.push(next);
 114                    }
 115                    continue 'outer;
 116                } else if arg.starts_with(a) {
 117                    args.push(arg);
 118                    continue 'outer;
 119                }
 120            }
 121            if arg.starts_with("-") || hostname.is_some() {
 122                anyhow::bail!("unsupported argument: {:?}", arg);
 123            }
 124            let mut input = &arg as &str;
 125            if let Some((u, rest)) = input.split_once('@') {
 126                input = rest;
 127                username = Some(u.to_string());
 128            }
 129            if let Some((rest, p)) = input.split_once(':') {
 130                input = rest;
 131                port = p.parse().ok()
 132            }
 133            hostname = Some(input.to_string())
 134        }
 135
 136        let Some(hostname) = hostname else {
 137            anyhow::bail!("missing hostname");
 138        };
 139
 140        Ok(Self {
 141            host: hostname.to_string(),
 142            username: username.clone(),
 143            port,
 144            password: None,
 145            args: Some(args),
 146        })
 147    }
 148
 149    pub fn ssh_url(&self) -> String {
 150        let mut result = String::from("ssh://");
 151        if let Some(username) = &self.username {
 152            result.push_str(username);
 153            result.push('@');
 154        }
 155        result.push_str(&self.host);
 156        if let Some(port) = self.port {
 157            result.push(':');
 158            result.push_str(&port.to_string());
 159        }
 160        result
 161    }
 162
 163    pub fn additional_args(&self) -> Option<&Vec<String>> {
 164        self.args.as_ref()
 165    }
 166
 167    fn scp_url(&self) -> String {
 168        if let Some(username) = &self.username {
 169            format!("{}@{}", username, self.host)
 170        } else {
 171            self.host.clone()
 172        }
 173    }
 174
 175    pub fn connection_string(&self) -> String {
 176        let host = if let Some(username) = &self.username {
 177            format!("{}@{}", username, self.host)
 178        } else {
 179            self.host.clone()
 180        };
 181        if let Some(port) = &self.port {
 182            format!("{}:{}", host, port)
 183        } else {
 184            host
 185        }
 186    }
 187
 188    // Uniquely identifies dev server projects on a remote host. Needs to be
 189    // stable for the same dev server project.
 190    pub fn dev_server_identifier(&self) -> String {
 191        let mut identifier = format!("dev-server-{:?}", self.host);
 192        if let Some(username) = self.username.as_ref() {
 193            identifier.push('-');
 194            identifier.push_str(&username);
 195        }
 196        identifier
 197    }
 198}
 199
 200#[derive(Copy, Clone, Debug)]
 201pub struct SshPlatform {
 202    pub os: &'static str,
 203    pub arch: &'static str,
 204}
 205
 206impl SshPlatform {
 207    pub fn triple(&self) -> Option<String> {
 208        Some(format!(
 209            "{}-{}",
 210            self.arch,
 211            match self.os {
 212                "linux" => "unknown-linux-gnu",
 213                "macos" => "apple-darwin",
 214                _ => return None,
 215            }
 216        ))
 217    }
 218}
 219
 220pub trait SshClientDelegate: Send + Sync {
 221    fn ask_password(
 222        &self,
 223        prompt: String,
 224        cx: &mut AsyncAppContext,
 225    ) -> oneshot::Receiver<Result<String>>;
 226    fn remote_server_binary_path(
 227        &self,
 228        platform: SshPlatform,
 229        cx: &mut AsyncAppContext,
 230    ) -> Result<PathBuf>;
 231    fn get_server_binary(
 232        &self,
 233        platform: SshPlatform,
 234        cx: &mut AsyncAppContext,
 235    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 236    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 237    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 238}
 239
 240impl SshSocket {
 241    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 242        let mut command = process::Command::new("ssh");
 243        self.ssh_options(&mut command)
 244            .arg(self.connection_options.ssh_url())
 245            .arg(program);
 246        command
 247    }
 248
 249    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 250        command
 251            .stdin(Stdio::piped())
 252            .stdout(Stdio::piped())
 253            .stderr(Stdio::piped())
 254            .args(["-o", "ControlMaster=no", "-o"])
 255            .arg(format!("ControlPath={}", self.socket_path.display()))
 256    }
 257
 258    fn ssh_args(&self) -> Vec<String> {
 259        vec![
 260            "-o".to_string(),
 261            "ControlMaster=no".to_string(),
 262            "-o".to_string(),
 263            format!("ControlPath={}", self.socket_path.display()),
 264            self.connection_options.ssh_url(),
 265        ]
 266    }
 267}
 268
 269async fn run_cmd(command: &mut process::Command) -> Result<String> {
 270    let output = command.output().await?;
 271    if output.status.success() {
 272        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 273    } else {
 274        Err(anyhow!(
 275            "failed to run command: {}",
 276            String::from_utf8_lossy(&output.stderr)
 277        ))
 278    }
 279}
 280
 281pub struct ChannelForwarder {
 282    quit_tx: UnboundedSender<()>,
 283    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 284}
 285
 286impl ChannelForwarder {
 287    fn new(
 288        mut incoming_tx: UnboundedSender<Envelope>,
 289        mut outgoing_rx: UnboundedReceiver<Envelope>,
 290        cx: &AsyncAppContext,
 291    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 292        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 293
 294        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 295        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 296
 297        let forwarding_task = cx.background_executor().spawn(async move {
 298            loop {
 299                select_biased! {
 300                    _ = quit_rx.next().fuse() => {
 301                        break;
 302                    },
 303                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 304                        if let Some(envelope) = incoming_envelope {
 305                            if incoming_tx.send(envelope).await.is_err() {
 306                                break;
 307                            }
 308                        } else {
 309                            break;
 310                        }
 311                    }
 312                    outgoing_envelope = outgoing_rx.next().fuse() => {
 313                        if let Some(envelope) = outgoing_envelope {
 314                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 315                                break;
 316                            }
 317                        } else {
 318                            break;
 319                        }
 320                    }
 321                }
 322            }
 323
 324            (incoming_tx, outgoing_rx)
 325        });
 326
 327        (
 328            Self {
 329                forwarding_task,
 330                quit_tx,
 331            },
 332            proxy_incoming_tx,
 333            proxy_outgoing_rx,
 334        )
 335    }
 336
 337    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 338        let _ = self.quit_tx.send(()).await;
 339        self.forwarding_task.await
 340    }
 341}
 342
 343const MAX_MISSED_HEARTBEATS: usize = 5;
 344const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 345const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 346
 347const MAX_RECONNECT_ATTEMPTS: usize = 3;
 348
 349enum State {
 350    Connecting,
 351    Connected {
 352        ssh_connection: Box<dyn SshRemoteProcess>,
 353        delegate: Arc<dyn SshClientDelegate>,
 354        forwarder: ChannelForwarder,
 355
 356        multiplex_task: Task<Result<()>>,
 357        heartbeat_task: Task<Result<()>>,
 358    },
 359    HeartbeatMissed {
 360        missed_heartbeats: usize,
 361
 362        ssh_connection: Box<dyn SshRemoteProcess>,
 363        delegate: Arc<dyn SshClientDelegate>,
 364        forwarder: ChannelForwarder,
 365
 366        multiplex_task: Task<Result<()>>,
 367        heartbeat_task: Task<Result<()>>,
 368    },
 369    Reconnecting,
 370    ReconnectFailed {
 371        ssh_connection: Box<dyn SshRemoteProcess>,
 372        delegate: Arc<dyn SshClientDelegate>,
 373        forwarder: ChannelForwarder,
 374
 375        error: anyhow::Error,
 376        attempts: usize,
 377    },
 378    ReconnectExhausted,
 379    ServerNotRunning,
 380}
 381
 382impl fmt::Display for State {
 383    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 384        match self {
 385            Self::Connecting => write!(f, "connecting"),
 386            Self::Connected { .. } => write!(f, "connected"),
 387            Self::Reconnecting => write!(f, "reconnecting"),
 388            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 389            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 390            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 391            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 392        }
 393    }
 394}
 395
 396impl State {
 397    fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> {
 398        match self {
 399            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 400            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 401            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
 402            _ => None,
 403        }
 404    }
 405
 406    fn can_reconnect(&self) -> bool {
 407        match self {
 408            Self::Connected { .. }
 409            | Self::HeartbeatMissed { .. }
 410            | Self::ReconnectFailed { .. } => true,
 411            State::Connecting
 412            | State::Reconnecting
 413            | State::ReconnectExhausted
 414            | State::ServerNotRunning => false,
 415        }
 416    }
 417
 418    fn is_reconnect_failed(&self) -> bool {
 419        matches!(self, Self::ReconnectFailed { .. })
 420    }
 421
 422    fn is_reconnect_exhausted(&self) -> bool {
 423        matches!(self, Self::ReconnectExhausted { .. })
 424    }
 425
 426    fn is_reconnecting(&self) -> bool {
 427        matches!(self, Self::Reconnecting { .. })
 428    }
 429
 430    fn heartbeat_recovered(self) -> Self {
 431        match self {
 432            Self::HeartbeatMissed {
 433                ssh_connection,
 434                delegate,
 435                forwarder,
 436                multiplex_task,
 437                heartbeat_task,
 438                ..
 439            } => Self::Connected {
 440                ssh_connection,
 441                delegate,
 442                forwarder,
 443                multiplex_task,
 444                heartbeat_task,
 445            },
 446            _ => self,
 447        }
 448    }
 449
 450    fn heartbeat_missed(self) -> Self {
 451        match self {
 452            Self::Connected {
 453                ssh_connection,
 454                delegate,
 455                forwarder,
 456                multiplex_task,
 457                heartbeat_task,
 458            } => Self::HeartbeatMissed {
 459                missed_heartbeats: 1,
 460                ssh_connection,
 461                delegate,
 462                forwarder,
 463                multiplex_task,
 464                heartbeat_task,
 465            },
 466            Self::HeartbeatMissed {
 467                missed_heartbeats,
 468                ssh_connection,
 469                delegate,
 470                forwarder,
 471                multiplex_task,
 472                heartbeat_task,
 473            } => Self::HeartbeatMissed {
 474                missed_heartbeats: missed_heartbeats + 1,
 475                ssh_connection,
 476                delegate,
 477                forwarder,
 478                multiplex_task,
 479                heartbeat_task,
 480            },
 481            _ => self,
 482        }
 483    }
 484}
 485
 486/// The state of the ssh connection.
 487#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 488pub enum ConnectionState {
 489    Connecting,
 490    Connected,
 491    HeartbeatMissed,
 492    Reconnecting,
 493    Disconnected,
 494}
 495
 496impl From<&State> for ConnectionState {
 497    fn from(value: &State) -> Self {
 498        match value {
 499            State::Connecting => Self::Connecting,
 500            State::Connected { .. } => Self::Connected,
 501            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 502            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 503            State::ReconnectExhausted => Self::Disconnected,
 504            State::ServerNotRunning => Self::Disconnected,
 505        }
 506    }
 507}
 508
 509pub struct SshRemoteClient {
 510    client: Arc<ChannelClient>,
 511    unique_identifier: String,
 512    connection_options: SshConnectionOptions,
 513    state: Arc<Mutex<Option<State>>>,
 514}
 515
 516#[derive(Debug)]
 517pub enum SshRemoteEvent {
 518    Disconnected,
 519}
 520
 521impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 522
 523impl SshRemoteClient {
 524    pub fn new(
 525        unique_identifier: String,
 526        connection_options: SshConnectionOptions,
 527        delegate: Arc<dyn SshClientDelegate>,
 528        cx: &AppContext,
 529    ) -> Task<Result<Model<Self>>> {
 530        cx.spawn(|mut cx| async move {
 531            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 532            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 533            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 534
 535            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 536            let this = cx.new_model(|_| Self {
 537                client: client.clone(),
 538                unique_identifier: unique_identifier.clone(),
 539                connection_options: connection_options.clone(),
 540                state: Arc::new(Mutex::new(Some(State::Connecting))),
 541            })?;
 542
 543            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 544                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 545
 546            let (ssh_connection, io_task) = Self::establish_connection(
 547                unique_identifier,
 548                false,
 549                connection_options,
 550                proxy_incoming_tx,
 551                proxy_outgoing_rx,
 552                connection_activity_tx,
 553                delegate.clone(),
 554                &mut cx,
 555            )
 556            .await?;
 557
 558            let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 559
 560            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 561                log::error!("failed to establish connection: {}", error);
 562                delegate.set_error(error.to_string(), &mut cx);
 563                return Err(error);
 564            }
 565
 566            let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 567
 568            this.update(&mut cx, |this, _| {
 569                *this.state.lock() = Some(State::Connected {
 570                    ssh_connection,
 571                    delegate,
 572                    forwarder: proxy,
 573                    multiplex_task,
 574                    heartbeat_task,
 575                });
 576            })?;
 577
 578            Ok(this)
 579        })
 580    }
 581
 582    pub fn shutdown_processes<T: RequestMessage>(
 583        &self,
 584        shutdown_request: Option<T>,
 585    ) -> Option<impl Future<Output = ()>> {
 586        let state = self.state.lock().take()?;
 587        log::info!("shutting down ssh processes");
 588
 589        let State::Connected {
 590            multiplex_task,
 591            heartbeat_task,
 592            ssh_connection,
 593            delegate,
 594            forwarder,
 595        } = state
 596        else {
 597            return None;
 598        };
 599
 600        let client = self.client.clone();
 601
 602        Some(async move {
 603            if let Some(shutdown_request) = shutdown_request {
 604                client.send(shutdown_request).log_err();
 605                // We wait 50ms instead of waiting for a response, because
 606                // waiting for a response would require us to wait on the main thread
 607                // which we want to avoid in an `on_app_quit` callback.
 608                smol::Timer::after(Duration::from_millis(50)).await;
 609            }
 610
 611            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 612            // child of master_process.
 613            drop(multiplex_task);
 614            // Now drop the rest of state, which kills master process.
 615            drop(heartbeat_task);
 616            drop(ssh_connection);
 617            drop(delegate);
 618            drop(forwarder);
 619        })
 620    }
 621
 622    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 623        let mut lock = self.state.lock();
 624
 625        let can_reconnect = lock
 626            .as_ref()
 627            .map(|state| state.can_reconnect())
 628            .unwrap_or(false);
 629        if !can_reconnect {
 630            let error = if let Some(state) = lock.as_ref() {
 631                format!("invalid state, cannot reconnect while in state {state}")
 632            } else {
 633                "no state set".to_string()
 634            };
 635            log::info!("aborting reconnect, because not in state that allows reconnecting");
 636            return Err(anyhow!(error));
 637        }
 638
 639        let state = lock.take().unwrap();
 640        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 641            State::Connected {
 642                ssh_connection,
 643                delegate,
 644                forwarder,
 645                multiplex_task,
 646                heartbeat_task,
 647            }
 648            | State::HeartbeatMissed {
 649                ssh_connection,
 650                delegate,
 651                forwarder,
 652                multiplex_task,
 653                heartbeat_task,
 654                ..
 655            } => {
 656                drop(multiplex_task);
 657                drop(heartbeat_task);
 658                (0, ssh_connection, delegate, forwarder)
 659            }
 660            State::ReconnectFailed {
 661                attempts,
 662                ssh_connection,
 663                delegate,
 664                forwarder,
 665                ..
 666            } => (attempts, ssh_connection, delegate, forwarder),
 667            State::Connecting
 668            | State::Reconnecting
 669            | State::ReconnectExhausted
 670            | State::ServerNotRunning => unreachable!(),
 671        };
 672
 673        let attempts = attempts + 1;
 674        if attempts > MAX_RECONNECT_ATTEMPTS {
 675            log::error!(
 676                "Failed to reconnect to after {} attempts, giving up",
 677                MAX_RECONNECT_ATTEMPTS
 678            );
 679            drop(lock);
 680            self.set_state(State::ReconnectExhausted, cx);
 681            return Ok(());
 682        }
 683        drop(lock);
 684
 685        self.set_state(State::Reconnecting, cx);
 686
 687        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 688
 689        let identifier = self.unique_identifier.clone();
 690        let client = self.client.clone();
 691        let reconnect_task = cx.spawn(|this, mut cx| async move {
 692            macro_rules! failed {
 693                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 694                    return State::ReconnectFailed {
 695                        error: anyhow!($error),
 696                        attempts: $attempts,
 697                        ssh_connection: $ssh_connection,
 698                        delegate: $delegate,
 699                        forwarder: $forwarder,
 700                    };
 701                };
 702            }
 703
 704            if let Err(error) = ssh_connection.kill().await.context("Failed to kill ssh process") {
 705                failed!(error, attempts, ssh_connection, delegate, forwarder);
 706            };
 707
 708            let connection_options = ssh_connection.connection_options();
 709
 710            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 711            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 712                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 713            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 714
 715            let (ssh_connection, io_task) = match Self::establish_connection(
 716                identifier,
 717                true,
 718                connection_options,
 719                proxy_incoming_tx,
 720                proxy_outgoing_rx,
 721                connection_activity_tx,
 722                delegate.clone(),
 723                &mut cx,
 724            )
 725            .await
 726            {
 727                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 728                Err(error) => {
 729                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 730                }
 731            };
 732
 733            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
 734
 735            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
 736                failed!(error, attempts, ssh_connection, delegate, forwarder);
 737            };
 738
 739            State::Connected {
 740                ssh_connection,
 741                delegate,
 742                forwarder,
 743                multiplex_task,
 744                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 745            }
 746        });
 747
 748        cx.spawn(|this, mut cx| async move {
 749            let new_state = reconnect_task.await;
 750            this.update(&mut cx, |this, cx| {
 751                this.try_set_state(cx, |old_state| {
 752                    if old_state.is_reconnecting() {
 753                        match &new_state {
 754                            State::Connecting
 755                            | State::Reconnecting { .. }
 756                            | State::HeartbeatMissed { .. }
 757                            | State::ServerNotRunning => {}
 758                            State::Connected { .. } => {
 759                                log::info!("Successfully reconnected");
 760                            }
 761                            State::ReconnectFailed {
 762                                error, attempts, ..
 763                            } => {
 764                                log::error!(
 765                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 766                                    attempts,
 767                                    error
 768                                );
 769                            }
 770                            State::ReconnectExhausted => {
 771                                log::error!("Reconnect attempt failed and all attempts exhausted");
 772                            }
 773                        }
 774                        Some(new_state)
 775                    } else {
 776                        None
 777                    }
 778                });
 779
 780                if this.state_is(State::is_reconnect_failed) {
 781                    this.reconnect(cx)
 782                } else if this.state_is(State::is_reconnect_exhausted) {
 783                    cx.emit(SshRemoteEvent::Disconnected);
 784                    Ok(())
 785                } else {
 786                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 787                    Ok(())
 788                }
 789            })
 790        })
 791        .detach_and_log_err(cx);
 792
 793        Ok(())
 794    }
 795
 796    fn heartbeat(
 797        this: WeakModel<Self>,
 798        mut connection_activity_rx: mpsc::Receiver<()>,
 799        cx: &mut AsyncAppContext,
 800    ) -> Task<Result<()>> {
 801        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 802            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 803        };
 804
 805        cx.spawn(|mut cx| {
 806            let this = this.clone();
 807            async move {
 808                let mut missed_heartbeats = 0;
 809
 810                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 811                futures::pin_mut!(keepalive_timer);
 812
 813                loop {
 814                    select_biased! {
 815                        result = connection_activity_rx.next().fuse() => {
 816                            if result.is_none() {
 817                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 818                                return Ok(());
 819                            }
 820
 821                            keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 822
 823                            if missed_heartbeats != 0 {
 824                                missed_heartbeats = 0;
 825                                this.update(&mut cx, |this, mut cx| {
 826                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 827                                })?;
 828                            }
 829                        }
 830                        _ = keepalive_timer => {
 831                            log::debug!("Sending heartbeat to server...");
 832
 833                            let result = select_biased! {
 834                                _ = connection_activity_rx.next().fuse() => {
 835                                    Ok(())
 836                                }
 837                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 838                                    ping_result
 839                                }
 840                            };
 841
 842                            if result.is_err() {
 843                                missed_heartbeats += 1;
 844                                log::warn!(
 845                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 846                                    HEARTBEAT_TIMEOUT,
 847                                    missed_heartbeats,
 848                                    MAX_MISSED_HEARTBEATS
 849                                );
 850                            } else if missed_heartbeats != 0 {
 851                                missed_heartbeats = 0;
 852                            } else {
 853                                continue;
 854                            }
 855
 856                            let result = this.update(&mut cx, |this, mut cx| {
 857                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 858                            })?;
 859                            if result.is_break() {
 860                                return Ok(());
 861                            }
 862                        }
 863                    }
 864                }
 865            }
 866        })
 867    }
 868
 869    fn handle_heartbeat_result(
 870        &mut self,
 871        missed_heartbeats: usize,
 872        cx: &mut ModelContext<Self>,
 873    ) -> ControlFlow<()> {
 874        let state = self.state.lock().take().unwrap();
 875        let next_state = if missed_heartbeats > 0 {
 876            state.heartbeat_missed()
 877        } else {
 878            state.heartbeat_recovered()
 879        };
 880
 881        self.set_state(next_state, cx);
 882
 883        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 884            log::error!(
 885                "Missed last {} heartbeats. Reconnecting...",
 886                missed_heartbeats
 887            );
 888
 889            self.reconnect(cx)
 890                .context("failed to start reconnect process after missing heartbeats")
 891                .log_err();
 892            ControlFlow::Break(())
 893        } else {
 894            ControlFlow::Continue(())
 895        }
 896    }
 897
 898    fn multiplex(
 899        mut ssh_proxy_process: Child,
 900        incoming_tx: UnboundedSender<Envelope>,
 901        mut outgoing_rx: UnboundedReceiver<Envelope>,
 902        mut connection_activity_tx: Sender<()>,
 903        cx: &AsyncAppContext,
 904    ) -> Task<Result<Option<i32>>> {
 905        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 906        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 907        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 908
 909        cx.background_executor().spawn(async move {
 910            let mut stdin_buffer = Vec::new();
 911            let mut stdout_buffer = Vec::new();
 912            let mut stderr_buffer = Vec::new();
 913            let mut stderr_offset = 0;
 914
 915            loop {
 916                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 917                stderr_buffer.resize(stderr_offset + 1024, 0);
 918
 919                select_biased! {
 920                    outgoing = outgoing_rx.next().fuse() => {
 921                        let Some(outgoing) = outgoing else {
 922                            return anyhow::Ok(None);
 923                        };
 924
 925                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 926                    }
 927
 928                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 929                        match result {
 930                            Ok(0) => {
 931                                child_stdin.close().await?;
 932                                outgoing_rx.close();
 933                                let status = ssh_proxy_process.status().await?;
 934                                // If we don't have a code, we assume process
 935                                // has been killed and treat it as non-zero exit
 936                                // code
 937                                return Ok(status.code().or_else(|| Some(1)));
 938                            }
 939                            Ok(len) => {
 940                                if len < stdout_buffer.len() {
 941                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 942                                }
 943
 944                                let message_len = message_len_from_buffer(&stdout_buffer);
 945                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 946                                    Ok(envelope) => {
 947                                        connection_activity_tx.try_send(()).ok();
 948                                        incoming_tx.unbounded_send(envelope).ok();
 949                                    }
 950                                    Err(error) => {
 951                                        log::error!("error decoding message {error:?}");
 952                                    }
 953                                }
 954                            }
 955                            Err(error) => {
 956                                Err(anyhow!("error reading stdout: {error:?}"))?;
 957                            }
 958                        }
 959                    }
 960
 961                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 962                        match result {
 963                            Ok(len) => {
 964                                stderr_offset += len;
 965                                let mut start_ix = 0;
 966                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 967                                    let line_ix = start_ix + ix;
 968                                    let content = &stderr_buffer[start_ix..line_ix];
 969                                    start_ix = line_ix + 1;
 970                                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
 971                                        record.log(log::logger())
 972                                    } else {
 973                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 974                                    }
 975                                }
 976                                stderr_buffer.drain(0..start_ix);
 977                                stderr_offset -= start_ix;
 978
 979                                connection_activity_tx.try_send(()).ok();
 980                            }
 981                            Err(error) => {
 982                                Err(anyhow!("error reading stderr: {error:?}"))?;
 983                            }
 984                        }
 985                    }
 986                }
 987            }
 988        })
 989    }
 990
 991    fn monitor(
 992        this: WeakModel<Self>,
 993        io_task: Task<Result<Option<i32>>>,
 994        cx: &AsyncAppContext,
 995    ) -> Task<Result<()>> {
 996        cx.spawn(|mut cx| async move {
 997            let result = io_task.await;
 998
 999            match result {
1000                Ok(Some(exit_code)) => {
1001                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1002                        match error {
1003                            ProxyLaunchError::ServerNotRunning => {
1004                                log::error!("failed to reconnect because server is not running");
1005                                this.update(&mut cx, |this, cx| {
1006                                    this.set_state(State::ServerNotRunning, cx);
1007                                    cx.emit(SshRemoteEvent::Disconnected);
1008                                })?;
1009                            }
1010                        }
1011                    } else if exit_code > 0 {
1012                        log::error!("proxy process terminated unexpectedly");
1013                        this.update(&mut cx, |this, cx| {
1014                            this.reconnect(cx).ok();
1015                        })?;
1016                    }
1017                }
1018                Ok(None) => {}
1019                Err(error) => {
1020                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1021                    this.update(&mut cx, |this, cx| {
1022                        this.reconnect(cx).ok();
1023                    })?;
1024                }
1025            }
1026            Ok(())
1027        })
1028    }
1029
1030    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1031        self.state.lock().as_ref().map_or(false, check)
1032    }
1033
1034    fn try_set_state(
1035        &self,
1036        cx: &mut ModelContext<Self>,
1037        map: impl FnOnce(&State) -> Option<State>,
1038    ) {
1039        let mut lock = self.state.lock();
1040        let new_state = lock.as_ref().and_then(map);
1041
1042        if let Some(new_state) = new_state {
1043            lock.replace(new_state);
1044            cx.notify();
1045        }
1046    }
1047
1048    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
1049        log::info!("setting state to '{}'", &state);
1050        self.state.lock().replace(state);
1051        cx.notify();
1052    }
1053
1054    #[allow(clippy::too_many_arguments)]
1055    async fn establish_connection(
1056        unique_identifier: String,
1057        reconnect: bool,
1058        connection_options: SshConnectionOptions,
1059        proxy_incoming_tx: UnboundedSender<Envelope>,
1060        proxy_outgoing_rx: UnboundedReceiver<Envelope>,
1061        connection_activity_tx: Sender<()>,
1062        delegate: Arc<dyn SshClientDelegate>,
1063        cx: &mut AsyncAppContext,
1064    ) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<Option<i32>>>)> {
1065        #[cfg(any(test, feature = "test-support"))]
1066        if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
1067            let io_task = fake::SshRemoteConnection::multiplex(
1068                fake.connection_options(),
1069                proxy_incoming_tx,
1070                proxy_outgoing_rx,
1071                connection_activity_tx,
1072                cx,
1073            )
1074            .await;
1075            return Ok((fake, io_task));
1076        }
1077
1078        let ssh_connection =
1079            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
1080
1081        let platform = ssh_connection.query_platform().await?;
1082        let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1083        if !reconnect {
1084            ssh_connection
1085                .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1086                .await?;
1087        }
1088
1089        let socket = ssh_connection.socket.clone();
1090        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1091
1092        delegate.set_status(Some("Starting proxy"), cx);
1093
1094        let mut start_proxy_command = format!(
1095            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1096            std::env::var("RUST_LOG").unwrap_or_default(),
1097            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1098            remote_binary_path,
1099            unique_identifier,
1100        );
1101        if reconnect {
1102            start_proxy_command.push_str(" --reconnect");
1103        }
1104
1105        let ssh_proxy_process = socket
1106            .ssh_command(start_proxy_command)
1107            // IMPORTANT: we kill this process when we drop the task that uses it.
1108            .kill_on_drop(true)
1109            .spawn()
1110            .context("failed to spawn remote server")?;
1111
1112        let io_task = Self::multiplex(
1113            ssh_proxy_process,
1114            proxy_incoming_tx,
1115            proxy_outgoing_rx,
1116            connection_activity_tx,
1117            &cx,
1118        );
1119
1120        Ok((Box::new(ssh_connection), io_task))
1121    }
1122
1123    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1124        self.client.subscribe_to_entity(remote_id, entity);
1125    }
1126
1127    pub fn ssh_args(&self) -> Option<Vec<String>> {
1128        self.state
1129            .lock()
1130            .as_ref()
1131            .and_then(|state| state.ssh_connection())
1132            .map(|ssh_connection| ssh_connection.ssh_args())
1133    }
1134
1135    pub fn proto_client(&self) -> AnyProtoClient {
1136        self.client.clone().into()
1137    }
1138
1139    pub fn connection_string(&self) -> String {
1140        self.connection_options.connection_string()
1141    }
1142
1143    pub fn connection_options(&self) -> SshConnectionOptions {
1144        self.connection_options.clone()
1145    }
1146
1147    pub fn connection_state(&self) -> ConnectionState {
1148        self.state
1149            .lock()
1150            .as_ref()
1151            .map(ConnectionState::from)
1152            .unwrap_or(ConnectionState::Disconnected)
1153    }
1154
1155    pub fn is_disconnected(&self) -> bool {
1156        self.connection_state() == ConnectionState::Disconnected
1157    }
1158
1159    #[cfg(any(test, feature = "test-support"))]
1160    pub fn simulate_disconnect(&self, cx: &mut AppContext) -> Task<()> {
1161        use gpui::BorrowAppContext;
1162
1163        let port = self.connection_options().port.unwrap();
1164
1165        let disconnect =
1166            cx.update_global(|c: &mut fake::GlobalConnections, _cx| c.take(port).into_channels());
1167        cx.spawn(|mut cx| async move {
1168            let (input_rx, output_tx) = disconnect.await;
1169            let (forwarder, _, _) = ChannelForwarder::new(input_rx, output_tx, &mut cx);
1170            cx.update_global(|c: &mut fake::GlobalConnections, _cx| c.replace(port, forwarder))
1171                .unwrap()
1172        })
1173    }
1174
1175    #[cfg(any(test, feature = "test-support"))]
1176    pub fn fake_server(
1177        server_cx: &mut gpui::TestAppContext,
1178    ) -> (ChannelForwarder, Arc<ChannelClient>) {
1179        server_cx.update(|cx| {
1180            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
1181            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
1182
1183            // We use the forwarder on the server side (in production we only use one on the client side)
1184            // the idea is that we can simulate a disconnect/reconnect by just messing with the forwarder.
1185            let (forwarder, _, _) =
1186                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx.to_async());
1187
1188            let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
1189            (forwarder, client)
1190        })
1191    }
1192
1193    #[cfg(any(test, feature = "test-support"))]
1194    pub async fn fake_client(
1195        forwarder: ChannelForwarder,
1196        client_cx: &mut gpui::TestAppContext,
1197    ) -> Model<Self> {
1198        use gpui::BorrowAppContext;
1199        client_cx
1200            .update(|cx| {
1201                let port = cx.update_default_global(|c: &mut fake::GlobalConnections, _cx| {
1202                    c.push(forwarder)
1203                });
1204
1205                Self::new(
1206                    "fake".to_string(),
1207                    SshConnectionOptions {
1208                        host: "<fake>".to_string(),
1209                        port: Some(port),
1210                        ..Default::default()
1211                    },
1212                    Arc::new(fake::Delegate),
1213                    cx,
1214                )
1215            })
1216            .await
1217            .unwrap()
1218    }
1219}
1220
1221impl From<SshRemoteClient> for AnyProtoClient {
1222    fn from(client: SshRemoteClient) -> Self {
1223        AnyProtoClient::new(client.client.clone())
1224    }
1225}
1226
1227#[async_trait]
1228trait SshRemoteProcess: Send + Sync {
1229    async fn kill(&mut self) -> Result<()>;
1230    fn ssh_args(&self) -> Vec<String>;
1231    fn connection_options(&self) -> SshConnectionOptions;
1232}
1233
1234struct SshRemoteConnection {
1235    socket: SshSocket,
1236    master_process: process::Child,
1237    _temp_dir: TempDir,
1238}
1239
1240impl Drop for SshRemoteConnection {
1241    fn drop(&mut self) {
1242        if let Err(error) = self.master_process.kill() {
1243            log::error!("failed to kill SSH master process: {}", error);
1244        }
1245    }
1246}
1247
1248#[async_trait]
1249impl SshRemoteProcess for SshRemoteConnection {
1250    async fn kill(&mut self) -> Result<()> {
1251        self.master_process.kill()?;
1252
1253        self.master_process.status().await?;
1254
1255        Ok(())
1256    }
1257
1258    fn ssh_args(&self) -> Vec<String> {
1259        self.socket.ssh_args()
1260    }
1261
1262    fn connection_options(&self) -> SshConnectionOptions {
1263        self.socket.connection_options.clone()
1264    }
1265}
1266
1267impl SshRemoteConnection {
1268    #[cfg(not(unix))]
1269    async fn new(
1270        _connection_options: SshConnectionOptions,
1271        _delegate: Arc<dyn SshClientDelegate>,
1272        _cx: &mut AsyncAppContext,
1273    ) -> Result<Self> {
1274        Err(anyhow!("ssh is not supported on this platform"))
1275    }
1276
1277    #[cfg(unix)]
1278    async fn new(
1279        connection_options: SshConnectionOptions,
1280        delegate: Arc<dyn SshClientDelegate>,
1281        cx: &mut AsyncAppContext,
1282    ) -> Result<Self> {
1283        use futures::{io::BufReader, AsyncBufReadExt as _};
1284        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1285        use util::ResultExt as _;
1286
1287        delegate.set_status(Some("connecting"), cx);
1288
1289        let url = connection_options.ssh_url();
1290        let temp_dir = tempfile::Builder::new()
1291            .prefix("zed-ssh-session")
1292            .tempdir()?;
1293
1294        // Create a domain socket listener to handle requests from the askpass program.
1295        let askpass_socket = temp_dir.path().join("askpass.sock");
1296        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1297        let listener =
1298            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1299
1300        let askpass_task = cx.spawn({
1301            let delegate = delegate.clone();
1302            |mut cx| async move {
1303                let mut askpass_opened_tx = Some(askpass_opened_tx);
1304
1305                while let Ok((mut stream, _)) = listener.accept().await {
1306                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1307                        askpass_opened_tx.send(()).ok();
1308                    }
1309                    let mut buffer = Vec::new();
1310                    let mut reader = BufReader::new(&mut stream);
1311                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1312                        buffer.clear();
1313                    }
1314                    let password_prompt = String::from_utf8_lossy(&buffer);
1315                    if let Some(password) = delegate
1316                        .ask_password(password_prompt.to_string(), &mut cx)
1317                        .await
1318                        .context("failed to get ssh password")
1319                        .and_then(|p| p)
1320                        .log_err()
1321                    {
1322                        stream.write_all(password.as_bytes()).await.log_err();
1323                    }
1324                }
1325            }
1326        });
1327
1328        // Create an askpass script that communicates back to this process.
1329        let askpass_script = format!(
1330            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1331            askpass_socket = askpass_socket.display(),
1332            print_args = "printf '%s\\0' \"$@\"",
1333            shebang = "#!/bin/sh",
1334        );
1335        let askpass_script_path = temp_dir.path().join("askpass.sh");
1336        fs::write(&askpass_script_path, askpass_script).await?;
1337        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1338
1339        // Start the master SSH process, which does not do anything except for establish
1340        // the connection and keep it open, allowing other ssh commands to reuse it
1341        // via a control socket.
1342        let socket_path = temp_dir.path().join("ssh.sock");
1343        let mut master_process = process::Command::new("ssh")
1344            .stdin(Stdio::null())
1345            .stdout(Stdio::piped())
1346            .stderr(Stdio::piped())
1347            .env("SSH_ASKPASS_REQUIRE", "force")
1348            .env("SSH_ASKPASS", &askpass_script_path)
1349            .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1350            .args([
1351                "-N",
1352                "-o",
1353                "ControlPersist=no",
1354                "-o",
1355                "ControlMaster=yes",
1356                "-o",
1357            ])
1358            .arg(format!("ControlPath={}", socket_path.display()))
1359            .arg(&url)
1360            .spawn()?;
1361
1362        // Wait for this ssh process to close its stdout, indicating that authentication
1363        // has completed.
1364        let stdout = master_process.stdout.as_mut().unwrap();
1365        let mut output = Vec::new();
1366        let connection_timeout = Duration::from_secs(10);
1367
1368        let result = select_biased! {
1369            _ = askpass_opened_rx.fuse() => {
1370                // If the askpass script has opened, that means the user is typing
1371                // their password, in which case we don't want to timeout anymore,
1372                // since we know a connection has been established.
1373                stdout.read_to_end(&mut output).await?;
1374                Ok(())
1375            }
1376            result = stdout.read_to_end(&mut output).fuse() => {
1377                result?;
1378                Ok(())
1379            }
1380            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1381                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1382            }
1383        };
1384
1385        if let Err(e) = result {
1386            let error_message = format!("Failed to connect to host: {}.", e);
1387            delegate.set_error(error_message, cx);
1388            return Err(e);
1389        }
1390
1391        drop(askpass_task);
1392
1393        if master_process.try_status()?.is_some() {
1394            output.clear();
1395            let mut stderr = master_process.stderr.take().unwrap();
1396            stderr.read_to_end(&mut output).await?;
1397
1398            let error_message = format!(
1399                "failed to connect: {}",
1400                String::from_utf8_lossy(&output).trim()
1401            );
1402            delegate.set_error(error_message.clone(), cx);
1403            Err(anyhow!(error_message))?;
1404        }
1405
1406        Ok(Self {
1407            socket: SshSocket {
1408                connection_options,
1409                socket_path,
1410            },
1411            master_process,
1412            _temp_dir: temp_dir,
1413        })
1414    }
1415
1416    async fn ensure_server_binary(
1417        &self,
1418        delegate: &Arc<dyn SshClientDelegate>,
1419        dst_path: &Path,
1420        platform: SshPlatform,
1421        cx: &mut AsyncAppContext,
1422    ) -> Result<()> {
1423        if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1424            if let Ok(installed_version) =
1425                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1426            {
1427                log::info!("using cached server binary version {}", installed_version);
1428                return Ok(());
1429            }
1430        }
1431
1432        let mut dst_path_gz = dst_path.to_path_buf();
1433        dst_path_gz.set_extension("gz");
1434
1435        if let Some(parent) = dst_path.parent() {
1436            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1437        }
1438
1439        let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
1440
1441        let mut server_binary_exists = false;
1442        if !server_binary_exists && cfg!(not(debug_assertions)) {
1443            if let Ok(installed_version) =
1444                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1445            {
1446                if installed_version.trim() == version.to_string() {
1447                    server_binary_exists = true;
1448                }
1449            }
1450        }
1451
1452        if server_binary_exists {
1453            log::info!("remote development server already present",);
1454            return Ok(());
1455        }
1456
1457        let src_stat = fs::metadata(&src_path).await?;
1458        let size = src_stat.len();
1459        let server_mode = 0o755;
1460
1461        let t0 = Instant::now();
1462        delegate.set_status(Some("Uploading remote development server"), cx);
1463        log::info!("uploading remote development server ({}kb)", size / 1024);
1464        self.upload_file(&src_path, &dst_path_gz)
1465            .await
1466            .context("failed to upload server binary")?;
1467        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1468
1469        delegate.set_status(Some("Extracting remote development server"), cx);
1470        run_cmd(
1471            self.socket
1472                .ssh_command("gunzip")
1473                .arg("--force")
1474                .arg(&dst_path_gz),
1475        )
1476        .await?;
1477
1478        delegate.set_status(Some("Marking remote development server executable"), cx);
1479        run_cmd(
1480            self.socket
1481                .ssh_command("chmod")
1482                .arg(format!("{:o}", server_mode))
1483                .arg(dst_path),
1484        )
1485        .await?;
1486
1487        Ok(())
1488    }
1489
1490    async fn query_platform(&self) -> Result<SshPlatform> {
1491        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1492        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1493
1494        let os = match os.trim() {
1495            "Darwin" => "macos",
1496            "Linux" => "linux",
1497            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1498        };
1499        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1500            "aarch64"
1501        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1502            "x86_64"
1503        } else {
1504            Err(anyhow!("unknown uname architecture {arch:?}"))?
1505        };
1506
1507        Ok(SshPlatform { os, arch })
1508    }
1509
1510    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1511        let mut command = process::Command::new("scp");
1512        let output = self
1513            .socket
1514            .ssh_options(&mut command)
1515            .args(
1516                self.socket
1517                    .connection_options
1518                    .port
1519                    .map(|port| vec!["-P".to_string(), port.to_string()])
1520                    .unwrap_or_default(),
1521            )
1522            .arg(src_path)
1523            .arg(format!(
1524                "{}:{}",
1525                self.socket.connection_options.scp_url(),
1526                dest_path.display()
1527            ))
1528            .output()
1529            .await?;
1530
1531        if output.status.success() {
1532            Ok(())
1533        } else {
1534            Err(anyhow!(
1535                "failed to upload file {} -> {}: {}",
1536                src_path.display(),
1537                dest_path.display(),
1538                String::from_utf8_lossy(&output.stderr)
1539            ))
1540        }
1541    }
1542}
1543
1544type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1545
1546pub struct ChannelClient {
1547    next_message_id: AtomicU32,
1548    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1549    buffer: Mutex<VecDeque<Envelope>>,
1550    response_channels: ResponseChannels,
1551    message_handlers: Mutex<ProtoMessageHandlerSet>,
1552    max_received: AtomicU32,
1553}
1554
1555impl ChannelClient {
1556    pub fn new(
1557        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1558        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1559        cx: &AppContext,
1560    ) -> Arc<Self> {
1561        let this = Arc::new(Self {
1562            outgoing_tx,
1563            next_message_id: AtomicU32::new(0),
1564            max_received: AtomicU32::new(0),
1565            response_channels: ResponseChannels::default(),
1566            message_handlers: Default::default(),
1567            buffer: Mutex::new(VecDeque::new()),
1568        });
1569
1570        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1571
1572        this
1573    }
1574
1575    fn start_handling_messages(
1576        this: Arc<Self>,
1577        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1578        cx: &AppContext,
1579    ) {
1580        cx.spawn(|cx| {
1581            let this = Arc::downgrade(&this);
1582            async move {
1583                let peer_id = PeerId { owner_id: 0, id: 0 };
1584                while let Some(incoming) = incoming_rx.next().await {
1585                    let Some(this) = this.upgrade() else {
1586                        return anyhow::Ok(());
1587                    };
1588                    if let Some(ack_id) = incoming.ack_id {
1589                        let mut buffer = this.buffer.lock();
1590                        while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1591                            buffer.pop_front();
1592                        }
1593                    }
1594                    if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload {
1595                        {
1596                            let buffer = this.buffer.lock();
1597                            for envelope in buffer.iter() {
1598                                this.outgoing_tx.unbounded_send(envelope.clone()).ok();
1599                            }
1600                        }
1601                        let response = proto::Ack{}.into_envelope(0, Some(incoming.id), None);
1602                        this.send_dynamic(response).ok();
1603                        continue;
1604                    }
1605
1606                    this.max_received.store(incoming.id, SeqCst);
1607
1608                    if let Some(request_id) = incoming.responding_to {
1609                        let request_id = MessageId(request_id);
1610                        let sender = this.response_channels.lock().remove(&request_id);
1611                        if let Some(sender) = sender {
1612                            let (tx, rx) = oneshot::channel();
1613                            if incoming.payload.is_some() {
1614                                sender.send((incoming, tx)).ok();
1615                            }
1616                            rx.await.ok();
1617                        }
1618                    } else if let Some(envelope) =
1619                        build_typed_envelope(peer_id, Instant::now(), incoming)
1620                    {
1621                        let type_name = envelope.payload_type_name();
1622                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1623                            &this.message_handlers,
1624                            envelope,
1625                            this.clone().into(),
1626                            cx.clone(),
1627                        ) {
1628                            log::debug!("ssh message received. name:{type_name}");
1629                            match future.await {
1630                                Ok(_) => {
1631                                    log::debug!("ssh message handled. name:{type_name}");
1632                                }
1633                                Err(error) => {
1634                                    log::error!(
1635                                        "error handling message. type:{type_name}, error:{error}",
1636                                    );
1637                                }
1638                            }
1639                        } else {
1640                            log::error!("unhandled ssh message name:{type_name}");
1641                        }
1642                    }
1643                }
1644                anyhow::Ok(())
1645            }
1646        })
1647        .detach();
1648    }
1649
1650    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1651        let id = (TypeId::of::<E>(), remote_id);
1652
1653        let mut message_handlers = self.message_handlers.lock();
1654        if message_handlers
1655            .entities_by_type_and_remote_id
1656            .contains_key(&id)
1657        {
1658            panic!("already subscribed to entity");
1659        }
1660
1661        message_handlers.entities_by_type_and_remote_id.insert(
1662            id,
1663            EntityMessageSubscriber::Entity {
1664                handle: entity.downgrade().into(),
1665            },
1666        );
1667    }
1668
1669    pub fn request<T: RequestMessage>(
1670        &self,
1671        payload: T,
1672    ) -> impl 'static + Future<Output = Result<T::Response>> {
1673        log::debug!("ssh request start. name:{}", T::NAME);
1674        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1675        async move {
1676            let response = response.await?;
1677            log::debug!("ssh request finish. name:{}", T::NAME);
1678            T::Response::from_envelope(response)
1679                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1680        }
1681    }
1682
1683    pub async fn resync(&self, timeout: Duration) -> Result<()> {
1684        smol::future::or(
1685            async {
1686                self.request(proto::FlushBufferedMessages {}).await?;
1687                for envelope in self.buffer.lock().iter() {
1688                    self.outgoing_tx.unbounded_send(envelope.clone()).ok();
1689                }
1690                Ok(())
1691            },
1692            async {
1693                smol::Timer::after(timeout).await;
1694                Err(anyhow!("Timeout detected"))
1695            },
1696        )
1697        .await
1698    }
1699
1700    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1701        smol::future::or(
1702            async {
1703                self.request(proto::Ping {}).await?;
1704                Ok(())
1705            },
1706            async {
1707                smol::Timer::after(timeout).await;
1708                Err(anyhow!("Timeout detected"))
1709            },
1710        )
1711        .await
1712    }
1713
1714    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1715        log::debug!("ssh send name:{}", T::NAME);
1716        self.send_dynamic(payload.into_envelope(0, None, None))
1717    }
1718
1719    pub fn request_dynamic(
1720        &self,
1721        mut envelope: proto::Envelope,
1722        type_name: &'static str,
1723    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1724        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1725        let (tx, rx) = oneshot::channel();
1726        let mut response_channels_lock = self.response_channels.lock();
1727        response_channels_lock.insert(MessageId(envelope.id), tx);
1728        drop(response_channels_lock);
1729
1730        let result = self.send_buffered(envelope);
1731        async move {
1732            if let Err(error) = &result {
1733                log::error!("failed to send message: {}", error);
1734                return Err(anyhow!("failed to send message: {}", error));
1735            }
1736
1737            let response = rx.await.context("connection lost")?.0;
1738            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1739                return Err(RpcError::from_proto(error, type_name));
1740            }
1741            Ok(response)
1742        }
1743    }
1744
1745    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1746        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1747        self.send_buffered(envelope)
1748    }
1749
1750    pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1751        envelope.ack_id = Some(self.max_received.load(SeqCst));
1752        self.buffer.lock().push_back(envelope.clone());
1753        self.outgoing_tx.unbounded_send(envelope)?;
1754        Ok(())
1755    }
1756}
1757
1758impl ProtoClient for ChannelClient {
1759    fn request(
1760        &self,
1761        envelope: proto::Envelope,
1762        request_type: &'static str,
1763    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1764        self.request_dynamic(envelope, request_type).boxed()
1765    }
1766
1767    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1768        self.send_dynamic(envelope)
1769    }
1770
1771    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1772        self.send_dynamic(envelope)
1773    }
1774
1775    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1776        &self.message_handlers
1777    }
1778
1779    fn is_via_collab(&self) -> bool {
1780        false
1781    }
1782}
1783
1784#[cfg(any(test, feature = "test-support"))]
1785mod fake {
1786    use std::path::PathBuf;
1787
1788    use anyhow::Result;
1789    use async_trait::async_trait;
1790    use futures::{
1791        channel::{
1792            mpsc::{self, Sender},
1793            oneshot,
1794        },
1795        select_biased, FutureExt, SinkExt, StreamExt,
1796    };
1797    use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
1798    use rpc::proto::Envelope;
1799
1800    use super::{
1801        ChannelForwarder, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess,
1802    };
1803
1804    pub(super) struct SshRemoteConnection {
1805        connection_options: SshConnectionOptions,
1806    }
1807
1808    impl SshRemoteConnection {
1809        pub(super) fn new(
1810            connection_options: &SshConnectionOptions,
1811        ) -> Option<Box<dyn SshRemoteProcess>> {
1812            if connection_options.host == "<fake>" {
1813                return Some(Box::new(Self {
1814                    connection_options: connection_options.clone(),
1815                }));
1816            }
1817            return None;
1818        }
1819        pub(super) async fn multiplex(
1820            connection_options: SshConnectionOptions,
1821            mut client_tx: mpsc::UnboundedSender<Envelope>,
1822            mut client_rx: mpsc::UnboundedReceiver<Envelope>,
1823            mut connection_activity_tx: Sender<()>,
1824            cx: &mut AsyncAppContext,
1825        ) -> Task<Result<Option<i32>>> {
1826            let (server_tx, server_rx) = cx
1827                .update(|cx| {
1828                    cx.update_global(|conns: &mut GlobalConnections, _| {
1829                        conns.take(connection_options.port.unwrap())
1830                    })
1831                })
1832                .unwrap()
1833                .into_channels()
1834                .await;
1835
1836            let (forwarder, mut proxy_tx, mut proxy_rx) =
1837                ChannelForwarder::new(server_tx, server_rx, cx);
1838
1839            cx.update(|cx| {
1840                cx.update_global(|conns: &mut GlobalConnections, _| {
1841                    conns.replace(connection_options.port.unwrap(), forwarder)
1842                })
1843            })
1844            .unwrap();
1845
1846            cx.background_executor().spawn(async move {
1847                loop {
1848                    select_biased! {
1849                        server_to_client = proxy_rx.next().fuse() => {
1850                            let Some(server_to_client) = server_to_client else {
1851                                return Ok(Some(1))
1852                            };
1853                            connection_activity_tx.try_send(()).ok();
1854                            client_tx.send(server_to_client).await.ok();
1855                        }
1856                        client_to_server = client_rx.next().fuse() => {
1857                            let Some(client_to_server) = client_to_server else {
1858                                return Ok(None)
1859                            };
1860                            proxy_tx.send(client_to_server).await.ok();
1861
1862                        }
1863                    }
1864                }
1865            })
1866        }
1867    }
1868
1869    #[async_trait]
1870    impl SshRemoteProcess for SshRemoteConnection {
1871        async fn kill(&mut self) -> Result<()> {
1872            Ok(())
1873        }
1874
1875        fn ssh_args(&self) -> Vec<String> {
1876            Vec::new()
1877        }
1878
1879        fn connection_options(&self) -> SshConnectionOptions {
1880            self.connection_options.clone()
1881        }
1882    }
1883
1884    #[derive(Default)]
1885    pub(super) struct GlobalConnections(Vec<Option<ChannelForwarder>>);
1886    impl Global for GlobalConnections {}
1887
1888    impl GlobalConnections {
1889        pub(super) fn push(&mut self, forwarder: ChannelForwarder) -> u16 {
1890            self.0.push(Some(forwarder));
1891            self.0.len() as u16 - 1
1892        }
1893
1894        pub(super) fn take(&mut self, port: u16) -> ChannelForwarder {
1895            self.0
1896                .get_mut(port as usize)
1897                .expect("no fake server for port")
1898                .take()
1899                .expect("fake server is already borrowed")
1900        }
1901        pub(super) fn replace(&mut self, port: u16, forwarder: ChannelForwarder) {
1902            let ret = self
1903                .0
1904                .get_mut(port as usize)
1905                .expect("no fake server for port")
1906                .replace(forwarder);
1907            if ret.is_some() {
1908                panic!("fake server is already replaced");
1909            }
1910        }
1911    }
1912
1913    pub(super) struct Delegate;
1914
1915    impl SshClientDelegate for Delegate {
1916        fn ask_password(
1917            &self,
1918            _: String,
1919            _: &mut AsyncAppContext,
1920        ) -> oneshot::Receiver<Result<String>> {
1921            unreachable!()
1922        }
1923        fn remote_server_binary_path(
1924            &self,
1925            _: SshPlatform,
1926            _: &mut AsyncAppContext,
1927        ) -> Result<PathBuf> {
1928            unreachable!()
1929        }
1930        fn get_server_binary(
1931            &self,
1932            _: SshPlatform,
1933            _: &mut AsyncAppContext,
1934        ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
1935            unreachable!()
1936        }
1937        fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {
1938            unreachable!()
1939        }
1940        fn set_error(&self, _: String, _: &mut AsyncAppContext) {
1941            unreachable!()
1942        }
1943    }
1944}