1use crate::{
2 SshConnectionOptions, protocol::MessageId, proxy::ProxyLaunchError,
3 transport::ssh::SshRemoteConnection,
4};
5use anyhow::{Context as _, Result, anyhow};
6use async_trait::async_trait;
7use collections::HashMap;
8use futures::{
9 Future, FutureExt as _, StreamExt as _,
10 channel::{
11 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
12 oneshot,
13 },
14 future::{BoxFuture, Shared},
15 select, select_biased,
16};
17use gpui::{
18 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
19 EventEmitter, Global, SemanticVersion, Task, WeakEntity,
20};
21use parking_lot::Mutex;
22
23use release_channel::ReleaseChannel;
24use rpc::{
25 AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
26 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
27};
28use std::{
29 collections::VecDeque,
30 fmt,
31 ops::ControlFlow,
32 path::PathBuf,
33 sync::{
34 Arc, Weak,
35 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
36 },
37 time::{Duration, Instant},
38};
39use util::{
40 ResultExt,
41 paths::{PathStyle, RemotePathBuf},
42};
43
44#[derive(Copy, Clone, Debug)]
45pub struct RemotePlatform {
46 pub os: &'static str,
47 pub arch: &'static str,
48}
49
50#[derive(Clone, Debug)]
51pub struct CommandTemplate {
52 pub program: String,
53 pub args: Vec<String>,
54 pub env: HashMap<String, String>,
55}
56
57pub trait RemoteClientDelegate: Send + Sync {
58 fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
59 fn get_download_params(
60 &self,
61 platform: RemotePlatform,
62 release_channel: ReleaseChannel,
63 version: Option<SemanticVersion>,
64 cx: &mut AsyncApp,
65 ) -> Task<Result<Option<(String, String)>>>;
66 fn download_server_binary_locally(
67 &self,
68 platform: RemotePlatform,
69 release_channel: ReleaseChannel,
70 version: Option<SemanticVersion>,
71 cx: &mut AsyncApp,
72 ) -> Task<Result<PathBuf>>;
73 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
74}
75
76const MAX_MISSED_HEARTBEATS: usize = 5;
77const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
78const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
79
80const MAX_RECONNECT_ATTEMPTS: usize = 3;
81
82enum State {
83 Connecting,
84 Connected {
85 ssh_connection: Arc<dyn RemoteConnection>,
86 delegate: Arc<dyn RemoteClientDelegate>,
87
88 multiplex_task: Task<Result<()>>,
89 heartbeat_task: Task<Result<()>>,
90 },
91 HeartbeatMissed {
92 missed_heartbeats: usize,
93
94 ssh_connection: Arc<dyn RemoteConnection>,
95 delegate: Arc<dyn RemoteClientDelegate>,
96
97 multiplex_task: Task<Result<()>>,
98 heartbeat_task: Task<Result<()>>,
99 },
100 Reconnecting,
101 ReconnectFailed {
102 ssh_connection: Arc<dyn RemoteConnection>,
103 delegate: Arc<dyn RemoteClientDelegate>,
104
105 error: anyhow::Error,
106 attempts: usize,
107 },
108 ReconnectExhausted,
109 ServerNotRunning,
110}
111
112impl fmt::Display for State {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 match self {
115 Self::Connecting => write!(f, "connecting"),
116 Self::Connected { .. } => write!(f, "connected"),
117 Self::Reconnecting => write!(f, "reconnecting"),
118 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
119 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
120 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
121 Self::ServerNotRunning { .. } => write!(f, "server not running"),
122 }
123 }
124}
125
126impl State {
127 fn remote_connection(&self) -> Option<Arc<dyn RemoteConnection>> {
128 match self {
129 Self::Connected { ssh_connection, .. } => Some(ssh_connection.clone()),
130 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.clone()),
131 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.clone()),
132 _ => None,
133 }
134 }
135
136 fn can_reconnect(&self) -> bool {
137 match self {
138 Self::Connected { .. }
139 | Self::HeartbeatMissed { .. }
140 | Self::ReconnectFailed { .. } => true,
141 State::Connecting
142 | State::Reconnecting
143 | State::ReconnectExhausted
144 | State::ServerNotRunning => false,
145 }
146 }
147
148 fn is_reconnect_failed(&self) -> bool {
149 matches!(self, Self::ReconnectFailed { .. })
150 }
151
152 fn is_reconnect_exhausted(&self) -> bool {
153 matches!(self, Self::ReconnectExhausted { .. })
154 }
155
156 fn is_server_not_running(&self) -> bool {
157 matches!(self, Self::ServerNotRunning)
158 }
159
160 fn is_reconnecting(&self) -> bool {
161 matches!(self, Self::Reconnecting { .. })
162 }
163
164 fn heartbeat_recovered(self) -> Self {
165 match self {
166 Self::HeartbeatMissed {
167 ssh_connection,
168 delegate,
169 multiplex_task,
170 heartbeat_task,
171 ..
172 } => Self::Connected {
173 ssh_connection,
174 delegate,
175 multiplex_task,
176 heartbeat_task,
177 },
178 _ => self,
179 }
180 }
181
182 fn heartbeat_missed(self) -> Self {
183 match self {
184 Self::Connected {
185 ssh_connection,
186 delegate,
187 multiplex_task,
188 heartbeat_task,
189 } => Self::HeartbeatMissed {
190 missed_heartbeats: 1,
191 ssh_connection,
192 delegate,
193 multiplex_task,
194 heartbeat_task,
195 },
196 Self::HeartbeatMissed {
197 missed_heartbeats,
198 ssh_connection,
199 delegate,
200 multiplex_task,
201 heartbeat_task,
202 } => Self::HeartbeatMissed {
203 missed_heartbeats: missed_heartbeats + 1,
204 ssh_connection,
205 delegate,
206 multiplex_task,
207 heartbeat_task,
208 },
209 _ => self,
210 }
211 }
212}
213
214/// The state of the ssh connection.
215#[derive(Clone, Copy, Debug, PartialEq, Eq)]
216pub enum ConnectionState {
217 Connecting,
218 Connected,
219 HeartbeatMissed,
220 Reconnecting,
221 Disconnected,
222}
223
224impl From<&State> for ConnectionState {
225 fn from(value: &State) -> Self {
226 match value {
227 State::Connecting => Self::Connecting,
228 State::Connected { .. } => Self::Connected,
229 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
230 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
231 State::ReconnectExhausted => Self::Disconnected,
232 State::ServerNotRunning => Self::Disconnected,
233 }
234 }
235}
236
237pub struct RemoteClient {
238 client: Arc<ChannelClient>,
239 unique_identifier: String,
240 connection_options: SshConnectionOptions,
241 path_style: PathStyle,
242 state: Option<State>,
243}
244
245#[derive(Debug)]
246pub enum RemoteClientEvent {
247 Disconnected,
248}
249
250impl EventEmitter<RemoteClientEvent> for RemoteClient {}
251
252// Identifies the socket on the remote server so that reconnects
253// can re-join the same project.
254pub enum ConnectionIdentifier {
255 Setup(u64),
256 Workspace(i64),
257}
258
259static NEXT_ID: AtomicU64 = AtomicU64::new(1);
260
261impl ConnectionIdentifier {
262 pub fn setup() -> Self {
263 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
264 }
265
266 // This string gets used in a socket name, and so must be relatively short.
267 // The total length of:
268 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
269 // Must be less than about 100 characters
270 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
271 // So our strings should be at most 20 characters or so.
272 fn to_string(&self, cx: &App) -> String {
273 let identifier_prefix = match ReleaseChannel::global(cx) {
274 ReleaseChannel::Stable => "".to_string(),
275 release_channel => format!("{}-", release_channel.dev_name()),
276 };
277 match self {
278 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
279 Self::Workspace(workspace_id) => {
280 format!("{identifier_prefix}workspace-{workspace_id}",)
281 }
282 }
283 }
284}
285
286impl RemoteClient {
287 pub fn ssh(
288 unique_identifier: ConnectionIdentifier,
289 connection_options: SshConnectionOptions,
290 cancellation: oneshot::Receiver<()>,
291 delegate: Arc<dyn RemoteClientDelegate>,
292 cx: &mut App,
293 ) -> Task<Result<Option<Entity<Self>>>> {
294 let unique_identifier = unique_identifier.to_string(cx);
295 cx.spawn(async move |cx| {
296 let success = Box::pin(async move {
297 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
298 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
299 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
300
301 let client =
302 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
303
304 let ssh_connection = cx
305 .update(|cx| {
306 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
307 pool.connect(connection_options.clone(), &delegate, cx)
308 })
309 })?
310 .await
311 .map_err(|e| e.cloned())?;
312
313 let path_style = ssh_connection.path_style();
314 let this = cx.new(|_| Self {
315 client: client.clone(),
316 unique_identifier: unique_identifier.clone(),
317 connection_options,
318 path_style,
319 state: Some(State::Connecting),
320 })?;
321
322 let io_task = ssh_connection.start_proxy(
323 unique_identifier,
324 false,
325 incoming_tx,
326 outgoing_rx,
327 connection_activity_tx,
328 delegate.clone(),
329 cx,
330 );
331
332 let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
333
334 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
335 log::error!("failed to establish connection: {}", error);
336 return Err(error);
337 }
338
339 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
340
341 this.update(cx, |this, _| {
342 this.state = Some(State::Connected {
343 ssh_connection,
344 delegate,
345 multiplex_task,
346 heartbeat_task,
347 });
348 })?;
349
350 Ok(Some(this))
351 });
352
353 select! {
354 _ = cancellation.fuse() => {
355 Ok(None)
356 }
357 result = success.fuse() => result
358 }
359 })
360 }
361
362 pub fn proto_client_from_channels(
363 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
364 outgoing_tx: mpsc::UnboundedSender<Envelope>,
365 cx: &App,
366 name: &'static str,
367 ) -> AnyProtoClient {
368 ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
369 }
370
371 pub fn shutdown_processes<T: RequestMessage>(
372 &mut self,
373 shutdown_request: Option<T>,
374 executor: BackgroundExecutor,
375 ) -> Option<impl Future<Output = ()> + use<T>> {
376 let state = self.state.take()?;
377 log::info!("shutting down ssh processes");
378
379 let State::Connected {
380 multiplex_task,
381 heartbeat_task,
382 ssh_connection,
383 delegate,
384 } = state
385 else {
386 return None;
387 };
388
389 let client = self.client.clone();
390
391 Some(async move {
392 if let Some(shutdown_request) = shutdown_request {
393 client.send(shutdown_request).log_err();
394 // We wait 50ms instead of waiting for a response, because
395 // waiting for a response would require us to wait on the main thread
396 // which we want to avoid in an `on_app_quit` callback.
397 executor.timer(Duration::from_millis(50)).await;
398 }
399
400 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
401 // child of master_process.
402 drop(multiplex_task);
403 // Now drop the rest of state, which kills master process.
404 drop(heartbeat_task);
405 drop(ssh_connection);
406 drop(delegate);
407 })
408 }
409
410 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
411 let can_reconnect = self
412 .state
413 .as_ref()
414 .map(|state| state.can_reconnect())
415 .unwrap_or(false);
416 if !can_reconnect {
417 log::info!("aborting reconnect, because not in state that allows reconnecting");
418 let error = if let Some(state) = self.state.as_ref() {
419 format!("invalid state, cannot reconnect while in state {state}")
420 } else {
421 "no state set".to_string()
422 };
423 anyhow::bail!(error);
424 }
425
426 let state = self.state.take().unwrap();
427 let (attempts, ssh_connection, delegate) = match state {
428 State::Connected {
429 ssh_connection,
430 delegate,
431 multiplex_task,
432 heartbeat_task,
433 }
434 | State::HeartbeatMissed {
435 ssh_connection,
436 delegate,
437 multiplex_task,
438 heartbeat_task,
439 ..
440 } => {
441 drop(multiplex_task);
442 drop(heartbeat_task);
443 (0, ssh_connection, delegate)
444 }
445 State::ReconnectFailed {
446 attempts,
447 ssh_connection,
448 delegate,
449 ..
450 } => (attempts, ssh_connection, delegate),
451 State::Connecting
452 | State::Reconnecting
453 | State::ReconnectExhausted
454 | State::ServerNotRunning => unreachable!(),
455 };
456
457 let attempts = attempts + 1;
458 if attempts > MAX_RECONNECT_ATTEMPTS {
459 log::error!(
460 "Failed to reconnect to after {} attempts, giving up",
461 MAX_RECONNECT_ATTEMPTS
462 );
463 self.set_state(State::ReconnectExhausted, cx);
464 return Ok(());
465 }
466
467 self.set_state(State::Reconnecting, cx);
468
469 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
470
471 let unique_identifier = self.unique_identifier.clone();
472 let client = self.client.clone();
473 let reconnect_task = cx.spawn(async move |this, cx| {
474 macro_rules! failed {
475 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
476 return State::ReconnectFailed {
477 error: anyhow!($error),
478 attempts: $attempts,
479 ssh_connection: $ssh_connection,
480 delegate: $delegate,
481 };
482 };
483 }
484
485 if let Err(error) = ssh_connection
486 .kill()
487 .await
488 .context("Failed to kill ssh process")
489 {
490 failed!(error, attempts, ssh_connection, delegate);
491 };
492
493 let connection_options = ssh_connection.connection_options();
494
495 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
496 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
497 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
498
499 let (ssh_connection, io_task) = match async {
500 let ssh_connection = cx
501 .update_global(|pool: &mut ConnectionPool, cx| {
502 pool.connect(connection_options, &delegate, cx)
503 })?
504 .await
505 .map_err(|error| error.cloned())?;
506
507 let io_task = ssh_connection.start_proxy(
508 unique_identifier,
509 true,
510 incoming_tx,
511 outgoing_rx,
512 connection_activity_tx,
513 delegate.clone(),
514 cx,
515 );
516 anyhow::Ok((ssh_connection, io_task))
517 }
518 .await
519 {
520 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
521 Err(error) => {
522 failed!(error, attempts, ssh_connection, delegate);
523 }
524 };
525
526 let multiplex_task = Self::monitor(this.clone(), io_task, cx);
527 client.reconnect(incoming_rx, outgoing_tx, cx);
528
529 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
530 failed!(error, attempts, ssh_connection, delegate);
531 };
532
533 State::Connected {
534 ssh_connection,
535 delegate,
536 multiplex_task,
537 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
538 }
539 });
540
541 cx.spawn(async move |this, cx| {
542 let new_state = reconnect_task.await;
543 this.update(cx, |this, cx| {
544 this.try_set_state(cx, |old_state| {
545 if old_state.is_reconnecting() {
546 match &new_state {
547 State::Connecting
548 | State::Reconnecting
549 | State::HeartbeatMissed { .. }
550 | State::ServerNotRunning => {}
551 State::Connected { .. } => {
552 log::info!("Successfully reconnected");
553 }
554 State::ReconnectFailed {
555 error, attempts, ..
556 } => {
557 log::error!(
558 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
559 attempts,
560 error
561 );
562 }
563 State::ReconnectExhausted => {
564 log::error!("Reconnect attempt failed and all attempts exhausted");
565 }
566 }
567 Some(new_state)
568 } else {
569 None
570 }
571 });
572
573 if this.state_is(State::is_reconnect_failed) {
574 this.reconnect(cx)
575 } else if this.state_is(State::is_reconnect_exhausted) {
576 Ok(())
577 } else {
578 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
579 Ok(())
580 }
581 })
582 })
583 .detach_and_log_err(cx);
584
585 Ok(())
586 }
587
588 fn heartbeat(
589 this: WeakEntity<Self>,
590 mut connection_activity_rx: mpsc::Receiver<()>,
591 cx: &mut AsyncApp,
592 ) -> Task<Result<()>> {
593 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
594 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
595 };
596
597 cx.spawn(async move |cx| {
598 let mut missed_heartbeats = 0;
599
600 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
601 futures::pin_mut!(keepalive_timer);
602
603 loop {
604 select_biased! {
605 result = connection_activity_rx.next().fuse() => {
606 if result.is_none() {
607 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
608 return Ok(());
609 }
610
611 if missed_heartbeats != 0 {
612 missed_heartbeats = 0;
613 let _ =this.update(cx, |this, cx| {
614 this.handle_heartbeat_result(missed_heartbeats, cx)
615 })?;
616 }
617 }
618 _ = keepalive_timer => {
619 log::debug!("Sending heartbeat to server...");
620
621 let result = select_biased! {
622 _ = connection_activity_rx.next().fuse() => {
623 Ok(())
624 }
625 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
626 ping_result
627 }
628 };
629
630 if result.is_err() {
631 missed_heartbeats += 1;
632 log::warn!(
633 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
634 HEARTBEAT_TIMEOUT,
635 missed_heartbeats,
636 MAX_MISSED_HEARTBEATS
637 );
638 } else if missed_heartbeats != 0 {
639 missed_heartbeats = 0;
640 } else {
641 continue;
642 }
643
644 let result = this.update(cx, |this, cx| {
645 this.handle_heartbeat_result(missed_heartbeats, cx)
646 })?;
647 if result.is_break() {
648 return Ok(());
649 }
650 }
651 }
652
653 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
654 }
655 })
656 }
657
658 fn handle_heartbeat_result(
659 &mut self,
660 missed_heartbeats: usize,
661 cx: &mut Context<Self>,
662 ) -> ControlFlow<()> {
663 let state = self.state.take().unwrap();
664 let next_state = if missed_heartbeats > 0 {
665 state.heartbeat_missed()
666 } else {
667 state.heartbeat_recovered()
668 };
669
670 self.set_state(next_state, cx);
671
672 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
673 log::error!(
674 "Missed last {} heartbeats. Reconnecting...",
675 missed_heartbeats
676 );
677
678 self.reconnect(cx)
679 .context("failed to start reconnect process after missing heartbeats")
680 .log_err();
681 ControlFlow::Break(())
682 } else {
683 ControlFlow::Continue(())
684 }
685 }
686
687 fn monitor(
688 this: WeakEntity<Self>,
689 io_task: Task<Result<i32>>,
690 cx: &AsyncApp,
691 ) -> Task<Result<()>> {
692 cx.spawn(async move |cx| {
693 let result = io_task.await;
694
695 match result {
696 Ok(exit_code) => {
697 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
698 match error {
699 ProxyLaunchError::ServerNotRunning => {
700 log::error!("failed to reconnect because server is not running");
701 this.update(cx, |this, cx| {
702 this.set_state(State::ServerNotRunning, cx);
703 })?;
704 }
705 }
706 } else if exit_code > 0 {
707 log::error!("proxy process terminated unexpectedly");
708 this.update(cx, |this, cx| {
709 this.reconnect(cx).ok();
710 })?;
711 }
712 }
713 Err(error) => {
714 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
715 this.update(cx, |this, cx| {
716 this.reconnect(cx).ok();
717 })?;
718 }
719 }
720
721 Ok(())
722 })
723 }
724
725 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
726 self.state.as_ref().is_some_and(check)
727 }
728
729 fn try_set_state(&mut self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
730 let new_state = self.state.as_ref().and_then(map);
731 if let Some(new_state) = new_state {
732 self.state.replace(new_state);
733 cx.notify();
734 }
735 }
736
737 fn set_state(&mut self, state: State, cx: &mut Context<Self>) {
738 log::info!("setting state to '{}'", &state);
739
740 let is_reconnect_exhausted = state.is_reconnect_exhausted();
741 let is_server_not_running = state.is_server_not_running();
742 self.state.replace(state);
743
744 if is_reconnect_exhausted || is_server_not_running {
745 cx.emit(RemoteClientEvent::Disconnected);
746 }
747 cx.notify();
748 }
749
750 pub fn shell(&self) -> Option<String> {
751 Some(self.state.as_ref()?.remote_connection()?.shell())
752 }
753
754 pub fn build_command(
755 &self,
756 program: Option<String>,
757 args: &[String],
758 env: &HashMap<String, String>,
759 working_dir: Option<String>,
760 activation_script: Option<String>,
761 port_forward: Option<(u16, String, u16)>,
762 ) -> Result<CommandTemplate> {
763 let Some(connection) = self
764 .state
765 .as_ref()
766 .and_then(|state| state.remote_connection())
767 else {
768 return Err(anyhow!("no connection"));
769 };
770 connection.build_command(
771 program,
772 args,
773 env,
774 working_dir,
775 activation_script,
776 port_forward,
777 )
778 }
779
780 pub fn upload_directory(
781 &self,
782 src_path: PathBuf,
783 dest_path: RemotePathBuf,
784 cx: &App,
785 ) -> Task<Result<()>> {
786 let Some(connection) = self
787 .state
788 .as_ref()
789 .and_then(|state| state.remote_connection())
790 else {
791 return Task::ready(Err(anyhow!("no ssh connection")));
792 };
793 connection.upload_directory(src_path, dest_path, cx)
794 }
795
796 pub fn proto_client(&self) -> AnyProtoClient {
797 self.client.clone().into()
798 }
799
800 pub fn host(&self) -> String {
801 self.connection_options.host.clone()
802 }
803
804 pub fn connection_options(&self) -> SshConnectionOptions {
805 self.connection_options.clone()
806 }
807
808 pub fn connection_state(&self) -> ConnectionState {
809 self.state
810 .as_ref()
811 .map(ConnectionState::from)
812 .unwrap_or(ConnectionState::Disconnected)
813 }
814
815 pub fn is_disconnected(&self) -> bool {
816 self.connection_state() == ConnectionState::Disconnected
817 }
818
819 pub fn path_style(&self) -> PathStyle {
820 self.path_style
821 }
822
823 #[cfg(any(test, feature = "test-support"))]
824 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
825 let opts = self.connection_options();
826 client_cx.spawn(async move |cx| {
827 let connection = cx
828 .update_global(|c: &mut ConnectionPool, _| {
829 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
830 c.clone()
831 } else {
832 panic!("missing test connection")
833 }
834 })
835 .unwrap()
836 .await
837 .unwrap();
838
839 connection.simulate_disconnect(cx);
840 })
841 }
842
843 #[cfg(any(test, feature = "test-support"))]
844 pub fn fake_server(
845 client_cx: &mut gpui::TestAppContext,
846 server_cx: &mut gpui::TestAppContext,
847 ) -> (SshConnectionOptions, AnyProtoClient) {
848 let port = client_cx
849 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
850 let opts = SshConnectionOptions {
851 host: "<fake>".to_string(),
852 port: Some(port),
853 ..Default::default()
854 };
855 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
856 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
857 let server_client =
858 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
859 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
860 connection_options: opts.clone(),
861 server_cx: fake::SendableCx::new(server_cx),
862 server_channel: server_client.clone(),
863 });
864
865 client_cx.update(|cx| {
866 cx.update_default_global(|c: &mut ConnectionPool, cx| {
867 c.connections.insert(
868 opts.clone(),
869 ConnectionPoolEntry::Connecting(
870 cx.background_spawn({
871 let connection = connection.clone();
872 async move { Ok(connection.clone()) }
873 })
874 .shared(),
875 ),
876 );
877 })
878 });
879
880 (opts, server_client.into())
881 }
882
883 #[cfg(any(test, feature = "test-support"))]
884 pub async fn fake_client(
885 opts: SshConnectionOptions,
886 client_cx: &mut gpui::TestAppContext,
887 ) -> Entity<Self> {
888 let (_tx, rx) = oneshot::channel();
889 client_cx
890 .update(|cx| {
891 Self::ssh(
892 ConnectionIdentifier::setup(),
893 opts,
894 rx,
895 Arc::new(fake::Delegate),
896 cx,
897 )
898 })
899 .await
900 .unwrap()
901 .unwrap()
902 }
903}
904
905enum ConnectionPoolEntry {
906 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
907 Connected(Weak<dyn RemoteConnection>),
908}
909
910#[derive(Default)]
911struct ConnectionPool {
912 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
913}
914
915impl Global for ConnectionPool {}
916
917impl ConnectionPool {
918 pub fn connect(
919 &mut self,
920 opts: SshConnectionOptions,
921 delegate: &Arc<dyn RemoteClientDelegate>,
922 cx: &mut App,
923 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
924 let connection = self.connections.get(&opts);
925 match connection {
926 Some(ConnectionPoolEntry::Connecting(task)) => {
927 let delegate = delegate.clone();
928 cx.spawn(async move |cx| {
929 delegate.set_status(Some("Waiting for existing connection attempt"), cx);
930 })
931 .detach();
932 return task.clone();
933 }
934 Some(ConnectionPoolEntry::Connected(ssh)) => {
935 if let Some(ssh) = ssh.upgrade()
936 && !ssh.has_been_killed()
937 {
938 return Task::ready(Ok(ssh)).shared();
939 }
940 self.connections.remove(&opts);
941 }
942 None => {}
943 }
944
945 let task = cx
946 .spawn({
947 let opts = opts.clone();
948 let delegate = delegate.clone();
949 async move |cx| {
950 let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
951 .await
952 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
953
954 cx.update_global(|pool: &mut Self, _| {
955 debug_assert!(matches!(
956 pool.connections.get(&opts),
957 Some(ConnectionPoolEntry::Connecting(_))
958 ));
959 match connection {
960 Ok(connection) => {
961 pool.connections.insert(
962 opts.clone(),
963 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
964 );
965 Ok(connection)
966 }
967 Err(error) => {
968 pool.connections.remove(&opts);
969 Err(Arc::new(error))
970 }
971 }
972 })?
973 }
974 })
975 .shared();
976
977 self.connections
978 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
979 task
980 }
981}
982
983#[async_trait(?Send)]
984pub(crate) trait RemoteConnection: Send + Sync {
985 fn start_proxy(
986 &self,
987 unique_identifier: String,
988 reconnect: bool,
989 incoming_tx: UnboundedSender<Envelope>,
990 outgoing_rx: UnboundedReceiver<Envelope>,
991 connection_activity_tx: Sender<()>,
992 delegate: Arc<dyn RemoteClientDelegate>,
993 cx: &mut AsyncApp,
994 ) -> Task<Result<i32>>;
995 fn upload_directory(
996 &self,
997 src_path: PathBuf,
998 dest_path: RemotePathBuf,
999 cx: &App,
1000 ) -> Task<Result<()>>;
1001 async fn kill(&self) -> Result<()>;
1002 fn has_been_killed(&self) -> bool;
1003 fn build_command(
1004 &self,
1005 program: Option<String>,
1006 args: &[String],
1007 env: &HashMap<String, String>,
1008 working_dir: Option<String>,
1009 activation_script: Option<String>,
1010 port_forward: Option<(u16, String, u16)>,
1011 ) -> Result<CommandTemplate>;
1012 fn connection_options(&self) -> SshConnectionOptions;
1013 fn path_style(&self) -> PathStyle;
1014 fn shell(&self) -> String;
1015
1016 #[cfg(any(test, feature = "test-support"))]
1017 fn simulate_disconnect(&self, _: &AsyncApp) {}
1018}
1019
1020type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1021
1022struct ChannelClient {
1023 next_message_id: AtomicU32,
1024 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1025 buffer: Mutex<VecDeque<Envelope>>,
1026 response_channels: ResponseChannels,
1027 message_handlers: Mutex<ProtoMessageHandlerSet>,
1028 max_received: AtomicU32,
1029 name: &'static str,
1030 task: Mutex<Task<Result<()>>>,
1031}
1032
1033impl ChannelClient {
1034 fn new(
1035 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1036 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1037 cx: &App,
1038 name: &'static str,
1039 ) -> Arc<Self> {
1040 Arc::new_cyclic(|this| Self {
1041 outgoing_tx: Mutex::new(outgoing_tx),
1042 next_message_id: AtomicU32::new(0),
1043 max_received: AtomicU32::new(0),
1044 response_channels: ResponseChannels::default(),
1045 message_handlers: Default::default(),
1046 buffer: Mutex::new(VecDeque::new()),
1047 name,
1048 task: Mutex::new(Self::start_handling_messages(
1049 this.clone(),
1050 incoming_rx,
1051 &cx.to_async(),
1052 )),
1053 })
1054 }
1055
1056 fn start_handling_messages(
1057 this: Weak<Self>,
1058 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1059 cx: &AsyncApp,
1060 ) -> Task<Result<()>> {
1061 cx.spawn(async move |cx| {
1062 let peer_id = PeerId { owner_id: 0, id: 0 };
1063 while let Some(incoming) = incoming_rx.next().await {
1064 let Some(this) = this.upgrade() else {
1065 return anyhow::Ok(());
1066 };
1067 if let Some(ack_id) = incoming.ack_id {
1068 let mut buffer = this.buffer.lock();
1069 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1070 buffer.pop_front();
1071 }
1072 }
1073 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
1074 {
1075 log::debug!(
1076 "{}:ssh message received. name:FlushBufferedMessages",
1077 this.name
1078 );
1079 {
1080 let buffer = this.buffer.lock();
1081 for envelope in buffer.iter() {
1082 this.outgoing_tx
1083 .lock()
1084 .unbounded_send(envelope.clone())
1085 .ok();
1086 }
1087 }
1088 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
1089 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1090 this.outgoing_tx.lock().unbounded_send(envelope).ok();
1091 continue;
1092 }
1093
1094 this.max_received.store(incoming.id, SeqCst);
1095
1096 if let Some(request_id) = incoming.responding_to {
1097 let request_id = MessageId(request_id);
1098 let sender = this.response_channels.lock().remove(&request_id);
1099 if let Some(sender) = sender {
1100 let (tx, rx) = oneshot::channel();
1101 if incoming.payload.is_some() {
1102 sender.send((incoming, tx)).ok();
1103 }
1104 rx.await.ok();
1105 }
1106 } else if let Some(envelope) =
1107 build_typed_envelope(peer_id, Instant::now(), incoming)
1108 {
1109 let type_name = envelope.payload_type_name();
1110 let message_id = envelope.message_id();
1111 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1112 &this.message_handlers,
1113 envelope,
1114 this.clone().into(),
1115 cx.clone(),
1116 ) {
1117 log::debug!("{}:ssh message received. name:{type_name}", this.name);
1118 cx.foreground_executor()
1119 .spawn(async move {
1120 match future.await {
1121 Ok(_) => {
1122 log::debug!(
1123 "{}:ssh message handled. name:{type_name}",
1124 this.name
1125 );
1126 }
1127 Err(error) => {
1128 log::error!(
1129 "{}:error handling message. type:{}, error:{}",
1130 this.name,
1131 type_name,
1132 format!("{error:#}").lines().fold(
1133 String::new(),
1134 |mut message, line| {
1135 if !message.is_empty() {
1136 message.push(' ');
1137 }
1138 message.push_str(line);
1139 message
1140 }
1141 )
1142 );
1143 }
1144 }
1145 })
1146 .detach()
1147 } else {
1148 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1149 if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
1150 message_id,
1151 anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
1152 ) {
1153 log::error!(
1154 "{}:error sending error response for {type_name}:{e:#}",
1155 this.name
1156 );
1157 }
1158 }
1159 }
1160 }
1161 anyhow::Ok(())
1162 })
1163 }
1164
1165 fn reconnect(
1166 self: &Arc<Self>,
1167 incoming_rx: UnboundedReceiver<Envelope>,
1168 outgoing_tx: UnboundedSender<Envelope>,
1169 cx: &AsyncApp,
1170 ) {
1171 *self.outgoing_tx.lock() = outgoing_tx;
1172 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1173 }
1174
1175 fn request<T: RequestMessage>(
1176 &self,
1177 payload: T,
1178 ) -> impl 'static + Future<Output = Result<T::Response>> {
1179 self.request_internal(payload, true)
1180 }
1181
1182 fn request_internal<T: RequestMessage>(
1183 &self,
1184 payload: T,
1185 use_buffer: bool,
1186 ) -> impl 'static + Future<Output = Result<T::Response>> {
1187 log::debug!("ssh request start. name:{}", T::NAME);
1188 let response =
1189 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
1190 async move {
1191 let response = response.await?;
1192 log::debug!("ssh request finish. name:{}", T::NAME);
1193 T::Response::from_envelope(response).context("received a response of the wrong type")
1194 }
1195 }
1196
1197 async fn resync(&self, timeout: Duration) -> Result<()> {
1198 smol::future::or(
1199 async {
1200 self.request_internal(proto::FlushBufferedMessages {}, false)
1201 .await?;
1202
1203 for envelope in self.buffer.lock().iter() {
1204 self.outgoing_tx
1205 .lock()
1206 .unbounded_send(envelope.clone())
1207 .ok();
1208 }
1209 Ok(())
1210 },
1211 async {
1212 smol::Timer::after(timeout).await;
1213 anyhow::bail!("Timed out resyncing remote client")
1214 },
1215 )
1216 .await
1217 }
1218
1219 async fn ping(&self, timeout: Duration) -> Result<()> {
1220 smol::future::or(
1221 async {
1222 self.request(proto::Ping {}).await?;
1223 Ok(())
1224 },
1225 async {
1226 smol::Timer::after(timeout).await;
1227 anyhow::bail!("Timed out pinging remote client")
1228 },
1229 )
1230 .await
1231 }
1232
1233 fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1234 log::debug!("ssh send name:{}", T::NAME);
1235 self.send_dynamic(payload.into_envelope(0, None, None))
1236 }
1237
1238 fn request_dynamic(
1239 &self,
1240 mut envelope: proto::Envelope,
1241 type_name: &'static str,
1242 use_buffer: bool,
1243 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1244 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1245 let (tx, rx) = oneshot::channel();
1246 let mut response_channels_lock = self.response_channels.lock();
1247 response_channels_lock.insert(MessageId(envelope.id), tx);
1248 drop(response_channels_lock);
1249
1250 let result = if use_buffer {
1251 self.send_buffered(envelope)
1252 } else {
1253 self.send_unbuffered(envelope)
1254 };
1255 async move {
1256 if let Err(error) = &result {
1257 log::error!("failed to send message: {error}");
1258 anyhow::bail!("failed to send message: {error}");
1259 }
1260
1261 let response = rx.await.context("connection lost")?.0;
1262 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1263 return Err(RpcError::from_proto(error, type_name));
1264 }
1265 Ok(response)
1266 }
1267 }
1268
1269 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1270 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1271 self.send_buffered(envelope)
1272 }
1273
1274 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1275 envelope.ack_id = Some(self.max_received.load(SeqCst));
1276 self.buffer.lock().push_back(envelope.clone());
1277 // ignore errors on send (happen while we're reconnecting)
1278 // assume that the global "disconnected" overlay is sufficient.
1279 self.outgoing_tx.lock().unbounded_send(envelope).ok();
1280 Ok(())
1281 }
1282
1283 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1284 envelope.ack_id = Some(self.max_received.load(SeqCst));
1285 self.outgoing_tx.lock().unbounded_send(envelope).ok();
1286 Ok(())
1287 }
1288}
1289
1290impl ProtoClient for ChannelClient {
1291 fn request(
1292 &self,
1293 envelope: proto::Envelope,
1294 request_type: &'static str,
1295 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1296 self.request_dynamic(envelope, request_type, true).boxed()
1297 }
1298
1299 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1300 self.send_dynamic(envelope)
1301 }
1302
1303 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1304 self.send_dynamic(envelope)
1305 }
1306
1307 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1308 &self.message_handlers
1309 }
1310
1311 fn is_via_collab(&self) -> bool {
1312 false
1313 }
1314}
1315
1316#[cfg(any(test, feature = "test-support"))]
1317mod fake {
1318 use super::{ChannelClient, RemoteClientDelegate, RemoteConnection, RemotePlatform};
1319 use crate::{SshConnectionOptions, remote_client::CommandTemplate};
1320 use anyhow::Result;
1321 use async_trait::async_trait;
1322 use collections::HashMap;
1323 use futures::{
1324 FutureExt, SinkExt, StreamExt,
1325 channel::{
1326 mpsc::{self, Sender},
1327 oneshot,
1328 },
1329 select_biased,
1330 };
1331 use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
1332 use release_channel::ReleaseChannel;
1333 use rpc::proto::Envelope;
1334 use std::{path::PathBuf, sync::Arc};
1335 use util::paths::{PathStyle, RemotePathBuf};
1336
1337 pub(super) struct FakeRemoteConnection {
1338 pub(super) connection_options: SshConnectionOptions,
1339 pub(super) server_channel: Arc<ChannelClient>,
1340 pub(super) server_cx: SendableCx,
1341 }
1342
1343 pub(super) struct SendableCx(AsyncApp);
1344 impl SendableCx {
1345 // SAFETY: When run in test mode, GPUI is always single threaded.
1346 pub(super) fn new(cx: &TestAppContext) -> Self {
1347 Self(cx.to_async())
1348 }
1349
1350 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
1351 fn get(&self, _: &AsyncApp) -> AsyncApp {
1352 self.0.clone()
1353 }
1354 }
1355
1356 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
1357 unsafe impl Send for SendableCx {}
1358 unsafe impl Sync for SendableCx {}
1359
1360 #[async_trait(?Send)]
1361 impl RemoteConnection for FakeRemoteConnection {
1362 async fn kill(&self) -> Result<()> {
1363 Ok(())
1364 }
1365
1366 fn has_been_killed(&self) -> bool {
1367 false
1368 }
1369
1370 fn build_command(
1371 &self,
1372 program: Option<String>,
1373 args: &[String],
1374 env: &HashMap<String, String>,
1375 _: Option<String>,
1376 _: Option<String>,
1377 _: Option<(u16, String, u16)>,
1378 ) -> Result<CommandTemplate> {
1379 let ssh_program = program.unwrap_or_else(|| "sh".to_string());
1380 let mut ssh_args = Vec::new();
1381 ssh_args.push(ssh_program);
1382 ssh_args.extend(args.iter().cloned());
1383 Ok(CommandTemplate {
1384 program: "ssh".into(),
1385 args: ssh_args,
1386 env: env.clone(),
1387 })
1388 }
1389
1390 fn upload_directory(
1391 &self,
1392 _src_path: PathBuf,
1393 _dest_path: RemotePathBuf,
1394 _cx: &App,
1395 ) -> Task<Result<()>> {
1396 unreachable!()
1397 }
1398
1399 fn connection_options(&self) -> SshConnectionOptions {
1400 self.connection_options.clone()
1401 }
1402
1403 fn simulate_disconnect(&self, cx: &AsyncApp) {
1404 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1405 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1406 self.server_channel
1407 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
1408 }
1409
1410 fn start_proxy(
1411 &self,
1412 _unique_identifier: String,
1413 _reconnect: bool,
1414 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
1415 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
1416 mut connection_activity_tx: Sender<()>,
1417 _delegate: Arc<dyn RemoteClientDelegate>,
1418 cx: &mut AsyncApp,
1419 ) -> Task<Result<i32>> {
1420 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
1421 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
1422
1423 self.server_channel.reconnect(
1424 server_incoming_rx,
1425 server_outgoing_tx,
1426 &self.server_cx.get(cx),
1427 );
1428
1429 cx.background_spawn(async move {
1430 loop {
1431 select_biased! {
1432 server_to_client = server_outgoing_rx.next().fuse() => {
1433 let Some(server_to_client) = server_to_client else {
1434 return Ok(1)
1435 };
1436 connection_activity_tx.try_send(()).ok();
1437 client_incoming_tx.send(server_to_client).await.ok();
1438 }
1439 client_to_server = client_outgoing_rx.next().fuse() => {
1440 let Some(client_to_server) = client_to_server else {
1441 return Ok(1)
1442 };
1443 server_incoming_tx.send(client_to_server).await.ok();
1444 }
1445 }
1446 }
1447 })
1448 }
1449
1450 fn path_style(&self) -> PathStyle {
1451 PathStyle::current()
1452 }
1453
1454 fn shell(&self) -> String {
1455 "sh".to_owned()
1456 }
1457 }
1458
1459 pub(super) struct Delegate;
1460
1461 impl RemoteClientDelegate for Delegate {
1462 fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
1463 unreachable!()
1464 }
1465
1466 fn download_server_binary_locally(
1467 &self,
1468 _: RemotePlatform,
1469 _: ReleaseChannel,
1470 _: Option<SemanticVersion>,
1471 _: &mut AsyncApp,
1472 ) -> Task<Result<PathBuf>> {
1473 unreachable!()
1474 }
1475
1476 fn get_download_params(
1477 &self,
1478 _platform: RemotePlatform,
1479 _release_channel: ReleaseChannel,
1480 _version: Option<SemanticVersion>,
1481 _cx: &mut AsyncApp,
1482 ) -> Task<Result<Option<(String, String)>>> {
1483 unreachable!()
1484 }
1485
1486 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
1487 }
1488}