ssh_session.rs

  1use crate::protocol::{
  2    message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
  3};
  4use anyhow::{anyhow, Context as _, Result};
  5use collections::HashMap;
  6use futures::{
  7    channel::{mpsc, oneshot},
  8    future::{BoxFuture, LocalBoxFuture},
  9    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
 10};
 11use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel};
 12use parking_lot::Mutex;
 13use rpc::{
 14    proto::{
 15        self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId,
 16        ProtoClient, RequestMessage,
 17    },
 18    TypedEnvelope,
 19};
 20use smol::{
 21    fs,
 22    process::{self, Stdio},
 23};
 24use std::{
 25    any::TypeId,
 26    ffi::OsStr,
 27    path::{Path, PathBuf},
 28    sync::{
 29        atomic::{AtomicU32, Ordering::SeqCst},
 30        Arc,
 31    },
 32    time::Instant,
 33};
 34use tempfile::TempDir;
 35
 36pub struct SshSession {
 37    next_message_id: AtomicU32,
 38    response_channels: ResponseChannels,
 39    outgoing_tx: mpsc::UnboundedSender<Envelope>,
 40    spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
 41    message_handlers: Mutex<
 42        HashMap<
 43            TypeId,
 44            Arc<
 45                dyn Send
 46                    + Sync
 47                    + Fn(
 48                        Box<dyn AnyTypedEnvelope>,
 49                        Arc<SshSession>,
 50                        AsyncAppContext,
 51                    ) -> Option<LocalBoxFuture<'static, Result<()>>>,
 52            >,
 53        >,
 54    >,
 55}
 56
 57struct SshClientState {
 58    socket_path: PathBuf,
 59    port: u16,
 60    url: String,
 61    _master_process: process::Child,
 62    _temp_dir: TempDir,
 63}
 64
 65struct SpawnRequest {
 66    command: String,
 67    process_tx: oneshot::Sender<process::Child>,
 68}
 69
 70#[derive(Copy, Clone, Debug)]
 71pub struct SshPlatform {
 72    pub os: &'static str,
 73    pub arch: &'static str,
 74}
 75
 76pub trait SshClientDelegate {
 77    fn ask_password(
 78        &self,
 79        prompt: String,
 80        cx: &mut AsyncAppContext,
 81    ) -> oneshot::Receiver<Result<String>>;
 82    fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
 83    fn get_server_binary(
 84        &self,
 85        platform: SshPlatform,
 86        cx: &mut AsyncAppContext,
 87    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
 88}
 89
 90type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 91
 92impl SshSession {
 93    pub async fn client(
 94        user: String,
 95        host: String,
 96        port: u16,
 97        delegate: Arc<dyn SshClientDelegate>,
 98        cx: &mut AsyncAppContext,
 99    ) -> Result<Arc<Self>> {
100        let client_state = SshClientState::new(user, host, port, delegate.clone(), cx).await?;
101
102        let platform = query_platform(&client_state).await?;
103        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
104        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
105        ensure_server_binary(
106            &client_state,
107            &local_binary_path,
108            &remote_binary_path,
109            version,
110        )
111        .await?;
112
113        let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
114        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
115        let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
116
117        let mut remote_server_child = client_state
118            .ssh_command(&remote_binary_path)
119            .arg("run")
120            .spawn()
121            .context("failed to spawn remote server")?;
122        let mut child_stderr = remote_server_child.stderr.take().unwrap();
123        let mut child_stdout = remote_server_child.stdout.take().unwrap();
124        let mut child_stdin = remote_server_child.stdin.take().unwrap();
125
126        let executor = cx.background_executor().clone();
127        executor.clone().spawn(async move {
128            let mut stdin_buffer = Vec::new();
129            let mut stdout_buffer = Vec::new();
130            let mut stderr_buffer = Vec::new();
131            let mut stderr_offset = 0;
132
133            loop {
134                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
135                stderr_buffer.resize(stderr_offset + 1024, 0);
136
137                select_biased! {
138                    outgoing = outgoing_rx.next().fuse() => {
139                        let Some(outgoing) = outgoing else {
140                            return anyhow::Ok(());
141                        };
142
143                        write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
144                    }
145
146                    request = spawn_process_rx.next().fuse() => {
147                        let Some(request) = request else {
148                            return Ok(());
149                        };
150
151                        log::info!("spawn process: {:?}", request.command);
152                        let child = client_state
153                            .ssh_command(&request.command)
154                            .spawn()
155                            .context("failed to create channel")?;
156                        request.process_tx.send(child).ok();
157                    }
158
159                    result = child_stdout.read(&mut stdout_buffer).fuse() => {
160                        match result {
161                            Ok(len) => {
162                                if len == 0 {
163                                    child_stdin.close().await?;
164                                    let status = remote_server_child.status().await?;
165                                    if !status.success() {
166                                        log::info!("channel exited with status: {status:?}");
167                                    }
168                                    return Ok(());
169                                }
170
171                                if len < stdout_buffer.len() {
172                                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
173                                }
174
175                                let message_len = message_len_from_buffer(&stdout_buffer);
176                                match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
177                                    Ok(envelope) => {
178                                        incoming_tx.unbounded_send(envelope).ok();
179                                    }
180                                    Err(error) => {
181                                        log::error!("error decoding message {error:?}");
182                                    }
183                                }
184                            }
185                            Err(error) => {
186                                Err(anyhow!("error reading stdout: {error:?}"))?;
187                            }
188                        }
189                    }
190
191                    result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
192                        match result {
193                            Ok(len) => {
194                                stderr_offset += len;
195                                let mut start_ix = 0;
196                                while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
197                                    let line_ix = start_ix + ix;
198                                    let content = String::from_utf8_lossy(&stderr_buffer[start_ix..line_ix]);
199                                    start_ix = line_ix + 1;
200                                    eprintln!("(remote) {}", content);
201                                }
202                                stderr_buffer.drain(0..start_ix);
203                                stderr_offset -= start_ix;
204                            }
205                            Err(error) => {
206                                Err(anyhow!("error reading stderr: {error:?}"))?;
207                            }
208                        }
209                    }
210                }
211            }
212        }).detach();
213
214        cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, cx))
215    }
216
217    pub fn server(
218        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
219        outgoing_tx: mpsc::UnboundedSender<Envelope>,
220        cx: &AppContext,
221    ) -> Arc<SshSession> {
222        let (tx, _rx) = mpsc::unbounded();
223        Self::new(incoming_rx, outgoing_tx, tx, cx)
224    }
225
226    #[cfg(any(test, feature = "test-support"))]
227    pub fn fake(
228        client_cx: &mut gpui::TestAppContext,
229        server_cx: &mut gpui::TestAppContext,
230    ) -> (Arc<Self>, Arc<Self>) {
231        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
232        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
233        let (tx, _rx) = mpsc::unbounded();
234        (
235            client_cx
236                .update(|cx| Self::new(server_to_client_rx, client_to_server_tx, tx.clone(), cx)),
237            server_cx
238                .update(|cx| Self::new(client_to_server_rx, server_to_client_tx, tx.clone(), cx)),
239        )
240    }
241
242    fn new(
243        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
244        outgoing_tx: mpsc::UnboundedSender<Envelope>,
245        spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
246        cx: &AppContext,
247    ) -> Arc<SshSession> {
248        let this = Arc::new(Self {
249            next_message_id: AtomicU32::new(0),
250            response_channels: ResponseChannels::default(),
251            outgoing_tx,
252            spawn_process_tx,
253            message_handlers: Default::default(),
254        });
255
256        cx.spawn(|cx| {
257            let this = this.clone();
258            async move {
259                let peer_id = PeerId { owner_id: 0, id: 0 };
260                while let Some(incoming) = incoming_rx.next().await {
261                    if let Some(request_id) = incoming.responding_to {
262                        let request_id = MessageId(request_id);
263                        let sender = this.response_channels.lock().remove(&request_id);
264                        if let Some(sender) = sender {
265                            let (tx, rx) = oneshot::channel();
266                            if incoming.payload.is_some() {
267                                sender.send((incoming, tx)).ok();
268                            }
269                            rx.await.ok();
270                        }
271                    } else if let Some(envelope) =
272                        build_typed_envelope(peer_id, Instant::now(), incoming)
273                    {
274                        log::debug!(
275                            "ssh message received. name:{}",
276                            envelope.payload_type_name()
277                        );
278                        let type_id = envelope.payload_type_id();
279                        let handler = this.message_handlers.lock().get(&type_id).cloned();
280                        if let Some(handler) = handler {
281                            if let Some(future) = handler(envelope, this.clone(), cx.clone()) {
282                                future.await.ok();
283                            } else {
284                                this.message_handlers.lock().remove(&type_id);
285                            }
286                        }
287                    }
288                }
289                anyhow::Ok(())
290            }
291        })
292        .detach();
293
294        this
295    }
296
297    pub fn request<T: RequestMessage>(
298        &self,
299        payload: T,
300    ) -> impl 'static + Future<Output = Result<T::Response>> {
301        log::debug!("ssh request start. name:{}", T::NAME);
302        let response = self.request_dynamic(payload.into_envelope(0, None, None), "");
303        async move {
304            let response = response.await?;
305            log::debug!("ssh request finish. name:{}", T::NAME);
306            T::Response::from_envelope(response)
307                .ok_or_else(|| anyhow!("received a response of the wrong type"))
308        }
309    }
310
311    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
312        self.send_dynamic(payload.into_envelope(0, None, None))
313    }
314
315    pub fn request_dynamic(
316        &self,
317        mut envelope: proto::Envelope,
318        _request_type: &'static str,
319    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
320        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
321        let (tx, rx) = oneshot::channel();
322        self.response_channels
323            .lock()
324            .insert(MessageId(envelope.id), tx);
325        self.outgoing_tx.unbounded_send(envelope).ok();
326        async move { Ok(rx.await.context("connection lost")?.0) }
327    }
328
329    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
330        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
331        self.outgoing_tx.unbounded_send(envelope)?;
332        Ok(())
333    }
334
335    pub async fn spawn_process(&self, command: String) -> process::Child {
336        let (process_tx, process_rx) = oneshot::channel();
337        self.spawn_process_tx
338            .unbounded_send(SpawnRequest {
339                command,
340                process_tx,
341            })
342            .ok();
343        process_rx.await.unwrap()
344    }
345
346    pub fn add_message_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
347    where
348        M: EnvelopedMessage,
349        E: 'static,
350        H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
351        F: 'static + Future<Output = Result<()>>,
352    {
353        let message_type_id = TypeId::of::<M>();
354        self.message_handlers.lock().insert(
355            message_type_id,
356            Arc::new(move |envelope, _, cx| {
357                let entity = entity.upgrade()?;
358                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
359                Some(handler(entity, *envelope, cx).boxed_local())
360            }),
361        );
362    }
363
364    pub fn add_request_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
365    where
366        M: EnvelopedMessage + RequestMessage,
367        E: 'static,
368        H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
369        F: 'static + Future<Output = Result<M::Response>>,
370    {
371        let message_type_id = TypeId::of::<M>();
372        self.message_handlers.lock().insert(
373            message_type_id,
374            Arc::new(move |envelope, this, cx| {
375                let entity = entity.upgrade()?;
376                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
377                let request_id = envelope.message_id();
378                Some(
379                    handler(entity, *envelope, cx)
380                        .then(move |result| async move {
381                            this.outgoing_tx.unbounded_send(result?.into_envelope(
382                                this.next_message_id.fetch_add(1, SeqCst),
383                                Some(request_id),
384                                None,
385                            ))?;
386                            Ok(())
387                        })
388                        .boxed_local(),
389                )
390            }),
391        );
392    }
393}
394
395impl ProtoClient for SshSession {
396    fn request(
397        &self,
398        envelope: proto::Envelope,
399        request_type: &'static str,
400    ) -> BoxFuture<'static, Result<proto::Envelope>> {
401        self.request_dynamic(envelope, request_type).boxed()
402    }
403
404    fn send(&self, envelope: proto::Envelope) -> Result<()> {
405        self.send_dynamic(envelope)
406    }
407}
408
409impl SshClientState {
410    #[cfg(not(unix))]
411    async fn new(
412        user: String,
413        host: String,
414        port: u16,
415        delegate: Arc<dyn SshClientDelegate>,
416        cx: &AsyncAppContext,
417    ) -> Result<Self> {
418        Err(anyhow!("ssh is not supported on this platform"))
419    }
420
421    #[cfg(unix)]
422    async fn new(
423        user: String,
424        host: String,
425        port: u16,
426        delegate: Arc<dyn SshClientDelegate>,
427        cx: &AsyncAppContext,
428    ) -> Result<Self> {
429        use smol::fs::unix::PermissionsExt as _;
430        use util::ResultExt as _;
431
432        let url = format!("{user}@{host}");
433        let temp_dir = tempfile::Builder::new()
434            .prefix("zed-ssh-session")
435            .tempdir()?;
436
437        // Create a TCP listener to handle requests from the askpass program.
438        let listener = smol::net::TcpListener::bind("127.0.0.1:0")
439            .await
440            .expect("failed to find open port");
441        let askpass_port = listener.local_addr().unwrap().port();
442        let askpass_task = cx.spawn(|mut cx| async move {
443            while let Ok((mut stream, _)) = listener.accept().await {
444                let mut buffer = Vec::new();
445                if stream.read_to_end(&mut buffer).await.is_err() {
446                    buffer.clear();
447                }
448                let password_prompt = String::from_utf8_lossy(&buffer);
449                if let Some(password) = delegate
450                    .ask_password(password_prompt.to_string(), &mut cx)
451                    .await
452                    .context("failed to get ssh password")
453                    .and_then(|p| p)
454                    .log_err()
455                {
456                    stream.write_all(password.as_bytes()).await.log_err();
457                }
458            }
459        });
460
461        // Create an askpass script that communicates back to this process using TCP.
462        let askpass_script = format!(
463            "{shebang}\n echo \"$@\" | nc 127.0.0.1 {askpass_port} 2> /dev/null",
464            shebang = "#!/bin/sh"
465        );
466        let askpass_script_path = temp_dir.path().join("askpass.sh");
467        fs::write(&askpass_script_path, askpass_script).await?;
468        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
469
470        // Start the master SSH process, which does not do anything except for establish
471        // the connection and keep it open, allowing other ssh commands to reuse it
472        // via a control socket.
473        let socket_path = temp_dir.path().join("ssh.sock");
474        let mut master_process = process::Command::new("ssh")
475            .stdin(Stdio::null())
476            .stdout(Stdio::piped())
477            .stderr(Stdio::piped())
478            .env("SSH_ASKPASS_REQUIRE", "force")
479            .env("SSH_ASKPASS", &askpass_script_path)
480            .args(["-N", "-o", "ControlMaster=yes", "-o"])
481            .arg(format!("ControlPath={}", socket_path.display()))
482            .args(["-p", &port.to_string()])
483            .arg(&url)
484            .spawn()?;
485
486        // Wait for this ssh process to close its stdout, indicating that authentication
487        // has completed.
488        let stdout = master_process.stdout.as_mut().unwrap();
489        let mut output = Vec::new();
490        stdout.read_to_end(&mut output).await?;
491        drop(askpass_task);
492
493        if master_process.try_status()?.is_some() {
494            output.clear();
495            let mut stderr = master_process.stderr.take().unwrap();
496            stderr.read_to_end(&mut output).await?;
497            Err(anyhow!(
498                "failed to connect: {}",
499                String::from_utf8_lossy(&output)
500            ))?;
501        }
502
503        Ok(Self {
504            _master_process: master_process,
505            port,
506            _temp_dir: temp_dir,
507            socket_path,
508            url,
509        })
510    }
511
512    async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
513        let mut command = process::Command::new("scp");
514        let output = self
515            .ssh_options(&mut command)
516            .arg("-P")
517            .arg(&self.port.to_string())
518            .arg(&src_path)
519            .arg(&format!("{}:{}", self.url, dest_path.display()))
520            .output()
521            .await?;
522
523        if output.status.success() {
524            Ok(())
525        } else {
526            Err(anyhow!(
527                "failed to upload file {} -> {}: {}",
528                src_path.display(),
529                dest_path.display(),
530                String::from_utf8_lossy(&output.stderr)
531            ))
532        }
533    }
534
535    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
536        let mut command = process::Command::new("ssh");
537        self.ssh_options(&mut command)
538            .arg("-p")
539            .arg(&self.port.to_string())
540            .arg(&self.url)
541            .arg(program);
542        command
543    }
544
545    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
546        command
547            .stdin(Stdio::piped())
548            .stdout(Stdio::piped())
549            .stderr(Stdio::piped())
550            .args(["-o", "ControlMaster=no", "-o"])
551            .arg(format!("ControlPath={}", self.socket_path.display()))
552    }
553}
554
555async fn run_cmd(command: &mut process::Command) -> Result<String> {
556    let output = command.output().await?;
557    if output.status.success() {
558        Ok(String::from_utf8_lossy(&output.stdout).to_string())
559    } else {
560        Err(anyhow!(
561            "failed to run command: {}",
562            String::from_utf8_lossy(&output.stderr)
563        ))
564    }
565}
566
567async fn query_platform(session: &SshClientState) -> Result<SshPlatform> {
568    let os = run_cmd(session.ssh_command("uname").arg("-s")).await?;
569    let arch = run_cmd(session.ssh_command("uname").arg("-m")).await?;
570
571    let os = match os.trim() {
572        "Darwin" => "macos",
573        "Linux" => "linux",
574        _ => Err(anyhow!("unknown uname os {os:?}"))?,
575    };
576    let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
577        "aarch64"
578    } else if arch.starts_with("x86") || arch.starts_with("i686") {
579        "x86_64"
580    } else {
581        Err(anyhow!("unknown uname architecture {arch:?}"))?
582    };
583
584    Ok(SshPlatform { os, arch })
585}
586
587async fn ensure_server_binary(
588    session: &SshClientState,
589    src_path: &Path,
590    dst_path: &Path,
591    version: SemanticVersion,
592) -> Result<()> {
593    let mut dst_path_gz = dst_path.to_path_buf();
594    dst_path_gz.set_extension("gz");
595
596    if let Some(parent) = dst_path.parent() {
597        run_cmd(session.ssh_command("mkdir").arg("-p").arg(parent)).await?;
598    }
599
600    let mut server_binary_exists = false;
601    if let Ok(installed_version) = run_cmd(session.ssh_command(&dst_path).arg("version")).await {
602        if installed_version.trim() == version.to_string() {
603            server_binary_exists = true;
604        }
605    }
606
607    if server_binary_exists {
608        log::info!("remote development server already present",);
609        return Ok(());
610    }
611
612    let src_stat = fs::metadata(src_path).await?;
613    let size = src_stat.len();
614    let server_mode = 0o755;
615
616    let t0 = Instant::now();
617    log::info!("uploading remote development server ({}kb)", size / 1024);
618    session
619        .upload_file(src_path, &dst_path_gz)
620        .await
621        .context("failed to upload server binary")?;
622    log::info!("uploaded remote development server in {:?}", t0.elapsed());
623
624    log::info!("extracting remote development server");
625    run_cmd(
626        session
627            .ssh_command("gunzip")
628            .arg("--force")
629            .arg(&dst_path_gz),
630    )
631    .await?;
632
633    log::info!("unzipping remote development server");
634    run_cmd(
635        session
636            .ssh_command("chmod")
637            .arg(format!("{:o}", server_mode))
638            .arg(&dst_path),
639    )
640    .await?;
641
642    Ok(())
643}