ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6};
   7use anyhow::{anyhow, Context as _, Result};
   8use collections::HashMap;
   9use futures::{
  10    channel::{
  11        mpsc::{self, UnboundedReceiver, UnboundedSender},
  12        oneshot,
  13    },
  14    future::BoxFuture,
  15    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  16    StreamExt as _,
  17};
  18use gpui::{
  19    AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
  20};
  21use parking_lot::Mutex;
  22use rpc::{
  23    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  24    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  25};
  26use smol::{
  27    fs,
  28    process::{self, Child, Stdio},
  29    Timer,
  30};
  31use std::{
  32    any::TypeId,
  33    ffi::OsStr,
  34    mem,
  35    path::{Path, PathBuf},
  36    sync::{
  37        atomic::{AtomicU32, Ordering::SeqCst},
  38        Arc,
  39    },
  40    time::{Duration, Instant},
  41};
  42use tempfile::TempDir;
  43use util::maybe;
  44
  45#[derive(
  46    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  47)]
  48pub struct SshProjectId(pub u64);
  49
  50#[derive(Clone)]
  51pub struct SshSocket {
  52    connection_options: SshConnectionOptions,
  53    socket_path: PathBuf,
  54}
  55
  56#[derive(Debug, Default, Clone, PartialEq, Eq)]
  57pub struct SshConnectionOptions {
  58    pub host: String,
  59    pub username: Option<String>,
  60    pub port: Option<u16>,
  61    pub password: Option<String>,
  62}
  63
  64impl SshConnectionOptions {
  65    pub fn ssh_url(&self) -> String {
  66        let mut result = String::from("ssh://");
  67        if let Some(username) = &self.username {
  68            result.push_str(username);
  69            result.push('@');
  70        }
  71        result.push_str(&self.host);
  72        if let Some(port) = self.port {
  73            result.push(':');
  74            result.push_str(&port.to_string());
  75        }
  76        result
  77    }
  78
  79    fn scp_url(&self) -> String {
  80        if let Some(username) = &self.username {
  81            format!("{}@{}", username, self.host)
  82        } else {
  83            self.host.clone()
  84        }
  85    }
  86
  87    pub fn connection_string(&self) -> String {
  88        let host = if let Some(username) = &self.username {
  89            format!("{}@{}", username, self.host)
  90        } else {
  91            self.host.clone()
  92        };
  93        if let Some(port) = &self.port {
  94            format!("{}:{}", host, port)
  95        } else {
  96            host
  97        }
  98    }
  99
 100    // Uniquely identifies dev server projects on a remote host. Needs to be
 101    // stable for the same dev server project.
 102    pub fn dev_server_identifier(&self) -> String {
 103        let mut identifier = format!("dev-server-{:?}", self.host);
 104        if let Some(username) = self.username.as_ref() {
 105            identifier.push('-');
 106            identifier.push_str(&username);
 107        }
 108        identifier
 109    }
 110}
 111
 112#[derive(Copy, Clone, Debug)]
 113pub struct SshPlatform {
 114    pub os: &'static str,
 115    pub arch: &'static str,
 116}
 117
 118pub trait SshClientDelegate: Send + Sync {
 119    fn ask_password(
 120        &self,
 121        prompt: String,
 122        cx: &mut AsyncAppContext,
 123    ) -> oneshot::Receiver<Result<String>>;
 124    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 125    fn get_server_binary(
 126        &self,
 127        platform: SshPlatform,
 128        cx: &mut AsyncAppContext,
 129    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 130    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 131    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 132}
 133
 134impl SshSocket {
 135    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 136        let mut command = process::Command::new("ssh");
 137        self.ssh_options(&mut command)
 138            .arg(self.connection_options.ssh_url())
 139            .arg(program);
 140        command
 141    }
 142
 143    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 144        command
 145            .stdin(Stdio::piped())
 146            .stdout(Stdio::piped())
 147            .stderr(Stdio::piped())
 148            .args(["-o", "ControlMaster=no", "-o"])
 149            .arg(format!("ControlPath={}", self.socket_path.display()))
 150    }
 151
 152    fn ssh_args(&self) -> Vec<String> {
 153        vec![
 154            "-o".to_string(),
 155            "ControlMaster=no".to_string(),
 156            "-o".to_string(),
 157            format!("ControlPath={}", self.socket_path.display()),
 158            self.connection_options.ssh_url(),
 159        ]
 160    }
 161}
 162
 163async fn run_cmd(command: &mut process::Command) -> Result<String> {
 164    let output = command.output().await?;
 165    if output.status.success() {
 166        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 167    } else {
 168        Err(anyhow!(
 169            "failed to run command: {}",
 170            String::from_utf8_lossy(&output.stderr)
 171        ))
 172    }
 173}
 174
 175struct ChannelForwarder {
 176    quit_tx: UnboundedSender<()>,
 177    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 178}
 179
 180impl ChannelForwarder {
 181    fn new(
 182        mut incoming_tx: UnboundedSender<Envelope>,
 183        mut outgoing_rx: UnboundedReceiver<Envelope>,
 184        cx: &AsyncAppContext,
 185    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 186        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 187
 188        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 189        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 190
 191        let forwarding_task = cx.background_executor().spawn(async move {
 192            loop {
 193                select_biased! {
 194                    _ = quit_rx.next().fuse() => {
 195                        break;
 196                    },
 197                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 198                        if let Some(envelope) = incoming_envelope {
 199                            if incoming_tx.send(envelope).await.is_err() {
 200                                break;
 201                            }
 202                        } else {
 203                            break;
 204                        }
 205                    }
 206                    outgoing_envelope = outgoing_rx.next().fuse() => {
 207                        if let Some(envelope) = outgoing_envelope {
 208                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 209                                break;
 210                            }
 211                        } else {
 212                            break;
 213                        }
 214                    }
 215                }
 216            }
 217
 218            (incoming_tx, outgoing_rx)
 219        });
 220
 221        (
 222            Self {
 223                forwarding_task,
 224                quit_tx,
 225            },
 226            proxy_incoming_tx,
 227            proxy_outgoing_rx,
 228        )
 229    }
 230
 231    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 232        let _ = self.quit_tx.send(()).await;
 233        self.forwarding_task.await
 234    }
 235}
 236
 237struct SshRemoteClientState {
 238    ssh_connection: SshRemoteConnection,
 239    delegate: Arc<dyn SshClientDelegate>,
 240    forwarder: ChannelForwarder,
 241    multiplex_task: Task<Result<()>>,
 242    heartbeat_task: Task<Result<()>>,
 243}
 244
 245pub struct SshRemoteClient {
 246    client: Arc<ChannelClient>,
 247    unique_identifier: String,
 248    connection_options: SshConnectionOptions,
 249    inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
 250}
 251
 252impl Drop for SshRemoteClient {
 253    fn drop(&mut self) {
 254        self.shutdown_processes();
 255    }
 256}
 257
 258impl SshRemoteClient {
 259    pub fn new(
 260        unique_identifier: String,
 261        connection_options: SshConnectionOptions,
 262        delegate: Arc<dyn SshClientDelegate>,
 263        cx: &AppContext,
 264    ) -> Task<Result<Model<Self>>> {
 265        cx.spawn(|mut cx| async move {
 266            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 267            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 268
 269            let this = cx.new_model(|cx| {
 270                cx.on_app_quit(|this: &mut Self, _| {
 271                    this.shutdown_processes();
 272                    futures::future::ready(())
 273                })
 274                .detach();
 275
 276                let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
 277                Self {
 278                    client,
 279                    unique_identifier: unique_identifier.clone(),
 280                    connection_options: SshConnectionOptions::default(),
 281                    inner_state: Arc::new(Mutex::new(None)),
 282                }
 283            })?;
 284
 285            let inner_state = {
 286                let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 287                    ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 288
 289                let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 290                    unique_identifier,
 291                    connection_options,
 292                    delegate.clone(),
 293                    &mut cx,
 294                )
 295                .await?;
 296
 297                let multiplex_task = Self::multiplex(
 298                    this.downgrade(),
 299                    ssh_proxy_process,
 300                    proxy_incoming_tx,
 301                    proxy_outgoing_rx,
 302                    &mut cx,
 303                );
 304
 305                SshRemoteClientState {
 306                    ssh_connection,
 307                    delegate,
 308                    forwarder: proxy,
 309                    multiplex_task,
 310                    heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx),
 311                }
 312            };
 313
 314            this.update(&mut cx, |this, cx| {
 315                this.inner_state.lock().replace(inner_state);
 316                cx.notify();
 317            })?;
 318
 319            Ok(this)
 320        })
 321    }
 322
 323    fn shutdown_processes(&self) {
 324        let Some(mut state) = self.inner_state.lock().take() else {
 325            return;
 326        };
 327        log::info!("shutting down ssh processes");
 328        // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 329        // child of master_process.
 330        let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
 331        drop(task);
 332        // Now drop the rest of state, which kills master process.
 333        drop(state);
 334    }
 335
 336    fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
 337        log::info!("Trying to reconnect to ssh server...");
 338        let Some(state) = self.inner_state.lock().take() else {
 339            return Err(anyhow!("reconnect is already in progress"));
 340        };
 341
 342        let workspace_identifier = self.unique_identifier.clone();
 343
 344        let SshRemoteClientState {
 345            mut ssh_connection,
 346            delegate,
 347            forwarder: proxy,
 348            multiplex_task,
 349            heartbeat_task,
 350        } = state;
 351        drop(multiplex_task);
 352        drop(heartbeat_task);
 353
 354        cx.spawn(|this, mut cx| async move {
 355            let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
 356
 357            ssh_connection.master_process.kill()?;
 358            ssh_connection
 359                .master_process
 360                .status()
 361                .await
 362                .context("Failed to kill ssh process")?;
 363
 364            let connection_options = ssh_connection.socket.connection_options.clone();
 365
 366            let (ssh_connection, ssh_process) = Self::establish_connection(
 367                workspace_identifier,
 368                connection_options,
 369                delegate.clone(),
 370                &mut cx,
 371            )
 372            .await?;
 373
 374            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 375                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 376
 377            let inner_state = SshRemoteClientState {
 378                ssh_connection,
 379                delegate,
 380                forwarder: proxy,
 381                multiplex_task: Self::multiplex(
 382                    this.clone(),
 383                    ssh_process,
 384                    proxy_incoming_tx,
 385                    proxy_outgoing_rx,
 386                    &mut cx,
 387                ),
 388                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
 389            };
 390
 391            this.update(&mut cx, |this, _| {
 392                this.inner_state.lock().replace(inner_state);
 393            })
 394        })
 395        .detach();
 396        Ok(())
 397    }
 398
 399    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
 400        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 401            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 402        };
 403        cx.spawn(|mut cx| {
 404            let this = this.clone();
 405            async move {
 406                const MAX_MISSED_HEARTBEATS: usize = 5;
 407                const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 408                const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 409
 410                let mut missed_heartbeats = 0;
 411
 412                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
 413                loop {
 414                    timer.next().await;
 415
 416                    log::info!("Sending heartbeat to server...");
 417
 418                    let result = smol::future::or(
 419                        async {
 420                            client.request(proto::Ping {}).await?;
 421                            Ok(())
 422                        },
 423                        async {
 424                            smol::Timer::after(HEARTBEAT_TIMEOUT).await;
 425
 426                            Err(anyhow!("Timeout detected"))
 427                        },
 428                    )
 429                    .await;
 430
 431                    if result.is_err() {
 432                        missed_heartbeats += 1;
 433                        log::warn!(
 434                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 435                            HEARTBEAT_TIMEOUT,
 436                            missed_heartbeats,
 437                            MAX_MISSED_HEARTBEATS
 438                        );
 439                    } else {
 440                        missed_heartbeats = 0;
 441                    }
 442
 443                    if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 444                        log::error!(
 445                            "Missed last {} hearbeats. Reconnecting...",
 446                            missed_heartbeats
 447                        );
 448
 449                        this.update(&mut cx, |this, cx| {
 450                            this.reconnect(cx)
 451                                .context("failed to reconnect after missing heartbeats")
 452                        })
 453                        .context("failed to update weak reference, SshRemoteClient lost?")??;
 454                        return Ok(());
 455                    }
 456                }
 457            }
 458        })
 459    }
 460
 461    fn multiplex(
 462        this: WeakModel<Self>,
 463        mut ssh_proxy_process: Child,
 464        incoming_tx: UnboundedSender<Envelope>,
 465        mut outgoing_rx: UnboundedReceiver<Envelope>,
 466        cx: &AsyncAppContext,
 467    ) -> Task<Result<()>> {
 468        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 469        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 470        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 471
 472        let io_task = cx.background_executor().spawn(async move {
 473            let mut stdin_buffer = Vec::new();
 474            let mut stdout_buffer = Vec::new();
 475            let mut stderr_buffer = Vec::new();
 476            let mut stderr_offset = 0;
 477
 478            loop {
 479                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 480                stderr_buffer.resize(stderr_offset + 1024, 0);
 481
 482                select_biased! {
 483                    outgoing = outgoing_rx.next().fuse() => {
 484                        let Some(outgoing) = outgoing else {
 485                            return anyhow::Ok(());
 486                        };
 487
 488                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 489                    }
 490
 491                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 492                        match result {
 493                            Ok(0) => {
 494                                child_stdin.close().await?;
 495                                outgoing_rx.close();
 496                                let status = ssh_proxy_process.status().await?;
 497                                if !status.success() {
 498                                    log::error!("ssh process exited with status: {status:?}");
 499                                    return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
 500                                }
 501                                return Ok(());
 502                            }
 503                            Ok(len) => {
 504                                if len < stdout_buffer.len() {
 505                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 506                                }
 507
 508                                let message_len = message_len_from_buffer(&stdout_buffer);
 509                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 510                                    Ok(envelope) => {
 511                                        incoming_tx.unbounded_send(envelope).ok();
 512                                    }
 513                                    Err(error) => {
 514                                        log::error!("error decoding message {error:?}");
 515                                    }
 516                                }
 517                            }
 518                            Err(error) => {
 519                                Err(anyhow!("error reading stdout: {error:?}"))?;
 520                            }
 521                        }
 522                    }
 523
 524                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 525                        match result {
 526                            Ok(len) => {
 527                                stderr_offset += len;
 528                                let mut start_ix = 0;
 529                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 530                                    let line_ix = start_ix + ix;
 531                                    let content = &stderr_buffer[start_ix..line_ix];
 532                                    start_ix = line_ix + 1;
 533                                    if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
 534                                        record.message = format!("(remote) {}", record.message);
 535                                        record.log(log::logger())
 536                                    } else {
 537                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 538                                    }
 539                                }
 540                                stderr_buffer.drain(0..start_ix);
 541                                stderr_offset -= start_ix;
 542                            }
 543                            Err(error) => {
 544                                Err(anyhow!("error reading stderr: {error:?}"))?;
 545                            }
 546                        }
 547                    }
 548                }
 549            }
 550        });
 551
 552        cx.spawn(|mut cx| async move {
 553            let result = io_task.await;
 554
 555            if let Err(error) = result {
 556                log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 557                this.update(&mut cx, |this, cx| {
 558                    this.reconnect(cx).ok();
 559                })?;
 560            }
 561
 562            Ok(())
 563        })
 564    }
 565
 566    async fn establish_connection(
 567        unique_identifier: String,
 568        connection_options: SshConnectionOptions,
 569        delegate: Arc<dyn SshClientDelegate>,
 570        cx: &mut AsyncAppContext,
 571    ) -> Result<(SshRemoteConnection, Child)> {
 572        let ssh_connection =
 573            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 574
 575        let platform = ssh_connection.query_platform().await?;
 576        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 577        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 578        ssh_connection
 579            .ensure_server_binary(
 580                &delegate,
 581                &local_binary_path,
 582                &remote_binary_path,
 583                version,
 584                cx,
 585            )
 586            .await?;
 587
 588        let socket = ssh_connection.socket.clone();
 589        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 590
 591        delegate.set_status(Some("Starting proxy"), cx);
 592
 593        let ssh_proxy_process = socket
 594            .ssh_command(format!(
 595                "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 596                std::env::var("RUST_LOG").unwrap_or_default(),
 597                std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 598                remote_binary_path,
 599                unique_identifier,
 600            ))
 601            // IMPORTANT: we kill this process when we drop the task that uses it.
 602            .kill_on_drop(true)
 603            .spawn()
 604            .context("failed to spawn remote server")?;
 605
 606        Ok((ssh_connection, ssh_proxy_process))
 607    }
 608
 609    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 610        self.client.subscribe_to_entity(remote_id, entity);
 611    }
 612
 613    pub fn ssh_args(&self) -> Option<Vec<String>> {
 614        let state = self.inner_state.lock();
 615        state
 616            .as_ref()
 617            .map(|state| state.ssh_connection.socket.ssh_args())
 618    }
 619
 620    pub fn to_proto_client(&self) -> AnyProtoClient {
 621        self.client.clone().into()
 622    }
 623
 624    pub fn connection_string(&self) -> String {
 625        self.connection_options.connection_string()
 626    }
 627
 628    pub fn is_reconnect_underway(&self) -> bool {
 629        maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default()
 630    }
 631
 632    #[cfg(any(test, feature = "test-support"))]
 633    pub fn fake(
 634        client_cx: &mut gpui::TestAppContext,
 635        server_cx: &mut gpui::TestAppContext,
 636    ) -> (Model<Self>, Arc<ChannelClient>) {
 637        use gpui::Context;
 638
 639        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
 640        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
 641
 642        (
 643            client_cx.update(|cx| {
 644                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
 645                cx.new_model(|_| Self {
 646                    client,
 647                    unique_identifier: "fake".to_string(),
 648                    connection_options: SshConnectionOptions::default(),
 649                    inner_state: Arc::new(Mutex::new(None)),
 650                })
 651            }),
 652            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
 653        )
 654    }
 655}
 656
 657impl From<SshRemoteClient> for AnyProtoClient {
 658    fn from(client: SshRemoteClient) -> Self {
 659        AnyProtoClient::new(client.client.clone())
 660    }
 661}
 662
 663struct SshRemoteConnection {
 664    socket: SshSocket,
 665    master_process: process::Child,
 666    _temp_dir: TempDir,
 667}
 668
 669impl Drop for SshRemoteConnection {
 670    fn drop(&mut self) {
 671        if let Err(error) = self.master_process.kill() {
 672            log::error!("failed to kill SSH master process: {}", error);
 673        }
 674    }
 675}
 676
 677impl SshRemoteConnection {
 678    #[cfg(not(unix))]
 679    async fn new(
 680        _connection_options: SshConnectionOptions,
 681        _delegate: Arc<dyn SshClientDelegate>,
 682        _cx: &mut AsyncAppContext,
 683    ) -> Result<Self> {
 684        Err(anyhow!("ssh is not supported on this platform"))
 685    }
 686
 687    #[cfg(unix)]
 688    async fn new(
 689        connection_options: SshConnectionOptions,
 690        delegate: Arc<dyn SshClientDelegate>,
 691        cx: &mut AsyncAppContext,
 692    ) -> Result<Self> {
 693        use futures::{io::BufReader, AsyncBufReadExt as _};
 694        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
 695        use util::ResultExt as _;
 696
 697        delegate.set_status(Some("connecting"), cx);
 698
 699        let url = connection_options.ssh_url();
 700        let temp_dir = tempfile::Builder::new()
 701            .prefix("zed-ssh-session")
 702            .tempdir()?;
 703
 704        // Create a domain socket listener to handle requests from the askpass program.
 705        let askpass_socket = temp_dir.path().join("askpass.sock");
 706        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 707        let listener =
 708            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
 709
 710        let askpass_task = cx.spawn({
 711            let delegate = delegate.clone();
 712            |mut cx| async move {
 713                let mut askpass_opened_tx = Some(askpass_opened_tx);
 714
 715                while let Ok((mut stream, _)) = listener.accept().await {
 716                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
 717                        askpass_opened_tx.send(()).ok();
 718                    }
 719                    let mut buffer = Vec::new();
 720                    let mut reader = BufReader::new(&mut stream);
 721                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
 722                        buffer.clear();
 723                    }
 724                    let password_prompt = String::from_utf8_lossy(&buffer);
 725                    if let Some(password) = delegate
 726                        .ask_password(password_prompt.to_string(), &mut cx)
 727                        .await
 728                        .context("failed to get ssh password")
 729                        .and_then(|p| p)
 730                        .log_err()
 731                    {
 732                        stream.write_all(password.as_bytes()).await.log_err();
 733                    }
 734                }
 735            }
 736        });
 737
 738        // Create an askpass script that communicates back to this process.
 739        let askpass_script = format!(
 740            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
 741            askpass_socket = askpass_socket.display(),
 742            print_args = "printf '%s\\0' \"$@\"",
 743            shebang = "#!/bin/sh",
 744        );
 745        let askpass_script_path = temp_dir.path().join("askpass.sh");
 746        fs::write(&askpass_script_path, askpass_script).await?;
 747        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
 748
 749        // Start the master SSH process, which does not do anything except for establish
 750        // the connection and keep it open, allowing other ssh commands to reuse it
 751        // via a control socket.
 752        let socket_path = temp_dir.path().join("ssh.sock");
 753        let mut master_process = process::Command::new("ssh")
 754            .stdin(Stdio::null())
 755            .stdout(Stdio::piped())
 756            .stderr(Stdio::piped())
 757            .env("SSH_ASKPASS_REQUIRE", "force")
 758            .env("SSH_ASKPASS", &askpass_script_path)
 759            .args(["-N", "-o", "ControlMaster=yes", "-o"])
 760            .arg(format!("ControlPath={}", socket_path.display()))
 761            .arg(&url)
 762            .spawn()?;
 763
 764        // Wait for this ssh process to close its stdout, indicating that authentication
 765        // has completed.
 766        let stdout = master_process.stdout.as_mut().unwrap();
 767        let mut output = Vec::new();
 768        let connection_timeout = Duration::from_secs(10);
 769
 770        let result = select_biased! {
 771            _ = askpass_opened_rx.fuse() => {
 772                // If the askpass script has opened, that means the user is typing
 773                // their password, in which case we don't want to timeout anymore,
 774                // since we know a connection has been established.
 775                stdout.read_to_end(&mut output).await?;
 776                Ok(())
 777            }
 778            result = stdout.read_to_end(&mut output).fuse() => {
 779                result?;
 780                Ok(())
 781            }
 782            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
 783                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
 784            }
 785        };
 786
 787        if let Err(e) = result {
 788            let error_message = format!("Failed to connect to host: {}.", e);
 789            delegate.set_error(error_message, cx);
 790            return Err(e);
 791        }
 792
 793        drop(askpass_task);
 794
 795        if master_process.try_status()?.is_some() {
 796            output.clear();
 797            let mut stderr = master_process.stderr.take().unwrap();
 798            stderr.read_to_end(&mut output).await?;
 799
 800            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
 801            delegate.set_error(error_message.clone(), cx);
 802            Err(anyhow!(error_message))?;
 803        }
 804
 805        Ok(Self {
 806            socket: SshSocket {
 807                connection_options,
 808                socket_path,
 809            },
 810            master_process,
 811            _temp_dir: temp_dir,
 812        })
 813    }
 814
 815    async fn ensure_server_binary(
 816        &self,
 817        delegate: &Arc<dyn SshClientDelegate>,
 818        src_path: &Path,
 819        dst_path: &Path,
 820        version: SemanticVersion,
 821        cx: &mut AsyncAppContext,
 822    ) -> Result<()> {
 823        let mut dst_path_gz = dst_path.to_path_buf();
 824        dst_path_gz.set_extension("gz");
 825
 826        if let Some(parent) = dst_path.parent() {
 827            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
 828        }
 829
 830        let mut server_binary_exists = false;
 831        if cfg!(not(debug_assertions)) {
 832            if let Ok(installed_version) =
 833                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
 834            {
 835                if installed_version.trim() == version.to_string() {
 836                    server_binary_exists = true;
 837                }
 838            }
 839        }
 840
 841        if server_binary_exists {
 842            log::info!("remote development server already present",);
 843            return Ok(());
 844        }
 845
 846        let src_stat = fs::metadata(src_path).await?;
 847        let size = src_stat.len();
 848        let server_mode = 0o755;
 849
 850        let t0 = Instant::now();
 851        delegate.set_status(Some("uploading remote development server"), cx);
 852        log::info!("uploading remote development server ({}kb)", size / 1024);
 853        self.upload_file(src_path, &dst_path_gz)
 854            .await
 855            .context("failed to upload server binary")?;
 856        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 857
 858        delegate.set_status(Some("extracting remote development server"), cx);
 859        run_cmd(
 860            self.socket
 861                .ssh_command("gunzip")
 862                .arg("--force")
 863                .arg(&dst_path_gz),
 864        )
 865        .await?;
 866
 867        delegate.set_status(Some("unzipping remote development server"), cx);
 868        run_cmd(
 869            self.socket
 870                .ssh_command("chmod")
 871                .arg(format!("{:o}", server_mode))
 872                .arg(dst_path),
 873        )
 874        .await?;
 875
 876        Ok(())
 877    }
 878
 879    async fn query_platform(&self) -> Result<SshPlatform> {
 880        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
 881        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
 882
 883        let os = match os.trim() {
 884            "Darwin" => "macos",
 885            "Linux" => "linux",
 886            _ => Err(anyhow!("unknown uname os {os:?}"))?,
 887        };
 888        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
 889            "aarch64"
 890        } else if arch.starts_with("x86") || arch.starts_with("i686") {
 891            "x86_64"
 892        } else {
 893            Err(anyhow!("unknown uname architecture {arch:?}"))?
 894        };
 895
 896        Ok(SshPlatform { os, arch })
 897    }
 898
 899    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
 900        let mut command = process::Command::new("scp");
 901        let output = self
 902            .socket
 903            .ssh_options(&mut command)
 904            .args(
 905                self.socket
 906                    .connection_options
 907                    .port
 908                    .map(|port| vec!["-P".to_string(), port.to_string()])
 909                    .unwrap_or_default(),
 910            )
 911            .arg(src_path)
 912            .arg(format!(
 913                "{}:{}",
 914                self.socket.connection_options.scp_url(),
 915                dest_path.display()
 916            ))
 917            .output()
 918            .await?;
 919
 920        if output.status.success() {
 921            Ok(())
 922        } else {
 923            Err(anyhow!(
 924                "failed to upload file {} -> {}: {}",
 925                src_path.display(),
 926                dest_path.display(),
 927                String::from_utf8_lossy(&output.stderr)
 928            ))
 929        }
 930    }
 931}
 932
 933type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 934
 935pub struct ChannelClient {
 936    next_message_id: AtomicU32,
 937    outgoing_tx: mpsc::UnboundedSender<Envelope>,
 938    response_channels: ResponseChannels,             // Lock
 939    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
 940}
 941
 942impl ChannelClient {
 943    pub fn new(
 944        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 945        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 946        cx: &AppContext,
 947    ) -> Arc<Self> {
 948        let this = Arc::new(Self {
 949            outgoing_tx,
 950            next_message_id: AtomicU32::new(0),
 951            response_channels: ResponseChannels::default(),
 952            message_handlers: Default::default(),
 953        });
 954
 955        Self::start_handling_messages(this.clone(), incoming_rx, cx);
 956
 957        this
 958    }
 959
 960    fn start_handling_messages(
 961        this: Arc<Self>,
 962        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 963        cx: &AppContext,
 964    ) {
 965        cx.spawn(|cx| {
 966            let this = Arc::downgrade(&this);
 967            async move {
 968                let peer_id = PeerId { owner_id: 0, id: 0 };
 969                while let Some(incoming) = incoming_rx.next().await {
 970                    let Some(this) = this.upgrade() else {
 971                        return anyhow::Ok(());
 972                    };
 973
 974                    if let Some(request_id) = incoming.responding_to {
 975                        let request_id = MessageId(request_id);
 976                        let sender = this.response_channels.lock().remove(&request_id);
 977                        if let Some(sender) = sender {
 978                            let (tx, rx) = oneshot::channel();
 979                            if incoming.payload.is_some() {
 980                                sender.send((incoming, tx)).ok();
 981                            }
 982                            rx.await.ok();
 983                        }
 984                    } else if let Some(envelope) =
 985                        build_typed_envelope(peer_id, Instant::now(), incoming)
 986                    {
 987                        let type_name = envelope.payload_type_name();
 988                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
 989                            &this.message_handlers,
 990                            envelope,
 991                            this.clone().into(),
 992                            cx.clone(),
 993                        ) {
 994                            log::debug!("ssh message received. name:{type_name}");
 995                            match future.await {
 996                                Ok(_) => {
 997                                    log::debug!("ssh message handled. name:{type_name}");
 998                                }
 999                                Err(error) => {
1000                                    log::error!(
1001                                        "error handling message. type:{type_name}, error:{error}",
1002                                    );
1003                                }
1004                            }
1005                        } else {
1006                            log::error!("unhandled ssh message name:{type_name}");
1007                        }
1008                    }
1009                }
1010                anyhow::Ok(())
1011            }
1012        })
1013        .detach();
1014    }
1015
1016    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1017        let id = (TypeId::of::<E>(), remote_id);
1018
1019        let mut message_handlers = self.message_handlers.lock();
1020        if message_handlers
1021            .entities_by_type_and_remote_id
1022            .contains_key(&id)
1023        {
1024            panic!("already subscribed to entity");
1025        }
1026
1027        message_handlers.entities_by_type_and_remote_id.insert(
1028            id,
1029            EntityMessageSubscriber::Entity {
1030                handle: entity.downgrade().into(),
1031            },
1032        );
1033    }
1034
1035    pub fn request<T: RequestMessage>(
1036        &self,
1037        payload: T,
1038    ) -> impl 'static + Future<Output = Result<T::Response>> {
1039        log::debug!("ssh request start. name:{}", T::NAME);
1040        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1041        async move {
1042            let response = response.await?;
1043            log::debug!("ssh request finish. name:{}", T::NAME);
1044            T::Response::from_envelope(response)
1045                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1046        }
1047    }
1048
1049    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1050        log::debug!("ssh send name:{}", T::NAME);
1051        self.send_dynamic(payload.into_envelope(0, None, None))
1052    }
1053
1054    pub fn request_dynamic(
1055        &self,
1056        mut envelope: proto::Envelope,
1057        type_name: &'static str,
1058    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1059        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1060        let (tx, rx) = oneshot::channel();
1061        let mut response_channels_lock = self.response_channels.lock();
1062        response_channels_lock.insert(MessageId(envelope.id), tx);
1063        drop(response_channels_lock);
1064        let result = self.outgoing_tx.unbounded_send(envelope);
1065        async move {
1066            if let Err(error) = &result {
1067                log::error!("failed to send message: {}", error);
1068                return Err(anyhow!("failed to send message: {}", error));
1069            }
1070
1071            let response = rx.await.context("connection lost")?.0;
1072            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1073                return Err(RpcError::from_proto(error, type_name));
1074            }
1075            Ok(response)
1076        }
1077    }
1078
1079    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1080        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1081        self.outgoing_tx.unbounded_send(envelope)?;
1082        Ok(())
1083    }
1084}
1085
1086impl ProtoClient for ChannelClient {
1087    fn request(
1088        &self,
1089        envelope: proto::Envelope,
1090        request_type: &'static str,
1091    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1092        self.request_dynamic(envelope, request_type).boxed()
1093    }
1094
1095    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1096        self.send_dynamic(envelope)
1097    }
1098
1099    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1100        self.send_dynamic(envelope)
1101    }
1102
1103    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1104        &self.message_handlers
1105    }
1106
1107    fn is_via_collab(&self) -> bool {
1108        false
1109    }
1110}