1use crate::protocol::{
2 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
3};
4use anyhow::{anyhow, Context as _, Result};
5use collections::HashMap;
6use futures::{
7 channel::{mpsc, oneshot},
8 future::{BoxFuture, LocalBoxFuture},
9 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
10};
11use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel};
12use parking_lot::Mutex;
13use rpc::{
14 proto::{
15 self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId,
16 ProtoClient, RequestMessage,
17 },
18 TypedEnvelope,
19};
20use smol::{
21 fs,
22 process::{self, Stdio},
23};
24use std::{
25 any::TypeId,
26 ffi::OsStr,
27 path::{Path, PathBuf},
28 sync::{
29 atomic::{AtomicU32, Ordering::SeqCst},
30 Arc,
31 },
32 time::Instant,
33};
34use tempfile::TempDir;
35
36pub struct SshSession {
37 next_message_id: AtomicU32,
38 response_channels: ResponseChannels,
39 outgoing_tx: mpsc::UnboundedSender<Envelope>,
40 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
41 message_handlers: Mutex<
42 HashMap<
43 TypeId,
44 Arc<
45 dyn Send
46 + Sync
47 + Fn(
48 Box<dyn AnyTypedEnvelope>,
49 Arc<SshSession>,
50 AsyncAppContext,
51 ) -> Option<LocalBoxFuture<'static, Result<()>>>,
52 >,
53 >,
54 >,
55}
56
57struct SshClientState {
58 socket_path: PathBuf,
59 port: u16,
60 url: String,
61 _master_process: process::Child,
62 _temp_dir: TempDir,
63}
64
65struct SpawnRequest {
66 command: String,
67 process_tx: oneshot::Sender<process::Child>,
68}
69
70#[derive(Copy, Clone, Debug)]
71pub struct SshPlatform {
72 pub os: &'static str,
73 pub arch: &'static str,
74}
75
76pub trait SshClientDelegate {
77 fn ask_password(
78 &self,
79 prompt: String,
80 cx: &mut AsyncAppContext,
81 ) -> oneshot::Receiver<Result<String>>;
82 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
83 fn get_server_binary(
84 &self,
85 platform: SshPlatform,
86 cx: &mut AsyncAppContext,
87 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
88}
89
90type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
91
92impl SshSession {
93 pub async fn client(
94 user: String,
95 host: String,
96 port: u16,
97 delegate: Arc<dyn SshClientDelegate>,
98 cx: &mut AsyncAppContext,
99 ) -> Result<Arc<Self>> {
100 let client_state = SshClientState::new(user, host, port, delegate.clone(), cx).await?;
101
102 let platform = query_platform(&client_state).await?;
103 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
104 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
105 ensure_server_binary(
106 &client_state,
107 &local_binary_path,
108 &remote_binary_path,
109 version,
110 )
111 .await?;
112
113 let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
114 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
115 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
116
117 let mut remote_server_child = client_state
118 .ssh_command(&remote_binary_path)
119 .arg("run")
120 .spawn()
121 .context("failed to spawn remote server")?;
122 let mut child_stderr = remote_server_child.stderr.take().unwrap();
123 let mut child_stdout = remote_server_child.stdout.take().unwrap();
124 let mut child_stdin = remote_server_child.stdin.take().unwrap();
125
126 let executor = cx.background_executor().clone();
127 executor.clone().spawn(async move {
128 let mut stdin_buffer = Vec::new();
129 let mut stdout_buffer = Vec::new();
130 let mut stderr_buffer = Vec::new();
131 let mut stderr_offset = 0;
132
133 loop {
134 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
135 stderr_buffer.resize(stderr_offset + 1024, 0);
136
137 select_biased! {
138 outgoing = outgoing_rx.next().fuse() => {
139 let Some(outgoing) = outgoing else {
140 return anyhow::Ok(());
141 };
142
143 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
144 }
145
146 request = spawn_process_rx.next().fuse() => {
147 let Some(request) = request else {
148 return Ok(());
149 };
150
151 log::info!("spawn process: {:?}", request.command);
152 let child = client_state
153 .ssh_command(&request.command)
154 .spawn()
155 .context("failed to create channel")?;
156 request.process_tx.send(child).ok();
157 }
158
159 result = child_stdout.read(&mut stdout_buffer).fuse() => {
160 match result {
161 Ok(len) => {
162 if len == 0 {
163 child_stdin.close().await?;
164 let status = remote_server_child.status().await?;
165 if !status.success() {
166 log::info!("channel exited with status: {status:?}");
167 }
168 return Ok(());
169 }
170
171 if len < stdout_buffer.len() {
172 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
173 }
174
175 let message_len = message_len_from_buffer(&stdout_buffer);
176 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
177 Ok(envelope) => {
178 incoming_tx.unbounded_send(envelope).ok();
179 }
180 Err(error) => {
181 log::error!("error decoding message {error:?}");
182 }
183 }
184 }
185 Err(error) => {
186 Err(anyhow!("error reading stdout: {error:?}"))?;
187 }
188 }
189 }
190
191 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
192 match result {
193 Ok(len) => {
194 stderr_offset += len;
195 let mut start_ix = 0;
196 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
197 let line_ix = start_ix + ix;
198 let content = String::from_utf8_lossy(&stderr_buffer[start_ix..line_ix]);
199 start_ix = line_ix + 1;
200 eprintln!("(remote) {}", content);
201 }
202 stderr_buffer.drain(0..start_ix);
203 stderr_offset -= start_ix;
204 }
205 Err(error) => {
206 Err(anyhow!("error reading stderr: {error:?}"))?;
207 }
208 }
209 }
210 }
211 }
212 }).detach();
213
214 cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, cx))
215 }
216
217 pub fn server(
218 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
219 outgoing_tx: mpsc::UnboundedSender<Envelope>,
220 cx: &AppContext,
221 ) -> Arc<SshSession> {
222 let (tx, _rx) = mpsc::unbounded();
223 Self::new(incoming_rx, outgoing_tx, tx, cx)
224 }
225
226 #[cfg(any(test, feature = "test-support"))]
227 pub fn fake(
228 client_cx: &mut gpui::TestAppContext,
229 server_cx: &mut gpui::TestAppContext,
230 ) -> (Arc<Self>, Arc<Self>) {
231 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
232 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
233 let (tx, _rx) = mpsc::unbounded();
234 (
235 client_cx
236 .update(|cx| Self::new(server_to_client_rx, client_to_server_tx, tx.clone(), cx)),
237 server_cx
238 .update(|cx| Self::new(client_to_server_rx, server_to_client_tx, tx.clone(), cx)),
239 )
240 }
241
242 fn new(
243 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
244 outgoing_tx: mpsc::UnboundedSender<Envelope>,
245 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
246 cx: &AppContext,
247 ) -> Arc<SshSession> {
248 let this = Arc::new(Self {
249 next_message_id: AtomicU32::new(0),
250 response_channels: ResponseChannels::default(),
251 outgoing_tx,
252 spawn_process_tx,
253 message_handlers: Default::default(),
254 });
255
256 cx.spawn(|cx| {
257 let this = this.clone();
258 async move {
259 let peer_id = PeerId { owner_id: 0, id: 0 };
260 while let Some(incoming) = incoming_rx.next().await {
261 if let Some(request_id) = incoming.responding_to {
262 let request_id = MessageId(request_id);
263 let sender = this.response_channels.lock().remove(&request_id);
264 if let Some(sender) = sender {
265 let (tx, rx) = oneshot::channel();
266 if incoming.payload.is_some() {
267 sender.send((incoming, tx)).ok();
268 }
269 rx.await.ok();
270 }
271 } else if let Some(envelope) =
272 build_typed_envelope(peer_id, Instant::now(), incoming)
273 {
274 log::debug!(
275 "ssh message received. name:{}",
276 envelope.payload_type_name()
277 );
278 let type_id = envelope.payload_type_id();
279 let handler = this.message_handlers.lock().get(&type_id).cloned();
280 if let Some(handler) = handler {
281 if let Some(future) = handler(envelope, this.clone(), cx.clone()) {
282 future.await.ok();
283 } else {
284 this.message_handlers.lock().remove(&type_id);
285 }
286 }
287 }
288 }
289 anyhow::Ok(())
290 }
291 })
292 .detach();
293
294 this
295 }
296
297 pub fn request<T: RequestMessage>(
298 &self,
299 payload: T,
300 ) -> impl 'static + Future<Output = Result<T::Response>> {
301 log::debug!("ssh request start. name:{}", T::NAME);
302 let response = self.request_dynamic(payload.into_envelope(0, None, None), "");
303 async move {
304 let response = response.await?;
305 log::debug!("ssh request finish. name:{}", T::NAME);
306 T::Response::from_envelope(response)
307 .ok_or_else(|| anyhow!("received a response of the wrong type"))
308 }
309 }
310
311 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
312 self.send_dynamic(payload.into_envelope(0, None, None))
313 }
314
315 pub fn request_dynamic(
316 &self,
317 mut envelope: proto::Envelope,
318 _request_type: &'static str,
319 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
320 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
321 let (tx, rx) = oneshot::channel();
322 self.response_channels
323 .lock()
324 .insert(MessageId(envelope.id), tx);
325 self.outgoing_tx.unbounded_send(envelope).ok();
326 async move { Ok(rx.await.context("connection lost")?.0) }
327 }
328
329 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
330 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
331 self.outgoing_tx.unbounded_send(envelope)?;
332 Ok(())
333 }
334
335 pub async fn spawn_process(&self, command: String) -> process::Child {
336 let (process_tx, process_rx) = oneshot::channel();
337 self.spawn_process_tx
338 .unbounded_send(SpawnRequest {
339 command,
340 process_tx,
341 })
342 .ok();
343 process_rx.await.unwrap()
344 }
345
346 pub fn add_message_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
347 where
348 M: EnvelopedMessage,
349 E: 'static,
350 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
351 F: 'static + Future<Output = Result<()>>,
352 {
353 let message_type_id = TypeId::of::<M>();
354 self.message_handlers.lock().insert(
355 message_type_id,
356 Arc::new(move |envelope, _, cx| {
357 let entity = entity.upgrade()?;
358 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
359 Some(handler(entity, *envelope, cx).boxed_local())
360 }),
361 );
362 }
363
364 pub fn add_request_handler<M, E, H, F>(&self, entity: WeakModel<E>, handler: H)
365 where
366 M: EnvelopedMessage + RequestMessage,
367 E: 'static,
368 H: 'static + Sync + Send + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
369 F: 'static + Future<Output = Result<M::Response>>,
370 {
371 let message_type_id = TypeId::of::<M>();
372 self.message_handlers.lock().insert(
373 message_type_id,
374 Arc::new(move |envelope, this, cx| {
375 let entity = entity.upgrade()?;
376 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
377 let request_id = envelope.message_id();
378 Some(
379 handler(entity, *envelope, cx)
380 .then(move |result| async move {
381 this.outgoing_tx.unbounded_send(result?.into_envelope(
382 this.next_message_id.fetch_add(1, SeqCst),
383 Some(request_id),
384 None,
385 ))?;
386 Ok(())
387 })
388 .boxed_local(),
389 )
390 }),
391 );
392 }
393}
394
395impl ProtoClient for SshSession {
396 fn request(
397 &self,
398 envelope: proto::Envelope,
399 request_type: &'static str,
400 ) -> BoxFuture<'static, Result<proto::Envelope>> {
401 self.request_dynamic(envelope, request_type).boxed()
402 }
403
404 fn send(&self, envelope: proto::Envelope) -> Result<()> {
405 self.send_dynamic(envelope)
406 }
407}
408
409impl SshClientState {
410 #[cfg(not(unix))]
411 async fn new(
412 user: String,
413 host: String,
414 port: u16,
415 delegate: Arc<dyn SshClientDelegate>,
416 cx: &AsyncAppContext,
417 ) -> Result<Self> {
418 Err(anyhow!("ssh is not supported on this platform"))
419 }
420
421 #[cfg(unix)]
422 async fn new(
423 user: String,
424 host: String,
425 port: u16,
426 delegate: Arc<dyn SshClientDelegate>,
427 cx: &AsyncAppContext,
428 ) -> Result<Self> {
429 use smol::fs::unix::PermissionsExt as _;
430 use util::ResultExt as _;
431
432 let url = format!("{user}@{host}");
433 let temp_dir = tempfile::Builder::new()
434 .prefix("zed-ssh-session")
435 .tempdir()?;
436
437 // Create a TCP listener to handle requests from the askpass program.
438 let listener = smol::net::TcpListener::bind("127.0.0.1:0")
439 .await
440 .expect("failed to find open port");
441 let askpass_port = listener.local_addr().unwrap().port();
442 let askpass_task = cx.spawn(|mut cx| async move {
443 while let Ok((mut stream, _)) = listener.accept().await {
444 let mut buffer = Vec::new();
445 if stream.read_to_end(&mut buffer).await.is_err() {
446 buffer.clear();
447 }
448 let password_prompt = String::from_utf8_lossy(&buffer);
449 if let Some(password) = delegate
450 .ask_password(password_prompt.to_string(), &mut cx)
451 .await
452 .context("failed to get ssh password")
453 .and_then(|p| p)
454 .log_err()
455 {
456 stream.write_all(password.as_bytes()).await.log_err();
457 }
458 }
459 });
460
461 // Create an askpass script that communicates back to this process using TCP.
462 let askpass_script = format!(
463 "{shebang}\n echo \"$@\" | nc 127.0.0.1 {askpass_port} 2> /dev/null",
464 shebang = "#!/bin/sh"
465 );
466 let askpass_script_path = temp_dir.path().join("askpass.sh");
467 fs::write(&askpass_script_path, askpass_script).await?;
468 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
469
470 // Start the master SSH process, which does not do anything except for establish
471 // the connection and keep it open, allowing other ssh commands to reuse it
472 // via a control socket.
473 let socket_path = temp_dir.path().join("ssh.sock");
474 let mut master_process = process::Command::new("ssh")
475 .stdin(Stdio::null())
476 .stdout(Stdio::piped())
477 .stderr(Stdio::piped())
478 .env("SSH_ASKPASS_REQUIRE", "force")
479 .env("SSH_ASKPASS", &askpass_script_path)
480 .args(["-N", "-o", "ControlMaster=yes", "-o"])
481 .arg(format!("ControlPath={}", socket_path.display()))
482 .args(["-p", &port.to_string()])
483 .arg(&url)
484 .spawn()?;
485
486 // Wait for this ssh process to close its stdout, indicating that authentication
487 // has completed.
488 let stdout = master_process.stdout.as_mut().unwrap();
489 let mut output = Vec::new();
490 stdout.read_to_end(&mut output).await?;
491 drop(askpass_task);
492
493 if master_process.try_status()?.is_some() {
494 output.clear();
495 let mut stderr = master_process.stderr.take().unwrap();
496 stderr.read_to_end(&mut output).await?;
497 Err(anyhow!(
498 "failed to connect: {}",
499 String::from_utf8_lossy(&output)
500 ))?;
501 }
502
503 Ok(Self {
504 _master_process: master_process,
505 port,
506 _temp_dir: temp_dir,
507 socket_path,
508 url,
509 })
510 }
511
512 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
513 let mut command = process::Command::new("scp");
514 let output = self
515 .ssh_options(&mut command)
516 .arg("-P")
517 .arg(&self.port.to_string())
518 .arg(&src_path)
519 .arg(&format!("{}:{}", self.url, dest_path.display()))
520 .output()
521 .await?;
522
523 if output.status.success() {
524 Ok(())
525 } else {
526 Err(anyhow!(
527 "failed to upload file {} -> {}: {}",
528 src_path.display(),
529 dest_path.display(),
530 String::from_utf8_lossy(&output.stderr)
531 ))
532 }
533 }
534
535 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
536 let mut command = process::Command::new("ssh");
537 self.ssh_options(&mut command)
538 .arg("-p")
539 .arg(&self.port.to_string())
540 .arg(&self.url)
541 .arg(program);
542 command
543 }
544
545 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
546 command
547 .stdin(Stdio::piped())
548 .stdout(Stdio::piped())
549 .stderr(Stdio::piped())
550 .args(["-o", "ControlMaster=no", "-o"])
551 .arg(format!("ControlPath={}", self.socket_path.display()))
552 }
553}
554
555async fn run_cmd(command: &mut process::Command) -> Result<String> {
556 let output = command.output().await?;
557 if output.status.success() {
558 Ok(String::from_utf8_lossy(&output.stdout).to_string())
559 } else {
560 Err(anyhow!(
561 "failed to run command: {}",
562 String::from_utf8_lossy(&output.stderr)
563 ))
564 }
565}
566
567async fn query_platform(session: &SshClientState) -> Result<SshPlatform> {
568 let os = run_cmd(session.ssh_command("uname").arg("-s")).await?;
569 let arch = run_cmd(session.ssh_command("uname").arg("-m")).await?;
570
571 let os = match os.trim() {
572 "Darwin" => "macos",
573 "Linux" => "linux",
574 _ => Err(anyhow!("unknown uname os {os:?}"))?,
575 };
576 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
577 "aarch64"
578 } else if arch.starts_with("x86") || arch.starts_with("i686") {
579 "x86_64"
580 } else {
581 Err(anyhow!("unknown uname architecture {arch:?}"))?
582 };
583
584 Ok(SshPlatform { os, arch })
585}
586
587async fn ensure_server_binary(
588 session: &SshClientState,
589 src_path: &Path,
590 dst_path: &Path,
591 version: SemanticVersion,
592) -> Result<()> {
593 let mut dst_path_gz = dst_path.to_path_buf();
594 dst_path_gz.set_extension("gz");
595
596 if let Some(parent) = dst_path.parent() {
597 run_cmd(session.ssh_command("mkdir").arg("-p").arg(parent)).await?;
598 }
599
600 let mut server_binary_exists = false;
601 if let Ok(installed_version) = run_cmd(session.ssh_command(&dst_path).arg("version")).await {
602 if installed_version.trim() == version.to_string() {
603 server_binary_exists = true;
604 }
605 }
606
607 if server_binary_exists {
608 log::info!("remote development server already present",);
609 return Ok(());
610 }
611
612 let src_stat = fs::metadata(src_path).await?;
613 let size = src_stat.len();
614 let server_mode = 0o755;
615
616 let t0 = Instant::now();
617 log::info!("uploading remote development server ({}kb)", size / 1024);
618 session
619 .upload_file(src_path, &dst_path_gz)
620 .await
621 .context("failed to upload server binary")?;
622 log::info!("uploaded remote development server in {:?}", t0.elapsed());
623
624 log::info!("extracting remote development server");
625 run_cmd(
626 session
627 .ssh_command("gunzip")
628 .arg("--force")
629 .arg(&dst_path_gz),
630 )
631 .await?;
632
633 log::info!("unzipping remote development server");
634 run_cmd(
635 session
636 .ssh_command("chmod")
637 .arg(format!("{:o}", server_mode))
638 .arg(&dst_path),
639 )
640 .await?;
641
642 Ok(())
643}