1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6};
7use anyhow::{anyhow, Context as _, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 future::{BoxFuture, LocalBoxFuture},
12 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
13};
14use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel};
15use parking_lot::Mutex;
16use rpc::{
17 proto::{
18 self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId,
19 ProtoClient, RequestMessage,
20 },
21 TypedEnvelope,
22};
23use smol::{
24 fs,
25 process::{self, Stdio},
26};
27use std::{
28 any::TypeId,
29 ffi::OsStr,
30 path::{Path, PathBuf},
31 sync::{
32 atomic::{AtomicU32, Ordering::SeqCst},
33 Arc,
34 },
35 time::Instant,
36};
37use tempfile::TempDir;
38
39pub struct SshSession {
40 next_message_id: AtomicU32,
41 response_channels: ResponseChannels,
42 outgoing_tx: mpsc::UnboundedSender<Envelope>,
43 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
44 message_handlers: Mutex<
45 HashMap<
46 TypeId,
47 Arc<
48 dyn Send
49 + Sync
50 + Fn(
51 Box<dyn AnyTypedEnvelope>,
52 Arc<SshSession>,
53 AsyncAppContext,
54 ) -> Option<LocalBoxFuture<'static, Result<()>>>,
55 >,
56 >,
57 >,
58}
59
60struct SshClientState {
61 socket_path: PathBuf,
62 port: u16,
63 url: String,
64 _master_process: process::Child,
65 _temp_dir: TempDir,
66}
67
68struct SpawnRequest {
69 command: String,
70 process_tx: oneshot::Sender<process::Child>,
71}
72
73#[derive(Copy, Clone, Debug)]
74pub struct SshPlatform {
75 pub os: &'static str,
76 pub arch: &'static str,
77}
78
79pub trait SshClientDelegate {
80 fn ask_password(
81 &self,
82 prompt: String,
83 cx: &mut AsyncAppContext,
84 ) -> oneshot::Receiver<Result<String>>;
85 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
86 fn get_server_binary(
87 &self,
88 platform: SshPlatform,
89 cx: &mut AsyncAppContext,
90 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
91 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
92}
93
94type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
95
96impl SshSession {
97 pub async fn client(
98 user: String,
99 host: String,
100 port: u16,
101 delegate: Arc<dyn SshClientDelegate>,
102 cx: &mut AsyncAppContext,
103 ) -> Result<Arc<Self>> {
104 let client_state = SshClientState::new(user, host, port, delegate.clone(), cx).await?;
105
106 let platform = client_state.query_platform().await?;
107 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
108 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
109 client_state
110 .ensure_server_binary(
111 &delegate,
112 &local_binary_path,
113 &remote_binary_path,
114 version,
115 cx,
116 )
117 .await?;
118
119 let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
120 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
121 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
122
123 run_cmd(client_state.ssh_command(&remote_binary_path).arg("version")).await?;
124
125 let mut remote_server_child = client_state
126 .ssh_command(&format!(
127 "RUST_LOG={} {:?} run",
128 std::env::var("RUST_LOG").unwrap_or(String::new()),
129 remote_binary_path,
130 ))
131 .spawn()
132 .context("failed to spawn remote server")?;
133 let mut child_stderr = remote_server_child.stderr.take().unwrap();
134 let mut child_stdout = remote_server_child.stdout.take().unwrap();
135 let mut child_stdin = remote_server_child.stdin.take().unwrap();
136
137 let executor = cx.background_executor().clone();
138 executor.clone().spawn(async move {
139 let mut stdin_buffer = Vec::new();
140 let mut stdout_buffer = Vec::new();
141 let mut stderr_buffer = Vec::new();
142 let mut stderr_offset = 0;
143
144 loop {
145 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
146 stderr_buffer.resize(stderr_offset + 1024, 0);
147
148 select_biased! {
149 outgoing = outgoing_rx.next().fuse() => {
150 let Some(outgoing) = outgoing else {
151 return anyhow::Ok(());
152 };
153
154 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
155 }
156
157 request = spawn_process_rx.next().fuse() => {
158 let Some(request) = request else {
159 return Ok(());
160 };
161
162 log::info!("spawn process: {:?}", request.command);
163 let child = client_state
164 .ssh_command(&request.command)
165 .spawn()
166 .context("failed to create channel")?;
167 request.process_tx.send(child).ok();
168 }
169
170 result = child_stdout.read(&mut stdout_buffer).fuse() => {
171 match result {
172 Ok(len) => {
173 if len == 0 {
174 child_stdin.close().await?;
175 let status = remote_server_child.status().await?;
176 if !status.success() {
177 log::info!("channel exited with status: {status:?}");
178 }
179 return Ok(());
180 }
181
182 if len < stdout_buffer.len() {
183 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
184 }
185
186 let message_len = message_len_from_buffer(&stdout_buffer);
187 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
188 Ok(envelope) => {
189 incoming_tx.unbounded_send(envelope).ok();
190 }
191 Err(error) => {
192 log::error!("error decoding message {error:?}");
193 }
194 }
195 }
196 Err(error) => {
197 Err(anyhow!("error reading stdout: {error:?}"))?;
198 }
199 }
200 }
201
202 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
203 match result {
204 Ok(len) => {
205 stderr_offset += len;
206 let mut start_ix = 0;
207 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
208 let line_ix = start_ix + ix;
209 let content = &stderr_buffer[start_ix..line_ix];
210 start_ix = line_ix + 1;
211 if let Ok(record) = serde_json::from_slice::<LogRecord>(&content) {
212 record.log(log::logger())
213 } else {
214 eprintln!("(remote) {}", String::from_utf8_lossy(content));
215 }
216 }
217 stderr_buffer.drain(0..start_ix);
218 stderr_offset -= start_ix;
219 }
220 Err(error) => {
221 Err(anyhow!("error reading stderr: {error:?}"))?;
222 }
223 }
224 }
225 }
226 }
227 }).detach();
228
229 cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, cx))
230 }
231
232 pub fn server(
233 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
234 outgoing_tx: mpsc::UnboundedSender<Envelope>,
235 cx: &AppContext,
236 ) -> Arc<SshSession> {
237 let (tx, _rx) = mpsc::unbounded();
238 Self::new(incoming_rx, outgoing_tx, tx, cx)
239 }
240
241 #[cfg(any(test, feature = "test-support"))]
242 pub fn fake(
243 client_cx: &mut gpui::TestAppContext,
244 server_cx: &mut gpui::TestAppContext,
245 ) -> (Arc<Self>, Arc<Self>) {
246 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
247 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
248 let (tx, _rx) = mpsc::unbounded();
249 (
250 client_cx
251 .update(|cx| Self::new(server_to_client_rx, client_to_server_tx, tx.clone(), cx)),
252 server_cx
253 .update(|cx| Self::new(client_to_server_rx, server_to_client_tx, tx.clone(), cx)),
254 )
255 }
256
257 fn new(
258 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
259 outgoing_tx: mpsc::UnboundedSender<Envelope>,
260 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
261 cx: &AppContext,
262 ) -> Arc<SshSession> {
263 let this = Arc::new(Self {
264 next_message_id: AtomicU32::new(0),
265 response_channels: ResponseChannels::default(),
266 outgoing_tx,
267 spawn_process_tx,
268 message_handlers: Default::default(),
269 });
270
271 cx.spawn(|cx| {
272 let this = this.clone();
273 async move {
274 let peer_id = PeerId { owner_id: 0, id: 0 };
275 while let Some(incoming) = incoming_rx.next().await {
276 if let Some(request_id) = incoming.responding_to {
277 let request_id = MessageId(request_id);
278 let sender = this.response_channels.lock().remove(&request_id);
279 if let Some(sender) = sender {
280 let (tx, rx) = oneshot::channel();
281 if incoming.payload.is_some() {
282 sender.send((incoming, tx)).ok();
283 }
284 rx.await.ok();
285 }
286 } else if let Some(envelope) =
287 build_typed_envelope(peer_id, Instant::now(), incoming)
288 {
289 log::debug!(
290 "ssh message received. name:{}",
291 envelope.payload_type_name()
292 );
293 let type_id = envelope.payload_type_id();
294 let handler = this.message_handlers.lock().get(&type_id).cloned();
295 if let Some(handler) = handler {
296 if let Some(future) = handler(envelope, this.clone(), cx.clone()) {
297 future.await.ok();
298 } else {
299 this.message_handlers.lock().remove(&type_id);
300 }
301 }
302 }
303 }
304 anyhow::Ok(())
305 }
306 })
307 .detach();
308
309 this
310 }
311
312 pub fn request<T: RequestMessage>(
313 &self,
314 payload: T,
315 ) -> impl 'static + Future<Output = Result<T::Response>> {
316 log::debug!("ssh request start. name:{}", T::NAME);
317 let response = self.request_dynamic(payload.into_envelope(0, None, None), "");
318 async move {
319 let response = response.await?;
320 log::debug!("ssh request finish. name:{}", T::NAME);
321 T::Response::from_envelope(response)
322 .ok_or_else(|| anyhow!("received a response of the wrong type"))
323 }
324 }
325
326 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
327 self.send_dynamic(payload.into_envelope(0, None, None))
328 }
329
330 pub fn request_dynamic(
331 &self,
332 mut envelope: proto::Envelope,
333 _request_type: &'static str,
334 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
335 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
336 let (tx, rx) = oneshot::channel();
337 self.response_channels
338 .lock()
339 .insert(MessageId(envelope.id), tx);
340 self.outgoing_tx.unbounded_send(envelope).ok();
341 async move { Ok(rx.await.context("connection lost")?.0) }
342 }
343
344 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
345 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
346 self.outgoing_tx.unbounded_send(envelope)?;
347 Ok(())
348 }
349
350 pub async fn spawn_process(&self, command: String) -> process::Child {
351 let (process_tx, process_rx) = oneshot::channel();
352 self.spawn_process_tx
353 .unbounded_send(SpawnRequest {
354 command,
355 process_tx,
356 })
357 .ok();
358 process_rx.await.unwrap()
359 }
360
361 pub fn add_message_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
362 where
363 M: EnvelopedMessage,
364 E: 'static,
365 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
366 F: 'static + Future<Output = Result<()>>,
367 {
368 let message_type_id = TypeId::of::<M>();
369 self.message_handlers.lock().insert(
370 message_type_id,
371 Arc::new(move |envelope, _, cx| {
372 let entity = entity.upgrade()?;
373 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
374 Some(handler(entity, *envelope, cx).boxed_local())
375 }),
376 );
377 }
378
379 pub fn add_request_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
380 where
381 M: EnvelopedMessage + RequestMessage,
382 E: 'static,
383 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
384 F: 'static + Future<Output = Result<M::Response>>,
385 {
386 let message_type_id = TypeId::of::<M>();
387 self.message_handlers.lock().insert(
388 message_type_id,
389 Arc::new(move |envelope, this, cx| {
390 let entity = entity.upgrade()?;
391 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
392 let request_id = envelope.message_id();
393 Some(
394 handler(entity, *envelope, cx)
395 .then(move |result| async move {
396 this.outgoing_tx.unbounded_send(result?.into_envelope(
397 this.next_message_id.fetch_add(1, SeqCst),
398 Some(request_id),
399 None,
400 ))?;
401 Ok(())
402 })
403 .boxed_local(),
404 )
405 }),
406 );
407 }
408}
409
410impl ProtoClient for SshSession {
411 fn request(
412 &self,
413 envelope: proto::Envelope,
414 request_type: &'static str,
415 ) -> BoxFuture<'static, Result<proto::Envelope>> {
416 self.request_dynamic(envelope, request_type).boxed()
417 }
418
419 fn send(&self, envelope: proto::Envelope) -> Result<()> {
420 self.send_dynamic(envelope)
421 }
422}
423
424impl SshClientState {
425 #[cfg(not(unix))]
426 async fn new(
427 _user: String,
428 _host: String,
429 _port: u16,
430 _delegate: Arc<dyn SshClientDelegate>,
431 _cx: &mut AsyncAppContext,
432 ) -> Result<Self> {
433 Err(anyhow!("ssh is not supported on this platform"))
434 }
435
436 #[cfg(unix)]
437 async fn new(
438 user: String,
439 host: String,
440 port: u16,
441 delegate: Arc<dyn SshClientDelegate>,
442 cx: &mut AsyncAppContext,
443 ) -> Result<Self> {
444 use futures::{io::BufReader, AsyncBufReadExt as _};
445 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
446 use util::ResultExt as _;
447
448 delegate.set_status(Some("connecting"), cx);
449
450 let url = format!("{user}@{host}");
451 let temp_dir = tempfile::Builder::new()
452 .prefix("zed-ssh-session")
453 .tempdir()?;
454
455 // Create a domain socket listener to handle requests from the askpass program.
456 let askpass_socket = temp_dir.path().join("askpass.sock");
457 let listener =
458 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
459
460 let askpass_task = cx.spawn(|mut cx| async move {
461 while let Ok((mut stream, _)) = listener.accept().await {
462 let mut buffer = Vec::new();
463 let mut reader = BufReader::new(&mut stream);
464 if reader.read_until(b'\0', &mut buffer).await.is_err() {
465 buffer.clear();
466 }
467 let password_prompt = String::from_utf8_lossy(&buffer);
468 if let Some(password) = delegate
469 .ask_password(password_prompt.to_string(), &mut cx)
470 .await
471 .context("failed to get ssh password")
472 .and_then(|p| p)
473 .log_err()
474 {
475 stream.write_all(password.as_bytes()).await.log_err();
476 }
477 }
478 });
479
480 // Create an askpass script that communicates back to this process.
481 let askpass_script = format!(
482 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
483 askpass_socket = askpass_socket.display(),
484 print_args = "printf '%s\\0' \"$@\"",
485 shebang = "#!/bin/sh",
486 );
487 let askpass_script_path = temp_dir.path().join("askpass.sh");
488 fs::write(&askpass_script_path, askpass_script).await?;
489 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
490
491 // Start the master SSH process, which does not do anything except for establish
492 // the connection and keep it open, allowing other ssh commands to reuse it
493 // via a control socket.
494 let socket_path = temp_dir.path().join("ssh.sock");
495 let mut master_process = process::Command::new("ssh")
496 .stdin(Stdio::null())
497 .stdout(Stdio::piped())
498 .stderr(Stdio::piped())
499 .env("SSH_ASKPASS_REQUIRE", "force")
500 .env("SSH_ASKPASS", &askpass_script_path)
501 .args(["-N", "-o", "ControlMaster=yes", "-o"])
502 .arg(format!("ControlPath={}", socket_path.display()))
503 .args(["-p", &port.to_string()])
504 .arg(&url)
505 .spawn()?;
506
507 // Wait for this ssh process to close its stdout, indicating that authentication
508 // has completed.
509 let stdout = master_process.stdout.as_mut().unwrap();
510 let mut output = Vec::new();
511 stdout.read_to_end(&mut output).await?;
512 drop(askpass_task);
513
514 if master_process.try_status()?.is_some() {
515 output.clear();
516 let mut stderr = master_process.stderr.take().unwrap();
517 stderr.read_to_end(&mut output).await?;
518 Err(anyhow!(
519 "failed to connect: {}",
520 String::from_utf8_lossy(&output)
521 ))?;
522 }
523
524 Ok(Self {
525 url,
526 port,
527 socket_path,
528 _master_process: master_process,
529 _temp_dir: temp_dir,
530 })
531 }
532
533 async fn ensure_server_binary(
534 &self,
535 delegate: &Arc<dyn SshClientDelegate>,
536 src_path: &Path,
537 dst_path: &Path,
538 version: SemanticVersion,
539 cx: &mut AsyncAppContext,
540 ) -> Result<()> {
541 let mut dst_path_gz = dst_path.to_path_buf();
542 dst_path_gz.set_extension("gz");
543
544 if let Some(parent) = dst_path.parent() {
545 run_cmd(self.ssh_command("mkdir").arg("-p").arg(parent)).await?;
546 }
547
548 let mut server_binary_exists = false;
549 if cfg!(not(debug_assertions)) {
550 if let Ok(installed_version) = run_cmd(self.ssh_command(&dst_path).arg("version")).await
551 {
552 if installed_version.trim() == version.to_string() {
553 server_binary_exists = true;
554 }
555 }
556 }
557
558 if server_binary_exists {
559 log::info!("remote development server already present",);
560 return Ok(());
561 }
562
563 let src_stat = fs::metadata(src_path).await?;
564 let size = src_stat.len();
565 let server_mode = 0o755;
566
567 let t0 = Instant::now();
568 delegate.set_status(Some("uploading remote development server"), cx);
569 log::info!("uploading remote development server ({}kb)", size / 1024);
570 self.upload_file(src_path, &dst_path_gz)
571 .await
572 .context("failed to upload server binary")?;
573 log::info!("uploaded remote development server in {:?}", t0.elapsed());
574
575 delegate.set_status(Some("extracting remote development server"), cx);
576 run_cmd(self.ssh_command("gunzip").arg("--force").arg(&dst_path_gz)).await?;
577
578 delegate.set_status(Some("unzipping remote development server"), cx);
579 run_cmd(
580 self.ssh_command("chmod")
581 .arg(format!("{:o}", server_mode))
582 .arg(&dst_path),
583 )
584 .await?;
585
586 Ok(())
587 }
588
589 async fn query_platform(&self) -> Result<SshPlatform> {
590 let os = run_cmd(self.ssh_command("uname").arg("-s")).await?;
591 let arch = run_cmd(self.ssh_command("uname").arg("-m")).await?;
592
593 let os = match os.trim() {
594 "Darwin" => "macos",
595 "Linux" => "linux",
596 _ => Err(anyhow!("unknown uname os {os:?}"))?,
597 };
598 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
599 "aarch64"
600 } else if arch.starts_with("x86") || arch.starts_with("i686") {
601 "x86_64"
602 } else {
603 Err(anyhow!("unknown uname architecture {arch:?}"))?
604 };
605
606 Ok(SshPlatform { os, arch })
607 }
608
609 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
610 let mut command = process::Command::new("scp");
611 let output = self
612 .ssh_options(&mut command)
613 .arg("-P")
614 .arg(&self.port.to_string())
615 .arg(&src_path)
616 .arg(&format!("{}:{}", self.url, dest_path.display()))
617 .output()
618 .await?;
619
620 if output.status.success() {
621 Ok(())
622 } else {
623 Err(anyhow!(
624 "failed to upload file {} -> {}: {}",
625 src_path.display(),
626 dest_path.display(),
627 String::from_utf8_lossy(&output.stderr)
628 ))
629 }
630 }
631
632 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
633 let mut command = process::Command::new("ssh");
634 self.ssh_options(&mut command)
635 .arg("-p")
636 .arg(&self.port.to_string())
637 .arg(&self.url)
638 .arg(program);
639 command
640 }
641
642 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
643 command
644 .stdin(Stdio::piped())
645 .stdout(Stdio::piped())
646 .stderr(Stdio::piped())
647 .args(["-o", "ControlMaster=no", "-o"])
648 .arg(format!("ControlPath={}", self.socket_path.display()))
649 }
650}
651
652async fn run_cmd(command: &mut process::Command) -> Result<String> {
653 let output = command.output().await?;
654 if output.status.success() {
655 Ok(String::from_utf8_lossy(&output.stdout).to_string())
656 } else {
657 Err(anyhow!(
658 "failed to run command: {}",
659 String::from_utf8_lossy(&output.stderr)
660 ))
661 }
662}