1use crate::{
2 SshConnectionOptions, protocol::MessageId, proxy::ProxyLaunchError,
3 transport::ssh::SshRemoteConnection,
4};
5use anyhow::{Context as _, Result, anyhow};
6use async_trait::async_trait;
7use collections::HashMap;
8use futures::{
9 Future, FutureExt as _, StreamExt as _,
10 channel::{
11 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
12 oneshot,
13 },
14 future::{BoxFuture, Shared},
15 select, select_biased,
16};
17use gpui::{
18 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
19 EventEmitter, Global, SemanticVersion, Task, WeakEntity,
20};
21use parking_lot::Mutex;
22
23use release_channel::ReleaseChannel;
24use rpc::{
25 AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
26 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
27};
28use std::{
29 collections::VecDeque,
30 fmt,
31 ops::ControlFlow,
32 path::PathBuf,
33 sync::{
34 Arc, Weak,
35 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
36 },
37 time::{Duration, Instant},
38};
39use util::{
40 ResultExt,
41 paths::{PathStyle, RemotePathBuf},
42};
43
44#[derive(Copy, Clone, Debug)]
45pub struct RemotePlatform {
46 pub os: &'static str,
47 pub arch: &'static str,
48}
49
50#[derive(Clone, Debug)]
51pub struct CommandTemplate {
52 pub program: String,
53 pub args: Vec<String>,
54 pub env: HashMap<String, String>,
55}
56
57pub trait RemoteClientDelegate: Send + Sync {
58 fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
59 fn get_download_params(
60 &self,
61 platform: RemotePlatform,
62 release_channel: ReleaseChannel,
63 version: Option<SemanticVersion>,
64 cx: &mut AsyncApp,
65 ) -> Task<Result<Option<(String, String)>>>;
66 fn download_server_binary_locally(
67 &self,
68 platform: RemotePlatform,
69 release_channel: ReleaseChannel,
70 version: Option<SemanticVersion>,
71 cx: &mut AsyncApp,
72 ) -> Task<Result<PathBuf>>;
73 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
74}
75
76const MAX_MISSED_HEARTBEATS: usize = 5;
77const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
78const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
79
80const MAX_RECONNECT_ATTEMPTS: usize = 3;
81
82enum State {
83 Connecting,
84 Connected {
85 ssh_connection: Arc<dyn RemoteConnection>,
86 delegate: Arc<dyn RemoteClientDelegate>,
87
88 multiplex_task: Task<Result<()>>,
89 heartbeat_task: Task<Result<()>>,
90 },
91 HeartbeatMissed {
92 missed_heartbeats: usize,
93
94 ssh_connection: Arc<dyn RemoteConnection>,
95 delegate: Arc<dyn RemoteClientDelegate>,
96
97 multiplex_task: Task<Result<()>>,
98 heartbeat_task: Task<Result<()>>,
99 },
100 Reconnecting,
101 ReconnectFailed {
102 ssh_connection: Arc<dyn RemoteConnection>,
103 delegate: Arc<dyn RemoteClientDelegate>,
104
105 error: anyhow::Error,
106 attempts: usize,
107 },
108 ReconnectExhausted,
109 ServerNotRunning,
110}
111
112impl fmt::Display for State {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 match self {
115 Self::Connecting => write!(f, "connecting"),
116 Self::Connected { .. } => write!(f, "connected"),
117 Self::Reconnecting => write!(f, "reconnecting"),
118 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
119 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
120 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
121 Self::ServerNotRunning { .. } => write!(f, "server not running"),
122 }
123 }
124}
125
126impl State {
127 fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
128 match self {
129 Self::Connected { ssh_connection, .. } => Some(ssh_connection.clone()),
130 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.clone()),
131 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.clone()),
132 _ => None,
133 }
134 }
135
136 fn can_reconnect(&self) -> bool {
137 match self {
138 Self::Connected { .. }
139 | Self::HeartbeatMissed { .. }
140 | Self::ReconnectFailed { .. } => true,
141 State::Connecting
142 | State::Reconnecting
143 | State::ReconnectExhausted
144 | State::ServerNotRunning => false,
145 }
146 }
147
148 fn is_reconnect_failed(&self) -> bool {
149 matches!(self, Self::ReconnectFailed { .. })
150 }
151
152 fn is_reconnect_exhausted(&self) -> bool {
153 matches!(self, Self::ReconnectExhausted { .. })
154 }
155
156 fn is_server_not_running(&self) -> bool {
157 matches!(self, Self::ServerNotRunning)
158 }
159
160 fn is_reconnecting(&self) -> bool {
161 matches!(self, Self::Reconnecting { .. })
162 }
163
164 fn heartbeat_recovered(self) -> Self {
165 match self {
166 Self::HeartbeatMissed {
167 ssh_connection,
168 delegate,
169 multiplex_task,
170 heartbeat_task,
171 ..
172 } => Self::Connected {
173 ssh_connection,
174 delegate,
175 multiplex_task,
176 heartbeat_task,
177 },
178 _ => self,
179 }
180 }
181
182 fn heartbeat_missed(self) -> Self {
183 match self {
184 Self::Connected {
185 ssh_connection,
186 delegate,
187 multiplex_task,
188 heartbeat_task,
189 } => Self::HeartbeatMissed {
190 missed_heartbeats: 1,
191 ssh_connection,
192 delegate,
193 multiplex_task,
194 heartbeat_task,
195 },
196 Self::HeartbeatMissed {
197 missed_heartbeats,
198 ssh_connection,
199 delegate,
200 multiplex_task,
201 heartbeat_task,
202 } => Self::HeartbeatMissed {
203 missed_heartbeats: missed_heartbeats + 1,
204 ssh_connection,
205 delegate,
206 multiplex_task,
207 heartbeat_task,
208 },
209 _ => self,
210 }
211 }
212}
213
214/// The state of the ssh connection.
215#[derive(Clone, Copy, Debug, PartialEq, Eq)]
216pub enum ConnectionState {
217 Connecting,
218 Connected,
219 HeartbeatMissed,
220 Reconnecting,
221 Disconnected,
222}
223
224impl From<&State> for ConnectionState {
225 fn from(value: &State) -> Self {
226 match value {
227 State::Connecting => Self::Connecting,
228 State::Connected { .. } => Self::Connected,
229 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
230 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
231 State::ReconnectExhausted => Self::Disconnected,
232 State::ServerNotRunning => Self::Disconnected,
233 }
234 }
235}
236
237pub struct RemoteClient {
238 client: Arc<ChannelClient>,
239 unique_identifier: String,
240 connection_options: SshConnectionOptions,
241 path_style: PathStyle,
242 state: Option<State>,
243}
244
245#[derive(Debug)]
246pub enum RemoteClientEvent {
247 Disconnected,
248}
249
250impl EventEmitter<RemoteClientEvent> for RemoteClient {}
251
252// Identifies the socket on the remote server so that reconnects
253// can re-join the same project.
254pub enum ConnectionIdentifier {
255 Setup(u64),
256 Workspace(i64),
257}
258
259static NEXT_ID: AtomicU64 = AtomicU64::new(1);
260
261impl ConnectionIdentifier {
262 pub fn setup() -> Self {
263 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
264 }
265
266 // This string gets used in a socket name, and so must be relatively short.
267 // The total length of:
268 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
269 // Must be less than about 100 characters
270 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
271 // So our strings should be at most 20 characters or so.
272 fn to_string(&self, cx: &App) -> String {
273 let identifier_prefix = match ReleaseChannel::global(cx) {
274 ReleaseChannel::Stable => "".to_string(),
275 release_channel => format!("{}-", release_channel.dev_name()),
276 };
277 match self {
278 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
279 Self::Workspace(workspace_id) => {
280 format!("{identifier_prefix}workspace-{workspace_id}",)
281 }
282 }
283 }
284}
285
286impl RemoteClient {
287 pub fn ssh(
288 unique_identifier: ConnectionIdentifier,
289 connection_options: SshConnectionOptions,
290 cancellation: oneshot::Receiver<()>,
291 delegate: Arc<dyn RemoteClientDelegate>,
292 cx: &mut App,
293 ) -> Task<Result<Option<Entity<Self>>>> {
294 let unique_identifier = unique_identifier.to_string(cx);
295 cx.spawn(async move |cx| {
296 let success = Box::pin(async move {
297 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
298 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
299 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
300
301 let client =
302 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
303
304 let ssh_connection = cx
305 .update(|cx| {
306 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
307 pool.connect(connection_options.clone(), &delegate, cx)
308 })
309 })?
310 .await
311 .map_err(|e| e.cloned())?;
312
313 let path_style = ssh_connection.path_style();
314 let this = cx.new(|_| Self {
315 client: client.clone(),
316 unique_identifier: unique_identifier.clone(),
317 connection_options,
318 path_style,
319 state: Some(State::Connecting),
320 })?;
321
322 let io_task = ssh_connection.start_proxy(
323 unique_identifier,
324 false,
325 incoming_tx,
326 outgoing_rx,
327 connection_activity_tx,
328 delegate.clone(),
329 cx,
330 );
331
332 let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
333
334 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
335 log::error!("failed to establish connection: {}", error);
336 return Err(error);
337 }
338
339 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
340
341 this.update(cx, |this, _| {
342 this.state = Some(State::Connected {
343 ssh_connection,
344 delegate,
345 multiplex_task,
346 heartbeat_task,
347 });
348 })?;
349
350 Ok(Some(this))
351 });
352
353 select! {
354 _ = cancellation.fuse() => {
355 Ok(None)
356 }
357 result = success.fuse() => result
358 }
359 })
360 }
361
362 pub fn proto_client_from_channels(
363 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
364 outgoing_tx: mpsc::UnboundedSender<Envelope>,
365 cx: &App,
366 name: &'static str,
367 ) -> AnyProtoClient {
368 ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
369 }
370
371 pub fn shutdown_processes<T: RequestMessage>(
372 &mut self,
373 shutdown_request: Option<T>,
374 executor: BackgroundExecutor,
375 ) -> Option<impl Future<Output = ()> + use<T>> {
376 let state = self.state.take()?;
377 log::info!("shutting down ssh processes");
378
379 let State::Connected {
380 multiplex_task,
381 heartbeat_task,
382 ssh_connection,
383 delegate,
384 } = state
385 else {
386 return None;
387 };
388
389 let client = self.client.clone();
390
391 Some(async move {
392 if let Some(shutdown_request) = shutdown_request {
393 client.send(shutdown_request).log_err();
394 // We wait 50ms instead of waiting for a response, because
395 // waiting for a response would require us to wait on the main thread
396 // which we want to avoid in an `on_app_quit` callback.
397 executor.timer(Duration::from_millis(50)).await;
398 }
399
400 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
401 // child of master_process.
402 drop(multiplex_task);
403 // Now drop the rest of state, which kills master process.
404 drop(heartbeat_task);
405 drop(ssh_connection);
406 drop(delegate);
407 })
408 }
409
410 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
411 let can_reconnect = self
412 .state
413 .as_ref()
414 .map(|state| state.can_reconnect())
415 .unwrap_or(false);
416 if !can_reconnect {
417 log::info!("aborting reconnect, because not in state that allows reconnecting");
418 let error = if let Some(state) = self.state.as_ref() {
419 format!("invalid state, cannot reconnect while in state {state}")
420 } else {
421 "no state set".to_string()
422 };
423 anyhow::bail!(error);
424 }
425
426 let state = self.state.take().unwrap();
427 let (attempts, ssh_connection, delegate) = match state {
428 State::Connected {
429 ssh_connection,
430 delegate,
431 multiplex_task,
432 heartbeat_task,
433 }
434 | State::HeartbeatMissed {
435 ssh_connection,
436 delegate,
437 multiplex_task,
438 heartbeat_task,
439 ..
440 } => {
441 drop(multiplex_task);
442 drop(heartbeat_task);
443 (0, ssh_connection, delegate)
444 }
445 State::ReconnectFailed {
446 attempts,
447 ssh_connection,
448 delegate,
449 ..
450 } => (attempts, ssh_connection, delegate),
451 State::Connecting
452 | State::Reconnecting
453 | State::ReconnectExhausted
454 | State::ServerNotRunning => unreachable!(),
455 };
456
457 let attempts = attempts + 1;
458 if attempts > MAX_RECONNECT_ATTEMPTS {
459 log::error!(
460 "Failed to reconnect to after {} attempts, giving up",
461 MAX_RECONNECT_ATTEMPTS
462 );
463 self.set_state(State::ReconnectExhausted, cx);
464 return Ok(());
465 }
466
467 self.set_state(State::Reconnecting, cx);
468
469 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
470
471 let unique_identifier = self.unique_identifier.clone();
472 let client = self.client.clone();
473 let reconnect_task = cx.spawn(async move |this, cx| {
474 macro_rules! failed {
475 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
476 return State::ReconnectFailed {
477 error: anyhow!($error),
478 attempts: $attempts,
479 ssh_connection: $ssh_connection,
480 delegate: $delegate,
481 };
482 };
483 }
484
485 if let Err(error) = ssh_connection
486 .kill()
487 .await
488 .context("Failed to kill ssh process")
489 {
490 failed!(error, attempts, ssh_connection, delegate);
491 };
492
493 let connection_options = ssh_connection.connection_options();
494
495 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
496 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
497 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
498
499 let (ssh_connection, io_task) = match async {
500 let ssh_connection = cx
501 .update_global(|pool: &mut ConnectionPool, cx| {
502 pool.connect(connection_options, &delegate, cx)
503 })?
504 .await
505 .map_err(|error| error.cloned())?;
506
507 let io_task = ssh_connection.start_proxy(
508 unique_identifier,
509 true,
510 incoming_tx,
511 outgoing_rx,
512 connection_activity_tx,
513 delegate.clone(),
514 cx,
515 );
516 anyhow::Ok((ssh_connection, io_task))
517 }
518 .await
519 {
520 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
521 Err(error) => {
522 failed!(error, attempts, ssh_connection, delegate);
523 }
524 };
525
526 let multiplex_task = Self::monitor(this.clone(), io_task, cx);
527 client.reconnect(incoming_rx, outgoing_tx, cx);
528
529 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
530 failed!(error, attempts, ssh_connection, delegate);
531 };
532
533 State::Connected {
534 ssh_connection,
535 delegate,
536 multiplex_task,
537 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
538 }
539 });
540
541 cx.spawn(async move |this, cx| {
542 let new_state = reconnect_task.await;
543 this.update(cx, |this, cx| {
544 this.try_set_state(cx, |old_state| {
545 if old_state.is_reconnecting() {
546 match &new_state {
547 State::Connecting
548 | State::Reconnecting
549 | State::HeartbeatMissed { .. }
550 | State::ServerNotRunning => {}
551 State::Connected { .. } => {
552 log::info!("Successfully reconnected");
553 }
554 State::ReconnectFailed {
555 error, attempts, ..
556 } => {
557 log::error!(
558 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
559 attempts,
560 error
561 );
562 }
563 State::ReconnectExhausted => {
564 log::error!("Reconnect attempt failed and all attempts exhausted");
565 }
566 }
567 Some(new_state)
568 } else {
569 None
570 }
571 });
572
573 if this.state_is(State::is_reconnect_failed) {
574 this.reconnect(cx)
575 } else if this.state_is(State::is_reconnect_exhausted) {
576 Ok(())
577 } else {
578 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
579 Ok(())
580 }
581 })
582 })
583 .detach_and_log_err(cx);
584
585 Ok(())
586 }
587
588 fn heartbeat(
589 this: WeakEntity<Self>,
590 mut connection_activity_rx: mpsc::Receiver<()>,
591 cx: &mut AsyncApp,
592 ) -> Task<Result<()>> {
593 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
594 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
595 };
596
597 cx.spawn(async move |cx| {
598 let mut missed_heartbeats = 0;
599
600 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
601 futures::pin_mut!(keepalive_timer);
602
603 loop {
604 select_biased! {
605 result = connection_activity_rx.next().fuse() => {
606 if result.is_none() {
607 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
608 return Ok(());
609 }
610
611 if missed_heartbeats != 0 {
612 missed_heartbeats = 0;
613 let _ =this.update(cx, |this, cx| {
614 this.handle_heartbeat_result(missed_heartbeats, cx)
615 })?;
616 }
617 }
618 _ = keepalive_timer => {
619 log::debug!("Sending heartbeat to server...");
620
621 let result = select_biased! {
622 _ = connection_activity_rx.next().fuse() => {
623 Ok(())
624 }
625 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
626 ping_result
627 }
628 };
629
630 if result.is_err() {
631 missed_heartbeats += 1;
632 log::warn!(
633 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
634 HEARTBEAT_TIMEOUT,
635 missed_heartbeats,
636 MAX_MISSED_HEARTBEATS
637 );
638 } else if missed_heartbeats != 0 {
639 missed_heartbeats = 0;
640 } else {
641 continue;
642 }
643
644 let result = this.update(cx, |this, cx| {
645 this.handle_heartbeat_result(missed_heartbeats, cx)
646 })?;
647 if result.is_break() {
648 return Ok(());
649 }
650 }
651 }
652
653 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
654 }
655 })
656 }
657
658 fn handle_heartbeat_result(
659 &mut self,
660 missed_heartbeats: usize,
661 cx: &mut Context<Self>,
662 ) -> ControlFlow<()> {
663 let state = self.state.take().unwrap();
664 let next_state = if missed_heartbeats > 0 {
665 state.heartbeat_missed()
666 } else {
667 state.heartbeat_recovered()
668 };
669
670 self.set_state(next_state, cx);
671
672 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
673 log::error!(
674 "Missed last {} heartbeats. Reconnecting...",
675 missed_heartbeats
676 );
677
678 self.reconnect(cx)
679 .context("failed to start reconnect process after missing heartbeats")
680 .log_err();
681 ControlFlow::Break(())
682 } else {
683 ControlFlow::Continue(())
684 }
685 }
686
687 fn monitor(
688 this: WeakEntity<Self>,
689 io_task: Task<Result<i32>>,
690 cx: &AsyncApp,
691 ) -> Task<Result<()>> {
692 cx.spawn(async move |cx| {
693 let result = io_task.await;
694
695 match result {
696 Ok(exit_code) => {
697 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
698 match error {
699 ProxyLaunchError::ServerNotRunning => {
700 log::error!("failed to reconnect because server is not running");
701 this.update(cx, |this, cx| {
702 this.set_state(State::ServerNotRunning, cx);
703 })?;
704 }
705 }
706 } else if exit_code > 0 {
707 log::error!("proxy process terminated unexpectedly");
708 this.update(cx, |this, cx| {
709 this.reconnect(cx).ok();
710 })?;
711 }
712 }
713 Err(error) => {
714 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
715 this.update(cx, |this, cx| {
716 this.reconnect(cx).ok();
717 })?;
718 }
719 }
720
721 Ok(())
722 })
723 }
724
725 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
726 self.state.as_ref().is_some_and(check)
727 }
728
729 fn try_set_state(&mut self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
730 let new_state = self.state.as_ref().and_then(map);
731 if let Some(new_state) = new_state {
732 self.state.replace(new_state);
733 cx.notify();
734 }
735 }
736
737 fn set_state(&mut self, state: State, cx: &mut Context<Self>) {
738 log::info!("setting state to '{}'", &state);
739
740 let is_reconnect_exhausted = state.is_reconnect_exhausted();
741 let is_server_not_running = state.is_server_not_running();
742 self.state.replace(state);
743
744 if is_reconnect_exhausted || is_server_not_running {
745 cx.emit(RemoteClientEvent::Disconnected);
746 }
747 cx.notify();
748 }
749
750 pub fn shell(&self) -> Option<String> {
751 Some(self.state.as_ref()?.remote_connection()?.shell())
752 }
753
754 pub fn build_command(
755 &self,
756 program: Option<String>,
757 args: &[String],
758 env: &HashMap<String, String>,
759 working_dir: Option<String>,
760 port_forward: Option<(u16, String, u16)>,
761 ) -> Result<CommandTemplate> {
762 let Some(connection) = self
763 .state
764 .as_ref()
765 .and_then(|state| state.remote_connection())
766 else {
767 return Err(anyhow!("no connection"));
768 };
769 connection.build_command(program, args, env, working_dir, port_forward)
770 }
771
772 pub fn upload_directory(
773 &self,
774 src_path: PathBuf,
775 dest_path: RemotePathBuf,
776 cx: &App,
777 ) -> Task<Result<()>> {
778 let Some(connection) = self
779 .state
780 .as_ref()
781 .and_then(|state| state.remote_connection())
782 else {
783 return Task::ready(Err(anyhow!("no ssh connection")));
784 };
785 connection.upload_directory(src_path, dest_path, cx)
786 }
787
788 pub fn proto_client(&self) -> AnyProtoClient {
789 self.client.clone().into()
790 }
791
792 pub fn host(&self) -> String {
793 self.connection_options.host.clone()
794 }
795
796 pub fn connection_options(&self) -> SshConnectionOptions {
797 self.connection_options.clone()
798 }
799
800 pub fn connection_state(&self) -> ConnectionState {
801 self.state
802 .as_ref()
803 .map(ConnectionState::from)
804 .unwrap_or(ConnectionState::Disconnected)
805 }
806
807 pub fn is_disconnected(&self) -> bool {
808 self.connection_state() == ConnectionState::Disconnected
809 }
810
811 pub fn path_style(&self) -> PathStyle {
812 self.path_style
813 }
814
815 #[cfg(any(test, feature = "test-support"))]
816 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
817 let opts = self.connection_options();
818 client_cx.spawn(async move |cx| {
819 let connection = cx
820 .update_global(|c: &mut ConnectionPool, _| {
821 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
822 c.clone()
823 } else {
824 panic!("missing test connection")
825 }
826 })
827 .unwrap()
828 .await
829 .unwrap();
830
831 connection.simulate_disconnect(cx);
832 })
833 }
834
835 #[cfg(any(test, feature = "test-support"))]
836 pub fn fake_server(
837 client_cx: &mut gpui::TestAppContext,
838 server_cx: &mut gpui::TestAppContext,
839 ) -> (SshConnectionOptions, AnyProtoClient) {
840 let port = client_cx
841 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
842 let opts = SshConnectionOptions {
843 host: "<fake>".to_string(),
844 port: Some(port),
845 ..Default::default()
846 };
847 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
848 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
849 let server_client =
850 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
851 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
852 connection_options: opts.clone(),
853 server_cx: fake::SendableCx::new(server_cx),
854 server_channel: server_client.clone(),
855 });
856
857 client_cx.update(|cx| {
858 cx.update_default_global(|c: &mut ConnectionPool, cx| {
859 c.connections.insert(
860 opts.clone(),
861 ConnectionPoolEntry::Connecting(
862 cx.background_spawn({
863 let connection = connection.clone();
864 async move { Ok(connection.clone()) }
865 })
866 .shared(),
867 ),
868 );
869 })
870 });
871
872 (opts, server_client.into())
873 }
874
875 #[cfg(any(test, feature = "test-support"))]
876 pub async fn fake_client(
877 opts: SshConnectionOptions,
878 client_cx: &mut gpui::TestAppContext,
879 ) -> Entity<Self> {
880 let (_tx, rx) = oneshot::channel();
881 client_cx
882 .update(|cx| {
883 Self::ssh(
884 ConnectionIdentifier::setup(),
885 opts,
886 rx,
887 Arc::new(fake::Delegate),
888 cx,
889 )
890 })
891 .await
892 .unwrap()
893 .unwrap()
894 }
895}
896
897enum ConnectionPoolEntry {
898 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
899 Connected(Weak<dyn RemoteConnection>),
900}
901
902#[derive(Default)]
903struct ConnectionPool {
904 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
905}
906
907impl Global for ConnectionPool {}
908
909impl ConnectionPool {
910 pub fn connect(
911 &mut self,
912 opts: SshConnectionOptions,
913 delegate: &Arc<dyn RemoteClientDelegate>,
914 cx: &mut App,
915 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
916 let connection = self.connections.get(&opts);
917 match connection {
918 Some(ConnectionPoolEntry::Connecting(task)) => {
919 let delegate = delegate.clone();
920 cx.spawn(async move |cx| {
921 delegate.set_status(Some("Waiting for existing connection attempt"), cx);
922 })
923 .detach();
924 return task.clone();
925 }
926 Some(ConnectionPoolEntry::Connected(ssh)) => {
927 if let Some(ssh) = ssh.upgrade()
928 && !ssh.has_been_killed()
929 {
930 return Task::ready(Ok(ssh)).shared();
931 }
932 self.connections.remove(&opts);
933 }
934 None => {}
935 }
936
937 let task = cx
938 .spawn({
939 let opts = opts.clone();
940 let delegate = delegate.clone();
941 async move |cx| {
942 let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
943 .await
944 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
945
946 cx.update_global(|pool: &mut Self, _| {
947 debug_assert!(matches!(
948 pool.connections.get(&opts),
949 Some(ConnectionPoolEntry::Connecting(_))
950 ));
951 match connection {
952 Ok(connection) => {
953 pool.connections.insert(
954 opts.clone(),
955 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
956 );
957 Ok(connection)
958 }
959 Err(error) => {
960 pool.connections.remove(&opts);
961 Err(Arc::new(error))
962 }
963 }
964 })?
965 }
966 })
967 .shared();
968
969 self.connections
970 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
971 task
972 }
973}
974
975#[async_trait(?Send)]
976pub(crate) trait RemoteConnection: Send + Sync {
977 fn start_proxy(
978 &self,
979 unique_identifier: String,
980 reconnect: bool,
981 incoming_tx: UnboundedSender<Envelope>,
982 outgoing_rx: UnboundedReceiver<Envelope>,
983 connection_activity_tx: Sender<()>,
984 delegate: Arc<dyn RemoteClientDelegate>,
985 cx: &mut AsyncApp,
986 ) -> Task<Result<i32>>;
987 fn upload_directory(
988 &self,
989 src_path: PathBuf,
990 dest_path: RemotePathBuf,
991 cx: &App,
992 ) -> Task<Result<()>>;
993 async fn kill(&self) -> Result<()>;
994 fn has_been_killed(&self) -> bool;
995 fn build_command(
996 &self,
997 program: Option<String>,
998 args: &[String],
999 env: &HashMap<String, String>,
1000 working_dir: Option<String>,
1001 port_forward: Option<(u16, String, u16)>,
1002 ) -> Result<CommandTemplate>;
1003 fn connection_options(&self) -> SshConnectionOptions;
1004 fn path_style(&self) -> PathStyle;
1005 fn shell(&self) -> String;
1006
1007 #[cfg(any(test, feature = "test-support"))]
1008 fn simulate_disconnect(&self, _: &AsyncApp) {}
1009}
1010
1011type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1012
1013struct ChannelClient {
1014 next_message_id: AtomicU32,
1015 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1016 buffer: Mutex<VecDeque<Envelope>>,
1017 response_channels: ResponseChannels,
1018 message_handlers: Mutex<ProtoMessageHandlerSet>,
1019 max_received: AtomicU32,
1020 name: &'static str,
1021 task: Mutex<Task<Result<()>>>,
1022}
1023
1024impl ChannelClient {
1025 fn new(
1026 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1027 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1028 cx: &App,
1029 name: &'static str,
1030 ) -> Arc<Self> {
1031 Arc::new_cyclic(|this| Self {
1032 outgoing_tx: Mutex::new(outgoing_tx),
1033 next_message_id: AtomicU32::new(0),
1034 max_received: AtomicU32::new(0),
1035 response_channels: ResponseChannels::default(),
1036 message_handlers: Default::default(),
1037 buffer: Mutex::new(VecDeque::new()),
1038 name,
1039 task: Mutex::new(Self::start_handling_messages(
1040 this.clone(),
1041 incoming_rx,
1042 &cx.to_async(),
1043 )),
1044 })
1045 }
1046
1047 fn start_handling_messages(
1048 this: Weak<Self>,
1049 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1050 cx: &AsyncApp,
1051 ) -> Task<Result<()>> {
1052 cx.spawn(async move |cx| {
1053 let peer_id = PeerId { owner_id: 0, id: 0 };
1054 while let Some(incoming) = incoming_rx.next().await {
1055 let Some(this) = this.upgrade() else {
1056 return anyhow::Ok(());
1057 };
1058 if let Some(ack_id) = incoming.ack_id {
1059 let mut buffer = this.buffer.lock();
1060 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1061 buffer.pop_front();
1062 }
1063 }
1064 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
1065 {
1066 log::debug!(
1067 "{}:ssh message received. name:FlushBufferedMessages",
1068 this.name
1069 );
1070 {
1071 let buffer = this.buffer.lock();
1072 for envelope in buffer.iter() {
1073 this.outgoing_tx
1074 .lock()
1075 .unbounded_send(envelope.clone())
1076 .ok();
1077 }
1078 }
1079 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1080 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1081 this.outgoing_tx.lock().unbounded_send(envelope).ok();
1082 continue;
1083 }
1084
1085 this.max_received.store(incoming.id, SeqCst);
1086
1087 if let Some(request_id) = incoming.responding_to {
1088 let request_id = MessageId(request_id);
1089 let sender = this.response_channels.lock().remove(&request_id);
1090 if let Some(sender) = sender {
1091 let (tx, rx) = oneshot::channel();
1092 if incoming.payload.is_some() {
1093 sender.send((incoming, tx)).ok();
1094 }
1095 rx.await.ok();
1096 }
1097 } else if let Some(envelope) =
1098 build_typed_envelope(peer_id, Instant::now(), incoming)
1099 {
1100 let type_name = envelope.payload_type_name();
1101 let message_id = envelope.message_id();
1102 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1103 &this.message_handlers,
1104 envelope,
1105 this.clone().into(),
1106 cx.clone(),
1107 ) {
1108 log::debug!("{}:ssh message received. name:{type_name}", this.name);
1109 cx.foreground_executor()
1110 .spawn(async move {
1111 match future.await {
1112 Ok(_) => {
1113 log::debug!(
1114 "{}:ssh message handled. name:{type_name}",
1115 this.name
1116 );
1117 }
1118 Err(error) => {
1119 log::error!(
1120 "{}:error handling message. type:{}, error:{}",
1121 this.name,
1122 type_name,
1123 format!("{error:#}").lines().fold(
1124 String::new(),
1125 |mut message, line| {
1126 if !message.is_empty() {
1127 message.push(' ');
1128 }
1129 message.push_str(line);
1130 message
1131 }
1132 )
1133 );
1134 }
1135 }
1136 })
1137 .detach()
1138 } else {
1139 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1140 if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
1141 message_id,
1142 anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
1143 ) {
1144 log::error!(
1145 "{}:error sending error response for {type_name}:{e:#}",
1146 this.name
1147 );
1148 }
1149 }
1150 }
1151 }
1152 anyhow::Ok(())
1153 })
1154 }
1155
1156 fn reconnect(
1157 self: &Arc<Self>,
1158 incoming_rx: UnboundedReceiver<Envelope>,
1159 outgoing_tx: UnboundedSender<Envelope>,
1160 cx: &AsyncApp,
1161 ) {
1162 *self.outgoing_tx.lock() = outgoing_tx;
1163 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1164 }
1165
1166 fn request<T: RequestMessage>(
1167 &self,
1168 payload: T,
1169 ) -> impl 'static + Future<Output = Result<T::Response>> {
1170 self.request_internal(payload, true)
1171 }
1172
1173 fn request_internal<T: RequestMessage>(
1174 &self,
1175 payload: T,
1176 use_buffer: bool,
1177 ) -> impl 'static + Future<Output = Result<T::Response>> {
1178 log::debug!("ssh request start. name:{}", T::NAME);
1179 let response =
1180 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
1181 async move {
1182 let response = response.await?;
1183 log::debug!("ssh request finish. name:{}", T::NAME);
1184 T::Response::from_envelope(response).context("received a response of the wrong type")
1185 }
1186 }
1187
1188 async fn resync(&self, timeout: Duration) -> Result<()> {
1189 smol::future::or(
1190 async {
1191 self.request_internal(proto::FlushBufferedMessages {}, false)
1192 .await?;
1193
1194 for envelope in self.buffer.lock().iter() {
1195 self.outgoing_tx
1196 .lock()
1197 .unbounded_send(envelope.clone())
1198 .ok();
1199 }
1200 Ok(())
1201 },
1202 async {
1203 smol::Timer::after(timeout).await;
1204 anyhow::bail!("Timed out resyncing remote client")
1205 },
1206 )
1207 .await
1208 }
1209
1210 async fn ping(&self, timeout: Duration) -> Result<()> {
1211 smol::future::or(
1212 async {
1213 self.request(proto::Ping {}).await?;
1214 Ok(())
1215 },
1216 async {
1217 smol::Timer::after(timeout).await;
1218 anyhow::bail!("Timed out pinging remote client")
1219 },
1220 )
1221 .await
1222 }
1223
1224 fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1225 log::debug!("ssh send name:{}", T::NAME);
1226 self.send_dynamic(payload.into_envelope(0, None, None))
1227 }
1228
1229 fn request_dynamic(
1230 &self,
1231 mut envelope: proto::Envelope,
1232 type_name: &'static str,
1233 use_buffer: bool,
1234 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1235 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1236 let (tx, rx) = oneshot::channel();
1237 let mut response_channels_lock = self.response_channels.lock();
1238 response_channels_lock.insert(MessageId(envelope.id), tx);
1239 drop(response_channels_lock);
1240
1241 let result = if use_buffer {
1242 self.send_buffered(envelope)
1243 } else {
1244 self.send_unbuffered(envelope)
1245 };
1246 async move {
1247 if let Err(error) = &result {
1248 log::error!("failed to send message: {error}");
1249 anyhow::bail!("failed to send message: {error}");
1250 }
1251
1252 let response = rx.await.context("connection lost")?.0;
1253 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1254 return Err(RpcError::from_proto(error, type_name));
1255 }
1256 Ok(response)
1257 }
1258 }
1259
1260 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1261 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1262 self.send_buffered(envelope)
1263 }
1264
1265 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1266 envelope.ack_id = Some(self.max_received.load(SeqCst));
1267 self.buffer.lock().push_back(envelope.clone());
1268 // ignore errors on send (happen while we're reconnecting)
1269 // assume that the global "disconnected" overlay is sufficient.
1270 self.outgoing_tx.lock().unbounded_send(envelope).ok();
1271 Ok(())
1272 }
1273
1274 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1275 envelope.ack_id = Some(self.max_received.load(SeqCst));
1276 self.outgoing_tx.lock().unbounded_send(envelope).ok();
1277 Ok(())
1278 }
1279}
1280
1281impl ProtoClient for ChannelClient {
1282 fn request(
1283 &self,
1284 envelope: proto::Envelope,
1285 request_type: &'static str,
1286 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1287 self.request_dynamic(envelope, request_type, true).boxed()
1288 }
1289
1290 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1291 self.send_dynamic(envelope)
1292 }
1293
1294 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1295 self.send_dynamic(envelope)
1296 }
1297
1298 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1299 &self.message_handlers
1300 }
1301
1302 fn is_via_collab(&self) -> bool {
1303 false
1304 }
1305}
1306
1307#[cfg(any(test, feature = "test-support"))]
1308mod fake {
1309 use super::{ChannelClient, RemoteClientDelegate, RemoteConnection, RemotePlatform};
1310 use crate::{SshConnectionOptions, remote_client::CommandTemplate};
1311 use anyhow::Result;
1312 use async_trait::async_trait;
1313 use collections::HashMap;
1314 use futures::{
1315 FutureExt, SinkExt, StreamExt,
1316 channel::{
1317 mpsc::{self, Sender},
1318 oneshot,
1319 },
1320 select_biased,
1321 };
1322 use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
1323 use release_channel::ReleaseChannel;
1324 use rpc::proto::Envelope;
1325 use std::{path::PathBuf, sync::Arc};
1326 use util::paths::{PathStyle, RemotePathBuf};
1327
1328 pub(super) struct FakeRemoteConnection {
1329 pub(super) connection_options: SshConnectionOptions,
1330 pub(super) server_channel: Arc<ChannelClient>,
1331 pub(super) server_cx: SendableCx,
1332 }
1333
1334 pub(super) struct SendableCx(AsyncApp);
1335 impl SendableCx {
1336 // SAFETY: When run in test mode, GPUI is always single threaded.
1337 pub(super) fn new(cx: &TestAppContext) -> Self {
1338 Self(cx.to_async())
1339 }
1340
1341 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
1342 fn get(&self, _: &AsyncApp) -> AsyncApp {
1343 self.0.clone()
1344 }
1345 }
1346
1347 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
1348 unsafe impl Send for SendableCx {}
1349 unsafe impl Sync for SendableCx {}
1350
1351 #[async_trait(?Send)]
1352 impl RemoteConnection for FakeRemoteConnection {
1353 async fn kill(&self) -> Result<()> {
1354 Ok(())
1355 }
1356
1357 fn has_been_killed(&self) -> bool {
1358 false
1359 }
1360
1361 fn build_command(
1362 &self,
1363 program: Option<String>,
1364 args: &[String],
1365 env: &HashMap<String, String>,
1366 _: Option<String>,
1367 _: Option<(u16, String, u16)>,
1368 ) -> Result<CommandTemplate> {
1369 let ssh_program = program.unwrap_or_else(|| "sh".to_string());
1370 let mut ssh_args = Vec::new();
1371 ssh_args.push(ssh_program);
1372 ssh_args.extend(args.iter().cloned());
1373 Ok(CommandTemplate {
1374 program: "ssh".into(),
1375 args: ssh_args,
1376 env: env.clone(),
1377 })
1378 }
1379
1380 fn upload_directory(
1381 &self,
1382 _src_path: PathBuf,
1383 _dest_path: RemotePathBuf,
1384 _cx: &App,
1385 ) -> Task<Result<()>> {
1386 unreachable!()
1387 }
1388
1389 fn connection_options(&self) -> SshConnectionOptions {
1390 self.connection_options.clone()
1391 }
1392
1393 fn simulate_disconnect(&self, cx: &AsyncApp) {
1394 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1395 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1396 self.server_channel
1397 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
1398 }
1399
1400 fn start_proxy(
1401 &self,
1402 _unique_identifier: String,
1403 _reconnect: bool,
1404 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
1405 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
1406 mut connection_activity_tx: Sender<()>,
1407 _delegate: Arc<dyn RemoteClientDelegate>,
1408 cx: &mut AsyncApp,
1409 ) -> Task<Result<i32>> {
1410 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
1411 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
1412
1413 self.server_channel.reconnect(
1414 server_incoming_rx,
1415 server_outgoing_tx,
1416 &self.server_cx.get(cx),
1417 );
1418
1419 cx.background_spawn(async move {
1420 loop {
1421 select_biased! {
1422 server_to_client = server_outgoing_rx.next().fuse() => {
1423 let Some(server_to_client) = server_to_client else {
1424 return Ok(1)
1425 };
1426 connection_activity_tx.try_send(()).ok();
1427 client_incoming_tx.send(server_to_client).await.ok();
1428 }
1429 client_to_server = client_outgoing_rx.next().fuse() => {
1430 let Some(client_to_server) = client_to_server else {
1431 return Ok(1)
1432 };
1433 server_incoming_tx.send(client_to_server).await.ok();
1434 }
1435 }
1436 }
1437 })
1438 }
1439
1440 fn path_style(&self) -> PathStyle {
1441 PathStyle::current()
1442 }
1443
1444 fn shell(&self) -> String {
1445 "sh".to_owned()
1446 }
1447 }
1448
1449 pub(super) struct Delegate;
1450
1451 impl RemoteClientDelegate for Delegate {
1452 fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
1453 unreachable!()
1454 }
1455
1456 fn download_server_binary_locally(
1457 &self,
1458 _: RemotePlatform,
1459 _: ReleaseChannel,
1460 _: Option<SemanticVersion>,
1461 _: &mut AsyncApp,
1462 ) -> Task<Result<PathBuf>> {
1463 unreachable!()
1464 }
1465
1466 fn get_download_params(
1467 &self,
1468 _platform: RemotePlatform,
1469 _release_channel: ReleaseChannel,
1470 _version: Option<SemanticVersion>,
1471 _cx: &mut AsyncApp,
1472 ) -> Task<Result<Option<(String, String)>>> {
1473 unreachable!()
1474 }
1475
1476 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
1477 }
1478}