1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 channel::{
13 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
14 oneshot,
15 },
16 future::BoxFuture,
17 select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
18};
19use gpui::{
20 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
21 WeakModel,
22};
23use parking_lot::Mutex;
24use rpc::{
25 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
26 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
27};
28use smol::{
29 fs,
30 process::{self, Child, Stdio},
31};
32use std::{
33 any::TypeId,
34 collections::VecDeque,
35 ffi::OsStr,
36 fmt,
37 ops::ControlFlow,
38 path::{Path, PathBuf},
39 sync::{
40 atomic::{AtomicU32, Ordering::SeqCst},
41 Arc, Weak,
42 },
43 time::{Duration, Instant},
44};
45use tempfile::TempDir;
46use util::ResultExt;
47
48#[derive(
49 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
50)]
51pub struct SshProjectId(pub u64);
52
53#[derive(Clone)]
54pub struct SshSocket {
55 connection_options: SshConnectionOptions,
56 socket_path: PathBuf,
57}
58
59#[derive(Debug, Default, Clone, PartialEq, Eq)]
60pub struct SshConnectionOptions {
61 pub host: String,
62 pub username: Option<String>,
63 pub port: Option<u16>,
64 pub password: Option<String>,
65 pub args: Option<Vec<String>>,
66}
67
68impl SshConnectionOptions {
69 pub fn parse_command_line(input: &str) -> Result<Self> {
70 let input = input.trim_start_matches("ssh ");
71 let mut hostname: Option<String> = None;
72 let mut username: Option<String> = None;
73 let mut port: Option<u16> = None;
74 let mut args = Vec::new();
75
76 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
77 const ALLOWED_OPTS: &[&str] = &[
78 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
79 ];
80 const ALLOWED_ARGS: &[&str] = &[
81 "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
82 "-w",
83 ];
84
85 let mut tokens = shlex::split(input)
86 .ok_or_else(|| anyhow!("invalid input"))?
87 .into_iter();
88
89 'outer: while let Some(arg) = tokens.next() {
90 if ALLOWED_OPTS.contains(&(&arg as &str)) {
91 args.push(arg.to_string());
92 continue;
93 }
94 if arg == "-p" {
95 port = tokens.next().and_then(|arg| arg.parse().ok());
96 continue;
97 } else if let Some(p) = arg.strip_prefix("-p") {
98 port = p.parse().ok();
99 continue;
100 }
101 if arg == "-l" {
102 username = tokens.next();
103 continue;
104 } else if let Some(l) = arg.strip_prefix("-l") {
105 username = Some(l.to_string());
106 continue;
107 }
108 for a in ALLOWED_ARGS {
109 if arg == *a {
110 args.push(arg);
111 if let Some(next) = tokens.next() {
112 args.push(next);
113 }
114 continue 'outer;
115 } else if arg.starts_with(a) {
116 args.push(arg);
117 continue 'outer;
118 }
119 }
120 if arg.starts_with("-") || hostname.is_some() {
121 anyhow::bail!("unsupported argument: {:?}", arg);
122 }
123 let mut input = &arg as &str;
124 if let Some((u, rest)) = input.split_once('@') {
125 input = rest;
126 username = Some(u.to_string());
127 }
128 if let Some((rest, p)) = input.split_once(':') {
129 input = rest;
130 port = p.parse().ok()
131 }
132 hostname = Some(input.to_string())
133 }
134
135 let Some(hostname) = hostname else {
136 anyhow::bail!("missing hostname");
137 };
138
139 Ok(Self {
140 host: hostname.to_string(),
141 username: username.clone(),
142 port,
143 password: None,
144 args: Some(args),
145 })
146 }
147
148 pub fn ssh_url(&self) -> String {
149 let mut result = String::from("ssh://");
150 if let Some(username) = &self.username {
151 result.push_str(username);
152 result.push('@');
153 }
154 result.push_str(&self.host);
155 if let Some(port) = self.port {
156 result.push(':');
157 result.push_str(&port.to_string());
158 }
159 result
160 }
161
162 pub fn additional_args(&self) -> Option<&Vec<String>> {
163 self.args.as_ref()
164 }
165
166 fn scp_url(&self) -> String {
167 if let Some(username) = &self.username {
168 format!("{}@{}", username, self.host)
169 } else {
170 self.host.clone()
171 }
172 }
173
174 pub fn connection_string(&self) -> String {
175 let host = if let Some(username) = &self.username {
176 format!("{}@{}", username, self.host)
177 } else {
178 self.host.clone()
179 };
180 if let Some(port) = &self.port {
181 format!("{}:{}", host, port)
182 } else {
183 host
184 }
185 }
186
187 // Uniquely identifies dev server projects on a remote host. Needs to be
188 // stable for the same dev server project.
189 pub fn dev_server_identifier(&self) -> String {
190 let mut identifier = format!("dev-server-{:?}", self.host);
191 if let Some(username) = self.username.as_ref() {
192 identifier.push('-');
193 identifier.push_str(&username);
194 }
195 identifier
196 }
197}
198
199#[derive(Copy, Clone, Debug)]
200pub struct SshPlatform {
201 pub os: &'static str,
202 pub arch: &'static str,
203}
204
205impl SshPlatform {
206 pub fn triple(&self) -> Option<String> {
207 Some(format!(
208 "{}-{}",
209 self.arch,
210 match self.os {
211 "linux" => "unknown-linux-gnu",
212 "macos" => "apple-darwin",
213 _ => return None,
214 }
215 ))
216 }
217}
218
219pub trait SshClientDelegate: Send + Sync {
220 fn ask_password(
221 &self,
222 prompt: String,
223 cx: &mut AsyncAppContext,
224 ) -> oneshot::Receiver<Result<String>>;
225 fn remote_server_binary_path(
226 &self,
227 platform: SshPlatform,
228 cx: &mut AsyncAppContext,
229 ) -> Result<PathBuf>;
230 fn get_server_binary(
231 &self,
232 platform: SshPlatform,
233 cx: &mut AsyncAppContext,
234 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
235 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
236 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
237}
238
239impl SshSocket {
240 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
241 let mut command = process::Command::new("ssh");
242 self.ssh_options(&mut command)
243 .arg(self.connection_options.ssh_url())
244 .arg(program);
245 command
246 }
247
248 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
249 command
250 .stdin(Stdio::piped())
251 .stdout(Stdio::piped())
252 .stderr(Stdio::piped())
253 .args(["-o", "ControlMaster=no", "-o"])
254 .arg(format!("ControlPath={}", self.socket_path.display()))
255 }
256
257 fn ssh_args(&self) -> Vec<String> {
258 vec![
259 "-o".to_string(),
260 "ControlMaster=no".to_string(),
261 "-o".to_string(),
262 format!("ControlPath={}", self.socket_path.display()),
263 self.connection_options.ssh_url(),
264 ]
265 }
266}
267
268async fn run_cmd(command: &mut process::Command) -> Result<String> {
269 let output = command.output().await?;
270 if output.status.success() {
271 Ok(String::from_utf8_lossy(&output.stdout).to_string())
272 } else {
273 Err(anyhow!(
274 "failed to run command: {}",
275 String::from_utf8_lossy(&output.stderr)
276 ))
277 }
278}
279
280const MAX_MISSED_HEARTBEATS: usize = 5;
281const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
282const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
283
284const MAX_RECONNECT_ATTEMPTS: usize = 3;
285
286enum State {
287 Connecting,
288 Connected {
289 ssh_connection: Box<dyn SshRemoteProcess>,
290 delegate: Arc<dyn SshClientDelegate>,
291
292 multiplex_task: Task<Result<()>>,
293 heartbeat_task: Task<Result<()>>,
294 },
295 HeartbeatMissed {
296 missed_heartbeats: usize,
297
298 ssh_connection: Box<dyn SshRemoteProcess>,
299 delegate: Arc<dyn SshClientDelegate>,
300
301 multiplex_task: Task<Result<()>>,
302 heartbeat_task: Task<Result<()>>,
303 },
304 Reconnecting,
305 ReconnectFailed {
306 ssh_connection: Box<dyn SshRemoteProcess>,
307 delegate: Arc<dyn SshClientDelegate>,
308
309 error: anyhow::Error,
310 attempts: usize,
311 },
312 ReconnectExhausted,
313 ServerNotRunning,
314}
315
316impl fmt::Display for State {
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 match self {
319 Self::Connecting => write!(f, "connecting"),
320 Self::Connected { .. } => write!(f, "connected"),
321 Self::Reconnecting => write!(f, "reconnecting"),
322 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
323 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
324 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
325 Self::ServerNotRunning { .. } => write!(f, "server not running"),
326 }
327 }
328}
329
330impl State {
331 fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> {
332 match self {
333 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
334 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
335 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
336 _ => None,
337 }
338 }
339
340 fn can_reconnect(&self) -> bool {
341 match self {
342 Self::Connected { .. }
343 | Self::HeartbeatMissed { .. }
344 | Self::ReconnectFailed { .. } => true,
345 State::Connecting
346 | State::Reconnecting
347 | State::ReconnectExhausted
348 | State::ServerNotRunning => false,
349 }
350 }
351
352 fn is_reconnect_failed(&self) -> bool {
353 matches!(self, Self::ReconnectFailed { .. })
354 }
355
356 fn is_reconnect_exhausted(&self) -> bool {
357 matches!(self, Self::ReconnectExhausted { .. })
358 }
359
360 fn is_reconnecting(&self) -> bool {
361 matches!(self, Self::Reconnecting { .. })
362 }
363
364 fn heartbeat_recovered(self) -> Self {
365 match self {
366 Self::HeartbeatMissed {
367 ssh_connection,
368 delegate,
369 multiplex_task,
370 heartbeat_task,
371 ..
372 } => Self::Connected {
373 ssh_connection,
374 delegate,
375 multiplex_task,
376 heartbeat_task,
377 },
378 _ => self,
379 }
380 }
381
382 fn heartbeat_missed(self) -> Self {
383 match self {
384 Self::Connected {
385 ssh_connection,
386 delegate,
387 multiplex_task,
388 heartbeat_task,
389 } => Self::HeartbeatMissed {
390 missed_heartbeats: 1,
391 ssh_connection,
392 delegate,
393 multiplex_task,
394 heartbeat_task,
395 },
396 Self::HeartbeatMissed {
397 missed_heartbeats,
398 ssh_connection,
399 delegate,
400 multiplex_task,
401 heartbeat_task,
402 } => Self::HeartbeatMissed {
403 missed_heartbeats: missed_heartbeats + 1,
404 ssh_connection,
405 delegate,
406 multiplex_task,
407 heartbeat_task,
408 },
409 _ => self,
410 }
411 }
412}
413
414/// The state of the ssh connection.
415#[derive(Clone, Copy, Debug, PartialEq, Eq)]
416pub enum ConnectionState {
417 Connecting,
418 Connected,
419 HeartbeatMissed,
420 Reconnecting,
421 Disconnected,
422}
423
424impl From<&State> for ConnectionState {
425 fn from(value: &State) -> Self {
426 match value {
427 State::Connecting => Self::Connecting,
428 State::Connected { .. } => Self::Connected,
429 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
430 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
431 State::ReconnectExhausted => Self::Disconnected,
432 State::ServerNotRunning => Self::Disconnected,
433 }
434 }
435}
436
437pub struct SshRemoteClient {
438 client: Arc<ChannelClient>,
439 unique_identifier: String,
440 connection_options: SshConnectionOptions,
441 state: Arc<Mutex<Option<State>>>,
442}
443
444#[derive(Debug)]
445pub enum SshRemoteEvent {
446 Disconnected,
447}
448
449impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
450
451impl SshRemoteClient {
452 pub fn new(
453 unique_identifier: String,
454 connection_options: SshConnectionOptions,
455 delegate: Arc<dyn SshClientDelegate>,
456 cx: &AppContext,
457 ) -> Task<Result<Model<Self>>> {
458 cx.spawn(|mut cx| async move {
459 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
460 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
461 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
462
463 let client =
464 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
465 let this = cx.new_model(|_| Self {
466 client: client.clone(),
467 unique_identifier: unique_identifier.clone(),
468 connection_options: connection_options.clone(),
469 state: Arc::new(Mutex::new(Some(State::Connecting))),
470 })?;
471
472 let (ssh_connection, io_task) = Self::establish_connection(
473 unique_identifier,
474 false,
475 connection_options,
476 incoming_tx,
477 outgoing_rx,
478 connection_activity_tx,
479 delegate.clone(),
480 &mut cx,
481 )
482 .await?;
483
484 let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
485
486 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
487 log::error!("failed to establish connection: {}", error);
488 delegate.set_error(error.to_string(), &mut cx);
489 return Err(error);
490 }
491
492 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
493
494 this.update(&mut cx, |this, _| {
495 *this.state.lock() = Some(State::Connected {
496 ssh_connection,
497 delegate,
498 multiplex_task,
499 heartbeat_task,
500 });
501 })?;
502
503 Ok(this)
504 })
505 }
506
507 pub fn shutdown_processes<T: RequestMessage>(
508 &self,
509 shutdown_request: Option<T>,
510 ) -> Option<impl Future<Output = ()>> {
511 let state = self.state.lock().take()?;
512 log::info!("shutting down ssh processes");
513
514 let State::Connected {
515 multiplex_task,
516 heartbeat_task,
517 ssh_connection,
518 delegate,
519 } = state
520 else {
521 return None;
522 };
523
524 let client = self.client.clone();
525
526 Some(async move {
527 if let Some(shutdown_request) = shutdown_request {
528 client.send(shutdown_request).log_err();
529 // We wait 50ms instead of waiting for a response, because
530 // waiting for a response would require us to wait on the main thread
531 // which we want to avoid in an `on_app_quit` callback.
532 smol::Timer::after(Duration::from_millis(50)).await;
533 }
534
535 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
536 // child of master_process.
537 drop(multiplex_task);
538 // Now drop the rest of state, which kills master process.
539 drop(heartbeat_task);
540 drop(ssh_connection);
541 drop(delegate);
542 })
543 }
544
545 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
546 let mut lock = self.state.lock();
547
548 let can_reconnect = lock
549 .as_ref()
550 .map(|state| state.can_reconnect())
551 .unwrap_or(false);
552 if !can_reconnect {
553 let error = if let Some(state) = lock.as_ref() {
554 format!("invalid state, cannot reconnect while in state {state}")
555 } else {
556 "no state set".to_string()
557 };
558 log::info!("aborting reconnect, because not in state that allows reconnecting");
559 return Err(anyhow!(error));
560 }
561
562 let state = lock.take().unwrap();
563 let (attempts, mut ssh_connection, delegate) = match state {
564 State::Connected {
565 ssh_connection,
566 delegate,
567 multiplex_task,
568 heartbeat_task,
569 }
570 | State::HeartbeatMissed {
571 ssh_connection,
572 delegate,
573 multiplex_task,
574 heartbeat_task,
575 ..
576 } => {
577 drop(multiplex_task);
578 drop(heartbeat_task);
579 (0, ssh_connection, delegate)
580 }
581 State::ReconnectFailed {
582 attempts,
583 ssh_connection,
584 delegate,
585 ..
586 } => (attempts, ssh_connection, delegate),
587 State::Connecting
588 | State::Reconnecting
589 | State::ReconnectExhausted
590 | State::ServerNotRunning => unreachable!(),
591 };
592
593 let attempts = attempts + 1;
594 if attempts > MAX_RECONNECT_ATTEMPTS {
595 log::error!(
596 "Failed to reconnect to after {} attempts, giving up",
597 MAX_RECONNECT_ATTEMPTS
598 );
599 drop(lock);
600 self.set_state(State::ReconnectExhausted, cx);
601 return Ok(());
602 }
603 drop(lock);
604
605 self.set_state(State::Reconnecting, cx);
606
607 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
608
609 let identifier = self.unique_identifier.clone();
610 let client = self.client.clone();
611 let reconnect_task = cx.spawn(|this, mut cx| async move {
612 macro_rules! failed {
613 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
614 return State::ReconnectFailed {
615 error: anyhow!($error),
616 attempts: $attempts,
617 ssh_connection: $ssh_connection,
618 delegate: $delegate,
619 };
620 };
621 }
622
623 if let Err(error) = ssh_connection
624 .kill()
625 .await
626 .context("Failed to kill ssh process")
627 {
628 failed!(error, attempts, ssh_connection, delegate);
629 };
630
631 let connection_options = ssh_connection.connection_options();
632
633 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
634 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
635 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
636
637 let (ssh_connection, io_task) = match Self::establish_connection(
638 identifier,
639 true,
640 connection_options,
641 incoming_tx,
642 outgoing_rx,
643 connection_activity_tx,
644 delegate.clone(),
645 &mut cx,
646 )
647 .await
648 {
649 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
650 Err(error) => {
651 failed!(error, attempts, ssh_connection, delegate);
652 }
653 };
654
655 let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
656 client.reconnect(incoming_rx, outgoing_tx, &cx);
657
658 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
659 failed!(error, attempts, ssh_connection, delegate);
660 };
661
662 State::Connected {
663 ssh_connection,
664 delegate,
665 multiplex_task,
666 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
667 }
668 });
669
670 cx.spawn(|this, mut cx| async move {
671 let new_state = reconnect_task.await;
672 this.update(&mut cx, |this, cx| {
673 this.try_set_state(cx, |old_state| {
674 if old_state.is_reconnecting() {
675 match &new_state {
676 State::Connecting
677 | State::Reconnecting { .. }
678 | State::HeartbeatMissed { .. }
679 | State::ServerNotRunning => {}
680 State::Connected { .. } => {
681 log::info!("Successfully reconnected");
682 }
683 State::ReconnectFailed {
684 error, attempts, ..
685 } => {
686 log::error!(
687 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
688 attempts,
689 error
690 );
691 }
692 State::ReconnectExhausted => {
693 log::error!("Reconnect attempt failed and all attempts exhausted");
694 }
695 }
696 Some(new_state)
697 } else {
698 None
699 }
700 });
701
702 if this.state_is(State::is_reconnect_failed) {
703 this.reconnect(cx)
704 } else if this.state_is(State::is_reconnect_exhausted) {
705 cx.emit(SshRemoteEvent::Disconnected);
706 Ok(())
707 } else {
708 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
709 Ok(())
710 }
711 })
712 })
713 .detach_and_log_err(cx);
714
715 Ok(())
716 }
717
718 fn heartbeat(
719 this: WeakModel<Self>,
720 mut connection_activity_rx: mpsc::Receiver<()>,
721 cx: &mut AsyncAppContext,
722 ) -> Task<Result<()>> {
723 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
724 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
725 };
726
727 cx.spawn(|mut cx| {
728 let this = this.clone();
729 async move {
730 let mut missed_heartbeats = 0;
731
732 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
733 futures::pin_mut!(keepalive_timer);
734
735 loop {
736 select_biased! {
737 result = connection_activity_rx.next().fuse() => {
738 if result.is_none() {
739 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
740 return Ok(());
741 }
742
743 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
744
745 if missed_heartbeats != 0 {
746 missed_heartbeats = 0;
747 this.update(&mut cx, |this, mut cx| {
748 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
749 })?;
750 }
751 }
752 _ = keepalive_timer => {
753 log::debug!("Sending heartbeat to server...");
754
755 let result = select_biased! {
756 _ = connection_activity_rx.next().fuse() => {
757 Ok(())
758 }
759 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
760 ping_result
761 }
762 };
763
764 if result.is_err() {
765 missed_heartbeats += 1;
766 log::warn!(
767 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
768 HEARTBEAT_TIMEOUT,
769 missed_heartbeats,
770 MAX_MISSED_HEARTBEATS
771 );
772 } else if missed_heartbeats != 0 {
773 missed_heartbeats = 0;
774 } else {
775 continue;
776 }
777
778 let result = this.update(&mut cx, |this, mut cx| {
779 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
780 })?;
781 if result.is_break() {
782 return Ok(());
783 }
784 }
785 }
786 }
787 }
788 })
789 }
790
791 fn handle_heartbeat_result(
792 &mut self,
793 missed_heartbeats: usize,
794 cx: &mut ModelContext<Self>,
795 ) -> ControlFlow<()> {
796 let state = self.state.lock().take().unwrap();
797 let next_state = if missed_heartbeats > 0 {
798 state.heartbeat_missed()
799 } else {
800 state.heartbeat_recovered()
801 };
802
803 self.set_state(next_state, cx);
804
805 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
806 log::error!(
807 "Missed last {} heartbeats. Reconnecting...",
808 missed_heartbeats
809 );
810
811 self.reconnect(cx)
812 .context("failed to start reconnect process after missing heartbeats")
813 .log_err();
814 ControlFlow::Break(())
815 } else {
816 ControlFlow::Continue(())
817 }
818 }
819
820 fn multiplex(
821 mut ssh_proxy_process: Child,
822 incoming_tx: UnboundedSender<Envelope>,
823 mut outgoing_rx: UnboundedReceiver<Envelope>,
824 mut connection_activity_tx: Sender<()>,
825 cx: &AsyncAppContext,
826 ) -> Task<Result<i32>> {
827 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
828 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
829 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
830
831 let mut stdin_buffer = Vec::new();
832 let mut stdout_buffer = Vec::new();
833 let mut stderr_buffer = Vec::new();
834 let mut stderr_offset = 0;
835
836 let stdin_task = cx.background_executor().spawn(async move {
837 while let Some(outgoing) = outgoing_rx.next().await {
838 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
839 }
840 anyhow::Ok(())
841 });
842
843 let stdout_task = cx.background_executor().spawn({
844 let mut connection_activity_tx = connection_activity_tx.clone();
845 async move {
846 loop {
847 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
848 let len = child_stdout.read(&mut stdout_buffer).await?;
849
850 if len == 0 {
851 return anyhow::Ok(());
852 }
853
854 if len < MESSAGE_LEN_SIZE {
855 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
856 }
857
858 let message_len = message_len_from_buffer(&stdout_buffer);
859 let envelope =
860 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
861 .await?;
862 connection_activity_tx.try_send(()).ok();
863 incoming_tx.unbounded_send(envelope).ok();
864 }
865 }
866 });
867
868 let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
869 loop {
870 stderr_buffer.resize(stderr_offset + 1024, 0);
871
872 let len = child_stderr
873 .read(&mut stderr_buffer[stderr_offset..])
874 .await?;
875
876 stderr_offset += len;
877 let mut start_ix = 0;
878 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
879 .iter()
880 .position(|b| b == &b'\n')
881 {
882 let line_ix = start_ix + ix;
883 let content = &stderr_buffer[start_ix..line_ix];
884 start_ix = line_ix + 1;
885 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
886 record.log(log::logger())
887 } else {
888 eprintln!("(remote) {}", String::from_utf8_lossy(content));
889 }
890 }
891 stderr_buffer.drain(0..start_ix);
892 stderr_offset -= start_ix;
893
894 connection_activity_tx.try_send(()).ok();
895 }
896 });
897
898 cx.spawn(|_| async move {
899 let result = futures::select! {
900 result = stdin_task.fuse() => {
901 result.context("stdin")
902 }
903 result = stdout_task.fuse() => {
904 result.context("stdout")
905 }
906 result = stderr_task.fuse() => {
907 result.context("stderr")
908 }
909 };
910
911 match result {
912 Ok(_) => Ok(ssh_proxy_process.status().await?.code().unwrap_or(1)),
913 Err(error) => Err(error),
914 }
915 })
916 }
917
918 fn monitor(
919 this: WeakModel<Self>,
920 io_task: Task<Result<i32>>,
921 cx: &AsyncAppContext,
922 ) -> Task<Result<()>> {
923 cx.spawn(|mut cx| async move {
924 let result = io_task.await;
925
926 match result {
927 Ok(exit_code) => {
928 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
929 match error {
930 ProxyLaunchError::ServerNotRunning => {
931 log::error!("failed to reconnect because server is not running");
932 this.update(&mut cx, |this, cx| {
933 this.set_state(State::ServerNotRunning, cx);
934 cx.emit(SshRemoteEvent::Disconnected);
935 })?;
936 }
937 }
938 } else if exit_code > 0 {
939 log::error!("proxy process terminated unexpectedly");
940 this.update(&mut cx, |this, cx| {
941 this.reconnect(cx).ok();
942 })?;
943 }
944 }
945 Err(error) => {
946 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
947 this.update(&mut cx, |this, cx| {
948 this.reconnect(cx).ok();
949 })?;
950 }
951 }
952
953 Ok(())
954 })
955 }
956
957 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
958 self.state.lock().as_ref().map_or(false, check)
959 }
960
961 fn try_set_state(
962 &self,
963 cx: &mut ModelContext<Self>,
964 map: impl FnOnce(&State) -> Option<State>,
965 ) {
966 let mut lock = self.state.lock();
967 let new_state = lock.as_ref().and_then(map);
968
969 if let Some(new_state) = new_state {
970 lock.replace(new_state);
971 cx.notify();
972 }
973 }
974
975 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
976 log::info!("setting state to '{}'", &state);
977 self.state.lock().replace(state);
978 cx.notify();
979 }
980
981 #[allow(clippy::too_many_arguments)]
982 async fn establish_connection(
983 unique_identifier: String,
984 reconnect: bool,
985 connection_options: SshConnectionOptions,
986 incoming_tx: UnboundedSender<Envelope>,
987 outgoing_rx: UnboundedReceiver<Envelope>,
988 connection_activity_tx: Sender<()>,
989 delegate: Arc<dyn SshClientDelegate>,
990 cx: &mut AsyncAppContext,
991 ) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<i32>>)> {
992 #[cfg(any(test, feature = "test-support"))]
993 if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
994 let io_task = fake::SshRemoteConnection::multiplex(
995 fake.connection_options(),
996 incoming_tx,
997 outgoing_rx,
998 connection_activity_tx,
999 cx,
1000 )
1001 .await;
1002 return Ok((fake, io_task));
1003 }
1004
1005 let ssh_connection =
1006 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
1007
1008 let platform = ssh_connection.query_platform().await?;
1009 let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1010 if !reconnect {
1011 ssh_connection
1012 .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1013 .await?;
1014 }
1015
1016 let socket = ssh_connection.socket.clone();
1017 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1018
1019 delegate.set_status(Some("Starting proxy"), cx);
1020
1021 let mut start_proxy_command = format!(
1022 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1023 std::env::var("RUST_LOG").unwrap_or_default(),
1024 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1025 remote_binary_path,
1026 unique_identifier,
1027 );
1028 if reconnect {
1029 start_proxy_command.push_str(" --reconnect");
1030 }
1031
1032 let ssh_proxy_process = socket
1033 .ssh_command(start_proxy_command)
1034 // IMPORTANT: we kill this process when we drop the task that uses it.
1035 .kill_on_drop(true)
1036 .spawn()
1037 .context("failed to spawn remote server")?;
1038
1039 let io_task = Self::multiplex(
1040 ssh_proxy_process,
1041 incoming_tx,
1042 outgoing_rx,
1043 connection_activity_tx,
1044 &cx,
1045 );
1046
1047 Ok((Box::new(ssh_connection), io_task))
1048 }
1049
1050 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1051 self.client.subscribe_to_entity(remote_id, entity);
1052 }
1053
1054 pub fn ssh_args(&self) -> Option<Vec<String>> {
1055 self.state
1056 .lock()
1057 .as_ref()
1058 .and_then(|state| state.ssh_connection())
1059 .map(|ssh_connection| ssh_connection.ssh_args())
1060 }
1061
1062 pub fn proto_client(&self) -> AnyProtoClient {
1063 self.client.clone().into()
1064 }
1065
1066 pub fn connection_string(&self) -> String {
1067 self.connection_options.connection_string()
1068 }
1069
1070 pub fn connection_options(&self) -> SshConnectionOptions {
1071 self.connection_options.clone()
1072 }
1073
1074 pub fn connection_state(&self) -> ConnectionState {
1075 self.state
1076 .lock()
1077 .as_ref()
1078 .map(ConnectionState::from)
1079 .unwrap_or(ConnectionState::Disconnected)
1080 }
1081
1082 pub fn is_disconnected(&self) -> bool {
1083 self.connection_state() == ConnectionState::Disconnected
1084 }
1085
1086 #[cfg(any(test, feature = "test-support"))]
1087 pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
1088 let port = self.connection_options().port.unwrap();
1089 client_cx.spawn(|cx| async move {
1090 let (channel, server_cx) = cx
1091 .update_global(|c: &mut fake::ServerConnections, _| c.get(port))
1092 .unwrap();
1093
1094 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1095 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1096 channel.reconnect(incoming_rx, outgoing_tx, &server_cx);
1097 })
1098 }
1099
1100 #[cfg(any(test, feature = "test-support"))]
1101 pub fn fake_server(
1102 client_cx: &mut gpui::TestAppContext,
1103 server_cx: &mut gpui::TestAppContext,
1104 ) -> (u16, Arc<ChannelClient>) {
1105 use gpui::BorrowAppContext;
1106 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1107 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1108 let server_client =
1109 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1110 let port = client_cx.update(|cx| {
1111 cx.update_default_global(|c: &mut fake::ServerConnections, _| {
1112 c.push(server_client.clone(), server_cx.to_async())
1113 })
1114 });
1115 (port, server_client)
1116 }
1117
1118 #[cfg(any(test, feature = "test-support"))]
1119 pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model<Self> {
1120 client_cx
1121 .update(|cx| {
1122 Self::new(
1123 "fake".to_string(),
1124 SshConnectionOptions {
1125 host: "<fake>".to_string(),
1126 port: Some(port),
1127 ..Default::default()
1128 },
1129 Arc::new(fake::Delegate),
1130 cx,
1131 )
1132 })
1133 .await
1134 .unwrap()
1135 }
1136}
1137
1138impl From<SshRemoteClient> for AnyProtoClient {
1139 fn from(client: SshRemoteClient) -> Self {
1140 AnyProtoClient::new(client.client.clone())
1141 }
1142}
1143
1144#[async_trait]
1145trait SshRemoteProcess: Send + Sync {
1146 async fn kill(&mut self) -> Result<()>;
1147 fn ssh_args(&self) -> Vec<String>;
1148 fn connection_options(&self) -> SshConnectionOptions;
1149}
1150
1151struct SshRemoteConnection {
1152 socket: SshSocket,
1153 master_process: process::Child,
1154 _temp_dir: TempDir,
1155}
1156
1157impl Drop for SshRemoteConnection {
1158 fn drop(&mut self) {
1159 if let Err(error) = self.master_process.kill() {
1160 log::error!("failed to kill SSH master process: {}", error);
1161 }
1162 }
1163}
1164
1165#[async_trait]
1166impl SshRemoteProcess for SshRemoteConnection {
1167 async fn kill(&mut self) -> Result<()> {
1168 self.master_process.kill()?;
1169
1170 self.master_process.status().await?;
1171
1172 Ok(())
1173 }
1174
1175 fn ssh_args(&self) -> Vec<String> {
1176 self.socket.ssh_args()
1177 }
1178
1179 fn connection_options(&self) -> SshConnectionOptions {
1180 self.socket.connection_options.clone()
1181 }
1182}
1183
1184impl SshRemoteConnection {
1185 #[cfg(not(unix))]
1186 async fn new(
1187 _connection_options: SshConnectionOptions,
1188 _delegate: Arc<dyn SshClientDelegate>,
1189 _cx: &mut AsyncAppContext,
1190 ) -> Result<Self> {
1191 Err(anyhow!("ssh is not supported on this platform"))
1192 }
1193
1194 #[cfg(unix)]
1195 async fn new(
1196 connection_options: SshConnectionOptions,
1197 delegate: Arc<dyn SshClientDelegate>,
1198 cx: &mut AsyncAppContext,
1199 ) -> Result<Self> {
1200 use futures::AsyncWriteExt as _;
1201 use futures::{io::BufReader, AsyncBufReadExt as _};
1202 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1203 use util::ResultExt as _;
1204
1205 delegate.set_status(Some("Connecting"), cx);
1206
1207 let url = connection_options.ssh_url();
1208 let temp_dir = tempfile::Builder::new()
1209 .prefix("zed-ssh-session")
1210 .tempdir()?;
1211
1212 // Create a domain socket listener to handle requests from the askpass program.
1213 let askpass_socket = temp_dir.path().join("askpass.sock");
1214 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1215 let listener =
1216 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1217
1218 let askpass_task = cx.spawn({
1219 let delegate = delegate.clone();
1220 |mut cx| async move {
1221 let mut askpass_opened_tx = Some(askpass_opened_tx);
1222
1223 while let Ok((mut stream, _)) = listener.accept().await {
1224 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1225 askpass_opened_tx.send(()).ok();
1226 }
1227 let mut buffer = Vec::new();
1228 let mut reader = BufReader::new(&mut stream);
1229 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1230 buffer.clear();
1231 }
1232 let password_prompt = String::from_utf8_lossy(&buffer);
1233 if let Some(password) = delegate
1234 .ask_password(password_prompt.to_string(), &mut cx)
1235 .await
1236 .context("failed to get ssh password")
1237 .and_then(|p| p)
1238 .log_err()
1239 {
1240 stream.write_all(password.as_bytes()).await.log_err();
1241 }
1242 }
1243 }
1244 });
1245
1246 // Create an askpass script that communicates back to this process.
1247 let askpass_script = format!(
1248 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1249 askpass_socket = askpass_socket.display(),
1250 print_args = "printf '%s\\0' \"$@\"",
1251 shebang = "#!/bin/sh",
1252 );
1253 let askpass_script_path = temp_dir.path().join("askpass.sh");
1254 fs::write(&askpass_script_path, askpass_script).await?;
1255 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1256
1257 // Start the master SSH process, which does not do anything except for establish
1258 // the connection and keep it open, allowing other ssh commands to reuse it
1259 // via a control socket.
1260 let socket_path = temp_dir.path().join("ssh.sock");
1261 let mut master_process = process::Command::new("ssh")
1262 .stdin(Stdio::null())
1263 .stdout(Stdio::piped())
1264 .stderr(Stdio::piped())
1265 .env("SSH_ASKPASS_REQUIRE", "force")
1266 .env("SSH_ASKPASS", &askpass_script_path)
1267 .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1268 .args([
1269 "-N",
1270 "-o",
1271 "ControlPersist=no",
1272 "-o",
1273 "ControlMaster=yes",
1274 "-o",
1275 ])
1276 .arg(format!("ControlPath={}", socket_path.display()))
1277 .arg(&url)
1278 .spawn()?;
1279
1280 // Wait for this ssh process to close its stdout, indicating that authentication
1281 // has completed.
1282 let stdout = master_process.stdout.as_mut().unwrap();
1283 let mut output = Vec::new();
1284 let connection_timeout = Duration::from_secs(10);
1285
1286 let result = select_biased! {
1287 _ = askpass_opened_rx.fuse() => {
1288 // If the askpass script has opened, that means the user is typing
1289 // their password, in which case we don't want to timeout anymore,
1290 // since we know a connection has been established.
1291 stdout.read_to_end(&mut output).await?;
1292 Ok(())
1293 }
1294 result = stdout.read_to_end(&mut output).fuse() => {
1295 result?;
1296 Ok(())
1297 }
1298 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1299 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1300 }
1301 };
1302
1303 if let Err(e) = result {
1304 let error_message = format!("Failed to connect to host: {}.", e);
1305 delegate.set_error(error_message, cx);
1306 return Err(e);
1307 }
1308
1309 drop(askpass_task);
1310
1311 if master_process.try_status()?.is_some() {
1312 output.clear();
1313 let mut stderr = master_process.stderr.take().unwrap();
1314 stderr.read_to_end(&mut output).await?;
1315
1316 let error_message = format!(
1317 "failed to connect: {}",
1318 String::from_utf8_lossy(&output).trim()
1319 );
1320 delegate.set_error(error_message.clone(), cx);
1321 Err(anyhow!(error_message))?;
1322 }
1323
1324 Ok(Self {
1325 socket: SshSocket {
1326 connection_options,
1327 socket_path,
1328 },
1329 master_process,
1330 _temp_dir: temp_dir,
1331 })
1332 }
1333
1334 async fn ensure_server_binary(
1335 &self,
1336 delegate: &Arc<dyn SshClientDelegate>,
1337 dst_path: &Path,
1338 platform: SshPlatform,
1339 cx: &mut AsyncAppContext,
1340 ) -> Result<()> {
1341 if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1342 if let Ok(installed_version) =
1343 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1344 {
1345 log::info!("using cached server binary version {}", installed_version);
1346 return Ok(());
1347 }
1348 }
1349
1350 let mut dst_path_gz = dst_path.to_path_buf();
1351 dst_path_gz.set_extension("gz");
1352
1353 if let Some(parent) = dst_path.parent() {
1354 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1355 }
1356
1357 let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
1358
1359 let mut server_binary_exists = false;
1360 if !server_binary_exists && cfg!(not(debug_assertions)) {
1361 if let Ok(installed_version) =
1362 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1363 {
1364 if installed_version.trim() == version.to_string() {
1365 server_binary_exists = true;
1366 }
1367 }
1368 }
1369
1370 if server_binary_exists {
1371 log::info!("remote development server already present",);
1372 return Ok(());
1373 }
1374
1375 let src_stat = fs::metadata(&src_path).await?;
1376 let size = src_stat.len();
1377 let server_mode = 0o755;
1378
1379 let t0 = Instant::now();
1380 delegate.set_status(Some("Uploading remote development server"), cx);
1381 log::info!("uploading remote development server ({}kb)", size / 1024);
1382 self.upload_file(&src_path, &dst_path_gz)
1383 .await
1384 .context("failed to upload server binary")?;
1385 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1386
1387 delegate.set_status(Some("Extracting remote development server"), cx);
1388 run_cmd(
1389 self.socket
1390 .ssh_command("gunzip")
1391 .arg("--force")
1392 .arg(&dst_path_gz),
1393 )
1394 .await?;
1395
1396 delegate.set_status(Some("Marking remote development server executable"), cx);
1397 run_cmd(
1398 self.socket
1399 .ssh_command("chmod")
1400 .arg(format!("{:o}", server_mode))
1401 .arg(dst_path),
1402 )
1403 .await?;
1404
1405 Ok(())
1406 }
1407
1408 async fn query_platform(&self) -> Result<SshPlatform> {
1409 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1410 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1411
1412 let os = match os.trim() {
1413 "Darwin" => "macos",
1414 "Linux" => "linux",
1415 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1416 };
1417 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1418 "aarch64"
1419 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1420 "x86_64"
1421 } else {
1422 Err(anyhow!("unknown uname architecture {arch:?}"))?
1423 };
1424
1425 Ok(SshPlatform { os, arch })
1426 }
1427
1428 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1429 let mut command = process::Command::new("scp");
1430 let output = self
1431 .socket
1432 .ssh_options(&mut command)
1433 .args(
1434 self.socket
1435 .connection_options
1436 .port
1437 .map(|port| vec!["-P".to_string(), port.to_string()])
1438 .unwrap_or_default(),
1439 )
1440 .arg(src_path)
1441 .arg(format!(
1442 "{}:{}",
1443 self.socket.connection_options.scp_url(),
1444 dest_path.display()
1445 ))
1446 .output()
1447 .await?;
1448
1449 if output.status.success() {
1450 Ok(())
1451 } else {
1452 Err(anyhow!(
1453 "failed to upload file {} -> {}: {}",
1454 src_path.display(),
1455 dest_path.display(),
1456 String::from_utf8_lossy(&output.stderr)
1457 ))
1458 }
1459 }
1460}
1461
1462type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1463
1464pub struct ChannelClient {
1465 next_message_id: AtomicU32,
1466 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
1467 buffer: Mutex<VecDeque<Envelope>>,
1468 response_channels: ResponseChannels,
1469 message_handlers: Mutex<ProtoMessageHandlerSet>,
1470 max_received: AtomicU32,
1471 name: &'static str,
1472 task: Mutex<Task<Result<()>>>,
1473}
1474
1475impl ChannelClient {
1476 pub fn new(
1477 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1478 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1479 cx: &AppContext,
1480 name: &'static str,
1481 ) -> Arc<Self> {
1482 Arc::new_cyclic(|this| Self {
1483 outgoing_tx: Mutex::new(outgoing_tx),
1484 next_message_id: AtomicU32::new(0),
1485 max_received: AtomicU32::new(0),
1486 response_channels: ResponseChannels::default(),
1487 message_handlers: Default::default(),
1488 buffer: Mutex::new(VecDeque::new()),
1489 name,
1490 task: Mutex::new(Self::start_handling_messages(
1491 this.clone(),
1492 incoming_rx,
1493 &cx.to_async(),
1494 )),
1495 })
1496 }
1497
1498 fn start_handling_messages(
1499 this: Weak<Self>,
1500 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1501 cx: &AsyncAppContext,
1502 ) -> Task<Result<()>> {
1503 cx.spawn(|cx| {
1504 async move {
1505 let peer_id = PeerId { owner_id: 0, id: 0 };
1506 while let Some(incoming) = incoming_rx.next().await {
1507 let Some(this) = this.upgrade() else {
1508 return anyhow::Ok(());
1509 };
1510 if let Some(ack_id) = incoming.ack_id {
1511 let mut buffer = this.buffer.lock();
1512 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1513 buffer.pop_front();
1514 }
1515 }
1516 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) =
1517 &incoming.payload
1518 {
1519 log::debug!("{}:ssh message received. name:FlushBufferedMessages", this.name);
1520 {
1521 let buffer = this.buffer.lock();
1522 for envelope in buffer.iter() {
1523 this.outgoing_tx.lock().unbounded_send(envelope.clone()).ok();
1524 }
1525 }
1526 let mut envelope = proto::Ack{}.into_envelope(0, Some(incoming.id), None);
1527 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
1528 this.outgoing_tx.lock().unbounded_send(envelope).ok();
1529 continue;
1530 }
1531
1532 this.max_received.store(incoming.id, SeqCst);
1533
1534 if let Some(request_id) = incoming.responding_to {
1535 let request_id = MessageId(request_id);
1536 let sender = this.response_channels.lock().remove(&request_id);
1537 if let Some(sender) = sender {
1538 let (tx, rx) = oneshot::channel();
1539 if incoming.payload.is_some() {
1540 sender.send((incoming, tx)).ok();
1541 }
1542 rx.await.ok();
1543 }
1544 } else if let Some(envelope) =
1545 build_typed_envelope(peer_id, Instant::now(), incoming)
1546 {
1547 let type_name = envelope.payload_type_name();
1548 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1549 &this.message_handlers,
1550 envelope,
1551 this.clone().into(),
1552 cx.clone(),
1553 ) {
1554 log::debug!("{}:ssh message received. name:{type_name}", this.name);
1555 cx.foreground_executor().spawn(async move {
1556 match future.await {
1557 Ok(_) => {
1558 log::debug!("{}:ssh message handled. name:{type_name}", this.name);
1559 }
1560 Err(error) => {
1561 log::error!(
1562 "{}:error handling message. type:{type_name}, error:{error}", this.name,
1563 );
1564 }
1565 }
1566 }).detach()
1567 } else {
1568 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
1569 }
1570 }
1571 }
1572 anyhow::Ok(())
1573 }
1574 })
1575 }
1576
1577 pub fn reconnect(
1578 self: &Arc<Self>,
1579 incoming_rx: UnboundedReceiver<Envelope>,
1580 outgoing_tx: UnboundedSender<Envelope>,
1581 cx: &AsyncAppContext,
1582 ) {
1583 *self.outgoing_tx.lock() = outgoing_tx;
1584 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
1585 }
1586
1587 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1588 let id = (TypeId::of::<E>(), remote_id);
1589
1590 let mut message_handlers = self.message_handlers.lock();
1591 if message_handlers
1592 .entities_by_type_and_remote_id
1593 .contains_key(&id)
1594 {
1595 panic!("already subscribed to entity");
1596 }
1597
1598 message_handlers.entities_by_type_and_remote_id.insert(
1599 id,
1600 EntityMessageSubscriber::Entity {
1601 handle: entity.downgrade().into(),
1602 },
1603 );
1604 }
1605
1606 pub fn request<T: RequestMessage>(
1607 &self,
1608 payload: T,
1609 ) -> impl 'static + Future<Output = Result<T::Response>> {
1610 log::debug!("ssh request start. name:{}", T::NAME);
1611 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1612 async move {
1613 let response = response.await?;
1614 log::debug!("ssh request finish. name:{}", T::NAME);
1615 T::Response::from_envelope(response)
1616 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1617 }
1618 }
1619
1620 pub async fn resync(&self, timeout: Duration) -> Result<()> {
1621 smol::future::or(
1622 async {
1623 self.request(proto::FlushBufferedMessages {}).await?;
1624 for envelope in self.buffer.lock().iter() {
1625 self.outgoing_tx
1626 .lock()
1627 .unbounded_send(envelope.clone())
1628 .ok();
1629 }
1630 Ok(())
1631 },
1632 async {
1633 smol::Timer::after(timeout).await;
1634 Err(anyhow!("Timeout detected"))
1635 },
1636 )
1637 .await
1638 }
1639
1640 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1641 smol::future::or(
1642 async {
1643 self.request(proto::Ping {}).await?;
1644 Ok(())
1645 },
1646 async {
1647 smol::Timer::after(timeout).await;
1648 Err(anyhow!("Timeout detected"))
1649 },
1650 )
1651 .await
1652 }
1653
1654 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1655 log::debug!("ssh send name:{}", T::NAME);
1656 self.send_dynamic(payload.into_envelope(0, None, None))
1657 }
1658
1659 pub fn request_dynamic(
1660 &self,
1661 mut envelope: proto::Envelope,
1662 type_name: &'static str,
1663 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1664 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1665 let (tx, rx) = oneshot::channel();
1666 let mut response_channels_lock = self.response_channels.lock();
1667 response_channels_lock.insert(MessageId(envelope.id), tx);
1668 drop(response_channels_lock);
1669
1670 let result = self.send_buffered(envelope);
1671 async move {
1672 if let Err(error) = &result {
1673 log::error!("failed to send message: {}", error);
1674 return Err(anyhow!("failed to send message: {}", error));
1675 }
1676
1677 let response = rx.await.context("connection lost")?.0;
1678 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1679 return Err(RpcError::from_proto(error, type_name));
1680 }
1681 Ok(response)
1682 }
1683 }
1684
1685 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1686 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1687 self.send_buffered(envelope)
1688 }
1689
1690 pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1691 envelope.ack_id = Some(self.max_received.load(SeqCst));
1692 self.buffer.lock().push_back(envelope.clone());
1693 // ignore errors on send (happen while we're reconnecting)
1694 // assume that the global "disconnected" overlay is sufficient.
1695 self.outgoing_tx.lock().unbounded_send(envelope).ok();
1696 Ok(())
1697 }
1698}
1699
1700impl ProtoClient for ChannelClient {
1701 fn request(
1702 &self,
1703 envelope: proto::Envelope,
1704 request_type: &'static str,
1705 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1706 self.request_dynamic(envelope, request_type).boxed()
1707 }
1708
1709 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1710 self.send_dynamic(envelope)
1711 }
1712
1713 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1714 self.send_dynamic(envelope)
1715 }
1716
1717 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1718 &self.message_handlers
1719 }
1720
1721 fn is_via_collab(&self) -> bool {
1722 false
1723 }
1724}
1725
1726#[cfg(any(test, feature = "test-support"))]
1727mod fake {
1728 use std::{path::PathBuf, sync::Arc};
1729
1730 use anyhow::Result;
1731 use async_trait::async_trait;
1732 use futures::{
1733 channel::{
1734 mpsc::{self, Sender},
1735 oneshot,
1736 },
1737 select_biased, FutureExt, SinkExt, StreamExt,
1738 };
1739 use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
1740 use rpc::proto::Envelope;
1741
1742 use super::{
1743 ChannelClient, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess,
1744 };
1745
1746 pub(super) struct SshRemoteConnection {
1747 connection_options: SshConnectionOptions,
1748 }
1749
1750 impl SshRemoteConnection {
1751 pub(super) fn new(
1752 connection_options: &SshConnectionOptions,
1753 ) -> Option<Box<dyn SshRemoteProcess>> {
1754 if connection_options.host == "<fake>" {
1755 return Some(Box::new(Self {
1756 connection_options: connection_options.clone(),
1757 }));
1758 }
1759 return None;
1760 }
1761 pub(super) async fn multiplex(
1762 connection_options: SshConnectionOptions,
1763 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
1764 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
1765 mut connection_activity_tx: Sender<()>,
1766 cx: &mut AsyncAppContext,
1767 ) -> Task<Result<i32>> {
1768 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
1769 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
1770
1771 let (channel, server_cx) = cx
1772 .update(|cx| {
1773 cx.update_global(|conns: &mut ServerConnections, _| {
1774 conns.get(connection_options.port.unwrap())
1775 })
1776 })
1777 .unwrap();
1778 channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx);
1779
1780 // send to proxy_tx to get to the server.
1781 // receive from
1782
1783 cx.background_executor().spawn(async move {
1784 loop {
1785 select_biased! {
1786 server_to_client = server_outgoing_rx.next().fuse() => {
1787 let Some(server_to_client) = server_to_client else {
1788 return Ok(1)
1789 };
1790 connection_activity_tx.try_send(()).ok();
1791 client_incoming_tx.send(server_to_client).await.ok();
1792 }
1793 client_to_server = client_outgoing_rx.next().fuse() => {
1794 let Some(client_to_server) = client_to_server else {
1795 return Ok(1)
1796 };
1797 server_incoming_tx.send(client_to_server).await.ok();
1798 }
1799 }
1800 }
1801 })
1802 }
1803 }
1804
1805 #[async_trait]
1806 impl SshRemoteProcess for SshRemoteConnection {
1807 async fn kill(&mut self) -> Result<()> {
1808 Ok(())
1809 }
1810
1811 fn ssh_args(&self) -> Vec<String> {
1812 Vec::new()
1813 }
1814
1815 fn connection_options(&self) -> SshConnectionOptions {
1816 self.connection_options.clone()
1817 }
1818 }
1819
1820 #[derive(Default)]
1821 pub(super) struct ServerConnections(Vec<(Arc<ChannelClient>, AsyncAppContext)>);
1822 impl Global for ServerConnections {}
1823
1824 impl ServerConnections {
1825 pub(super) fn push(&mut self, server: Arc<ChannelClient>, cx: AsyncAppContext) -> u16 {
1826 self.0.push((server.clone(), cx));
1827 self.0.len() as u16 - 1
1828 }
1829
1830 pub(super) fn get(&mut self, port: u16) -> (Arc<ChannelClient>, AsyncAppContext) {
1831 self.0
1832 .get(port as usize)
1833 .expect("no fake server for port")
1834 .clone()
1835 }
1836 }
1837
1838 pub(super) struct Delegate;
1839
1840 impl SshClientDelegate for Delegate {
1841 fn ask_password(
1842 &self,
1843 _: String,
1844 _: &mut AsyncAppContext,
1845 ) -> oneshot::Receiver<Result<String>> {
1846 unreachable!()
1847 }
1848 fn remote_server_binary_path(
1849 &self,
1850 _: SshPlatform,
1851 _: &mut AsyncAppContext,
1852 ) -> Result<PathBuf> {
1853 unreachable!()
1854 }
1855 fn get_server_binary(
1856 &self,
1857 _: SshPlatform,
1858 _: &mut AsyncAppContext,
1859 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
1860 unreachable!()
1861 }
1862 fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {
1863 unreachable!()
1864 }
1865 fn set_error(&self, _: String, _: &mut AsyncAppContext) {
1866 unreachable!()
1867 }
1868 }
1869}