1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6};
7use anyhow::{anyhow, Context as _, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 future::BoxFuture,
12 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
13};
14use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion};
15use parking_lot::Mutex;
16use rpc::{
17 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
18 EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
19};
20use smol::{
21 fs,
22 process::{self, Stdio},
23};
24use std::{
25 any::TypeId,
26 ffi::OsStr,
27 path::{Path, PathBuf},
28 sync::{
29 atomic::{AtomicU32, Ordering::SeqCst},
30 Arc,
31 },
32 time::Instant,
33};
34use tempfile::TempDir;
35
36#[derive(Clone)]
37pub struct SshSocket {
38 connection_options: SshConnectionOptions,
39 socket_path: PathBuf,
40}
41
42pub struct SshSession {
43 next_message_id: AtomicU32,
44 response_channels: ResponseChannels, // Lock
45 outgoing_tx: mpsc::UnboundedSender<Envelope>,
46 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
47 client_socket: Option<SshSocket>,
48 state: Mutex<ProtoMessageHandlerSet>, // Lock
49}
50
51struct SshClientState {
52 socket: SshSocket,
53 _master_process: process::Child,
54 _temp_dir: TempDir,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct SshConnectionOptions {
59 pub host: String,
60 pub username: Option<String>,
61 pub port: Option<u16>,
62 pub password: Option<String>,
63}
64
65impl SshConnectionOptions {
66 pub fn ssh_url(&self) -> String {
67 let mut result = String::from("ssh://");
68 if let Some(username) = &self.username {
69 result.push_str(username);
70 result.push('@');
71 }
72 result.push_str(&self.host);
73 if let Some(port) = self.port {
74 result.push(':');
75 result.push_str(&port.to_string());
76 }
77 result
78 }
79
80 fn scp_url(&self) -> String {
81 if let Some(username) = &self.username {
82 format!("{}@{}", username, self.host)
83 } else {
84 self.host.clone()
85 }
86 }
87
88 pub fn connection_string(&self) -> String {
89 let host = if let Some(username) = &self.username {
90 format!("{}@{}", username, self.host)
91 } else {
92 self.host.clone()
93 };
94 if let Some(port) = &self.port {
95 format!("{}:{}", host, port)
96 } else {
97 host
98 }
99 }
100}
101
102struct SpawnRequest {
103 command: String,
104 process_tx: oneshot::Sender<process::Child>,
105}
106
107#[derive(Copy, Clone, Debug)]
108pub struct SshPlatform {
109 pub os: &'static str,
110 pub arch: &'static str,
111}
112
113pub trait SshClientDelegate {
114 fn ask_password(
115 &self,
116 prompt: String,
117 cx: &mut AsyncAppContext,
118 ) -> oneshot::Receiver<Result<String>>;
119 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
120 fn get_server_binary(
121 &self,
122 platform: SshPlatform,
123 cx: &mut AsyncAppContext,
124 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
125 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
126}
127
128type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
129
130impl SshSession {
131 pub async fn client(
132 connection_options: SshConnectionOptions,
133 delegate: Arc<dyn SshClientDelegate>,
134 cx: &mut AsyncAppContext,
135 ) -> Result<Arc<Self>> {
136 let client_state = SshClientState::new(connection_options, delegate.clone(), cx).await?;
137
138 let platform = client_state.query_platform().await?;
139 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
140 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
141 client_state
142 .ensure_server_binary(
143 &delegate,
144 &local_binary_path,
145 &remote_binary_path,
146 version,
147 cx,
148 )
149 .await?;
150
151 let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
152 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
153 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
154
155 let socket = client_state.socket.clone();
156 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
157
158 let mut remote_server_child = socket
159 .ssh_command(format!(
160 "RUST_LOG={} RUST_BACKTRACE={} {:?} run",
161 std::env::var("RUST_LOG").unwrap_or_default(),
162 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
163 remote_binary_path,
164 ))
165 .spawn()
166 .context("failed to spawn remote server")?;
167 let mut child_stderr = remote_server_child.stderr.take().unwrap();
168 let mut child_stdout = remote_server_child.stdout.take().unwrap();
169 let mut child_stdin = remote_server_child.stdin.take().unwrap();
170
171 let executor = cx.background_executor().clone();
172 executor.clone().spawn(async move {
173 let mut stdin_buffer = Vec::new();
174 let mut stdout_buffer = Vec::new();
175 let mut stderr_buffer = Vec::new();
176 let mut stderr_offset = 0;
177
178 loop {
179 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
180 stderr_buffer.resize(stderr_offset + 1024, 0);
181
182 select_biased! {
183 outgoing = outgoing_rx.next().fuse() => {
184 let Some(outgoing) = outgoing else {
185 return anyhow::Ok(());
186 };
187
188 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
189 }
190
191 request = spawn_process_rx.next().fuse() => {
192 let Some(request) = request else {
193 return Ok(());
194 };
195
196 log::info!("spawn process: {:?}", request.command);
197 let child = client_state.socket
198 .ssh_command(&request.command)
199 .spawn()
200 .context("failed to create channel")?;
201 request.process_tx.send(child).ok();
202 }
203
204 result = child_stdout.read(&mut stdout_buffer).fuse() => {
205 match result {
206 Ok(len) => {
207 if len == 0 {
208 child_stdin.close().await?;
209 let status = remote_server_child.status().await?;
210 if !status.success() {
211 log::info!("channel exited with status: {status:?}");
212 }
213 return Ok(());
214 }
215
216 if len < stdout_buffer.len() {
217 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
218 }
219
220 let message_len = message_len_from_buffer(&stdout_buffer);
221 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
222 Ok(envelope) => {
223 incoming_tx.unbounded_send(envelope).ok();
224 }
225 Err(error) => {
226 log::error!("error decoding message {error:?}");
227 }
228 }
229 }
230 Err(error) => {
231 Err(anyhow!("error reading stdout: {error:?}"))?;
232 }
233 }
234 }
235
236 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
237 match result {
238 Ok(len) => {
239 stderr_offset += len;
240 let mut start_ix = 0;
241 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
242 let line_ix = start_ix + ix;
243 let content = &stderr_buffer[start_ix..line_ix];
244 start_ix = line_ix + 1;
245 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
246 record.log(log::logger())
247 } else {
248 eprintln!("(remote) {}", String::from_utf8_lossy(content));
249 }
250 }
251 stderr_buffer.drain(0..start_ix);
252 stderr_offset -= start_ix;
253 }
254 Err(error) => {
255 Err(anyhow!("error reading stderr: {error:?}"))?;
256 }
257 }
258 }
259 }
260 }
261 }).detach();
262
263 cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, Some(socket), cx))
264 }
265
266 pub fn server(
267 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
268 outgoing_tx: mpsc::UnboundedSender<Envelope>,
269 cx: &AppContext,
270 ) -> Arc<SshSession> {
271 let (tx, _rx) = mpsc::unbounded();
272 Self::new(incoming_rx, outgoing_tx, tx, None, cx)
273 }
274
275 #[cfg(any(test, feature = "test-support"))]
276 pub fn fake(
277 client_cx: &mut gpui::TestAppContext,
278 server_cx: &mut gpui::TestAppContext,
279 ) -> (Arc<Self>, Arc<Self>) {
280 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
281 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
282 let (tx, _rx) = mpsc::unbounded();
283 (
284 client_cx.update(|cx| {
285 Self::new(
286 server_to_client_rx,
287 client_to_server_tx,
288 tx.clone(),
289 None, // todo()
290 cx,
291 )
292 }),
293 server_cx.update(|cx| {
294 Self::new(
295 client_to_server_rx,
296 server_to_client_tx,
297 tx.clone(),
298 None,
299 cx,
300 )
301 }),
302 )
303 }
304
305 fn new(
306 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
307 outgoing_tx: mpsc::UnboundedSender<Envelope>,
308 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
309 client_socket: Option<SshSocket>,
310 cx: &AppContext,
311 ) -> Arc<SshSession> {
312 let this = Arc::new(Self {
313 next_message_id: AtomicU32::new(0),
314 response_channels: ResponseChannels::default(),
315 outgoing_tx,
316 spawn_process_tx,
317 client_socket,
318 state: Default::default(),
319 });
320
321 cx.spawn(|cx| {
322 let this = this.clone();
323 async move {
324 let peer_id = PeerId { owner_id: 0, id: 0 };
325 while let Some(incoming) = incoming_rx.next().await {
326 if let Some(request_id) = incoming.responding_to {
327 let request_id = MessageId(request_id);
328 let sender = this.response_channels.lock().remove(&request_id);
329 if let Some(sender) = sender {
330 let (tx, rx) = oneshot::channel();
331 if incoming.payload.is_some() {
332 sender.send((incoming, tx)).ok();
333 }
334 rx.await.ok();
335 }
336 } else if let Some(envelope) =
337 build_typed_envelope(peer_id, Instant::now(), incoming)
338 {
339 let type_name = envelope.payload_type_name();
340 if let Some(future) = ProtoMessageHandlerSet::handle_message(
341 &this.state,
342 envelope,
343 this.clone().into(),
344 cx.clone(),
345 ) {
346 log::debug!("ssh message received. name:{type_name}");
347 match future.await {
348 Ok(_) => {
349 log::debug!("ssh message handled. name:{type_name}");
350 }
351 Err(error) => {
352 log::error!(
353 "error handling message. type:{type_name}, error:{error}",
354 );
355 }
356 }
357 } else {
358 log::error!("unhandled ssh message name:{type_name}");
359 }
360 }
361 }
362 anyhow::Ok(())
363 }
364 })
365 .detach();
366
367 this
368 }
369
370 pub fn request<T: RequestMessage>(
371 &self,
372 payload: T,
373 ) -> impl 'static + Future<Output = Result<T::Response>> {
374 log::debug!("ssh request start. name:{}", T::NAME);
375 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
376 async move {
377 let response = response.await?;
378 log::debug!("ssh request finish. name:{}", T::NAME);
379 T::Response::from_envelope(response)
380 .ok_or_else(|| anyhow!("received a response of the wrong type"))
381 }
382 }
383
384 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
385 log::debug!("ssh send name:{}", T::NAME);
386 self.send_dynamic(payload.into_envelope(0, None, None))
387 }
388
389 pub fn request_dynamic(
390 &self,
391 mut envelope: proto::Envelope,
392 type_name: &'static str,
393 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
394 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
395 let (tx, rx) = oneshot::channel();
396 let mut response_channels_lock = self.response_channels.lock();
397 response_channels_lock.insert(MessageId(envelope.id), tx);
398 drop(response_channels_lock);
399 self.outgoing_tx.unbounded_send(envelope).ok();
400 async move {
401 let response = rx.await.context("connection lost")?.0;
402 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
403 return Err(RpcError::from_proto(error, type_name));
404 }
405 Ok(response)
406 }
407 }
408
409 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
410 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
411 self.outgoing_tx.unbounded_send(envelope)?;
412 Ok(())
413 }
414
415 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
416 let id = (TypeId::of::<E>(), remote_id);
417
418 let mut state = self.state.lock();
419 if state.entities_by_type_and_remote_id.contains_key(&id) {
420 panic!("already subscribed to entity");
421 }
422
423 state.entities_by_type_and_remote_id.insert(
424 id,
425 EntityMessageSubscriber::Entity {
426 handle: entity.downgrade().into(),
427 },
428 );
429 }
430
431 pub async fn spawn_process(&self, command: String) -> process::Child {
432 let (process_tx, process_rx) = oneshot::channel();
433 self.spawn_process_tx
434 .unbounded_send(SpawnRequest {
435 command,
436 process_tx,
437 })
438 .ok();
439 process_rx.await.unwrap()
440 }
441
442 pub fn ssh_args(&self) -> Vec<String> {
443 self.client_socket.as_ref().unwrap().ssh_args()
444 }
445}
446
447impl ProtoClient for SshSession {
448 fn request(
449 &self,
450 envelope: proto::Envelope,
451 request_type: &'static str,
452 ) -> BoxFuture<'static, Result<proto::Envelope>> {
453 self.request_dynamic(envelope, request_type).boxed()
454 }
455
456 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
457 self.send_dynamic(envelope)
458 }
459
460 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
461 self.send_dynamic(envelope)
462 }
463
464 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
465 &self.state
466 }
467}
468
469impl SshClientState {
470 #[cfg(not(unix))]
471 async fn new(
472 _connection_options: SshConnectionOptions,
473 _delegate: Arc<dyn SshClientDelegate>,
474 _cx: &mut AsyncAppContext,
475 ) -> Result<Self> {
476 Err(anyhow!("ssh is not supported on this platform"))
477 }
478
479 #[cfg(unix)]
480 async fn new(
481 connection_options: SshConnectionOptions,
482 delegate: Arc<dyn SshClientDelegate>,
483 cx: &mut AsyncAppContext,
484 ) -> Result<Self> {
485 use futures::{io::BufReader, AsyncBufReadExt as _};
486 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
487 use util::ResultExt as _;
488
489 delegate.set_status(Some("connecting"), cx);
490
491 let url = connection_options.ssh_url();
492 let temp_dir = tempfile::Builder::new()
493 .prefix("zed-ssh-session")
494 .tempdir()?;
495
496 // Create a domain socket listener to handle requests from the askpass program.
497 let askpass_socket = temp_dir.path().join("askpass.sock");
498 let listener =
499 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
500
501 let askpass_task = cx.spawn(|mut cx| async move {
502 while let Ok((mut stream, _)) = listener.accept().await {
503 let mut buffer = Vec::new();
504 let mut reader = BufReader::new(&mut stream);
505 if reader.read_until(b'\0', &mut buffer).await.is_err() {
506 buffer.clear();
507 }
508 let password_prompt = String::from_utf8_lossy(&buffer);
509 if let Some(password) = delegate
510 .ask_password(password_prompt.to_string(), &mut cx)
511 .await
512 .context("failed to get ssh password")
513 .and_then(|p| p)
514 .log_err()
515 {
516 stream.write_all(password.as_bytes()).await.log_err();
517 }
518 }
519 });
520
521 // Create an askpass script that communicates back to this process.
522 let askpass_script = format!(
523 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
524 askpass_socket = askpass_socket.display(),
525 print_args = "printf '%s\\0' \"$@\"",
526 shebang = "#!/bin/sh",
527 );
528 let askpass_script_path = temp_dir.path().join("askpass.sh");
529 fs::write(&askpass_script_path, askpass_script).await?;
530 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
531
532 // Start the master SSH process, which does not do anything except for establish
533 // the connection and keep it open, allowing other ssh commands to reuse it
534 // via a control socket.
535 let socket_path = temp_dir.path().join("ssh.sock");
536 let mut master_process = process::Command::new("ssh")
537 .stdin(Stdio::null())
538 .stdout(Stdio::piped())
539 .stderr(Stdio::piped())
540 .env("SSH_ASKPASS_REQUIRE", "force")
541 .env("SSH_ASKPASS", &askpass_script_path)
542 .args(["-N", "-o", "ControlMaster=yes", "-o"])
543 .arg(format!("ControlPath={}", socket_path.display()))
544 .arg(&url)
545 .spawn()?;
546
547 // Wait for this ssh process to close its stdout, indicating that authentication
548 // has completed.
549 let stdout = master_process.stdout.as_mut().unwrap();
550 let mut output = Vec::new();
551 stdout.read_to_end(&mut output).await?;
552 drop(askpass_task);
553
554 if master_process.try_status()?.is_some() {
555 output.clear();
556 let mut stderr = master_process.stderr.take().unwrap();
557 stderr.read_to_end(&mut output).await?;
558 Err(anyhow!(
559 "failed to connect: {}",
560 String::from_utf8_lossy(&output)
561 ))?;
562 }
563
564 Ok(Self {
565 socket: SshSocket {
566 connection_options,
567 socket_path,
568 },
569 _master_process: master_process,
570 _temp_dir: temp_dir,
571 })
572 }
573
574 async fn ensure_server_binary(
575 &self,
576 delegate: &Arc<dyn SshClientDelegate>,
577 src_path: &Path,
578 dst_path: &Path,
579 version: SemanticVersion,
580 cx: &mut AsyncAppContext,
581 ) -> Result<()> {
582 let mut dst_path_gz = dst_path.to_path_buf();
583 dst_path_gz.set_extension("gz");
584
585 if let Some(parent) = dst_path.parent() {
586 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
587 }
588
589 let mut server_binary_exists = false;
590 if cfg!(not(debug_assertions)) {
591 if let Ok(installed_version) =
592 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
593 {
594 if installed_version.trim() == version.to_string() {
595 server_binary_exists = true;
596 }
597 }
598 }
599
600 if server_binary_exists {
601 log::info!("remote development server already present",);
602 return Ok(());
603 }
604
605 let src_stat = fs::metadata(src_path).await?;
606 let size = src_stat.len();
607 let server_mode = 0o755;
608
609 let t0 = Instant::now();
610 delegate.set_status(Some("uploading remote development server"), cx);
611 log::info!("uploading remote development server ({}kb)", size / 1024);
612 self.upload_file(src_path, &dst_path_gz)
613 .await
614 .context("failed to upload server binary")?;
615 log::info!("uploaded remote development server in {:?}", t0.elapsed());
616
617 delegate.set_status(Some("extracting remote development server"), cx);
618 run_cmd(
619 self.socket
620 .ssh_command("gunzip")
621 .arg("--force")
622 .arg(&dst_path_gz),
623 )
624 .await?;
625
626 delegate.set_status(Some("unzipping remote development server"), cx);
627 run_cmd(
628 self.socket
629 .ssh_command("chmod")
630 .arg(format!("{:o}", server_mode))
631 .arg(dst_path),
632 )
633 .await?;
634
635 Ok(())
636 }
637
638 async fn query_platform(&self) -> Result<SshPlatform> {
639 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
640 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
641
642 let os = match os.trim() {
643 "Darwin" => "macos",
644 "Linux" => "linux",
645 _ => Err(anyhow!("unknown uname os {os:?}"))?,
646 };
647 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
648 "aarch64"
649 } else if arch.starts_with("x86") || arch.starts_with("i686") {
650 "x86_64"
651 } else {
652 Err(anyhow!("unknown uname architecture {arch:?}"))?
653 };
654
655 Ok(SshPlatform { os, arch })
656 }
657
658 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
659 let mut command = process::Command::new("scp");
660 let output = self
661 .socket
662 .ssh_options(&mut command)
663 .args(
664 self.socket
665 .connection_options
666 .port
667 .map(|port| vec!["-P".to_string(), port.to_string()])
668 .unwrap_or_default(),
669 )
670 .arg(src_path)
671 .arg(format!(
672 "{}:{}",
673 self.socket.connection_options.scp_url(),
674 dest_path.display()
675 ))
676 .output()
677 .await?;
678
679 if output.status.success() {
680 Ok(())
681 } else {
682 Err(anyhow!(
683 "failed to upload file {} -> {}: {}",
684 src_path.display(),
685 dest_path.display(),
686 String::from_utf8_lossy(&output.stderr)
687 ))
688 }
689 }
690}
691
692impl SshSocket {
693 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
694 let mut command = process::Command::new("ssh");
695 self.ssh_options(&mut command)
696 .arg(self.connection_options.ssh_url())
697 .arg(program);
698 command
699 }
700
701 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
702 command
703 .stdin(Stdio::piped())
704 .stdout(Stdio::piped())
705 .stderr(Stdio::piped())
706 .args(["-o", "ControlMaster=no", "-o"])
707 .arg(format!("ControlPath={}", self.socket_path.display()))
708 }
709
710 fn ssh_args(&self) -> Vec<String> {
711 vec![
712 "-o".to_string(),
713 "ControlMaster=no".to_string(),
714 "-o".to_string(),
715 format!("ControlPath={}", self.socket_path.display()),
716 self.connection_options.ssh_url(),
717 ]
718 }
719}
720
721async fn run_cmd(command: &mut process::Command) -> Result<String> {
722 let output = command.output().await?;
723 if output.status.success() {
724 Ok(String::from_utf8_lossy(&output.stdout).to_string())
725 } else {
726 Err(anyhow!(
727 "failed to run command: {}",
728 String::from_utf8_lossy(&output.stderr)
729 ))
730 }
731}