1use crate::{
2 SshConnectionOptions,
3 protocol::MessageId,
4 proxy::ProxyLaunchError,
5 transport::{
6 ssh::SshRemoteConnection,
7 wsl::{WslConnectionOptions, WslRemoteConnection},
8 },
9};
10use anyhow::{Context as _, Result, anyhow};
11use askpass::EncryptedPassword;
12use async_trait::async_trait;
13use collections::HashMap;
14use futures::{
15 Future, FutureExt as _, StreamExt as _,
16 channel::{
17 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
18 oneshot,
19 },
20 future::{BoxFuture, Shared},
21 select, select_biased,
22};
23use gpui::{
24 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
25 EventEmitter, FutureExt, Global, Task, WeakEntity,
26};
27use parking_lot::Mutex;
28
29use release_channel::ReleaseChannel;
30use rpc::{
31 AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
32 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
33};
34use semver::Version;
35use std::{
36 collections::VecDeque,
37 fmt,
38 ops::ControlFlow,
39 path::PathBuf,
40 sync::{
41 Arc, Weak,
42 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
43 },
44 time::{Duration, Instant},
45};
46use util::{
47 ResultExt,
48 paths::{PathStyle, RemotePathBuf},
49};
50
51#[derive(Copy, Clone, Debug)]
52pub struct RemotePlatform {
53 pub os: &'static str,
54 pub arch: &'static str,
55}
56
57#[derive(Clone, Debug)]
58pub struct CommandTemplate {
59 pub program: String,
60 pub args: Vec<String>,
61 pub env: HashMap<String, String>,
62}
63
64pub trait RemoteClientDelegate: Send + Sync {
65 fn ask_password(
66 &self,
67 prompt: String,
68 tx: oneshot::Sender<EncryptedPassword>,
69 cx: &mut AsyncApp,
70 );
71 fn get_download_url(
72 &self,
73 platform: RemotePlatform,
74 release_channel: ReleaseChannel,
75 version: Option<Version>,
76 cx: &mut AsyncApp,
77 ) -> Task<Result<Option<String>>>;
78 fn download_server_binary_locally(
79 &self,
80 platform: RemotePlatform,
81 release_channel: ReleaseChannel,
82 version: Option<Version>,
83 cx: &mut AsyncApp,
84 ) -> Task<Result<PathBuf>>;
85 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
86}
87
88const MAX_MISSED_HEARTBEATS: usize = 5;
89const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
90const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
91const INITIAL_CONNECTION_TIMEOUT: Duration = Duration::from_secs(60);
92
93const MAX_RECONNECT_ATTEMPTS: usize = 3;
94
95enum State {
96 Connecting,
97 Connected {
98 remote_connection: Arc<dyn RemoteConnection>,
99 delegate: Arc<dyn RemoteClientDelegate>,
100
101 multiplex_task: Task<Result<()>>,
102 heartbeat_task: Task<Result<()>>,
103 },
104 HeartbeatMissed {
105 missed_heartbeats: usize,
106
107 ssh_connection: Arc<dyn RemoteConnection>,
108 delegate: Arc<dyn RemoteClientDelegate>,
109
110 multiplex_task: Task<Result<()>>,
111 heartbeat_task: Task<Result<()>>,
112 },
113 Reconnecting,
114 ReconnectFailed {
115 ssh_connection: Arc<dyn RemoteConnection>,
116 delegate: Arc<dyn RemoteClientDelegate>,
117
118 error: anyhow::Error,
119 attempts: usize,
120 },
121 ReconnectExhausted,
122 ServerNotRunning,
123}
124
125impl fmt::Display for State {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 Self::Connecting => write!(f, "connecting"),
129 Self::Connected { .. } => write!(f, "connected"),
130 Self::Reconnecting => write!(f, "reconnecting"),
131 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
132 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
133 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
134 Self::ServerNotRunning { .. } => write!(f, "server not running"),
135 }
136 }
137}
138
139impl State {
140 fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
141 match self {
142 Self::Connected {
143 remote_connection: ssh_connection,
144 ..
145 } => Some(ssh_connection.clone()),
146 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.clone()),
147 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.clone()),
148 _ => None,
149 }
150 }
151
152 fn can_reconnect(&self) -> bool {
153 match self {
154 Self::Connected { .. }
155 | Self::HeartbeatMissed { .. }
156 | Self::ReconnectFailed { .. } => true,
157 State::Connecting
158 | State::Reconnecting
159 | State::ReconnectExhausted
160 | State::ServerNotRunning => false,
161 }
162 }
163
164 fn is_reconnect_failed(&self) -> bool {
165 matches!(self, Self::ReconnectFailed { .. })
166 }
167
168 fn is_reconnect_exhausted(&self) -> bool {
169 matches!(self, Self::ReconnectExhausted { .. })
170 }
171
172 fn is_server_not_running(&self) -> bool {
173 matches!(self, Self::ServerNotRunning)
174 }
175
176 fn is_reconnecting(&self) -> bool {
177 matches!(self, Self::Reconnecting { .. })
178 }
179
180 fn heartbeat_recovered(self) -> Self {
181 match self {
182 Self::HeartbeatMissed {
183 ssh_connection,
184 delegate,
185 multiplex_task,
186 heartbeat_task,
187 ..
188 } => Self::Connected {
189 remote_connection: ssh_connection,
190 delegate,
191 multiplex_task,
192 heartbeat_task,
193 },
194 _ => self,
195 }
196 }
197
198 fn heartbeat_missed(self) -> Self {
199 match self {
200 Self::Connected {
201 remote_connection: ssh_connection,
202 delegate,
203 multiplex_task,
204 heartbeat_task,
205 } => Self::HeartbeatMissed {
206 missed_heartbeats: 1,
207 ssh_connection,
208 delegate,
209 multiplex_task,
210 heartbeat_task,
211 },
212 Self::HeartbeatMissed {
213 missed_heartbeats,
214 ssh_connection,
215 delegate,
216 multiplex_task,
217 heartbeat_task,
218 } => Self::HeartbeatMissed {
219 missed_heartbeats: missed_heartbeats + 1,
220 ssh_connection,
221 delegate,
222 multiplex_task,
223 heartbeat_task,
224 },
225 _ => self,
226 }
227 }
228}
229
230/// The state of the ssh connection.
231#[derive(Clone, Copy, Debug, PartialEq, Eq)]
232pub enum ConnectionState {
233 Connecting,
234 Connected,
235 HeartbeatMissed,
236 Reconnecting,
237 Disconnected,
238}
239
240impl From<&State> for ConnectionState {
241 fn from(value: &State) -> Self {
242 match value {
243 State::Connecting => Self::Connecting,
244 State::Connected { .. } => Self::Connected,
245 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
246 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
247 State::ReconnectExhausted => Self::Disconnected,
248 State::ServerNotRunning => Self::Disconnected,
249 }
250 }
251}
252
253pub struct RemoteClient {
254 client: Arc<ChannelClient>,
255 unique_identifier: String,
256 connection_options: RemoteConnectionOptions,
257 path_style: PathStyle,
258 state: Option<State>,
259}
260
261#[derive(Debug)]
262pub enum RemoteClientEvent {
263 Disconnected,
264}
265
266impl EventEmitter<RemoteClientEvent> for RemoteClient {}
267
268/// Identifies the socket on the remote server so that reconnects
269/// can re-join the same project.
270pub enum ConnectionIdentifier {
271 Setup(u64),
272 Workspace(i64),
273}
274
275static NEXT_ID: AtomicU64 = AtomicU64::new(1);
276
277impl ConnectionIdentifier {
278 pub fn setup() -> Self {
279 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
280 }
281
282 // This string gets used in a socket name, and so must be relatively short.
283 // The total length of:
284 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
285 // Must be less than about 100 characters
286 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
287 // So our strings should be at most 20 characters or so.
288 fn to_string(&self, cx: &App) -> String {
289 let identifier_prefix = match ReleaseChannel::global(cx) {
290 ReleaseChannel::Stable => "".to_string(),
291 release_channel => format!("{}-", release_channel.dev_name()),
292 };
293 match self {
294 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
295 Self::Workspace(workspace_id) => {
296 format!("{identifier_prefix}workspace-{workspace_id}",)
297 }
298 }
299 }
300}
301
302pub async fn connect(
303 connection_options: RemoteConnectionOptions,
304 delegate: Arc<dyn RemoteClientDelegate>,
305 cx: &mut AsyncApp,
306) -> Result<Arc<dyn RemoteConnection>> {
307 cx.update(|cx| {
308 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
309 pool.connect(connection_options.clone(), delegate.clone(), cx)
310 })
311 })?
312 .await
313 .map_err(|e| e.cloned())
314}
315
316impl RemoteClient {
317 pub fn new(
318 unique_identifier: ConnectionIdentifier,
319 remote_connection: Arc<dyn RemoteConnection>,
320 cancellation: oneshot::Receiver<()>,
321 delegate: Arc<dyn RemoteClientDelegate>,
322 cx: &mut App,
323 ) -> Task<Result<Option<Entity<Self>>>> {
324 let unique_identifier = unique_identifier.to_string(cx);
325 cx.spawn(async move |cx| {
326 let success = Box::pin(async move {
327 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
328 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
329 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
330
331 let client =
332 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
333
334 let path_style = remote_connection.path_style();
335 let this = cx.new(|_| Self {
336 client: client.clone(),
337 unique_identifier: unique_identifier.clone(),
338 connection_options: remote_connection.connection_options(),
339 path_style,
340 state: Some(State::Connecting),
341 })?;
342
343 let io_task = remote_connection.start_proxy(
344 unique_identifier,
345 false,
346 incoming_tx,
347 outgoing_rx,
348 connection_activity_tx,
349 delegate.clone(),
350 cx,
351 );
352
353 let ready = client
354 .wait_for_remote_started()
355 .with_timeout(INITIAL_CONNECTION_TIMEOUT, cx.background_executor())
356 .await;
357 match ready {
358 Ok(Some(_)) => {}
359 Ok(None) => {
360 let mut error = "remote client exited before becoming ready".to_owned();
361 if let Some(status) = io_task.now_or_never() {
362 match status {
363 Ok(exit_code) => {
364 error.push_str(&format!(", exit_code={exit_code:?}"))
365 }
366 Err(e) => error.push_str(&format!(", error={e:?}")),
367 }
368 }
369 let error = anyhow::anyhow!("{error}");
370 log::error!("failed to establish connection: {}", error);
371 return Err(error);
372 }
373 Err(_) => {
374 let mut error =
375 "remote client did not become ready within the timeout".to_owned();
376 if let Some(status) = io_task.now_or_never() {
377 match status {
378 Ok(exit_code) => {
379 error.push_str(&format!(", exit_code={exit_code:?}"))
380 }
381 Err(e) => error.push_str(&format!(", error={e:?}")),
382 }
383 }
384 let error = anyhow::anyhow!("{error}");
385 log::error!("failed to establish connection: {}", error);
386 return Err(error);
387 }
388 }
389 let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
390 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
391 log::error!("failed to establish connection: {}", error);
392 return Err(error);
393 }
394
395 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
396
397 this.update(cx, |this, _| {
398 this.state = Some(State::Connected {
399 remote_connection,
400 delegate,
401 multiplex_task,
402 heartbeat_task,
403 });
404 })?;
405
406 Ok(Some(this))
407 });
408
409 select! {
410 _ = cancellation.fuse() => {
411 Ok(None)
412 }
413 result = success.fuse() => result
414 }
415 })
416 }
417
418 pub fn proto_client_from_channels(
419 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
420 outgoing_tx: mpsc::UnboundedSender<Envelope>,
421 cx: &App,
422 name: &'static str,
423 ) -> AnyProtoClient {
424 ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
425 }
426
427 pub fn shutdown_processes<T: RequestMessage>(
428 &mut self,
429 shutdown_request: Option<T>,
430 executor: BackgroundExecutor,
431 ) -> Option<impl Future<Output = ()> + use<T>> {
432 let state = self.state.take()?;
433 log::info!("shutting down ssh processes");
434
435 let State::Connected {
436 multiplex_task,
437 heartbeat_task,
438 remote_connection: ssh_connection,
439 delegate,
440 } = state
441 else {
442 return None;
443 };
444
445 let client = self.client.clone();
446
447 Some(async move {
448 if let Some(shutdown_request) = shutdown_request {
449 client.send(shutdown_request).log_err();
450 // We wait 50ms instead of waiting for a response, because
451 // waiting for a response would require us to wait on the main thread
452 // which we want to avoid in an `on_app_quit` callback.
453 executor.timer(Duration::from_millis(50)).await;
454 }
455
456 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
457 // child of master_process.
458 drop(multiplex_task);
459 // Now drop the rest of state, which kills master process.
460 drop(heartbeat_task);
461 drop(ssh_connection);
462 drop(delegate);
463 })
464 }
465
466 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
467 let can_reconnect = self
468 .state
469 .as_ref()
470 .map(|state| state.can_reconnect())
471 .unwrap_or(false);
472 if !can_reconnect {
473 log::info!("aborting reconnect, because not in state that allows reconnecting");
474 let error = if let Some(state) = self.state.as_ref() {
475 format!("invalid state, cannot reconnect while in state {state}")
476 } else {
477 "no state set".to_string()
478 };
479 anyhow::bail!(error);
480 }
481
482 let state = self.state.take().unwrap();
483 let (attempts, remote_connection, delegate) = match state {
484 State::Connected {
485 remote_connection: ssh_connection,
486 delegate,
487 multiplex_task,
488 heartbeat_task,
489 }
490 | State::HeartbeatMissed {
491 ssh_connection,
492 delegate,
493 multiplex_task,
494 heartbeat_task,
495 ..
496 } => {
497 drop(multiplex_task);
498 drop(heartbeat_task);
499 (0, ssh_connection, delegate)
500 }
501 State::ReconnectFailed {
502 attempts,
503 ssh_connection,
504 delegate,
505 ..
506 } => (attempts, ssh_connection, delegate),
507 State::Connecting
508 | State::Reconnecting
509 | State::ReconnectExhausted
510 | State::ServerNotRunning => unreachable!(),
511 };
512
513 let attempts = attempts + 1;
514 if attempts > MAX_RECONNECT_ATTEMPTS {
515 log::error!(
516 "Failed to reconnect to after {} attempts, giving up",
517 MAX_RECONNECT_ATTEMPTS
518 );
519 self.set_state(State::ReconnectExhausted, cx);
520 return Ok(());
521 }
522
523 self.set_state(State::Reconnecting, cx);
524
525 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
526
527 let unique_identifier = self.unique_identifier.clone();
528 let client = self.client.clone();
529 let reconnect_task = cx.spawn(async move |this, cx| {
530 macro_rules! failed {
531 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
532 delegate.set_status(Some(&format!("{error:#}", error = $error)), cx);
533 return State::ReconnectFailed {
534 error: anyhow!($error),
535 attempts: $attempts,
536 ssh_connection: $ssh_connection,
537 delegate: $delegate,
538 };
539 };
540 }
541
542 if let Err(error) = remote_connection
543 .kill()
544 .await
545 .context("Failed to kill ssh process")
546 {
547 failed!(error, attempts, remote_connection, delegate);
548 };
549
550 let connection_options = remote_connection.connection_options();
551
552 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
553 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
554 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
555
556 let (ssh_connection, io_task) = match async {
557 let ssh_connection = cx
558 .update_global(|pool: &mut ConnectionPool, cx| {
559 pool.connect(connection_options, delegate.clone(), cx)
560 })?
561 .await
562 .map_err(|error| error.cloned())?;
563
564 let io_task = ssh_connection.start_proxy(
565 unique_identifier,
566 true,
567 incoming_tx,
568 outgoing_rx,
569 connection_activity_tx,
570 delegate.clone(),
571 cx,
572 );
573 anyhow::Ok((ssh_connection, io_task))
574 }
575 .await
576 {
577 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
578 Err(error) => {
579 failed!(error, attempts, remote_connection, delegate);
580 }
581 };
582
583 let multiplex_task = Self::monitor(this.clone(), io_task, cx);
584 client.reconnect(incoming_rx, outgoing_tx, cx);
585
586 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
587 failed!(error, attempts, ssh_connection, delegate);
588 };
589
590 State::Connected {
591 remote_connection: ssh_connection,
592 delegate,
593 multiplex_task,
594 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
595 }
596 });
597
598 cx.spawn(async move |this, cx| {
599 let new_state = reconnect_task.await;
600 this.update(cx, |this, cx| {
601 this.try_set_state(cx, |old_state| {
602 if old_state.is_reconnecting() {
603 match &new_state {
604 State::Connecting
605 | State::Reconnecting
606 | State::HeartbeatMissed { .. }
607 | State::ServerNotRunning => {}
608 State::Connected { .. } => {
609 log::info!("Successfully reconnected");
610 }
611 State::ReconnectFailed {
612 error, attempts, ..
613 } => {
614 log::error!(
615 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
616 attempts,
617 error
618 );
619 }
620 State::ReconnectExhausted => {
621 log::error!("Reconnect attempt failed and all attempts exhausted");
622 }
623 }
624 Some(new_state)
625 } else {
626 None
627 }
628 });
629
630 if this.state_is(State::is_reconnect_failed) {
631 this.reconnect(cx)
632 } else if this.state_is(State::is_reconnect_exhausted) {
633 Ok(())
634 } else {
635 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
636 Ok(())
637 }
638 })
639 })
640 .detach_and_log_err(cx);
641
642 Ok(())
643 }
644
645 fn heartbeat(
646 this: WeakEntity<Self>,
647 mut connection_activity_rx: mpsc::Receiver<()>,
648 cx: &mut AsyncApp,
649 ) -> Task<Result<()>> {
650 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
651 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
652 };
653
654 cx.spawn(async move |cx| {
655 let mut missed_heartbeats = 0;
656
657 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
658 futures::pin_mut!(keepalive_timer);
659
660 loop {
661 select_biased! {
662 result = connection_activity_rx.next().fuse() => {
663 if result.is_none() {
664 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
665 return Ok(());
666 }
667
668 if missed_heartbeats != 0 {
669 missed_heartbeats = 0;
670 let _ =this.update(cx, |this, cx| {
671 this.handle_heartbeat_result(missed_heartbeats, cx)
672 })?;
673 }
674 }
675 _ = keepalive_timer => {
676 log::debug!("Sending heartbeat to server...");
677
678 let result = select_biased! {
679 _ = connection_activity_rx.next().fuse() => {
680 Ok(())
681 }
682 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
683 ping_result
684 }
685 };
686
687 if result.is_err() {
688 missed_heartbeats += 1;
689 log::warn!(
690 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
691 HEARTBEAT_TIMEOUT,
692 missed_heartbeats,
693 MAX_MISSED_HEARTBEATS
694 );
695 } else if missed_heartbeats != 0 {
696 missed_heartbeats = 0;
697 } else {
698 continue;
699 }
700
701 let result = this.update(cx, |this, cx| {
702 this.handle_heartbeat_result(missed_heartbeats, cx)
703 })?;
704 if result.is_break() {
705 return Ok(());
706 }
707 }
708 }
709
710 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
711 }
712 })
713 }
714
715 fn handle_heartbeat_result(
716 &mut self,
717 missed_heartbeats: usize,
718 cx: &mut Context<Self>,
719 ) -> ControlFlow<()> {
720 let state = self.state.take().unwrap();
721 let next_state = if missed_heartbeats > 0 {
722 state.heartbeat_missed()
723 } else {
724 state.heartbeat_recovered()
725 };
726
727 self.set_state(next_state, cx);
728
729 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
730 log::error!(
731 "Missed last {} heartbeats. Reconnecting...",
732 missed_heartbeats
733 );
734
735 self.reconnect(cx)
736 .context("failed to start reconnect process after missing heartbeats")
737 .log_err();
738 ControlFlow::Break(())
739 } else {
740 ControlFlow::Continue(())
741 }
742 }
743
744 fn monitor(
745 this: WeakEntity<Self>,
746 io_task: Task<Result<i32>>,
747 cx: &AsyncApp,
748 ) -> Task<Result<()>> {
749 cx.spawn(async move |cx| {
750 let result = io_task.await;
751
752 match result {
753 Ok(exit_code) => {
754 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
755 match error {
756 ProxyLaunchError::ServerNotRunning => {
757 log::error!("failed to reconnect because server is not running");
758 this.update(cx, |this, cx| {
759 this.set_state(State::ServerNotRunning, cx);
760 })?;
761 }
762 }
763 } else if exit_code > 0 {
764 log::error!("proxy process terminated unexpectedly");
765 this.update(cx, |this, cx| {
766 this.reconnect(cx).ok();
767 })?;
768 }
769 }
770 Err(error) => {
771 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
772 this.update(cx, |this, cx| {
773 this.reconnect(cx).ok();
774 })?;
775 }
776 }
777
778 Ok(())
779 })
780 }
781
782 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
783 self.state.as_ref().is_some_and(check)
784 }
785
786 fn try_set_state(&mut self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
787 let new_state = self.state.as_ref().and_then(map);
788 if let Some(new_state) = new_state {
789 self.state.replace(new_state);
790 cx.notify();
791 }
792 }
793
794 fn set_state(&mut self, state: State, cx: &mut Context<Self>) {
795 log::info!("setting state to '{}'", &state);
796
797 let is_reconnect_exhausted = state.is_reconnect_exhausted();
798 let is_server_not_running = state.is_server_not_running();
799 self.state.replace(state);
800
801 if is_reconnect_exhausted || is_server_not_running {
802 cx.emit(RemoteClientEvent::Disconnected);
803 }
804 cx.notify();
805 }
806
807 pub fn shell(&self) -> Option<String> {
808 Some(self.remote_connection()?.shell())
809 }
810
811 pub fn default_system_shell(&self) -> Option<String> {
812 Some(self.remote_connection()?.default_system_shell())
813 }
814
815 pub fn shares_network_interface(&self) -> bool {
816 self.remote_connection()
817 .map_or(false, |connection| connection.shares_network_interface())
818 }
819
820 pub fn build_command(
821 &self,
822 program: Option<String>,
823 args: &[String],
824 env: &HashMap<String, String>,
825 working_dir: Option<String>,
826 port_forward: Option<(u16, String, u16)>,
827 ) -> Result<CommandTemplate> {
828 let Some(connection) = self.remote_connection() else {
829 return Err(anyhow!("no ssh connection"));
830 };
831 connection.build_command(program, args, env, working_dir, port_forward)
832 }
833
834 pub fn build_forward_ports_command(
835 &self,
836 forwards: Vec<(u16, String, u16)>,
837 ) -> Result<CommandTemplate> {
838 let Some(connection) = self.remote_connection() else {
839 return Err(anyhow!("no ssh connection"));
840 };
841 connection.build_forward_ports_command(forwards)
842 }
843
844 pub fn upload_directory(
845 &self,
846 src_path: PathBuf,
847 dest_path: RemotePathBuf,
848 cx: &App,
849 ) -> Task<Result<()>> {
850 let Some(connection) = self.remote_connection() else {
851 return Task::ready(Err(anyhow!("no ssh connection")));
852 };
853 connection.upload_directory(src_path, dest_path, cx)
854 }
855
856 pub fn proto_client(&self) -> AnyProtoClient {
857 self.client.clone().into()
858 }
859
860 pub fn connection_options(&self) -> RemoteConnectionOptions {
861 self.connection_options.clone()
862 }
863
864 pub fn connection(&self) -> Option<Arc<dyn RemoteConnection>> {
865 if let State::Connected {
866 remote_connection, ..
867 } = self.state.as_ref()?
868 {
869 Some(remote_connection.clone())
870 } else {
871 None
872 }
873 }
874
875 pub fn connection_state(&self) -> ConnectionState {
876 self.state
877 .as_ref()
878 .map(ConnectionState::from)
879 .unwrap_or(ConnectionState::Disconnected)
880 }
881
882 pub fn is_disconnected(&self) -> bool {
883 self.connection_state() == ConnectionState::Disconnected
884 }
885
886 pub fn path_style(&self) -> PathStyle {
887 self.path_style
888 }
889
890 #[cfg(any(test, feature = "test-support"))]
891 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
892 let opts = self.connection_options();
893 client_cx.spawn(async move |cx| {
894 let connection = cx
895 .update_global(|c: &mut ConnectionPool, _| {
896 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
897 c.clone()
898 } else {
899 panic!("missing test connection")
900 }
901 })
902 .unwrap()
903 .await
904 .unwrap();
905
906 connection.simulate_disconnect(cx);
907 })
908 }
909
910 #[cfg(any(test, feature = "test-support"))]
911 pub fn fake_server(
912 client_cx: &mut gpui::TestAppContext,
913 server_cx: &mut gpui::TestAppContext,
914 ) -> (RemoteConnectionOptions, AnyProtoClient) {
915 let port = client_cx
916 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
917 let opts = RemoteConnectionOptions::Ssh(SshConnectionOptions {
918 host: "<fake>".to_string(),
919 port: Some(port),
920 ..Default::default()
921 });
922 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
923 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
924 let server_client =
925 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
926 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
927 connection_options: opts.clone(),
928 server_cx: fake::SendableCx::new(server_cx),
929 server_channel: server_client.clone(),
930 });
931
932 client_cx.update(|cx| {
933 cx.update_default_global(|c: &mut ConnectionPool, cx| {
934 c.connections.insert(
935 opts.clone(),
936 ConnectionPoolEntry::Connecting(
937 cx.background_spawn({
938 let connection = connection.clone();
939 async move { Ok(connection.clone()) }
940 })
941 .shared(),
942 ),
943 );
944 })
945 });
946
947 (opts, server_client.into())
948 }
949
950 #[cfg(any(test, feature = "test-support"))]
951 pub async fn fake_client(
952 opts: RemoteConnectionOptions,
953 client_cx: &mut gpui::TestAppContext,
954 ) -> Entity<Self> {
955 let (_tx, rx) = oneshot::channel();
956 let mut cx = client_cx.to_async();
957 let connection = connect(opts, Arc::new(fake::Delegate), &mut cx)
958 .await
959 .unwrap();
960 client_cx
961 .update(|cx| {
962 Self::new(
963 ConnectionIdentifier::setup(),
964 connection,
965 rx,
966 Arc::new(fake::Delegate),
967 cx,
968 )
969 })
970 .await
971 .unwrap()
972 .unwrap()
973 }
974
975 fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
976 self.state
977 .as_ref()
978 .and_then(|state| state.remote_connection())
979 }
980}
981
982enum ConnectionPoolEntry {
983 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
984 Connected(Weak<dyn RemoteConnection>),
985}
986
987#[derive(Default)]
988struct ConnectionPool {
989 connections: HashMap<RemoteConnectionOptions, ConnectionPoolEntry>,
990}
991
992impl Global for ConnectionPool {}
993
994impl ConnectionPool {
995 pub fn connect(
996 &mut self,
997 opts: RemoteConnectionOptions,
998 delegate: Arc<dyn RemoteClientDelegate>,
999 cx: &mut App,
1000 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1001 let connection = self.connections.get(&opts);
1002 match connection {
1003 Some(ConnectionPoolEntry::Connecting(task)) => {
1004 delegate.set_status(
1005 Some("Waiting for existing connection attempt"),
1006 &mut cx.to_async(),
1007 );
1008 return task.clone();
1009 }
1010 Some(ConnectionPoolEntry::Connected(ssh)) => {
1011 if let Some(ssh) = ssh.upgrade()
1012 && !ssh.has_been_killed()
1013 {
1014 return Task::ready(Ok(ssh)).shared();
1015 }
1016 self.connections.remove(&opts);
1017 }
1018 None => {}
1019 }
1020
1021 let task = cx
1022 .spawn({
1023 let opts = opts.clone();
1024 let delegate = delegate.clone();
1025 async move |cx| {
1026 let connection = match opts.clone() {
1027 RemoteConnectionOptions::Ssh(opts) => {
1028 SshRemoteConnection::new(opts, delegate, cx)
1029 .await
1030 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1031 }
1032 RemoteConnectionOptions::Wsl(opts) => {
1033 WslRemoteConnection::new(opts, delegate, cx)
1034 .await
1035 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
1036 }
1037 };
1038
1039 cx.update_global(|pool: &mut Self, _| {
1040 debug_assert!(matches!(
1041 pool.connections.get(&opts),
1042 Some(ConnectionPoolEntry::Connecting(_))
1043 ));
1044 match connection {
1045 Ok(connection) => {
1046 pool.connections.insert(
1047 opts.clone(),
1048 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1049 );
1050 Ok(connection)
1051 }
1052 Err(error) => {
1053 pool.connections.remove(&opts);
1054 Err(Arc::new(error))
1055 }
1056 }
1057 })?
1058 }
1059 })
1060 .shared();
1061
1062 self.connections
1063 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1064 task
1065 }
1066}
1067
1068#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1069pub enum RemoteConnectionOptions {
1070 Ssh(SshConnectionOptions),
1071 Wsl(WslConnectionOptions),
1072}
1073
1074impl RemoteConnectionOptions {
1075 pub fn display_name(&self) -> String {
1076 match self {
1077 RemoteConnectionOptions::Ssh(opts) => opts.host.clone(),
1078 RemoteConnectionOptions::Wsl(opts) => opts.distro_name.clone(),
1079 }
1080 }
1081}
1082
1083impl From<SshConnectionOptions> for RemoteConnectionOptions {
1084 fn from(opts: SshConnectionOptions) -> Self {
1085 RemoteConnectionOptions::Ssh(opts)
1086 }
1087}
1088
1089impl From<WslConnectionOptions> for RemoteConnectionOptions {
1090 fn from(opts: WslConnectionOptions) -> Self {
1091 RemoteConnectionOptions::Wsl(opts)
1092 }
1093}
1094
1095#[cfg(target_os = "windows")]
1096/// Open a wsl path (\\wsl.localhost\<distro>\path)
1097#[derive(Debug, Clone, PartialEq, Eq, gpui::Action)]
1098#[action(namespace = workspace, no_json, no_register)]
1099pub struct OpenWslPath {
1100 pub distro: WslConnectionOptions,
1101 pub paths: Vec<PathBuf>,
1102}
1103
1104#[async_trait(?Send)]
1105pub trait RemoteConnection: Send + Sync {
1106 fn start_proxy(
1107 &self,
1108 unique_identifier: String,
1109 reconnect: bool,
1110 incoming_tx: UnboundedSender<Envelope>,
1111 outgoing_rx: UnboundedReceiver<Envelope>,
1112 connection_activity_tx: Sender<()>,
1113 delegate: Arc<dyn RemoteClientDelegate>,
1114 cx: &mut AsyncApp,
1115 ) -> Task<Result<i32>>;
1116 fn upload_directory(
1117 &self,
1118 src_path: PathBuf,
1119 dest_path: RemotePathBuf,
1120 cx: &App,
1121 ) -> Task<Result<()>>;
1122 async fn kill(&self) -> Result<()>;
1123 fn has_been_killed(&self) -> bool;
1124 fn shares_network_interface(&self) -> bool {
1125 false
1126 }
1127 fn build_command(
1128 &self,
1129 program: Option<String>,
1130 args: &[String],
1131 env: &HashMap<String, String>,
1132 working_dir: Option<String>,
1133 port_forward: Option<(u16, String, u16)>,
1134 ) -> Result<CommandTemplate>;
1135 fn build_forward_ports_command(
1136 &self,
1137 forwards: Vec<(u16, String, u16)>,
1138 ) -> Result<CommandTemplate>;
1139 fn connection_options(&self) -> RemoteConnectionOptions;
1140 fn path_style(&self) -> PathStyle;
1141 fn shell(&self) -> String;
1142 fn default_system_shell(&self) -> String;
1143
1144 #[cfg(any(test, feature = "test-support"))]
1145 fn simulate_disconnect(&self, _: &AsyncApp) {}
1146}
1147
1148type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1149
1150struct Signal<T> {
1151 tx: Mutex<Option<oneshot::Sender<T>>>,
1152 rx: Shared<Task<Option<T>>>,
1153}
1154
1155impl<T: Send + Clone + 'static> Signal<T> {
1156 pub fn new(cx: &App) -> Self {
1157 let (tx, rx) = oneshot::channel();
1158
1159 let task = cx
1160 .background_executor()
1161 .spawn(async move { rx.await.ok() })
1162 .shared();
1163
1164 Self {
1165 tx: Mutex::new(Some(tx)),
1166 rx: task,
1167 }
1168 }
1169
1170 fn set(&self, value: T) {
1171 if let Some(tx) = self.tx.lock().take() {
1172 let _ = tx.send(value);
1173 }
1174 }
1175
1176 fn wait(&self) -> Shared<Task<Option<T>>> {
1177 self.rx.clone()
1178 }
1179}
1180
1181struct ChannelClient {
1182 next_message_id: AtomicU32,
1183 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1184 buffer: Mutex<VecDeque<Envelope>>,
1185 response_channels: ResponseChannels,
1186 message_handlers: Mutex<ProtoMessageHandlerSet>,
1187 max_received: AtomicU32,
1188 name: &'static str,
1189 task: Mutex<Task<Result<()>>>,
1190 remote_started: Signal<()>,
1191}
1192
1193impl ChannelClient {
1194 fn new(
1195 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1196 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1197 cx: &App,
1198 name: &'static str,
1199 ) -> Arc<Self> {
1200 Arc::new_cyclic(|this| Self {
1201 outgoing_tx: Mutex::new(outgoing_tx),
1202 next_message_id: AtomicU32::new(0),
1203 max_received: AtomicU32::new(0),
1204 response_channels: ResponseChannels::default(),
1205 message_handlers: Default::default(),
1206 buffer: Mutex::new(VecDeque::new()),
1207 name,
1208 task: Mutex::new(Self::start_handling_messages(
1209 this.clone(),
1210 incoming_rx,
1211 &cx.to_async(),
1212 )),
1213 remote_started: Signal::new(cx),
1214 })
1215 }
1216
1217 fn wait_for_remote_started(&self) -> Shared<Task<Option<()>>> {
1218 self.remote_started.wait()
1219 }
1220
1221 fn start_handling_messages(
1222 this: Weak<Self>,
1223 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1224 cx: &AsyncApp,
1225 ) -> Task<Result<()>> {
1226 cx.spawn(async move |cx| {
1227 if let Some(this) = this.upgrade() {
1228 let envelope = proto::RemoteStarted {}.into_envelope(0, None, None);
1229 this.outgoing_tx.lock().unbounded_send(envelope).ok();
1230 };
1231
1232 let peer_id = PeerId { owner_id: 0, id: 0 };
1233 while let Some(incoming) = incoming_rx.next().await {
1234 let Some(this) = this.upgrade() else {
1235 return anyhow::Ok(());
1236 };
1237 if let Some(ack_id) = incoming.ack_id {
1238 let mut buffer = this.buffer.lock();
1239 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1240 buffer.pop_front();
1241 }
1242 }
1243 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
1244 {
1245 log::debug!(
1246 "{}:ssh message received. name:FlushBufferedMessages",
1247 this.name
1248 );
1249 {
1250 let buffer = this.buffer.lock();
1251 for envelope in buffer.iter() {
1252 this.outgoing_tx
1253 .lock()
1254 .unbounded_send(envelope.clone())
1255 .ok();
1256 }
1257 }
1258 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1259 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1260 this.outgoing_tx.lock().unbounded_send(envelope).ok();
1261 continue;
1262 }
1263
1264 if let Some(proto::envelope::Payload::RemoteStarted(_)) = &incoming.payload {
1265 this.remote_started.set(());
1266 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1267 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1268 this.outgoing_tx.lock().unbounded_send(envelope).ok();
1269 continue;
1270 }
1271
1272 this.max_received.store(incoming.id, SeqCst);
1273
1274 if let Some(request_id) = incoming.responding_to {
1275 let request_id = MessageId(request_id);
1276 let sender = this.response_channels.lock().remove(&request_id);
1277 if let Some(sender) = sender {
1278 let (tx, rx) = oneshot::channel();
1279 if incoming.payload.is_some() {
1280 sender.send((incoming, tx)).ok();
1281 }
1282 rx.await.ok();
1283 }
1284 } else if let Some(envelope) =
1285 build_typed_envelope(peer_id, Instant::now(), incoming)
1286 {
1287 let type_name = envelope.payload_type_name();
1288 let message_id = envelope.message_id();
1289 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1290 &this.message_handlers,
1291 envelope,
1292 this.clone().into(),
1293 cx.clone(),
1294 ) {
1295 log::debug!("{}:ssh message received. name:{type_name}", this.name);
1296 cx.foreground_executor()
1297 .spawn(async move {
1298 match future.await {
1299 Ok(_) => {
1300 log::debug!(
1301 "{}:ssh message handled. name:{type_name}",
1302 this.name
1303 );
1304 }
1305 Err(error) => {
1306 log::error!(
1307 "{}:error handling message. type:{}, error:{:#}",
1308 this.name,
1309 type_name,
1310 format!("{error:#}").lines().fold(
1311 String::new(),
1312 |mut message, line| {
1313 if !message.is_empty() {
1314 message.push(' ');
1315 }
1316 message.push_str(line);
1317 message
1318 }
1319 )
1320 );
1321 }
1322 }
1323 })
1324 .detach()
1325 } else {
1326 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1327 if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
1328 message_id,
1329 anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
1330 ) {
1331 log::error!(
1332 "{}:error sending error response for {type_name}:{e:#}",
1333 this.name
1334 );
1335 }
1336 }
1337 }
1338 }
1339 anyhow::Ok(())
1340 })
1341 }
1342
1343 fn reconnect(
1344 self: &Arc<Self>,
1345 incoming_rx: UnboundedReceiver<Envelope>,
1346 outgoing_tx: UnboundedSender<Envelope>,
1347 cx: &AsyncApp,
1348 ) {
1349 *self.outgoing_tx.lock() = outgoing_tx;
1350 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1351 }
1352
1353 fn request<T: RequestMessage>(
1354 &self,
1355 payload: T,
1356 ) -> impl 'static + Future<Output = Result<T::Response>> {
1357 self.request_internal(payload, true)
1358 }
1359
1360 fn request_internal<T: RequestMessage>(
1361 &self,
1362 payload: T,
1363 use_buffer: bool,
1364 ) -> impl 'static + Future<Output = Result<T::Response>> {
1365 log::debug!("ssh request start. name:{}", T::NAME);
1366 let response =
1367 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
1368 async move {
1369 let response = response.await?;
1370 log::debug!("ssh request finish. name:{}", T::NAME);
1371 T::Response::from_envelope(response).context("received a response of the wrong type")
1372 }
1373 }
1374
1375 async fn resync(&self, timeout: Duration) -> Result<()> {
1376 smol::future::or(
1377 async {
1378 self.request_internal(proto::FlushBufferedMessages {}, false)
1379 .await?;
1380
1381 for envelope in self.buffer.lock().iter() {
1382 self.outgoing_tx
1383 .lock()
1384 .unbounded_send(envelope.clone())
1385 .ok();
1386 }
1387 Ok(())
1388 },
1389 async {
1390 smol::Timer::after(timeout).await;
1391 anyhow::bail!("Timed out resyncing remote client")
1392 },
1393 )
1394 .await
1395 }
1396
1397 async fn ping(&self, timeout: Duration) -> Result<()> {
1398 smol::future::or(
1399 async {
1400 self.request(proto::Ping {}).await?;
1401 Ok(())
1402 },
1403 async {
1404 smol::Timer::after(timeout).await;
1405 anyhow::bail!("Timed out pinging remote client")
1406 },
1407 )
1408 .await
1409 }
1410
1411 fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1412 log::debug!("ssh send name:{}", T::NAME);
1413 self.send_dynamic(payload.into_envelope(0, None, None))
1414 }
1415
1416 fn request_dynamic(
1417 &self,
1418 mut envelope: proto::Envelope,
1419 type_name: &'static str,
1420 use_buffer: bool,
1421 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1422 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1423 let (tx, rx) = oneshot::channel();
1424 let mut response_channels_lock = self.response_channels.lock();
1425 response_channels_lock.insert(MessageId(envelope.id), tx);
1426 drop(response_channels_lock);
1427
1428 let result = if use_buffer {
1429 self.send_buffered(envelope)
1430 } else {
1431 self.send_unbuffered(envelope)
1432 };
1433 async move {
1434 if let Err(error) = &result {
1435 log::error!("failed to send message: {error}");
1436 anyhow::bail!("failed to send message: {error}");
1437 }
1438
1439 let response = rx.await.context("connection lost")?.0;
1440 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1441 return Err(RpcError::from_proto(error, type_name));
1442 }
1443 Ok(response)
1444 }
1445 }
1446
1447 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1448 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1449 self.send_buffered(envelope)
1450 }
1451
1452 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1453 envelope.ack_id = Some(self.max_received.load(SeqCst));
1454 self.buffer.lock().push_back(envelope.clone());
1455 // ignore errors on send (happen while we're reconnecting)
1456 // assume that the global "disconnected" overlay is sufficient.
1457 self.outgoing_tx.lock().unbounded_send(envelope).ok();
1458 Ok(())
1459 }
1460
1461 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1462 envelope.ack_id = Some(self.max_received.load(SeqCst));
1463 self.outgoing_tx.lock().unbounded_send(envelope).ok();
1464 Ok(())
1465 }
1466}
1467
1468impl ProtoClient for ChannelClient {
1469 fn request(
1470 &self,
1471 envelope: proto::Envelope,
1472 request_type: &'static str,
1473 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1474 self.request_dynamic(envelope, request_type, true).boxed()
1475 }
1476
1477 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1478 self.send_dynamic(envelope)
1479 }
1480
1481 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1482 self.send_dynamic(envelope)
1483 }
1484
1485 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1486 &self.message_handlers
1487 }
1488
1489 fn is_via_collab(&self) -> bool {
1490 false
1491 }
1492}
1493
1494#[cfg(any(test, feature = "test-support"))]
1495mod fake {
1496 use super::{ChannelClient, RemoteClientDelegate, RemoteConnection, RemotePlatform};
1497 use crate::remote_client::{CommandTemplate, RemoteConnectionOptions};
1498 use anyhow::Result;
1499 use askpass::EncryptedPassword;
1500 use async_trait::async_trait;
1501 use collections::HashMap;
1502 use futures::{
1503 FutureExt, SinkExt, StreamExt,
1504 channel::{
1505 mpsc::{self, Sender},
1506 oneshot,
1507 },
1508 select_biased,
1509 };
1510 use gpui::{App, AppContext as _, AsyncApp, Task, TestAppContext};
1511 use release_channel::ReleaseChannel;
1512 use rpc::proto::Envelope;
1513 use semver::Version;
1514 use std::{path::PathBuf, sync::Arc};
1515 use util::paths::{PathStyle, RemotePathBuf};
1516
1517 pub(super) struct FakeRemoteConnection {
1518 pub(super) connection_options: RemoteConnectionOptions,
1519 pub(super) server_channel: Arc<ChannelClient>,
1520 pub(super) server_cx: SendableCx,
1521 }
1522
1523 pub(super) struct SendableCx(AsyncApp);
1524 impl SendableCx {
1525 // SAFETY: When run in test mode, GPUI is always single threaded.
1526 pub(super) fn new(cx: &TestAppContext) -> Self {
1527 Self(cx.to_async())
1528 }
1529
1530 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
1531 fn get(&self, _: &AsyncApp) -> AsyncApp {
1532 self.0.clone()
1533 }
1534 }
1535
1536 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
1537 unsafe impl Send for SendableCx {}
1538 unsafe impl Sync for SendableCx {}
1539
1540 #[async_trait(?Send)]
1541 impl RemoteConnection for FakeRemoteConnection {
1542 async fn kill(&self) -> Result<()> {
1543 Ok(())
1544 }
1545
1546 fn has_been_killed(&self) -> bool {
1547 false
1548 }
1549
1550 fn build_command(
1551 &self,
1552 program: Option<String>,
1553 args: &[String],
1554 env: &HashMap<String, String>,
1555 _: Option<String>,
1556 _: Option<(u16, String, u16)>,
1557 ) -> Result<CommandTemplate> {
1558 let ssh_program = program.unwrap_or_else(|| "sh".to_string());
1559 let mut ssh_args = Vec::new();
1560 ssh_args.push(ssh_program);
1561 ssh_args.extend(args.iter().cloned());
1562 Ok(CommandTemplate {
1563 program: "ssh".into(),
1564 args: ssh_args,
1565 env: env.clone(),
1566 })
1567 }
1568
1569 fn build_forward_ports_command(
1570 &self,
1571 forwards: Vec<(u16, String, u16)>,
1572 ) -> anyhow::Result<CommandTemplate> {
1573 Ok(CommandTemplate {
1574 program: "ssh".into(),
1575 args: std::iter::once("-N".to_owned())
1576 .chain(forwards.into_iter().map(|(local_port, host, remote_port)| {
1577 format!("{local_port}:{host}:{remote_port}")
1578 }))
1579 .collect(),
1580 env: Default::default(),
1581 })
1582 }
1583
1584 fn upload_directory(
1585 &self,
1586 _src_path: PathBuf,
1587 _dest_path: RemotePathBuf,
1588 _cx: &App,
1589 ) -> Task<Result<()>> {
1590 unreachable!()
1591 }
1592
1593 fn connection_options(&self) -> RemoteConnectionOptions {
1594 self.connection_options.clone()
1595 }
1596
1597 fn simulate_disconnect(&self, cx: &AsyncApp) {
1598 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1599 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1600 self.server_channel
1601 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
1602 }
1603
1604 fn start_proxy(
1605 &self,
1606 _unique_identifier: String,
1607 _reconnect: bool,
1608 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
1609 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
1610 mut connection_activity_tx: Sender<()>,
1611 _delegate: Arc<dyn RemoteClientDelegate>,
1612 cx: &mut AsyncApp,
1613 ) -> Task<Result<i32>> {
1614 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
1615 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
1616
1617 self.server_channel.reconnect(
1618 server_incoming_rx,
1619 server_outgoing_tx,
1620 &self.server_cx.get(cx),
1621 );
1622
1623 cx.background_spawn(async move {
1624 loop {
1625 select_biased! {
1626 server_to_client = server_outgoing_rx.next().fuse() => {
1627 let Some(server_to_client) = server_to_client else {
1628 return Ok(1)
1629 };
1630 connection_activity_tx.try_send(()).ok();
1631 client_incoming_tx.send(server_to_client).await.ok();
1632 }
1633 client_to_server = client_outgoing_rx.next().fuse() => {
1634 let Some(client_to_server) = client_to_server else {
1635 return Ok(1)
1636 };
1637 server_incoming_tx.send(client_to_server).await.ok();
1638 }
1639 }
1640 }
1641 })
1642 }
1643
1644 fn path_style(&self) -> PathStyle {
1645 PathStyle::local()
1646 }
1647
1648 fn shell(&self) -> String {
1649 "sh".to_owned()
1650 }
1651
1652 fn default_system_shell(&self) -> String {
1653 "sh".to_owned()
1654 }
1655 }
1656
1657 pub(super) struct Delegate;
1658
1659 impl RemoteClientDelegate for Delegate {
1660 fn ask_password(&self, _: String, _: oneshot::Sender<EncryptedPassword>, _: &mut AsyncApp) {
1661 unreachable!()
1662 }
1663
1664 fn download_server_binary_locally(
1665 &self,
1666 _: RemotePlatform,
1667 _: ReleaseChannel,
1668 _: Option<Version>,
1669 _: &mut AsyncApp,
1670 ) -> Task<Result<PathBuf>> {
1671 unreachable!()
1672 }
1673
1674 fn get_download_url(
1675 &self,
1676 _platform: RemotePlatform,
1677 _release_channel: ReleaseChannel,
1678 _version: Option<Version>,
1679 _cx: &mut AsyncApp,
1680 ) -> Task<Result<Option<String>>> {
1681 unreachable!()
1682 }
1683
1684 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
1685 }
1686}