1use crate::protocol::{
2 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
3};
4use anyhow::{anyhow, Context as _, Result};
5use collections::HashMap;
6use futures::{
7 channel::{mpsc, oneshot},
8 future::{BoxFuture, LocalBoxFuture},
9 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
10};
11use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel};
12use parking_lot::Mutex;
13use rpc::{
14 proto::{
15 self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId,
16 ProtoClient, RequestMessage,
17 },
18 TypedEnvelope,
19};
20use smol::{
21 fs,
22 process::{self, Stdio},
23};
24use std::{
25 any::TypeId,
26 ffi::OsStr,
27 path::{Path, PathBuf},
28 sync::{
29 atomic::{AtomicU32, Ordering::SeqCst},
30 Arc,
31 },
32 time::Instant,
33};
34use tempfile::TempDir;
35
36pub struct SshSession {
37 next_message_id: AtomicU32,
38 response_channels: ResponseChannels,
39 outgoing_tx: mpsc::UnboundedSender<Envelope>,
40 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
41 message_handlers: Mutex<
42 HashMap<
43 TypeId,
44 Arc<
45 dyn Send
46 + Sync
47 + Fn(
48 Box<dyn AnyTypedEnvelope>,
49 Arc<SshSession>,
50 AsyncAppContext,
51 ) -> Option<LocalBoxFuture<'static, Result<()>>>,
52 >,
53 >,
54 >,
55}
56
57struct SshClientState {
58 socket_path: PathBuf,
59 port: u16,
60 url: String,
61 _master_process: process::Child,
62 _temp_dir: TempDir,
63}
64
65struct SpawnRequest {
66 command: String,
67 process_tx: oneshot::Sender<process::Child>,
68}
69
70#[derive(Copy, Clone, Debug)]
71pub struct SshPlatform {
72 pub os: &'static str,
73 pub arch: &'static str,
74}
75
76pub trait SshClientDelegate {
77 fn ask_password(
78 &self,
79 prompt: String,
80 cx: &mut AsyncAppContext,
81 ) -> oneshot::Receiver<Result<String>>;
82 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
83 fn get_server_binary(
84 &self,
85 platform: SshPlatform,
86 cx: &mut AsyncAppContext,
87 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
88 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
89}
90
91type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
92
93impl SshSession {
94 pub async fn client(
95 user: String,
96 host: String,
97 port: u16,
98 delegate: Arc<dyn SshClientDelegate>,
99 cx: &mut AsyncAppContext,
100 ) -> Result<Arc<Self>> {
101 let client_state = SshClientState::new(user, host, port, delegate.clone(), cx).await?;
102
103 let platform = client_state.query_platform().await?;
104 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
105 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
106 client_state
107 .ensure_server_binary(
108 &delegate,
109 &local_binary_path,
110 &remote_binary_path,
111 version,
112 cx,
113 )
114 .await?;
115
116 let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
117 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
118 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
119
120 let mut remote_server_child = client_state
121 .ssh_command(&remote_binary_path)
122 .arg("run")
123 .spawn()
124 .context("failed to spawn remote server")?;
125 let mut child_stderr = remote_server_child.stderr.take().unwrap();
126 let mut child_stdout = remote_server_child.stdout.take().unwrap();
127 let mut child_stdin = remote_server_child.stdin.take().unwrap();
128
129 let executor = cx.background_executor().clone();
130 executor.clone().spawn(async move {
131 let mut stdin_buffer = Vec::new();
132 let mut stdout_buffer = Vec::new();
133 let mut stderr_buffer = Vec::new();
134 let mut stderr_offset = 0;
135
136 loop {
137 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
138 stderr_buffer.resize(stderr_offset + 1024, 0);
139
140 select_biased! {
141 outgoing = outgoing_rx.next().fuse() => {
142 let Some(outgoing) = outgoing else {
143 return anyhow::Ok(());
144 };
145
146 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
147 }
148
149 request = spawn_process_rx.next().fuse() => {
150 let Some(request) = request else {
151 return Ok(());
152 };
153
154 log::info!("spawn process: {:?}", request.command);
155 let child = client_state
156 .ssh_command(&request.command)
157 .spawn()
158 .context("failed to create channel")?;
159 request.process_tx.send(child).ok();
160 }
161
162 result = child_stdout.read(&mut stdout_buffer).fuse() => {
163 match result {
164 Ok(len) => {
165 if len == 0 {
166 child_stdin.close().await?;
167 let status = remote_server_child.status().await?;
168 if !status.success() {
169 log::info!("channel exited with status: {status:?}");
170 }
171 return Ok(());
172 }
173
174 if len < stdout_buffer.len() {
175 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
176 }
177
178 let message_len = message_len_from_buffer(&stdout_buffer);
179 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
180 Ok(envelope) => {
181 incoming_tx.unbounded_send(envelope).ok();
182 }
183 Err(error) => {
184 log::error!("error decoding message {error:?}");
185 }
186 }
187 }
188 Err(error) => {
189 Err(anyhow!("error reading stdout: {error:?}"))?;
190 }
191 }
192 }
193
194 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
195 match result {
196 Ok(len) => {
197 stderr_offset += len;
198 let mut start_ix = 0;
199 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
200 let line_ix = start_ix + ix;
201 let content = String::from_utf8_lossy(&stderr_buffer[start_ix..line_ix]);
202 start_ix = line_ix + 1;
203 eprintln!("(remote) {}", content);
204 }
205 stderr_buffer.drain(0..start_ix);
206 stderr_offset -= start_ix;
207 }
208 Err(error) => {
209 Err(anyhow!("error reading stderr: {error:?}"))?;
210 }
211 }
212 }
213 }
214 }
215 }).detach();
216
217 cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, cx))
218 }
219
220 pub fn server(
221 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
222 outgoing_tx: mpsc::UnboundedSender<Envelope>,
223 cx: &AppContext,
224 ) -> Arc<SshSession> {
225 let (tx, _rx) = mpsc::unbounded();
226 Self::new(incoming_rx, outgoing_tx, tx, cx)
227 }
228
229 #[cfg(any(test, feature = "test-support"))]
230 pub fn fake(
231 client_cx: &mut gpui::TestAppContext,
232 server_cx: &mut gpui::TestAppContext,
233 ) -> (Arc<Self>, Arc<Self>) {
234 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
235 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
236 let (tx, _rx) = mpsc::unbounded();
237 (
238 client_cx
239 .update(|cx| Self::new(server_to_client_rx, client_to_server_tx, tx.clone(), cx)),
240 server_cx
241 .update(|cx| Self::new(client_to_server_rx, server_to_client_tx, tx.clone(), cx)),
242 )
243 }
244
245 fn new(
246 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
247 outgoing_tx: mpsc::UnboundedSender<Envelope>,
248 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
249 cx: &AppContext,
250 ) -> Arc<SshSession> {
251 let this = Arc::new(Self {
252 next_message_id: AtomicU32::new(0),
253 response_channels: ResponseChannels::default(),
254 outgoing_tx,
255 spawn_process_tx,
256 message_handlers: Default::default(),
257 });
258
259 cx.spawn(|cx| {
260 let this = this.clone();
261 async move {
262 let peer_id = PeerId { owner_id: 0, id: 0 };
263 while let Some(incoming) = incoming_rx.next().await {
264 if let Some(request_id) = incoming.responding_to {
265 let request_id = MessageId(request_id);
266 let sender = this.response_channels.lock().remove(&request_id);
267 if let Some(sender) = sender {
268 let (tx, rx) = oneshot::channel();
269 if incoming.payload.is_some() {
270 sender.send((incoming, tx)).ok();
271 }
272 rx.await.ok();
273 }
274 } else if let Some(envelope) =
275 build_typed_envelope(peer_id, Instant::now(), incoming)
276 {
277 log::debug!(
278 "ssh message received. name:{}",
279 envelope.payload_type_name()
280 );
281 let type_id = envelope.payload_type_id();
282 let handler = this.message_handlers.lock().get(&type_id).cloned();
283 if let Some(handler) = handler {
284 if let Some(future) = handler(envelope, this.clone(), cx.clone()) {
285 future.await.ok();
286 } else {
287 this.message_handlers.lock().remove(&type_id);
288 }
289 }
290 }
291 }
292 anyhow::Ok(())
293 }
294 })
295 .detach();
296
297 this
298 }
299
300 pub fn request<T: RequestMessage>(
301 &self,
302 payload: T,
303 ) -> impl 'static + Future<Output = Result<T::Response>> {
304 log::debug!("ssh request start. name:{}", T::NAME);
305 let response = self.request_dynamic(payload.into_envelope(0, None, None), "");
306 async move {
307 let response = response.await?;
308 log::debug!("ssh request finish. name:{}", T::NAME);
309 T::Response::from_envelope(response)
310 .ok_or_else(|| anyhow!("received a response of the wrong type"))
311 }
312 }
313
314 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
315 self.send_dynamic(payload.into_envelope(0, None, None))
316 }
317
318 pub fn request_dynamic(
319 &self,
320 mut envelope: proto::Envelope,
321 _request_type: &'static str,
322 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
323 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
324 let (tx, rx) = oneshot::channel();
325 self.response_channels
326 .lock()
327 .insert(MessageId(envelope.id), tx);
328 self.outgoing_tx.unbounded_send(envelope).ok();
329 async move { Ok(rx.await.context("connection lost")?.0) }
330 }
331
332 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
333 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
334 self.outgoing_tx.unbounded_send(envelope)?;
335 Ok(())
336 }
337
338 pub async fn spawn_process(&self, command: String) -> process::Child {
339 let (process_tx, process_rx) = oneshot::channel();
340 self.spawn_process_tx
341 .unbounded_send(SpawnRequest {
342 command,
343 process_tx,
344 })
345 .ok();
346 process_rx.await.unwrap()
347 }
348
349 pub fn add_message_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
350 where
351 M: EnvelopedMessage,
352 E: 'static,
353 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
354 F: 'static + Future<Output = Result<()>>,
355 {
356 let message_type_id = TypeId::of::<M>();
357 self.message_handlers.lock().insert(
358 message_type_id,
359 Arc::new(move |envelope, _, cx| {
360 let entity = entity.upgrade()?;
361 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
362 Some(handler(entity, *envelope, cx).boxed_local())
363 }),
364 );
365 }
366
367 pub fn add_request_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
368 where
369 M: EnvelopedMessage + RequestMessage,
370 E: 'static,
371 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
372 F: 'static + Future<Output = Result<M::Response>>,
373 {
374 let message_type_id = TypeId::of::<M>();
375 self.message_handlers.lock().insert(
376 message_type_id,
377 Arc::new(move |envelope, this, cx| {
378 let entity = entity.upgrade()?;
379 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
380 let request_id = envelope.message_id();
381 Some(
382 handler(entity, *envelope, cx)
383 .then(move |result| async move {
384 this.outgoing_tx.unbounded_send(result?.into_envelope(
385 this.next_message_id.fetch_add(1, SeqCst),
386 Some(request_id),
387 None,
388 ))?;
389 Ok(())
390 })
391 .boxed_local(),
392 )
393 }),
394 );
395 }
396}
397
398impl ProtoClient for SshSession {
399 fn request(
400 &self,
401 envelope: proto::Envelope,
402 request_type: &'static str,
403 ) -> BoxFuture<'static, Result<proto::Envelope>> {
404 self.request_dynamic(envelope, request_type).boxed()
405 }
406
407 fn send(&self, envelope: proto::Envelope) -> Result<()> {
408 self.send_dynamic(envelope)
409 }
410}
411
412impl SshClientState {
413 #[cfg(not(unix))]
414 async fn new(
415 user: String,
416 host: String,
417 port: u16,
418 delegate: Arc<dyn SshClientDelegate>,
419 cx: &mut AsyncAppContext,
420 ) -> Result<Self> {
421 Err(anyhow!("ssh is not supported on this platform"))
422 }
423
424 #[cfg(unix)]
425 async fn new(
426 user: String,
427 host: String,
428 port: u16,
429 delegate: Arc<dyn SshClientDelegate>,
430 cx: &mut AsyncAppContext,
431 ) -> Result<Self> {
432 use futures::{io::BufReader, AsyncBufReadExt as _};
433 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
434 use util::ResultExt as _;
435
436 delegate.set_status(Some("connecting"), cx);
437
438 let url = format!("{user}@{host}");
439 let temp_dir = tempfile::Builder::new()
440 .prefix("zed-ssh-session")
441 .tempdir()?;
442
443 // Create a domain socket listener to handle requests from the askpass program.
444 let askpass_socket = temp_dir.path().join("askpass.sock");
445 let listener =
446 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
447
448 let askpass_task = cx.spawn(|mut cx| async move {
449 while let Ok((mut stream, _)) = listener.accept().await {
450 let mut buffer = Vec::new();
451 let mut reader = BufReader::new(&mut stream);
452 if reader.read_until(b'\0', &mut buffer).await.is_err() {
453 buffer.clear();
454 }
455 let password_prompt = String::from_utf8_lossy(&buffer);
456 if let Some(password) = delegate
457 .ask_password(password_prompt.to_string(), &mut cx)
458 .await
459 .context("failed to get ssh password")
460 .and_then(|p| p)
461 .log_err()
462 {
463 stream.write_all(password.as_bytes()).await.log_err();
464 }
465 }
466 });
467
468 // Create an askpass script that communicates back to this process.
469 let askpass_script = format!(
470 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
471 askpass_socket = askpass_socket.display(),
472 print_args = "printf '%s\\0' \"$@\"",
473 shebang = "#!/bin/sh",
474 );
475 let askpass_script_path = temp_dir.path().join("askpass.sh");
476 fs::write(&askpass_script_path, askpass_script).await?;
477 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
478
479 // Start the master SSH process, which does not do anything except for establish
480 // the connection and keep it open, allowing other ssh commands to reuse it
481 // via a control socket.
482 let socket_path = temp_dir.path().join("ssh.sock");
483 let mut master_process = process::Command::new("ssh")
484 .stdin(Stdio::null())
485 .stdout(Stdio::piped())
486 .stderr(Stdio::piped())
487 .env("SSH_ASKPASS_REQUIRE", "force")
488 .env("SSH_ASKPASS", &askpass_script_path)
489 .args(["-N", "-o", "ControlMaster=yes", "-o"])
490 .arg(format!("ControlPath={}", socket_path.display()))
491 .args(["-p", &port.to_string()])
492 .arg(&url)
493 .spawn()?;
494
495 // Wait for this ssh process to close its stdout, indicating that authentication
496 // has completed.
497 let stdout = master_process.stdout.as_mut().unwrap();
498 let mut output = Vec::new();
499 stdout.read_to_end(&mut output).await?;
500 drop(askpass_task);
501
502 if master_process.try_status()?.is_some() {
503 output.clear();
504 let mut stderr = master_process.stderr.take().unwrap();
505 stderr.read_to_end(&mut output).await?;
506 Err(anyhow!(
507 "failed to connect: {}",
508 String::from_utf8_lossy(&output)
509 ))?;
510 }
511
512 Ok(Self {
513 url,
514 port,
515 socket_path,
516 _master_process: master_process,
517 _temp_dir: temp_dir,
518 })
519 }
520
521 async fn ensure_server_binary(
522 &self,
523 delegate: &Arc<dyn SshClientDelegate>,
524 src_path: &Path,
525 dst_path: &Path,
526 version: SemanticVersion,
527 cx: &mut AsyncAppContext,
528 ) -> Result<()> {
529 let mut dst_path_gz = dst_path.to_path_buf();
530 dst_path_gz.set_extension("gz");
531
532 if let Some(parent) = dst_path.parent() {
533 run_cmd(self.ssh_command("mkdir").arg("-p").arg(parent)).await?;
534 }
535
536 let mut server_binary_exists = false;
537 if let Ok(installed_version) = run_cmd(self.ssh_command(&dst_path).arg("version")).await {
538 if installed_version.trim() == version.to_string() {
539 server_binary_exists = true;
540 }
541 }
542
543 if server_binary_exists {
544 log::info!("remote development server already present",);
545 return Ok(());
546 }
547
548 let src_stat = fs::metadata(src_path).await?;
549 let size = src_stat.len();
550 let server_mode = 0o755;
551
552 let t0 = Instant::now();
553 delegate.set_status(Some("uploading remote development server"), cx);
554 log::info!("uploading remote development server ({}kb)", size / 1024);
555 self.upload_file(src_path, &dst_path_gz)
556 .await
557 .context("failed to upload server binary")?;
558 log::info!("uploaded remote development server in {:?}", t0.elapsed());
559
560 delegate.set_status(Some("extracting remote development server"), cx);
561 run_cmd(self.ssh_command("gunzip").arg("--force").arg(&dst_path_gz)).await?;
562
563 delegate.set_status(Some("unzipping remote development server"), cx);
564 run_cmd(
565 self.ssh_command("chmod")
566 .arg(format!("{:o}", server_mode))
567 .arg(&dst_path),
568 )
569 .await?;
570
571 Ok(())
572 }
573
574 async fn query_platform(&self) -> Result<SshPlatform> {
575 let os = run_cmd(self.ssh_command("uname").arg("-s")).await?;
576 let arch = run_cmd(self.ssh_command("uname").arg("-m")).await?;
577
578 let os = match os.trim() {
579 "Darwin" => "macos",
580 "Linux" => "linux",
581 _ => Err(anyhow!("unknown uname os {os:?}"))?,
582 };
583 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
584 "aarch64"
585 } else if arch.starts_with("x86") || arch.starts_with("i686") {
586 "x86_64"
587 } else {
588 Err(anyhow!("unknown uname architecture {arch:?}"))?
589 };
590
591 Ok(SshPlatform { os, arch })
592 }
593
594 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
595 let mut command = process::Command::new("scp");
596 let output = self
597 .ssh_options(&mut command)
598 .arg("-P")
599 .arg(&self.port.to_string())
600 .arg(&src_path)
601 .arg(&format!("{}:{}", self.url, dest_path.display()))
602 .output()
603 .await?;
604
605 if output.status.success() {
606 Ok(())
607 } else {
608 Err(anyhow!(
609 "failed to upload file {} -> {}: {}",
610 src_path.display(),
611 dest_path.display(),
612 String::from_utf8_lossy(&output.stderr)
613 ))
614 }
615 }
616
617 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
618 let mut command = process::Command::new("ssh");
619 self.ssh_options(&mut command)
620 .arg("-p")
621 .arg(&self.port.to_string())
622 .arg(&self.url)
623 .arg(program);
624 command
625 }
626
627 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
628 command
629 .stdin(Stdio::piped())
630 .stdout(Stdio::piped())
631 .stderr(Stdio::piped())
632 .args(["-o", "ControlMaster=no", "-o"])
633 .arg(format!("ControlPath={}", self.socket_path.display()))
634 }
635}
636
637async fn run_cmd(command: &mut process::Command) -> Result<String> {
638 let output = command.output().await?;
639 if output.status.success() {
640 Ok(String::from_utf8_lossy(&output.stdout).to_string())
641 } else {
642 Err(anyhow!(
643 "failed to run command: {}",
644 String::from_utf8_lossy(&output.stderr)
645 ))
646 }
647}