1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{Context as _, Result, anyhow};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
13 channel::{
14 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
15 oneshot,
16 },
17 future::{BoxFuture, Shared},
18 select, select_biased,
19};
20use gpui::{
21 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
22 EventEmitter, Global, SemanticVersion, Task, WeakEntity,
23};
24use itertools::Itertools;
25use parking_lot::Mutex;
26
27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
28use rpc::{
29 AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
30 RpcError,
31 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
32};
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use smol::{
36 fs,
37 process::{self, Child, Stdio},
38};
39use std::{
40 any::TypeId,
41 collections::VecDeque,
42 fmt, iter,
43 ops::ControlFlow,
44 path::{Path, PathBuf},
45 sync::{
46 Arc, Weak,
47 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
48 },
49 time::{Duration, Instant},
50};
51use tempfile::TempDir;
52use util::{
53 ResultExt,
54 paths::{PathStyle, RemotePathBuf},
55};
56
57#[derive(
58 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
59)]
60pub struct SshProjectId(pub u64);
61
62#[derive(Clone)]
63pub struct SshSocket {
64 connection_options: SshConnectionOptions,
65 #[cfg(not(target_os = "windows"))]
66 socket_path: PathBuf,
67 #[cfg(target_os = "windows")]
68 envs: HashMap<String, String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
72pub struct SshPortForwardOption {
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub local_host: Option<String>,
75 pub local_port: u16,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub remote_host: Option<String>,
78 pub remote_port: u16,
79}
80
81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
82pub struct SshConnectionOptions {
83 pub host: String,
84 pub username: Option<String>,
85 pub port: Option<u16>,
86 pub password: Option<String>,
87 pub args: Option<Vec<String>>,
88 pub port_forwards: Option<Vec<SshPortForwardOption>>,
89
90 pub nickname: Option<String>,
91 pub upload_binary_over_ssh: bool,
92}
93
94pub struct SshArgs {
95 pub arguments: Vec<String>,
96 pub envs: Option<HashMap<String, String>>,
97}
98
99#[macro_export]
100macro_rules! shell_script {
101 ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
102 format!(
103 $fmt,
104 $(
105 $name = shlex::try_quote($arg).unwrap()
106 ),+
107 )
108 }};
109}
110
111fn parse_port_number(port_str: &str) -> Result<u16> {
112 port_str
113 .parse()
114 .with_context(|| format!("parsing port number: {port_str}"))
115}
116
117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
118 let parts: Vec<&str> = spec.split(':').collect();
119
120 match parts.len() {
121 4 => {
122 let local_port = parse_port_number(parts[1])?;
123 let remote_port = parse_port_number(parts[3])?;
124
125 Ok(SshPortForwardOption {
126 local_host: Some(parts[0].to_string()),
127 local_port,
128 remote_host: Some(parts[2].to_string()),
129 remote_port,
130 })
131 }
132 3 => {
133 let local_port = parse_port_number(parts[0])?;
134 let remote_port = parse_port_number(parts[2])?;
135
136 Ok(SshPortForwardOption {
137 local_host: None,
138 local_port,
139 remote_host: Some(parts[1].to_string()),
140 remote_port,
141 })
142 }
143 _ => anyhow::bail!("Invalid port forward format"),
144 }
145}
146
147impl SshConnectionOptions {
148 pub fn parse_command_line(input: &str) -> Result<Self> {
149 let input = input.trim_start_matches("ssh ");
150 let mut hostname: Option<String> = None;
151 let mut username: Option<String> = None;
152 let mut port: Option<u16> = None;
153 let mut args = Vec::new();
154 let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
155
156 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
157 const ALLOWED_OPTS: &[&str] = &[
158 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
159 ];
160 const ALLOWED_ARGS: &[&str] = &[
161 "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
162 "-w",
163 ];
164
165 let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
166
167 'outer: while let Some(arg) = tokens.next() {
168 if ALLOWED_OPTS.contains(&(&arg as &str)) {
169 args.push(arg.to_string());
170 continue;
171 }
172 if arg == "-p" {
173 port = tokens.next().and_then(|arg| arg.parse().ok());
174 continue;
175 } else if let Some(p) = arg.strip_prefix("-p") {
176 port = p.parse().ok();
177 continue;
178 }
179 if arg == "-l" {
180 username = tokens.next();
181 continue;
182 } else if let Some(l) = arg.strip_prefix("-l") {
183 username = Some(l.to_string());
184 continue;
185 }
186 if arg == "-L" || arg.starts_with("-L") {
187 let forward_spec = if arg == "-L" {
188 tokens.next()
189 } else {
190 Some(arg.strip_prefix("-L").unwrap().to_string())
191 };
192
193 if let Some(spec) = forward_spec {
194 port_forwards.push(parse_port_forward_spec(&spec)?);
195 } else {
196 anyhow::bail!("Missing port forward format");
197 }
198 }
199
200 for a in ALLOWED_ARGS {
201 if arg == *a {
202 args.push(arg);
203 if let Some(next) = tokens.next() {
204 args.push(next);
205 }
206 continue 'outer;
207 } else if arg.starts_with(a) {
208 args.push(arg);
209 continue 'outer;
210 }
211 }
212 if arg.starts_with("-") || hostname.is_some() {
213 anyhow::bail!("unsupported argument: {:?}", arg);
214 }
215 let mut input = &arg as &str;
216 // Destination might be: username1@username2@ip2@ip1
217 if let Some((u, rest)) = input.rsplit_once('@') {
218 input = rest;
219 username = Some(u.to_string());
220 }
221 if let Some((rest, p)) = input.split_once(':') {
222 input = rest;
223 port = p.parse().ok()
224 }
225 hostname = Some(input.to_string())
226 }
227
228 let Some(hostname) = hostname else {
229 anyhow::bail!("missing hostname");
230 };
231
232 let port_forwards = match port_forwards.len() {
233 0 => None,
234 _ => Some(port_forwards),
235 };
236
237 Ok(Self {
238 host: hostname.to_string(),
239 username: username.clone(),
240 port,
241 port_forwards,
242 args: Some(args),
243 password: None,
244 nickname: None,
245 upload_binary_over_ssh: false,
246 })
247 }
248
249 pub fn ssh_url(&self) -> String {
250 let mut result = String::from("ssh://");
251 if let Some(username) = &self.username {
252 // Username might be: username1@username2@ip2
253 let username = urlencoding::encode(username);
254 result.push_str(&username);
255 result.push('@');
256 }
257 result.push_str(&self.host);
258 if let Some(port) = self.port {
259 result.push(':');
260 result.push_str(&port.to_string());
261 }
262 result
263 }
264
265 pub fn additional_args(&self) -> Vec<String> {
266 let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
267
268 if let Some(forwards) = &self.port_forwards {
269 args.extend(forwards.iter().map(|pf| {
270 let local_host = match &pf.local_host {
271 Some(host) => host,
272 None => "localhost",
273 };
274 let remote_host = match &pf.remote_host {
275 Some(host) => host,
276 None => "localhost",
277 };
278
279 format!(
280 "-L{}:{}:{}:{}",
281 local_host, pf.local_port, remote_host, pf.remote_port
282 )
283 }));
284 }
285
286 args
287 }
288
289 fn scp_url(&self) -> String {
290 if let Some(username) = &self.username {
291 format!("{}@{}", username, self.host)
292 } else {
293 self.host.clone()
294 }
295 }
296
297 pub fn connection_string(&self) -> String {
298 let host = if let Some(username) = &self.username {
299 format!("{}@{}", username, self.host)
300 } else {
301 self.host.clone()
302 };
303 if let Some(port) = &self.port {
304 format!("{}:{}", host, port)
305 } else {
306 host
307 }
308 }
309}
310
311#[derive(Copy, Clone, Debug)]
312pub struct SshPlatform {
313 pub os: &'static str,
314 pub arch: &'static str,
315}
316
317impl SshPlatform {
318 pub fn triple(&self) -> Option<String> {
319 Some(format!(
320 "{}-{}",
321 self.arch,
322 match self.os {
323 "linux" => "unknown-linux-gnu",
324 "macos" => "apple-darwin",
325 _ => return None,
326 }
327 ))
328 }
329}
330
331pub trait SshClientDelegate: Send + Sync {
332 fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
333 fn get_download_params(
334 &self,
335 platform: SshPlatform,
336 release_channel: ReleaseChannel,
337 version: Option<SemanticVersion>,
338 cx: &mut AsyncApp,
339 ) -> Task<Result<Option<(String, String)>>>;
340
341 fn download_server_binary_locally(
342 &self,
343 platform: SshPlatform,
344 release_channel: ReleaseChannel,
345 version: Option<SemanticVersion>,
346 cx: &mut AsyncApp,
347 ) -> Task<Result<PathBuf>>;
348 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
349}
350
351impl SshSocket {
352 #[cfg(not(target_os = "windows"))]
353 fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
354 Ok(Self {
355 connection_options: options,
356 socket_path,
357 })
358 }
359
360 #[cfg(target_os = "windows")]
361 fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
362 let askpass_script = temp_dir.path().join("askpass.bat");
363 std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
364 let mut envs = HashMap::default();
365 envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
366 envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
367 envs.insert("ZED_SSH_ASKPASS".into(), secret);
368 Ok(Self {
369 connection_options: options,
370 envs,
371 })
372 }
373
374 // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
375 // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
376 // and passes -l as an argument to sh, not to ls.
377 // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
378 // into a machine. You must use `cd` to get back to $HOME.
379 // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
380 fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
381 let mut command = util::command::new_smol_command("ssh");
382 let to_run = iter::once(&program)
383 .chain(args.iter())
384 .map(|token| {
385 // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
386 debug_assert!(
387 !token.contains('\n'),
388 "multiline arguments do not work in all shells"
389 );
390 shlex::try_quote(token).unwrap()
391 })
392 .join(" ");
393 let to_run = format!("cd; {to_run}");
394 log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
395 self.ssh_options(&mut command)
396 .arg(self.connection_options.ssh_url())
397 .arg(to_run);
398 command
399 }
400
401 async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
402 let output = self.ssh_command(program, args).output().await?;
403 anyhow::ensure!(
404 output.status.success(),
405 "failed to run command: {}",
406 String::from_utf8_lossy(&output.stderr)
407 );
408 Ok(String::from_utf8_lossy(&output.stdout).to_string())
409 }
410
411 #[cfg(not(target_os = "windows"))]
412 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
413 command
414 .stdin(Stdio::piped())
415 .stdout(Stdio::piped())
416 .stderr(Stdio::piped())
417 .args(["-o", "ControlMaster=no", "-o"])
418 .arg(format!("ControlPath={}", self.socket_path.display()))
419 }
420
421 #[cfg(target_os = "windows")]
422 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
423 command
424 .stdin(Stdio::piped())
425 .stdout(Stdio::piped())
426 .stderr(Stdio::piped())
427 .envs(self.envs.clone())
428 }
429
430 // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
431 // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
432 #[cfg(not(target_os = "windows"))]
433 fn ssh_args(&self) -> SshArgs {
434 SshArgs {
435 arguments: vec![
436 "-o".to_string(),
437 "ControlMaster=no".to_string(),
438 "-o".to_string(),
439 format!("ControlPath={}", self.socket_path.display()),
440 self.connection_options.ssh_url(),
441 ],
442 envs: None,
443 }
444 }
445
446 #[cfg(target_os = "windows")]
447 fn ssh_args(&self) -> SshArgs {
448 SshArgs {
449 arguments: vec![self.connection_options.ssh_url()],
450 envs: Some(self.envs.clone()),
451 }
452 }
453
454 async fn platform(&self) -> Result<SshPlatform> {
455 let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
456 let Some((os, arch)) = uname.split_once(" ") else {
457 anyhow::bail!("unknown uname: {uname:?}")
458 };
459
460 let os = match os.trim() {
461 "Darwin" => "macos",
462 "Linux" => "linux",
463 _ => anyhow::bail!(
464 "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
465 ),
466 };
467 // exclude armv5,6,7 as they are 32-bit.
468 let arch = if arch.starts_with("armv8")
469 || arch.starts_with("armv9")
470 || arch.starts_with("arm64")
471 || arch.starts_with("aarch64")
472 {
473 "aarch64"
474 } else if arch.starts_with("x86") {
475 "x86_64"
476 } else {
477 anyhow::bail!(
478 "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
479 )
480 };
481
482 Ok(SshPlatform { os, arch })
483 }
484}
485
486const MAX_MISSED_HEARTBEATS: usize = 5;
487const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
488const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
489
490const MAX_RECONNECT_ATTEMPTS: usize = 3;
491
492enum State {
493 Connecting,
494 Connected {
495 ssh_connection: Arc<dyn RemoteConnection>,
496 delegate: Arc<dyn SshClientDelegate>,
497
498 multiplex_task: Task<Result<()>>,
499 heartbeat_task: Task<Result<()>>,
500 },
501 HeartbeatMissed {
502 missed_heartbeats: usize,
503
504 ssh_connection: Arc<dyn RemoteConnection>,
505 delegate: Arc<dyn SshClientDelegate>,
506
507 multiplex_task: Task<Result<()>>,
508 heartbeat_task: Task<Result<()>>,
509 },
510 Reconnecting,
511 ReconnectFailed {
512 ssh_connection: Arc<dyn RemoteConnection>,
513 delegate: Arc<dyn SshClientDelegate>,
514
515 error: anyhow::Error,
516 attempts: usize,
517 },
518 ReconnectExhausted,
519 ServerNotRunning,
520}
521
522impl fmt::Display for State {
523 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
524 match self {
525 Self::Connecting => write!(f, "connecting"),
526 Self::Connected { .. } => write!(f, "connected"),
527 Self::Reconnecting => write!(f, "reconnecting"),
528 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
529 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
530 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
531 Self::ServerNotRunning { .. } => write!(f, "server not running"),
532 }
533 }
534}
535
536impl State {
537 fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
538 match self {
539 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
540 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
541 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
542 _ => None,
543 }
544 }
545
546 fn can_reconnect(&self) -> bool {
547 match self {
548 Self::Connected { .. }
549 | Self::HeartbeatMissed { .. }
550 | Self::ReconnectFailed { .. } => true,
551 State::Connecting
552 | State::Reconnecting
553 | State::ReconnectExhausted
554 | State::ServerNotRunning => false,
555 }
556 }
557
558 fn is_reconnect_failed(&self) -> bool {
559 matches!(self, Self::ReconnectFailed { .. })
560 }
561
562 fn is_reconnect_exhausted(&self) -> bool {
563 matches!(self, Self::ReconnectExhausted { .. })
564 }
565
566 fn is_server_not_running(&self) -> bool {
567 matches!(self, Self::ServerNotRunning)
568 }
569
570 fn is_reconnecting(&self) -> bool {
571 matches!(self, Self::Reconnecting { .. })
572 }
573
574 fn heartbeat_recovered(self) -> Self {
575 match self {
576 Self::HeartbeatMissed {
577 ssh_connection,
578 delegate,
579 multiplex_task,
580 heartbeat_task,
581 ..
582 } => Self::Connected {
583 ssh_connection,
584 delegate,
585 multiplex_task,
586 heartbeat_task,
587 },
588 _ => self,
589 }
590 }
591
592 fn heartbeat_missed(self) -> Self {
593 match self {
594 Self::Connected {
595 ssh_connection,
596 delegate,
597 multiplex_task,
598 heartbeat_task,
599 } => Self::HeartbeatMissed {
600 missed_heartbeats: 1,
601 ssh_connection,
602 delegate,
603 multiplex_task,
604 heartbeat_task,
605 },
606 Self::HeartbeatMissed {
607 missed_heartbeats,
608 ssh_connection,
609 delegate,
610 multiplex_task,
611 heartbeat_task,
612 } => Self::HeartbeatMissed {
613 missed_heartbeats: missed_heartbeats + 1,
614 ssh_connection,
615 delegate,
616 multiplex_task,
617 heartbeat_task,
618 },
619 _ => self,
620 }
621 }
622}
623
624/// The state of the ssh connection.
625#[derive(Clone, Copy, Debug, PartialEq, Eq)]
626pub enum ConnectionState {
627 Connecting,
628 Connected,
629 HeartbeatMissed,
630 Reconnecting,
631 Disconnected,
632}
633
634impl From<&State> for ConnectionState {
635 fn from(value: &State) -> Self {
636 match value {
637 State::Connecting => Self::Connecting,
638 State::Connected { .. } => Self::Connected,
639 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
640 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
641 State::ReconnectExhausted => Self::Disconnected,
642 State::ServerNotRunning => Self::Disconnected,
643 }
644 }
645}
646
647pub struct SshRemoteClient {
648 client: Arc<ChannelClient>,
649 unique_identifier: String,
650 connection_options: SshConnectionOptions,
651 path_style: PathStyle,
652 state: Arc<Mutex<Option<State>>>,
653}
654
655#[derive(Debug)]
656pub enum SshRemoteEvent {
657 Disconnected,
658}
659
660impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
661
662// Identifies the socket on the remote server so that reconnects
663// can re-join the same project.
664pub enum ConnectionIdentifier {
665 Setup(u64),
666 Workspace(i64),
667}
668
669static NEXT_ID: AtomicU64 = AtomicU64::new(1);
670
671impl ConnectionIdentifier {
672 pub fn setup() -> Self {
673 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
674 }
675 // This string gets used in a socket name, and so must be relatively short.
676 // The total length of:
677 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
678 // Must be less than about 100 characters
679 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
680 // So our strings should be at most 20 characters or so.
681 fn to_string(&self, cx: &App) -> String {
682 let identifier_prefix = match ReleaseChannel::global(cx) {
683 ReleaseChannel::Stable => "".to_string(),
684 release_channel => format!("{}-", release_channel.dev_name()),
685 };
686 match self {
687 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
688 Self::Workspace(workspace_id) => {
689 format!("{identifier_prefix}workspace-{workspace_id}",)
690 }
691 }
692 }
693}
694
695impl SshRemoteClient {
696 pub fn new(
697 unique_identifier: ConnectionIdentifier,
698 connection_options: SshConnectionOptions,
699 cancellation: oneshot::Receiver<()>,
700 delegate: Arc<dyn SshClientDelegate>,
701 cx: &mut App,
702 ) -> Task<Result<Option<Entity<Self>>>> {
703 let unique_identifier = unique_identifier.to_string(cx);
704 cx.spawn(async move |cx| {
705 let success = Box::pin(async move {
706 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
707 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
708 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
709
710 let client =
711 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
712
713 let ssh_connection = cx
714 .update(|cx| {
715 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
716 pool.connect(connection_options.clone(), &delegate, cx)
717 })
718 })?
719 .await
720 .map_err(|e| e.cloned())?;
721
722 let path_style = ssh_connection.path_style();
723 let this = cx.new(|_| Self {
724 client: client.clone(),
725 unique_identifier: unique_identifier.clone(),
726 connection_options,
727 path_style,
728 state: Arc::new(Mutex::new(Some(State::Connecting))),
729 })?;
730
731 let io_task = ssh_connection.start_proxy(
732 unique_identifier,
733 false,
734 incoming_tx,
735 outgoing_rx,
736 connection_activity_tx,
737 delegate.clone(),
738 cx,
739 );
740
741 let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
742
743 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
744 log::error!("failed to establish connection: {}", error);
745 return Err(error);
746 }
747
748 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
749
750 this.update(cx, |this, _| {
751 *this.state.lock() = Some(State::Connected {
752 ssh_connection,
753 delegate,
754 multiplex_task,
755 heartbeat_task,
756 });
757 })?;
758
759 Ok(Some(this))
760 });
761
762 select! {
763 _ = cancellation.fuse() => {
764 Ok(None)
765 }
766 result = success.fuse() => result
767 }
768 })
769 }
770
771 pub fn shutdown_processes<T: RequestMessage>(
772 &self,
773 shutdown_request: Option<T>,
774 executor: BackgroundExecutor,
775 ) -> Option<impl Future<Output = ()> + use<T>> {
776 let state = self.state.lock().take()?;
777 log::info!("shutting down ssh processes");
778
779 let State::Connected {
780 multiplex_task,
781 heartbeat_task,
782 ssh_connection,
783 delegate,
784 } = state
785 else {
786 return None;
787 };
788
789 let client = self.client.clone();
790
791 Some(async move {
792 if let Some(shutdown_request) = shutdown_request {
793 client.send(shutdown_request).log_err();
794 // We wait 50ms instead of waiting for a response, because
795 // waiting for a response would require us to wait on the main thread
796 // which we want to avoid in an `on_app_quit` callback.
797 executor.timer(Duration::from_millis(50)).await;
798 }
799
800 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
801 // child of master_process.
802 drop(multiplex_task);
803 // Now drop the rest of state, which kills master process.
804 drop(heartbeat_task);
805 drop(ssh_connection);
806 drop(delegate);
807 })
808 }
809
810 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
811 let mut lock = self.state.lock();
812
813 let can_reconnect = lock
814 .as_ref()
815 .map(|state| state.can_reconnect())
816 .unwrap_or(false);
817 if !can_reconnect {
818 log::info!("aborting reconnect, because not in state that allows reconnecting");
819 let error = if let Some(state) = lock.as_ref() {
820 format!("invalid state, cannot reconnect while in state {state}")
821 } else {
822 "no state set".to_string()
823 };
824 anyhow::bail!(error);
825 }
826
827 let state = lock.take().unwrap();
828 let (attempts, ssh_connection, delegate) = match state {
829 State::Connected {
830 ssh_connection,
831 delegate,
832 multiplex_task,
833 heartbeat_task,
834 }
835 | State::HeartbeatMissed {
836 ssh_connection,
837 delegate,
838 multiplex_task,
839 heartbeat_task,
840 ..
841 } => {
842 drop(multiplex_task);
843 drop(heartbeat_task);
844 (0, ssh_connection, delegate)
845 }
846 State::ReconnectFailed {
847 attempts,
848 ssh_connection,
849 delegate,
850 ..
851 } => (attempts, ssh_connection, delegate),
852 State::Connecting
853 | State::Reconnecting
854 | State::ReconnectExhausted
855 | State::ServerNotRunning => unreachable!(),
856 };
857
858 let attempts = attempts + 1;
859 if attempts > MAX_RECONNECT_ATTEMPTS {
860 log::error!(
861 "Failed to reconnect to after {} attempts, giving up",
862 MAX_RECONNECT_ATTEMPTS
863 );
864 drop(lock);
865 self.set_state(State::ReconnectExhausted, cx);
866 return Ok(());
867 }
868 drop(lock);
869
870 self.set_state(State::Reconnecting, cx);
871
872 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
873
874 let unique_identifier = self.unique_identifier.clone();
875 let client = self.client.clone();
876 let reconnect_task = cx.spawn(async move |this, cx| {
877 macro_rules! failed {
878 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
879 return State::ReconnectFailed {
880 error: anyhow!($error),
881 attempts: $attempts,
882 ssh_connection: $ssh_connection,
883 delegate: $delegate,
884 };
885 };
886 }
887
888 if let Err(error) = ssh_connection
889 .kill()
890 .await
891 .context("Failed to kill ssh process")
892 {
893 failed!(error, attempts, ssh_connection, delegate);
894 };
895
896 let connection_options = ssh_connection.connection_options();
897
898 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
899 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
900 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
901
902 let (ssh_connection, io_task) = match async {
903 let ssh_connection = cx
904 .update_global(|pool: &mut ConnectionPool, cx| {
905 pool.connect(connection_options, &delegate, cx)
906 })?
907 .await
908 .map_err(|error| error.cloned())?;
909
910 let io_task = ssh_connection.start_proxy(
911 unique_identifier,
912 true,
913 incoming_tx,
914 outgoing_rx,
915 connection_activity_tx,
916 delegate.clone(),
917 cx,
918 );
919 anyhow::Ok((ssh_connection, io_task))
920 }
921 .await
922 {
923 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
924 Err(error) => {
925 failed!(error, attempts, ssh_connection, delegate);
926 }
927 };
928
929 let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
930 client.reconnect(incoming_rx, outgoing_tx, &cx);
931
932 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
933 failed!(error, attempts, ssh_connection, delegate);
934 };
935
936 State::Connected {
937 ssh_connection,
938 delegate,
939 multiplex_task,
940 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
941 }
942 });
943
944 cx.spawn(async move |this, cx| {
945 let new_state = reconnect_task.await;
946 this.update(cx, |this, cx| {
947 this.try_set_state(cx, |old_state| {
948 if old_state.is_reconnecting() {
949 match &new_state {
950 State::Connecting
951 | State::Reconnecting { .. }
952 | State::HeartbeatMissed { .. }
953 | State::ServerNotRunning => {}
954 State::Connected { .. } => {
955 log::info!("Successfully reconnected");
956 }
957 State::ReconnectFailed {
958 error, attempts, ..
959 } => {
960 log::error!(
961 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
962 attempts,
963 error
964 );
965 }
966 State::ReconnectExhausted => {
967 log::error!("Reconnect attempt failed and all attempts exhausted");
968 }
969 }
970 Some(new_state)
971 } else {
972 None
973 }
974 });
975
976 if this.state_is(State::is_reconnect_failed) {
977 this.reconnect(cx)
978 } else if this.state_is(State::is_reconnect_exhausted) {
979 Ok(())
980 } else {
981 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
982 Ok(())
983 }
984 })
985 })
986 .detach_and_log_err(cx);
987
988 Ok(())
989 }
990
991 fn heartbeat(
992 this: WeakEntity<Self>,
993 mut connection_activity_rx: mpsc::Receiver<()>,
994 cx: &mut AsyncApp,
995 ) -> Task<Result<()>> {
996 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
997 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
998 };
999
1000 cx.spawn(async move |cx| {
1001 let mut missed_heartbeats = 0;
1002
1003 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
1004 futures::pin_mut!(keepalive_timer);
1005
1006 loop {
1007 select_biased! {
1008 result = connection_activity_rx.next().fuse() => {
1009 if result.is_none() {
1010 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1011 return Ok(());
1012 }
1013
1014 if missed_heartbeats != 0 {
1015 missed_heartbeats = 0;
1016 let _ =this.update(cx, |this, mut cx| {
1017 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1018 })?;
1019 }
1020 }
1021 _ = keepalive_timer => {
1022 log::debug!("Sending heartbeat to server...");
1023
1024 let result = select_biased! {
1025 _ = connection_activity_rx.next().fuse() => {
1026 Ok(())
1027 }
1028 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1029 ping_result
1030 }
1031 };
1032
1033 if result.is_err() {
1034 missed_heartbeats += 1;
1035 log::warn!(
1036 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1037 HEARTBEAT_TIMEOUT,
1038 missed_heartbeats,
1039 MAX_MISSED_HEARTBEATS
1040 );
1041 } else if missed_heartbeats != 0 {
1042 missed_heartbeats = 0;
1043 } else {
1044 continue;
1045 }
1046
1047 let result = this.update(cx, |this, mut cx| {
1048 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1049 })?;
1050 if result.is_break() {
1051 return Ok(());
1052 }
1053 }
1054 }
1055
1056 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1057 }
1058
1059 })
1060 }
1061
1062 fn handle_heartbeat_result(
1063 &mut self,
1064 missed_heartbeats: usize,
1065 cx: &mut Context<Self>,
1066 ) -> ControlFlow<()> {
1067 let state = self.state.lock().take().unwrap();
1068 let next_state = if missed_heartbeats > 0 {
1069 state.heartbeat_missed()
1070 } else {
1071 state.heartbeat_recovered()
1072 };
1073
1074 self.set_state(next_state, cx);
1075
1076 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1077 log::error!(
1078 "Missed last {} heartbeats. Reconnecting...",
1079 missed_heartbeats
1080 );
1081
1082 self.reconnect(cx)
1083 .context("failed to start reconnect process after missing heartbeats")
1084 .log_err();
1085 ControlFlow::Break(())
1086 } else {
1087 ControlFlow::Continue(())
1088 }
1089 }
1090
1091 fn monitor(
1092 this: WeakEntity<Self>,
1093 io_task: Task<Result<i32>>,
1094 cx: &AsyncApp,
1095 ) -> Task<Result<()>> {
1096 cx.spawn(async move |cx| {
1097 let result = io_task.await;
1098
1099 match result {
1100 Ok(exit_code) => {
1101 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1102 match error {
1103 ProxyLaunchError::ServerNotRunning => {
1104 log::error!("failed to reconnect because server is not running");
1105 this.update(cx, |this, cx| {
1106 this.set_state(State::ServerNotRunning, cx);
1107 })?;
1108 }
1109 }
1110 } else if exit_code > 0 {
1111 log::error!("proxy process terminated unexpectedly");
1112 this.update(cx, |this, cx| {
1113 this.reconnect(cx).ok();
1114 })?;
1115 }
1116 }
1117 Err(error) => {
1118 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1119 this.update(cx, |this, cx| {
1120 this.reconnect(cx).ok();
1121 })?;
1122 }
1123 }
1124
1125 Ok(())
1126 })
1127 }
1128
1129 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1130 self.state.lock().as_ref().map_or(false, check)
1131 }
1132
1133 fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1134 let mut lock = self.state.lock();
1135 let new_state = lock.as_ref().and_then(map);
1136
1137 if let Some(new_state) = new_state {
1138 lock.replace(new_state);
1139 cx.notify();
1140 }
1141 }
1142
1143 fn set_state(&self, state: State, cx: &mut Context<Self>) {
1144 log::info!("setting state to '{}'", &state);
1145
1146 let is_reconnect_exhausted = state.is_reconnect_exhausted();
1147 let is_server_not_running = state.is_server_not_running();
1148 self.state.lock().replace(state);
1149
1150 if is_reconnect_exhausted || is_server_not_running {
1151 cx.emit(SshRemoteEvent::Disconnected);
1152 }
1153 cx.notify();
1154 }
1155
1156 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1157 self.client.subscribe_to_entity(remote_id, entity);
1158 }
1159
1160 pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1161 self.state
1162 .lock()
1163 .as_ref()
1164 .and_then(|state| state.ssh_connection())
1165 .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1166 }
1167
1168 pub fn upload_directory(
1169 &self,
1170 src_path: PathBuf,
1171 dest_path: RemotePathBuf,
1172 cx: &App,
1173 ) -> Task<Result<()>> {
1174 let state = self.state.lock();
1175 let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1176 return Task::ready(Err(anyhow!("no ssh connection")));
1177 };
1178 connection.upload_directory(src_path, dest_path, cx)
1179 }
1180
1181 pub fn proto_client(&self) -> AnyProtoClient {
1182 self.client.clone().into()
1183 }
1184
1185 pub fn connection_string(&self) -> String {
1186 self.connection_options.connection_string()
1187 }
1188
1189 pub fn connection_options(&self) -> SshConnectionOptions {
1190 self.connection_options.clone()
1191 }
1192
1193 pub fn connection_state(&self) -> ConnectionState {
1194 self.state
1195 .lock()
1196 .as_ref()
1197 .map(ConnectionState::from)
1198 .unwrap_or(ConnectionState::Disconnected)
1199 }
1200
1201 pub fn is_disconnected(&self) -> bool {
1202 self.connection_state() == ConnectionState::Disconnected
1203 }
1204
1205 pub fn path_style(&self) -> PathStyle {
1206 self.path_style
1207 }
1208
1209 #[cfg(any(test, feature = "test-support"))]
1210 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1211 let opts = self.connection_options();
1212 client_cx.spawn(async move |cx| {
1213 let connection = cx
1214 .update_global(|c: &mut ConnectionPool, _| {
1215 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1216 c.clone()
1217 } else {
1218 panic!("missing test connection")
1219 }
1220 })
1221 .unwrap()
1222 .await
1223 .unwrap();
1224
1225 connection.simulate_disconnect(&cx);
1226 })
1227 }
1228
1229 #[cfg(any(test, feature = "test-support"))]
1230 pub fn fake_server(
1231 client_cx: &mut gpui::TestAppContext,
1232 server_cx: &mut gpui::TestAppContext,
1233 ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1234 let port = client_cx
1235 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1236 let opts = SshConnectionOptions {
1237 host: "<fake>".to_string(),
1238 port: Some(port),
1239 ..Default::default()
1240 };
1241 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1242 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1243 let server_client =
1244 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1245 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1246 connection_options: opts.clone(),
1247 server_cx: fake::SendableCx::new(server_cx),
1248 server_channel: server_client.clone(),
1249 });
1250
1251 client_cx.update(|cx| {
1252 cx.update_default_global(|c: &mut ConnectionPool, cx| {
1253 c.connections.insert(
1254 opts.clone(),
1255 ConnectionPoolEntry::Connecting(
1256 cx.background_spawn({
1257 let connection = connection.clone();
1258 async move { Ok(connection.clone()) }
1259 })
1260 .shared(),
1261 ),
1262 );
1263 })
1264 });
1265
1266 (opts, server_client)
1267 }
1268
1269 #[cfg(any(test, feature = "test-support"))]
1270 pub async fn fake_client(
1271 opts: SshConnectionOptions,
1272 client_cx: &mut gpui::TestAppContext,
1273 ) -> Entity<Self> {
1274 let (_tx, rx) = oneshot::channel();
1275 client_cx
1276 .update(|cx| {
1277 Self::new(
1278 ConnectionIdentifier::setup(),
1279 opts,
1280 rx,
1281 Arc::new(fake::Delegate),
1282 cx,
1283 )
1284 })
1285 .await
1286 .unwrap()
1287 .unwrap()
1288 }
1289}
1290
1291enum ConnectionPoolEntry {
1292 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1293 Connected(Weak<dyn RemoteConnection>),
1294}
1295
1296#[derive(Default)]
1297struct ConnectionPool {
1298 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1299}
1300
1301impl Global for ConnectionPool {}
1302
1303impl ConnectionPool {
1304 pub fn connect(
1305 &mut self,
1306 opts: SshConnectionOptions,
1307 delegate: &Arc<dyn SshClientDelegate>,
1308 cx: &mut App,
1309 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1310 let connection = self.connections.get(&opts);
1311 match connection {
1312 Some(ConnectionPoolEntry::Connecting(task)) => {
1313 let delegate = delegate.clone();
1314 cx.spawn(async move |cx| {
1315 delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1316 })
1317 .detach();
1318 return task.clone();
1319 }
1320 Some(ConnectionPoolEntry::Connected(ssh)) => {
1321 if let Some(ssh) = ssh.upgrade() {
1322 if !ssh.has_been_killed() {
1323 return Task::ready(Ok(ssh)).shared();
1324 }
1325 }
1326 self.connections.remove(&opts);
1327 }
1328 None => {}
1329 }
1330
1331 let task = cx
1332 .spawn({
1333 let opts = opts.clone();
1334 let delegate = delegate.clone();
1335 async move |cx| {
1336 let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1337 .await
1338 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1339
1340 cx.update_global(|pool: &mut Self, _| {
1341 debug_assert!(matches!(
1342 pool.connections.get(&opts),
1343 Some(ConnectionPoolEntry::Connecting(_))
1344 ));
1345 match connection {
1346 Ok(connection) => {
1347 pool.connections.insert(
1348 opts.clone(),
1349 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1350 );
1351 Ok(connection)
1352 }
1353 Err(error) => {
1354 pool.connections.remove(&opts);
1355 Err(Arc::new(error))
1356 }
1357 }
1358 })?
1359 }
1360 })
1361 .shared();
1362
1363 self.connections
1364 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1365 task
1366 }
1367}
1368
1369impl From<SshRemoteClient> for AnyProtoClient {
1370 fn from(client: SshRemoteClient) -> Self {
1371 AnyProtoClient::new(client.client.clone())
1372 }
1373}
1374
1375#[async_trait(?Send)]
1376trait RemoteConnection: Send + Sync {
1377 fn start_proxy(
1378 &self,
1379 unique_identifier: String,
1380 reconnect: bool,
1381 incoming_tx: UnboundedSender<Envelope>,
1382 outgoing_rx: UnboundedReceiver<Envelope>,
1383 connection_activity_tx: Sender<()>,
1384 delegate: Arc<dyn SshClientDelegate>,
1385 cx: &mut AsyncApp,
1386 ) -> Task<Result<i32>>;
1387 fn upload_directory(
1388 &self,
1389 src_path: PathBuf,
1390 dest_path: RemotePathBuf,
1391 cx: &App,
1392 ) -> Task<Result<()>>;
1393 async fn kill(&self) -> Result<()>;
1394 fn has_been_killed(&self) -> bool;
1395 /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1396 /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1397 fn ssh_args(&self) -> SshArgs;
1398 fn connection_options(&self) -> SshConnectionOptions;
1399 fn path_style(&self) -> PathStyle;
1400
1401 #[cfg(any(test, feature = "test-support"))]
1402 fn simulate_disconnect(&self, _: &AsyncApp) {}
1403}
1404
1405struct SshRemoteConnection {
1406 socket: SshSocket,
1407 master_process: Mutex<Option<Child>>,
1408 remote_binary_path: Option<RemotePathBuf>,
1409 ssh_platform: SshPlatform,
1410 ssh_path_style: PathStyle,
1411 _temp_dir: TempDir,
1412}
1413
1414#[async_trait(?Send)]
1415impl RemoteConnection for SshRemoteConnection {
1416 async fn kill(&self) -> Result<()> {
1417 let Some(mut process) = self.master_process.lock().take() else {
1418 return Ok(());
1419 };
1420 process.kill().ok();
1421 process.status().await?;
1422 Ok(())
1423 }
1424
1425 fn has_been_killed(&self) -> bool {
1426 self.master_process.lock().is_none()
1427 }
1428
1429 fn ssh_args(&self) -> SshArgs {
1430 self.socket.ssh_args()
1431 }
1432
1433 fn connection_options(&self) -> SshConnectionOptions {
1434 self.socket.connection_options.clone()
1435 }
1436
1437 fn upload_directory(
1438 &self,
1439 src_path: PathBuf,
1440 dest_path: RemotePathBuf,
1441 cx: &App,
1442 ) -> Task<Result<()>> {
1443 let mut command = util::command::new_smol_command("scp");
1444 let output = self
1445 .socket
1446 .ssh_options(&mut command)
1447 .args(
1448 self.socket
1449 .connection_options
1450 .port
1451 .map(|port| vec!["-P".to_string(), port.to_string()])
1452 .unwrap_or_default(),
1453 )
1454 .arg("-C")
1455 .arg("-r")
1456 .arg(&src_path)
1457 .arg(format!(
1458 "{}:{}",
1459 self.socket.connection_options.scp_url(),
1460 dest_path.to_string()
1461 ))
1462 .output();
1463
1464 cx.background_spawn(async move {
1465 let output = output.await?;
1466
1467 anyhow::ensure!(
1468 output.status.success(),
1469 "failed to upload directory {} -> {}: {}",
1470 src_path.display(),
1471 dest_path.to_string(),
1472 String::from_utf8_lossy(&output.stderr)
1473 );
1474
1475 Ok(())
1476 })
1477 }
1478
1479 fn start_proxy(
1480 &self,
1481 unique_identifier: String,
1482 reconnect: bool,
1483 incoming_tx: UnboundedSender<Envelope>,
1484 outgoing_rx: UnboundedReceiver<Envelope>,
1485 connection_activity_tx: Sender<()>,
1486 delegate: Arc<dyn SshClientDelegate>,
1487 cx: &mut AsyncApp,
1488 ) -> Task<Result<i32>> {
1489 delegate.set_status(Some("Starting proxy"), cx);
1490
1491 let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1492 return Task::ready(Err(anyhow!("Remote binary path not set")));
1493 };
1494
1495 let mut start_proxy_command = shell_script!(
1496 "exec {binary_path} proxy --identifier {identifier}",
1497 binary_path = &remote_binary_path.to_string(),
1498 identifier = &unique_identifier,
1499 );
1500
1501 if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1502 start_proxy_command = format!(
1503 "RUST_LOG={} {}",
1504 shlex::try_quote(&rust_log).unwrap(),
1505 start_proxy_command
1506 )
1507 }
1508 if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1509 start_proxy_command = format!(
1510 "RUST_BACKTRACE={} {}",
1511 shlex::try_quote(&rust_backtrace).unwrap(),
1512 start_proxy_command
1513 )
1514 }
1515 if reconnect {
1516 start_proxy_command.push_str(" --reconnect");
1517 }
1518
1519 let ssh_proxy_process = match self
1520 .socket
1521 .ssh_command("sh", &["-c", &start_proxy_command])
1522 // IMPORTANT: we kill this process when we drop the task that uses it.
1523 .kill_on_drop(true)
1524 .spawn()
1525 {
1526 Ok(process) => process,
1527 Err(error) => {
1528 return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1529 }
1530 };
1531
1532 Self::multiplex(
1533 ssh_proxy_process,
1534 incoming_tx,
1535 outgoing_rx,
1536 connection_activity_tx,
1537 &cx,
1538 )
1539 }
1540
1541 fn path_style(&self) -> PathStyle {
1542 self.ssh_path_style
1543 }
1544}
1545
1546impl SshRemoteConnection {
1547 async fn new(
1548 connection_options: SshConnectionOptions,
1549 delegate: Arc<dyn SshClientDelegate>,
1550 cx: &mut AsyncApp,
1551 ) -> Result<Self> {
1552 use askpass::AskPassResult;
1553
1554 delegate.set_status(Some("Connecting"), cx);
1555
1556 let url = connection_options.ssh_url();
1557
1558 let temp_dir = tempfile::Builder::new()
1559 .prefix("zed-ssh-session")
1560 .tempdir()?;
1561 let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1562 let delegate = delegate.clone();
1563 move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1564 });
1565
1566 let mut askpass =
1567 askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1568
1569 // Start the master SSH process, which does not do anything except for establish
1570 // the connection and keep it open, allowing other ssh commands to reuse it
1571 // via a control socket.
1572 #[cfg(not(target_os = "windows"))]
1573 let socket_path = temp_dir.path().join("ssh.sock");
1574
1575 let mut master_process = {
1576 #[cfg(not(target_os = "windows"))]
1577 let args = [
1578 "-N",
1579 "-o",
1580 "ControlPersist=no",
1581 "-o",
1582 "ControlMaster=yes",
1583 "-o",
1584 ];
1585 // On Windows, `ControlMaster` and `ControlPath` are not supported:
1586 // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1587 // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1588 #[cfg(target_os = "windows")]
1589 let args = ["-N"];
1590 let mut master_process = process::Command::new("ssh");
1591 master_process
1592 .kill_on_drop(true)
1593 .stdin(Stdio::null())
1594 .stdout(Stdio::piped())
1595 .stderr(Stdio::piped())
1596 .env("SSH_ASKPASS_REQUIRE", "force")
1597 .env("SSH_ASKPASS", askpass.script_path())
1598 .args(connection_options.additional_args())
1599 .args(args);
1600 #[cfg(not(target_os = "windows"))]
1601 master_process.arg(format!("ControlPath={}", socket_path.display()));
1602 master_process.arg(&url).spawn()?
1603 };
1604 // Wait for this ssh process to close its stdout, indicating that authentication
1605 // has completed.
1606 let mut stdout = master_process.stdout.take().unwrap();
1607 let mut output = Vec::new();
1608
1609 let result = select_biased! {
1610 result = askpass.run().fuse() => {
1611 match result {
1612 AskPassResult::CancelledByUser => {
1613 master_process.kill().ok();
1614 anyhow::bail!("SSH connection canceled")
1615 }
1616 AskPassResult::Timedout => {
1617 anyhow::bail!("connecting to host timed out")
1618 }
1619 }
1620 }
1621 _ = stdout.read_to_end(&mut output).fuse() => {
1622 anyhow::Ok(())
1623 }
1624 };
1625
1626 if let Err(e) = result {
1627 return Err(e.context("Failed to connect to host"));
1628 }
1629
1630 if master_process.try_status()?.is_some() {
1631 output.clear();
1632 let mut stderr = master_process.stderr.take().unwrap();
1633 stderr.read_to_end(&mut output).await?;
1634
1635 let error_message = format!(
1636 "failed to connect: {}",
1637 String::from_utf8_lossy(&output).trim()
1638 );
1639 anyhow::bail!(error_message);
1640 }
1641
1642 #[cfg(not(target_os = "windows"))]
1643 let socket = SshSocket::new(connection_options, socket_path)?;
1644 #[cfg(target_os = "windows")]
1645 let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1646 drop(askpass);
1647
1648 let ssh_platform = socket.platform().await?;
1649 let ssh_path_style = match ssh_platform.os {
1650 "windows" => PathStyle::Windows,
1651 _ => PathStyle::Posix,
1652 };
1653
1654 let mut this = Self {
1655 socket,
1656 master_process: Mutex::new(Some(master_process)),
1657 _temp_dir: temp_dir,
1658 remote_binary_path: None,
1659 ssh_path_style,
1660 ssh_platform,
1661 };
1662
1663 let (release_channel, version, commit) = cx.update(|cx| {
1664 (
1665 ReleaseChannel::global(cx),
1666 AppVersion::global(cx),
1667 AppCommitSha::try_global(cx),
1668 )
1669 })?;
1670 this.remote_binary_path = Some(
1671 this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1672 .await?,
1673 );
1674
1675 Ok(this)
1676 }
1677
1678 fn multiplex(
1679 mut ssh_proxy_process: Child,
1680 incoming_tx: UnboundedSender<Envelope>,
1681 mut outgoing_rx: UnboundedReceiver<Envelope>,
1682 mut connection_activity_tx: Sender<()>,
1683 cx: &AsyncApp,
1684 ) -> Task<Result<i32>> {
1685 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1686 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1687 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1688
1689 let mut stdin_buffer = Vec::new();
1690 let mut stdout_buffer = Vec::new();
1691 let mut stderr_buffer = Vec::new();
1692 let mut stderr_offset = 0;
1693
1694 let stdin_task = cx.background_spawn(async move {
1695 while let Some(outgoing) = outgoing_rx.next().await {
1696 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1697 }
1698 anyhow::Ok(())
1699 });
1700
1701 let stdout_task = cx.background_spawn({
1702 let mut connection_activity_tx = connection_activity_tx.clone();
1703 async move {
1704 loop {
1705 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1706 let len = child_stdout.read(&mut stdout_buffer).await?;
1707
1708 if len == 0 {
1709 return anyhow::Ok(());
1710 }
1711
1712 if len < MESSAGE_LEN_SIZE {
1713 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1714 }
1715
1716 let message_len = message_len_from_buffer(&stdout_buffer);
1717 let envelope =
1718 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1719 .await?;
1720 connection_activity_tx.try_send(()).ok();
1721 incoming_tx.unbounded_send(envelope).ok();
1722 }
1723 }
1724 });
1725
1726 let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1727 loop {
1728 stderr_buffer.resize(stderr_offset + 1024, 0);
1729
1730 let len = child_stderr
1731 .read(&mut stderr_buffer[stderr_offset..])
1732 .await?;
1733 if len == 0 {
1734 return anyhow::Ok(());
1735 }
1736
1737 stderr_offset += len;
1738 let mut start_ix = 0;
1739 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1740 .iter()
1741 .position(|b| b == &b'\n')
1742 {
1743 let line_ix = start_ix + ix;
1744 let content = &stderr_buffer[start_ix..line_ix];
1745 start_ix = line_ix + 1;
1746 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1747 record.log(log::logger())
1748 } else {
1749 eprintln!("(remote) {}", String::from_utf8_lossy(content));
1750 }
1751 }
1752 stderr_buffer.drain(0..start_ix);
1753 stderr_offset -= start_ix;
1754
1755 connection_activity_tx.try_send(()).ok();
1756 }
1757 });
1758
1759 cx.spawn(async move |_| {
1760 let result = futures::select! {
1761 result = stdin_task.fuse() => {
1762 result.context("stdin")
1763 }
1764 result = stdout_task.fuse() => {
1765 result.context("stdout")
1766 }
1767 result = stderr_task.fuse() => {
1768 result.context("stderr")
1769 }
1770 };
1771
1772 let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1773 match result {
1774 Ok(_) => Ok(status),
1775 Err(error) => Err(error),
1776 }
1777 })
1778 }
1779
1780 #[allow(unused)]
1781 async fn ensure_server_binary(
1782 &self,
1783 delegate: &Arc<dyn SshClientDelegate>,
1784 release_channel: ReleaseChannel,
1785 version: SemanticVersion,
1786 commit: Option<AppCommitSha>,
1787 cx: &mut AsyncApp,
1788 ) -> Result<RemotePathBuf> {
1789 let version_str = match release_channel {
1790 ReleaseChannel::Nightly => {
1791 let commit = commit.map(|s| s.full()).unwrap_or_default();
1792 format!("{}-{}", version, commit)
1793 }
1794 ReleaseChannel::Dev => "build".to_string(),
1795 _ => version.to_string(),
1796 };
1797 let binary_name = format!(
1798 "zed-remote-server-{}-{}",
1799 release_channel.dev_name(),
1800 version_str
1801 );
1802 let dst_path = RemotePathBuf::new(
1803 paths::remote_server_dir_relative().join(binary_name),
1804 self.ssh_path_style,
1805 );
1806
1807 let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1808 #[cfg(debug_assertions)]
1809 if let Some(build_remote_server) = build_remote_server {
1810 let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1811 let tmp_path = RemotePathBuf::new(
1812 paths::remote_server_dir_relative().join(format!(
1813 "download-{}-{}",
1814 std::process::id(),
1815 src_path.file_name().unwrap().to_string_lossy()
1816 )),
1817 self.ssh_path_style,
1818 );
1819 self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1820 .await?;
1821 self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1822 .await?;
1823 return Ok(dst_path);
1824 }
1825
1826 if self
1827 .socket
1828 .run_command(&dst_path.to_string(), &["version"])
1829 .await
1830 .is_ok()
1831 {
1832 return Ok(dst_path);
1833 }
1834
1835 let wanted_version = cx.update(|cx| match release_channel {
1836 ReleaseChannel::Nightly => Ok(None),
1837 ReleaseChannel::Dev => {
1838 anyhow::bail!(
1839 "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1840 dst_path
1841 )
1842 }
1843 _ => Ok(Some(AppVersion::global(cx))),
1844 })??;
1845
1846 let tmp_path_gz = RemotePathBuf::new(
1847 PathBuf::from(format!(
1848 "{}-download-{}.gz",
1849 dst_path.to_string(),
1850 std::process::id()
1851 )),
1852 self.ssh_path_style,
1853 );
1854 if !self.socket.connection_options.upload_binary_over_ssh {
1855 if let Some((url, body)) = delegate
1856 .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1857 .await?
1858 {
1859 match self
1860 .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1861 .await
1862 {
1863 Ok(_) => {
1864 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1865 .await?;
1866 return Ok(dst_path);
1867 }
1868 Err(e) => {
1869 log::error!(
1870 "Failed to download binary on server, attempting to upload server: {}",
1871 e
1872 )
1873 }
1874 }
1875 }
1876 }
1877
1878 let src_path = delegate
1879 .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1880 .await?;
1881 self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1882 .await?;
1883 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1884 .await?;
1885 return Ok(dst_path);
1886 }
1887
1888 async fn download_binary_on_server(
1889 &self,
1890 url: &str,
1891 body: &str,
1892 tmp_path_gz: &RemotePathBuf,
1893 delegate: &Arc<dyn SshClientDelegate>,
1894 cx: &mut AsyncApp,
1895 ) -> Result<()> {
1896 if let Some(parent) = tmp_path_gz.parent() {
1897 self.socket
1898 .run_command(
1899 "sh",
1900 &[
1901 "-c",
1902 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1903 ],
1904 )
1905 .await?;
1906 }
1907
1908 delegate.set_status(Some("Downloading remote development server on host"), cx);
1909
1910 match self
1911 .socket
1912 .run_command(
1913 "curl",
1914 &[
1915 "-f",
1916 "-L",
1917 "-X",
1918 "GET",
1919 "-H",
1920 "Content-Type: application/json",
1921 "-d",
1922 &body,
1923 &url,
1924 "-o",
1925 &tmp_path_gz.to_string(),
1926 ],
1927 )
1928 .await
1929 {
1930 Ok(_) => {}
1931 Err(e) => {
1932 if self.socket.run_command("which", &["curl"]).await.is_ok() {
1933 return Err(e);
1934 }
1935
1936 match self
1937 .socket
1938 .run_command(
1939 "wget",
1940 &[
1941 "--method=GET",
1942 "--header=Content-Type: application/json",
1943 "--body-data",
1944 &body,
1945 &url,
1946 "-O",
1947 &tmp_path_gz.to_string(),
1948 ],
1949 )
1950 .await
1951 {
1952 Ok(_) => {}
1953 Err(e) => {
1954 if self.socket.run_command("which", &["wget"]).await.is_ok() {
1955 return Err(e);
1956 } else {
1957 anyhow::bail!("Neither curl nor wget is available");
1958 }
1959 }
1960 }
1961 }
1962 }
1963
1964 Ok(())
1965 }
1966
1967 async fn upload_local_server_binary(
1968 &self,
1969 src_path: &Path,
1970 tmp_path_gz: &RemotePathBuf,
1971 delegate: &Arc<dyn SshClientDelegate>,
1972 cx: &mut AsyncApp,
1973 ) -> Result<()> {
1974 if let Some(parent) = tmp_path_gz.parent() {
1975 self.socket
1976 .run_command(
1977 "sh",
1978 &[
1979 "-c",
1980 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1981 ],
1982 )
1983 .await?;
1984 }
1985
1986 let src_stat = fs::metadata(&src_path).await?;
1987 let size = src_stat.len();
1988
1989 let t0 = Instant::now();
1990 delegate.set_status(Some("Uploading remote development server"), cx);
1991 log::info!(
1992 "uploading remote development server to {:?} ({}kb)",
1993 tmp_path_gz,
1994 size / 1024
1995 );
1996 self.upload_file(&src_path, &tmp_path_gz)
1997 .await
1998 .context("failed to upload server binary")?;
1999 log::info!("uploaded remote development server in {:?}", t0.elapsed());
2000 Ok(())
2001 }
2002
2003 async fn extract_server_binary(
2004 &self,
2005 dst_path: &RemotePathBuf,
2006 tmp_path: &RemotePathBuf,
2007 delegate: &Arc<dyn SshClientDelegate>,
2008 cx: &mut AsyncApp,
2009 ) -> Result<()> {
2010 delegate.set_status(Some("Extracting remote development server"), cx);
2011 let server_mode = 0o755;
2012
2013 let orig_tmp_path = tmp_path.to_string();
2014 let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2015 shell_script!(
2016 "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2017 server_mode = &format!("{:o}", server_mode),
2018 dst_path = &dst_path.to_string(),
2019 )
2020 } else {
2021 shell_script!(
2022 "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2023 server_mode = &format!("{:o}", server_mode),
2024 dst_path = &dst_path.to_string()
2025 )
2026 };
2027 self.socket.run_command("sh", &["-c", &script]).await?;
2028 Ok(())
2029 }
2030
2031 async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2032 log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2033 let mut command = util::command::new_smol_command("scp");
2034 let output = self
2035 .socket
2036 .ssh_options(&mut command)
2037 .args(
2038 self.socket
2039 .connection_options
2040 .port
2041 .map(|port| vec!["-P".to_string(), port.to_string()])
2042 .unwrap_or_default(),
2043 )
2044 .arg(src_path)
2045 .arg(format!(
2046 "{}:{}",
2047 self.socket.connection_options.scp_url(),
2048 dest_path.to_string()
2049 ))
2050 .output()
2051 .await?;
2052
2053 anyhow::ensure!(
2054 output.status.success(),
2055 "failed to upload file {} -> {}: {}",
2056 src_path.display(),
2057 dest_path.to_string(),
2058 String::from_utf8_lossy(&output.stderr)
2059 );
2060 Ok(())
2061 }
2062
2063 #[cfg(debug_assertions)]
2064 async fn build_local(
2065 &self,
2066 build_remote_server: String,
2067 delegate: &Arc<dyn SshClientDelegate>,
2068 cx: &mut AsyncApp,
2069 ) -> Result<PathBuf> {
2070 use smol::process::{Command, Stdio};
2071
2072 async fn run_cmd(command: &mut Command) -> Result<()> {
2073 let output = command
2074 .kill_on_drop(true)
2075 .stderr(Stdio::inherit())
2076 .output()
2077 .await?;
2078 anyhow::ensure!(
2079 output.status.success(),
2080 "Failed to run command: {command:?}"
2081 );
2082 Ok(())
2083 }
2084
2085 if self.ssh_platform.arch == std::env::consts::ARCH
2086 && self.ssh_platform.os == std::env::consts::OS
2087 {
2088 delegate.set_status(Some("Building remote server binary from source"), cx);
2089 log::info!("building remote server binary from source");
2090 run_cmd(Command::new("cargo").args([
2091 "build",
2092 "--package",
2093 "remote_server",
2094 "--features",
2095 "debug-embed",
2096 "--target-dir",
2097 "target/remote_server",
2098 ]))
2099 .await?;
2100
2101 delegate.set_status(Some("Compressing binary"), cx);
2102
2103 run_cmd(Command::new("gzip").args([
2104 "-9",
2105 "-f",
2106 "target/remote_server/debug/remote_server",
2107 ]))
2108 .await?;
2109
2110 let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
2111 return Ok(path);
2112 }
2113 let Some(triple) = self.ssh_platform.triple() else {
2114 anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform);
2115 };
2116 smol::fs::create_dir_all("target/remote_server").await?;
2117
2118 if build_remote_server.contains("cross") {
2119 #[cfg(target_os = "windows")]
2120 use util::paths::SanitizedPath;
2121
2122 delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2123 log::info!("installing cross");
2124 run_cmd(Command::new("cargo").args([
2125 "install",
2126 "cross",
2127 "--git",
2128 "https://github.com/cross-rs/cross",
2129 ]))
2130 .await?;
2131
2132 delegate.set_status(
2133 Some(&format!(
2134 "Building remote server binary from source for {} with Docker",
2135 &triple
2136 )),
2137 cx,
2138 );
2139 log::info!("building remote server binary from source for {}", &triple);
2140
2141 // On Windows, the binding needs to be set to the canonical path
2142 #[cfg(target_os = "windows")]
2143 let src =
2144 SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2145 #[cfg(not(target_os = "windows"))]
2146 let src = "./target";
2147 run_cmd(
2148 Command::new("cross")
2149 .args([
2150 "build",
2151 "--package",
2152 "remote_server",
2153 "--features",
2154 "debug-embed",
2155 "--target-dir",
2156 "target/remote_server",
2157 "--target",
2158 &triple,
2159 ])
2160 .env(
2161 "CROSS_CONTAINER_OPTS",
2162 format!("--mount type=bind,src={src},dst=/app/target"),
2163 ),
2164 )
2165 .await?;
2166 } else {
2167 let which = cx
2168 .background_spawn(async move { which::which("zig") })
2169 .await;
2170
2171 if which.is_err() {
2172 #[cfg(not(target_os = "windows"))]
2173 {
2174 anyhow::bail!(
2175 "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2176 )
2177 }
2178 #[cfg(target_os = "windows")]
2179 {
2180 anyhow::bail!(
2181 "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2182 )
2183 }
2184 }
2185
2186 delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2187 log::info!("adding rustup target");
2188 run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2189
2190 delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2191 log::info!("installing cargo-zigbuild");
2192 run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
2193
2194 delegate.set_status(
2195 Some(&format!(
2196 "Building remote binary from source for {triple} with Zig"
2197 )),
2198 cx,
2199 );
2200 log::info!("building remote binary from source for {triple} with Zig");
2201 run_cmd(Command::new("cargo").args([
2202 "zigbuild",
2203 "--package",
2204 "remote_server",
2205 "--features",
2206 "debug-embed",
2207 "--target-dir",
2208 "target/remote_server",
2209 "--target",
2210 &triple,
2211 ]))
2212 .await?;
2213 };
2214
2215 let mut path = format!("target/remote_server/{triple}/debug/remote_server").into();
2216 if !build_remote_server.contains("nocompress") {
2217 delegate.set_status(Some("Compressing binary"), cx);
2218
2219 #[cfg(not(target_os = "windows"))]
2220 {
2221 run_cmd(Command::new("gzip").args([
2222 "-9",
2223 "-f",
2224 &format!("target/remote_server/{}/debug/remote_server", triple),
2225 ]))
2226 .await?;
2227 }
2228 #[cfg(target_os = "windows")]
2229 {
2230 // On Windows, we use 7z to compress the binary
2231 let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2232 let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2233 if smol::fs::metadata(&gz_path).await.is_ok() {
2234 smol::fs::remove_file(&gz_path).await?;
2235 }
2236 run_cmd(Command::new(seven_zip).args([
2237 "a",
2238 "-tgzip",
2239 &gz_path,
2240 &format!("target/remote_server/{}/debug/remote_server", triple),
2241 ]))
2242 .await?;
2243 }
2244
2245 path = std::env::current_dir()?.join(format!(
2246 "target/remote_server/{triple}/debug/remote_server.gz"
2247 ));
2248 }
2249
2250 return Ok(path);
2251 }
2252}
2253
2254type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2255
2256pub struct ChannelClient {
2257 next_message_id: AtomicU32,
2258 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2259 buffer: Mutex<VecDeque<Envelope>>,
2260 response_channels: ResponseChannels,
2261 message_handlers: Mutex<ProtoMessageHandlerSet>,
2262 max_received: AtomicU32,
2263 name: &'static str,
2264 task: Mutex<Task<Result<()>>>,
2265}
2266
2267impl ChannelClient {
2268 pub fn new(
2269 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2270 outgoing_tx: mpsc::UnboundedSender<Envelope>,
2271 cx: &App,
2272 name: &'static str,
2273 ) -> Arc<Self> {
2274 Arc::new_cyclic(|this| Self {
2275 outgoing_tx: Mutex::new(outgoing_tx),
2276 next_message_id: AtomicU32::new(0),
2277 max_received: AtomicU32::new(0),
2278 response_channels: ResponseChannels::default(),
2279 message_handlers: Default::default(),
2280 buffer: Mutex::new(VecDeque::new()),
2281 name,
2282 task: Mutex::new(Self::start_handling_messages(
2283 this.clone(),
2284 incoming_rx,
2285 &cx.to_async(),
2286 )),
2287 })
2288 }
2289
2290 fn start_handling_messages(
2291 this: Weak<Self>,
2292 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2293 cx: &AsyncApp,
2294 ) -> Task<Result<()>> {
2295 cx.spawn(async move |cx| {
2296 let peer_id = PeerId { owner_id: 0, id: 0 };
2297 while let Some(incoming) = incoming_rx.next().await {
2298 let Some(this) = this.upgrade() else {
2299 return anyhow::Ok(());
2300 };
2301 if let Some(ack_id) = incoming.ack_id {
2302 let mut buffer = this.buffer.lock();
2303 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2304 buffer.pop_front();
2305 }
2306 }
2307 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2308 {
2309 log::debug!(
2310 "{}:ssh message received. name:FlushBufferedMessages",
2311 this.name
2312 );
2313 {
2314 let buffer = this.buffer.lock();
2315 for envelope in buffer.iter() {
2316 this.outgoing_tx
2317 .lock()
2318 .unbounded_send(envelope.clone())
2319 .ok();
2320 }
2321 }
2322 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2323 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2324 this.outgoing_tx.lock().unbounded_send(envelope).ok();
2325 continue;
2326 }
2327
2328 this.max_received.store(incoming.id, SeqCst);
2329
2330 if let Some(request_id) = incoming.responding_to {
2331 let request_id = MessageId(request_id);
2332 let sender = this.response_channels.lock().remove(&request_id);
2333 if let Some(sender) = sender {
2334 let (tx, rx) = oneshot::channel();
2335 if incoming.payload.is_some() {
2336 sender.send((incoming, tx)).ok();
2337 }
2338 rx.await.ok();
2339 }
2340 } else if let Some(envelope) =
2341 build_typed_envelope(peer_id, Instant::now(), incoming)
2342 {
2343 let type_name = envelope.payload_type_name();
2344 if let Some(future) = ProtoMessageHandlerSet::handle_message(
2345 &this.message_handlers,
2346 envelope,
2347 this.clone().into(),
2348 cx.clone(),
2349 ) {
2350 log::debug!("{}:ssh message received. name:{type_name}", this.name);
2351 cx.foreground_executor()
2352 .spawn(async move {
2353 match future.await {
2354 Ok(_) => {
2355 log::debug!(
2356 "{}:ssh message handled. name:{type_name}",
2357 this.name
2358 );
2359 }
2360 Err(error) => {
2361 log::error!(
2362 "{}:error handling message. type:{}, error:{}",
2363 this.name,
2364 type_name,
2365 format!("{error:#}").lines().fold(
2366 String::new(),
2367 |mut message, line| {
2368 if !message.is_empty() {
2369 message.push(' ');
2370 }
2371 message.push_str(line);
2372 message
2373 }
2374 )
2375 );
2376 }
2377 }
2378 })
2379 .detach()
2380 } else {
2381 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2382 }
2383 }
2384 }
2385 anyhow::Ok(())
2386 })
2387 }
2388
2389 pub fn reconnect(
2390 self: &Arc<Self>,
2391 incoming_rx: UnboundedReceiver<Envelope>,
2392 outgoing_tx: UnboundedSender<Envelope>,
2393 cx: &AsyncApp,
2394 ) {
2395 *self.outgoing_tx.lock() = outgoing_tx;
2396 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2397 }
2398
2399 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2400 let id = (TypeId::of::<E>(), remote_id);
2401
2402 let mut message_handlers = self.message_handlers.lock();
2403 if message_handlers
2404 .entities_by_type_and_remote_id
2405 .contains_key(&id)
2406 {
2407 panic!("already subscribed to entity");
2408 }
2409
2410 message_handlers.entities_by_type_and_remote_id.insert(
2411 id,
2412 EntityMessageSubscriber::Entity {
2413 handle: entity.downgrade().into(),
2414 },
2415 );
2416 }
2417
2418 pub fn request<T: RequestMessage>(
2419 &self,
2420 payload: T,
2421 ) -> impl 'static + Future<Output = Result<T::Response>> {
2422 self.request_internal(payload, true)
2423 }
2424
2425 fn request_internal<T: RequestMessage>(
2426 &self,
2427 payload: T,
2428 use_buffer: bool,
2429 ) -> impl 'static + Future<Output = Result<T::Response>> {
2430 log::debug!("ssh request start. name:{}", T::NAME);
2431 let response =
2432 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2433 async move {
2434 let response = response.await?;
2435 log::debug!("ssh request finish. name:{}", T::NAME);
2436 T::Response::from_envelope(response).context("received a response of the wrong type")
2437 }
2438 }
2439
2440 pub async fn resync(&self, timeout: Duration) -> Result<()> {
2441 smol::future::or(
2442 async {
2443 self.request_internal(proto::FlushBufferedMessages {}, false)
2444 .await?;
2445
2446 for envelope in self.buffer.lock().iter() {
2447 self.outgoing_tx
2448 .lock()
2449 .unbounded_send(envelope.clone())
2450 .ok();
2451 }
2452 Ok(())
2453 },
2454 async {
2455 smol::Timer::after(timeout).await;
2456 anyhow::bail!("Timeout detected")
2457 },
2458 )
2459 .await
2460 }
2461
2462 pub async fn ping(&self, timeout: Duration) -> Result<()> {
2463 smol::future::or(
2464 async {
2465 self.request(proto::Ping {}).await?;
2466 Ok(())
2467 },
2468 async {
2469 smol::Timer::after(timeout).await;
2470 anyhow::bail!("Timeout detected")
2471 },
2472 )
2473 .await
2474 }
2475
2476 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2477 log::debug!("ssh send name:{}", T::NAME);
2478 self.send_dynamic(payload.into_envelope(0, None, None))
2479 }
2480
2481 fn request_dynamic(
2482 &self,
2483 mut envelope: proto::Envelope,
2484 type_name: &'static str,
2485 use_buffer: bool,
2486 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2487 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2488 let (tx, rx) = oneshot::channel();
2489 let mut response_channels_lock = self.response_channels.lock();
2490 response_channels_lock.insert(MessageId(envelope.id), tx);
2491 drop(response_channels_lock);
2492
2493 let result = if use_buffer {
2494 self.send_buffered(envelope)
2495 } else {
2496 self.send_unbuffered(envelope)
2497 };
2498 async move {
2499 if let Err(error) = &result {
2500 log::error!("failed to send message: {error}");
2501 anyhow::bail!("failed to send message: {error}");
2502 }
2503
2504 let response = rx.await.context("connection lost")?.0;
2505 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2506 return Err(RpcError::from_proto(error, type_name));
2507 }
2508 Ok(response)
2509 }
2510 }
2511
2512 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2513 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2514 self.send_buffered(envelope)
2515 }
2516
2517 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2518 envelope.ack_id = Some(self.max_received.load(SeqCst));
2519 self.buffer.lock().push_back(envelope.clone());
2520 // ignore errors on send (happen while we're reconnecting)
2521 // assume that the global "disconnected" overlay is sufficient.
2522 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2523 Ok(())
2524 }
2525
2526 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2527 envelope.ack_id = Some(self.max_received.load(SeqCst));
2528 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2529 Ok(())
2530 }
2531}
2532
2533impl ProtoClient for ChannelClient {
2534 fn request(
2535 &self,
2536 envelope: proto::Envelope,
2537 request_type: &'static str,
2538 ) -> BoxFuture<'static, Result<proto::Envelope>> {
2539 self.request_dynamic(envelope, request_type, true).boxed()
2540 }
2541
2542 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2543 self.send_dynamic(envelope)
2544 }
2545
2546 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2547 self.send_dynamic(envelope)
2548 }
2549
2550 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2551 &self.message_handlers
2552 }
2553
2554 fn is_via_collab(&self) -> bool {
2555 false
2556 }
2557}
2558
2559#[cfg(any(test, feature = "test-support"))]
2560mod fake {
2561 use std::{path::PathBuf, sync::Arc};
2562
2563 use anyhow::Result;
2564 use async_trait::async_trait;
2565 use futures::{
2566 FutureExt, SinkExt, StreamExt,
2567 channel::{
2568 mpsc::{self, Sender},
2569 oneshot,
2570 },
2571 select_biased,
2572 };
2573 use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2574 use release_channel::ReleaseChannel;
2575 use rpc::proto::Envelope;
2576 use util::paths::{PathStyle, RemotePathBuf};
2577
2578 use super::{
2579 ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2580 SshPlatform,
2581 };
2582
2583 pub(super) struct FakeRemoteConnection {
2584 pub(super) connection_options: SshConnectionOptions,
2585 pub(super) server_channel: Arc<ChannelClient>,
2586 pub(super) server_cx: SendableCx,
2587 }
2588
2589 pub(super) struct SendableCx(AsyncApp);
2590 impl SendableCx {
2591 // SAFETY: When run in test mode, GPUI is always single threaded.
2592 pub(super) fn new(cx: &TestAppContext) -> Self {
2593 Self(cx.to_async())
2594 }
2595
2596 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2597 fn get(&self, _: &AsyncApp) -> AsyncApp {
2598 self.0.clone()
2599 }
2600 }
2601
2602 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2603 unsafe impl Send for SendableCx {}
2604 unsafe impl Sync for SendableCx {}
2605
2606 #[async_trait(?Send)]
2607 impl RemoteConnection for FakeRemoteConnection {
2608 async fn kill(&self) -> Result<()> {
2609 Ok(())
2610 }
2611
2612 fn has_been_killed(&self) -> bool {
2613 false
2614 }
2615
2616 fn ssh_args(&self) -> SshArgs {
2617 SshArgs {
2618 arguments: Vec::new(),
2619 envs: None,
2620 }
2621 }
2622
2623 fn upload_directory(
2624 &self,
2625 _src_path: PathBuf,
2626 _dest_path: RemotePathBuf,
2627 _cx: &App,
2628 ) -> Task<Result<()>> {
2629 unreachable!()
2630 }
2631
2632 fn connection_options(&self) -> SshConnectionOptions {
2633 self.connection_options.clone()
2634 }
2635
2636 fn simulate_disconnect(&self, cx: &AsyncApp) {
2637 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2638 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2639 self.server_channel
2640 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2641 }
2642
2643 fn start_proxy(
2644 &self,
2645 _unique_identifier: String,
2646 _reconnect: bool,
2647 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2648 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2649 mut connection_activity_tx: Sender<()>,
2650 _delegate: Arc<dyn SshClientDelegate>,
2651 cx: &mut AsyncApp,
2652 ) -> Task<Result<i32>> {
2653 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2654 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2655
2656 self.server_channel.reconnect(
2657 server_incoming_rx,
2658 server_outgoing_tx,
2659 &self.server_cx.get(cx),
2660 );
2661
2662 cx.background_spawn(async move {
2663 loop {
2664 select_biased! {
2665 server_to_client = server_outgoing_rx.next().fuse() => {
2666 let Some(server_to_client) = server_to_client else {
2667 return Ok(1)
2668 };
2669 connection_activity_tx.try_send(()).ok();
2670 client_incoming_tx.send(server_to_client).await.ok();
2671 }
2672 client_to_server = client_outgoing_rx.next().fuse() => {
2673 let Some(client_to_server) = client_to_server else {
2674 return Ok(1)
2675 };
2676 server_incoming_tx.send(client_to_server).await.ok();
2677 }
2678 }
2679 }
2680 })
2681 }
2682
2683 fn path_style(&self) -> PathStyle {
2684 PathStyle::current()
2685 }
2686 }
2687
2688 pub(super) struct Delegate;
2689
2690 impl SshClientDelegate for Delegate {
2691 fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2692 unreachable!()
2693 }
2694
2695 fn download_server_binary_locally(
2696 &self,
2697 _: SshPlatform,
2698 _: ReleaseChannel,
2699 _: Option<SemanticVersion>,
2700 _: &mut AsyncApp,
2701 ) -> Task<Result<PathBuf>> {
2702 unreachable!()
2703 }
2704
2705 fn get_download_params(
2706 &self,
2707 _platform: SshPlatform,
2708 _release_channel: ReleaseChannel,
2709 _version: Option<SemanticVersion>,
2710 _cx: &mut AsyncApp,
2711 ) -> Task<Result<Option<(String, String)>>> {
2712 unreachable!()
2713 }
2714
2715 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2716 }
2717}