1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6};
7use anyhow::{anyhow, Context as _, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 future::{BoxFuture, LocalBoxFuture},
12 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
13};
14use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel};
15use parking_lot::Mutex;
16use rpc::{
17 proto::{
18 self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId,
19 ProtoClient, RequestMessage,
20 },
21 TypedEnvelope,
22};
23use smol::{
24 fs,
25 process::{self, Stdio},
26};
27use std::{
28 any::TypeId,
29 ffi::OsStr,
30 path::{Path, PathBuf},
31 sync::{
32 atomic::{AtomicU32, Ordering::SeqCst},
33 Arc,
34 },
35 time::Instant,
36};
37use tempfile::TempDir;
38
39pub struct SshSession {
40 next_message_id: AtomicU32,
41 response_channels: ResponseChannels,
42 outgoing_tx: mpsc::UnboundedSender<Envelope>,
43 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
44 message_handlers: Mutex<
45 HashMap<
46 TypeId,
47 Arc<
48 dyn Send
49 + Sync
50 + Fn(
51 Box<dyn AnyTypedEnvelope>,
52 Arc<SshSession>,
53 AsyncAppContext,
54 ) -> Option<LocalBoxFuture<'static, Result<()>>>,
55 >,
56 >,
57 >,
58}
59
60struct SshClientState {
61 connection_options: SshConnectionOptions,
62 socket_path: PathBuf,
63 _master_process: process::Child,
64 _temp_dir: TempDir,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct SshConnectionOptions {
69 pub host: String,
70 pub username: Option<String>,
71 pub port: Option<u16>,
72 pub password: Option<String>,
73}
74
75impl SshConnectionOptions {
76 pub fn ssh_url(&self) -> String {
77 let mut result = String::from("ssh://");
78 if let Some(username) = &self.username {
79 result.push_str(username);
80 result.push('@');
81 }
82 result.push_str(&self.host);
83 if let Some(port) = self.port {
84 result.push(':');
85 result.push_str(&port.to_string());
86 }
87 result
88 }
89
90 fn scp_url(&self) -> String {
91 if let Some(username) = &self.username {
92 format!("{}@{}", username, self.host)
93 } else {
94 self.host.clone()
95 }
96 }
97
98 pub fn connection_string(&self) -> String {
99 let host = if let Some(username) = &self.username {
100 format!("{}@{}", username, self.host)
101 } else {
102 self.host.clone()
103 };
104 if let Some(port) = &self.port {
105 format!("{}:{}", host, port)
106 } else {
107 host
108 }
109 }
110}
111
112struct SpawnRequest {
113 command: String,
114 process_tx: oneshot::Sender<process::Child>,
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct SshPlatform {
119 pub os: &'static str,
120 pub arch: &'static str,
121}
122
123pub trait SshClientDelegate {
124 fn ask_password(
125 &self,
126 prompt: String,
127 cx: &mut AsyncAppContext,
128 ) -> oneshot::Receiver<Result<String>>;
129 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
130 fn get_server_binary(
131 &self,
132 platform: SshPlatform,
133 cx: &mut AsyncAppContext,
134 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
135 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
136}
137
138type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
139
140impl SshSession {
141 pub async fn client(
142 connection_options: SshConnectionOptions,
143 delegate: Arc<dyn SshClientDelegate>,
144 cx: &mut AsyncAppContext,
145 ) -> Result<Arc<Self>> {
146 let client_state = SshClientState::new(connection_options, delegate.clone(), cx).await?;
147
148 let platform = client_state.query_platform().await?;
149 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
150 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
151 client_state
152 .ensure_server_binary(
153 &delegate,
154 &local_binary_path,
155 &remote_binary_path,
156 version,
157 cx,
158 )
159 .await?;
160
161 let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
162 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
163 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
164
165 run_cmd(client_state.ssh_command(&remote_binary_path).arg("version")).await?;
166
167 let mut remote_server_child = client_state
168 .ssh_command(&format!(
169 "RUST_LOG={} {:?} run",
170 std::env::var("RUST_LOG").unwrap_or(String::new()),
171 remote_binary_path,
172 ))
173 .spawn()
174 .context("failed to spawn remote server")?;
175 let mut child_stderr = remote_server_child.stderr.take().unwrap();
176 let mut child_stdout = remote_server_child.stdout.take().unwrap();
177 let mut child_stdin = remote_server_child.stdin.take().unwrap();
178
179 let executor = cx.background_executor().clone();
180 executor.clone().spawn(async move {
181 let mut stdin_buffer = Vec::new();
182 let mut stdout_buffer = Vec::new();
183 let mut stderr_buffer = Vec::new();
184 let mut stderr_offset = 0;
185
186 loop {
187 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
188 stderr_buffer.resize(stderr_offset + 1024, 0);
189
190 select_biased! {
191 outgoing = outgoing_rx.next().fuse() => {
192 let Some(outgoing) = outgoing else {
193 return anyhow::Ok(());
194 };
195
196 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
197 }
198
199 request = spawn_process_rx.next().fuse() => {
200 let Some(request) = request else {
201 return Ok(());
202 };
203
204 log::info!("spawn process: {:?}", request.command);
205 let child = client_state
206 .ssh_command(&request.command)
207 .spawn()
208 .context("failed to create channel")?;
209 request.process_tx.send(child).ok();
210 }
211
212 result = child_stdout.read(&mut stdout_buffer).fuse() => {
213 match result {
214 Ok(len) => {
215 if len == 0 {
216 child_stdin.close().await?;
217 let status = remote_server_child.status().await?;
218 if !status.success() {
219 log::info!("channel exited with status: {status:?}");
220 }
221 return Ok(());
222 }
223
224 if len < stdout_buffer.len() {
225 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
226 }
227
228 let message_len = message_len_from_buffer(&stdout_buffer);
229 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
230 Ok(envelope) => {
231 incoming_tx.unbounded_send(envelope).ok();
232 }
233 Err(error) => {
234 log::error!("error decoding message {error:?}");
235 }
236 }
237 }
238 Err(error) => {
239 Err(anyhow!("error reading stdout: {error:?}"))?;
240 }
241 }
242 }
243
244 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
245 match result {
246 Ok(len) => {
247 stderr_offset += len;
248 let mut start_ix = 0;
249 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
250 let line_ix = start_ix + ix;
251 let content = &stderr_buffer[start_ix..line_ix];
252 start_ix = line_ix + 1;
253 if let Ok(record) = serde_json::from_slice::<LogRecord>(&content) {
254 record.log(log::logger())
255 } else {
256 eprintln!("(remote) {}", String::from_utf8_lossy(content));
257 }
258 }
259 stderr_buffer.drain(0..start_ix);
260 stderr_offset -= start_ix;
261 }
262 Err(error) => {
263 Err(anyhow!("error reading stderr: {error:?}"))?;
264 }
265 }
266 }
267 }
268 }
269 }).detach();
270
271 cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, cx))
272 }
273
274 pub fn server(
275 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
276 outgoing_tx: mpsc::UnboundedSender<Envelope>,
277 cx: &AppContext,
278 ) -> Arc<SshSession> {
279 let (tx, _rx) = mpsc::unbounded();
280 Self::new(incoming_rx, outgoing_tx, tx, cx)
281 }
282
283 #[cfg(any(test, feature = "test-support"))]
284 pub fn fake(
285 client_cx: &mut gpui::TestAppContext,
286 server_cx: &mut gpui::TestAppContext,
287 ) -> (Arc<Self>, Arc<Self>) {
288 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
289 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
290 let (tx, _rx) = mpsc::unbounded();
291 (
292 client_cx
293 .update(|cx| Self::new(server_to_client_rx, client_to_server_tx, tx.clone(), cx)),
294 server_cx
295 .update(|cx| Self::new(client_to_server_rx, server_to_client_tx, tx.clone(), cx)),
296 )
297 }
298
299 fn new(
300 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
301 outgoing_tx: mpsc::UnboundedSender<Envelope>,
302 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
303 cx: &AppContext,
304 ) -> Arc<SshSession> {
305 let this = Arc::new(Self {
306 next_message_id: AtomicU32::new(0),
307 response_channels: ResponseChannels::default(),
308 outgoing_tx,
309 spawn_process_tx,
310 message_handlers: Default::default(),
311 });
312
313 cx.spawn(|cx| {
314 let this = this.clone();
315 async move {
316 let peer_id = PeerId { owner_id: 0, id: 0 };
317 while let Some(incoming) = incoming_rx.next().await {
318 if let Some(request_id) = incoming.responding_to {
319 let request_id = MessageId(request_id);
320 let sender = this.response_channels.lock().remove(&request_id);
321 if let Some(sender) = sender {
322 let (tx, rx) = oneshot::channel();
323 if incoming.payload.is_some() {
324 sender.send((incoming, tx)).ok();
325 }
326 rx.await.ok();
327 }
328 } else if let Some(envelope) =
329 build_typed_envelope(peer_id, Instant::now(), incoming)
330 {
331 log::debug!(
332 "ssh message received. name:{}",
333 envelope.payload_type_name()
334 );
335 let type_id = envelope.payload_type_id();
336 let handler = this.message_handlers.lock().get(&type_id).cloned();
337 if let Some(handler) = handler {
338 if let Some(future) = handler(envelope, this.clone(), cx.clone()) {
339 future.await.ok();
340 } else {
341 this.message_handlers.lock().remove(&type_id);
342 }
343 }
344 }
345 }
346 anyhow::Ok(())
347 }
348 })
349 .detach();
350
351 this
352 }
353
354 pub fn request<T: RequestMessage>(
355 &self,
356 payload: T,
357 ) -> impl 'static + Future<Output = Result<T::Response>> {
358 log::debug!("ssh request start. name:{}", T::NAME);
359 let response = self.request_dynamic(payload.into_envelope(0, None, None), "");
360 async move {
361 let response = response.await?;
362 log::debug!("ssh request finish. name:{}", T::NAME);
363 T::Response::from_envelope(response)
364 .ok_or_else(|| anyhow!("received a response of the wrong type"))
365 }
366 }
367
368 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
369 self.send_dynamic(payload.into_envelope(0, None, None))
370 }
371
372 pub fn request_dynamic(
373 &self,
374 mut envelope: proto::Envelope,
375 _request_type: &'static str,
376 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
377 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
378 let (tx, rx) = oneshot::channel();
379 self.response_channels
380 .lock()
381 .insert(MessageId(envelope.id), tx);
382 self.outgoing_tx.unbounded_send(envelope).ok();
383 async move { Ok(rx.await.context("connection lost")?.0) }
384 }
385
386 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
387 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
388 self.outgoing_tx.unbounded_send(envelope)?;
389 Ok(())
390 }
391
392 pub async fn spawn_process(&self, command: String) -> process::Child {
393 let (process_tx, process_rx) = oneshot::channel();
394 self.spawn_process_tx
395 .unbounded_send(SpawnRequest {
396 command,
397 process_tx,
398 })
399 .ok();
400 process_rx.await.unwrap()
401 }
402
403 pub fn add_message_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
404 where
405 M: EnvelopedMessage,
406 E: 'static,
407 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
408 F: 'static + Future<Output = Result<()>>,
409 {
410 let message_type_id = TypeId::of::<M>();
411 self.message_handlers.lock().insert(
412 message_type_id,
413 Arc::new(move |envelope, _, cx| {
414 let entity = entity.upgrade()?;
415 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
416 Some(handler(entity, *envelope, cx).boxed_local())
417 }),
418 );
419 }
420
421 pub fn add_request_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
422 where
423 M: EnvelopedMessage + RequestMessage,
424 E: 'static,
425 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
426 F: 'static + Future<Output = Result<M::Response>>,
427 {
428 let message_type_id = TypeId::of::<M>();
429 self.message_handlers.lock().insert(
430 message_type_id,
431 Arc::new(move |envelope, this, cx| {
432 let entity = entity.upgrade()?;
433 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
434 let request_id = envelope.message_id();
435 Some(
436 handler(entity, *envelope, cx)
437 .then(move |result| async move {
438 this.outgoing_tx.unbounded_send(result?.into_envelope(
439 this.next_message_id.fetch_add(1, SeqCst),
440 Some(request_id),
441 None,
442 ))?;
443 Ok(())
444 })
445 .boxed_local(),
446 )
447 }),
448 );
449 }
450}
451
452impl ProtoClient for SshSession {
453 fn request(
454 &self,
455 envelope: proto::Envelope,
456 request_type: &'static str,
457 ) -> BoxFuture<'static, Result<proto::Envelope>> {
458 self.request_dynamic(envelope, request_type).boxed()
459 }
460
461 fn send(&self, envelope: proto::Envelope) -> Result<()> {
462 self.send_dynamic(envelope)
463 }
464}
465
466impl SshClientState {
467 #[cfg(not(unix))]
468 async fn new(
469 _connection_options: SshConnectionOptions,
470 _delegate: Arc<dyn SshClientDelegate>,
471 _cx: &mut AsyncAppContext,
472 ) -> Result<Self> {
473 Err(anyhow!("ssh is not supported on this platform"))
474 }
475
476 #[cfg(unix)]
477 async fn new(
478 connection_options: SshConnectionOptions,
479 delegate: Arc<dyn SshClientDelegate>,
480 cx: &mut AsyncAppContext,
481 ) -> Result<Self> {
482 use futures::{io::BufReader, AsyncBufReadExt as _};
483 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
484 use util::ResultExt as _;
485
486 delegate.set_status(Some("connecting"), cx);
487
488 let url = connection_options.ssh_url();
489 let temp_dir = tempfile::Builder::new()
490 .prefix("zed-ssh-session")
491 .tempdir()?;
492
493 // Create a domain socket listener to handle requests from the askpass program.
494 let askpass_socket = temp_dir.path().join("askpass.sock");
495 let listener =
496 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
497
498 let askpass_task = cx.spawn(|mut cx| async move {
499 while let Ok((mut stream, _)) = listener.accept().await {
500 let mut buffer = Vec::new();
501 let mut reader = BufReader::new(&mut stream);
502 if reader.read_until(b'\0', &mut buffer).await.is_err() {
503 buffer.clear();
504 }
505 let password_prompt = String::from_utf8_lossy(&buffer);
506 if let Some(password) = delegate
507 .ask_password(password_prompt.to_string(), &mut cx)
508 .await
509 .context("failed to get ssh password")
510 .and_then(|p| p)
511 .log_err()
512 {
513 stream.write_all(password.as_bytes()).await.log_err();
514 }
515 }
516 });
517
518 // Create an askpass script that communicates back to this process.
519 let askpass_script = format!(
520 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
521 askpass_socket = askpass_socket.display(),
522 print_args = "printf '%s\\0' \"$@\"",
523 shebang = "#!/bin/sh",
524 );
525 let askpass_script_path = temp_dir.path().join("askpass.sh");
526 fs::write(&askpass_script_path, askpass_script).await?;
527 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
528
529 // Start the master SSH process, which does not do anything except for establish
530 // the connection and keep it open, allowing other ssh commands to reuse it
531 // via a control socket.
532 let socket_path = temp_dir.path().join("ssh.sock");
533 let mut master_process = process::Command::new("ssh")
534 .stdin(Stdio::null())
535 .stdout(Stdio::piped())
536 .stderr(Stdio::piped())
537 .env("SSH_ASKPASS_REQUIRE", "force")
538 .env("SSH_ASKPASS", &askpass_script_path)
539 .args(["-N", "-o", "ControlMaster=yes", "-o"])
540 .arg(format!("ControlPath={}", socket_path.display()))
541 .arg(&url)
542 .spawn()?;
543
544 // Wait for this ssh process to close its stdout, indicating that authentication
545 // has completed.
546 let stdout = master_process.stdout.as_mut().unwrap();
547 let mut output = Vec::new();
548 stdout.read_to_end(&mut output).await?;
549 drop(askpass_task);
550
551 if master_process.try_status()?.is_some() {
552 output.clear();
553 let mut stderr = master_process.stderr.take().unwrap();
554 stderr.read_to_end(&mut output).await?;
555 Err(anyhow!(
556 "failed to connect: {}",
557 String::from_utf8_lossy(&output)
558 ))?;
559 }
560
561 Ok(Self {
562 connection_options,
563 socket_path,
564 _master_process: master_process,
565 _temp_dir: temp_dir,
566 })
567 }
568
569 async fn ensure_server_binary(
570 &self,
571 delegate: &Arc<dyn SshClientDelegate>,
572 src_path: &Path,
573 dst_path: &Path,
574 version: SemanticVersion,
575 cx: &mut AsyncAppContext,
576 ) -> Result<()> {
577 let mut dst_path_gz = dst_path.to_path_buf();
578 dst_path_gz.set_extension("gz");
579
580 if let Some(parent) = dst_path.parent() {
581 run_cmd(self.ssh_command("mkdir").arg("-p").arg(parent)).await?;
582 }
583
584 let mut server_binary_exists = false;
585 if cfg!(not(debug_assertions)) {
586 if let Ok(installed_version) = run_cmd(self.ssh_command(&dst_path).arg("version")).await
587 {
588 if installed_version.trim() == version.to_string() {
589 server_binary_exists = true;
590 }
591 }
592 }
593
594 if server_binary_exists {
595 log::info!("remote development server already present",);
596 return Ok(());
597 }
598
599 let src_stat = fs::metadata(src_path).await?;
600 let size = src_stat.len();
601 let server_mode = 0o755;
602
603 let t0 = Instant::now();
604 delegate.set_status(Some("uploading remote development server"), cx);
605 log::info!("uploading remote development server ({}kb)", size / 1024);
606 self.upload_file(src_path, &dst_path_gz)
607 .await
608 .context("failed to upload server binary")?;
609 log::info!("uploaded remote development server in {:?}", t0.elapsed());
610
611 delegate.set_status(Some("extracting remote development server"), cx);
612 run_cmd(self.ssh_command("gunzip").arg("--force").arg(&dst_path_gz)).await?;
613
614 delegate.set_status(Some("unzipping remote development server"), cx);
615 run_cmd(
616 self.ssh_command("chmod")
617 .arg(format!("{:o}", server_mode))
618 .arg(&dst_path),
619 )
620 .await?;
621
622 Ok(())
623 }
624
625 async fn query_platform(&self) -> Result<SshPlatform> {
626 let os = run_cmd(self.ssh_command("uname").arg("-s")).await?;
627 let arch = run_cmd(self.ssh_command("uname").arg("-m")).await?;
628
629 let os = match os.trim() {
630 "Darwin" => "macos",
631 "Linux" => "linux",
632 _ => Err(anyhow!("unknown uname os {os:?}"))?,
633 };
634 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
635 "aarch64"
636 } else if arch.starts_with("x86") || arch.starts_with("i686") {
637 "x86_64"
638 } else {
639 Err(anyhow!("unknown uname architecture {arch:?}"))?
640 };
641
642 Ok(SshPlatform { os, arch })
643 }
644
645 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
646 let mut command = process::Command::new("scp");
647 let output = self
648 .ssh_options(&mut command)
649 .args(
650 self.connection_options
651 .port
652 .map(|port| vec!["-P".to_string(), port.to_string()])
653 .unwrap_or_default(),
654 )
655 .arg(&src_path)
656 .arg(&format!(
657 "{}:{}",
658 self.connection_options.scp_url(),
659 dest_path.display()
660 ))
661 .output()
662 .await?;
663
664 if output.status.success() {
665 Ok(())
666 } else {
667 Err(anyhow!(
668 "failed to upload file {} -> {}: {}",
669 src_path.display(),
670 dest_path.display(),
671 String::from_utf8_lossy(&output.stderr)
672 ))
673 }
674 }
675
676 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
677 let mut command = process::Command::new("ssh");
678 self.ssh_options(&mut command)
679 .arg(self.connection_options.ssh_url())
680 .arg(program);
681 command
682 }
683
684 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
685 command
686 .stdin(Stdio::piped())
687 .stdout(Stdio::piped())
688 .stderr(Stdio::piped())
689 .args(["-o", "ControlMaster=no", "-o"])
690 .arg(format!("ControlPath={}", self.socket_path.display()))
691 }
692}
693
694async fn run_cmd(command: &mut process::Command) -> Result<String> {
695 let output = command.output().await?;
696 if output.status.success() {
697 Ok(String::from_utf8_lossy(&output.stdout).to_string())
698 } else {
699 Err(anyhow!(
700 "failed to run command: {}",
701 String::from_utf8_lossy(&output.stderr)
702 ))
703 }
704}