ssh_session.rs

  1use crate::protocol::{
  2    message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
  3};
  4use anyhow::{anyhow, Context as _, Result};
  5use collections::HashMap;
  6use futures::{
  7    channel::{mpsc, oneshot},
  8    future::{BoxFuture, LocalBoxFuture},
  9    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
 10};
 11use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel};
 12use parking_lot::Mutex;
 13use rpc::{
 14    proto::{
 15        self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId,
 16        ProtoClient, RequestMessage,
 17    },
 18    TypedEnvelope,
 19};
 20use smol::{
 21    fs,
 22    process::{self, Stdio},
 23};
 24use std::{
 25    any::TypeId,
 26    ffi::OsStr,
 27    path::{Path, PathBuf},
 28    sync::{
 29        atomic::{AtomicU32, Ordering::SeqCst},
 30        Arc,
 31    },
 32    time::Instant,
 33};
 34use tempfile::TempDir;
 35
 36pub struct SshSession {
 37    next_message_id: AtomicU32,
 38    response_channels: ResponseChannels,
 39    outgoing_tx: mpsc::UnboundedSender<Envelope>,
 40    spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
 41    message_handlers: Mutex<
 42        HashMap<
 43            TypeId,
 44            Arc<
 45                dyn Send
 46                    + Sync
 47                    + Fn(
 48                        Box<dyn AnyTypedEnvelope>,
 49                        Arc<SshSession>,
 50                        AsyncAppContext,
 51                    ) -> Option<LocalBoxFuture<'static, Result<()>>>,
 52            >,
 53        >,
 54    >,
 55}
 56
 57struct SshClientState {
 58    socket_path: PathBuf,
 59    port: u16,
 60    url: String,
 61    _master_process: process::Child,
 62    _temp_dir: TempDir,
 63}
 64
 65struct SpawnRequest {
 66    command: String,
 67    process_tx: oneshot::Sender<process::Child>,
 68}
 69
 70#[derive(Copy, Clone, Debug)]
 71pub struct SshPlatform {
 72    pub os: &'static str,
 73    pub arch: &'static str,
 74}
 75
 76pub trait SshClientDelegate {
 77    fn ask_password(
 78        &self,
 79        prompt: String,
 80        cx: &mut AsyncAppContext,
 81    ) -> oneshot::Receiver<Result<String>>;
 82    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 83    fn get_server_binary(
 84        &self,
 85        platform: SshPlatform,
 86        cx: &mut AsyncAppContext,
 87    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 88    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 89}
 90
 91type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 92
 93impl SshSession {
 94    pub async fn client(
 95        user: String,
 96        host: String,
 97        port: u16,
 98        delegate: Arc<dyn SshClientDelegate>,
 99        cx: &mut AsyncAppContext,
100    ) -> Result<Arc<Self>> {
101        let client_state = SshClientState::new(user, host, port, delegate.clone(), cx).await?;
102
103        let platform = client_state.query_platform().await?;
104        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
105        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
106        client_state
107            .ensure_server_binary(
108                &delegate,
109                &local_binary_path,
110                &remote_binary_path,
111                version,
112                cx,
113            )
114            .await?;
115
116        let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
117        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
118        let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
119
120        let mut remote_server_child = client_state
121            .ssh_command(&remote_binary_path)
122            .arg("run")
123            .spawn()
124            .context("failed to spawn remote server")?;
125        let mut child_stderr = remote_server_child.stderr.take().unwrap();
126        let mut child_stdout = remote_server_child.stdout.take().unwrap();
127        let mut child_stdin = remote_server_child.stdin.take().unwrap();
128
129        let executor = cx.background_executor().clone();
130        executor.clone().spawn(async move {
131            let mut stdin_buffer = Vec::new();
132            let mut stdout_buffer = Vec::new();
133            let mut stderr_buffer = Vec::new();
134            let mut stderr_offset = 0;
135
136            loop {
137                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
138                stderr_buffer.resize(stderr_offset + 1024, 0);
139
140                select_biased! {
141                    outgoing = outgoing_rx.next().fuse() => {
142                        let Some(outgoing) = outgoing else {
143                            return anyhow::Ok(());
144                        };
145
146                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
147                    }
148
149                    request = spawn_process_rx.next().fuse() => {
150                        let Some(request) = request else {
151                            return Ok(());
152                        };
153
154                        log::info!("spawn process: {:?}", request.command);
155                        let child = client_state
156                            .ssh_command(&request.command)
157                            .spawn()
158                            .context("failed to create channel")?;
159                        request.process_tx.send(child).ok();
160                    }
161
162                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
163                        match result {
164                            Ok(len) => {
165                                if len == 0 {
166                                    child_stdin.close().await?;
167                                    let status = remote_server_child.status().await?;
168                                    if !status.success() {
169                                        log::info!("channel exited with status: {status:?}");
170                                    }
171                                    return Ok(());
172                                }
173
174                                if len < stdout_buffer.len() {
175                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
176                                }
177
178                                let message_len = message_len_from_buffer(&stdout_buffer);
179                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
180                                    Ok(envelope) => {
181                                        incoming_tx.unbounded_send(envelope).ok();
182                                    }
183                                    Err(error) => {
184                                        log::error!("error decoding message {error:?}");
185                                    }
186                                }
187                            }
188                            Err(error) => {
189                                Err(anyhow!("error reading stdout: {error:?}"))?;
190                            }
191                        }
192                    }
193
194                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
195                        match result {
196                            Ok(len) => {
197                                stderr_offset += len;
198                                let mut start_ix = 0;
199                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
200                                    let line_ix = start_ix + ix;
201                                    let content = String::from_utf8_lossy(&stderr_buffer[start_ix..line_ix]);
202                                    start_ix = line_ix + 1;
203                                    eprintln!("(remote) {}", content);
204                                }
205                                stderr_buffer.drain(0..start_ix);
206                                stderr_offset -= start_ix;
207                            }
208                            Err(error) => {
209                                Err(anyhow!("error reading stderr: {error:?}"))?;
210                            }
211                        }
212                    }
213                }
214            }
215        }).detach();
216
217        cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, cx))
218    }
219
220    pub fn server(
221        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
222        outgoing_tx: mpsc::UnboundedSender<Envelope>,
223        cx: &AppContext,
224    ) -> Arc<SshSession> {
225        let (tx, _rx) = mpsc::unbounded();
226        Self::new(incoming_rx, outgoing_tx, tx, cx)
227    }
228
229    #[cfg(any(test, feature = "test-support"))]
230    pub fn fake(
231        client_cx: &mut gpui::TestAppContext,
232        server_cx: &mut gpui::TestAppContext,
233    ) -> (Arc<Self>, Arc<Self>) {
234        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
235        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
236        let (tx, _rx) = mpsc::unbounded();
237        (
238            client_cx
239                .update(|cx| Self::new(server_to_client_rx, client_to_server_tx, tx.clone(), cx)),
240            server_cx
241                .update(|cx| Self::new(client_to_server_rx, server_to_client_tx, tx.clone(), cx)),
242        )
243    }
244
245    fn new(
246        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
247        outgoing_tx: mpsc::UnboundedSender<Envelope>,
248        spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
249        cx: &AppContext,
250    ) -> Arc<SshSession> {
251        let this = Arc::new(Self {
252            next_message_id: AtomicU32::new(0),
253            response_channels: ResponseChannels::default(),
254            outgoing_tx,
255            spawn_process_tx,
256            message_handlers: Default::default(),
257        });
258
259        cx.spawn(|cx| {
260            let this = this.clone();
261            async move {
262                let peer_id = PeerId { owner_id: 0, id: 0 };
263                while let Some(incoming) = incoming_rx.next().await {
264                    if let Some(request_id) = incoming.responding_to {
265                        let request_id = MessageId(request_id);
266                        let sender = this.response_channels.lock().remove(&request_id);
267                        if let Some(sender) = sender {
268                            let (tx, rx) = oneshot::channel();
269                            if incoming.payload.is_some() {
270                                sender.send((incoming, tx)).ok();
271                            }
272                            rx.await.ok();
273                        }
274                    } else if let Some(envelope) =
275                        build_typed_envelope(peer_id, Instant::now(), incoming)
276                    {
277                        log::debug!(
278                            "ssh message received. name:{}",
279                            envelope.payload_type_name()
280                        );
281                        let type_id = envelope.payload_type_id();
282                        let handler = this.message_handlers.lock().get(&type_id).cloned();
283                        if let Some(handler) = handler {
284                            if let Some(future) = handler(envelope, this.clone(), cx.clone()) {
285                                future.await.ok();
286                            } else {
287                                this.message_handlers.lock().remove(&type_id);
288                            }
289                        }
290                    }
291                }
292                anyhow::Ok(())
293            }
294        })
295        .detach();
296
297        this
298    }
299
300    pub fn request<T: RequestMessage>(
301        &self,
302        payload: T,
303    ) -> impl 'static + Future<Output = Result<T::Response>> {
304        log::debug!("ssh request start. name:{}", T::NAME);
305        let response = self.request_dynamic(payload.into_envelope(0, None, None), "");
306        async move {
307            let response = response.await?;
308            log::debug!("ssh request finish. name:{}", T::NAME);
309            T::Response::from_envelope(response)
310                .ok_or_else(|| anyhow!("received a response of the wrong type"))
311        }
312    }
313
314    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
315        self.send_dynamic(payload.into_envelope(0, None, None))
316    }
317
318    pub fn request_dynamic(
319        &self,
320        mut envelope: proto::Envelope,
321        _request_type: &'static str,
322    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
323        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
324        let (tx, rx) = oneshot::channel();
325        self.response_channels
326            .lock()
327            .insert(MessageId(envelope.id), tx);
328        self.outgoing_tx.unbounded_send(envelope).ok();
329        async move { Ok(rx.await.context("connection lost")?.0) }
330    }
331
332    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
333        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
334        self.outgoing_tx.unbounded_send(envelope)?;
335        Ok(())
336    }
337
338    pub async fn spawn_process(&self, command: String) -> process::Child {
339        let (process_tx, process_rx) = oneshot::channel();
340        self.spawn_process_tx
341            .unbounded_send(SpawnRequest {
342                command,
343                process_tx,
344            })
345            .ok();
346        process_rx.await.unwrap()
347    }
348
349    pub fn add_message_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
350    where
351        M: EnvelopedMessage,
352        E: 'static,
353        H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
354        F: 'static + Future<Output = Result<()>>,
355    {
356        let message_type_id = TypeId::of::<M>();
357        self.message_handlers.lock().insert(
358            message_type_id,
359            Arc::new(move |envelope, _, cx| {
360                let entity = entity.upgrade()?;
361                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
362                Some(handler(entity, *envelope, cx).boxed_local())
363            }),
364        );
365    }
366
367    pub fn add_request_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
368    where
369        M: EnvelopedMessage + RequestMessage,
370        E: 'static,
371        H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
372        F: 'static + Future<Output = Result<M::Response>>,
373    {
374        let message_type_id = TypeId::of::<M>();
375        self.message_handlers.lock().insert(
376            message_type_id,
377            Arc::new(move |envelope, this, cx| {
378                let entity = entity.upgrade()?;
379                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
380                let request_id = envelope.message_id();
381                Some(
382                    handler(entity, *envelope, cx)
383                        .then(move |result| async move {
384                            this.outgoing_tx.unbounded_send(result?.into_envelope(
385                                this.next_message_id.fetch_add(1, SeqCst),
386                                Some(request_id),
387                                None,
388                            ))?;
389                            Ok(())
390                        })
391                        .boxed_local(),
392                )
393            }),
394        );
395    }
396}
397
398impl ProtoClient for SshSession {
399    fn request(
400        &self,
401        envelope: proto::Envelope,
402        request_type: &'static str,
403    ) -> BoxFuture<'static, Result<proto::Envelope>> {
404        self.request_dynamic(envelope, request_type).boxed()
405    }
406
407    fn send(&self, envelope: proto::Envelope) -> Result<()> {
408        self.send_dynamic(envelope)
409    }
410}
411
412impl SshClientState {
413    #[cfg(not(unix))]
414    async fn new(
415        user: String,
416        host: String,
417        port: u16,
418        delegate: Arc<dyn SshClientDelegate>,
419        cx: &mut AsyncAppContext,
420    ) -> Result<Self> {
421        Err(anyhow!("ssh is not supported on this platform"))
422    }
423
424    #[cfg(unix)]
425    async fn new(
426        user: String,
427        host: String,
428        port: u16,
429        delegate: Arc<dyn SshClientDelegate>,
430        cx: &mut AsyncAppContext,
431    ) -> Result<Self> {
432        use futures::{io::BufReader, AsyncBufReadExt as _};
433        use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
434        use util::ResultExt as _;
435
436        delegate.set_status(Some("connecting"), cx);
437
438        let url = format!("{user}@{host}");
439        let temp_dir = tempfile::Builder::new()
440            .prefix("zed-ssh-session")
441            .tempdir()?;
442
443        // Create a domain socket listener to handle requests from the askpass program.
444        let askpass_socket = temp_dir.path().join("askpass.sock");
445        let listener =
446            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
447
448        let askpass_task = cx.spawn(|mut cx| async move {
449            while let Ok((mut stream, _)) = listener.accept().await {
450                let mut buffer = Vec::new();
451                let mut reader = BufReader::new(&mut stream);
452                if reader.read_until(b'\0', &mut buffer).await.is_err() {
453                    buffer.clear();
454                }
455                let password_prompt = String::from_utf8_lossy(&buffer);
456                if let Some(password) = delegate
457                    .ask_password(password_prompt.to_string(), &mut cx)
458                    .await
459                    .context("failed to get ssh password")
460                    .and_then(|p| p)
461                    .log_err()
462                {
463                    stream.write_all(password.as_bytes()).await.log_err();
464                }
465            }
466        });
467
468        // Create an askpass script that communicates back to this process.
469        let askpass_script = format!(
470            "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
471            askpass_socket = askpass_socket.display(),
472            print_args = "printf '%s\\0' \"$@\"",
473            shebang = "#!/bin/sh",
474        );
475        let askpass_script_path = temp_dir.path().join("askpass.sh");
476        fs::write(&askpass_script_path, askpass_script).await?;
477        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
478
479        // Start the master SSH process, which does not do anything except for establish
480        // the connection and keep it open, allowing other ssh commands to reuse it
481        // via a control socket.
482        let socket_path = temp_dir.path().join("ssh.sock");
483        let mut master_process = process::Command::new("ssh")
484            .stdin(Stdio::null())
485            .stdout(Stdio::piped())
486            .stderr(Stdio::piped())
487            .env("SSH_ASKPASS_REQUIRE", "force")
488            .env("SSH_ASKPASS", &askpass_script_path)
489            .args(["-N", "-o", "ControlMaster=yes", "-o"])
490            .arg(format!("ControlPath={}", socket_path.display()))
491            .args(["-p", &port.to_string()])
492            .arg(&url)
493            .spawn()?;
494
495        // Wait for this ssh process to close its stdout, indicating that authentication
496        // has completed.
497        let stdout = master_process.stdout.as_mut().unwrap();
498        let mut output = Vec::new();
499        stdout.read_to_end(&mut output).await?;
500        drop(askpass_task);
501
502        if master_process.try_status()?.is_some() {
503            output.clear();
504            let mut stderr = master_process.stderr.take().unwrap();
505            stderr.read_to_end(&mut output).await?;
506            Err(anyhow!(
507                "failed to connect: {}",
508                String::from_utf8_lossy(&output)
509            ))?;
510        }
511
512        Ok(Self {
513            url,
514            port,
515            socket_path,
516            _master_process: master_process,
517            _temp_dir: temp_dir,
518        })
519    }
520
521    async fn ensure_server_binary(
522        &self,
523        delegate: &Arc<dyn SshClientDelegate>,
524        src_path: &Path,
525        dst_path: &Path,
526        version: SemanticVersion,
527        cx: &mut AsyncAppContext,
528    ) -> Result<()> {
529        let mut dst_path_gz = dst_path.to_path_buf();
530        dst_path_gz.set_extension("gz");
531
532        if let Some(parent) = dst_path.parent() {
533            run_cmd(self.ssh_command("mkdir").arg("-p").arg(parent)).await?;
534        }
535
536        let mut server_binary_exists = false;
537        if let Ok(installed_version) = run_cmd(self.ssh_command(&dst_path).arg("version")).await {
538            if installed_version.trim() == version.to_string() {
539                server_binary_exists = true;
540            }
541        }
542
543        if server_binary_exists {
544            log::info!("remote development server already present",);
545            return Ok(());
546        }
547
548        let src_stat = fs::metadata(src_path).await?;
549        let size = src_stat.len();
550        let server_mode = 0o755;
551
552        let t0 = Instant::now();
553        delegate.set_status(Some("uploading remote development server"), cx);
554        log::info!("uploading remote development server ({}kb)", size / 1024);
555        self.upload_file(src_path, &dst_path_gz)
556            .await
557            .context("failed to upload server binary")?;
558        log::info!("uploaded remote development server in {:?}", t0.elapsed());
559
560        delegate.set_status(Some("extracting remote development server"), cx);
561        run_cmd(self.ssh_command("gunzip").arg("--force").arg(&dst_path_gz)).await?;
562
563        delegate.set_status(Some("unzipping remote development server"), cx);
564        run_cmd(
565            self.ssh_command("chmod")
566                .arg(format!("{:o}", server_mode))
567                .arg(&dst_path),
568        )
569        .await?;
570
571        Ok(())
572    }
573
574    async fn query_platform(&self) -> Result<SshPlatform> {
575        let os = run_cmd(self.ssh_command("uname").arg("-s")).await?;
576        let arch = run_cmd(self.ssh_command("uname").arg("-m")).await?;
577
578        let os = match os.trim() {
579            "Darwin" => "macos",
580            "Linux" => "linux",
581            _ => Err(anyhow!("unknown uname os {os:?}"))?,
582        };
583        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
584            "aarch64"
585        } else if arch.starts_with("x86") || arch.starts_with("i686") {
586            "x86_64"
587        } else {
588            Err(anyhow!("unknown uname architecture {arch:?}"))?
589        };
590
591        Ok(SshPlatform { os, arch })
592    }
593
594    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
595        let mut command = process::Command::new("scp");
596        let output = self
597            .ssh_options(&mut command)
598            .arg("-P")
599            .arg(&self.port.to_string())
600            .arg(&src_path)
601            .arg(&format!("{}:{}", self.url, dest_path.display()))
602            .output()
603            .await?;
604
605        if output.status.success() {
606            Ok(())
607        } else {
608            Err(anyhow!(
609                "failed to upload file {} -> {}: {}",
610                src_path.display(),
611                dest_path.display(),
612                String::from_utf8_lossy(&output.stderr)
613            ))
614        }
615    }
616
617    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
618        let mut command = process::Command::new("ssh");
619        self.ssh_options(&mut command)
620            .arg("-p")
621            .arg(&self.port.to_string())
622            .arg(&self.url)
623            .arg(program);
624        command
625    }
626
627    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
628        command
629            .stdin(Stdio::piped())
630            .stdout(Stdio::piped())
631            .stderr(Stdio::piped())
632            .args(["-o", "ControlMaster=no", "-o"])
633            .arg(format!("ControlPath={}", self.socket_path.display()))
634    }
635}
636
637async fn run_cmd(command: &mut process::Command) -> Result<String> {
638    let output = command.output().await?;
639    if output.status.success() {
640        Ok(String::from_utf8_lossy(&output.stdout).to_string())
641    } else {
642        Err(anyhow!(
643            "failed to run command: {}",
644            String::from_utf8_lossy(&output.stderr)
645        ))
646    }
647}