ssh_session.rs

   1use crate::{
   2    json_log::LogRecord,
   3    protocol::{
   4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
   5    },
   6};
   7use anyhow::{anyhow, Context as _, Result};
   8use collections::HashMap;
   9use futures::{
  10    channel::{
  11        mpsc::{self, UnboundedReceiver, UnboundedSender},
  12        oneshot,
  13    },
  14    future::BoxFuture,
  15    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
  16    StreamExt as _,
  17};
  18use gpui::{
  19    AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
  20};
  21use parking_lot::Mutex;
  22use rpc::{
  23    proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
  24    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
  25};
  26use smol::{
  27    fs,
  28    process::{self, Child, Stdio},
  29    Timer,
  30};
  31use std::{
  32    any::TypeId,
  33    ffi::OsStr,
  34    mem,
  35    path::{Path, PathBuf},
  36    sync::{
  37        atomic::{AtomicU32, Ordering::SeqCst},
  38        Arc,
  39    },
  40    time::{Duration, Instant},
  41};
  42use tempfile::TempDir;
  43use util::maybe;
  44
  45#[derive(
  46    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  47)]
  48pub struct SshProjectId(pub u64);
  49
  50#[derive(Clone)]
  51pub struct SshSocket {
  52    connection_options: SshConnectionOptions,
  53    socket_path: PathBuf,
  54}
  55
  56#[derive(Debug, Default, Clone, PartialEq, Eq)]
  57pub struct SshConnectionOptions {
  58    pub host: String,
  59    pub username: Option<String>,
  60    pub port: Option<u16>,
  61    pub password: Option<String>,
  62}
  63
  64impl SshConnectionOptions {
  65    pub fn ssh_url(&self) -> String {
  66        let mut result = String::from("ssh://");
  67        if let Some(username) = &self.username {
  68            result.push_str(username);
  69            result.push('@');
  70        }
  71        result.push_str(&self.host);
  72        if let Some(port) = self.port {
  73            result.push(':');
  74            result.push_str(&port.to_string());
  75        }
  76        result
  77    }
  78
  79    fn scp_url(&self) -> String {
  80        if let Some(username) = &self.username {
  81            format!("{}@{}", username, self.host)
  82        } else {
  83            self.host.clone()
  84        }
  85    }
  86
  87    pub fn connection_string(&self) -> String {
  88        let host = if let Some(username) = &self.username {
  89            format!("{}@{}", username, self.host)
  90        } else {
  91            self.host.clone()
  92        };
  93        if let Some(port) = &self.port {
  94            format!("{}:{}", host, port)
  95        } else {
  96            host
  97        }
  98    }
  99
 100    // Uniquely identifies dev server projects on a remote host. Needs to be
 101    // stable for the same dev server project.
 102    pub fn dev_server_identifier(&self) -> String {
 103        let mut identifier = format!("dev-server-{:?}", self.host);
 104        if let Some(username) = self.username.as_ref() {
 105            identifier.push('-');
 106            identifier.push_str(&username);
 107        }
 108        identifier
 109    }
 110}
 111
 112#[derive(Copy, Clone, Debug)]
 113pub struct SshPlatform {
 114    pub os: &'static str,
 115    pub arch: &'static str,
 116}
 117
 118pub trait SshClientDelegate: Send + Sync {
 119    fn ask_password(
 120        &self,
 121        prompt: String,
 122        cx: &mut AsyncAppContext,
 123    ) -> oneshot::Receiver<Result<String>>;
 124    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 125    fn get_server_binary(
 126        &self,
 127        platform: SshPlatform,
 128        cx: &mut AsyncAppContext,
 129    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 130    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 131    fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 132}
 133
 134impl SshSocket {
 135    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
 136        let mut command = process::Command::new("ssh");
 137        self.ssh_options(&mut command)
 138            .arg(self.connection_options.ssh_url())
 139            .arg(program);
 140        command
 141    }
 142
 143    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
 144        command
 145            .stdin(Stdio::piped())
 146            .stdout(Stdio::piped())
 147            .stderr(Stdio::piped())
 148            .args(["-o", "ControlMaster=no", "-o"])
 149            .arg(format!("ControlPath={}", self.socket_path.display()))
 150    }
 151
 152    fn ssh_args(&self) -> Vec<String> {
 153        vec![
 154            "-o".to_string(),
 155            "ControlMaster=no".to_string(),
 156            "-o".to_string(),
 157            format!("ControlPath={}", self.socket_path.display()),
 158            self.connection_options.ssh_url(),
 159        ]
 160    }
 161}
 162
 163async fn run_cmd(command: &mut process::Command) -> Result<String> {
 164    let output = command.output().await?;
 165    if output.status.success() {
 166        Ok(String::from_utf8_lossy(&output.stdout).to_string())
 167    } else {
 168        Err(anyhow!(
 169            "failed to run command: {}",
 170            String::from_utf8_lossy(&output.stderr)
 171        ))
 172    }
 173}
 174#[cfg(unix)]
 175async fn read_with_timeout(
 176    stdout: &mut process::ChildStdout,
 177    timeout: Duration,
 178    output: &mut Vec<u8>,
 179) -> Result<(), std::io::Error> {
 180    smol::future::or(
 181        async {
 182            stdout.read_to_end(output).await?;
 183            Ok::<_, std::io::Error>(())
 184        },
 185        async {
 186            smol::Timer::after(timeout).await;
 187
 188            Err(std::io::Error::new(
 189                std::io::ErrorKind::TimedOut,
 190                "Read operation timed out",
 191            ))
 192        },
 193    )
 194    .await
 195}
 196
 197struct ChannelForwarder {
 198    quit_tx: UnboundedSender<()>,
 199    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
 200}
 201
 202impl ChannelForwarder {
 203    fn new(
 204        mut incoming_tx: UnboundedSender<Envelope>,
 205        mut outgoing_rx: UnboundedReceiver<Envelope>,
 206        cx: &AsyncAppContext,
 207    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 208        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
 209
 210        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
 211        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
 212
 213        let forwarding_task = cx.background_executor().spawn(async move {
 214            loop {
 215                select_biased! {
 216                    _ = quit_rx.next().fuse() => {
 217                        break;
 218                    },
 219                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
 220                        if let Some(envelope) = incoming_envelope {
 221                            if incoming_tx.send(envelope).await.is_err() {
 222                                break;
 223                            }
 224                        } else {
 225                            break;
 226                        }
 227                    }
 228                    outgoing_envelope = outgoing_rx.next().fuse() => {
 229                        if let Some(envelope) = outgoing_envelope {
 230                            if proxy_outgoing_tx.send(envelope).await.is_err() {
 231                                break;
 232                            }
 233                        } else {
 234                            break;
 235                        }
 236                    }
 237                }
 238            }
 239
 240            (incoming_tx, outgoing_rx)
 241        });
 242
 243        (
 244            Self {
 245                forwarding_task,
 246                quit_tx,
 247            },
 248            proxy_incoming_tx,
 249            proxy_outgoing_rx,
 250        )
 251    }
 252
 253    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
 254        let _ = self.quit_tx.send(()).await;
 255        self.forwarding_task.await
 256    }
 257}
 258
 259struct SshRemoteClientState {
 260    ssh_connection: SshRemoteConnection,
 261    delegate: Arc<dyn SshClientDelegate>,
 262    forwarder: ChannelForwarder,
 263    multiplex_task: Task<Result<()>>,
 264    heartbeat_task: Task<Result<()>>,
 265}
 266
 267pub struct SshRemoteClient {
 268    client: Arc<ChannelClient>,
 269    unique_identifier: String,
 270    connection_options: SshConnectionOptions,
 271    inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
 272}
 273
 274impl Drop for SshRemoteClient {
 275    fn drop(&mut self) {
 276        self.shutdown_processes();
 277    }
 278}
 279
 280impl SshRemoteClient {
 281    pub fn new(
 282        unique_identifier: String,
 283        connection_options: SshConnectionOptions,
 284        delegate: Arc<dyn SshClientDelegate>,
 285        cx: &AppContext,
 286    ) -> Task<Result<Model<Self>>> {
 287        cx.spawn(|mut cx| async move {
 288            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
 289            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 290
 291            let this = cx.new_model(|cx| {
 292                cx.on_app_quit(|this: &mut Self, _| {
 293                    this.shutdown_processes();
 294                    futures::future::ready(())
 295                })
 296                .detach();
 297
 298                let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
 299                Self {
 300                    client,
 301                    unique_identifier: unique_identifier.clone(),
 302                    connection_options: SshConnectionOptions::default(),
 303                    inner_state: Arc::new(Mutex::new(None)),
 304                }
 305            })?;
 306
 307            let inner_state = {
 308                let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 309                    ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 310
 311                let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
 312                    unique_identifier,
 313                    connection_options,
 314                    delegate.clone(),
 315                    &mut cx,
 316                )
 317                .await?;
 318
 319                let multiplex_task = Self::multiplex(
 320                    this.downgrade(),
 321                    ssh_proxy_process,
 322                    proxy_incoming_tx,
 323                    proxy_outgoing_rx,
 324                    &mut cx,
 325                );
 326
 327                SshRemoteClientState {
 328                    ssh_connection,
 329                    delegate,
 330                    forwarder: proxy,
 331                    multiplex_task,
 332                    heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx),
 333                }
 334            };
 335
 336            this.update(&mut cx, |this, cx| {
 337                this.inner_state.lock().replace(inner_state);
 338                cx.notify();
 339            })?;
 340
 341            Ok(this)
 342        })
 343    }
 344
 345    fn shutdown_processes(&self) {
 346        let Some(mut state) = self.inner_state.lock().take() else {
 347            return;
 348        };
 349        log::info!("shutting down ssh processes");
 350        // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
 351        // child of master_process.
 352        let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
 353        drop(task);
 354        // Now drop the rest of state, which kills master process.
 355        drop(state);
 356    }
 357
 358    fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
 359        log::info!("Trying to reconnect to ssh server...");
 360        let Some(state) = self.inner_state.lock().take() else {
 361            return Err(anyhow!("reconnect is already in progress"));
 362        };
 363
 364        let workspace_identifier = self.unique_identifier.clone();
 365
 366        let SshRemoteClientState {
 367            mut ssh_connection,
 368            delegate,
 369            forwarder: proxy,
 370            multiplex_task,
 371            heartbeat_task,
 372        } = state;
 373        drop(multiplex_task);
 374        drop(heartbeat_task);
 375
 376        cx.spawn(|this, mut cx| async move {
 377            let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
 378
 379            ssh_connection.master_process.kill()?;
 380            ssh_connection
 381                .master_process
 382                .status()
 383                .await
 384                .context("Failed to kill ssh process")?;
 385
 386            let connection_options = ssh_connection.socket.connection_options.clone();
 387
 388            let (ssh_connection, ssh_process) = Self::establish_connection(
 389                workspace_identifier,
 390                connection_options,
 391                delegate.clone(),
 392                &mut cx,
 393            )
 394            .await?;
 395
 396            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
 397                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
 398
 399            let inner_state = SshRemoteClientState {
 400                ssh_connection,
 401                delegate,
 402                forwarder: proxy,
 403                multiplex_task: Self::multiplex(
 404                    this.clone(),
 405                    ssh_process,
 406                    proxy_incoming_tx,
 407                    proxy_outgoing_rx,
 408                    &mut cx,
 409                ),
 410                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
 411            };
 412
 413            this.update(&mut cx, |this, _| {
 414                this.inner_state.lock().replace(inner_state);
 415            })
 416        })
 417        .detach();
 418        Ok(())
 419    }
 420
 421    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
 422        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
 423            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
 424        };
 425        cx.spawn(|mut cx| {
 426            let this = this.clone();
 427            async move {
 428                const MAX_MISSED_HEARTBEATS: usize = 5;
 429                const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 430                const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
 431
 432                let mut missed_heartbeats = 0;
 433
 434                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
 435                loop {
 436                    timer.next().await;
 437
 438                    log::info!("Sending heartbeat to server...");
 439
 440                    let result = smol::future::or(
 441                        async {
 442                            client.request(proto::Ping {}).await?;
 443                            Ok(())
 444                        },
 445                        async {
 446                            smol::Timer::after(HEARTBEAT_TIMEOUT).await;
 447
 448                            Err(anyhow!("Timeout detected"))
 449                        },
 450                    )
 451                    .await;
 452
 453                    if result.is_err() {
 454                        missed_heartbeats += 1;
 455                        log::warn!(
 456                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
 457                            HEARTBEAT_TIMEOUT,
 458                            missed_heartbeats,
 459                            MAX_MISSED_HEARTBEATS
 460                        );
 461                    } else {
 462                        missed_heartbeats = 0;
 463                    }
 464
 465                    if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
 466                        log::error!(
 467                            "Missed last {} hearbeats. Reconnecting...",
 468                            missed_heartbeats
 469                        );
 470
 471                        this.update(&mut cx, |this, cx| {
 472                            this.reconnect(cx)
 473                                .context("failed to reconnect after missing heartbeats")
 474                        })
 475                        .context("failed to update weak reference, SshRemoteClient lost?")??;
 476                        return Ok(());
 477                    }
 478                }
 479            }
 480        })
 481    }
 482
 483    fn multiplex(
 484        this: WeakModel<Self>,
 485        mut ssh_proxy_process: Child,
 486        incoming_tx: UnboundedSender<Envelope>,
 487        mut outgoing_rx: UnboundedReceiver<Envelope>,
 488        cx: &AsyncAppContext,
 489    ) -> Task<Result<()>> {
 490        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 491        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 492        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 493
 494        let io_task = cx.background_executor().spawn(async move {
 495            let mut stdin_buffer = Vec::new();
 496            let mut stdout_buffer = Vec::new();
 497            let mut stderr_buffer = Vec::new();
 498            let mut stderr_offset = 0;
 499
 500            loop {
 501                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 502                stderr_buffer.resize(stderr_offset + 1024, 0);
 503
 504                select_biased! {
 505                    outgoing = outgoing_rx.next().fuse() => {
 506                        let Some(outgoing) = outgoing else {
 507                            return anyhow::Ok(());
 508                        };
 509
 510                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 511                    }
 512
 513                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
 514                        match result {
 515                            Ok(0) => {
 516                                child_stdin.close().await?;
 517                                outgoing_rx.close();
 518                                let status = ssh_proxy_process.status().await?;
 519                                if !status.success() {
 520                                    log::error!("ssh process exited with status: {status:?}");
 521                                    return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
 522                                }
 523                                return Ok(());
 524                            }
 525                            Ok(len) => {
 526                                if len < stdout_buffer.len() {
 527                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 528                                }
 529
 530                                let message_len = message_len_from_buffer(&stdout_buffer);
 531                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
 532                                    Ok(envelope) => {
 533                                        incoming_tx.unbounded_send(envelope).ok();
 534                                    }
 535                                    Err(error) => {
 536                                        log::error!("error decoding message {error:?}");
 537                                    }
 538                                }
 539                            }
 540                            Err(error) => {
 541                                Err(anyhow!("error reading stdout: {error:?}"))?;
 542                            }
 543                        }
 544                    }
 545
 546                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
 547                        match result {
 548                            Ok(len) => {
 549                                stderr_offset += len;
 550                                let mut start_ix = 0;
 551                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
 552                                    let line_ix = start_ix + ix;
 553                                    let content = &stderr_buffer[start_ix..line_ix];
 554                                    start_ix = line_ix + 1;
 555                                    if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
 556                                        record.message = format!("(remote) {}", record.message);
 557                                        record.log(log::logger())
 558                                    } else {
 559                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
 560                                    }
 561                                }
 562                                stderr_buffer.drain(0..start_ix);
 563                                stderr_offset -= start_ix;
 564                            }
 565                            Err(error) => {
 566                                Err(anyhow!("error reading stderr: {error:?}"))?;
 567                            }
 568                        }
 569                    }
 570                }
 571            }
 572        });
 573
 574        cx.spawn(|mut cx| async move {
 575            let result = io_task.await;
 576
 577            if let Err(error) = result {
 578                log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
 579                this.update(&mut cx, |this, cx| {
 580                    this.reconnect(cx).ok();
 581                })?;
 582            }
 583
 584            Ok(())
 585        })
 586    }
 587
 588    async fn establish_connection(
 589        unique_identifier: String,
 590        connection_options: SshConnectionOptions,
 591        delegate: Arc<dyn SshClientDelegate>,
 592        cx: &mut AsyncAppContext,
 593    ) -> Result<(SshRemoteConnection, Child)> {
 594        let ssh_connection =
 595            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 596
 597        let platform = ssh_connection.query_platform().await?;
 598        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
 599        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
 600        ssh_connection
 601            .ensure_server_binary(
 602                &delegate,
 603                &local_binary_path,
 604                &remote_binary_path,
 605                version,
 606                cx,
 607            )
 608            .await?;
 609
 610        let socket = ssh_connection.socket.clone();
 611        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 612
 613        delegate.set_status(Some("Starting proxy"), cx);
 614
 615        let ssh_proxy_process = socket
 616            .ssh_command(format!(
 617                "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
 618                std::env::var("RUST_LOG").unwrap_or_default(),
 619                std::env::var("RUST_BACKTRACE").unwrap_or_default(),
 620                remote_binary_path,
 621                unique_identifier,
 622            ))
 623            // IMPORTANT: we kill this process when we drop the task that uses it.
 624            .kill_on_drop(true)
 625            .spawn()
 626            .context("failed to spawn remote server")?;
 627
 628        Ok((ssh_connection, ssh_proxy_process))
 629    }
 630
 631    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
 632        self.client.subscribe_to_entity(remote_id, entity);
 633    }
 634
 635    pub fn ssh_args(&self) -> Option<Vec<String>> {
 636        let state = self.inner_state.lock();
 637        state
 638            .as_ref()
 639            .map(|state| state.ssh_connection.socket.ssh_args())
 640    }
 641
 642    pub fn to_proto_client(&self) -> AnyProtoClient {
 643        self.client.clone().into()
 644    }
 645
 646    pub fn connection_string(&self) -> String {
 647        self.connection_options.connection_string()
 648    }
 649
 650    pub fn is_reconnect_underway(&self) -> bool {
 651        maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default()
 652    }
 653
 654    #[cfg(any(test, feature = "test-support"))]
 655    pub fn fake(
 656        client_cx: &mut gpui::TestAppContext,
 657        server_cx: &mut gpui::TestAppContext,
 658    ) -> (Model<Self>, Arc<ChannelClient>) {
 659        use gpui::Context;
 660
 661        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
 662        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
 663
 664        (
 665            client_cx.update(|cx| {
 666                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
 667                cx.new_model(|_| Self {
 668                    client,
 669                    unique_identifier: "fake".to_string(),
 670                    connection_options: SshConnectionOptions::default(),
 671                    inner_state: Arc::new(Mutex::new(None)),
 672                })
 673            }),
 674            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
 675        )
 676    }
 677}
 678
 679impl From<SshRemoteClient> for AnyProtoClient {
 680    fn from(client: SshRemoteClient) -> Self {
 681        AnyProtoClient::new(client.client.clone())
 682    }
 683}
 684
 685struct SshRemoteConnection {
 686    socket: SshSocket,
 687    master_process: process::Child,
 688    _temp_dir: TempDir,
 689}
 690
 691impl Drop for SshRemoteConnection {
 692    fn drop(&mut self) {
 693        if let Err(error) = self.master_process.kill() {
 694            log::error!("failed to kill SSH master process: {}", error);
 695        }
 696    }
 697}
 698
 699impl SshRemoteConnection {
 700    #[cfg(not(unix))]
 701    async fn new(
 702        _connection_options: SshConnectionOptions,
 703        _delegate: Arc<dyn SshClientDelegate>,
 704        _cx: &mut AsyncAppContext,
 705    ) -> Result<Self> {
 706        Err(anyhow!("ssh is not supported on this platform"))
 707    }
 708
 709    #[cfg(unix)]
 710    async fn new(
 711        connection_options: SshConnectionOptions,
 712        delegate: Arc<dyn SshClientDelegate>,
 713        cx: &mut AsyncAppContext,
 714    ) -> Result<Self> {
 715        use futures::{io::BufReader, AsyncBufReadExt as _};
 716        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
 717        use util::ResultExt as _;
 718
 719        delegate.set_status(Some("connecting"), cx);
 720
 721        let url = connection_options.ssh_url();
 722        let temp_dir = tempfile::Builder::new()
 723            .prefix("zed-ssh-session")
 724            .tempdir()?;
 725
 726        // Create a domain socket listener to handle requests from the askpass program.
 727        let askpass_socket = temp_dir.path().join("askpass.sock");
 728        let listener =
 729            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
 730
 731        let askpass_task = cx.spawn({
 732            let delegate = delegate.clone();
 733            |mut cx| async move {
 734                while let Ok((mut stream, _)) = listener.accept().await {
 735                    let mut buffer = Vec::new();
 736                    let mut reader = BufReader::new(&mut stream);
 737                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
 738                        buffer.clear();
 739                    }
 740                    let password_prompt = String::from_utf8_lossy(&buffer);
 741                    if let Some(password) = delegate
 742                        .ask_password(password_prompt.to_string(), &mut cx)
 743                        .await
 744                        .context("failed to get ssh password")
 745                        .and_then(|p| p)
 746                        .log_err()
 747                    {
 748                        stream.write_all(password.as_bytes()).await.log_err();
 749                    }
 750                }
 751            }
 752        });
 753
 754        // Create an askpass script that communicates back to this process.
 755        let askpass_script = format!(
 756            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
 757            askpass_socket = askpass_socket.display(),
 758            print_args = "printf '%s\\0' \"$@\"",
 759            shebang = "#!/bin/sh",
 760        );
 761        let askpass_script_path = temp_dir.path().join("askpass.sh");
 762        fs::write(&askpass_script_path, askpass_script).await?;
 763        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
 764
 765        // Start the master SSH process, which does not do anything except for establish
 766        // the connection and keep it open, allowing other ssh commands to reuse it
 767        // via a control socket.
 768        let socket_path = temp_dir.path().join("ssh.sock");
 769        let mut master_process = process::Command::new("ssh")
 770            .stdin(Stdio::null())
 771            .stdout(Stdio::piped())
 772            .stderr(Stdio::piped())
 773            .env("SSH_ASKPASS_REQUIRE", "force")
 774            .env("SSH_ASKPASS", &askpass_script_path)
 775            .args(["-N", "-o", "ControlMaster=yes", "-o"])
 776            .arg(format!("ControlPath={}", socket_path.display()))
 777            .arg(&url)
 778            .spawn()?;
 779
 780        // Wait for this ssh process to close its stdout, indicating that authentication
 781        // has completed.
 782        let stdout = master_process.stdout.as_mut().unwrap();
 783        let mut output = Vec::new();
 784        let connection_timeout = Duration::from_secs(10);
 785        let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
 786        if let Err(e) = result {
 787            let error_message = if e.kind() == std::io::ErrorKind::TimedOut {
 788                format!(
 789                    "Failed to connect to host. Timed out after {:?}.",
 790                    connection_timeout
 791                )
 792            } else {
 793                format!("Failed to connect to host: {}.", e)
 794            };
 795
 796            delegate.set_error(error_message, cx);
 797            return Err(e.into());
 798        }
 799
 800        drop(askpass_task);
 801
 802        if master_process.try_status()?.is_some() {
 803            output.clear();
 804            let mut stderr = master_process.stderr.take().unwrap();
 805            stderr.read_to_end(&mut output).await?;
 806            Err(anyhow!(
 807                "failed to connect: {}",
 808                String::from_utf8_lossy(&output)
 809            ))?;
 810        }
 811
 812        Ok(Self {
 813            socket: SshSocket {
 814                connection_options,
 815                socket_path,
 816            },
 817            master_process,
 818            _temp_dir: temp_dir,
 819        })
 820    }
 821
 822    async fn ensure_server_binary(
 823        &self,
 824        delegate: &Arc<dyn SshClientDelegate>,
 825        src_path: &Path,
 826        dst_path: &Path,
 827        version: SemanticVersion,
 828        cx: &mut AsyncAppContext,
 829    ) -> Result<()> {
 830        let mut dst_path_gz = dst_path.to_path_buf();
 831        dst_path_gz.set_extension("gz");
 832
 833        if let Some(parent) = dst_path.parent() {
 834            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
 835        }
 836
 837        let mut server_binary_exists = false;
 838        if cfg!(not(debug_assertions)) {
 839            if let Ok(installed_version) =
 840                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
 841            {
 842                if installed_version.trim() == version.to_string() {
 843                    server_binary_exists = true;
 844                }
 845            }
 846        }
 847
 848        if server_binary_exists {
 849            log::info!("remote development server already present",);
 850            return Ok(());
 851        }
 852
 853        let src_stat = fs::metadata(src_path).await?;
 854        let size = src_stat.len();
 855        let server_mode = 0o755;
 856
 857        let t0 = Instant::now();
 858        delegate.set_status(Some("uploading remote development server"), cx);
 859        log::info!("uploading remote development server ({}kb)", size / 1024);
 860        self.upload_file(src_path, &dst_path_gz)
 861            .await
 862            .context("failed to upload server binary")?;
 863        log::info!("uploaded remote development server in {:?}", t0.elapsed());
 864
 865        delegate.set_status(Some("extracting remote development server"), cx);
 866        run_cmd(
 867            self.socket
 868                .ssh_command("gunzip")
 869                .arg("--force")
 870                .arg(&dst_path_gz),
 871        )
 872        .await?;
 873
 874        delegate.set_status(Some("unzipping remote development server"), cx);
 875        run_cmd(
 876            self.socket
 877                .ssh_command("chmod")
 878                .arg(format!("{:o}", server_mode))
 879                .arg(dst_path),
 880        )
 881        .await?;
 882
 883        Ok(())
 884    }
 885
 886    async fn query_platform(&self) -> Result<SshPlatform> {
 887        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
 888        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
 889
 890        let os = match os.trim() {
 891            "Darwin" => "macos",
 892            "Linux" => "linux",
 893            _ => Err(anyhow!("unknown uname os {os:?}"))?,
 894        };
 895        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
 896            "aarch64"
 897        } else if arch.starts_with("x86") || arch.starts_with("i686") {
 898            "x86_64"
 899        } else {
 900            Err(anyhow!("unknown uname architecture {arch:?}"))?
 901        };
 902
 903        Ok(SshPlatform { os, arch })
 904    }
 905
 906    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
 907        let mut command = process::Command::new("scp");
 908        let output = self
 909            .socket
 910            .ssh_options(&mut command)
 911            .args(
 912                self.socket
 913                    .connection_options
 914                    .port
 915                    .map(|port| vec!["-P".to_string(), port.to_string()])
 916                    .unwrap_or_default(),
 917            )
 918            .arg(src_path)
 919            .arg(format!(
 920                "{}:{}",
 921                self.socket.connection_options.scp_url(),
 922                dest_path.display()
 923            ))
 924            .output()
 925            .await?;
 926
 927        if output.status.success() {
 928            Ok(())
 929        } else {
 930            Err(anyhow!(
 931                "failed to upload file {} -> {}: {}",
 932                src_path.display(),
 933                dest_path.display(),
 934                String::from_utf8_lossy(&output.stderr)
 935            ))
 936        }
 937    }
 938}
 939
 940type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 941
 942pub struct ChannelClient {
 943    next_message_id: AtomicU32,
 944    outgoing_tx: mpsc::UnboundedSender<Envelope>,
 945    response_channels: ResponseChannels,             // Lock
 946    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
 947}
 948
 949impl ChannelClient {
 950    pub fn new(
 951        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 952        outgoing_tx: mpsc::UnboundedSender<Envelope>,
 953        cx: &AppContext,
 954    ) -> Arc<Self> {
 955        let this = Arc::new(Self {
 956            outgoing_tx,
 957            next_message_id: AtomicU32::new(0),
 958            response_channels: ResponseChannels::default(),
 959            message_handlers: Default::default(),
 960        });
 961
 962        Self::start_handling_messages(this.clone(), incoming_rx, cx);
 963
 964        this
 965    }
 966
 967    fn start_handling_messages(
 968        this: Arc<Self>,
 969        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
 970        cx: &AppContext,
 971    ) {
 972        cx.spawn(|cx| {
 973            let this = Arc::downgrade(&this);
 974            async move {
 975                let peer_id = PeerId { owner_id: 0, id: 0 };
 976                while let Some(incoming) = incoming_rx.next().await {
 977                    let Some(this) = this.upgrade() else {
 978                        return anyhow::Ok(());
 979                    };
 980
 981                    if let Some(request_id) = incoming.responding_to {
 982                        let request_id = MessageId(request_id);
 983                        let sender = this.response_channels.lock().remove(&request_id);
 984                        if let Some(sender) = sender {
 985                            let (tx, rx) = oneshot::channel();
 986                            if incoming.payload.is_some() {
 987                                sender.send((incoming, tx)).ok();
 988                            }
 989                            rx.await.ok();
 990                        }
 991                    } else if let Some(envelope) =
 992                        build_typed_envelope(peer_id, Instant::now(), incoming)
 993                    {
 994                        let type_name = envelope.payload_type_name();
 995                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
 996                            &this.message_handlers,
 997                            envelope,
 998                            this.clone().into(),
 999                            cx.clone(),
1000                        ) {
1001                            log::debug!("ssh message received. name:{type_name}");
1002                            match future.await {
1003                                Ok(_) => {
1004                                    log::debug!("ssh message handled. name:{type_name}");
1005                                }
1006                                Err(error) => {
1007                                    log::error!(
1008                                        "error handling message. type:{type_name}, error:{error}",
1009                                    );
1010                                }
1011                            }
1012                        } else {
1013                            log::error!("unhandled ssh message name:{type_name}");
1014                        }
1015                    }
1016                }
1017                anyhow::Ok(())
1018            }
1019        })
1020        .detach();
1021    }
1022
1023    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1024        let id = (TypeId::of::<E>(), remote_id);
1025
1026        let mut message_handlers = self.message_handlers.lock();
1027        if message_handlers
1028            .entities_by_type_and_remote_id
1029            .contains_key(&id)
1030        {
1031            panic!("already subscribed to entity");
1032        }
1033
1034        message_handlers.entities_by_type_and_remote_id.insert(
1035            id,
1036            EntityMessageSubscriber::Entity {
1037                handle: entity.downgrade().into(),
1038            },
1039        );
1040    }
1041
1042    pub fn request<T: RequestMessage>(
1043        &self,
1044        payload: T,
1045    ) -> impl 'static + Future<Output = Result<T::Response>> {
1046        log::debug!("ssh request start. name:{}", T::NAME);
1047        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1048        async move {
1049            let response = response.await?;
1050            log::debug!("ssh request finish. name:{}", T::NAME);
1051            T::Response::from_envelope(response)
1052                .ok_or_else(|| anyhow!("received a response of the wrong type"))
1053        }
1054    }
1055
1056    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1057        log::debug!("ssh send name:{}", T::NAME);
1058        self.send_dynamic(payload.into_envelope(0, None, None))
1059    }
1060
1061    pub fn request_dynamic(
1062        &self,
1063        mut envelope: proto::Envelope,
1064        type_name: &'static str,
1065    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1066        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1067        let (tx, rx) = oneshot::channel();
1068        let mut response_channels_lock = self.response_channels.lock();
1069        response_channels_lock.insert(MessageId(envelope.id), tx);
1070        drop(response_channels_lock);
1071        let result = self.outgoing_tx.unbounded_send(envelope);
1072        async move {
1073            if let Err(error) = &result {
1074                log::error!("failed to send message: {}", error);
1075                return Err(anyhow!("failed to send message: {}", error));
1076            }
1077
1078            let response = rx.await.context("connection lost")?.0;
1079            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1080                return Err(RpcError::from_proto(error, type_name));
1081            }
1082            Ok(response)
1083        }
1084    }
1085
1086    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1087        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1088        self.outgoing_tx.unbounded_send(envelope)?;
1089        Ok(())
1090    }
1091}
1092
1093impl ProtoClient for ChannelClient {
1094    fn request(
1095        &self,
1096        envelope: proto::Envelope,
1097        request_type: &'static str,
1098    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1099        self.request_dynamic(envelope, request_type).boxed()
1100    }
1101
1102    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1103        self.send_dynamic(envelope)
1104    }
1105
1106    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1107        self.send_dynamic(envelope)
1108    }
1109
1110    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1111        &self.message_handlers
1112    }
1113
1114    fn is_via_collab(&self) -> bool {
1115        false
1116    }
1117}