1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{Context as _, Result, anyhow};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
13 channel::{
14 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
15 oneshot,
16 },
17 future::{BoxFuture, Shared},
18 select, select_biased,
19};
20use gpui::{
21 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
22 EventEmitter, Global, SemanticVersion, Task, WeakEntity,
23};
24use itertools::Itertools;
25use parking_lot::Mutex;
26
27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
28use rpc::{
29 AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
30 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
31};
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use smol::{
35 fs,
36 process::{self, Child, Stdio},
37};
38use std::{
39 collections::VecDeque,
40 fmt, iter,
41 ops::ControlFlow,
42 path::{Path, PathBuf},
43 sync::{
44 Arc, Weak,
45 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
46 },
47 time::{Duration, Instant},
48};
49use tempfile::TempDir;
50use util::{
51 ResultExt,
52 paths::{PathStyle, RemotePathBuf},
53};
54
55#[derive(
56 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
57)]
58pub struct SshProjectId(pub u64);
59
60#[derive(Clone)]
61pub struct SshSocket {
62 connection_options: SshConnectionOptions,
63 #[cfg(not(target_os = "windows"))]
64 socket_path: PathBuf,
65 #[cfg(target_os = "windows")]
66 envs: HashMap<String, String>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
70pub struct SshPortForwardOption {
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub local_host: Option<String>,
73 pub local_port: u16,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub remote_host: Option<String>,
76 pub remote_port: u16,
77}
78
79#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
80pub struct SshConnectionOptions {
81 pub host: String,
82 pub username: Option<String>,
83 pub port: Option<u16>,
84 pub password: Option<String>,
85 pub args: Option<Vec<String>>,
86 pub port_forwards: Option<Vec<SshPortForwardOption>>,
87
88 pub nickname: Option<String>,
89 pub upload_binary_over_ssh: bool,
90}
91
92pub struct SshArgs {
93 pub arguments: Vec<String>,
94 pub envs: Option<HashMap<String, String>>,
95}
96
97#[macro_export]
98macro_rules! shell_script {
99 ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
100 format!(
101 $fmt,
102 $(
103 $name = shlex::try_quote($arg).unwrap()
104 ),+
105 )
106 }};
107}
108
109fn parse_port_number(port_str: &str) -> Result<u16> {
110 port_str
111 .parse()
112 .with_context(|| format!("parsing port number: {port_str}"))
113}
114
115fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
116 let parts: Vec<&str> = spec.split(':').collect();
117
118 match parts.len() {
119 4 => {
120 let local_port = parse_port_number(parts[1])?;
121 let remote_port = parse_port_number(parts[3])?;
122
123 Ok(SshPortForwardOption {
124 local_host: Some(parts[0].to_string()),
125 local_port,
126 remote_host: Some(parts[2].to_string()),
127 remote_port,
128 })
129 }
130 3 => {
131 let local_port = parse_port_number(parts[0])?;
132 let remote_port = parse_port_number(parts[2])?;
133
134 Ok(SshPortForwardOption {
135 local_host: None,
136 local_port,
137 remote_host: Some(parts[1].to_string()),
138 remote_port,
139 })
140 }
141 _ => anyhow::bail!("Invalid port forward format"),
142 }
143}
144
145impl SshConnectionOptions {
146 pub fn parse_command_line(input: &str) -> Result<Self> {
147 let input = input.trim_start_matches("ssh ");
148 let mut hostname: Option<String> = None;
149 let mut username: Option<String> = None;
150 let mut port: Option<u16> = None;
151 let mut args = Vec::new();
152 let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
153
154 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
155 const ALLOWED_OPTS: &[&str] = &[
156 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
157 ];
158 const ALLOWED_ARGS: &[&str] = &[
159 "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
160 "-w",
161 ];
162
163 let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
164
165 'outer: while let Some(arg) = tokens.next() {
166 if ALLOWED_OPTS.contains(&(&arg as &str)) {
167 args.push(arg.to_string());
168 continue;
169 }
170 if arg == "-p" {
171 port = tokens.next().and_then(|arg| arg.parse().ok());
172 continue;
173 } else if let Some(p) = arg.strip_prefix("-p") {
174 port = p.parse().ok();
175 continue;
176 }
177 if arg == "-l" {
178 username = tokens.next();
179 continue;
180 } else if let Some(l) = arg.strip_prefix("-l") {
181 username = Some(l.to_string());
182 continue;
183 }
184 if arg == "-L" || arg.starts_with("-L") {
185 let forward_spec = if arg == "-L" {
186 tokens.next()
187 } else {
188 Some(arg.strip_prefix("-L").unwrap().to_string())
189 };
190
191 if let Some(spec) = forward_spec {
192 port_forwards.push(parse_port_forward_spec(&spec)?);
193 } else {
194 anyhow::bail!("Missing port forward format");
195 }
196 }
197
198 for a in ALLOWED_ARGS {
199 if arg == *a {
200 args.push(arg);
201 if let Some(next) = tokens.next() {
202 args.push(next);
203 }
204 continue 'outer;
205 } else if arg.starts_with(a) {
206 args.push(arg);
207 continue 'outer;
208 }
209 }
210 if arg.starts_with("-") || hostname.is_some() {
211 anyhow::bail!("unsupported argument: {:?}", arg);
212 }
213 let mut input = &arg as &str;
214 // Destination might be: username1@username2@ip2@ip1
215 if let Some((u, rest)) = input.rsplit_once('@') {
216 input = rest;
217 username = Some(u.to_string());
218 }
219 if let Some((rest, p)) = input.split_once(':') {
220 input = rest;
221 port = p.parse().ok()
222 }
223 hostname = Some(input.to_string())
224 }
225
226 let Some(hostname) = hostname else {
227 anyhow::bail!("missing hostname");
228 };
229
230 let port_forwards = match port_forwards.len() {
231 0 => None,
232 _ => Some(port_forwards),
233 };
234
235 Ok(Self {
236 host: hostname.to_string(),
237 username: username.clone(),
238 port,
239 port_forwards,
240 args: Some(args),
241 password: None,
242 nickname: None,
243 upload_binary_over_ssh: false,
244 })
245 }
246
247 pub fn ssh_url(&self) -> String {
248 let mut result = String::from("ssh://");
249 if let Some(username) = &self.username {
250 // Username might be: username1@username2@ip2
251 let username = urlencoding::encode(username);
252 result.push_str(&username);
253 result.push('@');
254 }
255 result.push_str(&self.host);
256 if let Some(port) = self.port {
257 result.push(':');
258 result.push_str(&port.to_string());
259 }
260 result
261 }
262
263 pub fn additional_args(&self) -> Vec<String> {
264 let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
265
266 if let Some(forwards) = &self.port_forwards {
267 args.extend(forwards.iter().map(|pf| {
268 let local_host = match &pf.local_host {
269 Some(host) => host,
270 None => "localhost",
271 };
272 let remote_host = match &pf.remote_host {
273 Some(host) => host,
274 None => "localhost",
275 };
276
277 format!(
278 "-L{}:{}:{}:{}",
279 local_host, pf.local_port, remote_host, pf.remote_port
280 )
281 }));
282 }
283
284 args
285 }
286
287 fn scp_url(&self) -> String {
288 if let Some(username) = &self.username {
289 format!("{}@{}", username, self.host)
290 } else {
291 self.host.clone()
292 }
293 }
294
295 pub fn connection_string(&self) -> String {
296 let host = if let Some(username) = &self.username {
297 format!("{}@{}", username, self.host)
298 } else {
299 self.host.clone()
300 };
301 if let Some(port) = &self.port {
302 format!("{}:{}", host, port)
303 } else {
304 host
305 }
306 }
307}
308
309#[derive(Copy, Clone, Debug)]
310pub struct SshPlatform {
311 pub os: &'static str,
312 pub arch: &'static str,
313}
314
315pub trait SshClientDelegate: Send + Sync {
316 fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
317 fn get_download_params(
318 &self,
319 platform: SshPlatform,
320 release_channel: ReleaseChannel,
321 version: Option<SemanticVersion>,
322 cx: &mut AsyncApp,
323 ) -> Task<Result<Option<(String, String)>>>;
324
325 fn download_server_binary_locally(
326 &self,
327 platform: SshPlatform,
328 release_channel: ReleaseChannel,
329 version: Option<SemanticVersion>,
330 cx: &mut AsyncApp,
331 ) -> Task<Result<PathBuf>>;
332 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
333}
334
335impl SshSocket {
336 #[cfg(not(target_os = "windows"))]
337 fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
338 Ok(Self {
339 connection_options: options,
340 socket_path,
341 })
342 }
343
344 #[cfg(target_os = "windows")]
345 fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
346 let askpass_script = temp_dir.path().join("askpass.bat");
347 std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
348 let mut envs = HashMap::default();
349 envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
350 envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
351 envs.insert("ZED_SSH_ASKPASS".into(), secret);
352 Ok(Self {
353 connection_options: options,
354 envs,
355 })
356 }
357
358 // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
359 // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
360 // and passes -l as an argument to sh, not to ls.
361 // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
362 // into a machine. You must use `cd` to get back to $HOME.
363 // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
364 fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
365 let mut command = util::command::new_smol_command("ssh");
366 let to_run = iter::once(&program)
367 .chain(args.iter())
368 .map(|token| {
369 // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
370 debug_assert!(
371 !token.contains('\n'),
372 "multiline arguments do not work in all shells"
373 );
374 shlex::try_quote(token).unwrap()
375 })
376 .join(" ");
377 let to_run = format!("cd; {to_run}");
378 log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
379 self.ssh_options(&mut command)
380 .arg(self.connection_options.ssh_url())
381 .arg(to_run);
382 command
383 }
384
385 async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
386 let output = self.ssh_command(program, args).output().await?;
387 anyhow::ensure!(
388 output.status.success(),
389 "failed to run command: {}",
390 String::from_utf8_lossy(&output.stderr)
391 );
392 Ok(String::from_utf8_lossy(&output.stdout).to_string())
393 }
394
395 #[cfg(not(target_os = "windows"))]
396 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
397 command
398 .stdin(Stdio::piped())
399 .stdout(Stdio::piped())
400 .stderr(Stdio::piped())
401 .args(self.connection_options.additional_args())
402 .args(["-o", "ControlMaster=no", "-o"])
403 .arg(format!("ControlPath={}", self.socket_path.display()))
404 }
405
406 #[cfg(target_os = "windows")]
407 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
408 command
409 .stdin(Stdio::piped())
410 .stdout(Stdio::piped())
411 .stderr(Stdio::piped())
412 .args(self.connection_options.additional_args())
413 .envs(self.envs.clone())
414 }
415
416 // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
417 // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
418 #[cfg(not(target_os = "windows"))]
419 fn ssh_args(&self) -> SshArgs {
420 let mut arguments = self.connection_options.additional_args();
421 arguments.extend(vec![
422 "-o".to_string(),
423 "ControlMaster=no".to_string(),
424 "-o".to_string(),
425 format!("ControlPath={}", self.socket_path.display()),
426 self.connection_options.ssh_url(),
427 ]);
428 SshArgs {
429 arguments,
430 envs: None,
431 }
432 }
433
434 #[cfg(target_os = "windows")]
435 fn ssh_args(&self) -> SshArgs {
436 let mut arguments = self.connection_options.additional_args();
437 arguments.push(self.connection_options.ssh_url());
438 SshArgs {
439 arguments,
440 envs: Some(self.envs.clone()),
441 }
442 }
443
444 async fn platform(&self) -> Result<SshPlatform> {
445 let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
446 let Some((os, arch)) = uname.split_once(" ") else {
447 anyhow::bail!("unknown uname: {uname:?}")
448 };
449
450 let os = match os.trim() {
451 "Darwin" => "macos",
452 "Linux" => "linux",
453 _ => anyhow::bail!(
454 "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
455 ),
456 };
457 // exclude armv5,6,7 as they are 32-bit.
458 let arch = if arch.starts_with("armv8")
459 || arch.starts_with("armv9")
460 || arch.starts_with("arm64")
461 || arch.starts_with("aarch64")
462 {
463 "aarch64"
464 } else if arch.starts_with("x86") {
465 "x86_64"
466 } else {
467 anyhow::bail!(
468 "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
469 )
470 };
471
472 Ok(SshPlatform { os, arch })
473 }
474}
475
476const MAX_MISSED_HEARTBEATS: usize = 5;
477const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
478const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
479
480const MAX_RECONNECT_ATTEMPTS: usize = 3;
481
482enum State {
483 Connecting,
484 Connected {
485 ssh_connection: Arc<dyn RemoteConnection>,
486 delegate: Arc<dyn SshClientDelegate>,
487
488 multiplex_task: Task<Result<()>>,
489 heartbeat_task: Task<Result<()>>,
490 },
491 HeartbeatMissed {
492 missed_heartbeats: usize,
493
494 ssh_connection: Arc<dyn RemoteConnection>,
495 delegate: Arc<dyn SshClientDelegate>,
496
497 multiplex_task: Task<Result<()>>,
498 heartbeat_task: Task<Result<()>>,
499 },
500 Reconnecting,
501 ReconnectFailed {
502 ssh_connection: Arc<dyn RemoteConnection>,
503 delegate: Arc<dyn SshClientDelegate>,
504
505 error: anyhow::Error,
506 attempts: usize,
507 },
508 ReconnectExhausted,
509 ServerNotRunning,
510}
511
512impl fmt::Display for State {
513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514 match self {
515 Self::Connecting => write!(f, "connecting"),
516 Self::Connected { .. } => write!(f, "connected"),
517 Self::Reconnecting => write!(f, "reconnecting"),
518 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
519 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
520 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
521 Self::ServerNotRunning { .. } => write!(f, "server not running"),
522 }
523 }
524}
525
526impl State {
527 fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
528 match self {
529 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
530 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
531 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
532 _ => None,
533 }
534 }
535
536 fn can_reconnect(&self) -> bool {
537 match self {
538 Self::Connected { .. }
539 | Self::HeartbeatMissed { .. }
540 | Self::ReconnectFailed { .. } => true,
541 State::Connecting
542 | State::Reconnecting
543 | State::ReconnectExhausted
544 | State::ServerNotRunning => false,
545 }
546 }
547
548 fn is_reconnect_failed(&self) -> bool {
549 matches!(self, Self::ReconnectFailed { .. })
550 }
551
552 fn is_reconnect_exhausted(&self) -> bool {
553 matches!(self, Self::ReconnectExhausted { .. })
554 }
555
556 fn is_server_not_running(&self) -> bool {
557 matches!(self, Self::ServerNotRunning)
558 }
559
560 fn is_reconnecting(&self) -> bool {
561 matches!(self, Self::Reconnecting { .. })
562 }
563
564 fn heartbeat_recovered(self) -> Self {
565 match self {
566 Self::HeartbeatMissed {
567 ssh_connection,
568 delegate,
569 multiplex_task,
570 heartbeat_task,
571 ..
572 } => Self::Connected {
573 ssh_connection,
574 delegate,
575 multiplex_task,
576 heartbeat_task,
577 },
578 _ => self,
579 }
580 }
581
582 fn heartbeat_missed(self) -> Self {
583 match self {
584 Self::Connected {
585 ssh_connection,
586 delegate,
587 multiplex_task,
588 heartbeat_task,
589 } => Self::HeartbeatMissed {
590 missed_heartbeats: 1,
591 ssh_connection,
592 delegate,
593 multiplex_task,
594 heartbeat_task,
595 },
596 Self::HeartbeatMissed {
597 missed_heartbeats,
598 ssh_connection,
599 delegate,
600 multiplex_task,
601 heartbeat_task,
602 } => Self::HeartbeatMissed {
603 missed_heartbeats: missed_heartbeats + 1,
604 ssh_connection,
605 delegate,
606 multiplex_task,
607 heartbeat_task,
608 },
609 _ => self,
610 }
611 }
612}
613
614/// The state of the ssh connection.
615#[derive(Clone, Copy, Debug, PartialEq, Eq)]
616pub enum ConnectionState {
617 Connecting,
618 Connected,
619 HeartbeatMissed,
620 Reconnecting,
621 Disconnected,
622}
623
624impl From<&State> for ConnectionState {
625 fn from(value: &State) -> Self {
626 match value {
627 State::Connecting => Self::Connecting,
628 State::Connected { .. } => Self::Connected,
629 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
630 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
631 State::ReconnectExhausted => Self::Disconnected,
632 State::ServerNotRunning => Self::Disconnected,
633 }
634 }
635}
636
637pub struct SshRemoteClient {
638 client: Arc<ChannelClient>,
639 unique_identifier: String,
640 connection_options: SshConnectionOptions,
641 path_style: PathStyle,
642 state: Arc<Mutex<Option<State>>>,
643}
644
645#[derive(Debug)]
646pub enum SshRemoteEvent {
647 Disconnected,
648}
649
650impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
651
652// Identifies the socket on the remote server so that reconnects
653// can re-join the same project.
654pub enum ConnectionIdentifier {
655 Setup(u64),
656 Workspace(i64),
657}
658
659static NEXT_ID: AtomicU64 = AtomicU64::new(1);
660
661impl ConnectionIdentifier {
662 pub fn setup() -> Self {
663 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
664 }
665
666 // This string gets used in a socket name, and so must be relatively short.
667 // The total length of:
668 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
669 // Must be less than about 100 characters
670 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
671 // So our strings should be at most 20 characters or so.
672 fn to_string(&self, cx: &App) -> String {
673 let identifier_prefix = match ReleaseChannel::global(cx) {
674 ReleaseChannel::Stable => "".to_string(),
675 release_channel => format!("{}-", release_channel.dev_name()),
676 };
677 match self {
678 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
679 Self::Workspace(workspace_id) => {
680 format!("{identifier_prefix}workspace-{workspace_id}",)
681 }
682 }
683 }
684}
685
686impl SshRemoteClient {
687 pub fn new(
688 unique_identifier: ConnectionIdentifier,
689 connection_options: SshConnectionOptions,
690 cancellation: oneshot::Receiver<()>,
691 delegate: Arc<dyn SshClientDelegate>,
692 cx: &mut App,
693 ) -> Task<Result<Option<Entity<Self>>>> {
694 let unique_identifier = unique_identifier.to_string(cx);
695 cx.spawn(async move |cx| {
696 let success = Box::pin(async move {
697 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
698 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
699 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
700
701 let client =
702 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
703
704 let ssh_connection = cx
705 .update(|cx| {
706 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
707 pool.connect(connection_options.clone(), &delegate, cx)
708 })
709 })?
710 .await
711 .map_err(|e| e.cloned())?;
712
713 let path_style = ssh_connection.path_style();
714 let this = cx.new(|_| Self {
715 client: client.clone(),
716 unique_identifier: unique_identifier.clone(),
717 connection_options,
718 path_style,
719 state: Arc::new(Mutex::new(Some(State::Connecting))),
720 })?;
721
722 let io_task = ssh_connection.start_proxy(
723 unique_identifier,
724 false,
725 incoming_tx,
726 outgoing_rx,
727 connection_activity_tx,
728 delegate.clone(),
729 cx,
730 );
731
732 let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
733
734 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
735 log::error!("failed to establish connection: {}", error);
736 return Err(error);
737 }
738
739 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
740
741 this.update(cx, |this, _| {
742 *this.state.lock() = Some(State::Connected {
743 ssh_connection,
744 delegate,
745 multiplex_task,
746 heartbeat_task,
747 });
748 })?;
749
750 Ok(Some(this))
751 });
752
753 select! {
754 _ = cancellation.fuse() => {
755 Ok(None)
756 }
757 result = success.fuse() => result
758 }
759 })
760 }
761
762 pub fn proto_client_from_channels(
763 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
764 outgoing_tx: mpsc::UnboundedSender<Envelope>,
765 cx: &App,
766 name: &'static str,
767 ) -> AnyProtoClient {
768 ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
769 }
770
771 pub fn shutdown_processes<T: RequestMessage>(
772 &self,
773 shutdown_request: Option<T>,
774 executor: BackgroundExecutor,
775 ) -> Option<impl Future<Output = ()> + use<T>> {
776 let state = self.state.lock().take()?;
777 log::info!("shutting down ssh processes");
778
779 let State::Connected {
780 multiplex_task,
781 heartbeat_task,
782 ssh_connection,
783 delegate,
784 } = state
785 else {
786 return None;
787 };
788
789 let client = self.client.clone();
790
791 Some(async move {
792 if let Some(shutdown_request) = shutdown_request {
793 client.send(shutdown_request).log_err();
794 // We wait 50ms instead of waiting for a response, because
795 // waiting for a response would require us to wait on the main thread
796 // which we want to avoid in an `on_app_quit` callback.
797 executor.timer(Duration::from_millis(50)).await;
798 }
799
800 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
801 // child of master_process.
802 drop(multiplex_task);
803 // Now drop the rest of state, which kills master process.
804 drop(heartbeat_task);
805 drop(ssh_connection);
806 drop(delegate);
807 })
808 }
809
810 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
811 let mut lock = self.state.lock();
812
813 let can_reconnect = lock
814 .as_ref()
815 .map(|state| state.can_reconnect())
816 .unwrap_or(false);
817 if !can_reconnect {
818 log::info!("aborting reconnect, because not in state that allows reconnecting");
819 let error = if let Some(state) = lock.as_ref() {
820 format!("invalid state, cannot reconnect while in state {state}")
821 } else {
822 "no state set".to_string()
823 };
824 anyhow::bail!(error);
825 }
826
827 let state = lock.take().unwrap();
828 let (attempts, ssh_connection, delegate) = match state {
829 State::Connected {
830 ssh_connection,
831 delegate,
832 multiplex_task,
833 heartbeat_task,
834 }
835 | State::HeartbeatMissed {
836 ssh_connection,
837 delegate,
838 multiplex_task,
839 heartbeat_task,
840 ..
841 } => {
842 drop(multiplex_task);
843 drop(heartbeat_task);
844 (0, ssh_connection, delegate)
845 }
846 State::ReconnectFailed {
847 attempts,
848 ssh_connection,
849 delegate,
850 ..
851 } => (attempts, ssh_connection, delegate),
852 State::Connecting
853 | State::Reconnecting
854 | State::ReconnectExhausted
855 | State::ServerNotRunning => unreachable!(),
856 };
857
858 let attempts = attempts + 1;
859 if attempts > MAX_RECONNECT_ATTEMPTS {
860 log::error!(
861 "Failed to reconnect to after {} attempts, giving up",
862 MAX_RECONNECT_ATTEMPTS
863 );
864 drop(lock);
865 self.set_state(State::ReconnectExhausted, cx);
866 return Ok(());
867 }
868 drop(lock);
869
870 self.set_state(State::Reconnecting, cx);
871
872 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
873
874 let unique_identifier = self.unique_identifier.clone();
875 let client = self.client.clone();
876 let reconnect_task = cx.spawn(async move |this, cx| {
877 macro_rules! failed {
878 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
879 return State::ReconnectFailed {
880 error: anyhow!($error),
881 attempts: $attempts,
882 ssh_connection: $ssh_connection,
883 delegate: $delegate,
884 };
885 };
886 }
887
888 if let Err(error) = ssh_connection
889 .kill()
890 .await
891 .context("Failed to kill ssh process")
892 {
893 failed!(error, attempts, ssh_connection, delegate);
894 };
895
896 let connection_options = ssh_connection.connection_options();
897
898 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
899 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
900 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
901
902 let (ssh_connection, io_task) = match async {
903 let ssh_connection = cx
904 .update_global(|pool: &mut ConnectionPool, cx| {
905 pool.connect(connection_options, &delegate, cx)
906 })?
907 .await
908 .map_err(|error| error.cloned())?;
909
910 let io_task = ssh_connection.start_proxy(
911 unique_identifier,
912 true,
913 incoming_tx,
914 outgoing_rx,
915 connection_activity_tx,
916 delegate.clone(),
917 cx,
918 );
919 anyhow::Ok((ssh_connection, io_task))
920 }
921 .await
922 {
923 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
924 Err(error) => {
925 failed!(error, attempts, ssh_connection, delegate);
926 }
927 };
928
929 let multiplex_task = Self::monitor(this.clone(), io_task, cx);
930 client.reconnect(incoming_rx, outgoing_tx, cx);
931
932 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
933 failed!(error, attempts, ssh_connection, delegate);
934 };
935
936 State::Connected {
937 ssh_connection,
938 delegate,
939 multiplex_task,
940 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
941 }
942 });
943
944 cx.spawn(async move |this, cx| {
945 let new_state = reconnect_task.await;
946 this.update(cx, |this, cx| {
947 this.try_set_state(cx, |old_state| {
948 if old_state.is_reconnecting() {
949 match &new_state {
950 State::Connecting
951 | State::Reconnecting { .. }
952 | State::HeartbeatMissed { .. }
953 | State::ServerNotRunning => {}
954 State::Connected { .. } => {
955 log::info!("Successfully reconnected");
956 }
957 State::ReconnectFailed {
958 error, attempts, ..
959 } => {
960 log::error!(
961 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
962 attempts,
963 error
964 );
965 }
966 State::ReconnectExhausted => {
967 log::error!("Reconnect attempt failed and all attempts exhausted");
968 }
969 }
970 Some(new_state)
971 } else {
972 None
973 }
974 });
975
976 if this.state_is(State::is_reconnect_failed) {
977 this.reconnect(cx)
978 } else if this.state_is(State::is_reconnect_exhausted) {
979 Ok(())
980 } else {
981 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
982 Ok(())
983 }
984 })
985 })
986 .detach_and_log_err(cx);
987
988 Ok(())
989 }
990
991 fn heartbeat(
992 this: WeakEntity<Self>,
993 mut connection_activity_rx: mpsc::Receiver<()>,
994 cx: &mut AsyncApp,
995 ) -> Task<Result<()>> {
996 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
997 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
998 };
999
1000 cx.spawn(async move |cx| {
1001 let mut missed_heartbeats = 0;
1002
1003 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
1004 futures::pin_mut!(keepalive_timer);
1005
1006 loop {
1007 select_biased! {
1008 result = connection_activity_rx.next().fuse() => {
1009 if result.is_none() {
1010 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1011 return Ok(());
1012 }
1013
1014 if missed_heartbeats != 0 {
1015 missed_heartbeats = 0;
1016 let _ =this.update(cx, |this, cx| {
1017 this.handle_heartbeat_result(missed_heartbeats, cx)
1018 })?;
1019 }
1020 }
1021 _ = keepalive_timer => {
1022 log::debug!("Sending heartbeat to server...");
1023
1024 let result = select_biased! {
1025 _ = connection_activity_rx.next().fuse() => {
1026 Ok(())
1027 }
1028 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1029 ping_result
1030 }
1031 };
1032
1033 if result.is_err() {
1034 missed_heartbeats += 1;
1035 log::warn!(
1036 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1037 HEARTBEAT_TIMEOUT,
1038 missed_heartbeats,
1039 MAX_MISSED_HEARTBEATS
1040 );
1041 } else if missed_heartbeats != 0 {
1042 missed_heartbeats = 0;
1043 } else {
1044 continue;
1045 }
1046
1047 let result = this.update(cx, |this, cx| {
1048 this.handle_heartbeat_result(missed_heartbeats, cx)
1049 })?;
1050 if result.is_break() {
1051 return Ok(());
1052 }
1053 }
1054 }
1055
1056 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1057 }
1058 })
1059 }
1060
1061 fn handle_heartbeat_result(
1062 &mut self,
1063 missed_heartbeats: usize,
1064 cx: &mut Context<Self>,
1065 ) -> ControlFlow<()> {
1066 let state = self.state.lock().take().unwrap();
1067 let next_state = if missed_heartbeats > 0 {
1068 state.heartbeat_missed()
1069 } else {
1070 state.heartbeat_recovered()
1071 };
1072
1073 self.set_state(next_state, cx);
1074
1075 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1076 log::error!(
1077 "Missed last {} heartbeats. Reconnecting...",
1078 missed_heartbeats
1079 );
1080
1081 self.reconnect(cx)
1082 .context("failed to start reconnect process after missing heartbeats")
1083 .log_err();
1084 ControlFlow::Break(())
1085 } else {
1086 ControlFlow::Continue(())
1087 }
1088 }
1089
1090 fn monitor(
1091 this: WeakEntity<Self>,
1092 io_task: Task<Result<i32>>,
1093 cx: &AsyncApp,
1094 ) -> Task<Result<()>> {
1095 cx.spawn(async move |cx| {
1096 let result = io_task.await;
1097
1098 match result {
1099 Ok(exit_code) => {
1100 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1101 match error {
1102 ProxyLaunchError::ServerNotRunning => {
1103 log::error!("failed to reconnect because server is not running");
1104 this.update(cx, |this, cx| {
1105 this.set_state(State::ServerNotRunning, cx);
1106 })?;
1107 }
1108 }
1109 } else if exit_code > 0 {
1110 log::error!("proxy process terminated unexpectedly");
1111 this.update(cx, |this, cx| {
1112 this.reconnect(cx).ok();
1113 })?;
1114 }
1115 }
1116 Err(error) => {
1117 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1118 this.update(cx, |this, cx| {
1119 this.reconnect(cx).ok();
1120 })?;
1121 }
1122 }
1123
1124 Ok(())
1125 })
1126 }
1127
1128 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1129 self.state.lock().as_ref().is_some_and(check)
1130 }
1131
1132 fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1133 let mut lock = self.state.lock();
1134 let new_state = lock.as_ref().and_then(map);
1135
1136 if let Some(new_state) = new_state {
1137 lock.replace(new_state);
1138 cx.notify();
1139 }
1140 }
1141
1142 fn set_state(&self, state: State, cx: &mut Context<Self>) {
1143 log::info!("setting state to '{}'", &state);
1144
1145 let is_reconnect_exhausted = state.is_reconnect_exhausted();
1146 let is_server_not_running = state.is_server_not_running();
1147 self.state.lock().replace(state);
1148
1149 if is_reconnect_exhausted || is_server_not_running {
1150 cx.emit(SshRemoteEvent::Disconnected);
1151 }
1152 cx.notify();
1153 }
1154
1155 pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1156 self.state
1157 .lock()
1158 .as_ref()
1159 .and_then(|state| state.ssh_connection())
1160 .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1161 }
1162
1163 pub fn upload_directory(
1164 &self,
1165 src_path: PathBuf,
1166 dest_path: RemotePathBuf,
1167 cx: &App,
1168 ) -> Task<Result<()>> {
1169 let state = self.state.lock();
1170 let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1171 return Task::ready(Err(anyhow!("no ssh connection")));
1172 };
1173 connection.upload_directory(src_path, dest_path, cx)
1174 }
1175
1176 pub fn proto_client(&self) -> AnyProtoClient {
1177 self.client.clone().into()
1178 }
1179
1180 pub fn connection_string(&self) -> String {
1181 self.connection_options.connection_string()
1182 }
1183
1184 pub fn connection_options(&self) -> SshConnectionOptions {
1185 self.connection_options.clone()
1186 }
1187
1188 pub fn connection_state(&self) -> ConnectionState {
1189 self.state
1190 .lock()
1191 .as_ref()
1192 .map(ConnectionState::from)
1193 .unwrap_or(ConnectionState::Disconnected)
1194 }
1195
1196 pub fn is_disconnected(&self) -> bool {
1197 self.connection_state() == ConnectionState::Disconnected
1198 }
1199
1200 pub fn path_style(&self) -> PathStyle {
1201 self.path_style
1202 }
1203
1204 #[cfg(any(test, feature = "test-support"))]
1205 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1206 let opts = self.connection_options();
1207 client_cx.spawn(async move |cx| {
1208 let connection = cx
1209 .update_global(|c: &mut ConnectionPool, _| {
1210 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1211 c.clone()
1212 } else {
1213 panic!("missing test connection")
1214 }
1215 })
1216 .unwrap()
1217 .await
1218 .unwrap();
1219
1220 connection.simulate_disconnect(cx);
1221 })
1222 }
1223
1224 #[cfg(any(test, feature = "test-support"))]
1225 pub fn fake_server(
1226 client_cx: &mut gpui::TestAppContext,
1227 server_cx: &mut gpui::TestAppContext,
1228 ) -> (SshConnectionOptions, AnyProtoClient) {
1229 let port = client_cx
1230 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1231 let opts = SshConnectionOptions {
1232 host: "<fake>".to_string(),
1233 port: Some(port),
1234 ..Default::default()
1235 };
1236 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1237 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1238 let server_client =
1239 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1240 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1241 connection_options: opts.clone(),
1242 server_cx: fake::SendableCx::new(server_cx),
1243 server_channel: server_client.clone(),
1244 });
1245
1246 client_cx.update(|cx| {
1247 cx.update_default_global(|c: &mut ConnectionPool, cx| {
1248 c.connections.insert(
1249 opts.clone(),
1250 ConnectionPoolEntry::Connecting(
1251 cx.background_spawn({
1252 let connection = connection.clone();
1253 async move { Ok(connection.clone()) }
1254 })
1255 .shared(),
1256 ),
1257 );
1258 })
1259 });
1260
1261 (opts, server_client.into())
1262 }
1263
1264 #[cfg(any(test, feature = "test-support"))]
1265 pub async fn fake_client(
1266 opts: SshConnectionOptions,
1267 client_cx: &mut gpui::TestAppContext,
1268 ) -> Entity<Self> {
1269 let (_tx, rx) = oneshot::channel();
1270 client_cx
1271 .update(|cx| {
1272 Self::new(
1273 ConnectionIdentifier::setup(),
1274 opts,
1275 rx,
1276 Arc::new(fake::Delegate),
1277 cx,
1278 )
1279 })
1280 .await
1281 .unwrap()
1282 .unwrap()
1283 }
1284}
1285
1286enum ConnectionPoolEntry {
1287 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1288 Connected(Weak<dyn RemoteConnection>),
1289}
1290
1291#[derive(Default)]
1292struct ConnectionPool {
1293 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1294}
1295
1296impl Global for ConnectionPool {}
1297
1298impl ConnectionPool {
1299 pub fn connect(
1300 &mut self,
1301 opts: SshConnectionOptions,
1302 delegate: &Arc<dyn SshClientDelegate>,
1303 cx: &mut App,
1304 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1305 let connection = self.connections.get(&opts);
1306 match connection {
1307 Some(ConnectionPoolEntry::Connecting(task)) => {
1308 let delegate = delegate.clone();
1309 cx.spawn(async move |cx| {
1310 delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1311 })
1312 .detach();
1313 return task.clone();
1314 }
1315 Some(ConnectionPoolEntry::Connected(ssh)) => {
1316 if let Some(ssh) = ssh.upgrade()
1317 && !ssh.has_been_killed()
1318 {
1319 return Task::ready(Ok(ssh)).shared();
1320 }
1321 self.connections.remove(&opts);
1322 }
1323 None => {}
1324 }
1325
1326 let task = cx
1327 .spawn({
1328 let opts = opts.clone();
1329 let delegate = delegate.clone();
1330 async move |cx| {
1331 let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1332 .await
1333 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1334
1335 cx.update_global(|pool: &mut Self, _| {
1336 debug_assert!(matches!(
1337 pool.connections.get(&opts),
1338 Some(ConnectionPoolEntry::Connecting(_))
1339 ));
1340 match connection {
1341 Ok(connection) => {
1342 pool.connections.insert(
1343 opts.clone(),
1344 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1345 );
1346 Ok(connection)
1347 }
1348 Err(error) => {
1349 pool.connections.remove(&opts);
1350 Err(Arc::new(error))
1351 }
1352 }
1353 })?
1354 }
1355 })
1356 .shared();
1357
1358 self.connections
1359 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1360 task
1361 }
1362}
1363
1364impl From<SshRemoteClient> for AnyProtoClient {
1365 fn from(client: SshRemoteClient) -> Self {
1366 AnyProtoClient::new(client.client.clone())
1367 }
1368}
1369
1370#[async_trait(?Send)]
1371trait RemoteConnection: Send + Sync {
1372 fn start_proxy(
1373 &self,
1374 unique_identifier: String,
1375 reconnect: bool,
1376 incoming_tx: UnboundedSender<Envelope>,
1377 outgoing_rx: UnboundedReceiver<Envelope>,
1378 connection_activity_tx: Sender<()>,
1379 delegate: Arc<dyn SshClientDelegate>,
1380 cx: &mut AsyncApp,
1381 ) -> Task<Result<i32>>;
1382 fn upload_directory(
1383 &self,
1384 src_path: PathBuf,
1385 dest_path: RemotePathBuf,
1386 cx: &App,
1387 ) -> Task<Result<()>>;
1388 async fn kill(&self) -> Result<()>;
1389 fn has_been_killed(&self) -> bool;
1390 /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1391 /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1392 fn ssh_args(&self) -> SshArgs;
1393 fn connection_options(&self) -> SshConnectionOptions;
1394 fn path_style(&self) -> PathStyle;
1395
1396 #[cfg(any(test, feature = "test-support"))]
1397 fn simulate_disconnect(&self, _: &AsyncApp) {}
1398}
1399
1400struct SshRemoteConnection {
1401 socket: SshSocket,
1402 master_process: Mutex<Option<Child>>,
1403 remote_binary_path: Option<RemotePathBuf>,
1404 ssh_platform: SshPlatform,
1405 ssh_path_style: PathStyle,
1406 _temp_dir: TempDir,
1407}
1408
1409#[async_trait(?Send)]
1410impl RemoteConnection for SshRemoteConnection {
1411 async fn kill(&self) -> Result<()> {
1412 let Some(mut process) = self.master_process.lock().take() else {
1413 return Ok(());
1414 };
1415 process.kill().ok();
1416 process.status().await?;
1417 Ok(())
1418 }
1419
1420 fn has_been_killed(&self) -> bool {
1421 self.master_process.lock().is_none()
1422 }
1423
1424 fn ssh_args(&self) -> SshArgs {
1425 self.socket.ssh_args()
1426 }
1427
1428 fn connection_options(&self) -> SshConnectionOptions {
1429 self.socket.connection_options.clone()
1430 }
1431
1432 fn upload_directory(
1433 &self,
1434 src_path: PathBuf,
1435 dest_path: RemotePathBuf,
1436 cx: &App,
1437 ) -> Task<Result<()>> {
1438 let mut command = util::command::new_smol_command("scp");
1439 let output = self
1440 .socket
1441 .ssh_options(&mut command)
1442 .args(
1443 self.socket
1444 .connection_options
1445 .port
1446 .map(|port| vec!["-P".to_string(), port.to_string()])
1447 .unwrap_or_default(),
1448 )
1449 .arg("-C")
1450 .arg("-r")
1451 .arg(&src_path)
1452 .arg(format!(
1453 "{}:{}",
1454 self.socket.connection_options.scp_url(),
1455 dest_path
1456 ))
1457 .output();
1458
1459 cx.background_spawn(async move {
1460 let output = output.await?;
1461
1462 anyhow::ensure!(
1463 output.status.success(),
1464 "failed to upload directory {} -> {}: {}",
1465 src_path.display(),
1466 dest_path.to_string(),
1467 String::from_utf8_lossy(&output.stderr)
1468 );
1469
1470 Ok(())
1471 })
1472 }
1473
1474 fn start_proxy(
1475 &self,
1476 unique_identifier: String,
1477 reconnect: bool,
1478 incoming_tx: UnboundedSender<Envelope>,
1479 outgoing_rx: UnboundedReceiver<Envelope>,
1480 connection_activity_tx: Sender<()>,
1481 delegate: Arc<dyn SshClientDelegate>,
1482 cx: &mut AsyncApp,
1483 ) -> Task<Result<i32>> {
1484 delegate.set_status(Some("Starting proxy"), cx);
1485
1486 let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1487 return Task::ready(Err(anyhow!("Remote binary path not set")));
1488 };
1489
1490 let mut start_proxy_command = shell_script!(
1491 "exec {binary_path} proxy --identifier {identifier}",
1492 binary_path = &remote_binary_path.to_string(),
1493 identifier = &unique_identifier,
1494 );
1495
1496 for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
1497 if let Some(value) = std::env::var(env_var).ok() {
1498 start_proxy_command = format!(
1499 "{}={} {} ",
1500 env_var,
1501 shlex::try_quote(&value).unwrap(),
1502 start_proxy_command,
1503 );
1504 }
1505 }
1506
1507 if reconnect {
1508 start_proxy_command.push_str(" --reconnect");
1509 }
1510
1511 let ssh_proxy_process = match self
1512 .socket
1513 .ssh_command("sh", &["-c", &start_proxy_command])
1514 // IMPORTANT: we kill this process when we drop the task that uses it.
1515 .kill_on_drop(true)
1516 .spawn()
1517 {
1518 Ok(process) => process,
1519 Err(error) => {
1520 return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1521 }
1522 };
1523
1524 Self::multiplex(
1525 ssh_proxy_process,
1526 incoming_tx,
1527 outgoing_rx,
1528 connection_activity_tx,
1529 cx,
1530 )
1531 }
1532
1533 fn path_style(&self) -> PathStyle {
1534 self.ssh_path_style
1535 }
1536}
1537
1538impl SshRemoteConnection {
1539 async fn new(
1540 connection_options: SshConnectionOptions,
1541 delegate: Arc<dyn SshClientDelegate>,
1542 cx: &mut AsyncApp,
1543 ) -> Result<Self> {
1544 use askpass::AskPassResult;
1545
1546 delegate.set_status(Some("Connecting"), cx);
1547
1548 let url = connection_options.ssh_url();
1549
1550 let temp_dir = tempfile::Builder::new()
1551 .prefix("zed-ssh-session")
1552 .tempdir()?;
1553 let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1554 let delegate = delegate.clone();
1555 move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1556 });
1557
1558 let mut askpass =
1559 askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1560
1561 // Start the master SSH process, which does not do anything except for establish
1562 // the connection and keep it open, allowing other ssh commands to reuse it
1563 // via a control socket.
1564 #[cfg(not(target_os = "windows"))]
1565 let socket_path = temp_dir.path().join("ssh.sock");
1566
1567 let mut master_process = {
1568 #[cfg(not(target_os = "windows"))]
1569 let args = [
1570 "-N",
1571 "-o",
1572 "ControlPersist=no",
1573 "-o",
1574 "ControlMaster=yes",
1575 "-o",
1576 ];
1577 // On Windows, `ControlMaster` and `ControlPath` are not supported:
1578 // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1579 // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1580 #[cfg(target_os = "windows")]
1581 let args = ["-N"];
1582 let mut master_process = util::command::new_smol_command("ssh");
1583 master_process
1584 .kill_on_drop(true)
1585 .stdin(Stdio::null())
1586 .stdout(Stdio::piped())
1587 .stderr(Stdio::piped())
1588 .env("SSH_ASKPASS_REQUIRE", "force")
1589 .env("SSH_ASKPASS", askpass.script_path())
1590 .args(connection_options.additional_args())
1591 .args(args);
1592 #[cfg(not(target_os = "windows"))]
1593 master_process.arg(format!("ControlPath={}", socket_path.display()));
1594 master_process.arg(&url).spawn()?
1595 };
1596 // Wait for this ssh process to close its stdout, indicating that authentication
1597 // has completed.
1598 let mut stdout = master_process.stdout.take().unwrap();
1599 let mut output = Vec::new();
1600
1601 let result = select_biased! {
1602 result = askpass.run().fuse() => {
1603 match result {
1604 AskPassResult::CancelledByUser => {
1605 master_process.kill().ok();
1606 anyhow::bail!("SSH connection canceled")
1607 }
1608 AskPassResult::Timedout => {
1609 anyhow::bail!("connecting to host timed out")
1610 }
1611 }
1612 }
1613 _ = stdout.read_to_end(&mut output).fuse() => {
1614 anyhow::Ok(())
1615 }
1616 };
1617
1618 if let Err(e) = result {
1619 return Err(e.context("Failed to connect to host"));
1620 }
1621
1622 if master_process.try_status()?.is_some() {
1623 output.clear();
1624 let mut stderr = master_process.stderr.take().unwrap();
1625 stderr.read_to_end(&mut output).await?;
1626
1627 let error_message = format!(
1628 "failed to connect: {}",
1629 String::from_utf8_lossy(&output).trim()
1630 );
1631 anyhow::bail!(error_message);
1632 }
1633
1634 #[cfg(not(target_os = "windows"))]
1635 let socket = SshSocket::new(connection_options, socket_path)?;
1636 #[cfg(target_os = "windows")]
1637 let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1638 drop(askpass);
1639
1640 let ssh_platform = socket.platform().await?;
1641 let ssh_path_style = match ssh_platform.os {
1642 "windows" => PathStyle::Windows,
1643 _ => PathStyle::Posix,
1644 };
1645
1646 let mut this = Self {
1647 socket,
1648 master_process: Mutex::new(Some(master_process)),
1649 _temp_dir: temp_dir,
1650 remote_binary_path: None,
1651 ssh_path_style,
1652 ssh_platform,
1653 };
1654
1655 let (release_channel, version, commit) = cx.update(|cx| {
1656 (
1657 ReleaseChannel::global(cx),
1658 AppVersion::global(cx),
1659 AppCommitSha::try_global(cx),
1660 )
1661 })?;
1662 this.remote_binary_path = Some(
1663 this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1664 .await?,
1665 );
1666
1667 Ok(this)
1668 }
1669
1670 fn multiplex(
1671 mut ssh_proxy_process: Child,
1672 incoming_tx: UnboundedSender<Envelope>,
1673 mut outgoing_rx: UnboundedReceiver<Envelope>,
1674 mut connection_activity_tx: Sender<()>,
1675 cx: &AsyncApp,
1676 ) -> Task<Result<i32>> {
1677 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1678 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1679 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1680
1681 let mut stdin_buffer = Vec::new();
1682 let mut stdout_buffer = Vec::new();
1683 let mut stderr_buffer = Vec::new();
1684 let mut stderr_offset = 0;
1685
1686 let stdin_task = cx.background_spawn(async move {
1687 while let Some(outgoing) = outgoing_rx.next().await {
1688 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1689 }
1690 anyhow::Ok(())
1691 });
1692
1693 let stdout_task = cx.background_spawn({
1694 let mut connection_activity_tx = connection_activity_tx.clone();
1695 async move {
1696 loop {
1697 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1698 let len = child_stdout.read(&mut stdout_buffer).await?;
1699
1700 if len == 0 {
1701 return anyhow::Ok(());
1702 }
1703
1704 if len < MESSAGE_LEN_SIZE {
1705 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1706 }
1707
1708 let message_len = message_len_from_buffer(&stdout_buffer);
1709 let envelope =
1710 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1711 .await?;
1712 connection_activity_tx.try_send(()).ok();
1713 incoming_tx.unbounded_send(envelope).ok();
1714 }
1715 }
1716 });
1717
1718 let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1719 loop {
1720 stderr_buffer.resize(stderr_offset + 1024, 0);
1721
1722 let len = child_stderr
1723 .read(&mut stderr_buffer[stderr_offset..])
1724 .await?;
1725 if len == 0 {
1726 return anyhow::Ok(());
1727 }
1728
1729 stderr_offset += len;
1730 let mut start_ix = 0;
1731 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1732 .iter()
1733 .position(|b| b == &b'\n')
1734 {
1735 let line_ix = start_ix + ix;
1736 let content = &stderr_buffer[start_ix..line_ix];
1737 start_ix = line_ix + 1;
1738 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1739 record.log(log::logger())
1740 } else {
1741 eprintln!("(remote) {}", String::from_utf8_lossy(content));
1742 }
1743 }
1744 stderr_buffer.drain(0..start_ix);
1745 stderr_offset -= start_ix;
1746
1747 connection_activity_tx.try_send(()).ok();
1748 }
1749 });
1750
1751 cx.background_spawn(async move {
1752 let result = futures::select! {
1753 result = stdin_task.fuse() => {
1754 result.context("stdin")
1755 }
1756 result = stdout_task.fuse() => {
1757 result.context("stdout")
1758 }
1759 result = stderr_task.fuse() => {
1760 result.context("stderr")
1761 }
1762 };
1763
1764 let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1765 match result {
1766 Ok(_) => Ok(status),
1767 Err(error) => Err(error),
1768 }
1769 })
1770 }
1771
1772 #[allow(unused)]
1773 async fn ensure_server_binary(
1774 &self,
1775 delegate: &Arc<dyn SshClientDelegate>,
1776 release_channel: ReleaseChannel,
1777 version: SemanticVersion,
1778 commit: Option<AppCommitSha>,
1779 cx: &mut AsyncApp,
1780 ) -> Result<RemotePathBuf> {
1781 let version_str = match release_channel {
1782 ReleaseChannel::Nightly => {
1783 let commit = commit.map(|s| s.full()).unwrap_or_default();
1784 format!("{}-{}", version, commit)
1785 }
1786 ReleaseChannel::Dev => "build".to_string(),
1787 _ => version.to_string(),
1788 };
1789 let binary_name = format!(
1790 "zed-remote-server-{}-{}",
1791 release_channel.dev_name(),
1792 version_str
1793 );
1794 let dst_path = RemotePathBuf::new(
1795 paths::remote_server_dir_relative().join(binary_name),
1796 self.ssh_path_style,
1797 );
1798
1799 let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1800 #[cfg(debug_assertions)]
1801 if let Some(build_remote_server) = build_remote_server {
1802 let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1803 let tmp_path = RemotePathBuf::new(
1804 paths::remote_server_dir_relative().join(format!(
1805 "download-{}-{}",
1806 std::process::id(),
1807 src_path.file_name().unwrap().to_string_lossy()
1808 )),
1809 self.ssh_path_style,
1810 );
1811 self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1812 .await?;
1813 self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1814 .await?;
1815 return Ok(dst_path);
1816 }
1817
1818 if self
1819 .socket
1820 .run_command(&dst_path.to_string(), &["version"])
1821 .await
1822 .is_ok()
1823 {
1824 return Ok(dst_path);
1825 }
1826
1827 let wanted_version = cx.update(|cx| match release_channel {
1828 ReleaseChannel::Nightly => Ok(None),
1829 ReleaseChannel::Dev => {
1830 anyhow::bail!(
1831 "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1832 dst_path
1833 )
1834 }
1835 _ => Ok(Some(AppVersion::global(cx))),
1836 })??;
1837
1838 let tmp_path_gz = RemotePathBuf::new(
1839 PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
1840 self.ssh_path_style,
1841 );
1842 if !self.socket.connection_options.upload_binary_over_ssh
1843 && let Some((url, body)) = delegate
1844 .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1845 .await?
1846 {
1847 match self
1848 .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1849 .await
1850 {
1851 Ok(_) => {
1852 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1853 .await?;
1854 return Ok(dst_path);
1855 }
1856 Err(e) => {
1857 log::error!(
1858 "Failed to download binary on server, attempting to upload server: {}",
1859 e
1860 )
1861 }
1862 }
1863 }
1864
1865 let src_path = delegate
1866 .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1867 .await?;
1868 self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1869 .await?;
1870 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1871 .await?;
1872 Ok(dst_path)
1873 }
1874
1875 async fn download_binary_on_server(
1876 &self,
1877 url: &str,
1878 body: &str,
1879 tmp_path_gz: &RemotePathBuf,
1880 delegate: &Arc<dyn SshClientDelegate>,
1881 cx: &mut AsyncApp,
1882 ) -> Result<()> {
1883 if let Some(parent) = tmp_path_gz.parent() {
1884 self.socket
1885 .run_command(
1886 "sh",
1887 &[
1888 "-c",
1889 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1890 ],
1891 )
1892 .await?;
1893 }
1894
1895 delegate.set_status(Some("Downloading remote development server on host"), cx);
1896
1897 match self
1898 .socket
1899 .run_command(
1900 "curl",
1901 &[
1902 "-f",
1903 "-L",
1904 "-X",
1905 "GET",
1906 "-H",
1907 "Content-Type: application/json",
1908 "-d",
1909 body,
1910 url,
1911 "-o",
1912 &tmp_path_gz.to_string(),
1913 ],
1914 )
1915 .await
1916 {
1917 Ok(_) => {}
1918 Err(e) => {
1919 if self.socket.run_command("which", &["curl"]).await.is_ok() {
1920 return Err(e);
1921 }
1922
1923 match self
1924 .socket
1925 .run_command(
1926 "wget",
1927 &[
1928 "--method=GET",
1929 "--header=Content-Type: application/json",
1930 "--body-data",
1931 body,
1932 url,
1933 "-O",
1934 &tmp_path_gz.to_string(),
1935 ],
1936 )
1937 .await
1938 {
1939 Ok(_) => {}
1940 Err(e) => {
1941 if self.socket.run_command("which", &["wget"]).await.is_ok() {
1942 return Err(e);
1943 } else {
1944 anyhow::bail!("Neither curl nor wget is available");
1945 }
1946 }
1947 }
1948 }
1949 }
1950
1951 Ok(())
1952 }
1953
1954 async fn upload_local_server_binary(
1955 &self,
1956 src_path: &Path,
1957 tmp_path_gz: &RemotePathBuf,
1958 delegate: &Arc<dyn SshClientDelegate>,
1959 cx: &mut AsyncApp,
1960 ) -> Result<()> {
1961 if let Some(parent) = tmp_path_gz.parent() {
1962 self.socket
1963 .run_command(
1964 "sh",
1965 &[
1966 "-c",
1967 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1968 ],
1969 )
1970 .await?;
1971 }
1972
1973 let src_stat = fs::metadata(&src_path).await?;
1974 let size = src_stat.len();
1975
1976 let t0 = Instant::now();
1977 delegate.set_status(Some("Uploading remote development server"), cx);
1978 log::info!(
1979 "uploading remote development server to {:?} ({}kb)",
1980 tmp_path_gz,
1981 size / 1024
1982 );
1983 self.upload_file(src_path, tmp_path_gz)
1984 .await
1985 .context("failed to upload server binary")?;
1986 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1987 Ok(())
1988 }
1989
1990 async fn extract_server_binary(
1991 &self,
1992 dst_path: &RemotePathBuf,
1993 tmp_path: &RemotePathBuf,
1994 delegate: &Arc<dyn SshClientDelegate>,
1995 cx: &mut AsyncApp,
1996 ) -> Result<()> {
1997 delegate.set_status(Some("Extracting remote development server"), cx);
1998 let server_mode = 0o755;
1999
2000 let orig_tmp_path = tmp_path.to_string();
2001 let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2002 shell_script!(
2003 "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2004 server_mode = &format!("{:o}", server_mode),
2005 dst_path = &dst_path.to_string(),
2006 )
2007 } else {
2008 shell_script!(
2009 "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2010 server_mode = &format!("{:o}", server_mode),
2011 dst_path = &dst_path.to_string()
2012 )
2013 };
2014 self.socket.run_command("sh", &["-c", &script]).await?;
2015 Ok(())
2016 }
2017
2018 async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2019 log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2020 let mut command = util::command::new_smol_command("scp");
2021 let output = self
2022 .socket
2023 .ssh_options(&mut command)
2024 .args(
2025 self.socket
2026 .connection_options
2027 .port
2028 .map(|port| vec!["-P".to_string(), port.to_string()])
2029 .unwrap_or_default(),
2030 )
2031 .arg(src_path)
2032 .arg(format!(
2033 "{}:{}",
2034 self.socket.connection_options.scp_url(),
2035 dest_path
2036 ))
2037 .output()
2038 .await?;
2039
2040 anyhow::ensure!(
2041 output.status.success(),
2042 "failed to upload file {} -> {}: {}",
2043 src_path.display(),
2044 dest_path.to_string(),
2045 String::from_utf8_lossy(&output.stderr)
2046 );
2047 Ok(())
2048 }
2049
2050 #[cfg(debug_assertions)]
2051 async fn build_local(
2052 &self,
2053 build_remote_server: String,
2054 delegate: &Arc<dyn SshClientDelegate>,
2055 cx: &mut AsyncApp,
2056 ) -> Result<PathBuf> {
2057 use smol::process::{Command, Stdio};
2058 use std::env::VarError;
2059
2060 async fn run_cmd(command: &mut Command) -> Result<()> {
2061 let output = command
2062 .kill_on_drop(true)
2063 .stderr(Stdio::inherit())
2064 .output()
2065 .await?;
2066 anyhow::ensure!(
2067 output.status.success(),
2068 "Failed to run command: {command:?}"
2069 );
2070 Ok(())
2071 }
2072
2073 let use_musl = !build_remote_server.contains("nomusl");
2074 let triple = format!(
2075 "{}-{}",
2076 self.ssh_platform.arch,
2077 match self.ssh_platform.os {
2078 "linux" =>
2079 if use_musl {
2080 "unknown-linux-musl"
2081 } else {
2082 "unknown-linux-gnu"
2083 },
2084 "macos" => "apple-darwin",
2085 _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2086 }
2087 );
2088 let mut rust_flags = match std::env::var("RUSTFLAGS") {
2089 Ok(val) => val,
2090 Err(VarError::NotPresent) => String::new(),
2091 Err(e) => {
2092 log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2093 String::new()
2094 }
2095 };
2096 if self.ssh_platform.os == "linux" && use_musl {
2097 rust_flags.push_str(" -C target-feature=+crt-static");
2098 }
2099 if build_remote_server.contains("mold") {
2100 rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2101 }
2102
2103 if self.ssh_platform.arch == std::env::consts::ARCH
2104 && self.ssh_platform.os == std::env::consts::OS
2105 {
2106 delegate.set_status(Some("Building remote server binary from source"), cx);
2107 log::info!("building remote server binary from source");
2108 run_cmd(
2109 Command::new("cargo")
2110 .args([
2111 "build",
2112 "--package",
2113 "remote_server",
2114 "--features",
2115 "debug-embed",
2116 "--target-dir",
2117 "target/remote_server",
2118 "--target",
2119 &triple,
2120 ])
2121 .env("RUSTFLAGS", &rust_flags),
2122 )
2123 .await?;
2124 } else if build_remote_server.contains("cross") {
2125 #[cfg(target_os = "windows")]
2126 use util::paths::SanitizedPath;
2127
2128 delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2129 log::info!("installing cross");
2130 run_cmd(Command::new("cargo").args([
2131 "install",
2132 "cross",
2133 "--git",
2134 "https://github.com/cross-rs/cross",
2135 ]))
2136 .await?;
2137
2138 delegate.set_status(
2139 Some(&format!(
2140 "Building remote server binary from source for {} with Docker",
2141 &triple
2142 )),
2143 cx,
2144 );
2145 log::info!("building remote server binary from source for {}", &triple);
2146
2147 // On Windows, the binding needs to be set to the canonical path
2148 #[cfg(target_os = "windows")]
2149 let src =
2150 SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2151 #[cfg(not(target_os = "windows"))]
2152 let src = "./target";
2153 run_cmd(
2154 Command::new("cross")
2155 .args([
2156 "build",
2157 "--package",
2158 "remote_server",
2159 "--features",
2160 "debug-embed",
2161 "--target-dir",
2162 "target/remote_server",
2163 "--target",
2164 &triple,
2165 ])
2166 .env(
2167 "CROSS_CONTAINER_OPTS",
2168 format!("--mount type=bind,src={src},dst=/app/target"),
2169 )
2170 .env("RUSTFLAGS", &rust_flags),
2171 )
2172 .await?;
2173 } else {
2174 let which = cx
2175 .background_spawn(async move { which::which("zig") })
2176 .await;
2177
2178 if which.is_err() {
2179 #[cfg(not(target_os = "windows"))]
2180 {
2181 anyhow::bail!(
2182 "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2183 )
2184 }
2185 #[cfg(target_os = "windows")]
2186 {
2187 anyhow::bail!(
2188 "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2189 )
2190 }
2191 }
2192
2193 delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2194 log::info!("adding rustup target");
2195 run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2196
2197 delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2198 log::info!("installing cargo-zigbuild");
2199 run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
2200
2201 delegate.set_status(
2202 Some(&format!(
2203 "Building remote binary from source for {triple} with Zig"
2204 )),
2205 cx,
2206 );
2207 log::info!("building remote binary from source for {triple} with Zig");
2208 run_cmd(
2209 Command::new("cargo")
2210 .args([
2211 "zigbuild",
2212 "--package",
2213 "remote_server",
2214 "--features",
2215 "debug-embed",
2216 "--target-dir",
2217 "target/remote_server",
2218 "--target",
2219 &triple,
2220 ])
2221 .env("RUSTFLAGS", &rust_flags),
2222 )
2223 .await?;
2224 };
2225 let bin_path = Path::new("target")
2226 .join("remote_server")
2227 .join(&triple)
2228 .join("debug")
2229 .join("remote_server");
2230
2231 let path = if !build_remote_server.contains("nocompress") {
2232 delegate.set_status(Some("Compressing binary"), cx);
2233
2234 #[cfg(not(target_os = "windows"))]
2235 {
2236 run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
2237 }
2238 #[cfg(target_os = "windows")]
2239 {
2240 // On Windows, we use 7z to compress the binary
2241 let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2242 let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2243 if smol::fs::metadata(&gz_path).await.is_ok() {
2244 smol::fs::remove_file(&gz_path).await?;
2245 }
2246 run_cmd(Command::new(seven_zip).args([
2247 "a",
2248 "-tgzip",
2249 &gz_path,
2250 &bin_path.to_string_lossy(),
2251 ]))
2252 .await?;
2253 }
2254
2255 let mut archive_path = bin_path;
2256 archive_path.set_extension("gz");
2257 std::env::current_dir()?.join(archive_path)
2258 } else {
2259 bin_path
2260 };
2261
2262 Ok(path)
2263 }
2264}
2265
2266type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2267
2268struct ChannelClient {
2269 next_message_id: AtomicU32,
2270 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2271 buffer: Mutex<VecDeque<Envelope>>,
2272 response_channels: ResponseChannels,
2273 message_handlers: Mutex<ProtoMessageHandlerSet>,
2274 max_received: AtomicU32,
2275 name: &'static str,
2276 task: Mutex<Task<Result<()>>>,
2277}
2278
2279impl ChannelClient {
2280 fn new(
2281 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2282 outgoing_tx: mpsc::UnboundedSender<Envelope>,
2283 cx: &App,
2284 name: &'static str,
2285 ) -> Arc<Self> {
2286 Arc::new_cyclic(|this| Self {
2287 outgoing_tx: Mutex::new(outgoing_tx),
2288 next_message_id: AtomicU32::new(0),
2289 max_received: AtomicU32::new(0),
2290 response_channels: ResponseChannels::default(),
2291 message_handlers: Default::default(),
2292 buffer: Mutex::new(VecDeque::new()),
2293 name,
2294 task: Mutex::new(Self::start_handling_messages(
2295 this.clone(),
2296 incoming_rx,
2297 &cx.to_async(),
2298 )),
2299 })
2300 }
2301
2302 fn start_handling_messages(
2303 this: Weak<Self>,
2304 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2305 cx: &AsyncApp,
2306 ) -> Task<Result<()>> {
2307 cx.spawn(async move |cx| {
2308 let peer_id = PeerId { owner_id: 0, id: 0 };
2309 while let Some(incoming) = incoming_rx.next().await {
2310 let Some(this) = this.upgrade() else {
2311 return anyhow::Ok(());
2312 };
2313 if let Some(ack_id) = incoming.ack_id {
2314 let mut buffer = this.buffer.lock();
2315 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2316 buffer.pop_front();
2317 }
2318 }
2319 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2320 {
2321 log::debug!(
2322 "{}:ssh message received. name:FlushBufferedMessages",
2323 this.name
2324 );
2325 {
2326 let buffer = this.buffer.lock();
2327 for envelope in buffer.iter() {
2328 this.outgoing_tx
2329 .lock()
2330 .unbounded_send(envelope.clone())
2331 .ok();
2332 }
2333 }
2334 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2335 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2336 this.outgoing_tx.lock().unbounded_send(envelope).ok();
2337 continue;
2338 }
2339
2340 this.max_received.store(incoming.id, SeqCst);
2341
2342 if let Some(request_id) = incoming.responding_to {
2343 let request_id = MessageId(request_id);
2344 let sender = this.response_channels.lock().remove(&request_id);
2345 if let Some(sender) = sender {
2346 let (tx, rx) = oneshot::channel();
2347 if incoming.payload.is_some() {
2348 sender.send((incoming, tx)).ok();
2349 }
2350 rx.await.ok();
2351 }
2352 } else if let Some(envelope) =
2353 build_typed_envelope(peer_id, Instant::now(), incoming)
2354 {
2355 let type_name = envelope.payload_type_name();
2356 if let Some(future) = ProtoMessageHandlerSet::handle_message(
2357 &this.message_handlers,
2358 envelope,
2359 this.clone().into(),
2360 cx.clone(),
2361 ) {
2362 log::debug!("{}:ssh message received. name:{type_name}", this.name);
2363 cx.foreground_executor()
2364 .spawn(async move {
2365 match future.await {
2366 Ok(_) => {
2367 log::debug!(
2368 "{}:ssh message handled. name:{type_name}",
2369 this.name
2370 );
2371 }
2372 Err(error) => {
2373 log::error!(
2374 "{}:error handling message. type:{}, error:{}",
2375 this.name,
2376 type_name,
2377 format!("{error:#}").lines().fold(
2378 String::new(),
2379 |mut message, line| {
2380 if !message.is_empty() {
2381 message.push(' ');
2382 }
2383 message.push_str(line);
2384 message
2385 }
2386 )
2387 );
2388 }
2389 }
2390 })
2391 .detach()
2392 } else {
2393 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2394 }
2395 }
2396 }
2397 anyhow::Ok(())
2398 })
2399 }
2400
2401 fn reconnect(
2402 self: &Arc<Self>,
2403 incoming_rx: UnboundedReceiver<Envelope>,
2404 outgoing_tx: UnboundedSender<Envelope>,
2405 cx: &AsyncApp,
2406 ) {
2407 *self.outgoing_tx.lock() = outgoing_tx;
2408 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2409 }
2410
2411 fn request<T: RequestMessage>(
2412 &self,
2413 payload: T,
2414 ) -> impl 'static + Future<Output = Result<T::Response>> {
2415 self.request_internal(payload, true)
2416 }
2417
2418 fn request_internal<T: RequestMessage>(
2419 &self,
2420 payload: T,
2421 use_buffer: bool,
2422 ) -> impl 'static + Future<Output = Result<T::Response>> {
2423 log::debug!("ssh request start. name:{}", T::NAME);
2424 let response =
2425 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2426 async move {
2427 let response = response.await?;
2428 log::debug!("ssh request finish. name:{}", T::NAME);
2429 T::Response::from_envelope(response).context("received a response of the wrong type")
2430 }
2431 }
2432
2433 async fn resync(&self, timeout: Duration) -> Result<()> {
2434 smol::future::or(
2435 async {
2436 self.request_internal(proto::FlushBufferedMessages {}, false)
2437 .await?;
2438
2439 for envelope in self.buffer.lock().iter() {
2440 self.outgoing_tx
2441 .lock()
2442 .unbounded_send(envelope.clone())
2443 .ok();
2444 }
2445 Ok(())
2446 },
2447 async {
2448 smol::Timer::after(timeout).await;
2449 anyhow::bail!("Timed out resyncing remote client")
2450 },
2451 )
2452 .await
2453 }
2454
2455 async fn ping(&self, timeout: Duration) -> Result<()> {
2456 smol::future::or(
2457 async {
2458 self.request(proto::Ping {}).await?;
2459 Ok(())
2460 },
2461 async {
2462 smol::Timer::after(timeout).await;
2463 anyhow::bail!("Timed out pinging remote client")
2464 },
2465 )
2466 .await
2467 }
2468
2469 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2470 log::debug!("ssh send name:{}", T::NAME);
2471 self.send_dynamic(payload.into_envelope(0, None, None))
2472 }
2473
2474 fn request_dynamic(
2475 &self,
2476 mut envelope: proto::Envelope,
2477 type_name: &'static str,
2478 use_buffer: bool,
2479 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2480 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2481 let (tx, rx) = oneshot::channel();
2482 let mut response_channels_lock = self.response_channels.lock();
2483 response_channels_lock.insert(MessageId(envelope.id), tx);
2484 drop(response_channels_lock);
2485
2486 let result = if use_buffer {
2487 self.send_buffered(envelope)
2488 } else {
2489 self.send_unbuffered(envelope)
2490 };
2491 async move {
2492 if let Err(error) = &result {
2493 log::error!("failed to send message: {error}");
2494 anyhow::bail!("failed to send message: {error}");
2495 }
2496
2497 let response = rx.await.context("connection lost")?.0;
2498 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2499 return Err(RpcError::from_proto(error, type_name));
2500 }
2501 Ok(response)
2502 }
2503 }
2504
2505 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2506 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2507 self.send_buffered(envelope)
2508 }
2509
2510 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2511 envelope.ack_id = Some(self.max_received.load(SeqCst));
2512 self.buffer.lock().push_back(envelope.clone());
2513 // ignore errors on send (happen while we're reconnecting)
2514 // assume that the global "disconnected" overlay is sufficient.
2515 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2516 Ok(())
2517 }
2518
2519 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2520 envelope.ack_id = Some(self.max_received.load(SeqCst));
2521 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2522 Ok(())
2523 }
2524}
2525
2526impl ProtoClient for ChannelClient {
2527 fn request(
2528 &self,
2529 envelope: proto::Envelope,
2530 request_type: &'static str,
2531 ) -> BoxFuture<'static, Result<proto::Envelope>> {
2532 self.request_dynamic(envelope, request_type, true).boxed()
2533 }
2534
2535 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2536 self.send_dynamic(envelope)
2537 }
2538
2539 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2540 self.send_dynamic(envelope)
2541 }
2542
2543 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2544 &self.message_handlers
2545 }
2546
2547 fn is_via_collab(&self) -> bool {
2548 false
2549 }
2550}
2551
2552#[cfg(any(test, feature = "test-support"))]
2553mod fake {
2554 use std::{path::PathBuf, sync::Arc};
2555
2556 use anyhow::Result;
2557 use async_trait::async_trait;
2558 use futures::{
2559 FutureExt, SinkExt, StreamExt,
2560 channel::{
2561 mpsc::{self, Sender},
2562 oneshot,
2563 },
2564 select_biased,
2565 };
2566 use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2567 use release_channel::ReleaseChannel;
2568 use rpc::proto::Envelope;
2569 use util::paths::{PathStyle, RemotePathBuf};
2570
2571 use super::{
2572 ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2573 SshPlatform,
2574 };
2575
2576 pub(super) struct FakeRemoteConnection {
2577 pub(super) connection_options: SshConnectionOptions,
2578 pub(super) server_channel: Arc<ChannelClient>,
2579 pub(super) server_cx: SendableCx,
2580 }
2581
2582 pub(super) struct SendableCx(AsyncApp);
2583 impl SendableCx {
2584 // SAFETY: When run in test mode, GPUI is always single threaded.
2585 pub(super) fn new(cx: &TestAppContext) -> Self {
2586 Self(cx.to_async())
2587 }
2588
2589 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2590 fn get(&self, _: &AsyncApp) -> AsyncApp {
2591 self.0.clone()
2592 }
2593 }
2594
2595 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2596 unsafe impl Send for SendableCx {}
2597 unsafe impl Sync for SendableCx {}
2598
2599 #[async_trait(?Send)]
2600 impl RemoteConnection for FakeRemoteConnection {
2601 async fn kill(&self) -> Result<()> {
2602 Ok(())
2603 }
2604
2605 fn has_been_killed(&self) -> bool {
2606 false
2607 }
2608
2609 fn ssh_args(&self) -> SshArgs {
2610 SshArgs {
2611 arguments: Vec::new(),
2612 envs: None,
2613 }
2614 }
2615
2616 fn upload_directory(
2617 &self,
2618 _src_path: PathBuf,
2619 _dest_path: RemotePathBuf,
2620 _cx: &App,
2621 ) -> Task<Result<()>> {
2622 unreachable!()
2623 }
2624
2625 fn connection_options(&self) -> SshConnectionOptions {
2626 self.connection_options.clone()
2627 }
2628
2629 fn simulate_disconnect(&self, cx: &AsyncApp) {
2630 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2631 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2632 self.server_channel
2633 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
2634 }
2635
2636 fn start_proxy(
2637 &self,
2638 _unique_identifier: String,
2639 _reconnect: bool,
2640 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2641 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2642 mut connection_activity_tx: Sender<()>,
2643 _delegate: Arc<dyn SshClientDelegate>,
2644 cx: &mut AsyncApp,
2645 ) -> Task<Result<i32>> {
2646 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2647 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2648
2649 self.server_channel.reconnect(
2650 server_incoming_rx,
2651 server_outgoing_tx,
2652 &self.server_cx.get(cx),
2653 );
2654
2655 cx.background_spawn(async move {
2656 loop {
2657 select_biased! {
2658 server_to_client = server_outgoing_rx.next().fuse() => {
2659 let Some(server_to_client) = server_to_client else {
2660 return Ok(1)
2661 };
2662 connection_activity_tx.try_send(()).ok();
2663 client_incoming_tx.send(server_to_client).await.ok();
2664 }
2665 client_to_server = client_outgoing_rx.next().fuse() => {
2666 let Some(client_to_server) = client_to_server else {
2667 return Ok(1)
2668 };
2669 server_incoming_tx.send(client_to_server).await.ok();
2670 }
2671 }
2672 }
2673 })
2674 }
2675
2676 fn path_style(&self) -> PathStyle {
2677 PathStyle::current()
2678 }
2679 }
2680
2681 pub(super) struct Delegate;
2682
2683 impl SshClientDelegate for Delegate {
2684 fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2685 unreachable!()
2686 }
2687
2688 fn download_server_binary_locally(
2689 &self,
2690 _: SshPlatform,
2691 _: ReleaseChannel,
2692 _: Option<SemanticVersion>,
2693 _: &mut AsyncApp,
2694 ) -> Task<Result<PathBuf>> {
2695 unreachable!()
2696 }
2697
2698 fn get_download_params(
2699 &self,
2700 _platform: SshPlatform,
2701 _release_channel: ReleaseChannel,
2702 _version: Option<SemanticVersion>,
2703 _cx: &mut AsyncApp,
2704 ) -> Task<Result<Option<(String, String)>>> {
2705 unreachable!()
2706 }
2707
2708 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2709 }
2710}