ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6    proxy::ProxyLaunchError,
   7};
   8use anyhow::{anyhow, Context as _, Result};
   9use collections::HashMap;
  10use futures::{
  11    channel::{
  12        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
  13        oneshot,
  14    },
  15    future::BoxFuture,
  16    select_biased, AsyncReadExt as _, Future, FutureExt as _, SinkExt, StreamExt as _,
  17};
  18use gpui::{
  19    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
  20    WeakModel,
  21};
  22use parking_lot::Mutex;
  23use rpc::{
  24    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  25    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  26};
  27use smol::{
  28    fs,
  29    process::{self, Child, Stdio},
  30};
  31use std::{
  32    any::TypeId,
  33    ffi::OsStr,
  34    fmt,
  35    ops::ControlFlow,
  36    path::{Path, PathBuf},
  37    sync::{
  38        atomic::{AtomicU32, Ordering::SeqCst},
  39        Arc,
  40    },
  41    time::{Duration, Instant},
  42};
  43use tempfile::TempDir;
  44use util::ResultExt;
  45
  46#[derive(
  47    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  48)]
  49pub struct SshProjectId(pub u64);
  50
  51#[derive(Clone)]
  52pub struct SshSocket {
  53    connection_options: SshConnectionOptions,
  54    socket_path: PathBuf,
  55}
  56
  57#[derive(Debug, Default, Clone, PartialEq, Eq)]
  58pub struct SshConnectionOptions {
  59    pub host: String,
  60    pub username: Option<String>,
  61    pub port: Option<u16>,
  62    pub password: Option<String>,
  63    pub args: Option<Vec<String>>,
  64}
  65
  66impl SshConnectionOptions {
  67    pub fn parse_command_line(input: &str) -> Result<Self> {
  68        let input = input.trim_start_matches("ssh ");
  69        let mut hostname: Option<String> = None;
  70        let mut username: Option<String> = None;
  71        let mut port: Option<u16> = None;
  72        let mut args = Vec::new();
  73
  74        // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
  75        const ALLOWED_OPTS: &[&str] = &[
  76            "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
  77        ];
  78        const ALLOWED_ARGS: &[&str] = &[
  79            "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
  80            "-w",
  81        ];
  82
  83        let mut tokens = shlex::split(input)
  84            .ok_or_else(|| anyhow!("invalid input"))?
  85            .into_iter();
  86
  87        'outer: while let Some(arg) = tokens.next() {
  88            if ALLOWED_OPTS.contains(&(&arg as &str)) {
  89                args.push(arg.to_string());
  90                continue;
  91            }
  92            if arg == "-p" {
  93                port = tokens.next().and_then(|arg| arg.parse().ok());
  94                continue;
  95            } else if let Some(p) = arg.strip_prefix("-p") {
  96                port = p.parse().ok();
  97                continue;
  98            }
  99            if arg == "-l" {
 100                username = tokens.next();
 101                continue;
 102            } else if let Some(l) = arg.strip_prefix("-l") {
 103                username = Some(l.to_string());
 104                continue;
 105            }
 106            for a in ALLOWED_ARGS {
 107                if arg == *a {
 108                    args.push(arg);
 109                    if let Some(next) = tokens.next() {
 110                        args.push(next);
 111                    }
 112                    continue 'outer;
 113                } else if arg.starts_with(a) {
 114                    args.push(arg);
 115                    continue 'outer;
 116                }
 117            }
 118            if arg.starts_with("-") || hostname.is_some() {
 119                anyhow::bail!("unsupported argument: {:?}", arg);
 120            }
 121            let mut input = &arg as &str;
 122            if let Some((u, rest)) = input.split_once('@') {
 123                input = rest;
 124                username = Some(u.to_string());
 125            }
 126            if let Some((rest, p)) = input.split_once(':') {
 127                input = rest;
 128                port = p.parse().ok()
 129            }
 130            hostname = Some(input.to_string())
 131        }
 132
 133        let Some(hostname) = hostname else {
 134            anyhow::bail!("missing hostname");
 135        };
 136
 137        Ok(Self {
 138            host: hostname.to_string(),
 139            username: username.clone(),
 140            port,
 141            password: None,
 142            args: Some(args),
 143        })
 144    }
 145
 146    pub fn ssh_url(&self) -> String {
 147        let mut result = String::from("ssh://");
 148        if let Some(username) = &self.username {
 149            result.push_str(username);
 150            result.push('@');
 151        }
 152        result.push_str(&self.host);
 153        if let Some(port) = self.port {
 154            result.push(':');
 155            result.push_str(&port.to_string());
 156        }
 157        result
 158    }
 159
 160    pub fn additional_args(&self) -> Option<&Vec<String>> {
 161        self.args.as_ref()
 162    }
 163
 164    fn scp_url(&self) -> String {
 165        if let Some(username) = &self.username {
 166            format!("{}@{}", username, self.host)
 167        } else {
 168            self.host.clone()
 169        }
 170    }
 171
 172    pub fn connection_string(&self) -> String {
 173        let host = if let Some(username) = &self.username {
 174            format!("{}@{}", username, self.host)
 175        } else {
 176            self.host.clone()
 177        };
 178        if let Some(port) = &self.port {
 179            format!("{}:{}", host, port)
 180        } else {
 181            host
 182        }
 183    }
 184
 185    // Uniquely identifies dev server projects on a remote host. Needs to be
 186    // stable for the same dev server project.
 187    pub fn dev_server_identifier(&self) -> String {
 188        let mut identifier = format!("dev-server-{:?}", self.host);
 189        if let Some(username) = self.username.as_ref() {
 190            identifier.push('-');
 191            identifier.push_str(&username);
 192        }
 193        identifier
 194    }
 195}
 196
 197#[derive(Copy, Clone, Debug)]
 198pub struct SshPlatform {
 199    pub os: &'static str,
 200    pub arch: &'static str,
 201}
 202
 203impl SshPlatform {
 204    pub fn triple(&self) -> Option<String> {
 205        Some(format!(
 206            "{}-{}",
 207            self.arch,
 208            match self.os {
 209                "linux" => "unknown-linux-gnu",
 210                "macos" => "apple-darwin",
 211                _ => return None,
 212            }
 213        ))
 214    }
 215}
 216
 217pub trait SshClientDelegate: Send + Sync {
 218    fn ask_password(
 219        &self,
 220        prompt: String,
 221        cx: &mut AsyncAppContext,
 222    ) -> oneshot::Receiver<Result<String>>;
 223    fn remote_server_binary_path(
 224        &self,
 225        platform: SshPlatform,
 226        cx: &mut AsyncAppContext,
 227    ) -> Result<PathBuf>;
 228    fn get_server_binary(
 229        &self,
 230        platform: SshPlatform,
 231        cx: &mut AsyncAppContext,
 232    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 233    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 234    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 235}
 236
 237impl SshSocket {
 238    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 239        let mut command = process::Command::new("ssh");
 240        self.ssh_options(&mut command)
 241            .arg(self.connection_options.ssh_url())
 242            .arg(program);
 243        command
 244    }
 245
 246    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 247        command
 248            .stdin(Stdio::piped())
 249            .stdout(Stdio::piped())
 250            .stderr(Stdio::piped())
 251            .args(["-o", "ControlMaster=no", "-o"])
 252            .arg(format!("ControlPath={}", self.socket_path.display()))
 253    }
 254
 255    fn ssh_args(&self) -> Vec<String> {
 256        vec![
 257            "-o".to_string(),
 258            "ControlMaster=no".to_string(),
 259            "-o".to_string(),
 260            format!("ControlPath={}", self.socket_path.display()),
 261            self.connection_options.ssh_url(),
 262        ]
 263    }
 264}
 265
 266async fn run_cmd(command: &mut process::Command) -> Result<String> {
 267    let output = command.output().await?;
 268    if output.status.success() {
 269        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 270    } else {
 271        Err(anyhow!(
 272            "failed to run command: {}",
 273            String::from_utf8_lossy(&output.stderr)
 274        ))
 275    }
 276}
 277
 278struct ChannelForwarder {
 279    quit_tx: UnboundedSender<()>,
 280    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 281}
 282
 283impl ChannelForwarder {
 284    fn new(
 285        mut incoming_tx: UnboundedSender<Envelope>,
 286        mut outgoing_rx: UnboundedReceiver<Envelope>,
 287        cx: &AsyncAppContext,
 288    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 289        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 290
 291        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 292        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 293
 294        let forwarding_task = cx.background_executor().spawn(async move {
 295            loop {
 296                select_biased! {
 297                    _ = quit_rx.next().fuse() => {
 298                        break;
 299                    },
 300                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 301                        if let Some(envelope) = incoming_envelope {
 302                            if incoming_tx.send(envelope).await.is_err() {
 303                                break;
 304                            }
 305                        } else {
 306                            break;
 307                        }
 308                    }
 309                    outgoing_envelope = outgoing_rx.next().fuse() => {
 310                        if let Some(envelope) = outgoing_envelope {
 311                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 312                                break;
 313                            }
 314                        } else {
 315                            break;
 316                        }
 317                    }
 318                }
 319            }
 320
 321            (incoming_tx, outgoing_rx)
 322        });
 323
 324        (
 325            Self {
 326                forwarding_task,
 327                quit_tx,
 328            },
 329            proxy_incoming_tx,
 330            proxy_outgoing_rx,
 331        )
 332    }
 333
 334    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 335        let _ = self.quit_tx.send(()).await;
 336        self.forwarding_task.await
 337    }
 338}
 339
 340const MAX_MISSED_HEARTBEATS: usize = 5;
 341const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 342const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 343
 344const MAX_RECONNECT_ATTEMPTS: usize = 3;
 345
 346enum State {
 347    Connecting,
 348    Connected {
 349        ssh_connection: SshRemoteConnection,
 350        delegate: Arc<dyn SshClientDelegate>,
 351        forwarder: ChannelForwarder,
 352
 353        multiplex_task: Task<Result<()>>,
 354        heartbeat_task: Task<Result<()>>,
 355    },
 356    HeartbeatMissed {
 357        missed_heartbeats: usize,
 358
 359        ssh_connection: SshRemoteConnection,
 360        delegate: Arc<dyn SshClientDelegate>,
 361        forwarder: ChannelForwarder,
 362
 363        multiplex_task: Task<Result<()>>,
 364        heartbeat_task: Task<Result<()>>,
 365    },
 366    Reconnecting,
 367    ReconnectFailed {
 368        ssh_connection: SshRemoteConnection,
 369        delegate: Arc<dyn SshClientDelegate>,
 370        forwarder: ChannelForwarder,
 371
 372        error: anyhow::Error,
 373        attempts: usize,
 374    },
 375    ReconnectExhausted,
 376    ServerNotRunning,
 377}
 378
 379impl fmt::Display for State {
 380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 381        match self {
 382            Self::Connecting => write!(f, "connecting"),
 383            Self::Connected { .. } => write!(f, "connected"),
 384            Self::Reconnecting => write!(f, "reconnecting"),
 385            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
 386            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
 387            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
 388            Self::ServerNotRunning { .. } => write!(f, "server not running"),
 389        }
 390    }
 391}
 392
 393impl State {
 394    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
 395        match self {
 396            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
 397            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
 398            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
 399            _ => None,
 400        }
 401    }
 402
 403    fn can_reconnect(&self) -> bool {
 404        match self {
 405            Self::Connected { .. }
 406            | Self::HeartbeatMissed { .. }
 407            | Self::ReconnectFailed { .. } => true,
 408            State::Connecting
 409            | State::Reconnecting
 410            | State::ReconnectExhausted
 411            | State::ServerNotRunning => false,
 412        }
 413    }
 414
 415    fn is_reconnect_failed(&self) -> bool {
 416        matches!(self, Self::ReconnectFailed { .. })
 417    }
 418
 419    fn is_reconnect_exhausted(&self) -> bool {
 420        matches!(self, Self::ReconnectExhausted { .. })
 421    }
 422
 423    fn is_reconnecting(&self) -> bool {
 424        matches!(self, Self::Reconnecting { .. })
 425    }
 426
 427    fn heartbeat_recovered(self) -> Self {
 428        match self {
 429            Self::HeartbeatMissed {
 430                ssh_connection,
 431                delegate,
 432                forwarder,
 433                multiplex_task,
 434                heartbeat_task,
 435                ..
 436            } => Self::Connected {
 437                ssh_connection,
 438                delegate,
 439                forwarder,
 440                multiplex_task,
 441                heartbeat_task,
 442            },
 443            _ => self,
 444        }
 445    }
 446
 447    fn heartbeat_missed(self) -> Self {
 448        match self {
 449            Self::Connected {
 450                ssh_connection,
 451                delegate,
 452                forwarder,
 453                multiplex_task,
 454                heartbeat_task,
 455            } => Self::HeartbeatMissed {
 456                missed_heartbeats: 1,
 457                ssh_connection,
 458                delegate,
 459                forwarder,
 460                multiplex_task,
 461                heartbeat_task,
 462            },
 463            Self::HeartbeatMissed {
 464                missed_heartbeats,
 465                ssh_connection,
 466                delegate,
 467                forwarder,
 468                multiplex_task,
 469                heartbeat_task,
 470            } => Self::HeartbeatMissed {
 471                missed_heartbeats: missed_heartbeats + 1,
 472                ssh_connection,
 473                delegate,
 474                forwarder,
 475                multiplex_task,
 476                heartbeat_task,
 477            },
 478            _ => self,
 479        }
 480    }
 481}
 482
 483/// The state of the ssh connection.
 484#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 485pub enum ConnectionState {
 486    Connecting,
 487    Connected,
 488    HeartbeatMissed,
 489    Reconnecting,
 490    Disconnected,
 491}
 492
 493impl From<&State> for ConnectionState {
 494    fn from(value: &State) -> Self {
 495        match value {
 496            State::Connecting => Self::Connecting,
 497            State::Connected { .. } => Self::Connected,
 498            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
 499            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
 500            State::ReconnectExhausted => Self::Disconnected,
 501            State::ServerNotRunning => Self::Disconnected,
 502        }
 503    }
 504}
 505
 506pub struct SshRemoteClient {
 507    client: Arc<ChannelClient>,
 508    unique_identifier: String,
 509    connection_options: SshConnectionOptions,
 510    state: Arc<Mutex<Option<State>>>,
 511}
 512
 513#[derive(Debug)]
 514pub enum SshRemoteEvent {
 515    Disconnected,
 516}
 517
 518impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
 519
 520impl SshRemoteClient {
 521    pub fn new(
 522        unique_identifier: String,
 523        connection_options: SshConnectionOptions,
 524        delegate: Arc<dyn SshClientDelegate>,
 525        cx: &AppContext,
 526    ) -> Task<Result<Model<Self>>> {
 527        cx.spawn(|mut cx| async move {
 528            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 529            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 530            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 531
 532            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
 533            let this = cx.new_model(|_| Self {
 534                client: client.clone(),
 535                unique_identifier: unique_identifier.clone(),
 536                connection_options: connection_options.clone(),
 537                state: Arc::new(Mutex::new(Some(State::Connecting))),
 538            })?;
 539
 540            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 541                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 542
 543            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 544                unique_identifier,
 545                false,
 546                connection_options,
 547                delegate.clone(),
 548                &mut cx,
 549            )
 550            .await?;
 551
 552            let multiplex_task = Self::multiplex(
 553                this.downgrade(),
 554                ssh_proxy_process,
 555                proxy_incoming_tx,
 556                proxy_outgoing_rx,
 557                connection_activity_tx,
 558                &mut cx,
 559            );
 560
 561            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 562                log::error!("failed to establish connection: {}", error);
 563                delegate.set_error(error.to_string(), &mut cx);
 564                return Err(error);
 565            }
 566
 567            let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 568
 569            this.update(&mut cx, |this, _| {
 570                *this.state.lock() = Some(State::Connected {
 571                    ssh_connection,
 572                    delegate,
 573                    forwarder: proxy,
 574                    multiplex_task,
 575                    heartbeat_task,
 576                });
 577            })?;
 578
 579            Ok(this)
 580        })
 581    }
 582
 583    pub fn shutdown_processes<T: RequestMessage>(
 584        &self,
 585        shutdown_request: Option<T>,
 586    ) -> Option<impl Future<Output = ()>> {
 587        let state = self.state.lock().take()?;
 588        log::info!("shutting down ssh processes");
 589
 590        let State::Connected {
 591            multiplex_task,
 592            heartbeat_task,
 593            ssh_connection,
 594            delegate,
 595            forwarder,
 596        } = state
 597        else {
 598            return None;
 599        };
 600
 601        let client = self.client.clone();
 602
 603        Some(async move {
 604            if let Some(shutdown_request) = shutdown_request {
 605                client.send(shutdown_request).log_err();
 606                // We wait 50ms instead of waiting for a response, because
 607                // waiting for a response would require us to wait on the main thread
 608                // which we want to avoid in an `on_app_quit` callback.
 609                smol::Timer::after(Duration::from_millis(50)).await;
 610            }
 611
 612            // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 613            // child of master_process.
 614            drop(multiplex_task);
 615            // Now drop the rest of state, which kills master process.
 616            drop(heartbeat_task);
 617            drop(ssh_connection);
 618            drop(delegate);
 619            drop(forwarder);
 620        })
 621    }
 622
 623    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
 624        let mut lock = self.state.lock();
 625
 626        let can_reconnect = lock
 627            .as_ref()
 628            .map(|state| state.can_reconnect())
 629            .unwrap_or(false);
 630        if !can_reconnect {
 631            let error = if let Some(state) = lock.as_ref() {
 632                format!("invalid state, cannot reconnect while in state {state}")
 633            } else {
 634                "no state set".to_string()
 635            };
 636            log::info!("aborting reconnect, because not in state that allows reconnecting");
 637            return Err(anyhow!(error));
 638        }
 639
 640        let state = lock.take().unwrap();
 641        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
 642            State::Connected {
 643                ssh_connection,
 644                delegate,
 645                forwarder,
 646                multiplex_task,
 647                heartbeat_task,
 648            }
 649            | State::HeartbeatMissed {
 650                ssh_connection,
 651                delegate,
 652                forwarder,
 653                multiplex_task,
 654                heartbeat_task,
 655                ..
 656            } => {
 657                drop(multiplex_task);
 658                drop(heartbeat_task);
 659                (0, ssh_connection, delegate, forwarder)
 660            }
 661            State::ReconnectFailed {
 662                attempts,
 663                ssh_connection,
 664                delegate,
 665                forwarder,
 666                ..
 667            } => (attempts, ssh_connection, delegate, forwarder),
 668            State::Connecting
 669            | State::Reconnecting
 670            | State::ReconnectExhausted
 671            | State::ServerNotRunning => unreachable!(),
 672        };
 673
 674        let attempts = attempts + 1;
 675        if attempts > MAX_RECONNECT_ATTEMPTS {
 676            log::error!(
 677                "Failed to reconnect to after {} attempts, giving up",
 678                MAX_RECONNECT_ATTEMPTS
 679            );
 680            drop(lock);
 681            self.set_state(State::ReconnectExhausted, cx);
 682            return Ok(());
 683        }
 684        drop(lock);
 685
 686        self.set_state(State::Reconnecting, cx);
 687
 688        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 689
 690        let identifier = self.unique_identifier.clone();
 691        let client = self.client.clone();
 692        let reconnect_task = cx.spawn(|this, mut cx| async move {
 693            macro_rules! failed {
 694                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
 695                    return State::ReconnectFailed {
 696                        error: anyhow!($error),
 697                        attempts: $attempts,
 698                        ssh_connection: $ssh_connection,
 699                        delegate: $delegate,
 700                        forwarder: $forwarder,
 701                    };
 702                };
 703            }
 704
 705            if let Err(error) = ssh_connection.master_process.kill() {
 706                failed!(error, attempts, ssh_connection, delegate, forwarder);
 707            };
 708
 709            if let Err(error) = ssh_connection
 710                .master_process
 711                .status()
 712                .await
 713                .context("Failed to kill ssh process")
 714            {
 715                failed!(error, attempts, ssh_connection, delegate, forwarder);
 716            }
 717
 718            let connection_options = ssh_connection.socket.connection_options.clone();
 719
 720            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
 721            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
 722                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 723            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 724
 725            let (ssh_connection, ssh_process) = match Self::establish_connection(
 726                identifier,
 727                true,
 728                connection_options,
 729                delegate.clone(),
 730                &mut cx,
 731            )
 732            .await
 733            {
 734                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
 735                Err(error) => {
 736                    failed!(error, attempts, ssh_connection, delegate, forwarder);
 737                }
 738            };
 739
 740            let multiplex_task = Self::multiplex(
 741                this.clone(),
 742                ssh_process,
 743                proxy_incoming_tx,
 744                proxy_outgoing_rx,
 745                connection_activity_tx,
 746                &mut cx,
 747            );
 748
 749            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
 750                failed!(error, attempts, ssh_connection, delegate, forwarder);
 751            };
 752
 753            State::Connected {
 754                ssh_connection,
 755                delegate,
 756                forwarder,
 757                multiplex_task,
 758                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
 759            }
 760        });
 761
 762        cx.spawn(|this, mut cx| async move {
 763            let new_state = reconnect_task.await;
 764            this.update(&mut cx, |this, cx| {
 765                this.try_set_state(cx, |old_state| {
 766                    if old_state.is_reconnecting() {
 767                        match &new_state {
 768                            State::Connecting
 769                            | State::Reconnecting { .. }
 770                            | State::HeartbeatMissed { .. }
 771                            | State::ServerNotRunning => {}
 772                            State::Connected { .. } => {
 773                                log::info!("Successfully reconnected");
 774                            }
 775                            State::ReconnectFailed {
 776                                error, attempts, ..
 777                            } => {
 778                                log::error!(
 779                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
 780                                    attempts,
 781                                    error
 782                                );
 783                            }
 784                            State::ReconnectExhausted => {
 785                                log::error!("Reconnect attempt failed and all attempts exhausted");
 786                            }
 787                        }
 788                        Some(new_state)
 789                    } else {
 790                        None
 791                    }
 792                });
 793
 794                if this.state_is(State::is_reconnect_failed) {
 795                    this.reconnect(cx)
 796                } else if this.state_is(State::is_reconnect_exhausted) {
 797                    cx.emit(SshRemoteEvent::Disconnected);
 798                    Ok(())
 799                } else {
 800                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
 801                    Ok(())
 802                }
 803            })
 804        })
 805        .detach_and_log_err(cx);
 806
 807        Ok(())
 808    }
 809
 810    fn heartbeat(
 811        this: WeakModel<Self>,
 812        mut connection_activity_rx: mpsc::Receiver<()>,
 813        cx: &mut AsyncAppContext,
 814    ) -> Task<Result<()>> {
 815        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 816            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 817        };
 818
 819        cx.spawn(|mut cx| {
 820            let this = this.clone();
 821            async move {
 822                let mut missed_heartbeats = 0;
 823
 824                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
 825                futures::pin_mut!(keepalive_timer);
 826
 827                loop {
 828                    select_biased! {
 829                        result = connection_activity_rx.next().fuse() => {
 830                            if result.is_none() {
 831                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
 832                                return Ok(());
 833                            }
 834
 835                            keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
 836
 837                            if missed_heartbeats != 0 {
 838                                missed_heartbeats = 0;
 839                                this.update(&mut cx, |this, mut cx| {
 840                                    this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 841                                })?;
 842                            }
 843                        }
 844                        _ = keepalive_timer => {
 845                            log::debug!("Sending heartbeat to server...");
 846
 847                            let result = select_biased! {
 848                                _ = connection_activity_rx.next().fuse() => {
 849                                    Ok(())
 850                                }
 851                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
 852                                    ping_result
 853                                }
 854                            };
 855
 856                            if result.is_err() {
 857                                missed_heartbeats += 1;
 858                                log::warn!(
 859                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 860                                    HEARTBEAT_TIMEOUT,
 861                                    missed_heartbeats,
 862                                    MAX_MISSED_HEARTBEATS
 863                                );
 864                            } else if missed_heartbeats != 0 {
 865                                missed_heartbeats = 0;
 866                            } else {
 867                                continue;
 868                            }
 869
 870                            let result = this.update(&mut cx, |this, mut cx| {
 871                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
 872                            })?;
 873                            if result.is_break() {
 874                                return Ok(());
 875                            }
 876                        }
 877                    }
 878                }
 879            }
 880        })
 881    }
 882
 883    fn handle_heartbeat_result(
 884        &mut self,
 885        missed_heartbeats: usize,
 886        cx: &mut ModelContext<Self>,
 887    ) -> ControlFlow<()> {
 888        let state = self.state.lock().take().unwrap();
 889        let next_state = if missed_heartbeats > 0 {
 890            state.heartbeat_missed()
 891        } else {
 892            state.heartbeat_recovered()
 893        };
 894
 895        self.set_state(next_state, cx);
 896
 897        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 898            log::error!(
 899                "Missed last {} heartbeats. Reconnecting...",
 900                missed_heartbeats
 901            );
 902
 903            self.reconnect(cx)
 904                .context("failed to start reconnect process after missing heartbeats")
 905                .log_err();
 906            ControlFlow::Break(())
 907        } else {
 908            ControlFlow::Continue(())
 909        }
 910    }
 911
 912    fn multiplex(
 913        this: WeakModel<Self>,
 914        mut ssh_proxy_process: Child,
 915        incoming_tx: UnboundedSender<Envelope>,
 916        mut outgoing_rx: UnboundedReceiver<Envelope>,
 917        mut connection_activity_tx: Sender<()>,
 918        cx: &AsyncAppContext,
 919    ) -> Task<Result<()>> {
 920        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 921        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 922        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 923
 924        let mut stdin_buffer = Vec::new();
 925        let mut stdout_buffer = Vec::new();
 926        let mut stderr_buffer = Vec::new();
 927        let mut stderr_offset = 0;
 928
 929        let stdin_task = cx.background_executor().spawn(async move {
 930            while let Some(outgoing) = outgoing_rx.next().await {
 931                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 932            }
 933            anyhow::Ok(())
 934        });
 935
 936        let stdout_task = cx.background_executor().spawn({
 937            let mut connection_activity_tx = connection_activity_tx.clone();
 938            async move {
 939                loop {
 940                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 941                    let len = child_stdout.read(&mut stdout_buffer).await?;
 942
 943                    if len == 0 {
 944                        return anyhow::Ok(());
 945                    }
 946
 947                    if len < MESSAGE_LEN_SIZE {
 948                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 949                    }
 950
 951                    let message_len = message_len_from_buffer(&stdout_buffer);
 952                    let envelope =
 953                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
 954                            .await?;
 955                    connection_activity_tx.try_send(()).ok();
 956                    incoming_tx.unbounded_send(envelope).ok();
 957                }
 958            }
 959        });
 960
 961        let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
 962            loop {
 963                stderr_buffer.resize(stderr_offset + 1024, 0);
 964
 965                let len = child_stderr
 966                    .read(&mut stderr_buffer[stderr_offset..])
 967                    .await?;
 968
 969                stderr_offset += len;
 970                let mut start_ix = 0;
 971                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
 972                    .iter()
 973                    .position(|b| b == &b'\n')
 974                {
 975                    let line_ix = start_ix + ix;
 976                    let content = &stderr_buffer[start_ix..line_ix];
 977                    start_ix = line_ix + 1;
 978                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
 979                        record.log(log::logger())
 980                    } else {
 981                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 982                    }
 983                }
 984                stderr_buffer.drain(0..start_ix);
 985                stderr_offset -= start_ix;
 986
 987                connection_activity_tx.try_send(()).ok();
 988            }
 989        });
 990
 991        cx.spawn(|mut cx| async move {
 992            let result = futures::select! {
 993                result = stdin_task.fuse() => {
 994                    result.context("stdin")
 995                }
 996                result = stdout_task.fuse() => {
 997                    result.context("stdout")
 998                }
 999                result = stderr_task.fuse() => {
1000                    result.context("stderr")
1001                }
1002            };
1003
1004            match result {
1005                Ok(_) => {
1006                    let exit_code = ssh_proxy_process.status().await?.code().unwrap_or(1);
1007
1008                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1009                        match error {
1010                            ProxyLaunchError::ServerNotRunning => {
1011                                log::error!("failed to reconnect because server is not running");
1012                                this.update(&mut cx, |this, cx| {
1013                                    this.set_state(State::ServerNotRunning, cx);
1014                                    cx.emit(SshRemoteEvent::Disconnected);
1015                                })?;
1016                            }
1017                        }
1018                    } else if exit_code > 0 {
1019                        log::error!("proxy process terminated unexpectedly");
1020                        this.update(&mut cx, |this, cx| {
1021                            this.reconnect(cx).ok();
1022                        })?;
1023                    }
1024                }
1025                Err(error) => {
1026                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1027                    this.update(&mut cx, |this, cx| {
1028                        this.reconnect(cx).ok();
1029                    })?;
1030                }
1031            }
1032
1033            Ok(())
1034        })
1035    }
1036
1037    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1038        self.state.lock().as_ref().map_or(false, check)
1039    }
1040
1041    fn try_set_state(
1042        &self,
1043        cx: &mut ModelContext<Self>,
1044        map: impl FnOnce(&State) -> Option<State>,
1045    ) {
1046        let mut lock = self.state.lock();
1047        let new_state = lock.as_ref().and_then(map);
1048
1049        if let Some(new_state) = new_state {
1050            lock.replace(new_state);
1051            cx.notify();
1052        }
1053    }
1054
1055    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
1056        log::info!("setting state to '{}'", &state);
1057        self.state.lock().replace(state);
1058        cx.notify();
1059    }
1060
1061    async fn establish_connection(
1062        unique_identifier: String,
1063        reconnect: bool,
1064        connection_options: SshConnectionOptions,
1065        delegate: Arc<dyn SshClientDelegate>,
1066        cx: &mut AsyncAppContext,
1067    ) -> Result<(SshRemoteConnection, Child)> {
1068        let ssh_connection =
1069            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
1070
1071        let platform = ssh_connection.query_platform().await?;
1072        let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1073        ssh_connection
1074            .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1075            .await?;
1076
1077        let socket = ssh_connection.socket.clone();
1078        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1079
1080        delegate.set_status(Some("Starting proxy"), cx);
1081
1082        let mut start_proxy_command = format!(
1083            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1084            std::env::var("RUST_LOG").unwrap_or_default(),
1085            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1086            remote_binary_path,
1087            unique_identifier,
1088        );
1089        if reconnect {
1090            start_proxy_command.push_str(" --reconnect");
1091        }
1092
1093        let ssh_proxy_process = socket
1094            .ssh_command(start_proxy_command)
1095            // IMPORTANT: we kill this process when we drop the task that uses it.
1096            .kill_on_drop(true)
1097            .spawn()
1098            .context("failed to spawn remote server")?;
1099
1100        Ok((ssh_connection, ssh_proxy_process))
1101    }
1102
1103    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1104        self.client.subscribe_to_entity(remote_id, entity);
1105    }
1106
1107    pub fn ssh_args(&self) -> Option<Vec<String>> {
1108        self.state
1109            .lock()
1110            .as_ref()
1111            .and_then(|state| state.ssh_connection())
1112            .map(|ssh_connection| ssh_connection.socket.ssh_args())
1113    }
1114
1115    pub fn proto_client(&self) -> AnyProtoClient {
1116        self.client.clone().into()
1117    }
1118
1119    pub fn connection_string(&self) -> String {
1120        self.connection_options.connection_string()
1121    }
1122
1123    pub fn connection_options(&self) -> SshConnectionOptions {
1124        self.connection_options.clone()
1125    }
1126
1127    #[cfg(not(any(test, feature = "test-support")))]
1128    pub fn connection_state(&self) -> ConnectionState {
1129        self.state
1130            .lock()
1131            .as_ref()
1132            .map(ConnectionState::from)
1133            .unwrap_or(ConnectionState::Disconnected)
1134    }
1135
1136    #[cfg(any(test, feature = "test-support"))]
1137    pub fn connection_state(&self) -> ConnectionState {
1138        ConnectionState::Connected
1139    }
1140
1141    pub fn is_disconnected(&self) -> bool {
1142        self.connection_state() == ConnectionState::Disconnected
1143    }
1144
1145    #[cfg(any(test, feature = "test-support"))]
1146    pub fn fake(
1147        client_cx: &mut gpui::TestAppContext,
1148        server_cx: &mut gpui::TestAppContext,
1149    ) -> (Model<Self>, Arc<ChannelClient>) {
1150        use gpui::Context;
1151
1152        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1153        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1154
1155        (
1156            client_cx.update(|cx| {
1157                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1158                cx.new_model(|_| Self {
1159                    client,
1160                    unique_identifier: "fake".to_string(),
1161                    connection_options: SshConnectionOptions::default(),
1162                    state: Arc::new(Mutex::new(None)),
1163                })
1164            }),
1165            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1166        )
1167    }
1168}
1169
1170impl From<SshRemoteClient> for AnyProtoClient {
1171    fn from(client: SshRemoteClient) -> Self {
1172        AnyProtoClient::new(client.client.clone())
1173    }
1174}
1175
1176struct SshRemoteConnection {
1177    socket: SshSocket,
1178    master_process: process::Child,
1179    _temp_dir: TempDir,
1180}
1181
1182impl Drop for SshRemoteConnection {
1183    fn drop(&mut self) {
1184        if let Err(error) = self.master_process.kill() {
1185            log::error!("failed to kill SSH master process: {}", error);
1186        }
1187    }
1188}
1189
1190impl SshRemoteConnection {
1191    #[cfg(not(unix))]
1192    async fn new(
1193        _connection_options: SshConnectionOptions,
1194        _delegate: Arc<dyn SshClientDelegate>,
1195        _cx: &mut AsyncAppContext,
1196    ) -> Result<Self> {
1197        Err(anyhow!("ssh is not supported on this platform"))
1198    }
1199
1200    #[cfg(unix)]
1201    async fn new(
1202        connection_options: SshConnectionOptions,
1203        delegate: Arc<dyn SshClientDelegate>,
1204        cx: &mut AsyncAppContext,
1205    ) -> Result<Self> {
1206        use futures::AsyncWriteExt as _;
1207        use futures::{io::BufReader, AsyncBufReadExt as _};
1208        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1209        use util::ResultExt as _;
1210
1211        delegate.set_status(Some("connecting"), cx);
1212
1213        let url = connection_options.ssh_url();
1214        let temp_dir = tempfile::Builder::new()
1215            .prefix("zed-ssh-session")
1216            .tempdir()?;
1217
1218        // Create a domain socket listener to handle requests from the askpass program.
1219        let askpass_socket = temp_dir.path().join("askpass.sock");
1220        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1221        let listener =
1222            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1223
1224        let askpass_task = cx.spawn({
1225            let delegate = delegate.clone();
1226            |mut cx| async move {
1227                let mut askpass_opened_tx = Some(askpass_opened_tx);
1228
1229                while let Ok((mut stream, _)) = listener.accept().await {
1230                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1231                        askpass_opened_tx.send(()).ok();
1232                    }
1233                    let mut buffer = Vec::new();
1234                    let mut reader = BufReader::new(&mut stream);
1235                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
1236                        buffer.clear();
1237                    }
1238                    let password_prompt = String::from_utf8_lossy(&buffer);
1239                    if let Some(password) = delegate
1240                        .ask_password(password_prompt.to_string(), &mut cx)
1241                        .await
1242                        .context("failed to get ssh password")
1243                        .and_then(|p| p)
1244                        .log_err()
1245                    {
1246                        stream.write_all(password.as_bytes()).await.log_err();
1247                    }
1248                }
1249            }
1250        });
1251
1252        // Create an askpass script that communicates back to this process.
1253        let askpass_script = format!(
1254            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1255            askpass_socket = askpass_socket.display(),
1256            print_args = "printf '%s\\0' \"$@\"",
1257            shebang = "#!/bin/sh",
1258        );
1259        let askpass_script_path = temp_dir.path().join("askpass.sh");
1260        fs::write(&askpass_script_path, askpass_script).await?;
1261        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1262
1263        // Start the master SSH process, which does not do anything except for establish
1264        // the connection and keep it open, allowing other ssh commands to reuse it
1265        // via a control socket.
1266        let socket_path = temp_dir.path().join("ssh.sock");
1267        let mut master_process = process::Command::new("ssh")
1268            .stdin(Stdio::null())
1269            .stdout(Stdio::piped())
1270            .stderr(Stdio::piped())
1271            .env("SSH_ASKPASS_REQUIRE", "force")
1272            .env("SSH_ASKPASS", &askpass_script_path)
1273            .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1274            .args([
1275                "-N",
1276                "-o",
1277                "ControlPersist=no",
1278                "-o",
1279                "ControlMaster=yes",
1280                "-o",
1281            ])
1282            .arg(format!("ControlPath={}", socket_path.display()))
1283            .arg(&url)
1284            .spawn()?;
1285
1286        // Wait for this ssh process to close its stdout, indicating that authentication
1287        // has completed.
1288        let stdout = master_process.stdout.as_mut().unwrap();
1289        let mut output = Vec::new();
1290        let connection_timeout = Duration::from_secs(10);
1291
1292        let result = select_biased! {
1293            _ = askpass_opened_rx.fuse() => {
1294                // If the askpass script has opened, that means the user is typing
1295                // their password, in which case we don't want to timeout anymore,
1296                // since we know a connection has been established.
1297                stdout.read_to_end(&mut output).await?;
1298                Ok(())
1299            }
1300            result = stdout.read_to_end(&mut output).fuse() => {
1301                result?;
1302                Ok(())
1303            }
1304            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1305                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1306            }
1307        };
1308
1309        if let Err(e) = result {
1310            let error_message = format!("Failed to connect to host: {}.", e);
1311            delegate.set_error(error_message, cx);
1312            return Err(e);
1313        }
1314
1315        drop(askpass_task);
1316
1317        if master_process.try_status()?.is_some() {
1318            output.clear();
1319            let mut stderr = master_process.stderr.take().unwrap();
1320            stderr.read_to_end(&mut output).await?;
1321
1322            let error_message = format!(
1323                "failed to connect: {}",
1324                String::from_utf8_lossy(&output).trim()
1325            );
1326            delegate.set_error(error_message.clone(), cx);
1327            Err(anyhow!(error_message))?;
1328        }
1329
1330        Ok(Self {
1331            socket: SshSocket {
1332                connection_options,
1333                socket_path,
1334            },
1335            master_process,
1336            _temp_dir: temp_dir,
1337        })
1338    }
1339
1340    async fn ensure_server_binary(
1341        &self,
1342        delegate: &Arc<dyn SshClientDelegate>,
1343        dst_path: &Path,
1344        platform: SshPlatform,
1345        cx: &mut AsyncAppContext,
1346    ) -> Result<()> {
1347        if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1348            if let Ok(installed_version) =
1349                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1350            {
1351                log::info!("using cached server binary version {}", installed_version);
1352                return Ok(());
1353            }
1354        }
1355
1356        let mut dst_path_gz = dst_path.to_path_buf();
1357        dst_path_gz.set_extension("gz");
1358
1359        if let Some(parent) = dst_path.parent() {
1360            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1361        }
1362
1363        let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
1364
1365        let mut server_binary_exists = false;
1366        if !server_binary_exists && cfg!(not(debug_assertions)) {
1367            if let Ok(installed_version) =
1368                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1369            {
1370                if installed_version.trim() == version.to_string() {
1371                    server_binary_exists = true;
1372                }
1373            }
1374        }
1375
1376        if server_binary_exists {
1377            log::info!("remote development server already present",);
1378            return Ok(());
1379        }
1380
1381        let src_stat = fs::metadata(&src_path).await?;
1382        let size = src_stat.len();
1383        let server_mode = 0o755;
1384
1385        let t0 = Instant::now();
1386        delegate.set_status(Some("Uploading remote development server"), cx);
1387        log::info!("uploading remote development server ({}kb)", size / 1024);
1388        self.upload_file(&src_path, &dst_path_gz)
1389            .await
1390            .context("failed to upload server binary")?;
1391        log::info!("uploaded remote development server in {:?}", t0.elapsed());
1392
1393        delegate.set_status(Some("Extracting remote development server"), cx);
1394        run_cmd(
1395            self.socket
1396                .ssh_command("gunzip")
1397                .arg("--force")
1398                .arg(&dst_path_gz),
1399        )
1400        .await?;
1401
1402        delegate.set_status(Some("Marking remote development server executable"), cx);
1403        run_cmd(
1404            self.socket
1405                .ssh_command("chmod")
1406                .arg(format!("{:o}", server_mode))
1407                .arg(dst_path),
1408        )
1409        .await?;
1410
1411        Ok(())
1412    }
1413
1414    async fn query_platform(&self) -> Result<SshPlatform> {
1415        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1416        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1417
1418        let os = match os.trim() {
1419            "Darwin" => "macos",
1420            "Linux" => "linux",
1421            _ => Err(anyhow!("unknown uname os {os:?}"))?,
1422        };
1423        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1424            "aarch64"
1425        } else if arch.starts_with("x86") || arch.starts_with("i686") {
1426            "x86_64"
1427        } else {
1428            Err(anyhow!("unknown uname architecture {arch:?}"))?
1429        };
1430
1431        Ok(SshPlatform { os, arch })
1432    }
1433
1434    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1435        let mut command = process::Command::new("scp");
1436        let output = self
1437            .socket
1438            .ssh_options(&mut command)
1439            .args(
1440                self.socket
1441                    .connection_options
1442                    .port
1443                    .map(|port| vec!["-P".to_string(), port.to_string()])
1444                    .unwrap_or_default(),
1445            )
1446            .arg(src_path)
1447            .arg(format!(
1448                "{}:{}",
1449                self.socket.connection_options.scp_url(),
1450                dest_path.display()
1451            ))
1452            .output()
1453            .await?;
1454
1455        if output.status.success() {
1456            Ok(())
1457        } else {
1458            Err(anyhow!(
1459                "failed to upload file {} -> {}: {}",
1460                src_path.display(),
1461                dest_path.display(),
1462                String::from_utf8_lossy(&output.stderr)
1463            ))
1464        }
1465    }
1466}
1467
1468type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1469
1470pub struct ChannelClient {
1471    next_message_id: AtomicU32,
1472    outgoing_tx: mpsc::UnboundedSender<Envelope>,
1473    response_channels: ResponseChannels,             // Lock
1474    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1475}
1476
1477impl ChannelClient {
1478    pub fn new(
1479        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1480        outgoing_tx: mpsc::UnboundedSender<Envelope>,
1481        cx: &AppContext,
1482    ) -> Arc<Self> {
1483        let this = Arc::new(Self {
1484            outgoing_tx,
1485            next_message_id: AtomicU32::new(0),
1486            response_channels: ResponseChannels::default(),
1487            message_handlers: Default::default(),
1488        });
1489
1490        Self::start_handling_messages(this.clone(), incoming_rx, cx);
1491
1492        this
1493    }
1494
1495    fn start_handling_messages(
1496        this: Arc<Self>,
1497        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1498        cx: &AppContext,
1499    ) {
1500        cx.spawn(|cx| {
1501            let this = Arc::downgrade(&this);
1502            async move {
1503                let peer_id = PeerId { owner_id: 0, id: 0 };
1504                while let Some(incoming) = incoming_rx.next().await {
1505                    let Some(this) = this.upgrade() else {
1506                        return anyhow::Ok(());
1507                    };
1508
1509                    if let Some(request_id) = incoming.responding_to {
1510                        let request_id = MessageId(request_id);
1511                        let sender = this.response_channels.lock().remove(&request_id);
1512                        if let Some(sender) = sender {
1513                            let (tx, rx) = oneshot::channel();
1514                            if incoming.payload.is_some() {
1515                                sender.send((incoming, tx)).ok();
1516                            }
1517                            rx.await.ok();
1518                        }
1519                    } else if let Some(envelope) =
1520                        build_typed_envelope(peer_id, Instant::now(), incoming)
1521                    {
1522                        let type_name = envelope.payload_type_name();
1523                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1524                            &this.message_handlers,
1525                            envelope,
1526                            this.clone().into(),
1527                            cx.clone(),
1528                        ) {
1529                            log::debug!("ssh message received. name:{type_name}");
1530                            match future.await {
1531                                Ok(_) => {
1532                                    log::debug!("ssh message handled. name:{type_name}");
1533                                }
1534                                Err(error) => {
1535                                    log::error!(
1536                                        "error handling message. type:{type_name}, error:{error}",
1537                                    );
1538                                }
1539                            }
1540                        } else {
1541                            log::error!("unhandled ssh message name:{type_name}");
1542                        }
1543                    }
1544                }
1545                anyhow::Ok(())
1546            }
1547        })
1548        .detach();
1549    }
1550
1551    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1552        let id = (TypeId::of::<E>(), remote_id);
1553
1554        let mut message_handlers = self.message_handlers.lock();
1555        if message_handlers
1556            .entities_by_type_and_remote_id
1557            .contains_key(&id)
1558        {
1559            panic!("already subscribed to entity");
1560        }
1561
1562        message_handlers.entities_by_type_and_remote_id.insert(
1563            id,
1564            EntityMessageSubscriber::Entity {
1565                handle: entity.downgrade().into(),
1566            },
1567        );
1568    }
1569
1570    pub fn request<T: RequestMessage>(
1571        &self,
1572        payload: T,
1573    ) -> impl 'static + Future<Output = Result<T::Response>> {
1574        log::debug!("ssh request start. name:{}", T::NAME);
1575        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1576        async move {
1577            let response = response.await?;
1578            log::debug!("ssh request finish. name:{}", T::NAME);
1579            T::Response::from_envelope(response)
1580                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1581        }
1582    }
1583
1584    pub async fn ping(&self, timeout: Duration) -> Result<()> {
1585        smol::future::or(
1586            async {
1587                self.request(proto::Ping {}).await?;
1588                Ok(())
1589            },
1590            async {
1591                smol::Timer::after(timeout).await;
1592                Err(anyhow!("Timeout detected"))
1593            },
1594        )
1595        .await
1596    }
1597
1598    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1599        log::debug!("ssh send name:{}", T::NAME);
1600        self.send_dynamic(payload.into_envelope(0, None, None))
1601    }
1602
1603    pub fn request_dynamic(
1604        &self,
1605        mut envelope: proto::Envelope,
1606        type_name: &'static str,
1607    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1608        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1609        let (tx, rx) = oneshot::channel();
1610        let mut response_channels_lock = self.response_channels.lock();
1611        response_channels_lock.insert(MessageId(envelope.id), tx);
1612        drop(response_channels_lock);
1613        let result = self.outgoing_tx.unbounded_send(envelope);
1614        async move {
1615            if let Err(error) = &result {
1616                log::error!("failed to send message: {}", error);
1617                return Err(anyhow!("failed to send message: {}", error));
1618            }
1619
1620            let response = rx.await.context("connection lost")?.0;
1621            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1622                return Err(RpcError::from_proto(error, type_name));
1623            }
1624            Ok(response)
1625        }
1626    }
1627
1628    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1629        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1630        self.outgoing_tx.unbounded_send(envelope)?;
1631        Ok(())
1632    }
1633}
1634
1635impl ProtoClient for ChannelClient {
1636    fn request(
1637        &self,
1638        envelope: proto::Envelope,
1639        request_type: &'static str,
1640    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1641        self.request_dynamic(envelope, request_type).boxed()
1642    }
1643
1644    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1645        self.send_dynamic(envelope)
1646    }
1647
1648    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1649        self.send_dynamic(envelope)
1650    }
1651
1652    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1653        &self.message_handlers
1654    }
1655
1656    fn is_via_collab(&self) -> bool {
1657        false
1658    }
1659}