1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6};
7use anyhow::{anyhow, Context as _, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 future::BoxFuture,
12 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
13};
14use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion};
15use parking_lot::Mutex;
16use rpc::{
17 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
18 EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
19};
20use smol::{
21 fs,
22 process::{self, Stdio},
23};
24use std::{
25 any::TypeId,
26 ffi::OsStr,
27 path::{Path, PathBuf},
28 sync::{
29 atomic::{AtomicU32, Ordering::SeqCst},
30 Arc,
31 },
32 time::Instant,
33};
34use tempfile::TempDir;
35
36#[derive(
37 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
38)]
39pub struct SshProjectId(pub u64);
40
41#[derive(Clone)]
42pub struct SshSocket {
43 connection_options: SshConnectionOptions,
44 socket_path: PathBuf,
45}
46
47pub struct SshSession {
48 next_message_id: AtomicU32,
49 response_channels: ResponseChannels, // Lock
50 outgoing_tx: mpsc::UnboundedSender<Envelope>,
51 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
52 client_socket: Option<SshSocket>,
53 state: Mutex<ProtoMessageHandlerSet>, // Lock
54}
55
56struct SshClientState {
57 socket: SshSocket,
58 _master_process: process::Child,
59 _temp_dir: TempDir,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub struct SshConnectionOptions {
64 pub host: String,
65 pub username: Option<String>,
66 pub port: Option<u16>,
67 pub password: Option<String>,
68}
69
70impl SshConnectionOptions {
71 pub fn ssh_url(&self) -> String {
72 let mut result = String::from("ssh://");
73 if let Some(username) = &self.username {
74 result.push_str(username);
75 result.push('@');
76 }
77 result.push_str(&self.host);
78 if let Some(port) = self.port {
79 result.push(':');
80 result.push_str(&port.to_string());
81 }
82 result
83 }
84
85 fn scp_url(&self) -> String {
86 if let Some(username) = &self.username {
87 format!("{}@{}", username, self.host)
88 } else {
89 self.host.clone()
90 }
91 }
92
93 pub fn connection_string(&self) -> String {
94 let host = if let Some(username) = &self.username {
95 format!("{}@{}", username, self.host)
96 } else {
97 self.host.clone()
98 };
99 if let Some(port) = &self.port {
100 format!("{}:{}", host, port)
101 } else {
102 host
103 }
104 }
105}
106
107struct SpawnRequest {
108 command: String,
109 process_tx: oneshot::Sender<process::Child>,
110}
111
112#[derive(Copy, Clone, Debug)]
113pub struct SshPlatform {
114 pub os: &'static str,
115 pub arch: &'static str,
116}
117
118pub trait SshClientDelegate {
119 fn ask_password(
120 &self,
121 prompt: String,
122 cx: &mut AsyncAppContext,
123 ) -> oneshot::Receiver<Result<String>>;
124 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
125 fn get_server_binary(
126 &self,
127 platform: SshPlatform,
128 cx: &mut AsyncAppContext,
129 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
130 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
131}
132
133type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
134
135impl SshSession {
136 pub async fn client(
137 connection_options: SshConnectionOptions,
138 delegate: Arc<dyn SshClientDelegate>,
139 cx: &mut AsyncAppContext,
140 ) -> Result<Arc<Self>> {
141 let client_state = SshClientState::new(connection_options, delegate.clone(), cx).await?;
142
143 let platform = client_state.query_platform().await?;
144 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
145 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
146 client_state
147 .ensure_server_binary(
148 &delegate,
149 &local_binary_path,
150 &remote_binary_path,
151 version,
152 cx,
153 )
154 .await?;
155
156 let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
157 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
158 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
159
160 let socket = client_state.socket.clone();
161 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
162
163 let mut remote_server_child = socket
164 .ssh_command(format!(
165 "RUST_LOG={} RUST_BACKTRACE={} {:?} run",
166 std::env::var("RUST_LOG").unwrap_or_default(),
167 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
168 remote_binary_path,
169 ))
170 .spawn()
171 .context("failed to spawn remote server")?;
172 let mut child_stderr = remote_server_child.stderr.take().unwrap();
173 let mut child_stdout = remote_server_child.stdout.take().unwrap();
174 let mut child_stdin = remote_server_child.stdin.take().unwrap();
175
176 let executor = cx.background_executor().clone();
177 executor.clone().spawn(async move {
178 let mut stdin_buffer = Vec::new();
179 let mut stdout_buffer = Vec::new();
180 let mut stderr_buffer = Vec::new();
181 let mut stderr_offset = 0;
182
183 loop {
184 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
185 stderr_buffer.resize(stderr_offset + 1024, 0);
186
187 select_biased! {
188 outgoing = outgoing_rx.next().fuse() => {
189 let Some(outgoing) = outgoing else {
190 return anyhow::Ok(());
191 };
192
193 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
194 }
195
196 request = spawn_process_rx.next().fuse() => {
197 let Some(request) = request else {
198 return Ok(());
199 };
200
201 log::info!("spawn process: {:?}", request.command);
202 let child = client_state.socket
203 .ssh_command(&request.command)
204 .spawn()
205 .context("failed to create channel")?;
206 request.process_tx.send(child).ok();
207 }
208
209 result = child_stdout.read(&mut stdout_buffer).fuse() => {
210 match result {
211 Ok(len) => {
212 if len == 0 {
213 child_stdin.close().await?;
214 let status = remote_server_child.status().await?;
215 if !status.success() {
216 log::info!("channel exited with status: {status:?}");
217 }
218 return Ok(());
219 }
220
221 if len < stdout_buffer.len() {
222 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
223 }
224
225 let message_len = message_len_from_buffer(&stdout_buffer);
226 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
227 Ok(envelope) => {
228 incoming_tx.unbounded_send(envelope).ok();
229 }
230 Err(error) => {
231 log::error!("error decoding message {error:?}");
232 }
233 }
234 }
235 Err(error) => {
236 Err(anyhow!("error reading stdout: {error:?}"))?;
237 }
238 }
239 }
240
241 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
242 match result {
243 Ok(len) => {
244 stderr_offset += len;
245 let mut start_ix = 0;
246 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
247 let line_ix = start_ix + ix;
248 let content = &stderr_buffer[start_ix..line_ix];
249 start_ix = line_ix + 1;
250 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
251 record.log(log::logger())
252 } else {
253 eprintln!("(remote) {}", String::from_utf8_lossy(content));
254 }
255 }
256 stderr_buffer.drain(0..start_ix);
257 stderr_offset -= start_ix;
258 }
259 Err(error) => {
260 Err(anyhow!("error reading stderr: {error:?}"))?;
261 }
262 }
263 }
264 }
265 }
266 }).detach();
267
268 cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, Some(socket), cx))
269 }
270
271 pub fn server(
272 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
273 outgoing_tx: mpsc::UnboundedSender<Envelope>,
274 cx: &AppContext,
275 ) -> Arc<SshSession> {
276 let (tx, _rx) = mpsc::unbounded();
277 Self::new(incoming_rx, outgoing_tx, tx, None, cx)
278 }
279
280 #[cfg(any(test, feature = "test-support"))]
281 pub fn fake(
282 client_cx: &mut gpui::TestAppContext,
283 server_cx: &mut gpui::TestAppContext,
284 ) -> (Arc<Self>, Arc<Self>) {
285 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
286 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
287 let (tx, _rx) = mpsc::unbounded();
288 (
289 client_cx.update(|cx| {
290 Self::new(
291 server_to_client_rx,
292 client_to_server_tx,
293 tx.clone(),
294 None, // todo()
295 cx,
296 )
297 }),
298 server_cx.update(|cx| {
299 Self::new(
300 client_to_server_rx,
301 server_to_client_tx,
302 tx.clone(),
303 None,
304 cx,
305 )
306 }),
307 )
308 }
309
310 fn new(
311 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
312 outgoing_tx: mpsc::UnboundedSender<Envelope>,
313 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
314 client_socket: Option<SshSocket>,
315 cx: &AppContext,
316 ) -> Arc<SshSession> {
317 let this = Arc::new(Self {
318 next_message_id: AtomicU32::new(0),
319 response_channels: ResponseChannels::default(),
320 outgoing_tx,
321 spawn_process_tx,
322 client_socket,
323 state: Default::default(),
324 });
325
326 cx.spawn(|cx| {
327 let this = this.clone();
328 async move {
329 let peer_id = PeerId { owner_id: 0, id: 0 };
330 while let Some(incoming) = incoming_rx.next().await {
331 if let Some(request_id) = incoming.responding_to {
332 let request_id = MessageId(request_id);
333 let sender = this.response_channels.lock().remove(&request_id);
334 if let Some(sender) = sender {
335 let (tx, rx) = oneshot::channel();
336 if incoming.payload.is_some() {
337 sender.send((incoming, tx)).ok();
338 }
339 rx.await.ok();
340 }
341 } else if let Some(envelope) =
342 build_typed_envelope(peer_id, Instant::now(), incoming)
343 {
344 let type_name = envelope.payload_type_name();
345 if let Some(future) = ProtoMessageHandlerSet::handle_message(
346 &this.state,
347 envelope,
348 this.clone().into(),
349 cx.clone(),
350 ) {
351 log::debug!("ssh message received. name:{type_name}");
352 match future.await {
353 Ok(_) => {
354 log::debug!("ssh message handled. name:{type_name}");
355 }
356 Err(error) => {
357 log::error!(
358 "error handling message. type:{type_name}, error:{error}",
359 );
360 }
361 }
362 } else {
363 log::error!("unhandled ssh message name:{type_name}");
364 }
365 }
366 }
367 anyhow::Ok(())
368 }
369 })
370 .detach();
371
372 this
373 }
374
375 pub fn request<T: RequestMessage>(
376 &self,
377 payload: T,
378 ) -> impl 'static + Future<Output = Result<T::Response>> {
379 log::debug!("ssh request start. name:{}", T::NAME);
380 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
381 async move {
382 let response = response.await?;
383 log::debug!("ssh request finish. name:{}", T::NAME);
384 T::Response::from_envelope(response)
385 .ok_or_else(|| anyhow!("received a response of the wrong type"))
386 }
387 }
388
389 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
390 log::debug!("ssh send name:{}", T::NAME);
391 self.send_dynamic(payload.into_envelope(0, None, None))
392 }
393
394 pub fn request_dynamic(
395 &self,
396 mut envelope: proto::Envelope,
397 type_name: &'static str,
398 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
399 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
400 let (tx, rx) = oneshot::channel();
401 let mut response_channels_lock = self.response_channels.lock();
402 response_channels_lock.insert(MessageId(envelope.id), tx);
403 drop(response_channels_lock);
404 self.outgoing_tx.unbounded_send(envelope).ok();
405 async move {
406 let response = rx.await.context("connection lost")?.0;
407 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
408 return Err(RpcError::from_proto(error, type_name));
409 }
410 Ok(response)
411 }
412 }
413
414 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
415 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
416 self.outgoing_tx.unbounded_send(envelope)?;
417 Ok(())
418 }
419
420 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
421 let id = (TypeId::of::<E>(), remote_id);
422
423 let mut state = self.state.lock();
424 if state.entities_by_type_and_remote_id.contains_key(&id) {
425 panic!("already subscribed to entity");
426 }
427
428 state.entities_by_type_and_remote_id.insert(
429 id,
430 EntityMessageSubscriber::Entity {
431 handle: entity.downgrade().into(),
432 },
433 );
434 }
435
436 pub async fn spawn_process(&self, command: String) -> process::Child {
437 let (process_tx, process_rx) = oneshot::channel();
438 self.spawn_process_tx
439 .unbounded_send(SpawnRequest {
440 command,
441 process_tx,
442 })
443 .ok();
444 process_rx.await.unwrap()
445 }
446
447 pub fn ssh_args(&self) -> Vec<String> {
448 self.client_socket.as_ref().unwrap().ssh_args()
449 }
450}
451
452impl ProtoClient for SshSession {
453 fn request(
454 &self,
455 envelope: proto::Envelope,
456 request_type: &'static str,
457 ) -> BoxFuture<'static, Result<proto::Envelope>> {
458 self.request_dynamic(envelope, request_type).boxed()
459 }
460
461 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
462 self.send_dynamic(envelope)
463 }
464
465 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
466 self.send_dynamic(envelope)
467 }
468
469 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
470 &self.state
471 }
472}
473
474impl SshClientState {
475 #[cfg(not(unix))]
476 async fn new(
477 _connection_options: SshConnectionOptions,
478 _delegate: Arc<dyn SshClientDelegate>,
479 _cx: &mut AsyncAppContext,
480 ) -> Result<Self> {
481 Err(anyhow!("ssh is not supported on this platform"))
482 }
483
484 #[cfg(unix)]
485 async fn new(
486 connection_options: SshConnectionOptions,
487 delegate: Arc<dyn SshClientDelegate>,
488 cx: &mut AsyncAppContext,
489 ) -> Result<Self> {
490 use futures::{io::BufReader, AsyncBufReadExt as _};
491 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
492 use util::ResultExt as _;
493
494 delegate.set_status(Some("connecting"), cx);
495
496 let url = connection_options.ssh_url();
497 let temp_dir = tempfile::Builder::new()
498 .prefix("zed-ssh-session")
499 .tempdir()?;
500
501 // Create a domain socket listener to handle requests from the askpass program.
502 let askpass_socket = temp_dir.path().join("askpass.sock");
503 let listener =
504 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
505
506 let askpass_task = cx.spawn(|mut cx| async move {
507 while let Ok((mut stream, _)) = listener.accept().await {
508 let mut buffer = Vec::new();
509 let mut reader = BufReader::new(&mut stream);
510 if reader.read_until(b'\0', &mut buffer).await.is_err() {
511 buffer.clear();
512 }
513 let password_prompt = String::from_utf8_lossy(&buffer);
514 if let Some(password) = delegate
515 .ask_password(password_prompt.to_string(), &mut cx)
516 .await
517 .context("failed to get ssh password")
518 .and_then(|p| p)
519 .log_err()
520 {
521 stream.write_all(password.as_bytes()).await.log_err();
522 }
523 }
524 });
525
526 // Create an askpass script that communicates back to this process.
527 let askpass_script = format!(
528 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
529 askpass_socket = askpass_socket.display(),
530 print_args = "printf '%s\\0' \"$@\"",
531 shebang = "#!/bin/sh",
532 );
533 let askpass_script_path = temp_dir.path().join("askpass.sh");
534 fs::write(&askpass_script_path, askpass_script).await?;
535 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
536
537 // Start the master SSH process, which does not do anything except for establish
538 // the connection and keep it open, allowing other ssh commands to reuse it
539 // via a control socket.
540 let socket_path = temp_dir.path().join("ssh.sock");
541 let mut master_process = process::Command::new("ssh")
542 .stdin(Stdio::null())
543 .stdout(Stdio::piped())
544 .stderr(Stdio::piped())
545 .env("SSH_ASKPASS_REQUIRE", "force")
546 .env("SSH_ASKPASS", &askpass_script_path)
547 .args(["-N", "-o", "ControlMaster=yes", "-o"])
548 .arg(format!("ControlPath={}", socket_path.display()))
549 .arg(&url)
550 .spawn()?;
551
552 // Wait for this ssh process to close its stdout, indicating that authentication
553 // has completed.
554 let stdout = master_process.stdout.as_mut().unwrap();
555 let mut output = Vec::new();
556 stdout.read_to_end(&mut output).await?;
557 drop(askpass_task);
558
559 if master_process.try_status()?.is_some() {
560 output.clear();
561 let mut stderr = master_process.stderr.take().unwrap();
562 stderr.read_to_end(&mut output).await?;
563 Err(anyhow!(
564 "failed to connect: {}",
565 String::from_utf8_lossy(&output)
566 ))?;
567 }
568
569 Ok(Self {
570 socket: SshSocket {
571 connection_options,
572 socket_path,
573 },
574 _master_process: master_process,
575 _temp_dir: temp_dir,
576 })
577 }
578
579 async fn ensure_server_binary(
580 &self,
581 delegate: &Arc<dyn SshClientDelegate>,
582 src_path: &Path,
583 dst_path: &Path,
584 version: SemanticVersion,
585 cx: &mut AsyncAppContext,
586 ) -> Result<()> {
587 let mut dst_path_gz = dst_path.to_path_buf();
588 dst_path_gz.set_extension("gz");
589
590 if let Some(parent) = dst_path.parent() {
591 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
592 }
593
594 let mut server_binary_exists = false;
595 if cfg!(not(debug_assertions)) {
596 if let Ok(installed_version) =
597 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
598 {
599 if installed_version.trim() == version.to_string() {
600 server_binary_exists = true;
601 }
602 }
603 }
604
605 if server_binary_exists {
606 log::info!("remote development server already present",);
607 return Ok(());
608 }
609
610 let src_stat = fs::metadata(src_path).await?;
611 let size = src_stat.len();
612 let server_mode = 0o755;
613
614 let t0 = Instant::now();
615 delegate.set_status(Some("uploading remote development server"), cx);
616 log::info!("uploading remote development server ({}kb)", size / 1024);
617 self.upload_file(src_path, &dst_path_gz)
618 .await
619 .context("failed to upload server binary")?;
620 log::info!("uploaded remote development server in {:?}", t0.elapsed());
621
622 delegate.set_status(Some("extracting remote development server"), cx);
623 run_cmd(
624 self.socket
625 .ssh_command("gunzip")
626 .arg("--force")
627 .arg(&dst_path_gz),
628 )
629 .await?;
630
631 delegate.set_status(Some("unzipping remote development server"), cx);
632 run_cmd(
633 self.socket
634 .ssh_command("chmod")
635 .arg(format!("{:o}", server_mode))
636 .arg(dst_path),
637 )
638 .await?;
639
640 Ok(())
641 }
642
643 async fn query_platform(&self) -> Result<SshPlatform> {
644 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
645 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
646
647 let os = match os.trim() {
648 "Darwin" => "macos",
649 "Linux" => "linux",
650 _ => Err(anyhow!("unknown uname os {os:?}"))?,
651 };
652 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
653 "aarch64"
654 } else if arch.starts_with("x86") || arch.starts_with("i686") {
655 "x86_64"
656 } else {
657 Err(anyhow!("unknown uname architecture {arch:?}"))?
658 };
659
660 Ok(SshPlatform { os, arch })
661 }
662
663 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
664 let mut command = process::Command::new("scp");
665 let output = self
666 .socket
667 .ssh_options(&mut command)
668 .args(
669 self.socket
670 .connection_options
671 .port
672 .map(|port| vec!["-P".to_string(), port.to_string()])
673 .unwrap_or_default(),
674 )
675 .arg(src_path)
676 .arg(format!(
677 "{}:{}",
678 self.socket.connection_options.scp_url(),
679 dest_path.display()
680 ))
681 .output()
682 .await?;
683
684 if output.status.success() {
685 Ok(())
686 } else {
687 Err(anyhow!(
688 "failed to upload file {} -> {}: {}",
689 src_path.display(),
690 dest_path.display(),
691 String::from_utf8_lossy(&output.stderr)
692 ))
693 }
694 }
695}
696
697impl SshSocket {
698 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
699 let mut command = process::Command::new("ssh");
700 self.ssh_options(&mut command)
701 .arg(self.connection_options.ssh_url())
702 .arg(program);
703 command
704 }
705
706 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
707 command
708 .stdin(Stdio::piped())
709 .stdout(Stdio::piped())
710 .stderr(Stdio::piped())
711 .args(["-o", "ControlMaster=no", "-o"])
712 .arg(format!("ControlPath={}", self.socket_path.display()))
713 }
714
715 fn ssh_args(&self) -> Vec<String> {
716 vec![
717 "-o".to_string(),
718 "ControlMaster=no".to_string(),
719 "-o".to_string(),
720 format!("ControlPath={}", self.socket_path.display()),
721 self.connection_options.ssh_url(),
722 ]
723 }
724}
725
726async fn run_cmd(command: &mut process::Command) -> Result<String> {
727 let output = command.output().await?;
728 if output.status.success() {
729 Ok(String::from_utf8_lossy(&output.stdout).to_string())
730 } else {
731 Err(anyhow!(
732 "failed to run command: {}",
733 String::from_utf8_lossy(&output.stderr)
734 ))
735 }
736}