1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{Context as _, Result, anyhow};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
13 channel::{
14 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
15 oneshot,
16 },
17 future::{BoxFuture, Shared},
18 select, select_biased,
19};
20use gpui::{
21 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
22 EventEmitter, Global, SemanticVersion, Task, WeakEntity,
23};
24use itertools::Itertools;
25use parking_lot::Mutex;
26
27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
28use rpc::{
29 AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
30 RpcError,
31 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
32};
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use smol::{
36 fs,
37 process::{self, Child, Stdio},
38};
39use std::{
40 any::TypeId,
41 collections::VecDeque,
42 fmt, iter,
43 ops::ControlFlow,
44 path::{Path, PathBuf},
45 sync::{
46 Arc, Weak,
47 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
48 },
49 time::{Duration, Instant},
50};
51use tempfile::TempDir;
52use util::{
53 ResultExt,
54 paths::{PathStyle, RemotePathBuf},
55};
56
57#[derive(
58 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
59)]
60pub struct SshProjectId(pub u64);
61
62#[derive(Clone)]
63pub struct SshSocket {
64 connection_options: SshConnectionOptions,
65 #[cfg(not(target_os = "windows"))]
66 socket_path: PathBuf,
67 #[cfg(target_os = "windows")]
68 envs: HashMap<String, String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
72pub struct SshPortForwardOption {
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub local_host: Option<String>,
75 pub local_port: u16,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub remote_host: Option<String>,
78 pub remote_port: u16,
79}
80
81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
82pub struct SshConnectionOptions {
83 pub host: String,
84 pub username: Option<String>,
85 pub port: Option<u16>,
86 pub password: Option<String>,
87 pub args: Option<Vec<String>>,
88 pub port_forwards: Option<Vec<SshPortForwardOption>>,
89
90 pub nickname: Option<String>,
91 pub upload_binary_over_ssh: bool,
92}
93
94pub struct SshArgs {
95 pub arguments: Vec<String>,
96 pub envs: Option<HashMap<String, String>>,
97}
98
99#[macro_export]
100macro_rules! shell_script {
101 ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
102 format!(
103 $fmt,
104 $(
105 $name = shlex::try_quote($arg).unwrap()
106 ),+
107 )
108 }};
109}
110
111fn parse_port_number(port_str: &str) -> Result<u16> {
112 port_str
113 .parse()
114 .with_context(|| format!("parsing port number: {port_str}"))
115}
116
117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
118 let parts: Vec<&str> = spec.split(':').collect();
119
120 match parts.len() {
121 4 => {
122 let local_port = parse_port_number(parts[1])?;
123 let remote_port = parse_port_number(parts[3])?;
124
125 Ok(SshPortForwardOption {
126 local_host: Some(parts[0].to_string()),
127 local_port,
128 remote_host: Some(parts[2].to_string()),
129 remote_port,
130 })
131 }
132 3 => {
133 let local_port = parse_port_number(parts[0])?;
134 let remote_port = parse_port_number(parts[2])?;
135
136 Ok(SshPortForwardOption {
137 local_host: None,
138 local_port,
139 remote_host: Some(parts[1].to_string()),
140 remote_port,
141 })
142 }
143 _ => anyhow::bail!("Invalid port forward format"),
144 }
145}
146
147impl SshConnectionOptions {
148 pub fn parse_command_line(input: &str) -> Result<Self> {
149 let input = input.trim_start_matches("ssh ");
150 let mut hostname: Option<String> = None;
151 let mut username: Option<String> = None;
152 let mut port: Option<u16> = None;
153 let mut args = Vec::new();
154 let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
155
156 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
157 const ALLOWED_OPTS: &[&str] = &[
158 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
159 ];
160 const ALLOWED_ARGS: &[&str] = &[
161 "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
162 "-w",
163 ];
164
165 let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
166
167 'outer: while let Some(arg) = tokens.next() {
168 if ALLOWED_OPTS.contains(&(&arg as &str)) {
169 args.push(arg.to_string());
170 continue;
171 }
172 if arg == "-p" {
173 port = tokens.next().and_then(|arg| arg.parse().ok());
174 continue;
175 } else if let Some(p) = arg.strip_prefix("-p") {
176 port = p.parse().ok();
177 continue;
178 }
179 if arg == "-l" {
180 username = tokens.next();
181 continue;
182 } else if let Some(l) = arg.strip_prefix("-l") {
183 username = Some(l.to_string());
184 continue;
185 }
186 if arg == "-L" || arg.starts_with("-L") {
187 let forward_spec = if arg == "-L" {
188 tokens.next()
189 } else {
190 Some(arg.strip_prefix("-L").unwrap().to_string())
191 };
192
193 if let Some(spec) = forward_spec {
194 port_forwards.push(parse_port_forward_spec(&spec)?);
195 } else {
196 anyhow::bail!("Missing port forward format");
197 }
198 }
199
200 for a in ALLOWED_ARGS {
201 if arg == *a {
202 args.push(arg);
203 if let Some(next) = tokens.next() {
204 args.push(next);
205 }
206 continue 'outer;
207 } else if arg.starts_with(a) {
208 args.push(arg);
209 continue 'outer;
210 }
211 }
212 if arg.starts_with("-") || hostname.is_some() {
213 anyhow::bail!("unsupported argument: {:?}", arg);
214 }
215 let mut input = &arg as &str;
216 // Destination might be: username1@username2@ip2@ip1
217 if let Some((u, rest)) = input.rsplit_once('@') {
218 input = rest;
219 username = Some(u.to_string());
220 }
221 if let Some((rest, p)) = input.split_once(':') {
222 input = rest;
223 port = p.parse().ok()
224 }
225 hostname = Some(input.to_string())
226 }
227
228 let Some(hostname) = hostname else {
229 anyhow::bail!("missing hostname");
230 };
231
232 let port_forwards = match port_forwards.len() {
233 0 => None,
234 _ => Some(port_forwards),
235 };
236
237 Ok(Self {
238 host: hostname.to_string(),
239 username: username.clone(),
240 port,
241 port_forwards,
242 args: Some(args),
243 password: None,
244 nickname: None,
245 upload_binary_over_ssh: false,
246 })
247 }
248
249 pub fn ssh_url(&self) -> String {
250 let mut result = String::from("ssh://");
251 if let Some(username) = &self.username {
252 // Username might be: username1@username2@ip2
253 let username = urlencoding::encode(username);
254 result.push_str(&username);
255 result.push('@');
256 }
257 result.push_str(&self.host);
258 if let Some(port) = self.port {
259 result.push(':');
260 result.push_str(&port.to_string());
261 }
262 result
263 }
264
265 pub fn additional_args(&self) -> Vec<String> {
266 let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
267
268 if let Some(forwards) = &self.port_forwards {
269 args.extend(forwards.iter().map(|pf| {
270 let local_host = match &pf.local_host {
271 Some(host) => host,
272 None => "localhost",
273 };
274 let remote_host = match &pf.remote_host {
275 Some(host) => host,
276 None => "localhost",
277 };
278
279 format!(
280 "-L{}:{}:{}:{}",
281 local_host, pf.local_port, remote_host, pf.remote_port
282 )
283 }));
284 }
285
286 args
287 }
288
289 fn scp_url(&self) -> String {
290 if let Some(username) = &self.username {
291 format!("{}@{}", username, self.host)
292 } else {
293 self.host.clone()
294 }
295 }
296
297 pub fn connection_string(&self) -> String {
298 let host = if let Some(username) = &self.username {
299 format!("{}@{}", username, self.host)
300 } else {
301 self.host.clone()
302 };
303 if let Some(port) = &self.port {
304 format!("{}:{}", host, port)
305 } else {
306 host
307 }
308 }
309}
310
311#[derive(Copy, Clone, Debug)]
312pub struct SshPlatform {
313 pub os: &'static str,
314 pub arch: &'static str,
315}
316
317pub trait SshClientDelegate: Send + Sync {
318 fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
319 fn get_download_params(
320 &self,
321 platform: SshPlatform,
322 release_channel: ReleaseChannel,
323 version: Option<SemanticVersion>,
324 cx: &mut AsyncApp,
325 ) -> Task<Result<Option<(String, String)>>>;
326
327 fn download_server_binary_locally(
328 &self,
329 platform: SshPlatform,
330 release_channel: ReleaseChannel,
331 version: Option<SemanticVersion>,
332 cx: &mut AsyncApp,
333 ) -> Task<Result<PathBuf>>;
334 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
335}
336
337impl SshSocket {
338 #[cfg(not(target_os = "windows"))]
339 fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
340 Ok(Self {
341 connection_options: options,
342 socket_path,
343 })
344 }
345
346 #[cfg(target_os = "windows")]
347 fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
348 let askpass_script = temp_dir.path().join("askpass.bat");
349 std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
350 let mut envs = HashMap::default();
351 envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
352 envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
353 envs.insert("ZED_SSH_ASKPASS".into(), secret);
354 Ok(Self {
355 connection_options: options,
356 envs,
357 })
358 }
359
360 // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
361 // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
362 // and passes -l as an argument to sh, not to ls.
363 // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
364 // into a machine. You must use `cd` to get back to $HOME.
365 // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
366 fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
367 let mut command = util::command::new_smol_command("ssh");
368 let to_run = iter::once(&program)
369 .chain(args.iter())
370 .map(|token| {
371 // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
372 debug_assert!(
373 !token.contains('\n'),
374 "multiline arguments do not work in all shells"
375 );
376 shlex::try_quote(token).unwrap()
377 })
378 .join(" ");
379 let to_run = format!("cd; {to_run}");
380 log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
381 self.ssh_options(&mut command)
382 .arg(self.connection_options.ssh_url())
383 .arg(to_run);
384 command
385 }
386
387 async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
388 let output = self.ssh_command(program, args).output().await?;
389 anyhow::ensure!(
390 output.status.success(),
391 "failed to run command: {}",
392 String::from_utf8_lossy(&output.stderr)
393 );
394 Ok(String::from_utf8_lossy(&output.stdout).to_string())
395 }
396
397 #[cfg(not(target_os = "windows"))]
398 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
399 command
400 .stdin(Stdio::piped())
401 .stdout(Stdio::piped())
402 .stderr(Stdio::piped())
403 .args(["-o", "ControlMaster=no", "-o"])
404 .arg(format!("ControlPath={}", self.socket_path.display()))
405 }
406
407 #[cfg(target_os = "windows")]
408 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
409 command
410 .stdin(Stdio::piped())
411 .stdout(Stdio::piped())
412 .stderr(Stdio::piped())
413 .envs(self.envs.clone())
414 }
415
416 // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
417 // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
418 #[cfg(not(target_os = "windows"))]
419 fn ssh_args(&self) -> SshArgs {
420 SshArgs {
421 arguments: vec![
422 "-o".to_string(),
423 "ControlMaster=no".to_string(),
424 "-o".to_string(),
425 format!("ControlPath={}", self.socket_path.display()),
426 self.connection_options.ssh_url(),
427 ],
428 envs: None,
429 }
430 }
431
432 #[cfg(target_os = "windows")]
433 fn ssh_args(&self) -> SshArgs {
434 SshArgs {
435 arguments: vec![self.connection_options.ssh_url()],
436 envs: Some(self.envs.clone()),
437 }
438 }
439
440 async fn platform(&self) -> Result<SshPlatform> {
441 let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
442 let Some((os, arch)) = uname.split_once(" ") else {
443 anyhow::bail!("unknown uname: {uname:?}")
444 };
445
446 let os = match os.trim() {
447 "Darwin" => "macos",
448 "Linux" => "linux",
449 _ => anyhow::bail!(
450 "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
451 ),
452 };
453 // exclude armv5,6,7 as they are 32-bit.
454 let arch = if arch.starts_with("armv8")
455 || arch.starts_with("armv9")
456 || arch.starts_with("arm64")
457 || arch.starts_with("aarch64")
458 {
459 "aarch64"
460 } else if arch.starts_with("x86") {
461 "x86_64"
462 } else {
463 anyhow::bail!(
464 "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
465 )
466 };
467
468 Ok(SshPlatform { os, arch })
469 }
470}
471
472const MAX_MISSED_HEARTBEATS: usize = 5;
473const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
474const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
475
476const MAX_RECONNECT_ATTEMPTS: usize = 3;
477
478enum State {
479 Connecting,
480 Connected {
481 ssh_connection: Arc<dyn RemoteConnection>,
482 delegate: Arc<dyn SshClientDelegate>,
483
484 multiplex_task: Task<Result<()>>,
485 heartbeat_task: Task<Result<()>>,
486 },
487 HeartbeatMissed {
488 missed_heartbeats: usize,
489
490 ssh_connection: Arc<dyn RemoteConnection>,
491 delegate: Arc<dyn SshClientDelegate>,
492
493 multiplex_task: Task<Result<()>>,
494 heartbeat_task: Task<Result<()>>,
495 },
496 Reconnecting,
497 ReconnectFailed {
498 ssh_connection: Arc<dyn RemoteConnection>,
499 delegate: Arc<dyn SshClientDelegate>,
500
501 error: anyhow::Error,
502 attempts: usize,
503 },
504 ReconnectExhausted,
505 ServerNotRunning,
506}
507
508impl fmt::Display for State {
509 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510 match self {
511 Self::Connecting => write!(f, "connecting"),
512 Self::Connected { .. } => write!(f, "connected"),
513 Self::Reconnecting => write!(f, "reconnecting"),
514 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
515 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
516 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
517 Self::ServerNotRunning { .. } => write!(f, "server not running"),
518 }
519 }
520}
521
522impl State {
523 fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
524 match self {
525 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
526 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
527 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
528 _ => None,
529 }
530 }
531
532 fn can_reconnect(&self) -> bool {
533 match self {
534 Self::Connected { .. }
535 | Self::HeartbeatMissed { .. }
536 | Self::ReconnectFailed { .. } => true,
537 State::Connecting
538 | State::Reconnecting
539 | State::ReconnectExhausted
540 | State::ServerNotRunning => false,
541 }
542 }
543
544 fn is_reconnect_failed(&self) -> bool {
545 matches!(self, Self::ReconnectFailed { .. })
546 }
547
548 fn is_reconnect_exhausted(&self) -> bool {
549 matches!(self, Self::ReconnectExhausted { .. })
550 }
551
552 fn is_server_not_running(&self) -> bool {
553 matches!(self, Self::ServerNotRunning)
554 }
555
556 fn is_reconnecting(&self) -> bool {
557 matches!(self, Self::Reconnecting { .. })
558 }
559
560 fn heartbeat_recovered(self) -> Self {
561 match self {
562 Self::HeartbeatMissed {
563 ssh_connection,
564 delegate,
565 multiplex_task,
566 heartbeat_task,
567 ..
568 } => Self::Connected {
569 ssh_connection,
570 delegate,
571 multiplex_task,
572 heartbeat_task,
573 },
574 _ => self,
575 }
576 }
577
578 fn heartbeat_missed(self) -> Self {
579 match self {
580 Self::Connected {
581 ssh_connection,
582 delegate,
583 multiplex_task,
584 heartbeat_task,
585 } => Self::HeartbeatMissed {
586 missed_heartbeats: 1,
587 ssh_connection,
588 delegate,
589 multiplex_task,
590 heartbeat_task,
591 },
592 Self::HeartbeatMissed {
593 missed_heartbeats,
594 ssh_connection,
595 delegate,
596 multiplex_task,
597 heartbeat_task,
598 } => Self::HeartbeatMissed {
599 missed_heartbeats: missed_heartbeats + 1,
600 ssh_connection,
601 delegate,
602 multiplex_task,
603 heartbeat_task,
604 },
605 _ => self,
606 }
607 }
608}
609
610/// The state of the ssh connection.
611#[derive(Clone, Copy, Debug, PartialEq, Eq)]
612pub enum ConnectionState {
613 Connecting,
614 Connected,
615 HeartbeatMissed,
616 Reconnecting,
617 Disconnected,
618}
619
620impl From<&State> for ConnectionState {
621 fn from(value: &State) -> Self {
622 match value {
623 State::Connecting => Self::Connecting,
624 State::Connected { .. } => Self::Connected,
625 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
626 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
627 State::ReconnectExhausted => Self::Disconnected,
628 State::ServerNotRunning => Self::Disconnected,
629 }
630 }
631}
632
633pub struct SshRemoteClient {
634 client: Arc<ChannelClient>,
635 unique_identifier: String,
636 connection_options: SshConnectionOptions,
637 path_style: PathStyle,
638 state: Arc<Mutex<Option<State>>>,
639}
640
641#[derive(Debug)]
642pub enum SshRemoteEvent {
643 Disconnected,
644}
645
646impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
647
648// Identifies the socket on the remote server so that reconnects
649// can re-join the same project.
650pub enum ConnectionIdentifier {
651 Setup(u64),
652 Workspace(i64),
653}
654
655static NEXT_ID: AtomicU64 = AtomicU64::new(1);
656
657impl ConnectionIdentifier {
658 pub fn setup() -> Self {
659 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
660 }
661 // This string gets used in a socket name, and so must be relatively short.
662 // The total length of:
663 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
664 // Must be less than about 100 characters
665 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
666 // So our strings should be at most 20 characters or so.
667 fn to_string(&self, cx: &App) -> String {
668 let identifier_prefix = match ReleaseChannel::global(cx) {
669 ReleaseChannel::Stable => "".to_string(),
670 release_channel => format!("{}-", release_channel.dev_name()),
671 };
672 match self {
673 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
674 Self::Workspace(workspace_id) => {
675 format!("{identifier_prefix}workspace-{workspace_id}",)
676 }
677 }
678 }
679}
680
681impl SshRemoteClient {
682 pub fn new(
683 unique_identifier: ConnectionIdentifier,
684 connection_options: SshConnectionOptions,
685 cancellation: oneshot::Receiver<()>,
686 delegate: Arc<dyn SshClientDelegate>,
687 cx: &mut App,
688 ) -> Task<Result<Option<Entity<Self>>>> {
689 let unique_identifier = unique_identifier.to_string(cx);
690 cx.spawn(async move |cx| {
691 let success = Box::pin(async move {
692 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
693 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
694 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
695
696 let client =
697 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
698
699 let ssh_connection = cx
700 .update(|cx| {
701 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
702 pool.connect(connection_options.clone(), &delegate, cx)
703 })
704 })?
705 .await
706 .map_err(|e| e.cloned())?;
707
708 let path_style = ssh_connection.path_style();
709 let this = cx.new(|_| Self {
710 client: client.clone(),
711 unique_identifier: unique_identifier.clone(),
712 connection_options,
713 path_style,
714 state: Arc::new(Mutex::new(Some(State::Connecting))),
715 })?;
716
717 let io_task = ssh_connection.start_proxy(
718 unique_identifier,
719 false,
720 incoming_tx,
721 outgoing_rx,
722 connection_activity_tx,
723 delegate.clone(),
724 cx,
725 );
726
727 let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
728
729 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
730 log::error!("failed to establish connection: {}", error);
731 return Err(error);
732 }
733
734 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
735
736 this.update(cx, |this, _| {
737 *this.state.lock() = Some(State::Connected {
738 ssh_connection,
739 delegate,
740 multiplex_task,
741 heartbeat_task,
742 });
743 })?;
744
745 Ok(Some(this))
746 });
747
748 select! {
749 _ = cancellation.fuse() => {
750 Ok(None)
751 }
752 result = success.fuse() => result
753 }
754 })
755 }
756
757 pub fn shutdown_processes<T: RequestMessage>(
758 &self,
759 shutdown_request: Option<T>,
760 executor: BackgroundExecutor,
761 ) -> Option<impl Future<Output = ()> + use<T>> {
762 let state = self.state.lock().take()?;
763 log::info!("shutting down ssh processes");
764
765 let State::Connected {
766 multiplex_task,
767 heartbeat_task,
768 ssh_connection,
769 delegate,
770 } = state
771 else {
772 return None;
773 };
774
775 let client = self.client.clone();
776
777 Some(async move {
778 if let Some(shutdown_request) = shutdown_request {
779 client.send(shutdown_request).log_err();
780 // We wait 50ms instead of waiting for a response, because
781 // waiting for a response would require us to wait on the main thread
782 // which we want to avoid in an `on_app_quit` callback.
783 executor.timer(Duration::from_millis(50)).await;
784 }
785
786 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
787 // child of master_process.
788 drop(multiplex_task);
789 // Now drop the rest of state, which kills master process.
790 drop(heartbeat_task);
791 drop(ssh_connection);
792 drop(delegate);
793 })
794 }
795
796 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
797 let mut lock = self.state.lock();
798
799 let can_reconnect = lock
800 .as_ref()
801 .map(|state| state.can_reconnect())
802 .unwrap_or(false);
803 if !can_reconnect {
804 log::info!("aborting reconnect, because not in state that allows reconnecting");
805 let error = if let Some(state) = lock.as_ref() {
806 format!("invalid state, cannot reconnect while in state {state}")
807 } else {
808 "no state set".to_string()
809 };
810 anyhow::bail!(error);
811 }
812
813 let state = lock.take().unwrap();
814 let (attempts, ssh_connection, delegate) = match state {
815 State::Connected {
816 ssh_connection,
817 delegate,
818 multiplex_task,
819 heartbeat_task,
820 }
821 | State::HeartbeatMissed {
822 ssh_connection,
823 delegate,
824 multiplex_task,
825 heartbeat_task,
826 ..
827 } => {
828 drop(multiplex_task);
829 drop(heartbeat_task);
830 (0, ssh_connection, delegate)
831 }
832 State::ReconnectFailed {
833 attempts,
834 ssh_connection,
835 delegate,
836 ..
837 } => (attempts, ssh_connection, delegate),
838 State::Connecting
839 | State::Reconnecting
840 | State::ReconnectExhausted
841 | State::ServerNotRunning => unreachable!(),
842 };
843
844 let attempts = attempts + 1;
845 if attempts > MAX_RECONNECT_ATTEMPTS {
846 log::error!(
847 "Failed to reconnect to after {} attempts, giving up",
848 MAX_RECONNECT_ATTEMPTS
849 );
850 drop(lock);
851 self.set_state(State::ReconnectExhausted, cx);
852 return Ok(());
853 }
854 drop(lock);
855
856 self.set_state(State::Reconnecting, cx);
857
858 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
859
860 let unique_identifier = self.unique_identifier.clone();
861 let client = self.client.clone();
862 let reconnect_task = cx.spawn(async move |this, cx| {
863 macro_rules! failed {
864 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
865 return State::ReconnectFailed {
866 error: anyhow!($error),
867 attempts: $attempts,
868 ssh_connection: $ssh_connection,
869 delegate: $delegate,
870 };
871 };
872 }
873
874 if let Err(error) = ssh_connection
875 .kill()
876 .await
877 .context("Failed to kill ssh process")
878 {
879 failed!(error, attempts, ssh_connection, delegate);
880 };
881
882 let connection_options = ssh_connection.connection_options();
883
884 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
885 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
886 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
887
888 let (ssh_connection, io_task) = match async {
889 let ssh_connection = cx
890 .update_global(|pool: &mut ConnectionPool, cx| {
891 pool.connect(connection_options, &delegate, cx)
892 })?
893 .await
894 .map_err(|error| error.cloned())?;
895
896 let io_task = ssh_connection.start_proxy(
897 unique_identifier,
898 true,
899 incoming_tx,
900 outgoing_rx,
901 connection_activity_tx,
902 delegate.clone(),
903 cx,
904 );
905 anyhow::Ok((ssh_connection, io_task))
906 }
907 .await
908 {
909 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
910 Err(error) => {
911 failed!(error, attempts, ssh_connection, delegate);
912 }
913 };
914
915 let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
916 client.reconnect(incoming_rx, outgoing_tx, &cx);
917
918 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
919 failed!(error, attempts, ssh_connection, delegate);
920 };
921
922 State::Connected {
923 ssh_connection,
924 delegate,
925 multiplex_task,
926 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
927 }
928 });
929
930 cx.spawn(async move |this, cx| {
931 let new_state = reconnect_task.await;
932 this.update(cx, |this, cx| {
933 this.try_set_state(cx, |old_state| {
934 if old_state.is_reconnecting() {
935 match &new_state {
936 State::Connecting
937 | State::Reconnecting { .. }
938 | State::HeartbeatMissed { .. }
939 | State::ServerNotRunning => {}
940 State::Connected { .. } => {
941 log::info!("Successfully reconnected");
942 }
943 State::ReconnectFailed {
944 error, attempts, ..
945 } => {
946 log::error!(
947 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
948 attempts,
949 error
950 );
951 }
952 State::ReconnectExhausted => {
953 log::error!("Reconnect attempt failed and all attempts exhausted");
954 }
955 }
956 Some(new_state)
957 } else {
958 None
959 }
960 });
961
962 if this.state_is(State::is_reconnect_failed) {
963 this.reconnect(cx)
964 } else if this.state_is(State::is_reconnect_exhausted) {
965 Ok(())
966 } else {
967 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
968 Ok(())
969 }
970 })
971 })
972 .detach_and_log_err(cx);
973
974 Ok(())
975 }
976
977 fn heartbeat(
978 this: WeakEntity<Self>,
979 mut connection_activity_rx: mpsc::Receiver<()>,
980 cx: &mut AsyncApp,
981 ) -> Task<Result<()>> {
982 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
983 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
984 };
985
986 cx.spawn(async move |cx| {
987 let mut missed_heartbeats = 0;
988
989 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
990 futures::pin_mut!(keepalive_timer);
991
992 loop {
993 select_biased! {
994 result = connection_activity_rx.next().fuse() => {
995 if result.is_none() {
996 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
997 return Ok(());
998 }
999
1000 if missed_heartbeats != 0 {
1001 missed_heartbeats = 0;
1002 let _ =this.update(cx, |this, mut cx| {
1003 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1004 })?;
1005 }
1006 }
1007 _ = keepalive_timer => {
1008 log::debug!("Sending heartbeat to server...");
1009
1010 let result = select_biased! {
1011 _ = connection_activity_rx.next().fuse() => {
1012 Ok(())
1013 }
1014 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1015 ping_result
1016 }
1017 };
1018
1019 if result.is_err() {
1020 missed_heartbeats += 1;
1021 log::warn!(
1022 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1023 HEARTBEAT_TIMEOUT,
1024 missed_heartbeats,
1025 MAX_MISSED_HEARTBEATS
1026 );
1027 } else if missed_heartbeats != 0 {
1028 missed_heartbeats = 0;
1029 } else {
1030 continue;
1031 }
1032
1033 let result = this.update(cx, |this, mut cx| {
1034 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1035 })?;
1036 if result.is_break() {
1037 return Ok(());
1038 }
1039 }
1040 }
1041
1042 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1043 }
1044
1045 })
1046 }
1047
1048 fn handle_heartbeat_result(
1049 &mut self,
1050 missed_heartbeats: usize,
1051 cx: &mut Context<Self>,
1052 ) -> ControlFlow<()> {
1053 let state = self.state.lock().take().unwrap();
1054 let next_state = if missed_heartbeats > 0 {
1055 state.heartbeat_missed()
1056 } else {
1057 state.heartbeat_recovered()
1058 };
1059
1060 self.set_state(next_state, cx);
1061
1062 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1063 log::error!(
1064 "Missed last {} heartbeats. Reconnecting...",
1065 missed_heartbeats
1066 );
1067
1068 self.reconnect(cx)
1069 .context("failed to start reconnect process after missing heartbeats")
1070 .log_err();
1071 ControlFlow::Break(())
1072 } else {
1073 ControlFlow::Continue(())
1074 }
1075 }
1076
1077 fn monitor(
1078 this: WeakEntity<Self>,
1079 io_task: Task<Result<i32>>,
1080 cx: &AsyncApp,
1081 ) -> Task<Result<()>> {
1082 cx.spawn(async move |cx| {
1083 let result = io_task.await;
1084
1085 match result {
1086 Ok(exit_code) => {
1087 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1088 match error {
1089 ProxyLaunchError::ServerNotRunning => {
1090 log::error!("failed to reconnect because server is not running");
1091 this.update(cx, |this, cx| {
1092 this.set_state(State::ServerNotRunning, cx);
1093 })?;
1094 }
1095 }
1096 } else if exit_code > 0 {
1097 log::error!("proxy process terminated unexpectedly");
1098 this.update(cx, |this, cx| {
1099 this.reconnect(cx).ok();
1100 })?;
1101 }
1102 }
1103 Err(error) => {
1104 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1105 this.update(cx, |this, cx| {
1106 this.reconnect(cx).ok();
1107 })?;
1108 }
1109 }
1110
1111 Ok(())
1112 })
1113 }
1114
1115 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1116 self.state.lock().as_ref().map_or(false, check)
1117 }
1118
1119 fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1120 let mut lock = self.state.lock();
1121 let new_state = lock.as_ref().and_then(map);
1122
1123 if let Some(new_state) = new_state {
1124 lock.replace(new_state);
1125 cx.notify();
1126 }
1127 }
1128
1129 fn set_state(&self, state: State, cx: &mut Context<Self>) {
1130 log::info!("setting state to '{}'", &state);
1131
1132 let is_reconnect_exhausted = state.is_reconnect_exhausted();
1133 let is_server_not_running = state.is_server_not_running();
1134 self.state.lock().replace(state);
1135
1136 if is_reconnect_exhausted || is_server_not_running {
1137 cx.emit(SshRemoteEvent::Disconnected);
1138 }
1139 cx.notify();
1140 }
1141
1142 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1143 self.client.subscribe_to_entity(remote_id, entity);
1144 }
1145
1146 pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1147 self.state
1148 .lock()
1149 .as_ref()
1150 .and_then(|state| state.ssh_connection())
1151 .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1152 }
1153
1154 pub fn upload_directory(
1155 &self,
1156 src_path: PathBuf,
1157 dest_path: RemotePathBuf,
1158 cx: &App,
1159 ) -> Task<Result<()>> {
1160 let state = self.state.lock();
1161 let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1162 return Task::ready(Err(anyhow!("no ssh connection")));
1163 };
1164 connection.upload_directory(src_path, dest_path, cx)
1165 }
1166
1167 pub fn proto_client(&self) -> AnyProtoClient {
1168 self.client.clone().into()
1169 }
1170
1171 pub fn connection_string(&self) -> String {
1172 self.connection_options.connection_string()
1173 }
1174
1175 pub fn connection_options(&self) -> SshConnectionOptions {
1176 self.connection_options.clone()
1177 }
1178
1179 pub fn connection_state(&self) -> ConnectionState {
1180 self.state
1181 .lock()
1182 .as_ref()
1183 .map(ConnectionState::from)
1184 .unwrap_or(ConnectionState::Disconnected)
1185 }
1186
1187 pub fn is_disconnected(&self) -> bool {
1188 self.connection_state() == ConnectionState::Disconnected
1189 }
1190
1191 pub fn path_style(&self) -> PathStyle {
1192 self.path_style
1193 }
1194
1195 #[cfg(any(test, feature = "test-support"))]
1196 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1197 let opts = self.connection_options();
1198 client_cx.spawn(async move |cx| {
1199 let connection = cx
1200 .update_global(|c: &mut ConnectionPool, _| {
1201 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1202 c.clone()
1203 } else {
1204 panic!("missing test connection")
1205 }
1206 })
1207 .unwrap()
1208 .await
1209 .unwrap();
1210
1211 connection.simulate_disconnect(&cx);
1212 })
1213 }
1214
1215 #[cfg(any(test, feature = "test-support"))]
1216 pub fn fake_server(
1217 client_cx: &mut gpui::TestAppContext,
1218 server_cx: &mut gpui::TestAppContext,
1219 ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1220 let port = client_cx
1221 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1222 let opts = SshConnectionOptions {
1223 host: "<fake>".to_string(),
1224 port: Some(port),
1225 ..Default::default()
1226 };
1227 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1228 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1229 let server_client =
1230 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1231 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1232 connection_options: opts.clone(),
1233 server_cx: fake::SendableCx::new(server_cx),
1234 server_channel: server_client.clone(),
1235 });
1236
1237 client_cx.update(|cx| {
1238 cx.update_default_global(|c: &mut ConnectionPool, cx| {
1239 c.connections.insert(
1240 opts.clone(),
1241 ConnectionPoolEntry::Connecting(
1242 cx.background_spawn({
1243 let connection = connection.clone();
1244 async move { Ok(connection.clone()) }
1245 })
1246 .shared(),
1247 ),
1248 );
1249 })
1250 });
1251
1252 (opts, server_client)
1253 }
1254
1255 #[cfg(any(test, feature = "test-support"))]
1256 pub async fn fake_client(
1257 opts: SshConnectionOptions,
1258 client_cx: &mut gpui::TestAppContext,
1259 ) -> Entity<Self> {
1260 let (_tx, rx) = oneshot::channel();
1261 client_cx
1262 .update(|cx| {
1263 Self::new(
1264 ConnectionIdentifier::setup(),
1265 opts,
1266 rx,
1267 Arc::new(fake::Delegate),
1268 cx,
1269 )
1270 })
1271 .await
1272 .unwrap()
1273 .unwrap()
1274 }
1275}
1276
1277enum ConnectionPoolEntry {
1278 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1279 Connected(Weak<dyn RemoteConnection>),
1280}
1281
1282#[derive(Default)]
1283struct ConnectionPool {
1284 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1285}
1286
1287impl Global for ConnectionPool {}
1288
1289impl ConnectionPool {
1290 pub fn connect(
1291 &mut self,
1292 opts: SshConnectionOptions,
1293 delegate: &Arc<dyn SshClientDelegate>,
1294 cx: &mut App,
1295 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1296 let connection = self.connections.get(&opts);
1297 match connection {
1298 Some(ConnectionPoolEntry::Connecting(task)) => {
1299 let delegate = delegate.clone();
1300 cx.spawn(async move |cx| {
1301 delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1302 })
1303 .detach();
1304 return task.clone();
1305 }
1306 Some(ConnectionPoolEntry::Connected(ssh)) => {
1307 if let Some(ssh) = ssh.upgrade() {
1308 if !ssh.has_been_killed() {
1309 return Task::ready(Ok(ssh)).shared();
1310 }
1311 }
1312 self.connections.remove(&opts);
1313 }
1314 None => {}
1315 }
1316
1317 let task = cx
1318 .spawn({
1319 let opts = opts.clone();
1320 let delegate = delegate.clone();
1321 async move |cx| {
1322 let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1323 .await
1324 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1325
1326 cx.update_global(|pool: &mut Self, _| {
1327 debug_assert!(matches!(
1328 pool.connections.get(&opts),
1329 Some(ConnectionPoolEntry::Connecting(_))
1330 ));
1331 match connection {
1332 Ok(connection) => {
1333 pool.connections.insert(
1334 opts.clone(),
1335 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1336 );
1337 Ok(connection)
1338 }
1339 Err(error) => {
1340 pool.connections.remove(&opts);
1341 Err(Arc::new(error))
1342 }
1343 }
1344 })?
1345 }
1346 })
1347 .shared();
1348
1349 self.connections
1350 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1351 task
1352 }
1353}
1354
1355impl From<SshRemoteClient> for AnyProtoClient {
1356 fn from(client: SshRemoteClient) -> Self {
1357 AnyProtoClient::new(client.client.clone())
1358 }
1359}
1360
1361#[async_trait(?Send)]
1362trait RemoteConnection: Send + Sync {
1363 fn start_proxy(
1364 &self,
1365 unique_identifier: String,
1366 reconnect: bool,
1367 incoming_tx: UnboundedSender<Envelope>,
1368 outgoing_rx: UnboundedReceiver<Envelope>,
1369 connection_activity_tx: Sender<()>,
1370 delegate: Arc<dyn SshClientDelegate>,
1371 cx: &mut AsyncApp,
1372 ) -> Task<Result<i32>>;
1373 fn upload_directory(
1374 &self,
1375 src_path: PathBuf,
1376 dest_path: RemotePathBuf,
1377 cx: &App,
1378 ) -> Task<Result<()>>;
1379 async fn kill(&self) -> Result<()>;
1380 fn has_been_killed(&self) -> bool;
1381 /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1382 /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1383 fn ssh_args(&self) -> SshArgs;
1384 fn connection_options(&self) -> SshConnectionOptions;
1385 fn path_style(&self) -> PathStyle;
1386
1387 #[cfg(any(test, feature = "test-support"))]
1388 fn simulate_disconnect(&self, _: &AsyncApp) {}
1389}
1390
1391struct SshRemoteConnection {
1392 socket: SshSocket,
1393 master_process: Mutex<Option<Child>>,
1394 remote_binary_path: Option<RemotePathBuf>,
1395 ssh_platform: SshPlatform,
1396 ssh_path_style: PathStyle,
1397 _temp_dir: TempDir,
1398}
1399
1400#[async_trait(?Send)]
1401impl RemoteConnection for SshRemoteConnection {
1402 async fn kill(&self) -> Result<()> {
1403 let Some(mut process) = self.master_process.lock().take() else {
1404 return Ok(());
1405 };
1406 process.kill().ok();
1407 process.status().await?;
1408 Ok(())
1409 }
1410
1411 fn has_been_killed(&self) -> bool {
1412 self.master_process.lock().is_none()
1413 }
1414
1415 fn ssh_args(&self) -> SshArgs {
1416 self.socket.ssh_args()
1417 }
1418
1419 fn connection_options(&self) -> SshConnectionOptions {
1420 self.socket.connection_options.clone()
1421 }
1422
1423 fn upload_directory(
1424 &self,
1425 src_path: PathBuf,
1426 dest_path: RemotePathBuf,
1427 cx: &App,
1428 ) -> Task<Result<()>> {
1429 let mut command = util::command::new_smol_command("scp");
1430 let output = self
1431 .socket
1432 .ssh_options(&mut command)
1433 .args(
1434 self.socket
1435 .connection_options
1436 .port
1437 .map(|port| vec!["-P".to_string(), port.to_string()])
1438 .unwrap_or_default(),
1439 )
1440 .arg("-C")
1441 .arg("-r")
1442 .arg(&src_path)
1443 .arg(format!(
1444 "{}:{}",
1445 self.socket.connection_options.scp_url(),
1446 dest_path.to_string()
1447 ))
1448 .output();
1449
1450 cx.background_spawn(async move {
1451 let output = output.await?;
1452
1453 anyhow::ensure!(
1454 output.status.success(),
1455 "failed to upload directory {} -> {}: {}",
1456 src_path.display(),
1457 dest_path.to_string(),
1458 String::from_utf8_lossy(&output.stderr)
1459 );
1460
1461 Ok(())
1462 })
1463 }
1464
1465 fn start_proxy(
1466 &self,
1467 unique_identifier: String,
1468 reconnect: bool,
1469 incoming_tx: UnboundedSender<Envelope>,
1470 outgoing_rx: UnboundedReceiver<Envelope>,
1471 connection_activity_tx: Sender<()>,
1472 delegate: Arc<dyn SshClientDelegate>,
1473 cx: &mut AsyncApp,
1474 ) -> Task<Result<i32>> {
1475 delegate.set_status(Some("Starting proxy"), cx);
1476
1477 let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1478 return Task::ready(Err(anyhow!("Remote binary path not set")));
1479 };
1480
1481 let mut start_proxy_command = shell_script!(
1482 "exec {binary_path} proxy --identifier {identifier}",
1483 binary_path = &remote_binary_path.to_string(),
1484 identifier = &unique_identifier,
1485 );
1486
1487 if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1488 start_proxy_command = format!(
1489 "RUST_LOG={} {}",
1490 shlex::try_quote(&rust_log).unwrap(),
1491 start_proxy_command
1492 )
1493 }
1494 if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1495 start_proxy_command = format!(
1496 "RUST_BACKTRACE={} {}",
1497 shlex::try_quote(&rust_backtrace).unwrap(),
1498 start_proxy_command
1499 )
1500 }
1501 if reconnect {
1502 start_proxy_command.push_str(" --reconnect");
1503 }
1504
1505 let ssh_proxy_process = match self
1506 .socket
1507 .ssh_command("sh", &["-c", &start_proxy_command])
1508 // IMPORTANT: we kill this process when we drop the task that uses it.
1509 .kill_on_drop(true)
1510 .spawn()
1511 {
1512 Ok(process) => process,
1513 Err(error) => {
1514 return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1515 }
1516 };
1517
1518 Self::multiplex(
1519 ssh_proxy_process,
1520 incoming_tx,
1521 outgoing_rx,
1522 connection_activity_tx,
1523 &cx,
1524 )
1525 }
1526
1527 fn path_style(&self) -> PathStyle {
1528 self.ssh_path_style
1529 }
1530}
1531
1532impl SshRemoteConnection {
1533 async fn new(
1534 connection_options: SshConnectionOptions,
1535 delegate: Arc<dyn SshClientDelegate>,
1536 cx: &mut AsyncApp,
1537 ) -> Result<Self> {
1538 use askpass::AskPassResult;
1539
1540 delegate.set_status(Some("Connecting"), cx);
1541
1542 let url = connection_options.ssh_url();
1543
1544 let temp_dir = tempfile::Builder::new()
1545 .prefix("zed-ssh-session")
1546 .tempdir()?;
1547 let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1548 let delegate = delegate.clone();
1549 move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1550 });
1551
1552 let mut askpass =
1553 askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1554
1555 // Start the master SSH process, which does not do anything except for establish
1556 // the connection and keep it open, allowing other ssh commands to reuse it
1557 // via a control socket.
1558 #[cfg(not(target_os = "windows"))]
1559 let socket_path = temp_dir.path().join("ssh.sock");
1560
1561 let mut master_process = {
1562 #[cfg(not(target_os = "windows"))]
1563 let args = [
1564 "-N",
1565 "-o",
1566 "ControlPersist=no",
1567 "-o",
1568 "ControlMaster=yes",
1569 "-o",
1570 ];
1571 // On Windows, `ControlMaster` and `ControlPath` are not supported:
1572 // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1573 // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1574 #[cfg(target_os = "windows")]
1575 let args = ["-N"];
1576 let mut master_process = util::command::new_smol_command("ssh");
1577 master_process
1578 .kill_on_drop(true)
1579 .stdin(Stdio::null())
1580 .stdout(Stdio::piped())
1581 .stderr(Stdio::piped())
1582 .env("SSH_ASKPASS_REQUIRE", "force")
1583 .env("SSH_ASKPASS", askpass.script_path())
1584 .args(connection_options.additional_args())
1585 .args(args);
1586 #[cfg(not(target_os = "windows"))]
1587 master_process.arg(format!("ControlPath={}", socket_path.display()));
1588 master_process.arg(&url).spawn()?
1589 };
1590 // Wait for this ssh process to close its stdout, indicating that authentication
1591 // has completed.
1592 let mut stdout = master_process.stdout.take().unwrap();
1593 let mut output = Vec::new();
1594
1595 let result = select_biased! {
1596 result = askpass.run().fuse() => {
1597 match result {
1598 AskPassResult::CancelledByUser => {
1599 master_process.kill().ok();
1600 anyhow::bail!("SSH connection canceled")
1601 }
1602 AskPassResult::Timedout => {
1603 anyhow::bail!("connecting to host timed out")
1604 }
1605 }
1606 }
1607 _ = stdout.read_to_end(&mut output).fuse() => {
1608 anyhow::Ok(())
1609 }
1610 };
1611
1612 if let Err(e) = result {
1613 return Err(e.context("Failed to connect to host"));
1614 }
1615
1616 if master_process.try_status()?.is_some() {
1617 output.clear();
1618 let mut stderr = master_process.stderr.take().unwrap();
1619 stderr.read_to_end(&mut output).await?;
1620
1621 let error_message = format!(
1622 "failed to connect: {}",
1623 String::from_utf8_lossy(&output).trim()
1624 );
1625 anyhow::bail!(error_message);
1626 }
1627
1628 #[cfg(not(target_os = "windows"))]
1629 let socket = SshSocket::new(connection_options, socket_path)?;
1630 #[cfg(target_os = "windows")]
1631 let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1632 drop(askpass);
1633
1634 let ssh_platform = socket.platform().await?;
1635 let ssh_path_style = match ssh_platform.os {
1636 "windows" => PathStyle::Windows,
1637 _ => PathStyle::Posix,
1638 };
1639
1640 let mut this = Self {
1641 socket,
1642 master_process: Mutex::new(Some(master_process)),
1643 _temp_dir: temp_dir,
1644 remote_binary_path: None,
1645 ssh_path_style,
1646 ssh_platform,
1647 };
1648
1649 let (release_channel, version, commit) = cx.update(|cx| {
1650 (
1651 ReleaseChannel::global(cx),
1652 AppVersion::global(cx),
1653 AppCommitSha::try_global(cx),
1654 )
1655 })?;
1656 this.remote_binary_path = Some(
1657 this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1658 .await?,
1659 );
1660
1661 Ok(this)
1662 }
1663
1664 fn multiplex(
1665 mut ssh_proxy_process: Child,
1666 incoming_tx: UnboundedSender<Envelope>,
1667 mut outgoing_rx: UnboundedReceiver<Envelope>,
1668 mut connection_activity_tx: Sender<()>,
1669 cx: &AsyncApp,
1670 ) -> Task<Result<i32>> {
1671 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1672 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1673 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1674
1675 let mut stdin_buffer = Vec::new();
1676 let mut stdout_buffer = Vec::new();
1677 let mut stderr_buffer = Vec::new();
1678 let mut stderr_offset = 0;
1679
1680 let stdin_task = cx.background_spawn(async move {
1681 while let Some(outgoing) = outgoing_rx.next().await {
1682 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1683 }
1684 anyhow::Ok(())
1685 });
1686
1687 let stdout_task = cx.background_spawn({
1688 let mut connection_activity_tx = connection_activity_tx.clone();
1689 async move {
1690 loop {
1691 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1692 let len = child_stdout.read(&mut stdout_buffer).await?;
1693
1694 if len == 0 {
1695 return anyhow::Ok(());
1696 }
1697
1698 if len < MESSAGE_LEN_SIZE {
1699 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1700 }
1701
1702 let message_len = message_len_from_buffer(&stdout_buffer);
1703 let envelope =
1704 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1705 .await?;
1706 connection_activity_tx.try_send(()).ok();
1707 incoming_tx.unbounded_send(envelope).ok();
1708 }
1709 }
1710 });
1711
1712 let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1713 loop {
1714 stderr_buffer.resize(stderr_offset + 1024, 0);
1715
1716 let len = child_stderr
1717 .read(&mut stderr_buffer[stderr_offset..])
1718 .await?;
1719 if len == 0 {
1720 return anyhow::Ok(());
1721 }
1722
1723 stderr_offset += len;
1724 let mut start_ix = 0;
1725 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1726 .iter()
1727 .position(|b| b == &b'\n')
1728 {
1729 let line_ix = start_ix + ix;
1730 let content = &stderr_buffer[start_ix..line_ix];
1731 start_ix = line_ix + 1;
1732 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1733 record.log(log::logger())
1734 } else {
1735 eprintln!("(remote) {}", String::from_utf8_lossy(content));
1736 }
1737 }
1738 stderr_buffer.drain(0..start_ix);
1739 stderr_offset -= start_ix;
1740
1741 connection_activity_tx.try_send(()).ok();
1742 }
1743 });
1744
1745 cx.background_spawn(async move {
1746 let result = futures::select! {
1747 result = stdin_task.fuse() => {
1748 result.context("stdin")
1749 }
1750 result = stdout_task.fuse() => {
1751 result.context("stdout")
1752 }
1753 result = stderr_task.fuse() => {
1754 result.context("stderr")
1755 }
1756 };
1757
1758 let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1759 match result {
1760 Ok(_) => Ok(status),
1761 Err(error) => Err(error),
1762 }
1763 })
1764 }
1765
1766 #[allow(unused)]
1767 async fn ensure_server_binary(
1768 &self,
1769 delegate: &Arc<dyn SshClientDelegate>,
1770 release_channel: ReleaseChannel,
1771 version: SemanticVersion,
1772 commit: Option<AppCommitSha>,
1773 cx: &mut AsyncApp,
1774 ) -> Result<RemotePathBuf> {
1775 let version_str = match release_channel {
1776 ReleaseChannel::Nightly => {
1777 let commit = commit.map(|s| s.full()).unwrap_or_default();
1778 format!("{}-{}", version, commit)
1779 }
1780 ReleaseChannel::Dev => "build".to_string(),
1781 _ => version.to_string(),
1782 };
1783 let binary_name = format!(
1784 "zed-remote-server-{}-{}",
1785 release_channel.dev_name(),
1786 version_str
1787 );
1788 let dst_path = RemotePathBuf::new(
1789 paths::remote_server_dir_relative().join(binary_name),
1790 self.ssh_path_style,
1791 );
1792
1793 let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1794 #[cfg(debug_assertions)]
1795 if let Some(build_remote_server) = build_remote_server {
1796 let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1797 let tmp_path = RemotePathBuf::new(
1798 paths::remote_server_dir_relative().join(format!(
1799 "download-{}-{}",
1800 std::process::id(),
1801 src_path.file_name().unwrap().to_string_lossy()
1802 )),
1803 self.ssh_path_style,
1804 );
1805 self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1806 .await?;
1807 self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1808 .await?;
1809 return Ok(dst_path);
1810 }
1811
1812 if self
1813 .socket
1814 .run_command(&dst_path.to_string(), &["version"])
1815 .await
1816 .is_ok()
1817 {
1818 return Ok(dst_path);
1819 }
1820
1821 let wanted_version = cx.update(|cx| match release_channel {
1822 ReleaseChannel::Nightly => Ok(None),
1823 ReleaseChannel::Dev => {
1824 anyhow::bail!(
1825 "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1826 dst_path
1827 )
1828 }
1829 _ => Ok(Some(AppVersion::global(cx))),
1830 })??;
1831
1832 let tmp_path_gz = RemotePathBuf::new(
1833 PathBuf::from(format!(
1834 "{}-download-{}.gz",
1835 dst_path.to_string(),
1836 std::process::id()
1837 )),
1838 self.ssh_path_style,
1839 );
1840 if !self.socket.connection_options.upload_binary_over_ssh {
1841 if let Some((url, body)) = delegate
1842 .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1843 .await?
1844 {
1845 match self
1846 .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1847 .await
1848 {
1849 Ok(_) => {
1850 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1851 .await?;
1852 return Ok(dst_path);
1853 }
1854 Err(e) => {
1855 log::error!(
1856 "Failed to download binary on server, attempting to upload server: {}",
1857 e
1858 )
1859 }
1860 }
1861 }
1862 }
1863
1864 let src_path = delegate
1865 .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1866 .await?;
1867 self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1868 .await?;
1869 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1870 .await?;
1871 return Ok(dst_path);
1872 }
1873
1874 async fn download_binary_on_server(
1875 &self,
1876 url: &str,
1877 body: &str,
1878 tmp_path_gz: &RemotePathBuf,
1879 delegate: &Arc<dyn SshClientDelegate>,
1880 cx: &mut AsyncApp,
1881 ) -> Result<()> {
1882 if let Some(parent) = tmp_path_gz.parent() {
1883 self.socket
1884 .run_command(
1885 "sh",
1886 &[
1887 "-c",
1888 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1889 ],
1890 )
1891 .await?;
1892 }
1893
1894 delegate.set_status(Some("Downloading remote development server on host"), cx);
1895
1896 match self
1897 .socket
1898 .run_command(
1899 "curl",
1900 &[
1901 "-f",
1902 "-L",
1903 "-X",
1904 "GET",
1905 "-H",
1906 "Content-Type: application/json",
1907 "-d",
1908 &body,
1909 &url,
1910 "-o",
1911 &tmp_path_gz.to_string(),
1912 ],
1913 )
1914 .await
1915 {
1916 Ok(_) => {}
1917 Err(e) => {
1918 if self.socket.run_command("which", &["curl"]).await.is_ok() {
1919 return Err(e);
1920 }
1921
1922 match self
1923 .socket
1924 .run_command(
1925 "wget",
1926 &[
1927 "--method=GET",
1928 "--header=Content-Type: application/json",
1929 "--body-data",
1930 &body,
1931 &url,
1932 "-O",
1933 &tmp_path_gz.to_string(),
1934 ],
1935 )
1936 .await
1937 {
1938 Ok(_) => {}
1939 Err(e) => {
1940 if self.socket.run_command("which", &["wget"]).await.is_ok() {
1941 return Err(e);
1942 } else {
1943 anyhow::bail!("Neither curl nor wget is available");
1944 }
1945 }
1946 }
1947 }
1948 }
1949
1950 Ok(())
1951 }
1952
1953 async fn upload_local_server_binary(
1954 &self,
1955 src_path: &Path,
1956 tmp_path_gz: &RemotePathBuf,
1957 delegate: &Arc<dyn SshClientDelegate>,
1958 cx: &mut AsyncApp,
1959 ) -> Result<()> {
1960 if let Some(parent) = tmp_path_gz.parent() {
1961 self.socket
1962 .run_command(
1963 "sh",
1964 &[
1965 "-c",
1966 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1967 ],
1968 )
1969 .await?;
1970 }
1971
1972 let src_stat = fs::metadata(&src_path).await?;
1973 let size = src_stat.len();
1974
1975 let t0 = Instant::now();
1976 delegate.set_status(Some("Uploading remote development server"), cx);
1977 log::info!(
1978 "uploading remote development server to {:?} ({}kb)",
1979 tmp_path_gz,
1980 size / 1024
1981 );
1982 self.upload_file(&src_path, &tmp_path_gz)
1983 .await
1984 .context("failed to upload server binary")?;
1985 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1986 Ok(())
1987 }
1988
1989 async fn extract_server_binary(
1990 &self,
1991 dst_path: &RemotePathBuf,
1992 tmp_path: &RemotePathBuf,
1993 delegate: &Arc<dyn SshClientDelegate>,
1994 cx: &mut AsyncApp,
1995 ) -> Result<()> {
1996 delegate.set_status(Some("Extracting remote development server"), cx);
1997 let server_mode = 0o755;
1998
1999 let orig_tmp_path = tmp_path.to_string();
2000 let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2001 shell_script!(
2002 "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2003 server_mode = &format!("{:o}", server_mode),
2004 dst_path = &dst_path.to_string(),
2005 )
2006 } else {
2007 shell_script!(
2008 "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2009 server_mode = &format!("{:o}", server_mode),
2010 dst_path = &dst_path.to_string()
2011 )
2012 };
2013 self.socket.run_command("sh", &["-c", &script]).await?;
2014 Ok(())
2015 }
2016
2017 async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2018 log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2019 let mut command = util::command::new_smol_command("scp");
2020 let output = self
2021 .socket
2022 .ssh_options(&mut command)
2023 .args(
2024 self.socket
2025 .connection_options
2026 .port
2027 .map(|port| vec!["-P".to_string(), port.to_string()])
2028 .unwrap_or_default(),
2029 )
2030 .arg(src_path)
2031 .arg(format!(
2032 "{}:{}",
2033 self.socket.connection_options.scp_url(),
2034 dest_path.to_string()
2035 ))
2036 .output()
2037 .await?;
2038
2039 anyhow::ensure!(
2040 output.status.success(),
2041 "failed to upload file {} -> {}: {}",
2042 src_path.display(),
2043 dest_path.to_string(),
2044 String::from_utf8_lossy(&output.stderr)
2045 );
2046 Ok(())
2047 }
2048
2049 #[cfg(debug_assertions)]
2050 async fn build_local(
2051 &self,
2052 build_remote_server: String,
2053 delegate: &Arc<dyn SshClientDelegate>,
2054 cx: &mut AsyncApp,
2055 ) -> Result<PathBuf> {
2056 use smol::process::{Command, Stdio};
2057 use std::env::VarError;
2058
2059 async fn run_cmd(command: &mut Command) -> Result<()> {
2060 let output = command
2061 .kill_on_drop(true)
2062 .stderr(Stdio::inherit())
2063 .output()
2064 .await?;
2065 anyhow::ensure!(
2066 output.status.success(),
2067 "Failed to run command: {command:?}"
2068 );
2069 Ok(())
2070 }
2071
2072 let triple = format!(
2073 "{}-{}",
2074 self.ssh_platform.arch,
2075 match self.ssh_platform.os {
2076 "linux" => "unknown-linux-musl",
2077 "macos" => "apple-darwin",
2078 _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2079 }
2080 );
2081 let mut rust_flags = match std::env::var("RUSTFLAGS") {
2082 Ok(val) => val,
2083 Err(VarError::NotPresent) => String::new(),
2084 Err(e) => {
2085 log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2086 String::new()
2087 }
2088 };
2089 if self.ssh_platform.os == "linux" {
2090 rust_flags.push_str(" -C target-feature=+crt-static");
2091 }
2092 if build_remote_server.contains("mold") {
2093 rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2094 }
2095
2096 if self.ssh_platform.arch == std::env::consts::ARCH
2097 && self.ssh_platform.os == std::env::consts::OS
2098 {
2099 delegate.set_status(Some("Building remote server binary from source"), cx);
2100 log::info!("building remote server binary from source");
2101 run_cmd(
2102 Command::new("cargo")
2103 .args([
2104 "build",
2105 "--package",
2106 "remote_server",
2107 "--features",
2108 "debug-embed",
2109 "--target-dir",
2110 "target/remote_server",
2111 "--target",
2112 &triple,
2113 ])
2114 .env("RUSTFLAGS", &rust_flags),
2115 )
2116 .await?;
2117 } else {
2118 if build_remote_server.contains("cross") {
2119 #[cfg(target_os = "windows")]
2120 use util::paths::SanitizedPath;
2121
2122 delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2123 log::info!("installing cross");
2124 run_cmd(Command::new("cargo").args([
2125 "install",
2126 "cross",
2127 "--git",
2128 "https://github.com/cross-rs/cross",
2129 ]))
2130 .await?;
2131
2132 delegate.set_status(
2133 Some(&format!(
2134 "Building remote server binary from source for {} with Docker",
2135 &triple
2136 )),
2137 cx,
2138 );
2139 log::info!("building remote server binary from source for {}", &triple);
2140
2141 // On Windows, the binding needs to be set to the canonical path
2142 #[cfg(target_os = "windows")]
2143 let src =
2144 SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2145 #[cfg(not(target_os = "windows"))]
2146 let src = "./target";
2147 run_cmd(
2148 Command::new("cross")
2149 .args([
2150 "build",
2151 "--package",
2152 "remote_server",
2153 "--features",
2154 "debug-embed",
2155 "--target-dir",
2156 "target/remote_server",
2157 "--target",
2158 &triple,
2159 ])
2160 .env(
2161 "CROSS_CONTAINER_OPTS",
2162 format!("--mount type=bind,src={src},dst=/app/target"),
2163 )
2164 .env("RUSTFLAGS", &rust_flags),
2165 )
2166 .await?;
2167 } else {
2168 let which = cx
2169 .background_spawn(async move { which::which("zig") })
2170 .await;
2171
2172 if which.is_err() {
2173 #[cfg(not(target_os = "windows"))]
2174 {
2175 anyhow::bail!(
2176 "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2177 )
2178 }
2179 #[cfg(target_os = "windows")]
2180 {
2181 anyhow::bail!(
2182 "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2183 )
2184 }
2185 }
2186
2187 delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2188 log::info!("adding rustup target");
2189 run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2190
2191 delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2192 log::info!("installing cargo-zigbuild");
2193 run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"]))
2194 .await?;
2195
2196 delegate.set_status(
2197 Some(&format!(
2198 "Building remote binary from source for {triple} with Zig"
2199 )),
2200 cx,
2201 );
2202 log::info!("building remote binary from source for {triple} with Zig");
2203 run_cmd(
2204 Command::new("cargo")
2205 .args([
2206 "zigbuild",
2207 "--package",
2208 "remote_server",
2209 "--features",
2210 "debug-embed",
2211 "--target-dir",
2212 "target/remote_server",
2213 "--target",
2214 &triple,
2215 ])
2216 .env("RUSTFLAGS", &rust_flags),
2217 )
2218 .await?;
2219 }
2220 };
2221 let bin_path = Path::new("target")
2222 .join("remote_server")
2223 .join(&triple)
2224 .join("debug")
2225 .join("remote_server");
2226
2227 let path = if !build_remote_server.contains("nocompress") {
2228 delegate.set_status(Some("Compressing binary"), cx);
2229
2230 #[cfg(not(target_os = "windows"))]
2231 {
2232 run_cmd(Command::new("gzip").args(["-9", "-f", &bin_path.to_string_lossy()]))
2233 .await?;
2234 }
2235 #[cfg(target_os = "windows")]
2236 {
2237 // On Windows, we use 7z to compress the binary
2238 let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2239 let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2240 if smol::fs::metadata(&gz_path).await.is_ok() {
2241 smol::fs::remove_file(&gz_path).await?;
2242 }
2243 run_cmd(Command::new(seven_zip).args([
2244 "a",
2245 "-tgzip",
2246 &gz_path,
2247 &bin_path.to_string_lossy(),
2248 ]))
2249 .await?;
2250 }
2251
2252 let mut archive_path = bin_path;
2253 archive_path.set_extension("gz");
2254 std::env::current_dir()?.join(archive_path)
2255 } else {
2256 bin_path
2257 };
2258
2259 Ok(path)
2260 }
2261}
2262
2263type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2264
2265pub struct ChannelClient {
2266 next_message_id: AtomicU32,
2267 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2268 buffer: Mutex<VecDeque<Envelope>>,
2269 response_channels: ResponseChannels,
2270 message_handlers: Mutex<ProtoMessageHandlerSet>,
2271 max_received: AtomicU32,
2272 name: &'static str,
2273 task: Mutex<Task<Result<()>>>,
2274}
2275
2276impl ChannelClient {
2277 pub fn new(
2278 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2279 outgoing_tx: mpsc::UnboundedSender<Envelope>,
2280 cx: &App,
2281 name: &'static str,
2282 ) -> Arc<Self> {
2283 Arc::new_cyclic(|this| Self {
2284 outgoing_tx: Mutex::new(outgoing_tx),
2285 next_message_id: AtomicU32::new(0),
2286 max_received: AtomicU32::new(0),
2287 response_channels: ResponseChannels::default(),
2288 message_handlers: Default::default(),
2289 buffer: Mutex::new(VecDeque::new()),
2290 name,
2291 task: Mutex::new(Self::start_handling_messages(
2292 this.clone(),
2293 incoming_rx,
2294 &cx.to_async(),
2295 )),
2296 })
2297 }
2298
2299 fn start_handling_messages(
2300 this: Weak<Self>,
2301 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2302 cx: &AsyncApp,
2303 ) -> Task<Result<()>> {
2304 cx.spawn(async move |cx| {
2305 let peer_id = PeerId { owner_id: 0, id: 0 };
2306 while let Some(incoming) = incoming_rx.next().await {
2307 let Some(this) = this.upgrade() else {
2308 return anyhow::Ok(());
2309 };
2310 if let Some(ack_id) = incoming.ack_id {
2311 let mut buffer = this.buffer.lock();
2312 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2313 buffer.pop_front();
2314 }
2315 }
2316 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2317 {
2318 log::debug!(
2319 "{}:ssh message received. name:FlushBufferedMessages",
2320 this.name
2321 );
2322 {
2323 let buffer = this.buffer.lock();
2324 for envelope in buffer.iter() {
2325 this.outgoing_tx
2326 .lock()
2327 .unbounded_send(envelope.clone())
2328 .ok();
2329 }
2330 }
2331 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2332 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2333 this.outgoing_tx.lock().unbounded_send(envelope).ok();
2334 continue;
2335 }
2336
2337 this.max_received.store(incoming.id, SeqCst);
2338
2339 if let Some(request_id) = incoming.responding_to {
2340 let request_id = MessageId(request_id);
2341 let sender = this.response_channels.lock().remove(&request_id);
2342 if let Some(sender) = sender {
2343 let (tx, rx) = oneshot::channel();
2344 if incoming.payload.is_some() {
2345 sender.send((incoming, tx)).ok();
2346 }
2347 rx.await.ok();
2348 }
2349 } else if let Some(envelope) =
2350 build_typed_envelope(peer_id, Instant::now(), incoming)
2351 {
2352 let type_name = envelope.payload_type_name();
2353 if let Some(future) = ProtoMessageHandlerSet::handle_message(
2354 &this.message_handlers,
2355 envelope,
2356 this.clone().into(),
2357 cx.clone(),
2358 ) {
2359 log::debug!("{}:ssh message received. name:{type_name}", this.name);
2360 cx.foreground_executor()
2361 .spawn(async move {
2362 match future.await {
2363 Ok(_) => {
2364 log::debug!(
2365 "{}:ssh message handled. name:{type_name}",
2366 this.name
2367 );
2368 }
2369 Err(error) => {
2370 log::error!(
2371 "{}:error handling message. type:{}, error:{}",
2372 this.name,
2373 type_name,
2374 format!("{error:#}").lines().fold(
2375 String::new(),
2376 |mut message, line| {
2377 if !message.is_empty() {
2378 message.push(' ');
2379 }
2380 message.push_str(line);
2381 message
2382 }
2383 )
2384 );
2385 }
2386 }
2387 })
2388 .detach()
2389 } else {
2390 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2391 }
2392 }
2393 }
2394 anyhow::Ok(())
2395 })
2396 }
2397
2398 pub fn reconnect(
2399 self: &Arc<Self>,
2400 incoming_rx: UnboundedReceiver<Envelope>,
2401 outgoing_tx: UnboundedSender<Envelope>,
2402 cx: &AsyncApp,
2403 ) {
2404 *self.outgoing_tx.lock() = outgoing_tx;
2405 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2406 }
2407
2408 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2409 let id = (TypeId::of::<E>(), remote_id);
2410
2411 let mut message_handlers = self.message_handlers.lock();
2412 if message_handlers
2413 .entities_by_type_and_remote_id
2414 .contains_key(&id)
2415 {
2416 panic!("already subscribed to entity");
2417 }
2418
2419 message_handlers.entities_by_type_and_remote_id.insert(
2420 id,
2421 EntityMessageSubscriber::Entity {
2422 handle: entity.downgrade().into(),
2423 },
2424 );
2425 }
2426
2427 pub fn request<T: RequestMessage>(
2428 &self,
2429 payload: T,
2430 ) -> impl 'static + Future<Output = Result<T::Response>> {
2431 self.request_internal(payload, true)
2432 }
2433
2434 fn request_internal<T: RequestMessage>(
2435 &self,
2436 payload: T,
2437 use_buffer: bool,
2438 ) -> impl 'static + Future<Output = Result<T::Response>> {
2439 log::debug!("ssh request start. name:{}", T::NAME);
2440 let response =
2441 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2442 async move {
2443 let response = response.await?;
2444 log::debug!("ssh request finish. name:{}", T::NAME);
2445 T::Response::from_envelope(response).context("received a response of the wrong type")
2446 }
2447 }
2448
2449 pub async fn resync(&self, timeout: Duration) -> Result<()> {
2450 smol::future::or(
2451 async {
2452 self.request_internal(proto::FlushBufferedMessages {}, false)
2453 .await?;
2454
2455 for envelope in self.buffer.lock().iter() {
2456 self.outgoing_tx
2457 .lock()
2458 .unbounded_send(envelope.clone())
2459 .ok();
2460 }
2461 Ok(())
2462 },
2463 async {
2464 smol::Timer::after(timeout).await;
2465 anyhow::bail!("Timeout detected")
2466 },
2467 )
2468 .await
2469 }
2470
2471 pub async fn ping(&self, timeout: Duration) -> Result<()> {
2472 smol::future::or(
2473 async {
2474 self.request(proto::Ping {}).await?;
2475 Ok(())
2476 },
2477 async {
2478 smol::Timer::after(timeout).await;
2479 anyhow::bail!("Timeout detected")
2480 },
2481 )
2482 .await
2483 }
2484
2485 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2486 log::debug!("ssh send name:{}", T::NAME);
2487 self.send_dynamic(payload.into_envelope(0, None, None))
2488 }
2489
2490 fn request_dynamic(
2491 &self,
2492 mut envelope: proto::Envelope,
2493 type_name: &'static str,
2494 use_buffer: bool,
2495 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2496 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2497 let (tx, rx) = oneshot::channel();
2498 let mut response_channels_lock = self.response_channels.lock();
2499 response_channels_lock.insert(MessageId(envelope.id), tx);
2500 drop(response_channels_lock);
2501
2502 let result = if use_buffer {
2503 self.send_buffered(envelope)
2504 } else {
2505 self.send_unbuffered(envelope)
2506 };
2507 async move {
2508 if let Err(error) = &result {
2509 log::error!("failed to send message: {error}");
2510 anyhow::bail!("failed to send message: {error}");
2511 }
2512
2513 let response = rx.await.context("connection lost")?.0;
2514 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2515 return Err(RpcError::from_proto(error, type_name));
2516 }
2517 Ok(response)
2518 }
2519 }
2520
2521 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2522 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2523 self.send_buffered(envelope)
2524 }
2525
2526 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2527 envelope.ack_id = Some(self.max_received.load(SeqCst));
2528 self.buffer.lock().push_back(envelope.clone());
2529 // ignore errors on send (happen while we're reconnecting)
2530 // assume that the global "disconnected" overlay is sufficient.
2531 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2532 Ok(())
2533 }
2534
2535 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2536 envelope.ack_id = Some(self.max_received.load(SeqCst));
2537 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2538 Ok(())
2539 }
2540}
2541
2542impl ProtoClient for ChannelClient {
2543 fn request(
2544 &self,
2545 envelope: proto::Envelope,
2546 request_type: &'static str,
2547 ) -> BoxFuture<'static, Result<proto::Envelope>> {
2548 self.request_dynamic(envelope, request_type, true).boxed()
2549 }
2550
2551 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2552 self.send_dynamic(envelope)
2553 }
2554
2555 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2556 self.send_dynamic(envelope)
2557 }
2558
2559 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2560 &self.message_handlers
2561 }
2562
2563 fn is_via_collab(&self) -> bool {
2564 false
2565 }
2566}
2567
2568#[cfg(any(test, feature = "test-support"))]
2569mod fake {
2570 use std::{path::PathBuf, sync::Arc};
2571
2572 use anyhow::Result;
2573 use async_trait::async_trait;
2574 use futures::{
2575 FutureExt, SinkExt, StreamExt,
2576 channel::{
2577 mpsc::{self, Sender},
2578 oneshot,
2579 },
2580 select_biased,
2581 };
2582 use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2583 use release_channel::ReleaseChannel;
2584 use rpc::proto::Envelope;
2585 use util::paths::{PathStyle, RemotePathBuf};
2586
2587 use super::{
2588 ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2589 SshPlatform,
2590 };
2591
2592 pub(super) struct FakeRemoteConnection {
2593 pub(super) connection_options: SshConnectionOptions,
2594 pub(super) server_channel: Arc<ChannelClient>,
2595 pub(super) server_cx: SendableCx,
2596 }
2597
2598 pub(super) struct SendableCx(AsyncApp);
2599 impl SendableCx {
2600 // SAFETY: When run in test mode, GPUI is always single threaded.
2601 pub(super) fn new(cx: &TestAppContext) -> Self {
2602 Self(cx.to_async())
2603 }
2604
2605 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2606 fn get(&self, _: &AsyncApp) -> AsyncApp {
2607 self.0.clone()
2608 }
2609 }
2610
2611 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2612 unsafe impl Send for SendableCx {}
2613 unsafe impl Sync for SendableCx {}
2614
2615 #[async_trait(?Send)]
2616 impl RemoteConnection for FakeRemoteConnection {
2617 async fn kill(&self) -> Result<()> {
2618 Ok(())
2619 }
2620
2621 fn has_been_killed(&self) -> bool {
2622 false
2623 }
2624
2625 fn ssh_args(&self) -> SshArgs {
2626 SshArgs {
2627 arguments: Vec::new(),
2628 envs: None,
2629 }
2630 }
2631
2632 fn upload_directory(
2633 &self,
2634 _src_path: PathBuf,
2635 _dest_path: RemotePathBuf,
2636 _cx: &App,
2637 ) -> Task<Result<()>> {
2638 unreachable!()
2639 }
2640
2641 fn connection_options(&self) -> SshConnectionOptions {
2642 self.connection_options.clone()
2643 }
2644
2645 fn simulate_disconnect(&self, cx: &AsyncApp) {
2646 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2647 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2648 self.server_channel
2649 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2650 }
2651
2652 fn start_proxy(
2653 &self,
2654 _unique_identifier: String,
2655 _reconnect: bool,
2656 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2657 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2658 mut connection_activity_tx: Sender<()>,
2659 _delegate: Arc<dyn SshClientDelegate>,
2660 cx: &mut AsyncApp,
2661 ) -> Task<Result<i32>> {
2662 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2663 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2664
2665 self.server_channel.reconnect(
2666 server_incoming_rx,
2667 server_outgoing_tx,
2668 &self.server_cx.get(cx),
2669 );
2670
2671 cx.background_spawn(async move {
2672 loop {
2673 select_biased! {
2674 server_to_client = server_outgoing_rx.next().fuse() => {
2675 let Some(server_to_client) = server_to_client else {
2676 return Ok(1)
2677 };
2678 connection_activity_tx.try_send(()).ok();
2679 client_incoming_tx.send(server_to_client).await.ok();
2680 }
2681 client_to_server = client_outgoing_rx.next().fuse() => {
2682 let Some(client_to_server) = client_to_server else {
2683 return Ok(1)
2684 };
2685 server_incoming_tx.send(client_to_server).await.ok();
2686 }
2687 }
2688 }
2689 })
2690 }
2691
2692 fn path_style(&self) -> PathStyle {
2693 PathStyle::current()
2694 }
2695 }
2696
2697 pub(super) struct Delegate;
2698
2699 impl SshClientDelegate for Delegate {
2700 fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2701 unreachable!()
2702 }
2703
2704 fn download_server_binary_locally(
2705 &self,
2706 _: SshPlatform,
2707 _: ReleaseChannel,
2708 _: Option<SemanticVersion>,
2709 _: &mut AsyncApp,
2710 ) -> Task<Result<PathBuf>> {
2711 unreachable!()
2712 }
2713
2714 fn get_download_params(
2715 &self,
2716 _platform: SshPlatform,
2717 _release_channel: ReleaseChannel,
2718 _version: Option<SemanticVersion>,
2719 _cx: &mut AsyncApp,
2720 ) -> Task<Result<Option<(String, String)>>> {
2721 unreachable!()
2722 }
2723
2724 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2725 }
2726}