1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 channel::{
13 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
14 oneshot,
15 },
16 future::{BoxFuture, Shared},
17 select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
18};
19use gpui::{
20 AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model,
21 ModelContext, SemanticVersion, Task, WeakModel,
22};
23use itertools::Itertools;
24use parking_lot::Mutex;
25use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
26use rpc::{
27 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
28 AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
29 RpcError,
30};
31use smol::{
32 fs,
33 process::{self, Child, Stdio},
34};
35use std::{
36 any::TypeId,
37 collections::VecDeque,
38 fmt, iter,
39 ops::ControlFlow,
40 path::{Path, PathBuf},
41 sync::{
42 atomic::{AtomicU32, Ordering::SeqCst},
43 Arc, Weak,
44 },
45 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
46};
47use tempfile::TempDir;
48use util::ResultExt;
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
52)]
53pub struct SshProjectId(pub u64);
54
55#[derive(Clone)]
56pub struct SshSocket {
57 connection_options: SshConnectionOptions,
58 socket_path: PathBuf,
59}
60
61#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
62pub struct SshConnectionOptions {
63 pub host: String,
64 pub username: Option<String>,
65 pub port: Option<u16>,
66 pub password: Option<String>,
67 pub args: Option<Vec<String>>,
68
69 pub nickname: Option<String>,
70 pub upload_binary_over_ssh: bool,
71}
72
73#[macro_export]
74macro_rules! shell_script {
75 ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
76 format!(
77 $fmt,
78 $(
79 $name = shlex::try_quote($arg).unwrap()
80 ),+
81 )
82 }};
83}
84
85impl SshConnectionOptions {
86 pub fn parse_command_line(input: &str) -> Result<Self> {
87 let input = input.trim_start_matches("ssh ");
88 let mut hostname: Option<String> = None;
89 let mut username: Option<String> = None;
90 let mut port: Option<u16> = None;
91 let mut args = Vec::new();
92
93 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
94 const ALLOWED_OPTS: &[&str] = &[
95 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
96 ];
97 const ALLOWED_ARGS: &[&str] = &[
98 "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
99 "-w",
100 ];
101
102 let mut tokens = shlex::split(input)
103 .ok_or_else(|| anyhow!("invalid input"))?
104 .into_iter();
105
106 'outer: while let Some(arg) = tokens.next() {
107 if ALLOWED_OPTS.contains(&(&arg as &str)) {
108 args.push(arg.to_string());
109 continue;
110 }
111 if arg == "-p" {
112 port = tokens.next().and_then(|arg| arg.parse().ok());
113 continue;
114 } else if let Some(p) = arg.strip_prefix("-p") {
115 port = p.parse().ok();
116 continue;
117 }
118 if arg == "-l" {
119 username = tokens.next();
120 continue;
121 } else if let Some(l) = arg.strip_prefix("-l") {
122 username = Some(l.to_string());
123 continue;
124 }
125 for a in ALLOWED_ARGS {
126 if arg == *a {
127 args.push(arg);
128 if let Some(next) = tokens.next() {
129 args.push(next);
130 }
131 continue 'outer;
132 } else if arg.starts_with(a) {
133 args.push(arg);
134 continue 'outer;
135 }
136 }
137 if arg.starts_with("-") || hostname.is_some() {
138 anyhow::bail!("unsupported argument: {:?}", arg);
139 }
140 let mut input = &arg as &str;
141 if let Some((u, rest)) = input.split_once('@') {
142 input = rest;
143 username = Some(u.to_string());
144 }
145 if let Some((rest, p)) = input.split_once(':') {
146 input = rest;
147 port = p.parse().ok()
148 }
149 hostname = Some(input.to_string())
150 }
151
152 let Some(hostname) = hostname else {
153 anyhow::bail!("missing hostname");
154 };
155
156 Ok(Self {
157 host: hostname.to_string(),
158 username: username.clone(),
159 port,
160 args: Some(args),
161 password: None,
162 nickname: None,
163 upload_binary_over_ssh: false,
164 })
165 }
166
167 pub fn ssh_url(&self) -> String {
168 let mut result = String::from("ssh://");
169 if let Some(username) = &self.username {
170 result.push_str(username);
171 result.push('@');
172 }
173 result.push_str(&self.host);
174 if let Some(port) = self.port {
175 result.push(':');
176 result.push_str(&port.to_string());
177 }
178 result
179 }
180
181 pub fn additional_args(&self) -> Option<&Vec<String>> {
182 self.args.as_ref()
183 }
184
185 fn scp_url(&self) -> String {
186 if let Some(username) = &self.username {
187 format!("{}@{}", username, self.host)
188 } else {
189 self.host.clone()
190 }
191 }
192
193 pub fn connection_string(&self) -> String {
194 let host = if let Some(username) = &self.username {
195 format!("{}@{}", username, self.host)
196 } else {
197 self.host.clone()
198 };
199 if let Some(port) = &self.port {
200 format!("{}:{}", host, port)
201 } else {
202 host
203 }
204 }
205
206 // Uniquely identifies dev server projects on a remote host. Needs to be
207 // stable for the same dev server project.
208 pub fn remote_server_identifier(&self) -> String {
209 let mut identifier = format!("dev-server-{:?}", self.host);
210 if let Some(username) = self.username.as_ref() {
211 identifier.push('-');
212 identifier.push_str(&username);
213 }
214 identifier
215 }
216}
217
218#[derive(Copy, Clone, Debug)]
219pub struct SshPlatform {
220 pub os: &'static str,
221 pub arch: &'static str,
222}
223
224impl SshPlatform {
225 pub fn triple(&self) -> Option<String> {
226 Some(format!(
227 "{}-{}",
228 self.arch,
229 match self.os {
230 "linux" => "unknown-linux-gnu",
231 "macos" => "apple-darwin",
232 _ => return None,
233 }
234 ))
235 }
236}
237
238pub enum ServerBinary {
239 LocalBinary(PathBuf),
240 ReleaseUrl { url: String, body: String },
241}
242
243#[derive(Clone, Debug, PartialEq, Eq)]
244pub enum ServerVersion {
245 Semantic(SemanticVersion),
246 Commit(String),
247}
248impl ServerVersion {
249 pub fn semantic_version(&self) -> Option<SemanticVersion> {
250 match self {
251 Self::Semantic(version) => Some(*version),
252 _ => None,
253 }
254 }
255}
256
257impl std::fmt::Display for ServerVersion {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 match self {
260 Self::Semantic(version) => write!(f, "{}", version),
261 Self::Commit(commit) => write!(f, "{}", commit),
262 }
263 }
264}
265
266pub trait SshClientDelegate: Send + Sync {
267 fn ask_password(
268 &self,
269 prompt: String,
270 cx: &mut AsyncAppContext,
271 ) -> oneshot::Receiver<Result<String>>;
272 fn remote_server_binary_path(
273 &self,
274 platform: SshPlatform,
275 cx: &mut AsyncAppContext,
276 ) -> Result<PathBuf>;
277 fn get_download_params(
278 &self,
279 platform: SshPlatform,
280 release_channel: ReleaseChannel,
281 version: Option<SemanticVersion>,
282 cx: &mut AsyncAppContext,
283 ) -> Task<Result<(String, String)>>;
284
285 fn download_server_binary_locally(
286 &self,
287 platform: SshPlatform,
288 release_channel: ReleaseChannel,
289 version: Option<SemanticVersion>,
290 cx: &mut AsyncAppContext,
291 ) -> Task<Result<PathBuf>>;
292 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
293}
294
295impl SshSocket {
296 // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
297 // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
298 // and passes -l as an argument to sh, not to ls.
299 // You need to do it like this: $ ssh host "sh -c 'ls -l /tmp'"
300 fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
301 let mut command = process::Command::new("ssh");
302 let to_run = iter::once(&program)
303 .chain(args.iter())
304 .map(|token| shlex::try_quote(token).unwrap())
305 .join(" ");
306 self.ssh_options(&mut command)
307 .arg(self.connection_options.ssh_url())
308 .arg(to_run);
309 command
310 }
311
312 fn shell_script(&self, script: impl AsRef<str>) -> process::Command {
313 return self.ssh_command("sh", &["-c", script.as_ref()]);
314 }
315
316 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
317 command
318 .stdin(Stdio::piped())
319 .stdout(Stdio::piped())
320 .stderr(Stdio::piped())
321 .args(["-o", "ControlMaster=no", "-o"])
322 .arg(format!("ControlPath={}", self.socket_path.display()))
323 }
324
325 fn ssh_args(&self) -> Vec<String> {
326 vec![
327 "-o".to_string(),
328 "ControlMaster=no".to_string(),
329 "-o".to_string(),
330 format!("ControlPath={}", self.socket_path.display()),
331 self.connection_options.ssh_url(),
332 ]
333 }
334}
335
336async fn run_cmd(mut command: process::Command) -> Result<String> {
337 let output = command.output().await?;
338 if output.status.success() {
339 Ok(String::from_utf8_lossy(&output.stdout).to_string())
340 } else {
341 Err(anyhow!(
342 "failed to run command: {}",
343 String::from_utf8_lossy(&output.stderr)
344 ))
345 }
346}
347
348const MAX_MISSED_HEARTBEATS: usize = 5;
349const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
350const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
351
352const MAX_RECONNECT_ATTEMPTS: usize = 3;
353
354enum State {
355 Connecting,
356 Connected {
357 ssh_connection: Arc<dyn RemoteConnection>,
358 delegate: Arc<dyn SshClientDelegate>,
359
360 multiplex_task: Task<Result<()>>,
361 heartbeat_task: Task<Result<()>>,
362 },
363 HeartbeatMissed {
364 missed_heartbeats: usize,
365
366 ssh_connection: Arc<dyn RemoteConnection>,
367 delegate: Arc<dyn SshClientDelegate>,
368
369 multiplex_task: Task<Result<()>>,
370 heartbeat_task: Task<Result<()>>,
371 },
372 Reconnecting,
373 ReconnectFailed {
374 ssh_connection: Arc<dyn RemoteConnection>,
375 delegate: Arc<dyn SshClientDelegate>,
376
377 error: anyhow::Error,
378 attempts: usize,
379 },
380 ReconnectExhausted,
381 ServerNotRunning,
382}
383
384impl fmt::Display for State {
385 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386 match self {
387 Self::Connecting => write!(f, "connecting"),
388 Self::Connected { .. } => write!(f, "connected"),
389 Self::Reconnecting => write!(f, "reconnecting"),
390 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
391 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
392 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
393 Self::ServerNotRunning { .. } => write!(f, "server not running"),
394 }
395 }
396}
397
398impl State {
399 fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
400 match self {
401 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
402 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
403 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
404 _ => None,
405 }
406 }
407
408 fn can_reconnect(&self) -> bool {
409 match self {
410 Self::Connected { .. }
411 | Self::HeartbeatMissed { .. }
412 | Self::ReconnectFailed { .. } => true,
413 State::Connecting
414 | State::Reconnecting
415 | State::ReconnectExhausted
416 | State::ServerNotRunning => false,
417 }
418 }
419
420 fn is_reconnect_failed(&self) -> bool {
421 matches!(self, Self::ReconnectFailed { .. })
422 }
423
424 fn is_reconnect_exhausted(&self) -> bool {
425 matches!(self, Self::ReconnectExhausted { .. })
426 }
427
428 fn is_server_not_running(&self) -> bool {
429 matches!(self, Self::ServerNotRunning)
430 }
431
432 fn is_reconnecting(&self) -> bool {
433 matches!(self, Self::Reconnecting { .. })
434 }
435
436 fn heartbeat_recovered(self) -> Self {
437 match self {
438 Self::HeartbeatMissed {
439 ssh_connection,
440 delegate,
441 multiplex_task,
442 heartbeat_task,
443 ..
444 } => Self::Connected {
445 ssh_connection,
446 delegate,
447 multiplex_task,
448 heartbeat_task,
449 },
450 _ => self,
451 }
452 }
453
454 fn heartbeat_missed(self) -> Self {
455 match self {
456 Self::Connected {
457 ssh_connection,
458 delegate,
459 multiplex_task,
460 heartbeat_task,
461 } => Self::HeartbeatMissed {
462 missed_heartbeats: 1,
463 ssh_connection,
464 delegate,
465 multiplex_task,
466 heartbeat_task,
467 },
468 Self::HeartbeatMissed {
469 missed_heartbeats,
470 ssh_connection,
471 delegate,
472 multiplex_task,
473 heartbeat_task,
474 } => Self::HeartbeatMissed {
475 missed_heartbeats: missed_heartbeats + 1,
476 ssh_connection,
477 delegate,
478 multiplex_task,
479 heartbeat_task,
480 },
481 _ => self,
482 }
483 }
484}
485
486/// The state of the ssh connection.
487#[derive(Clone, Copy, Debug, PartialEq, Eq)]
488pub enum ConnectionState {
489 Connecting,
490 Connected,
491 HeartbeatMissed,
492 Reconnecting,
493 Disconnected,
494}
495
496impl From<&State> for ConnectionState {
497 fn from(value: &State) -> Self {
498 match value {
499 State::Connecting => Self::Connecting,
500 State::Connected { .. } => Self::Connected,
501 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
502 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
503 State::ReconnectExhausted => Self::Disconnected,
504 State::ServerNotRunning => Self::Disconnected,
505 }
506 }
507}
508
509pub struct SshRemoteClient {
510 client: Arc<ChannelClient>,
511 unique_identifier: String,
512 connection_options: SshConnectionOptions,
513 state: Arc<Mutex<Option<State>>>,
514}
515
516#[derive(Debug)]
517pub enum SshRemoteEvent {
518 Disconnected,
519}
520
521impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
522
523impl SshRemoteClient {
524 pub fn new(
525 unique_identifier: String,
526 connection_options: SshConnectionOptions,
527 cancellation: oneshot::Receiver<()>,
528 delegate: Arc<dyn SshClientDelegate>,
529 cx: &mut AppContext,
530 ) -> Task<Result<Option<Model<Self>>>> {
531 cx.spawn(|mut cx| async move {
532 let success = Box::pin(async move {
533 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
534 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
535 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
536
537 let client =
538 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
539 let this = cx.new_model(|_| Self {
540 client: client.clone(),
541 unique_identifier: unique_identifier.clone(),
542 connection_options: connection_options.clone(),
543 state: Arc::new(Mutex::new(Some(State::Connecting))),
544 })?;
545
546 let ssh_connection = cx
547 .update(|cx| {
548 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
549 pool.connect(connection_options, &delegate, cx)
550 })
551 })?
552 .await
553 .map_err(|e| e.cloned())?;
554 let remote_binary_path = ssh_connection
555 .get_remote_binary_path(&delegate, false, &mut cx)
556 .await?;
557
558 let io_task = ssh_connection.start_proxy(
559 remote_binary_path,
560 unique_identifier,
561 false,
562 incoming_tx,
563 outgoing_rx,
564 connection_activity_tx,
565 delegate.clone(),
566 &mut cx,
567 );
568
569 let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
570
571 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
572 log::error!("failed to establish connection: {}", error);
573 return Err(error);
574 }
575
576 let heartbeat_task =
577 Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
578
579 this.update(&mut cx, |this, _| {
580 *this.state.lock() = Some(State::Connected {
581 ssh_connection,
582 delegate,
583 multiplex_task,
584 heartbeat_task,
585 });
586 })?;
587
588 Ok(Some(this))
589 });
590
591 select! {
592 _ = cancellation.fuse() => {
593 Ok(None)
594 }
595 result = success.fuse() => result
596 }
597 })
598 }
599
600 pub fn shutdown_processes<T: RequestMessage>(
601 &self,
602 shutdown_request: Option<T>,
603 ) -> Option<impl Future<Output = ()>> {
604 let state = self.state.lock().take()?;
605 log::info!("shutting down ssh processes");
606
607 let State::Connected {
608 multiplex_task,
609 heartbeat_task,
610 ssh_connection,
611 delegate,
612 } = state
613 else {
614 return None;
615 };
616
617 let client = self.client.clone();
618
619 Some(async move {
620 if let Some(shutdown_request) = shutdown_request {
621 client.send(shutdown_request).log_err();
622 // We wait 50ms instead of waiting for a response, because
623 // waiting for a response would require us to wait on the main thread
624 // which we want to avoid in an `on_app_quit` callback.
625 smol::Timer::after(Duration::from_millis(50)).await;
626 }
627
628 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
629 // child of master_process.
630 drop(multiplex_task);
631 // Now drop the rest of state, which kills master process.
632 drop(heartbeat_task);
633 drop(ssh_connection);
634 drop(delegate);
635 })
636 }
637
638 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
639 let mut lock = self.state.lock();
640
641 let can_reconnect = lock
642 .as_ref()
643 .map(|state| state.can_reconnect())
644 .unwrap_or(false);
645 if !can_reconnect {
646 let error = if let Some(state) = lock.as_ref() {
647 format!("invalid state, cannot reconnect while in state {state}")
648 } else {
649 "no state set".to_string()
650 };
651 log::info!("aborting reconnect, because not in state that allows reconnecting");
652 return Err(anyhow!(error));
653 }
654
655 let state = lock.take().unwrap();
656 let (attempts, ssh_connection, delegate) = match state {
657 State::Connected {
658 ssh_connection,
659 delegate,
660 multiplex_task,
661 heartbeat_task,
662 }
663 | State::HeartbeatMissed {
664 ssh_connection,
665 delegate,
666 multiplex_task,
667 heartbeat_task,
668 ..
669 } => {
670 drop(multiplex_task);
671 drop(heartbeat_task);
672 (0, ssh_connection, delegate)
673 }
674 State::ReconnectFailed {
675 attempts,
676 ssh_connection,
677 delegate,
678 ..
679 } => (attempts, ssh_connection, delegate),
680 State::Connecting
681 | State::Reconnecting
682 | State::ReconnectExhausted
683 | State::ServerNotRunning => unreachable!(),
684 };
685
686 let attempts = attempts + 1;
687 if attempts > MAX_RECONNECT_ATTEMPTS {
688 log::error!(
689 "Failed to reconnect to after {} attempts, giving up",
690 MAX_RECONNECT_ATTEMPTS
691 );
692 drop(lock);
693 self.set_state(State::ReconnectExhausted, cx);
694 return Ok(());
695 }
696 drop(lock);
697
698 self.set_state(State::Reconnecting, cx);
699
700 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
701
702 let unique_identifier = self.unique_identifier.clone();
703 let client = self.client.clone();
704 let reconnect_task = cx.spawn(|this, mut cx| async move {
705 macro_rules! failed {
706 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
707 return State::ReconnectFailed {
708 error: anyhow!($error),
709 attempts: $attempts,
710 ssh_connection: $ssh_connection,
711 delegate: $delegate,
712 };
713 };
714 }
715
716 if let Err(error) = ssh_connection
717 .kill()
718 .await
719 .context("Failed to kill ssh process")
720 {
721 failed!(error, attempts, ssh_connection, delegate);
722 };
723
724 let connection_options = ssh_connection.connection_options();
725
726 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
727 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
728 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
729
730 let (ssh_connection, io_task) = match async {
731 let ssh_connection = cx
732 .update_global(|pool: &mut ConnectionPool, cx| {
733 pool.connect(connection_options, &delegate, cx)
734 })?
735 .await
736 .map_err(|error| error.cloned())?;
737
738 let remote_binary_path = ssh_connection
739 .get_remote_binary_path(&delegate, true, &mut cx)
740 .await?;
741
742 let io_task = ssh_connection.start_proxy(
743 remote_binary_path,
744 unique_identifier,
745 true,
746 incoming_tx,
747 outgoing_rx,
748 connection_activity_tx,
749 delegate.clone(),
750 &mut cx,
751 );
752 anyhow::Ok((ssh_connection, io_task))
753 }
754 .await
755 {
756 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
757 Err(error) => {
758 failed!(error, attempts, ssh_connection, delegate);
759 }
760 };
761
762 let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
763 client.reconnect(incoming_rx, outgoing_tx, &cx);
764
765 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
766 failed!(error, attempts, ssh_connection, delegate);
767 };
768
769 State::Connected {
770 ssh_connection,
771 delegate,
772 multiplex_task,
773 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
774 }
775 });
776
777 cx.spawn(|this, mut cx| async move {
778 let new_state = reconnect_task.await;
779 this.update(&mut cx, |this, cx| {
780 this.try_set_state(cx, |old_state| {
781 if old_state.is_reconnecting() {
782 match &new_state {
783 State::Connecting
784 | State::Reconnecting { .. }
785 | State::HeartbeatMissed { .. }
786 | State::ServerNotRunning => {}
787 State::Connected { .. } => {
788 log::info!("Successfully reconnected");
789 }
790 State::ReconnectFailed {
791 error, attempts, ..
792 } => {
793 log::error!(
794 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
795 attempts,
796 error
797 );
798 }
799 State::ReconnectExhausted => {
800 log::error!("Reconnect attempt failed and all attempts exhausted");
801 }
802 }
803 Some(new_state)
804 } else {
805 None
806 }
807 });
808
809 if this.state_is(State::is_reconnect_failed) {
810 this.reconnect(cx)
811 } else if this.state_is(State::is_reconnect_exhausted) {
812 Ok(())
813 } else {
814 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
815 Ok(())
816 }
817 })
818 })
819 .detach_and_log_err(cx);
820
821 Ok(())
822 }
823
824 fn heartbeat(
825 this: WeakModel<Self>,
826 mut connection_activity_rx: mpsc::Receiver<()>,
827 cx: &mut AsyncAppContext,
828 ) -> Task<Result<()>> {
829 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
830 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
831 };
832
833 cx.spawn(|mut cx| {
834 let this = this.clone();
835 async move {
836 let mut missed_heartbeats = 0;
837
838 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
839 futures::pin_mut!(keepalive_timer);
840
841 loop {
842 select_biased! {
843 result = connection_activity_rx.next().fuse() => {
844 if result.is_none() {
845 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
846 return Ok(());
847 }
848
849 if missed_heartbeats != 0 {
850 missed_heartbeats = 0;
851 this.update(&mut cx, |this, mut cx| {
852 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
853 })?;
854 }
855 }
856 _ = keepalive_timer => {
857 log::debug!("Sending heartbeat to server...");
858
859 let result = select_biased! {
860 _ = connection_activity_rx.next().fuse() => {
861 Ok(())
862 }
863 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
864 ping_result
865 }
866 };
867
868 if result.is_err() {
869 missed_heartbeats += 1;
870 log::warn!(
871 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
872 HEARTBEAT_TIMEOUT,
873 missed_heartbeats,
874 MAX_MISSED_HEARTBEATS
875 );
876 } else if missed_heartbeats != 0 {
877 missed_heartbeats = 0;
878 } else {
879 continue;
880 }
881
882 let result = this.update(&mut cx, |this, mut cx| {
883 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
884 })?;
885 if result.is_break() {
886 return Ok(());
887 }
888 }
889 }
890
891 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
892 }
893 }
894 })
895 }
896
897 fn handle_heartbeat_result(
898 &mut self,
899 missed_heartbeats: usize,
900 cx: &mut ModelContext<Self>,
901 ) -> ControlFlow<()> {
902 let state = self.state.lock().take().unwrap();
903 let next_state = if missed_heartbeats > 0 {
904 state.heartbeat_missed()
905 } else {
906 state.heartbeat_recovered()
907 };
908
909 self.set_state(next_state, cx);
910
911 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
912 log::error!(
913 "Missed last {} heartbeats. Reconnecting...",
914 missed_heartbeats
915 );
916
917 self.reconnect(cx)
918 .context("failed to start reconnect process after missing heartbeats")
919 .log_err();
920 ControlFlow::Break(())
921 } else {
922 ControlFlow::Continue(())
923 }
924 }
925
926 fn monitor(
927 this: WeakModel<Self>,
928 io_task: Task<Result<i32>>,
929 cx: &AsyncAppContext,
930 ) -> Task<Result<()>> {
931 cx.spawn(|mut cx| async move {
932 let result = io_task.await;
933
934 match result {
935 Ok(exit_code) => {
936 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
937 match error {
938 ProxyLaunchError::ServerNotRunning => {
939 log::error!("failed to reconnect because server is not running");
940 this.update(&mut cx, |this, cx| {
941 this.set_state(State::ServerNotRunning, cx);
942 })?;
943 }
944 }
945 } else if exit_code > 0 {
946 log::error!("proxy process terminated unexpectedly");
947 this.update(&mut cx, |this, cx| {
948 this.reconnect(cx).ok();
949 })?;
950 }
951 }
952 Err(error) => {
953 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
954 this.update(&mut cx, |this, cx| {
955 this.reconnect(cx).ok();
956 })?;
957 }
958 }
959
960 Ok(())
961 })
962 }
963
964 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
965 self.state.lock().as_ref().map_or(false, check)
966 }
967
968 fn try_set_state(
969 &self,
970 cx: &mut ModelContext<Self>,
971 map: impl FnOnce(&State) -> Option<State>,
972 ) {
973 let mut lock = self.state.lock();
974 let new_state = lock.as_ref().and_then(map);
975
976 if let Some(new_state) = new_state {
977 lock.replace(new_state);
978 cx.notify();
979 }
980 }
981
982 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
983 log::info!("setting state to '{}'", &state);
984
985 let is_reconnect_exhausted = state.is_reconnect_exhausted();
986 let is_server_not_running = state.is_server_not_running();
987 self.state.lock().replace(state);
988
989 if is_reconnect_exhausted || is_server_not_running {
990 cx.emit(SshRemoteEvent::Disconnected);
991 }
992 cx.notify();
993 }
994
995 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
996 self.client.subscribe_to_entity(remote_id, entity);
997 }
998
999 pub fn ssh_args(&self) -> Option<Vec<String>> {
1000 self.state
1001 .lock()
1002 .as_ref()
1003 .and_then(|state| state.ssh_connection())
1004 .map(|ssh_connection| ssh_connection.ssh_args())
1005 }
1006
1007 pub fn proto_client(&self) -> AnyProtoClient {
1008 self.client.clone().into()
1009 }
1010
1011 pub fn connection_string(&self) -> String {
1012 self.connection_options.connection_string()
1013 }
1014
1015 pub fn connection_options(&self) -> SshConnectionOptions {
1016 self.connection_options.clone()
1017 }
1018
1019 pub fn connection_state(&self) -> ConnectionState {
1020 self.state
1021 .lock()
1022 .as_ref()
1023 .map(ConnectionState::from)
1024 .unwrap_or(ConnectionState::Disconnected)
1025 }
1026
1027 pub fn is_disconnected(&self) -> bool {
1028 self.connection_state() == ConnectionState::Disconnected
1029 }
1030
1031 #[cfg(any(test, feature = "test-support"))]
1032 pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
1033 let opts = self.connection_options();
1034 client_cx.spawn(|cx| async move {
1035 let connection = cx
1036 .update_global(|c: &mut ConnectionPool, _| {
1037 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1038 c.clone()
1039 } else {
1040 panic!("missing test connection")
1041 }
1042 })
1043 .unwrap()
1044 .await
1045 .unwrap();
1046
1047 connection.simulate_disconnect(&cx);
1048 })
1049 }
1050
1051 #[cfg(any(test, feature = "test-support"))]
1052 pub fn fake_server(
1053 client_cx: &mut gpui::TestAppContext,
1054 server_cx: &mut gpui::TestAppContext,
1055 ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1056 let port = client_cx
1057 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1058 let opts = SshConnectionOptions {
1059 host: "<fake>".to_string(),
1060 port: Some(port),
1061 ..Default::default()
1062 };
1063 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1064 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1065 let server_client =
1066 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1067 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1068 connection_options: opts.clone(),
1069 server_cx: fake::SendableCx::new(server_cx),
1070 server_channel: server_client.clone(),
1071 });
1072
1073 client_cx.update(|cx| {
1074 cx.update_default_global(|c: &mut ConnectionPool, cx| {
1075 c.connections.insert(
1076 opts.clone(),
1077 ConnectionPoolEntry::Connecting(
1078 cx.foreground_executor()
1079 .spawn({
1080 let connection = connection.clone();
1081 async move { Ok(connection.clone()) }
1082 })
1083 .shared(),
1084 ),
1085 );
1086 })
1087 });
1088
1089 (opts, server_client)
1090 }
1091
1092 #[cfg(any(test, feature = "test-support"))]
1093 pub async fn fake_client(
1094 opts: SshConnectionOptions,
1095 client_cx: &mut gpui::TestAppContext,
1096 ) -> Model<Self> {
1097 let (_tx, rx) = oneshot::channel();
1098 client_cx
1099 .update(|cx| Self::new("fake".to_string(), opts, rx, Arc::new(fake::Delegate), cx))
1100 .await
1101 .unwrap()
1102 .unwrap()
1103 }
1104}
1105
1106enum ConnectionPoolEntry {
1107 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1108 Connected(Weak<dyn RemoteConnection>),
1109}
1110
1111#[derive(Default)]
1112struct ConnectionPool {
1113 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1114}
1115
1116impl Global for ConnectionPool {}
1117
1118impl ConnectionPool {
1119 pub fn connect(
1120 &mut self,
1121 opts: SshConnectionOptions,
1122 delegate: &Arc<dyn SshClientDelegate>,
1123 cx: &mut AppContext,
1124 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1125 let connection = self.connections.get(&opts);
1126 match connection {
1127 Some(ConnectionPoolEntry::Connecting(task)) => {
1128 let delegate = delegate.clone();
1129 cx.spawn(|mut cx| async move {
1130 delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx);
1131 })
1132 .detach();
1133 return task.clone();
1134 }
1135 Some(ConnectionPoolEntry::Connected(ssh)) => {
1136 if let Some(ssh) = ssh.upgrade() {
1137 if !ssh.has_been_killed() {
1138 return Task::ready(Ok(ssh)).shared();
1139 }
1140 }
1141 self.connections.remove(&opts);
1142 }
1143 None => {}
1144 }
1145
1146 let task = cx
1147 .spawn({
1148 let opts = opts.clone();
1149 let delegate = delegate.clone();
1150 |mut cx| async move {
1151 let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx)
1152 .await
1153 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1154
1155 cx.update_global(|pool: &mut Self, _| {
1156 debug_assert!(matches!(
1157 pool.connections.get(&opts),
1158 Some(ConnectionPoolEntry::Connecting(_))
1159 ));
1160 match connection {
1161 Ok(connection) => {
1162 pool.connections.insert(
1163 opts.clone(),
1164 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1165 );
1166 Ok(connection)
1167 }
1168 Err(error) => {
1169 pool.connections.remove(&opts);
1170 Err(Arc::new(error))
1171 }
1172 }
1173 })?
1174 }
1175 })
1176 .shared();
1177
1178 self.connections
1179 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1180 task
1181 }
1182}
1183
1184impl From<SshRemoteClient> for AnyProtoClient {
1185 fn from(client: SshRemoteClient) -> Self {
1186 AnyProtoClient::new(client.client.clone())
1187 }
1188}
1189
1190#[async_trait(?Send)]
1191trait RemoteConnection: Send + Sync {
1192 #[allow(clippy::too_many_arguments)]
1193 fn start_proxy(
1194 &self,
1195 remote_binary_path: PathBuf,
1196 unique_identifier: String,
1197 reconnect: bool,
1198 incoming_tx: UnboundedSender<Envelope>,
1199 outgoing_rx: UnboundedReceiver<Envelope>,
1200 connection_activity_tx: Sender<()>,
1201 delegate: Arc<dyn SshClientDelegate>,
1202 cx: &mut AsyncAppContext,
1203 ) -> Task<Result<i32>>;
1204 async fn get_remote_binary_path(
1205 &self,
1206 delegate: &Arc<dyn SshClientDelegate>,
1207 reconnect: bool,
1208 cx: &mut AsyncAppContext,
1209 ) -> Result<PathBuf>;
1210 async fn kill(&self) -> Result<()>;
1211 fn has_been_killed(&self) -> bool;
1212 fn ssh_args(&self) -> Vec<String>;
1213 fn connection_options(&self) -> SshConnectionOptions;
1214
1215 #[cfg(any(test, feature = "test-support"))]
1216 fn simulate_disconnect(&self, _: &AsyncAppContext) {}
1217}
1218
1219struct SshRemoteConnection {
1220 socket: SshSocket,
1221 master_process: Mutex<Option<process::Child>>,
1222 platform: SshPlatform,
1223 _temp_dir: TempDir,
1224}
1225
1226#[async_trait(?Send)]
1227impl RemoteConnection for SshRemoteConnection {
1228 async fn kill(&self) -> Result<()> {
1229 let Some(mut process) = self.master_process.lock().take() else {
1230 return Ok(());
1231 };
1232 process.kill().ok();
1233 process.status().await?;
1234 Ok(())
1235 }
1236
1237 fn has_been_killed(&self) -> bool {
1238 self.master_process.lock().is_none()
1239 }
1240
1241 fn ssh_args(&self) -> Vec<String> {
1242 self.socket.ssh_args()
1243 }
1244
1245 fn connection_options(&self) -> SshConnectionOptions {
1246 self.socket.connection_options.clone()
1247 }
1248
1249 async fn get_remote_binary_path(
1250 &self,
1251 delegate: &Arc<dyn SshClientDelegate>,
1252 reconnect: bool,
1253 cx: &mut AsyncAppContext,
1254 ) -> Result<PathBuf> {
1255 let platform = self.platform;
1256 let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1257 if !reconnect {
1258 self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1259 .await?;
1260 }
1261
1262 let socket = self.socket.clone();
1263 run_cmd(socket.ssh_command(&remote_binary_path.to_string_lossy(), &["version"])).await?;
1264 Ok(remote_binary_path)
1265 }
1266
1267 fn start_proxy(
1268 &self,
1269 remote_binary_path: PathBuf,
1270 unique_identifier: String,
1271 reconnect: bool,
1272 incoming_tx: UnboundedSender<Envelope>,
1273 outgoing_rx: UnboundedReceiver<Envelope>,
1274 connection_activity_tx: Sender<()>,
1275 delegate: Arc<dyn SshClientDelegate>,
1276 cx: &mut AsyncAppContext,
1277 ) -> Task<Result<i32>> {
1278 delegate.set_status(Some("Starting proxy"), cx);
1279
1280 let mut start_proxy_command = shell_script!(
1281 "exec {binary_path} proxy --identifier {identifier}",
1282 binary_path = &remote_binary_path.to_string_lossy(),
1283 identifier = &unique_identifier,
1284 );
1285
1286 if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1287 start_proxy_command = format!(
1288 "RUST_LOG={} {}",
1289 shlex::try_quote(&rust_log).unwrap(),
1290 start_proxy_command
1291 )
1292 }
1293 if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1294 start_proxy_command = format!(
1295 "RUST_BACKTRACE={} {}",
1296 shlex::try_quote(&rust_backtrace).unwrap(),
1297 start_proxy_command
1298 )
1299 }
1300 if reconnect {
1301 start_proxy_command.push_str(" --reconnect");
1302 }
1303
1304 let ssh_proxy_process = match self
1305 .socket
1306 .shell_script(start_proxy_command)
1307 // IMPORTANT: we kill this process when we drop the task that uses it.
1308 .kill_on_drop(true)
1309 .spawn()
1310 {
1311 Ok(process) => process,
1312 Err(error) => {
1313 return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)))
1314 }
1315 };
1316
1317 Self::multiplex(
1318 ssh_proxy_process,
1319 incoming_tx,
1320 outgoing_rx,
1321 connection_activity_tx,
1322 &cx,
1323 )
1324 }
1325}
1326
1327impl SshRemoteConnection {
1328 #[cfg(not(unix))]
1329 async fn new(
1330 _connection_options: SshConnectionOptions,
1331 _delegate: Arc<dyn SshClientDelegate>,
1332 _cx: &mut AsyncAppContext,
1333 ) -> Result<Self> {
1334 Err(anyhow!("ssh is not supported on this platform"))
1335 }
1336
1337 #[cfg(unix)]
1338 async fn new(
1339 connection_options: SshConnectionOptions,
1340 delegate: Arc<dyn SshClientDelegate>,
1341 cx: &mut AsyncAppContext,
1342 ) -> Result<Self> {
1343 use futures::AsyncWriteExt as _;
1344 use futures::{io::BufReader, AsyncBufReadExt as _};
1345 use smol::net::unix::UnixStream;
1346 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1347 use util::ResultExt as _;
1348
1349 delegate.set_status(Some("Connecting"), cx);
1350
1351 let url = connection_options.ssh_url();
1352 let temp_dir = tempfile::Builder::new()
1353 .prefix("zed-ssh-session")
1354 .tempdir()?;
1355
1356 // Create a domain socket listener to handle requests from the askpass program.
1357 let askpass_socket = temp_dir.path().join("askpass.sock");
1358 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1359 let listener =
1360 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1361
1362 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<UnixStream>();
1363 let mut kill_tx = Some(askpass_kill_master_tx);
1364
1365 let askpass_task = cx.spawn({
1366 let delegate = delegate.clone();
1367 |mut cx| async move {
1368 let mut askpass_opened_tx = Some(askpass_opened_tx);
1369
1370 while let Ok((mut stream, _)) = listener.accept().await {
1371 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1372 askpass_opened_tx.send(()).ok();
1373 }
1374 let mut buffer = Vec::new();
1375 let mut reader = BufReader::new(&mut stream);
1376 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1377 buffer.clear();
1378 }
1379 let password_prompt = String::from_utf8_lossy(&buffer);
1380 if let Some(password) = delegate
1381 .ask_password(password_prompt.to_string(), &mut cx)
1382 .await
1383 .context("failed to get ssh password")
1384 .and_then(|p| p)
1385 .log_err()
1386 {
1387 stream.write_all(password.as_bytes()).await.log_err();
1388 } else {
1389 if let Some(kill_tx) = kill_tx.take() {
1390 kill_tx.send(stream).log_err();
1391 break;
1392 }
1393 }
1394 }
1395 }
1396 });
1397
1398 // Create an askpass script that communicates back to this process.
1399 let askpass_script = format!(
1400 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1401 askpass_socket = askpass_socket.display(),
1402 print_args = "printf '%s\\0' \"$@\"",
1403 shebang = "#!/bin/sh",
1404 );
1405 let askpass_script_path = temp_dir.path().join("askpass.sh");
1406 fs::write(&askpass_script_path, askpass_script).await?;
1407 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1408
1409 // Start the master SSH process, which does not do anything except for establish
1410 // the connection and keep it open, allowing other ssh commands to reuse it
1411 // via a control socket.
1412 let socket_path = temp_dir.path().join("ssh.sock");
1413
1414 let mut master_process = process::Command::new("ssh")
1415 .stdin(Stdio::null())
1416 .stdout(Stdio::piped())
1417 .stderr(Stdio::piped())
1418 .env("SSH_ASKPASS_REQUIRE", "force")
1419 .env("SSH_ASKPASS", &askpass_script_path)
1420 .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1421 .args([
1422 "-N",
1423 "-o",
1424 "ControlPersist=no",
1425 "-o",
1426 "ControlMaster=yes",
1427 "-o",
1428 ])
1429 .arg(format!("ControlPath={}", socket_path.display()))
1430 .arg(&url)
1431 .kill_on_drop(true)
1432 .spawn()?;
1433
1434 // Wait for this ssh process to close its stdout, indicating that authentication
1435 // has completed.
1436 let mut stdout = master_process.stdout.take().unwrap();
1437 let mut output = Vec::new();
1438 let connection_timeout = Duration::from_secs(10);
1439
1440 let result = select_biased! {
1441 _ = askpass_opened_rx.fuse() => {
1442 select_biased! {
1443 stream = askpass_kill_master_rx.fuse() => {
1444 master_process.kill().ok();
1445 drop(stream);
1446 Err(anyhow!("SSH connection canceled"))
1447 }
1448 // If the askpass script has opened, that means the user is typing
1449 // their password, in which case we don't want to timeout anymore,
1450 // since we know a connection has been established.
1451 result = stdout.read_to_end(&mut output).fuse() => {
1452 result?;
1453 Ok(())
1454 }
1455 }
1456 }
1457 _ = stdout.read_to_end(&mut output).fuse() => {
1458 Ok(())
1459 }
1460 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1461 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1462 }
1463 };
1464
1465 if let Err(e) = result {
1466 return Err(e.context("Failed to connect to host"));
1467 }
1468
1469 drop(askpass_task);
1470
1471 if master_process.try_status()?.is_some() {
1472 output.clear();
1473 let mut stderr = master_process.stderr.take().unwrap();
1474 stderr.read_to_end(&mut output).await?;
1475
1476 let error_message = format!(
1477 "failed to connect: {}",
1478 String::from_utf8_lossy(&output).trim()
1479 );
1480 Err(anyhow!(error_message))?;
1481 }
1482
1483 let socket = SshSocket {
1484 connection_options,
1485 socket_path,
1486 };
1487
1488 let os = run_cmd(socket.ssh_command("uname", &["-s"])).await?;
1489 let arch = run_cmd(socket.ssh_command("uname", &["-m"])).await?;
1490
1491 let os = match os.trim() {
1492 "Darwin" => "macos",
1493 "Linux" => "linux",
1494 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1495 };
1496 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1497 "aarch64"
1498 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1499 "x86_64"
1500 } else {
1501 Err(anyhow!("unknown uname architecture {arch:?}"))?
1502 };
1503
1504 let platform = SshPlatform { os, arch };
1505
1506 Ok(Self {
1507 socket,
1508 master_process: Mutex::new(Some(master_process)),
1509 platform,
1510 _temp_dir: temp_dir,
1511 })
1512 }
1513
1514 fn multiplex(
1515 mut ssh_proxy_process: Child,
1516 incoming_tx: UnboundedSender<Envelope>,
1517 mut outgoing_rx: UnboundedReceiver<Envelope>,
1518 mut connection_activity_tx: Sender<()>,
1519 cx: &AsyncAppContext,
1520 ) -> Task<Result<i32>> {
1521 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1522 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1523 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1524
1525 let mut stdin_buffer = Vec::new();
1526 let mut stdout_buffer = Vec::new();
1527 let mut stderr_buffer = Vec::new();
1528 let mut stderr_offset = 0;
1529
1530 let stdin_task = cx.background_executor().spawn(async move {
1531 while let Some(outgoing) = outgoing_rx.next().await {
1532 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1533 }
1534 anyhow::Ok(())
1535 });
1536
1537 let stdout_task = cx.background_executor().spawn({
1538 let mut connection_activity_tx = connection_activity_tx.clone();
1539 async move {
1540 loop {
1541 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1542 let len = child_stdout.read(&mut stdout_buffer).await?;
1543
1544 if len == 0 {
1545 return anyhow::Ok(());
1546 }
1547
1548 if len < MESSAGE_LEN_SIZE {
1549 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1550 }
1551
1552 let message_len = message_len_from_buffer(&stdout_buffer);
1553 let envelope =
1554 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1555 .await?;
1556 connection_activity_tx.try_send(()).ok();
1557 incoming_tx.unbounded_send(envelope).ok();
1558 }
1559 }
1560 });
1561
1562 let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
1563 loop {
1564 stderr_buffer.resize(stderr_offset + 1024, 0);
1565
1566 let len = child_stderr
1567 .read(&mut stderr_buffer[stderr_offset..])
1568 .await?;
1569 if len == 0 {
1570 return anyhow::Ok(());
1571 }
1572
1573 stderr_offset += len;
1574 let mut start_ix = 0;
1575 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1576 .iter()
1577 .position(|b| b == &b'\n')
1578 {
1579 let line_ix = start_ix + ix;
1580 let content = &stderr_buffer[start_ix..line_ix];
1581 start_ix = line_ix + 1;
1582 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1583 record.log(log::logger())
1584 } else {
1585 eprintln!("(remote) {}", String::from_utf8_lossy(content));
1586 }
1587 }
1588 stderr_buffer.drain(0..start_ix);
1589 stderr_offset -= start_ix;
1590
1591 connection_activity_tx.try_send(()).ok();
1592 }
1593 });
1594
1595 cx.spawn(|_| async move {
1596 let result = futures::select! {
1597 result = stdin_task.fuse() => {
1598 result.context("stdin")
1599 }
1600 result = stdout_task.fuse() => {
1601 result.context("stdout")
1602 }
1603 result = stderr_task.fuse() => {
1604 result.context("stderr")
1605 }
1606 };
1607
1608 let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1609 match result {
1610 Ok(_) => Ok(status),
1611 Err(error) => Err(error),
1612 }
1613 })
1614 }
1615
1616 async fn ensure_server_binary(
1617 &self,
1618 delegate: &Arc<dyn SshClientDelegate>,
1619 dst_path: &Path,
1620 platform: SshPlatform,
1621 cx: &mut AsyncAppContext,
1622 ) -> Result<()> {
1623 let lock_file = dst_path.with_extension("lock");
1624 let lock_content = {
1625 let timestamp = SystemTime::now()
1626 .duration_since(UNIX_EPOCH)
1627 .context("failed to get timestamp")?
1628 .as_secs();
1629 let source_port = self.get_ssh_source_port().await?;
1630 format!("{} {}", source_port, timestamp)
1631 };
1632
1633 let lock_stale_age = Duration::from_secs(10 * 60);
1634 let max_wait_time = Duration::from_secs(10 * 60);
1635 let check_interval = Duration::from_secs(5);
1636 let start_time = Instant::now();
1637
1638 loop {
1639 let lock_acquired = self.create_lock_file(&lock_file, &lock_content).await?;
1640 if lock_acquired {
1641 delegate.set_status(Some("Acquired lock file on host"), cx);
1642 let result = self
1643 .update_server_binary_if_needed(delegate, dst_path, platform, cx)
1644 .await;
1645
1646 self.remove_lock_file(&lock_file).await.ok();
1647
1648 return result;
1649 } else {
1650 if let Ok(is_stale) = self.is_lock_stale(&lock_file, &lock_stale_age).await {
1651 if is_stale {
1652 delegate.set_status(
1653 Some("Detected lock file on host being stale. Removing"),
1654 cx,
1655 );
1656 self.remove_lock_file(&lock_file).await?;
1657 continue;
1658 } else {
1659 if start_time.elapsed() > max_wait_time {
1660 return Err(anyhow!("Timeout waiting for lock to be released"));
1661 }
1662 log::info!(
1663 "Found lockfile: {:?}. Will check again in {:?}",
1664 lock_file,
1665 check_interval
1666 );
1667 delegate.set_status(
1668 Some("Waiting for another Zed instance to finish uploading binary"),
1669 cx,
1670 );
1671 smol::Timer::after(check_interval).await;
1672 continue;
1673 }
1674 } else {
1675 // Unable to check lock, assume it's valid and wait
1676 if start_time.elapsed() > max_wait_time {
1677 return Err(anyhow!("Timeout waiting for lock to be released"));
1678 }
1679 smol::Timer::after(check_interval).await;
1680 continue;
1681 }
1682 }
1683 }
1684 }
1685
1686 async fn get_ssh_source_port(&self) -> Result<String> {
1687 let output = run_cmd(self.socket.shell_script("echo $SSH_CLIENT | cut -d' ' -f2"))
1688 .await
1689 .context("failed to get source port from SSH_CLIENT on host")?;
1690
1691 Ok(output.trim().to_string())
1692 }
1693
1694 async fn create_lock_file(&self, lock_file: &Path, content: &str) -> Result<bool> {
1695 let parent_dir = lock_file
1696 .parent()
1697 .ok_or_else(|| anyhow!("Lock file path has no parent directory"))?;
1698
1699 let script = format!(
1700 r#"mkdir -p "{parent_dir}" && [ ! -f "{lock_file}" ] && echo "{content}" > "{lock_file}" && echo "created" || echo "exists""#,
1701 parent_dir = parent_dir.display(),
1702 lock_file = lock_file.display(),
1703 content = content,
1704 );
1705
1706 let output = run_cmd(self.socket.shell_script(&script))
1707 .await
1708 .with_context(|| format!("failed to create a lock file at {:?}", lock_file))?;
1709
1710 Ok(output.trim() == "created")
1711 }
1712
1713 fn generate_stale_check_script(lock_file: &Path, max_age: u64) -> String {
1714 shell_script!(
1715 r#"
1716 if [ ! -f "{lock_file}" ]; then
1717 echo "lock file does not exist"
1718 exit 0
1719 fi
1720
1721 read -r port timestamp < "{lock_file}"
1722
1723 # Check if port is still active
1724 if command -v ss >/dev/null 2>&1; then
1725 if ! ss -n | grep -q ":$port[[:space:]]"; then
1726 echo "ss reports port $port is not open"
1727 exit 0
1728 fi
1729 elif command -v netstat >/dev/null 2>&1; then
1730 if ! netstat -n | grep -q ":$port[[:space:]]"; then
1731 echo "netstat reports port $port is not open"
1732 exit 0
1733 fi
1734 fi
1735
1736 # Check timestamp
1737 if [ $(( $(date +%s) - timestamp )) -gt {max_age} ]; then
1738 echo "timestamp in lockfile is too old"
1739 else
1740 echo "recent"
1741 fi"#,
1742 lock_file = &lock_file.to_string_lossy(),
1743 max_age = &max_age.to_string()
1744 )
1745 }
1746
1747 async fn is_lock_stale(&self, lock_file: &Path, max_age: &Duration) -> Result<bool> {
1748 let script = Self::generate_stale_check_script(lock_file, max_age.as_secs());
1749
1750 let output = run_cmd(self.socket.shell_script(script))
1751 .await
1752 .with_context(|| {
1753 format!("failed to check whether lock file {:?} is stale", lock_file)
1754 })?;
1755
1756 let trimmed = output.trim();
1757 let is_stale = trimmed != "recent";
1758 log::info!("checked lockfile for staleness. stale: {is_stale}, output: {trimmed:?}");
1759 Ok(is_stale)
1760 }
1761
1762 async fn remove_lock_file(&self, lock_file: &Path) -> Result<()> {
1763 run_cmd(
1764 self.socket
1765 .ssh_command("rm", &["-f", &lock_file.to_string_lossy()]),
1766 )
1767 .await
1768 .context("failed to remove lock file")?;
1769 Ok(())
1770 }
1771
1772 async fn update_server_binary_if_needed(
1773 &self,
1774 delegate: &Arc<dyn SshClientDelegate>,
1775 dst_path: &Path,
1776 platform: SshPlatform,
1777 cx: &mut AsyncAppContext,
1778 ) -> Result<()> {
1779 let current_version = match run_cmd(
1780 self.socket
1781 .ssh_command(&dst_path.to_string_lossy(), &["version"]),
1782 )
1783 .await
1784 {
1785 Ok(version_output) => {
1786 if let Ok(version) = version_output.trim().parse::<SemanticVersion>() {
1787 Some(ServerVersion::Semantic(version))
1788 } else {
1789 Some(ServerVersion::Commit(version_output.trim().to_string()))
1790 }
1791 }
1792 Err(_) => None,
1793 };
1794 let (release_channel, wanted_version) = cx.update(|cx| {
1795 let release_channel = ReleaseChannel::global(cx);
1796 let wanted_version = match release_channel {
1797 ReleaseChannel::Nightly => {
1798 AppCommitSha::try_global(cx).map(|sha| ServerVersion::Commit(sha.0))
1799 }
1800 ReleaseChannel::Dev => None,
1801 _ => Some(ServerVersion::Semantic(AppVersion::global(cx))),
1802 };
1803 (release_channel, wanted_version)
1804 })?;
1805
1806 match (¤t_version, &wanted_version) {
1807 (Some(current), Some(wanted)) if current == wanted => {
1808 log::info!("remote development server present and matching client version");
1809 return Ok(());
1810 }
1811 (Some(ServerVersion::Semantic(current)), Some(ServerVersion::Semantic(wanted)))
1812 if current > wanted =>
1813 {
1814 anyhow::bail!("The version of the remote server ({}) is newer than the Zed version ({}). Please update Zed.", current, wanted);
1815 }
1816 _ => {
1817 log::info!("Installing remote development server");
1818 }
1819 }
1820
1821 if self.is_binary_in_use(dst_path).await? {
1822 // When we're not in dev mode, we don't want to switch out the binary if it's
1823 // still open.
1824 // In dev mode, that's fine, since we often kill Zed processes with Ctrl-C and want
1825 // to still replace the binary.
1826 if cfg!(not(debug_assertions)) {
1827 anyhow::bail!("The remote server version ({:?}) does not match the wanted version ({:?}), but is in use by another Zed client so cannot be upgraded.", ¤t_version, &wanted_version)
1828 } else {
1829 log::info!("Binary is currently in use, ignoring because this is a dev build")
1830 }
1831 }
1832
1833 if wanted_version.is_none() {
1834 if std::env::var("ZED_BUILD_REMOTE_SERVER").is_err() {
1835 if let Some(current_version) = current_version {
1836 log::warn!(
1837 "In development, using cached remote server binary version ({})",
1838 current_version
1839 );
1840
1841 return Ok(());
1842 } else {
1843 anyhow::bail!(
1844 "ZED_BUILD_REMOTE_SERVER is not set, but no remote server exists at ({:?})",
1845 dst_path
1846 )
1847 }
1848 }
1849
1850 #[cfg(debug_assertions)]
1851 {
1852 let src_path = self.build_local(platform, delegate, cx).await?;
1853
1854 return self
1855 .upload_local_server_binary(&src_path, dst_path, delegate, cx)
1856 .await;
1857 }
1858
1859 #[cfg(not(debug_assertions))]
1860 anyhow::bail!("Running development build in release mode, cannot cross compile (unset ZED_BUILD_REMOTE_SERVER)")
1861 }
1862
1863 let upload_binary_over_ssh = self.socket.connection_options.upload_binary_over_ssh;
1864
1865 if !upload_binary_over_ssh {
1866 let (url, body) = delegate
1867 .get_download_params(
1868 platform,
1869 release_channel,
1870 wanted_version.clone().and_then(|v| v.semantic_version()),
1871 cx,
1872 )
1873 .await?;
1874
1875 match self
1876 .download_binary_on_server(&url, &body, dst_path, delegate, cx)
1877 .await
1878 {
1879 Ok(_) => return Ok(()),
1880 Err(e) => {
1881 log::error!(
1882 "Failed to download binary on server, attempting to upload server: {}",
1883 e
1884 )
1885 }
1886 }
1887 }
1888
1889 let src_path = delegate
1890 .download_server_binary_locally(
1891 platform,
1892 release_channel,
1893 wanted_version.and_then(|v| v.semantic_version()),
1894 cx,
1895 )
1896 .await?;
1897
1898 self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
1899 .await
1900 }
1901
1902 async fn is_binary_in_use(&self, binary_path: &Path) -> Result<bool> {
1903 let script = shell_script!(
1904 r#"
1905 if command -v lsof >/dev/null 2>&1; then
1906 if lsof "{binary_path}" >/dev/null 2>&1; then
1907 echo "in_use"
1908 exit 0
1909 fi
1910 elif command -v fuser >/dev/null 2>&1; then
1911 if fuser "{binary_path}" >/dev/null 2>&1; then
1912 echo "in_use"
1913 exit 0
1914 fi
1915 fi
1916 echo "not_in_use"
1917 "#,
1918 binary_path = &binary_path.to_string_lossy(),
1919 );
1920
1921 let output = run_cmd(self.socket.shell_script(script))
1922 .await
1923 .context("failed to check if binary is in use")?;
1924
1925 Ok(output.trim() == "in_use")
1926 }
1927
1928 async fn download_binary_on_server(
1929 &self,
1930 url: &str,
1931 body: &str,
1932 dst_path: &Path,
1933 delegate: &Arc<dyn SshClientDelegate>,
1934 cx: &mut AsyncAppContext,
1935 ) -> Result<()> {
1936 let mut dst_path_gz = dst_path.to_path_buf();
1937 dst_path_gz.set_extension("gz");
1938
1939 if let Some(parent) = dst_path.parent() {
1940 run_cmd(
1941 self.socket
1942 .ssh_command("mkdir", &["-p", &parent.to_string_lossy()]),
1943 )
1944 .await?;
1945 }
1946
1947 delegate.set_status(Some("Downloading remote development server on host"), cx);
1948
1949 let script = shell_script!(
1950 r#"
1951 if command -v curl >/dev/null 2>&1; then
1952 curl -f -L -X GET -H "Content-Type: application/json" -d {body} {url} -o {dst_path} && echo "curl"
1953 elif command -v wget >/dev/null 2>&1; then
1954 wget --max-redirect=5 --method=GET --header="Content-Type: application/json" --body-data={body} {url} -O {dst_path} && echo "wget"
1955 else
1956 echo "Neither curl nor wget is available" >&2
1957 exit 1
1958 fi
1959 "#,
1960 body = body,
1961 url = url,
1962 dst_path = &dst_path_gz.to_string_lossy(),
1963 );
1964
1965 let output = run_cmd(self.socket.shell_script(script))
1966 .await
1967 .context("Failed to download server binary")?;
1968
1969 if !output.contains("curl") && !output.contains("wget") {
1970 return Err(anyhow!("Failed to download server binary: {}", output));
1971 }
1972
1973 self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
1974 .await
1975 }
1976
1977 async fn upload_local_server_binary(
1978 &self,
1979 src_path: &Path,
1980 dst_path: &Path,
1981 delegate: &Arc<dyn SshClientDelegate>,
1982 cx: &mut AsyncAppContext,
1983 ) -> Result<()> {
1984 let mut dst_path_gz = dst_path.to_path_buf();
1985 dst_path_gz.set_extension("gz");
1986
1987 if let Some(parent) = dst_path.parent() {
1988 run_cmd(
1989 self.socket
1990 .ssh_command("mkdir", &["-p", &parent.to_string_lossy()]),
1991 )
1992 .await?;
1993 }
1994
1995 let src_stat = fs::metadata(&src_path).await?;
1996 let size = src_stat.len();
1997
1998 let t0 = Instant::now();
1999 delegate.set_status(Some("Uploading remote development server"), cx);
2000 log::info!("uploading remote development server ({}kb)", size / 1024);
2001 self.upload_file(&src_path, &dst_path_gz)
2002 .await
2003 .context("failed to upload server binary")?;
2004 log::info!("uploaded remote development server in {:?}", t0.elapsed());
2005
2006 self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
2007 .await
2008 }
2009
2010 async fn extract_server_binary(
2011 &self,
2012 dst_path: &Path,
2013 dst_path_gz: &Path,
2014 delegate: &Arc<dyn SshClientDelegate>,
2015 cx: &mut AsyncAppContext,
2016 ) -> Result<()> {
2017 delegate.set_status(Some("Extracting remote development server"), cx);
2018 run_cmd(
2019 self.socket
2020 .ssh_command("gunzip", &["-f", &dst_path_gz.to_string_lossy()]),
2021 )
2022 .await?;
2023
2024 let server_mode = 0o755;
2025 delegate.set_status(Some("Marking remote development server executable"), cx);
2026 run_cmd(self.socket.ssh_command(
2027 "chmod",
2028 &[&format!("{:o}", server_mode), &dst_path.to_string_lossy()],
2029 ))
2030 .await?;
2031
2032 Ok(())
2033 }
2034
2035 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
2036 let mut command = process::Command::new("scp");
2037 let output = self
2038 .socket
2039 .ssh_options(&mut command)
2040 .args(
2041 self.socket
2042 .connection_options
2043 .port
2044 .map(|port| vec!["-P".to_string(), port.to_string()])
2045 .unwrap_or_default(),
2046 )
2047 .arg(src_path)
2048 .arg(format!(
2049 "{}:{}",
2050 self.socket.connection_options.scp_url(),
2051 dest_path.display()
2052 ))
2053 .output()
2054 .await?;
2055
2056 if output.status.success() {
2057 Ok(())
2058 } else {
2059 Err(anyhow!(
2060 "failed to upload file {} -> {}: {}",
2061 src_path.display(),
2062 dest_path.display(),
2063 String::from_utf8_lossy(&output.stderr)
2064 ))
2065 }
2066 }
2067
2068 #[cfg(debug_assertions)]
2069 async fn build_local(
2070 &self,
2071 platform: SshPlatform,
2072 delegate: &Arc<dyn SshClientDelegate>,
2073 cx: &mut AsyncAppContext,
2074 ) -> Result<PathBuf> {
2075 use smol::process::{Command, Stdio};
2076
2077 async fn run_cmd(command: &mut Command) -> Result<()> {
2078 let output = command
2079 .kill_on_drop(true)
2080 .stderr(Stdio::inherit())
2081 .output()
2082 .await?;
2083 if !output.status.success() {
2084 Err(anyhow!("Failed to run command: {:?}", command))?;
2085 }
2086 Ok(())
2087 }
2088
2089 if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
2090 delegate.set_status(Some("Building remote server binary from source"), cx);
2091 log::info!("building remote server binary from source");
2092 run_cmd(Command::new("cargo").args([
2093 "build",
2094 "--package",
2095 "remote_server",
2096 "--features",
2097 "debug-embed",
2098 "--target-dir",
2099 "target/remote_server",
2100 ]))
2101 .await?;
2102
2103 delegate.set_status(Some("Compressing binary"), cx);
2104
2105 run_cmd(Command::new("gzip").args([
2106 "-9",
2107 "-f",
2108 "target/remote_server/debug/remote_server",
2109 ]))
2110 .await?;
2111
2112 let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
2113 return Ok(path);
2114 }
2115 let Some(triple) = platform.triple() else {
2116 anyhow::bail!("can't cross compile for: {:?}", platform);
2117 };
2118 smol::fs::create_dir_all("target/remote_server").await?;
2119
2120 delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2121 log::info!("installing cross");
2122 run_cmd(Command::new("cargo").args([
2123 "install",
2124 "cross",
2125 "--git",
2126 "https://github.com/cross-rs/cross",
2127 ]))
2128 .await?;
2129
2130 delegate.set_status(
2131 Some(&format!(
2132 "Building remote server binary from source for {} with Docker",
2133 &triple
2134 )),
2135 cx,
2136 );
2137 log::info!("building remote server binary from source for {}", &triple);
2138 run_cmd(
2139 Command::new("cross")
2140 .args([
2141 "build",
2142 "--package",
2143 "remote_server",
2144 "--features",
2145 "debug-embed",
2146 "--target-dir",
2147 "target/remote_server",
2148 "--target",
2149 &triple,
2150 ])
2151 .env(
2152 "CROSS_CONTAINER_OPTS",
2153 "--mount type=bind,src=./target,dst=/app/target",
2154 ),
2155 )
2156 .await?;
2157
2158 delegate.set_status(Some("Compressing binary"), cx);
2159
2160 run_cmd(Command::new("gzip").args([
2161 "-9",
2162 "-f",
2163 &format!("target/remote_server/{}/debug/remote_server", triple),
2164 ]))
2165 .await?;
2166
2167 let path = std::env::current_dir()?.join(format!(
2168 "target/remote_server/{}/debug/remote_server.gz",
2169 triple
2170 ));
2171
2172 return Ok(path);
2173 }
2174}
2175
2176type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2177
2178pub struct ChannelClient {
2179 next_message_id: AtomicU32,
2180 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2181 buffer: Mutex<VecDeque<Envelope>>,
2182 response_channels: ResponseChannels,
2183 message_handlers: Mutex<ProtoMessageHandlerSet>,
2184 max_received: AtomicU32,
2185 name: &'static str,
2186 task: Mutex<Task<Result<()>>>,
2187}
2188
2189impl ChannelClient {
2190 pub fn new(
2191 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2192 outgoing_tx: mpsc::UnboundedSender<Envelope>,
2193 cx: &AppContext,
2194 name: &'static str,
2195 ) -> Arc<Self> {
2196 Arc::new_cyclic(|this| Self {
2197 outgoing_tx: Mutex::new(outgoing_tx),
2198 next_message_id: AtomicU32::new(0),
2199 max_received: AtomicU32::new(0),
2200 response_channels: ResponseChannels::default(),
2201 message_handlers: Default::default(),
2202 buffer: Mutex::new(VecDeque::new()),
2203 name,
2204 task: Mutex::new(Self::start_handling_messages(
2205 this.clone(),
2206 incoming_rx,
2207 &cx.to_async(),
2208 )),
2209 })
2210 }
2211
2212 fn start_handling_messages(
2213 this: Weak<Self>,
2214 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2215 cx: &AsyncAppContext,
2216 ) -> Task<Result<()>> {
2217 cx.spawn(|cx| async move {
2218 let peer_id = PeerId { owner_id: 0, id: 0 };
2219 while let Some(incoming) = incoming_rx.next().await {
2220 let Some(this) = this.upgrade() else {
2221 return anyhow::Ok(());
2222 };
2223 if let Some(ack_id) = incoming.ack_id {
2224 let mut buffer = this.buffer.lock();
2225 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2226 buffer.pop_front();
2227 }
2228 }
2229 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2230 {
2231 log::debug!(
2232 "{}:ssh message received. name:FlushBufferedMessages",
2233 this.name
2234 );
2235 {
2236 let buffer = this.buffer.lock();
2237 for envelope in buffer.iter() {
2238 this.outgoing_tx
2239 .lock()
2240 .unbounded_send(envelope.clone())
2241 .ok();
2242 }
2243 }
2244 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2245 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2246 this.outgoing_tx.lock().unbounded_send(envelope).ok();
2247 continue;
2248 }
2249
2250 this.max_received.store(incoming.id, SeqCst);
2251
2252 if let Some(request_id) = incoming.responding_to {
2253 let request_id = MessageId(request_id);
2254 let sender = this.response_channels.lock().remove(&request_id);
2255 if let Some(sender) = sender {
2256 let (tx, rx) = oneshot::channel();
2257 if incoming.payload.is_some() {
2258 sender.send((incoming, tx)).ok();
2259 }
2260 rx.await.ok();
2261 }
2262 } else if let Some(envelope) =
2263 build_typed_envelope(peer_id, Instant::now(), incoming)
2264 {
2265 let type_name = envelope.payload_type_name();
2266 if let Some(future) = ProtoMessageHandlerSet::handle_message(
2267 &this.message_handlers,
2268 envelope,
2269 this.clone().into(),
2270 cx.clone(),
2271 ) {
2272 log::debug!("{}:ssh message received. name:{type_name}", this.name);
2273 cx.foreground_executor()
2274 .spawn(async move {
2275 match future.await {
2276 Ok(_) => {
2277 log::debug!(
2278 "{}:ssh message handled. name:{type_name}",
2279 this.name
2280 );
2281 }
2282 Err(error) => {
2283 log::error!(
2284 "{}:error handling message. type:{}, error:{}",
2285 this.name,
2286 type_name,
2287 format!("{error:#}").lines().fold(
2288 String::new(),
2289 |mut message, line| {
2290 if !message.is_empty() {
2291 message.push(' ');
2292 }
2293 message.push_str(line);
2294 message
2295 }
2296 )
2297 );
2298 }
2299 }
2300 })
2301 .detach()
2302 } else {
2303 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2304 }
2305 }
2306 }
2307 anyhow::Ok(())
2308 })
2309 }
2310
2311 pub fn reconnect(
2312 self: &Arc<Self>,
2313 incoming_rx: UnboundedReceiver<Envelope>,
2314 outgoing_tx: UnboundedSender<Envelope>,
2315 cx: &AsyncAppContext,
2316 ) {
2317 *self.outgoing_tx.lock() = outgoing_tx;
2318 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2319 }
2320
2321 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
2322 let id = (TypeId::of::<E>(), remote_id);
2323
2324 let mut message_handlers = self.message_handlers.lock();
2325 if message_handlers
2326 .entities_by_type_and_remote_id
2327 .contains_key(&id)
2328 {
2329 panic!("already subscribed to entity");
2330 }
2331
2332 message_handlers.entities_by_type_and_remote_id.insert(
2333 id,
2334 EntityMessageSubscriber::Entity {
2335 handle: entity.downgrade().into(),
2336 },
2337 );
2338 }
2339
2340 pub fn request<T: RequestMessage>(
2341 &self,
2342 payload: T,
2343 ) -> impl 'static + Future<Output = Result<T::Response>> {
2344 self.request_internal(payload, true)
2345 }
2346
2347 fn request_internal<T: RequestMessage>(
2348 &self,
2349 payload: T,
2350 use_buffer: bool,
2351 ) -> impl 'static + Future<Output = Result<T::Response>> {
2352 log::debug!("ssh request start. name:{}", T::NAME);
2353 let response =
2354 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2355 async move {
2356 let response = response.await?;
2357 log::debug!("ssh request finish. name:{}", T::NAME);
2358 T::Response::from_envelope(response)
2359 .ok_or_else(|| anyhow!("received a response of the wrong type"))
2360 }
2361 }
2362
2363 pub async fn resync(&self, timeout: Duration) -> Result<()> {
2364 smol::future::or(
2365 async {
2366 self.request_internal(proto::FlushBufferedMessages {}, false)
2367 .await?;
2368
2369 for envelope in self.buffer.lock().iter() {
2370 self.outgoing_tx
2371 .lock()
2372 .unbounded_send(envelope.clone())
2373 .ok();
2374 }
2375 Ok(())
2376 },
2377 async {
2378 smol::Timer::after(timeout).await;
2379 Err(anyhow!("Timeout detected"))
2380 },
2381 )
2382 .await
2383 }
2384
2385 pub async fn ping(&self, timeout: Duration) -> Result<()> {
2386 smol::future::or(
2387 async {
2388 self.request(proto::Ping {}).await?;
2389 Ok(())
2390 },
2391 async {
2392 smol::Timer::after(timeout).await;
2393 Err(anyhow!("Timeout detected"))
2394 },
2395 )
2396 .await
2397 }
2398
2399 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2400 log::debug!("ssh send name:{}", T::NAME);
2401 self.send_dynamic(payload.into_envelope(0, None, None))
2402 }
2403
2404 fn request_dynamic(
2405 &self,
2406 mut envelope: proto::Envelope,
2407 type_name: &'static str,
2408 use_buffer: bool,
2409 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2410 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2411 let (tx, rx) = oneshot::channel();
2412 let mut response_channels_lock = self.response_channels.lock();
2413 response_channels_lock.insert(MessageId(envelope.id), tx);
2414 drop(response_channels_lock);
2415
2416 let result = if use_buffer {
2417 self.send_buffered(envelope)
2418 } else {
2419 self.send_unbuffered(envelope)
2420 };
2421 async move {
2422 if let Err(error) = &result {
2423 log::error!("failed to send message: {}", error);
2424 return Err(anyhow!("failed to send message: {}", error));
2425 }
2426
2427 let response = rx.await.context("connection lost")?.0;
2428 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2429 return Err(RpcError::from_proto(error, type_name));
2430 }
2431 Ok(response)
2432 }
2433 }
2434
2435 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2436 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2437 self.send_buffered(envelope)
2438 }
2439
2440 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2441 envelope.ack_id = Some(self.max_received.load(SeqCst));
2442 self.buffer.lock().push_back(envelope.clone());
2443 // ignore errors on send (happen while we're reconnecting)
2444 // assume that the global "disconnected" overlay is sufficient.
2445 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2446 Ok(())
2447 }
2448
2449 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2450 envelope.ack_id = Some(self.max_received.load(SeqCst));
2451 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2452 Ok(())
2453 }
2454}
2455
2456impl ProtoClient for ChannelClient {
2457 fn request(
2458 &self,
2459 envelope: proto::Envelope,
2460 request_type: &'static str,
2461 ) -> BoxFuture<'static, Result<proto::Envelope>> {
2462 self.request_dynamic(envelope, request_type, true).boxed()
2463 }
2464
2465 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2466 self.send_dynamic(envelope)
2467 }
2468
2469 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2470 self.send_dynamic(envelope)
2471 }
2472
2473 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2474 &self.message_handlers
2475 }
2476
2477 fn is_via_collab(&self) -> bool {
2478 false
2479 }
2480}
2481
2482#[cfg(any(test, feature = "test-support"))]
2483mod fake {
2484 use std::{path::PathBuf, sync::Arc};
2485
2486 use anyhow::Result;
2487 use async_trait::async_trait;
2488 use futures::{
2489 channel::{
2490 mpsc::{self, Sender},
2491 oneshot,
2492 },
2493 select_biased, FutureExt, SinkExt, StreamExt,
2494 };
2495 use gpui::{AsyncAppContext, SemanticVersion, Task, TestAppContext};
2496 use release_channel::ReleaseChannel;
2497 use rpc::proto::Envelope;
2498
2499 use super::{
2500 ChannelClient, RemoteConnection, SshClientDelegate, SshConnectionOptions, SshPlatform,
2501 };
2502
2503 pub(super) struct FakeRemoteConnection {
2504 pub(super) connection_options: SshConnectionOptions,
2505 pub(super) server_channel: Arc<ChannelClient>,
2506 pub(super) server_cx: SendableCx,
2507 }
2508
2509 pub(super) struct SendableCx(AsyncAppContext);
2510 impl SendableCx {
2511 // SAFETY: When run in test mode, GPUI is always single threaded.
2512 pub(super) fn new(cx: &TestAppContext) -> Self {
2513 Self(cx.to_async())
2514 }
2515
2516 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncAppContext
2517 fn get(&self, _: &AsyncAppContext) -> AsyncAppContext {
2518 self.0.clone()
2519 }
2520 }
2521
2522 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2523 unsafe impl Send for SendableCx {}
2524 unsafe impl Sync for SendableCx {}
2525
2526 #[async_trait(?Send)]
2527 impl RemoteConnection for FakeRemoteConnection {
2528 async fn kill(&self) -> Result<()> {
2529 Ok(())
2530 }
2531
2532 fn has_been_killed(&self) -> bool {
2533 false
2534 }
2535
2536 fn ssh_args(&self) -> Vec<String> {
2537 Vec::new()
2538 }
2539
2540 fn connection_options(&self) -> SshConnectionOptions {
2541 self.connection_options.clone()
2542 }
2543
2544 fn simulate_disconnect(&self, cx: &AsyncAppContext) {
2545 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2546 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2547 self.server_channel
2548 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2549 }
2550
2551 async fn get_remote_binary_path(
2552 &self,
2553 _delegate: &Arc<dyn SshClientDelegate>,
2554 _reconnect: bool,
2555 _cx: &mut AsyncAppContext,
2556 ) -> Result<PathBuf> {
2557 Ok(PathBuf::new())
2558 }
2559
2560 fn start_proxy(
2561 &self,
2562 _remote_binary_path: PathBuf,
2563 _unique_identifier: String,
2564 _reconnect: bool,
2565 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2566 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2567 mut connection_activity_tx: Sender<()>,
2568 _delegate: Arc<dyn SshClientDelegate>,
2569 cx: &mut AsyncAppContext,
2570 ) -> Task<Result<i32>> {
2571 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2572 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2573
2574 self.server_channel.reconnect(
2575 server_incoming_rx,
2576 server_outgoing_tx,
2577 &self.server_cx.get(cx),
2578 );
2579
2580 cx.background_executor().spawn(async move {
2581 loop {
2582 select_biased! {
2583 server_to_client = server_outgoing_rx.next().fuse() => {
2584 let Some(server_to_client) = server_to_client else {
2585 return Ok(1)
2586 };
2587 connection_activity_tx.try_send(()).ok();
2588 client_incoming_tx.send(server_to_client).await.ok();
2589 }
2590 client_to_server = client_outgoing_rx.next().fuse() => {
2591 let Some(client_to_server) = client_to_server else {
2592 return Ok(1)
2593 };
2594 server_incoming_tx.send(client_to_server).await.ok();
2595 }
2596 }
2597 }
2598 })
2599 }
2600 }
2601
2602 pub(super) struct Delegate;
2603
2604 impl SshClientDelegate for Delegate {
2605 fn ask_password(
2606 &self,
2607 _: String,
2608 _: &mut AsyncAppContext,
2609 ) -> oneshot::Receiver<Result<String>> {
2610 unreachable!()
2611 }
2612
2613 fn download_server_binary_locally(
2614 &self,
2615 _: SshPlatform,
2616 _: ReleaseChannel,
2617 _: Option<SemanticVersion>,
2618 _: &mut AsyncAppContext,
2619 ) -> Task<Result<PathBuf>> {
2620 unreachable!()
2621 }
2622
2623 fn get_download_params(
2624 &self,
2625 _platform: SshPlatform,
2626 _release_channel: ReleaseChannel,
2627 _version: Option<SemanticVersion>,
2628 _cx: &mut AsyncAppContext,
2629 ) -> Task<Result<(String, String)>> {
2630 unreachable!()
2631 }
2632
2633 fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {}
2634
2635 fn remote_server_binary_path(
2636 &self,
2637 _platform: SshPlatform,
2638 _cx: &mut AsyncAppContext,
2639 ) -> Result<PathBuf> {
2640 unreachable!()
2641 }
2642 }
2643}
2644
2645#[cfg(all(test, unix))]
2646mod tests {
2647 use super::*;
2648 use std::fs;
2649 use tempfile::TempDir;
2650
2651 fn run_stale_check_script(
2652 lock_file: &Path,
2653 max_age: Duration,
2654 simulate_port_open: Option<&str>,
2655 ) -> Result<String> {
2656 let wrapper = format!(
2657 r#"
2658 # Mock ss/netstat commands
2659 ss() {{
2660 # Only handle the -n argument
2661 if [ "$1" = "-n" ]; then
2662 # If we're simulating an open port, output a line containing that port
2663 if [ "{simulated_port}" != "" ]; then
2664 echo "ESTAB 0 0 1.2.3.4:{simulated_port} 5.6.7.8:12345"
2665 fi
2666 fi
2667 }}
2668 netstat() {{
2669 ss "$@"
2670 }}
2671 export -f ss netstat
2672
2673 # Real script starts here
2674 {script}"#,
2675 simulated_port = simulate_port_open.unwrap_or(""),
2676 script = SshRemoteConnection::generate_stale_check_script(lock_file, max_age.as_secs())
2677 );
2678
2679 let output = std::process::Command::new("bash")
2680 .arg("-c")
2681 .arg(&wrapper)
2682 .output()?;
2683
2684 if !output.stderr.is_empty() {
2685 eprintln!("Script stderr: {}", String::from_utf8_lossy(&output.stderr));
2686 }
2687
2688 Ok(String::from_utf8(output.stdout)?.trim().to_string())
2689 }
2690
2691 #[test]
2692 fn test_lock_staleness() -> Result<()> {
2693 let temp_dir = TempDir::new()?;
2694 let lock_file = temp_dir.path().join("test.lock");
2695
2696 // Test 1: No lock file
2697 let output = run_stale_check_script(&lock_file, Duration::from_secs(600), None)?;
2698 assert_eq!(output, "lock file does not exist");
2699
2700 // Test 2: Lock file with port that's not open
2701 fs::write(&lock_file, "54321 1234567890")?;
2702 let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("98765"))?;
2703 assert_eq!(output, "ss reports port 54321 is not open");
2704
2705 // Test 3: Lock file with port that is open but old timestamp
2706 let old_timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() - 700; // 700 seconds ago
2707 fs::write(&lock_file, format!("54321 {}", old_timestamp))?;
2708 let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("54321"))?;
2709 assert_eq!(output, "timestamp in lockfile is too old");
2710
2711 // Test 4: Lock file with port that is open and recent timestamp
2712 let recent_timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() - 60; // 1 minute ago
2713 fs::write(&lock_file, format!("54321 {}", recent_timestamp))?;
2714 let output = run_stale_check_script(&lock_file, Duration::from_secs(600), Some("54321"))?;
2715 assert_eq!(output, "recent");
2716
2717 Ok(())
2718 }
2719}