ssh_session.rs

  1use crate::{
  2    json_log::LogRecord,
  3    protocol::{
  4        message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
  5    },
  6};
  7use anyhow::{anyhow, Context as _, Result};
  8use collections::HashMap;
  9use futures::{
 10    channel::{mpsc, oneshot},
 11    future::{BoxFuture, LocalBoxFuture},
 12    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
 13};
 14use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel};
 15use parking_lot::Mutex;
 16use rpc::{
 17    proto::{
 18        self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId,
 19        ProtoClient, RequestMessage,
 20    },
 21    TypedEnvelope,
 22};
 23use smol::{
 24    fs,
 25    process::{self, Stdio},
 26};
 27use std::{
 28    any::TypeId,
 29    ffi::OsStr,
 30    path::{Path, PathBuf},
 31    sync::{
 32        atomic::{AtomicU32, Ordering::SeqCst},
 33        Arc,
 34    },
 35    time::Instant,
 36};
 37use tempfile::TempDir;
 38
 39pub struct SshSession {
 40    next_message_id: AtomicU32,
 41    response_channels: ResponseChannels,
 42    outgoing_tx: mpsc::UnboundedSender<Envelope>,
 43    spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
 44    message_handlers: Mutex<
 45        HashMap<
 46            TypeId,
 47            Arc<
 48                dyn Send
 49                    + Sync
 50                    + Fn(
 51                        Box<dyn AnyTypedEnvelope>,
 52                        Arc<SshSession>,
 53                        AsyncAppContext,
 54                    ) -> Option<LocalBoxFuture<'static, Result<()>>>,
 55            >,
 56        >,
 57    >,
 58}
 59
 60struct SshClientState {
 61    socket_path: PathBuf,
 62    port: u16,
 63    url: String,
 64    _master_process: process::Child,
 65    _temp_dir: TempDir,
 66}
 67
 68struct SpawnRequest {
 69    command: String,
 70    process_tx: oneshot::Sender<process::Child>,
 71}
 72
 73#[derive(Copy, Clone, Debug)]
 74pub struct SshPlatform {
 75    pub os: &'static str,
 76    pub arch: &'static str,
 77}
 78
 79pub trait SshClientDelegate {
 80    fn ask_password(
 81        &self,
 82        prompt: String,
 83        cx: &mut AsyncAppContext,
 84    ) -> oneshot::Receiver<Result<String>>;
 85    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 86    fn get_server_binary(
 87        &self,
 88        platform: SshPlatform,
 89        cx: &mut AsyncAppContext,
 90    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 91    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 92}
 93
 94type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 95
 96impl SshSession {
 97    pub async fn client(
 98        user: String,
 99        host: String,
100        port: u16,
101        delegate: Arc<dyn SshClientDelegate>,
102        cx: &mut AsyncAppContext,
103    ) -> Result<Arc<Self>> {
104        let client_state = SshClientState::new(user, host, port, delegate.clone(), cx).await?;
105
106        let platform = client_state.query_platform().await?;
107        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
108        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
109        client_state
110            .ensure_server_binary(
111                &delegate,
112                &local_binary_path,
113                &remote_binary_path,
114                version,
115                cx,
116            )
117            .await?;
118
119        let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
120        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
121        let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
122
123        run_cmd(client_state.ssh_command(&remote_binary_path).arg("version")).await?;
124
125        let mut remote_server_child = client_state
126            .ssh_command(&format!(
127                "RUST_LOG={} {:?} run",
128                std::env::var("RUST_LOG").unwrap_or(String::new()),
129                remote_binary_path,
130            ))
131            .spawn()
132            .context("failed to spawn remote server")?;
133        let mut child_stderr = remote_server_child.stderr.take().unwrap();
134        let mut child_stdout = remote_server_child.stdout.take().unwrap();
135        let mut child_stdin = remote_server_child.stdin.take().unwrap();
136
137        let executor = cx.background_executor().clone();
138        executor.clone().spawn(async move {
139            let mut stdin_buffer = Vec::new();
140            let mut stdout_buffer = Vec::new();
141            let mut stderr_buffer = Vec::new();
142            let mut stderr_offset = 0;
143
144            loop {
145                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
146                stderr_buffer.resize(stderr_offset + 1024, 0);
147
148                select_biased! {
149                    outgoing = outgoing_rx.next().fuse() => {
150                        let Some(outgoing) = outgoing else {
151                            return anyhow::Ok(());
152                        };
153
154                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
155                    }
156
157                    request = spawn_process_rx.next().fuse() => {
158                        let Some(request) = request else {
159                            return Ok(());
160                        };
161
162                        log::info!("spawn process: {:?}", request.command);
163                        let child = client_state
164                            .ssh_command(&request.command)
165                            .spawn()
166                            .context("failed to create channel")?;
167                        request.process_tx.send(child).ok();
168                    }
169
170                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
171                        match result {
172                            Ok(len) => {
173                                if len == 0 {
174                                    child_stdin.close().await?;
175                                    let status = remote_server_child.status().await?;
176                                    if !status.success() {
177                                        log::info!("channel exited with status: {status:?}");
178                                    }
179                                    return Ok(());
180                                }
181
182                                if len < stdout_buffer.len() {
183                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
184                                }
185
186                                let message_len = message_len_from_buffer(&stdout_buffer);
187                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
188                                    Ok(envelope) => {
189                                        incoming_tx.unbounded_send(envelope).ok();
190                                    }
191                                    Err(error) => {
192                                        log::error!("error decoding message {error:?}");
193                                    }
194                                }
195                            }
196                            Err(error) => {
197                                Err(anyhow!("error reading stdout: {error:?}"))?;
198                            }
199                        }
200                    }
201
202                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
203                        match result {
204                            Ok(len) => {
205                                stderr_offset += len;
206                                let mut start_ix = 0;
207                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
208                                    let line_ix = start_ix + ix;
209                                    let content = &stderr_buffer[start_ix..line_ix];
210                                    start_ix = line_ix + 1;
211                                    if let Ok(record) = serde_json::from_slice::<LogRecord>(&content) {
212                                        record.log(log::logger())
213                                    } else {
214                                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
215                                    }
216                                }
217                                stderr_buffer.drain(0..start_ix);
218                                stderr_offset -= start_ix;
219                            }
220                            Err(error) => {
221                                Err(anyhow!("error reading stderr: {error:?}"))?;
222                            }
223                        }
224                    }
225                }
226            }
227        }).detach();
228
229        cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, cx))
230    }
231
232    pub fn server(
233        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
234        outgoing_tx: mpsc::UnboundedSender<Envelope>,
235        cx: &AppContext,
236    ) -> Arc<SshSession> {
237        let (tx, _rx) = mpsc::unbounded();
238        Self::new(incoming_rx, outgoing_tx, tx, cx)
239    }
240
241    #[cfg(any(test, feature = "test-support"))]
242    pub fn fake(
243        client_cx: &mut gpui::TestAppContext,
244        server_cx: &mut gpui::TestAppContext,
245    ) -> (Arc<Self>, Arc<Self>) {
246        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
247        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
248        let (tx, _rx) = mpsc::unbounded();
249        (
250            client_cx
251                .update(|cx| Self::new(server_to_client_rx, client_to_server_tx, tx.clone(), cx)),
252            server_cx
253                .update(|cx| Self::new(client_to_server_rx, server_to_client_tx, tx.clone(), cx)),
254        )
255    }
256
257    fn new(
258        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
259        outgoing_tx: mpsc::UnboundedSender<Envelope>,
260        spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
261        cx: &AppContext,
262    ) -> Arc<SshSession> {
263        let this = Arc::new(Self {
264            next_message_id: AtomicU32::new(0),
265            response_channels: ResponseChannels::default(),
266            outgoing_tx,
267            spawn_process_tx,
268            message_handlers: Default::default(),
269        });
270
271        cx.spawn(|cx| {
272            let this = this.clone();
273            async move {
274                let peer_id = PeerId { owner_id: 0, id: 0 };
275                while let Some(incoming) = incoming_rx.next().await {
276                    if let Some(request_id) = incoming.responding_to {
277                        let request_id = MessageId(request_id);
278                        let sender = this.response_channels.lock().remove(&request_id);
279                        if let Some(sender) = sender {
280                            let (tx, rx) = oneshot::channel();
281                            if incoming.payload.is_some() {
282                                sender.send((incoming, tx)).ok();
283                            }
284                            rx.await.ok();
285                        }
286                    } else if let Some(envelope) =
287                        build_typed_envelope(peer_id, Instant::now(), incoming)
288                    {
289                        log::debug!(
290                            "ssh message received. name:{}",
291                            envelope.payload_type_name()
292                        );
293                        let type_id = envelope.payload_type_id();
294                        let handler = this.message_handlers.lock().get(&type_id).cloned();
295                        if let Some(handler) = handler {
296                            if let Some(future) = handler(envelope, this.clone(), cx.clone()) {
297                                future.await.ok();
298                            } else {
299                                this.message_handlers.lock().remove(&type_id);
300                            }
301                        }
302                    }
303                }
304                anyhow::Ok(())
305            }
306        })
307        .detach();
308
309        this
310    }
311
312    pub fn request<T: RequestMessage>(
313        &self,
314        payload: T,
315    ) -> impl 'static + Future<Output = Result<T::Response>> {
316        log::debug!("ssh request start. name:{}", T::NAME);
317        let response = self.request_dynamic(payload.into_envelope(0, None, None), "");
318        async move {
319            let response = response.await?;
320            log::debug!("ssh request finish. name:{}", T::NAME);
321            T::Response::from_envelope(response)
322                .ok_or_else(|| anyhow!("received a response of the wrong type"))
323        }
324    }
325
326    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
327        self.send_dynamic(payload.into_envelope(0, None, None))
328    }
329
330    pub fn request_dynamic(
331        &self,
332        mut envelope: proto::Envelope,
333        _request_type: &'static str,
334    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
335        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
336        let (tx, rx) = oneshot::channel();
337        self.response_channels
338            .lock()
339            .insert(MessageId(envelope.id), tx);
340        self.outgoing_tx.unbounded_send(envelope).ok();
341        async move { Ok(rx.await.context("connection lost")?.0) }
342    }
343
344    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
345        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
346        self.outgoing_tx.unbounded_send(envelope)?;
347        Ok(())
348    }
349
350    pub async fn spawn_process(&self, command: String) -> process::Child {
351        let (process_tx, process_rx) = oneshot::channel();
352        self.spawn_process_tx
353            .unbounded_send(SpawnRequest {
354                command,
355                process_tx,
356            })
357            .ok();
358        process_rx.await.unwrap()
359    }
360
361    pub fn add_message_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
362    where
363        M: EnvelopedMessage,
364        E: 'static,
365        H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
366        F: 'static + Future<Output = Result<()>>,
367    {
368        let message_type_id = TypeId::of::<M>();
369        self.message_handlers.lock().insert(
370            message_type_id,
371            Arc::new(move |envelope, _, cx| {
372                let entity = entity.upgrade()?;
373                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
374                Some(handler(entity, *envelope, cx).boxed_local())
375            }),
376        );
377    }
378
379    pub fn add_request_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
380    where
381        M: EnvelopedMessage + RequestMessage,
382        E: 'static,
383        H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
384        F: 'static + Future<Output = Result<M::Response>>,
385    {
386        let message_type_id = TypeId::of::<M>();
387        self.message_handlers.lock().insert(
388            message_type_id,
389            Arc::new(move |envelope, this, cx| {
390                let entity = entity.upgrade()?;
391                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
392                let request_id = envelope.message_id();
393                Some(
394                    handler(entity, *envelope, cx)
395                        .then(move |result| async move {
396                            this.outgoing_tx.unbounded_send(result?.into_envelope(
397                                this.next_message_id.fetch_add(1, SeqCst),
398                                Some(request_id),
399                                None,
400                            ))?;
401                            Ok(())
402                        })
403                        .boxed_local(),
404                )
405            }),
406        );
407    }
408}
409
410impl ProtoClient for SshSession {
411    fn request(
412        &self,
413        envelope: proto::Envelope,
414        request_type: &'static str,
415    ) -> BoxFuture<'static, Result<proto::Envelope>> {
416        self.request_dynamic(envelope, request_type).boxed()
417    }
418
419    fn send(&self, envelope: proto::Envelope) -> Result<()> {
420        self.send_dynamic(envelope)
421    }
422}
423
424impl SshClientState {
425    #[cfg(not(unix))]
426    async fn new(
427        _user: String,
428        _host: String,
429        _port: u16,
430        _delegate: Arc<dyn SshClientDelegate>,
431        _cx: &mut AsyncAppContext,
432    ) -> Result<Self> {
433        Err(anyhow!("ssh is not supported on this platform"))
434    }
435
436    #[cfg(unix)]
437    async fn new(
438        user: String,
439        host: String,
440        port: u16,
441        delegate: Arc<dyn SshClientDelegate>,
442        cx: &mut AsyncAppContext,
443    ) -> Result<Self> {
444        use futures::{io::BufReader, AsyncBufReadExt as _};
445        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
446        use util::ResultExt as _;
447
448        delegate.set_status(Some("connecting"), cx);
449
450        let url = format!("{user}@{host}");
451        let temp_dir = tempfile::Builder::new()
452            .prefix("zed-ssh-session")
453            .tempdir()?;
454
455        // Create a domain socket listener to handle requests from the askpass program.
456        let askpass_socket = temp_dir.path().join("askpass.sock");
457        let listener =
458            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
459
460        let askpass_task = cx.spawn(|mut cx| async move {
461            while let Ok((mut stream, _)) = listener.accept().await {
462                let mut buffer = Vec::new();
463                let mut reader = BufReader::new(&mut stream);
464                if reader.read_until(b'\0', &mut buffer).await.is_err() {
465                    buffer.clear();
466                }
467                let password_prompt = String::from_utf8_lossy(&buffer);
468                if let Some(password) = delegate
469                    .ask_password(password_prompt.to_string(), &mut cx)
470                    .await
471                    .context("failed to get ssh password")
472                    .and_then(|p| p)
473                    .log_err()
474                {
475                    stream.write_all(password.as_bytes()).await.log_err();
476                }
477            }
478        });
479
480        // Create an askpass script that communicates back to this process.
481        let askpass_script = format!(
482            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
483            askpass_socket = askpass_socket.display(),
484            print_args = "printf '%s\\0' \"$@\"",
485            shebang = "#!/bin/sh",
486        );
487        let askpass_script_path = temp_dir.path().join("askpass.sh");
488        fs::write(&askpass_script_path, askpass_script).await?;
489        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
490
491        // Start the master SSH process, which does not do anything except for establish
492        // the connection and keep it open, allowing other ssh commands to reuse it
493        // via a control socket.
494        let socket_path = temp_dir.path().join("ssh.sock");
495        let mut master_process = process::Command::new("ssh")
496            .stdin(Stdio::null())
497            .stdout(Stdio::piped())
498            .stderr(Stdio::piped())
499            .env("SSH_ASKPASS_REQUIRE", "force")
500            .env("SSH_ASKPASS", &askpass_script_path)
501            .args(["-N", "-o", "ControlMaster=yes", "-o"])
502            .arg(format!("ControlPath={}", socket_path.display()))
503            .args(["-p", &port.to_string()])
504            .arg(&url)
505            .spawn()?;
506
507        // Wait for this ssh process to close its stdout, indicating that authentication
508        // has completed.
509        let stdout = master_process.stdout.as_mut().unwrap();
510        let mut output = Vec::new();
511        stdout.read_to_end(&mut output).await?;
512        drop(askpass_task);
513
514        if master_process.try_status()?.is_some() {
515            output.clear();
516            let mut stderr = master_process.stderr.take().unwrap();
517            stderr.read_to_end(&mut output).await?;
518            Err(anyhow!(
519                "failed to connect: {}",
520                String::from_utf8_lossy(&output)
521            ))?;
522        }
523
524        Ok(Self {
525            url,
526            port,
527            socket_path,
528            _master_process: master_process,
529            _temp_dir: temp_dir,
530        })
531    }
532
533    async fn ensure_server_binary(
534        &self,
535        delegate: &Arc<dyn SshClientDelegate>,
536        src_path: &Path,
537        dst_path: &Path,
538        version: SemanticVersion,
539        cx: &mut AsyncAppContext,
540    ) -> Result<()> {
541        let mut dst_path_gz = dst_path.to_path_buf();
542        dst_path_gz.set_extension("gz");
543
544        if let Some(parent) = dst_path.parent() {
545            run_cmd(self.ssh_command("mkdir").arg("-p").arg(parent)).await?;
546        }
547
548        let mut server_binary_exists = false;
549        if cfg!(not(debug_assertions)) {
550            if let Ok(installed_version) = run_cmd(self.ssh_command(&dst_path).arg("version")).await
551            {
552                if installed_version.trim() == version.to_string() {
553                    server_binary_exists = true;
554                }
555            }
556        }
557
558        if server_binary_exists {
559            log::info!("remote development server already present",);
560            return Ok(());
561        }
562
563        let src_stat = fs::metadata(src_path).await?;
564        let size = src_stat.len();
565        let server_mode = 0o755;
566
567        let t0 = Instant::now();
568        delegate.set_status(Some("uploading remote development server"), cx);
569        log::info!("uploading remote development server ({}kb)", size / 1024);
570        self.upload_file(src_path, &dst_path_gz)
571            .await
572            .context("failed to upload server binary")?;
573        log::info!("uploaded remote development server in {:?}", t0.elapsed());
574
575        delegate.set_status(Some("extracting remote development server"), cx);
576        run_cmd(self.ssh_command("gunzip").arg("--force").arg(&dst_path_gz)).await?;
577
578        delegate.set_status(Some("unzipping remote development server"), cx);
579        run_cmd(
580            self.ssh_command("chmod")
581                .arg(format!("{:o}", server_mode))
582                .arg(&dst_path),
583        )
584        .await?;
585
586        Ok(())
587    }
588
589    async fn query_platform(&self) -> Result<SshPlatform> {
590        let os = run_cmd(self.ssh_command("uname").arg("-s")).await?;
591        let arch = run_cmd(self.ssh_command("uname").arg("-m")).await?;
592
593        let os = match os.trim() {
594            "Darwin" => "macos",
595            "Linux" => "linux",
596            _ => Err(anyhow!("unknown uname os {os:?}"))?,
597        };
598        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
599            "aarch64"
600        } else if arch.starts_with("x86") || arch.starts_with("i686") {
601            "x86_64"
602        } else {
603            Err(anyhow!("unknown uname architecture {arch:?}"))?
604        };
605
606        Ok(SshPlatform { os, arch })
607    }
608
609    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
610        let mut command = process::Command::new("scp");
611        let output = self
612            .ssh_options(&mut command)
613            .arg("-P")
614            .arg(&self.port.to_string())
615            .arg(&src_path)
616            .arg(&format!("{}:{}", self.url, dest_path.display()))
617            .output()
618            .await?;
619
620        if output.status.success() {
621            Ok(())
622        } else {
623            Err(anyhow!(
624                "failed to upload file {} -> {}: {}",
625                src_path.display(),
626                dest_path.display(),
627                String::from_utf8_lossy(&output.stderr)
628            ))
629        }
630    }
631
632    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
633        let mut command = process::Command::new("ssh");
634        self.ssh_options(&mut command)
635            .arg("-p")
636            .arg(&self.port.to_string())
637            .arg(&self.url)
638            .arg(program);
639        command
640    }
641
642    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
643        command
644            .stdin(Stdio::piped())
645            .stdout(Stdio::piped())
646            .stderr(Stdio::piped())
647            .args(["-o", "ControlMaster=no", "-o"])
648            .arg(format!("ControlPath={}", self.socket_path.display()))
649    }
650}
651
652async fn run_cmd(command: &mut process::Command) -> Result<String> {
653    let output = command.output().await?;
654    if output.status.success() {
655        Ok(String::from_utf8_lossy(&output.stdout).to_string())
656    } else {
657        Err(anyhow!(
658            "failed to run command: {}",
659            String::from_utf8_lossy(&output.stderr)
660        ))
661    }
662}