1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6};
7use anyhow::{anyhow, Context as _, Result};
8use collections::HashMap;
9use futures::{
10 channel::{
11 mpsc::{self, UnboundedReceiver, UnboundedSender},
12 oneshot,
13 },
14 future::BoxFuture,
15 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
16 StreamExt as _,
17};
18use gpui::{
19 AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
20};
21use parking_lot::Mutex;
22use rpc::{
23 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
24 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
25};
26use smol::{
27 fs,
28 process::{self, Child, Stdio},
29};
30use std::{
31 any::TypeId,
32 ffi::OsStr,
33 mem,
34 path::{Path, PathBuf},
35 sync::{
36 atomic::{AtomicU32, Ordering::SeqCst},
37 Arc,
38 },
39 time::Instant,
40};
41use tempfile::TempDir;
42use util::maybe;
43
44#[derive(
45 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
46)]
47pub struct SshProjectId(pub u64);
48
49#[derive(Clone)]
50pub struct SshSocket {
51 connection_options: SshConnectionOptions,
52 socket_path: PathBuf,
53}
54
55#[derive(Debug, Default, Clone, PartialEq, Eq)]
56pub struct SshConnectionOptions {
57 pub host: String,
58 pub username: Option<String>,
59 pub port: Option<u16>,
60 pub password: Option<String>,
61}
62
63impl SshConnectionOptions {
64 pub fn ssh_url(&self) -> String {
65 let mut result = String::from("ssh://");
66 if let Some(username) = &self.username {
67 result.push_str(username);
68 result.push('@');
69 }
70 result.push_str(&self.host);
71 if let Some(port) = self.port {
72 result.push(':');
73 result.push_str(&port.to_string());
74 }
75 result
76 }
77
78 fn scp_url(&self) -> String {
79 if let Some(username) = &self.username {
80 format!("{}@{}", username, self.host)
81 } else {
82 self.host.clone()
83 }
84 }
85
86 pub fn connection_string(&self) -> String {
87 let host = if let Some(username) = &self.username {
88 format!("{}@{}", username, self.host)
89 } else {
90 self.host.clone()
91 };
92 if let Some(port) = &self.port {
93 format!("{}:{}", host, port)
94 } else {
95 host
96 }
97 }
98
99 // Uniquely identifies dev server projects on a remote host. Needs to be
100 // stable for the same dev server project.
101 pub fn dev_server_identifier(&self) -> String {
102 let mut identifier = format!("dev-server-{:?}", self.host);
103 if let Some(username) = self.username.as_ref() {
104 identifier.push('-');
105 identifier.push_str(&username);
106 }
107 identifier
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct SshPlatform {
113 pub os: &'static str,
114 pub arch: &'static str,
115}
116
117pub trait SshClientDelegate: Send + Sync {
118 fn ask_password(
119 &self,
120 prompt: String,
121 cx: &mut AsyncAppContext,
122 ) -> oneshot::Receiver<Result<String>>;
123 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
124 fn get_server_binary(
125 &self,
126 platform: SshPlatform,
127 cx: &mut AsyncAppContext,
128 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
129 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
130 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
131}
132
133impl SshSocket {
134 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
135 let mut command = process::Command::new("ssh");
136 self.ssh_options(&mut command)
137 .arg(self.connection_options.ssh_url())
138 .arg(program);
139 command
140 }
141
142 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
143 command
144 .stdin(Stdio::piped())
145 .stdout(Stdio::piped())
146 .stderr(Stdio::piped())
147 .args(["-o", "ControlMaster=no", "-o"])
148 .arg(format!("ControlPath={}", self.socket_path.display()))
149 }
150
151 fn ssh_args(&self) -> Vec<String> {
152 vec![
153 "-o".to_string(),
154 "ControlMaster=no".to_string(),
155 "-o".to_string(),
156 format!("ControlPath={}", self.socket_path.display()),
157 self.connection_options.ssh_url(),
158 ]
159 }
160}
161
162async fn run_cmd(command: &mut process::Command) -> Result<String> {
163 let output = command.output().await?;
164 if output.status.success() {
165 Ok(String::from_utf8_lossy(&output.stdout).to_string())
166 } else {
167 Err(anyhow!(
168 "failed to run command: {}",
169 String::from_utf8_lossy(&output.stderr)
170 ))
171 }
172}
173#[cfg(unix)]
174async fn read_with_timeout(
175 stdout: &mut process::ChildStdout,
176 timeout: std::time::Duration,
177 output: &mut Vec<u8>,
178) -> Result<(), std::io::Error> {
179 smol::future::or(
180 async {
181 stdout.read_to_end(output).await?;
182 Ok::<_, std::io::Error>(())
183 },
184 async {
185 smol::Timer::after(timeout).await;
186
187 Err(std::io::Error::new(
188 std::io::ErrorKind::TimedOut,
189 "Read operation timed out",
190 ))
191 },
192 )
193 .await
194}
195
196struct ChannelForwarder {
197 quit_tx: UnboundedSender<()>,
198 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
199}
200
201impl ChannelForwarder {
202 fn new(
203 mut incoming_tx: UnboundedSender<Envelope>,
204 mut outgoing_rx: UnboundedReceiver<Envelope>,
205 cx: &AsyncAppContext,
206 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
207 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
208
209 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
210 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
211
212 let forwarding_task = cx.background_executor().spawn(async move {
213 loop {
214 select_biased! {
215 _ = quit_rx.next().fuse() => {
216 break;
217 },
218 incoming_envelope = proxy_incoming_rx.next().fuse() => {
219 if let Some(envelope) = incoming_envelope {
220 if incoming_tx.send(envelope).await.is_err() {
221 break;
222 }
223 } else {
224 break;
225 }
226 }
227 outgoing_envelope = outgoing_rx.next().fuse() => {
228 if let Some(envelope) = outgoing_envelope {
229 if proxy_outgoing_tx.send(envelope).await.is_err() {
230 break;
231 }
232 } else {
233 break;
234 }
235 }
236 }
237 }
238
239 (incoming_tx, outgoing_rx)
240 });
241
242 (
243 Self {
244 forwarding_task,
245 quit_tx,
246 },
247 proxy_incoming_tx,
248 proxy_outgoing_rx,
249 )
250 }
251
252 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
253 let _ = self.quit_tx.send(()).await;
254 self.forwarding_task.await
255 }
256}
257
258struct SshRemoteClientState {
259 ssh_connection: SshRemoteConnection,
260 delegate: Arc<dyn SshClientDelegate>,
261 forwarder: ChannelForwarder,
262 multiplex_task: Task<Result<()>>,
263}
264
265pub struct SshRemoteClient {
266 client: Arc<ChannelClient>,
267 unique_identifier: String,
268 connection_options: SshConnectionOptions,
269 inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
270}
271
272impl Drop for SshRemoteClient {
273 fn drop(&mut self) {
274 self.shutdown_processes();
275 }
276}
277
278impl SshRemoteClient {
279 pub fn new(
280 unique_identifier: String,
281 connection_options: SshConnectionOptions,
282 delegate: Arc<dyn SshClientDelegate>,
283 cx: &AppContext,
284 ) -> Task<Result<Model<Self>>> {
285 cx.spawn(|mut cx| async move {
286 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
287 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
288
289 let this = cx.new_model(|cx| {
290 cx.on_app_quit(|this: &mut Self, _| {
291 this.shutdown_processes();
292 futures::future::ready(())
293 })
294 .detach();
295
296 let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
297 Self {
298 client,
299 unique_identifier: unique_identifier.clone(),
300 connection_options: SshConnectionOptions::default(),
301 inner_state: Arc::new(Mutex::new(None)),
302 }
303 })?;
304
305 let inner_state = {
306 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
307 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
308
309 let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
310 unique_identifier,
311 connection_options,
312 delegate.clone(),
313 &mut cx,
314 )
315 .await?;
316
317 let multiplex_task = Self::multiplex(
318 this.downgrade(),
319 ssh_proxy_process,
320 proxy_incoming_tx,
321 proxy_outgoing_rx,
322 &mut cx,
323 );
324
325 SshRemoteClientState {
326 ssh_connection,
327 delegate,
328 forwarder: proxy,
329 multiplex_task,
330 }
331 };
332
333 this.update(&mut cx, |this, cx| {
334 this.inner_state.lock().replace(inner_state);
335 cx.notify();
336 })?;
337
338 Ok(this)
339 })
340 }
341
342 fn shutdown_processes(&self) {
343 let Some(mut state) = self.inner_state.lock().take() else {
344 return;
345 };
346 log::info!("shutting down ssh processes");
347 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
348 // child of master_process.
349 let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
350 drop(task);
351 // Now drop the rest of state, which kills master process.
352 drop(state);
353 }
354
355 fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
356 let Some(state) = self.inner_state.lock().take() else {
357 return Err(anyhow!("reconnect is already in progress"));
358 };
359
360 let workspace_identifier = self.unique_identifier.clone();
361
362 let SshRemoteClientState {
363 mut ssh_connection,
364 delegate,
365 forwarder: proxy,
366 multiplex_task,
367 } = state;
368 drop(multiplex_task);
369
370 cx.spawn(|this, mut cx| async move {
371 let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
372
373 ssh_connection.master_process.kill()?;
374 ssh_connection
375 .master_process
376 .status()
377 .await
378 .context("Failed to kill ssh process")?;
379
380 let connection_options = ssh_connection.socket.connection_options.clone();
381
382 let (ssh_connection, ssh_process) = Self::establish_connection(
383 workspace_identifier,
384 connection_options,
385 delegate.clone(),
386 &mut cx,
387 )
388 .await?;
389
390 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
391 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
392
393 let inner_state = SshRemoteClientState {
394 ssh_connection,
395 delegate,
396 forwarder: proxy,
397 multiplex_task: Self::multiplex(
398 this.clone(),
399 ssh_process,
400 proxy_incoming_tx,
401 proxy_outgoing_rx,
402 &mut cx,
403 ),
404 };
405
406 this.update(&mut cx, |this, _| {
407 this.inner_state.lock().replace(inner_state);
408 })
409 })
410 .detach();
411 Ok(())
412 }
413
414 fn multiplex(
415 this: WeakModel<Self>,
416 mut ssh_proxy_process: Child,
417 incoming_tx: UnboundedSender<Envelope>,
418 mut outgoing_rx: UnboundedReceiver<Envelope>,
419 cx: &AsyncAppContext,
420 ) -> Task<Result<()>> {
421 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
422 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
423 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
424
425 let io_task = cx.background_executor().spawn(async move {
426 let mut stdin_buffer = Vec::new();
427 let mut stdout_buffer = Vec::new();
428 let mut stderr_buffer = Vec::new();
429 let mut stderr_offset = 0;
430
431 loop {
432 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
433 stderr_buffer.resize(stderr_offset + 1024, 0);
434
435 select_biased! {
436 outgoing = outgoing_rx.next().fuse() => {
437 let Some(outgoing) = outgoing else {
438 return anyhow::Ok(());
439 };
440
441 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
442 }
443
444 result = child_stdout.read(&mut stdout_buffer).fuse() => {
445 match result {
446 Ok(0) => {
447 child_stdin.close().await?;
448 outgoing_rx.close();
449 let status = ssh_proxy_process.status().await?;
450 if !status.success() {
451 log::error!("ssh process exited with status: {status:?}");
452 return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
453 }
454 return Ok(());
455 }
456 Ok(len) => {
457 if len < stdout_buffer.len() {
458 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
459 }
460
461 let message_len = message_len_from_buffer(&stdout_buffer);
462 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
463 Ok(envelope) => {
464 incoming_tx.unbounded_send(envelope).ok();
465 }
466 Err(error) => {
467 log::error!("error decoding message {error:?}");
468 }
469 }
470 }
471 Err(error) => {
472 Err(anyhow!("error reading stdout: {error:?}"))?;
473 }
474 }
475 }
476
477 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
478 match result {
479 Ok(len) => {
480 stderr_offset += len;
481 let mut start_ix = 0;
482 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
483 let line_ix = start_ix + ix;
484 let content = &stderr_buffer[start_ix..line_ix];
485 start_ix = line_ix + 1;
486 if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
487 record.message = format!("(remote) {}", record.message);
488 record.log(log::logger())
489 } else {
490 eprintln!("(remote) {}", String::from_utf8_lossy(content));
491 }
492 }
493 stderr_buffer.drain(0..start_ix);
494 stderr_offset -= start_ix;
495 }
496 Err(error) => {
497 Err(anyhow!("error reading stderr: {error:?}"))?;
498 }
499 }
500 }
501 }
502 }
503 });
504
505 cx.spawn(|mut cx| async move {
506 let result = io_task.await;
507
508 if let Err(error) = result {
509 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
510 this.update(&mut cx, |this, cx| {
511 this.reconnect(cx).ok();
512 })?;
513 }
514
515 Ok(())
516 })
517 }
518
519 async fn establish_connection(
520 unique_identifier: String,
521 connection_options: SshConnectionOptions,
522 delegate: Arc<dyn SshClientDelegate>,
523 cx: &mut AsyncAppContext,
524 ) -> Result<(SshRemoteConnection, Child)> {
525 let ssh_connection =
526 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
527
528 let platform = ssh_connection.query_platform().await?;
529 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
530 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
531 ssh_connection
532 .ensure_server_binary(
533 &delegate,
534 &local_binary_path,
535 &remote_binary_path,
536 version,
537 cx,
538 )
539 .await?;
540
541 let socket = ssh_connection.socket.clone();
542 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
543
544 delegate.set_status(Some("Starting proxy"), cx);
545
546 let ssh_proxy_process = socket
547 .ssh_command(format!(
548 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
549 std::env::var("RUST_LOG").unwrap_or_default(),
550 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
551 remote_binary_path,
552 unique_identifier,
553 ))
554 // IMPORTANT: we kill this process when we drop the task that uses it.
555 .kill_on_drop(true)
556 .spawn()
557 .context("failed to spawn remote server")?;
558
559 Ok((ssh_connection, ssh_proxy_process))
560 }
561
562 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
563 self.client.subscribe_to_entity(remote_id, entity);
564 }
565
566 pub fn ssh_args(&self) -> Option<Vec<String>> {
567 let state = self.inner_state.lock();
568 state
569 .as_ref()
570 .map(|state| state.ssh_connection.socket.ssh_args())
571 }
572
573 pub fn to_proto_client(&self) -> AnyProtoClient {
574 self.client.clone().into()
575 }
576
577 pub fn connection_string(&self) -> String {
578 self.connection_options.connection_string()
579 }
580
581 pub fn is_reconnect_underway(&self) -> bool {
582 maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default()
583 }
584
585 #[cfg(any(test, feature = "test-support"))]
586 pub fn fake(
587 client_cx: &mut gpui::TestAppContext,
588 server_cx: &mut gpui::TestAppContext,
589 ) -> (Model<Self>, Arc<ChannelClient>) {
590 use gpui::Context;
591
592 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
593 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
594
595 (
596 client_cx.update(|cx| {
597 let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
598 cx.new_model(|_| Self {
599 client,
600 unique_identifier: "fake".to_string(),
601 connection_options: SshConnectionOptions::default(),
602 inner_state: Arc::new(Mutex::new(None)),
603 })
604 }),
605 server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
606 )
607 }
608}
609
610impl From<SshRemoteClient> for AnyProtoClient {
611 fn from(client: SshRemoteClient) -> Self {
612 AnyProtoClient::new(client.client.clone())
613 }
614}
615
616struct SshRemoteConnection {
617 socket: SshSocket,
618 master_process: process::Child,
619 _temp_dir: TempDir,
620}
621
622impl Drop for SshRemoteConnection {
623 fn drop(&mut self) {
624 if let Err(error) = self.master_process.kill() {
625 log::error!("failed to kill SSH master process: {}", error);
626 }
627 }
628}
629
630impl SshRemoteConnection {
631 #[cfg(not(unix))]
632 async fn new(
633 _connection_options: SshConnectionOptions,
634 _delegate: Arc<dyn SshClientDelegate>,
635 _cx: &mut AsyncAppContext,
636 ) -> Result<Self> {
637 Err(anyhow!("ssh is not supported on this platform"))
638 }
639
640 #[cfg(unix)]
641 async fn new(
642 connection_options: SshConnectionOptions,
643 delegate: Arc<dyn SshClientDelegate>,
644 cx: &mut AsyncAppContext,
645 ) -> Result<Self> {
646 use futures::{io::BufReader, AsyncBufReadExt as _};
647 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
648 use util::ResultExt as _;
649
650 delegate.set_status(Some("connecting"), cx);
651
652 let url = connection_options.ssh_url();
653 let temp_dir = tempfile::Builder::new()
654 .prefix("zed-ssh-session")
655 .tempdir()?;
656
657 // Create a domain socket listener to handle requests from the askpass program.
658 let askpass_socket = temp_dir.path().join("askpass.sock");
659 let listener =
660 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
661
662 let askpass_task = cx.spawn({
663 let delegate = delegate.clone();
664 |mut cx| async move {
665 while let Ok((mut stream, _)) = listener.accept().await {
666 let mut buffer = Vec::new();
667 let mut reader = BufReader::new(&mut stream);
668 if reader.read_until(b'\0', &mut buffer).await.is_err() {
669 buffer.clear();
670 }
671 let password_prompt = String::from_utf8_lossy(&buffer);
672 if let Some(password) = delegate
673 .ask_password(password_prompt.to_string(), &mut cx)
674 .await
675 .context("failed to get ssh password")
676 .and_then(|p| p)
677 .log_err()
678 {
679 stream.write_all(password.as_bytes()).await.log_err();
680 }
681 }
682 }
683 });
684
685 // Create an askpass script that communicates back to this process.
686 let askpass_script = format!(
687 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
688 askpass_socket = askpass_socket.display(),
689 print_args = "printf '%s\\0' \"$@\"",
690 shebang = "#!/bin/sh",
691 );
692 let askpass_script_path = temp_dir.path().join("askpass.sh");
693 fs::write(&askpass_script_path, askpass_script).await?;
694 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
695
696 // Start the master SSH process, which does not do anything except for establish
697 // the connection and keep it open, allowing other ssh commands to reuse it
698 // via a control socket.
699 let socket_path = temp_dir.path().join("ssh.sock");
700 let mut master_process = process::Command::new("ssh")
701 .stdin(Stdio::null())
702 .stdout(Stdio::piped())
703 .stderr(Stdio::piped())
704 .env("SSH_ASKPASS_REQUIRE", "force")
705 .env("SSH_ASKPASS", &askpass_script_path)
706 .args(["-N", "-o", "ControlMaster=yes", "-o"])
707 .arg(format!("ControlPath={}", socket_path.display()))
708 .arg(&url)
709 .spawn()?;
710
711 // Wait for this ssh process to close its stdout, indicating that authentication
712 // has completed.
713 let stdout = master_process.stdout.as_mut().unwrap();
714 let mut output = Vec::new();
715 let connection_timeout = std::time::Duration::from_secs(10);
716 let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
717 if let Err(e) = result {
718 let error_message = if e.kind() == std::io::ErrorKind::TimedOut {
719 format!(
720 "Failed to connect to host. Timed out after {:?}.",
721 connection_timeout
722 )
723 } else {
724 format!("Failed to connect to host: {}.", e)
725 };
726
727 delegate.set_error(error_message, cx);
728 return Err(e.into());
729 }
730
731 drop(askpass_task);
732
733 if master_process.try_status()?.is_some() {
734 output.clear();
735 let mut stderr = master_process.stderr.take().unwrap();
736 stderr.read_to_end(&mut output).await?;
737 Err(anyhow!(
738 "failed to connect: {}",
739 String::from_utf8_lossy(&output)
740 ))?;
741 }
742
743 Ok(Self {
744 socket: SshSocket {
745 connection_options,
746 socket_path,
747 },
748 master_process,
749 _temp_dir: temp_dir,
750 })
751 }
752
753 async fn ensure_server_binary(
754 &self,
755 delegate: &Arc<dyn SshClientDelegate>,
756 src_path: &Path,
757 dst_path: &Path,
758 version: SemanticVersion,
759 cx: &mut AsyncAppContext,
760 ) -> Result<()> {
761 let mut dst_path_gz = dst_path.to_path_buf();
762 dst_path_gz.set_extension("gz");
763
764 if let Some(parent) = dst_path.parent() {
765 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
766 }
767
768 let mut server_binary_exists = false;
769 if cfg!(not(debug_assertions)) {
770 if let Ok(installed_version) =
771 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
772 {
773 if installed_version.trim() == version.to_string() {
774 server_binary_exists = true;
775 }
776 }
777 }
778
779 if server_binary_exists {
780 log::info!("remote development server already present",);
781 return Ok(());
782 }
783
784 let src_stat = fs::metadata(src_path).await?;
785 let size = src_stat.len();
786 let server_mode = 0o755;
787
788 let t0 = Instant::now();
789 delegate.set_status(Some("uploading remote development server"), cx);
790 log::info!("uploading remote development server ({}kb)", size / 1024);
791 self.upload_file(src_path, &dst_path_gz)
792 .await
793 .context("failed to upload server binary")?;
794 log::info!("uploaded remote development server in {:?}", t0.elapsed());
795
796 delegate.set_status(Some("extracting remote development server"), cx);
797 run_cmd(
798 self.socket
799 .ssh_command("gunzip")
800 .arg("--force")
801 .arg(&dst_path_gz),
802 )
803 .await?;
804
805 delegate.set_status(Some("unzipping remote development server"), cx);
806 run_cmd(
807 self.socket
808 .ssh_command("chmod")
809 .arg(format!("{:o}", server_mode))
810 .arg(dst_path),
811 )
812 .await?;
813
814 Ok(())
815 }
816
817 async fn query_platform(&self) -> Result<SshPlatform> {
818 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
819 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
820
821 let os = match os.trim() {
822 "Darwin" => "macos",
823 "Linux" => "linux",
824 _ => Err(anyhow!("unknown uname os {os:?}"))?,
825 };
826 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
827 "aarch64"
828 } else if arch.starts_with("x86") || arch.starts_with("i686") {
829 "x86_64"
830 } else {
831 Err(anyhow!("unknown uname architecture {arch:?}"))?
832 };
833
834 Ok(SshPlatform { os, arch })
835 }
836
837 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
838 let mut command = process::Command::new("scp");
839 let output = self
840 .socket
841 .ssh_options(&mut command)
842 .args(
843 self.socket
844 .connection_options
845 .port
846 .map(|port| vec!["-P".to_string(), port.to_string()])
847 .unwrap_or_default(),
848 )
849 .arg(src_path)
850 .arg(format!(
851 "{}:{}",
852 self.socket.connection_options.scp_url(),
853 dest_path.display()
854 ))
855 .output()
856 .await?;
857
858 if output.status.success() {
859 Ok(())
860 } else {
861 Err(anyhow!(
862 "failed to upload file {} -> {}: {}",
863 src_path.display(),
864 dest_path.display(),
865 String::from_utf8_lossy(&output.stderr)
866 ))
867 }
868 }
869}
870
871type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
872
873pub struct ChannelClient {
874 next_message_id: AtomicU32,
875 outgoing_tx: mpsc::UnboundedSender<Envelope>,
876 response_channels: ResponseChannels, // Lock
877 message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
878}
879
880impl ChannelClient {
881 pub fn new(
882 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
883 outgoing_tx: mpsc::UnboundedSender<Envelope>,
884 cx: &AppContext,
885 ) -> Arc<Self> {
886 let this = Arc::new(Self {
887 outgoing_tx,
888 next_message_id: AtomicU32::new(0),
889 response_channels: ResponseChannels::default(),
890 message_handlers: Default::default(),
891 });
892
893 Self::start_handling_messages(this.clone(), incoming_rx, cx);
894
895 this
896 }
897
898 fn start_handling_messages(
899 this: Arc<Self>,
900 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
901 cx: &AppContext,
902 ) {
903 cx.spawn(|cx| {
904 let this = Arc::downgrade(&this);
905 async move {
906 let peer_id = PeerId { owner_id: 0, id: 0 };
907 while let Some(incoming) = incoming_rx.next().await {
908 let Some(this) = this.upgrade() else {
909 return anyhow::Ok(());
910 };
911
912 if let Some(request_id) = incoming.responding_to {
913 let request_id = MessageId(request_id);
914 let sender = this.response_channels.lock().remove(&request_id);
915 if let Some(sender) = sender {
916 let (tx, rx) = oneshot::channel();
917 if incoming.payload.is_some() {
918 sender.send((incoming, tx)).ok();
919 }
920 rx.await.ok();
921 }
922 } else if let Some(envelope) =
923 build_typed_envelope(peer_id, Instant::now(), incoming)
924 {
925 let type_name = envelope.payload_type_name();
926 if let Some(future) = ProtoMessageHandlerSet::handle_message(
927 &this.message_handlers,
928 envelope,
929 this.clone().into(),
930 cx.clone(),
931 ) {
932 log::debug!("ssh message received. name:{type_name}");
933 match future.await {
934 Ok(_) => {
935 log::debug!("ssh message handled. name:{type_name}");
936 }
937 Err(error) => {
938 log::error!(
939 "error handling message. type:{type_name}, error:{error}",
940 );
941 }
942 }
943 } else {
944 log::error!("unhandled ssh message name:{type_name}");
945 }
946 }
947 }
948 anyhow::Ok(())
949 }
950 })
951 .detach();
952 }
953
954 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
955 let id = (TypeId::of::<E>(), remote_id);
956
957 let mut message_handlers = self.message_handlers.lock();
958 if message_handlers
959 .entities_by_type_and_remote_id
960 .contains_key(&id)
961 {
962 panic!("already subscribed to entity");
963 }
964
965 message_handlers.entities_by_type_and_remote_id.insert(
966 id,
967 EntityMessageSubscriber::Entity {
968 handle: entity.downgrade().into(),
969 },
970 );
971 }
972
973 pub fn request<T: RequestMessage>(
974 &self,
975 payload: T,
976 ) -> impl 'static + Future<Output = Result<T::Response>> {
977 log::debug!("ssh request start. name:{}", T::NAME);
978 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
979 async move {
980 let response = response.await?;
981 log::debug!("ssh request finish. name:{}", T::NAME);
982 T::Response::from_envelope(response)
983 .ok_or_else(|| anyhow!("received a response of the wrong type"))
984 }
985 }
986
987 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
988 log::debug!("ssh send name:{}", T::NAME);
989 self.send_dynamic(payload.into_envelope(0, None, None))
990 }
991
992 pub fn request_dynamic(
993 &self,
994 mut envelope: proto::Envelope,
995 type_name: &'static str,
996 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
997 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
998 let (tx, rx) = oneshot::channel();
999 let mut response_channels_lock = self.response_channels.lock();
1000 response_channels_lock.insert(MessageId(envelope.id), tx);
1001 drop(response_channels_lock);
1002 let result = self.outgoing_tx.unbounded_send(envelope);
1003 async move {
1004 if let Err(error) = &result {
1005 log::error!("failed to send message: {}", error);
1006 return Err(anyhow!("failed to send message: {}", error));
1007 }
1008
1009 let response = rx.await.context("connection lost")?.0;
1010 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1011 return Err(RpcError::from_proto(error, type_name));
1012 }
1013 Ok(response)
1014 }
1015 }
1016
1017 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1018 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1019 self.outgoing_tx.unbounded_send(envelope)?;
1020 Ok(())
1021 }
1022}
1023
1024impl ProtoClient for ChannelClient {
1025 fn request(
1026 &self,
1027 envelope: proto::Envelope,
1028 request_type: &'static str,
1029 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1030 self.request_dynamic(envelope, request_type).boxed()
1031 }
1032
1033 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1034 self.send_dynamic(envelope)
1035 }
1036
1037 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1038 self.send_dynamic(envelope)
1039 }
1040
1041 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1042 &self.message_handlers
1043 }
1044
1045 fn is_via_collab(&self) -> bool {
1046 false
1047 }
1048}