1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{Context as _, Result, anyhow};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
13 channel::{
14 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
15 oneshot,
16 },
17 future::{BoxFuture, Shared},
18 select, select_biased,
19};
20use gpui::{
21 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
22 EventEmitter, Global, SemanticVersion, Task, WeakEntity,
23};
24use itertools::Itertools;
25use parking_lot::Mutex;
26
27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
28use rpc::{
29 AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
30 RpcError,
31 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
32};
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use smol::{
36 fs,
37 process::{self, Child, Stdio},
38};
39use std::{
40 any::TypeId,
41 collections::VecDeque,
42 fmt, iter,
43 ops::ControlFlow,
44 path::{Path, PathBuf},
45 sync::{
46 Arc, Weak,
47 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
48 },
49 time::{Duration, Instant},
50};
51use tempfile::TempDir;
52use util::{
53 ResultExt,
54 paths::{PathStyle, RemotePathBuf},
55};
56
57#[derive(
58 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
59)]
60pub struct SshProjectId(pub u64);
61
62#[derive(Clone)]
63pub struct SshSocket {
64 connection_options: SshConnectionOptions,
65 #[cfg(not(target_os = "windows"))]
66 socket_path: PathBuf,
67 #[cfg(target_os = "windows")]
68 envs: HashMap<String, String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
72pub struct SshPortForwardOption {
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub local_host: Option<String>,
75 pub local_port: u16,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub remote_host: Option<String>,
78 pub remote_port: u16,
79}
80
81#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
82pub struct SshConnectionOptions {
83 pub host: String,
84 pub username: Option<String>,
85 pub port: Option<u16>,
86 pub password: Option<String>,
87 pub args: Option<Vec<String>>,
88 pub port_forwards: Option<Vec<SshPortForwardOption>>,
89
90 pub nickname: Option<String>,
91 pub upload_binary_over_ssh: bool,
92}
93
94pub struct SshArgs {
95 pub arguments: Vec<String>,
96 pub envs: Option<HashMap<String, String>>,
97}
98
99#[macro_export]
100macro_rules! shell_script {
101 ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
102 format!(
103 $fmt,
104 $(
105 $name = shlex::try_quote($arg).unwrap()
106 ),+
107 )
108 }};
109}
110
111fn parse_port_number(port_str: &str) -> Result<u16> {
112 port_str
113 .parse()
114 .with_context(|| format!("parsing port number: {port_str}"))
115}
116
117fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
118 let parts: Vec<&str> = spec.split(':').collect();
119
120 match parts.len() {
121 4 => {
122 let local_port = parse_port_number(parts[1])?;
123 let remote_port = parse_port_number(parts[3])?;
124
125 Ok(SshPortForwardOption {
126 local_host: Some(parts[0].to_string()),
127 local_port,
128 remote_host: Some(parts[2].to_string()),
129 remote_port,
130 })
131 }
132 3 => {
133 let local_port = parse_port_number(parts[0])?;
134 let remote_port = parse_port_number(parts[2])?;
135
136 Ok(SshPortForwardOption {
137 local_host: None,
138 local_port,
139 remote_host: Some(parts[1].to_string()),
140 remote_port,
141 })
142 }
143 _ => anyhow::bail!("Invalid port forward format"),
144 }
145}
146
147impl SshConnectionOptions {
148 pub fn parse_command_line(input: &str) -> Result<Self> {
149 let input = input.trim_start_matches("ssh ");
150 let mut hostname: Option<String> = None;
151 let mut username: Option<String> = None;
152 let mut port: Option<u16> = None;
153 let mut args = Vec::new();
154 let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
155
156 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
157 const ALLOWED_OPTS: &[&str] = &[
158 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
159 ];
160 const ALLOWED_ARGS: &[&str] = &[
161 "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
162 "-w",
163 ];
164
165 let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
166
167 'outer: while let Some(arg) = tokens.next() {
168 if ALLOWED_OPTS.contains(&(&arg as &str)) {
169 args.push(arg.to_string());
170 continue;
171 }
172 if arg == "-p" {
173 port = tokens.next().and_then(|arg| arg.parse().ok());
174 continue;
175 } else if let Some(p) = arg.strip_prefix("-p") {
176 port = p.parse().ok();
177 continue;
178 }
179 if arg == "-l" {
180 username = tokens.next();
181 continue;
182 } else if let Some(l) = arg.strip_prefix("-l") {
183 username = Some(l.to_string());
184 continue;
185 }
186 if arg == "-L" || arg.starts_with("-L") {
187 let forward_spec = if arg == "-L" {
188 tokens.next()
189 } else {
190 Some(arg.strip_prefix("-L").unwrap().to_string())
191 };
192
193 if let Some(spec) = forward_spec {
194 port_forwards.push(parse_port_forward_spec(&spec)?);
195 } else {
196 anyhow::bail!("Missing port forward format");
197 }
198 }
199
200 for a in ALLOWED_ARGS {
201 if arg == *a {
202 args.push(arg);
203 if let Some(next) = tokens.next() {
204 args.push(next);
205 }
206 continue 'outer;
207 } else if arg.starts_with(a) {
208 args.push(arg);
209 continue 'outer;
210 }
211 }
212 if arg.starts_with("-") || hostname.is_some() {
213 anyhow::bail!("unsupported argument: {:?}", arg);
214 }
215 let mut input = &arg as &str;
216 // Destination might be: username1@username2@ip2@ip1
217 if let Some((u, rest)) = input.rsplit_once('@') {
218 input = rest;
219 username = Some(u.to_string());
220 }
221 if let Some((rest, p)) = input.split_once(':') {
222 input = rest;
223 port = p.parse().ok()
224 }
225 hostname = Some(input.to_string())
226 }
227
228 let Some(hostname) = hostname else {
229 anyhow::bail!("missing hostname");
230 };
231
232 let port_forwards = match port_forwards.len() {
233 0 => None,
234 _ => Some(port_forwards),
235 };
236
237 Ok(Self {
238 host: hostname.to_string(),
239 username: username.clone(),
240 port,
241 port_forwards,
242 args: Some(args),
243 password: None,
244 nickname: None,
245 upload_binary_over_ssh: false,
246 })
247 }
248
249 pub fn ssh_url(&self) -> String {
250 let mut result = String::from("ssh://");
251 if let Some(username) = &self.username {
252 // Username might be: username1@username2@ip2
253 let username = urlencoding::encode(username);
254 result.push_str(&username);
255 result.push('@');
256 }
257 result.push_str(&self.host);
258 if let Some(port) = self.port {
259 result.push(':');
260 result.push_str(&port.to_string());
261 }
262 result
263 }
264
265 pub fn additional_args(&self) -> Vec<String> {
266 let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
267
268 if let Some(forwards) = &self.port_forwards {
269 args.extend(forwards.iter().map(|pf| {
270 let local_host = match &pf.local_host {
271 Some(host) => host,
272 None => "localhost",
273 };
274 let remote_host = match &pf.remote_host {
275 Some(host) => host,
276 None => "localhost",
277 };
278
279 format!(
280 "-L{}:{}:{}:{}",
281 local_host, pf.local_port, remote_host, pf.remote_port
282 )
283 }));
284 }
285
286 args
287 }
288
289 fn scp_url(&self) -> String {
290 if let Some(username) = &self.username {
291 format!("{}@{}", username, self.host)
292 } else {
293 self.host.clone()
294 }
295 }
296
297 pub fn connection_string(&self) -> String {
298 let host = if let Some(username) = &self.username {
299 format!("{}@{}", username, self.host)
300 } else {
301 self.host.clone()
302 };
303 if let Some(port) = &self.port {
304 format!("{}:{}", host, port)
305 } else {
306 host
307 }
308 }
309}
310
311#[derive(Copy, Clone, Debug)]
312pub struct SshPlatform {
313 pub os: &'static str,
314 pub arch: &'static str,
315}
316
317pub trait SshClientDelegate: Send + Sync {
318 fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
319 fn get_download_params(
320 &self,
321 platform: SshPlatform,
322 release_channel: ReleaseChannel,
323 version: Option<SemanticVersion>,
324 cx: &mut AsyncApp,
325 ) -> Task<Result<Option<(String, String)>>>;
326
327 fn download_server_binary_locally(
328 &self,
329 platform: SshPlatform,
330 release_channel: ReleaseChannel,
331 version: Option<SemanticVersion>,
332 cx: &mut AsyncApp,
333 ) -> Task<Result<PathBuf>>;
334 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
335}
336
337impl SshSocket {
338 #[cfg(not(target_os = "windows"))]
339 fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
340 Ok(Self {
341 connection_options: options,
342 socket_path,
343 })
344 }
345
346 #[cfg(target_os = "windows")]
347 fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
348 let askpass_script = temp_dir.path().join("askpass.bat");
349 std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
350 let mut envs = HashMap::default();
351 envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
352 envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
353 envs.insert("ZED_SSH_ASKPASS".into(), secret);
354 Ok(Self {
355 connection_options: options,
356 envs,
357 })
358 }
359
360 // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
361 // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
362 // and passes -l as an argument to sh, not to ls.
363 // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
364 // into a machine. You must use `cd` to get back to $HOME.
365 // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
366 fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
367 let mut command = util::command::new_smol_command("ssh");
368 let to_run = iter::once(&program)
369 .chain(args.iter())
370 .map(|token| {
371 // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
372 debug_assert!(
373 !token.contains('\n'),
374 "multiline arguments do not work in all shells"
375 );
376 shlex::try_quote(token).unwrap()
377 })
378 .join(" ");
379 let to_run = format!("cd; {to_run}");
380 log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
381 self.ssh_options(&mut command)
382 .arg(self.connection_options.ssh_url())
383 .arg(to_run);
384 command
385 }
386
387 async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
388 let output = self.ssh_command(program, args).output().await?;
389 anyhow::ensure!(
390 output.status.success(),
391 "failed to run command: {}",
392 String::from_utf8_lossy(&output.stderr)
393 );
394 Ok(String::from_utf8_lossy(&output.stdout).to_string())
395 }
396
397 #[cfg(not(target_os = "windows"))]
398 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
399 command
400 .stdin(Stdio::piped())
401 .stdout(Stdio::piped())
402 .stderr(Stdio::piped())
403 .args(self.connection_options.additional_args())
404 .args(["-o", "ControlMaster=no", "-o"])
405 .arg(format!("ControlPath={}", self.socket_path.display()))
406 }
407
408 #[cfg(target_os = "windows")]
409 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
410 command
411 .stdin(Stdio::piped())
412 .stdout(Stdio::piped())
413 .stderr(Stdio::piped())
414 .args(self.connection_options.additional_args())
415 .envs(self.envs.clone())
416 }
417
418 // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
419 // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
420 #[cfg(not(target_os = "windows"))]
421 fn ssh_args(&self) -> SshArgs {
422 let mut arguments = self.connection_options.additional_args();
423 arguments.extend(vec![
424 "-o".to_string(),
425 "ControlMaster=no".to_string(),
426 "-o".to_string(),
427 format!("ControlPath={}", self.socket_path.display()),
428 self.connection_options.ssh_url(),
429 ]);
430 SshArgs {
431 arguments,
432 envs: None,
433 }
434 }
435
436 #[cfg(target_os = "windows")]
437 fn ssh_args(&self) -> SshArgs {
438 let mut arguments = self.connection_options.additional_args();
439 arguments.push(self.connection_options.ssh_url());
440 SshArgs {
441 arguments,
442 envs: Some(self.envs.clone()),
443 }
444 }
445
446 async fn platform(&self) -> Result<SshPlatform> {
447 let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
448 let Some((os, arch)) = uname.split_once(" ") else {
449 anyhow::bail!("unknown uname: {uname:?}")
450 };
451
452 let os = match os.trim() {
453 "Darwin" => "macos",
454 "Linux" => "linux",
455 _ => anyhow::bail!(
456 "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
457 ),
458 };
459 // exclude armv5,6,7 as they are 32-bit.
460 let arch = if arch.starts_with("armv8")
461 || arch.starts_with("armv9")
462 || arch.starts_with("arm64")
463 || arch.starts_with("aarch64")
464 {
465 "aarch64"
466 } else if arch.starts_with("x86") {
467 "x86_64"
468 } else {
469 anyhow::bail!(
470 "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
471 )
472 };
473
474 Ok(SshPlatform { os, arch })
475 }
476}
477
478const MAX_MISSED_HEARTBEATS: usize = 5;
479const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
480const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
481
482const MAX_RECONNECT_ATTEMPTS: usize = 3;
483
484enum State {
485 Connecting,
486 Connected {
487 ssh_connection: Arc<dyn RemoteConnection>,
488 delegate: Arc<dyn SshClientDelegate>,
489
490 multiplex_task: Task<Result<()>>,
491 heartbeat_task: Task<Result<()>>,
492 },
493 HeartbeatMissed {
494 missed_heartbeats: usize,
495
496 ssh_connection: Arc<dyn RemoteConnection>,
497 delegate: Arc<dyn SshClientDelegate>,
498
499 multiplex_task: Task<Result<()>>,
500 heartbeat_task: Task<Result<()>>,
501 },
502 Reconnecting,
503 ReconnectFailed {
504 ssh_connection: Arc<dyn RemoteConnection>,
505 delegate: Arc<dyn SshClientDelegate>,
506
507 error: anyhow::Error,
508 attempts: usize,
509 },
510 ReconnectExhausted,
511 ServerNotRunning,
512}
513
514impl fmt::Display for State {
515 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
516 match self {
517 Self::Connecting => write!(f, "connecting"),
518 Self::Connected { .. } => write!(f, "connected"),
519 Self::Reconnecting => write!(f, "reconnecting"),
520 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
521 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
522 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
523 Self::ServerNotRunning { .. } => write!(f, "server not running"),
524 }
525 }
526}
527
528impl State {
529 fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
530 match self {
531 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
532 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
533 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
534 _ => None,
535 }
536 }
537
538 fn can_reconnect(&self) -> bool {
539 match self {
540 Self::Connected { .. }
541 | Self::HeartbeatMissed { .. }
542 | Self::ReconnectFailed { .. } => true,
543 State::Connecting
544 | State::Reconnecting
545 | State::ReconnectExhausted
546 | State::ServerNotRunning => false,
547 }
548 }
549
550 fn is_reconnect_failed(&self) -> bool {
551 matches!(self, Self::ReconnectFailed { .. })
552 }
553
554 fn is_reconnect_exhausted(&self) -> bool {
555 matches!(self, Self::ReconnectExhausted { .. })
556 }
557
558 fn is_server_not_running(&self) -> bool {
559 matches!(self, Self::ServerNotRunning)
560 }
561
562 fn is_reconnecting(&self) -> bool {
563 matches!(self, Self::Reconnecting { .. })
564 }
565
566 fn heartbeat_recovered(self) -> Self {
567 match self {
568 Self::HeartbeatMissed {
569 ssh_connection,
570 delegate,
571 multiplex_task,
572 heartbeat_task,
573 ..
574 } => Self::Connected {
575 ssh_connection,
576 delegate,
577 multiplex_task,
578 heartbeat_task,
579 },
580 _ => self,
581 }
582 }
583
584 fn heartbeat_missed(self) -> Self {
585 match self {
586 Self::Connected {
587 ssh_connection,
588 delegate,
589 multiplex_task,
590 heartbeat_task,
591 } => Self::HeartbeatMissed {
592 missed_heartbeats: 1,
593 ssh_connection,
594 delegate,
595 multiplex_task,
596 heartbeat_task,
597 },
598 Self::HeartbeatMissed {
599 missed_heartbeats,
600 ssh_connection,
601 delegate,
602 multiplex_task,
603 heartbeat_task,
604 } => Self::HeartbeatMissed {
605 missed_heartbeats: missed_heartbeats + 1,
606 ssh_connection,
607 delegate,
608 multiplex_task,
609 heartbeat_task,
610 },
611 _ => self,
612 }
613 }
614}
615
616/// The state of the ssh connection.
617#[derive(Clone, Copy, Debug, PartialEq, Eq)]
618pub enum ConnectionState {
619 Connecting,
620 Connected,
621 HeartbeatMissed,
622 Reconnecting,
623 Disconnected,
624}
625
626impl From<&State> for ConnectionState {
627 fn from(value: &State) -> Self {
628 match value {
629 State::Connecting => Self::Connecting,
630 State::Connected { .. } => Self::Connected,
631 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
632 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
633 State::ReconnectExhausted => Self::Disconnected,
634 State::ServerNotRunning => Self::Disconnected,
635 }
636 }
637}
638
639pub struct SshRemoteClient {
640 client: Arc<ChannelClient>,
641 unique_identifier: String,
642 connection_options: SshConnectionOptions,
643 path_style: PathStyle,
644 state: Arc<Mutex<Option<State>>>,
645}
646
647#[derive(Debug)]
648pub enum SshRemoteEvent {
649 Disconnected,
650}
651
652impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
653
654// Identifies the socket on the remote server so that reconnects
655// can re-join the same project.
656pub enum ConnectionIdentifier {
657 Setup(u64),
658 Workspace(i64),
659}
660
661static NEXT_ID: AtomicU64 = AtomicU64::new(1);
662
663impl ConnectionIdentifier {
664 pub fn setup() -> Self {
665 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
666 }
667 // This string gets used in a socket name, and so must be relatively short.
668 // The total length of:
669 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
670 // Must be less than about 100 characters
671 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
672 // So our strings should be at most 20 characters or so.
673 fn to_string(&self, cx: &App) -> String {
674 let identifier_prefix = match ReleaseChannel::global(cx) {
675 ReleaseChannel::Stable => "".to_string(),
676 release_channel => format!("{}-", release_channel.dev_name()),
677 };
678 match self {
679 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
680 Self::Workspace(workspace_id) => {
681 format!("{identifier_prefix}workspace-{workspace_id}",)
682 }
683 }
684 }
685}
686
687impl SshRemoteClient {
688 pub fn new(
689 unique_identifier: ConnectionIdentifier,
690 connection_options: SshConnectionOptions,
691 cancellation: oneshot::Receiver<()>,
692 delegate: Arc<dyn SshClientDelegate>,
693 cx: &mut App,
694 ) -> Task<Result<Option<Entity<Self>>>> {
695 let unique_identifier = unique_identifier.to_string(cx);
696 cx.spawn(async move |cx| {
697 let success = Box::pin(async move {
698 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
699 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
700 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
701
702 let client =
703 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
704
705 let ssh_connection = cx
706 .update(|cx| {
707 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
708 pool.connect(connection_options.clone(), &delegate, cx)
709 })
710 })?
711 .await
712 .map_err(|e| e.cloned())?;
713
714 let path_style = ssh_connection.path_style();
715 let this = cx.new(|_| Self {
716 client: client.clone(),
717 unique_identifier: unique_identifier.clone(),
718 connection_options,
719 path_style,
720 state: Arc::new(Mutex::new(Some(State::Connecting))),
721 })?;
722
723 let io_task = ssh_connection.start_proxy(
724 unique_identifier,
725 false,
726 incoming_tx,
727 outgoing_rx,
728 connection_activity_tx,
729 delegate.clone(),
730 cx,
731 );
732
733 let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
734
735 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
736 log::error!("failed to establish connection: {}", error);
737 return Err(error);
738 }
739
740 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
741
742 this.update(cx, |this, _| {
743 *this.state.lock() = Some(State::Connected {
744 ssh_connection,
745 delegate,
746 multiplex_task,
747 heartbeat_task,
748 });
749 })?;
750
751 Ok(Some(this))
752 });
753
754 select! {
755 _ = cancellation.fuse() => {
756 Ok(None)
757 }
758 result = success.fuse() => result
759 }
760 })
761 }
762
763 pub fn shutdown_processes<T: RequestMessage>(
764 &self,
765 shutdown_request: Option<T>,
766 executor: BackgroundExecutor,
767 ) -> Option<impl Future<Output = ()> + use<T>> {
768 let state = self.state.lock().take()?;
769 log::info!("shutting down ssh processes");
770
771 let State::Connected {
772 multiplex_task,
773 heartbeat_task,
774 ssh_connection,
775 delegate,
776 } = state
777 else {
778 return None;
779 };
780
781 let client = self.client.clone();
782
783 Some(async move {
784 if let Some(shutdown_request) = shutdown_request {
785 client.send(shutdown_request).log_err();
786 // We wait 50ms instead of waiting for a response, because
787 // waiting for a response would require us to wait on the main thread
788 // which we want to avoid in an `on_app_quit` callback.
789 executor.timer(Duration::from_millis(50)).await;
790 }
791
792 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
793 // child of master_process.
794 drop(multiplex_task);
795 // Now drop the rest of state, which kills master process.
796 drop(heartbeat_task);
797 drop(ssh_connection);
798 drop(delegate);
799 })
800 }
801
802 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
803 let mut lock = self.state.lock();
804
805 let can_reconnect = lock
806 .as_ref()
807 .map(|state| state.can_reconnect())
808 .unwrap_or(false);
809 if !can_reconnect {
810 log::info!("aborting reconnect, because not in state that allows reconnecting");
811 let error = if let Some(state) = lock.as_ref() {
812 format!("invalid state, cannot reconnect while in state {state}")
813 } else {
814 "no state set".to_string()
815 };
816 anyhow::bail!(error);
817 }
818
819 let state = lock.take().unwrap();
820 let (attempts, ssh_connection, delegate) = match state {
821 State::Connected {
822 ssh_connection,
823 delegate,
824 multiplex_task,
825 heartbeat_task,
826 }
827 | State::HeartbeatMissed {
828 ssh_connection,
829 delegate,
830 multiplex_task,
831 heartbeat_task,
832 ..
833 } => {
834 drop(multiplex_task);
835 drop(heartbeat_task);
836 (0, ssh_connection, delegate)
837 }
838 State::ReconnectFailed {
839 attempts,
840 ssh_connection,
841 delegate,
842 ..
843 } => (attempts, ssh_connection, delegate),
844 State::Connecting
845 | State::Reconnecting
846 | State::ReconnectExhausted
847 | State::ServerNotRunning => unreachable!(),
848 };
849
850 let attempts = attempts + 1;
851 if attempts > MAX_RECONNECT_ATTEMPTS {
852 log::error!(
853 "Failed to reconnect to after {} attempts, giving up",
854 MAX_RECONNECT_ATTEMPTS
855 );
856 drop(lock);
857 self.set_state(State::ReconnectExhausted, cx);
858 return Ok(());
859 }
860 drop(lock);
861
862 self.set_state(State::Reconnecting, cx);
863
864 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
865
866 let unique_identifier = self.unique_identifier.clone();
867 let client = self.client.clone();
868 let reconnect_task = cx.spawn(async move |this, cx| {
869 macro_rules! failed {
870 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
871 return State::ReconnectFailed {
872 error: anyhow!($error),
873 attempts: $attempts,
874 ssh_connection: $ssh_connection,
875 delegate: $delegate,
876 };
877 };
878 }
879
880 if let Err(error) = ssh_connection
881 .kill()
882 .await
883 .context("Failed to kill ssh process")
884 {
885 failed!(error, attempts, ssh_connection, delegate);
886 };
887
888 let connection_options = ssh_connection.connection_options();
889
890 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
891 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
892 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
893
894 let (ssh_connection, io_task) = match async {
895 let ssh_connection = cx
896 .update_global(|pool: &mut ConnectionPool, cx| {
897 pool.connect(connection_options, &delegate, cx)
898 })?
899 .await
900 .map_err(|error| error.cloned())?;
901
902 let io_task = ssh_connection.start_proxy(
903 unique_identifier,
904 true,
905 incoming_tx,
906 outgoing_rx,
907 connection_activity_tx,
908 delegate.clone(),
909 cx,
910 );
911 anyhow::Ok((ssh_connection, io_task))
912 }
913 .await
914 {
915 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
916 Err(error) => {
917 failed!(error, attempts, ssh_connection, delegate);
918 }
919 };
920
921 let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
922 client.reconnect(incoming_rx, outgoing_tx, &cx);
923
924 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
925 failed!(error, attempts, ssh_connection, delegate);
926 };
927
928 State::Connected {
929 ssh_connection,
930 delegate,
931 multiplex_task,
932 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
933 }
934 });
935
936 cx.spawn(async move |this, cx| {
937 let new_state = reconnect_task.await;
938 this.update(cx, |this, cx| {
939 this.try_set_state(cx, |old_state| {
940 if old_state.is_reconnecting() {
941 match &new_state {
942 State::Connecting
943 | State::Reconnecting { .. }
944 | State::HeartbeatMissed { .. }
945 | State::ServerNotRunning => {}
946 State::Connected { .. } => {
947 log::info!("Successfully reconnected");
948 }
949 State::ReconnectFailed {
950 error, attempts, ..
951 } => {
952 log::error!(
953 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
954 attempts,
955 error
956 );
957 }
958 State::ReconnectExhausted => {
959 log::error!("Reconnect attempt failed and all attempts exhausted");
960 }
961 }
962 Some(new_state)
963 } else {
964 None
965 }
966 });
967
968 if this.state_is(State::is_reconnect_failed) {
969 this.reconnect(cx)
970 } else if this.state_is(State::is_reconnect_exhausted) {
971 Ok(())
972 } else {
973 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
974 Ok(())
975 }
976 })
977 })
978 .detach_and_log_err(cx);
979
980 Ok(())
981 }
982
983 fn heartbeat(
984 this: WeakEntity<Self>,
985 mut connection_activity_rx: mpsc::Receiver<()>,
986 cx: &mut AsyncApp,
987 ) -> Task<Result<()>> {
988 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
989 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
990 };
991
992 cx.spawn(async move |cx| {
993 let mut missed_heartbeats = 0;
994
995 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
996 futures::pin_mut!(keepalive_timer);
997
998 loop {
999 select_biased! {
1000 result = connection_activity_rx.next().fuse() => {
1001 if result.is_none() {
1002 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1003 return Ok(());
1004 }
1005
1006 if missed_heartbeats != 0 {
1007 missed_heartbeats = 0;
1008 let _ =this.update(cx, |this, mut cx| {
1009 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1010 })?;
1011 }
1012 }
1013 _ = keepalive_timer => {
1014 log::debug!("Sending heartbeat to server...");
1015
1016 let result = select_biased! {
1017 _ = connection_activity_rx.next().fuse() => {
1018 Ok(())
1019 }
1020 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1021 ping_result
1022 }
1023 };
1024
1025 if result.is_err() {
1026 missed_heartbeats += 1;
1027 log::warn!(
1028 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1029 HEARTBEAT_TIMEOUT,
1030 missed_heartbeats,
1031 MAX_MISSED_HEARTBEATS
1032 );
1033 } else if missed_heartbeats != 0 {
1034 missed_heartbeats = 0;
1035 } else {
1036 continue;
1037 }
1038
1039 let result = this.update(cx, |this, mut cx| {
1040 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
1041 })?;
1042 if result.is_break() {
1043 return Ok(());
1044 }
1045 }
1046 }
1047
1048 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1049 }
1050
1051 })
1052 }
1053
1054 fn handle_heartbeat_result(
1055 &mut self,
1056 missed_heartbeats: usize,
1057 cx: &mut Context<Self>,
1058 ) -> ControlFlow<()> {
1059 let state = self.state.lock().take().unwrap();
1060 let next_state = if missed_heartbeats > 0 {
1061 state.heartbeat_missed()
1062 } else {
1063 state.heartbeat_recovered()
1064 };
1065
1066 self.set_state(next_state, cx);
1067
1068 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1069 log::error!(
1070 "Missed last {} heartbeats. Reconnecting...",
1071 missed_heartbeats
1072 );
1073
1074 self.reconnect(cx)
1075 .context("failed to start reconnect process after missing heartbeats")
1076 .log_err();
1077 ControlFlow::Break(())
1078 } else {
1079 ControlFlow::Continue(())
1080 }
1081 }
1082
1083 fn monitor(
1084 this: WeakEntity<Self>,
1085 io_task: Task<Result<i32>>,
1086 cx: &AsyncApp,
1087 ) -> Task<Result<()>> {
1088 cx.spawn(async move |cx| {
1089 let result = io_task.await;
1090
1091 match result {
1092 Ok(exit_code) => {
1093 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1094 match error {
1095 ProxyLaunchError::ServerNotRunning => {
1096 log::error!("failed to reconnect because server is not running");
1097 this.update(cx, |this, cx| {
1098 this.set_state(State::ServerNotRunning, cx);
1099 })?;
1100 }
1101 }
1102 } else if exit_code > 0 {
1103 log::error!("proxy process terminated unexpectedly");
1104 this.update(cx, |this, cx| {
1105 this.reconnect(cx).ok();
1106 })?;
1107 }
1108 }
1109 Err(error) => {
1110 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1111 this.update(cx, |this, cx| {
1112 this.reconnect(cx).ok();
1113 })?;
1114 }
1115 }
1116
1117 Ok(())
1118 })
1119 }
1120
1121 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1122 self.state.lock().as_ref().map_or(false, check)
1123 }
1124
1125 fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1126 let mut lock = self.state.lock();
1127 let new_state = lock.as_ref().and_then(map);
1128
1129 if let Some(new_state) = new_state {
1130 lock.replace(new_state);
1131 cx.notify();
1132 }
1133 }
1134
1135 fn set_state(&self, state: State, cx: &mut Context<Self>) {
1136 log::info!("setting state to '{}'", &state);
1137
1138 let is_reconnect_exhausted = state.is_reconnect_exhausted();
1139 let is_server_not_running = state.is_server_not_running();
1140 self.state.lock().replace(state);
1141
1142 if is_reconnect_exhausted || is_server_not_running {
1143 cx.emit(SshRemoteEvent::Disconnected);
1144 }
1145 cx.notify();
1146 }
1147
1148 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
1149 self.client.subscribe_to_entity(remote_id, entity);
1150 }
1151
1152 pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
1153 self.state
1154 .lock()
1155 .as_ref()
1156 .and_then(|state| state.ssh_connection())
1157 .map(|ssh_connection| (ssh_connection.ssh_args(), ssh_connection.path_style()))
1158 }
1159
1160 pub fn upload_directory(
1161 &self,
1162 src_path: PathBuf,
1163 dest_path: RemotePathBuf,
1164 cx: &App,
1165 ) -> Task<Result<()>> {
1166 let state = self.state.lock();
1167 let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1168 return Task::ready(Err(anyhow!("no ssh connection")));
1169 };
1170 connection.upload_directory(src_path, dest_path, cx)
1171 }
1172
1173 pub fn proto_client(&self) -> AnyProtoClient {
1174 self.client.clone().into()
1175 }
1176
1177 pub fn connection_string(&self) -> String {
1178 self.connection_options.connection_string()
1179 }
1180
1181 pub fn connection_options(&self) -> SshConnectionOptions {
1182 self.connection_options.clone()
1183 }
1184
1185 pub fn connection_state(&self) -> ConnectionState {
1186 self.state
1187 .lock()
1188 .as_ref()
1189 .map(ConnectionState::from)
1190 .unwrap_or(ConnectionState::Disconnected)
1191 }
1192
1193 pub fn is_disconnected(&self) -> bool {
1194 self.connection_state() == ConnectionState::Disconnected
1195 }
1196
1197 pub fn path_style(&self) -> PathStyle {
1198 self.path_style
1199 }
1200
1201 #[cfg(any(test, feature = "test-support"))]
1202 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1203 let opts = self.connection_options();
1204 client_cx.spawn(async move |cx| {
1205 let connection = cx
1206 .update_global(|c: &mut ConnectionPool, _| {
1207 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1208 c.clone()
1209 } else {
1210 panic!("missing test connection")
1211 }
1212 })
1213 .unwrap()
1214 .await
1215 .unwrap();
1216
1217 connection.simulate_disconnect(&cx);
1218 })
1219 }
1220
1221 #[cfg(any(test, feature = "test-support"))]
1222 pub fn fake_server(
1223 client_cx: &mut gpui::TestAppContext,
1224 server_cx: &mut gpui::TestAppContext,
1225 ) -> (SshConnectionOptions, Arc<ChannelClient>) {
1226 let port = client_cx
1227 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1228 let opts = SshConnectionOptions {
1229 host: "<fake>".to_string(),
1230 port: Some(port),
1231 ..Default::default()
1232 };
1233 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1234 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1235 let server_client =
1236 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1237 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1238 connection_options: opts.clone(),
1239 server_cx: fake::SendableCx::new(server_cx),
1240 server_channel: server_client.clone(),
1241 });
1242
1243 client_cx.update(|cx| {
1244 cx.update_default_global(|c: &mut ConnectionPool, cx| {
1245 c.connections.insert(
1246 opts.clone(),
1247 ConnectionPoolEntry::Connecting(
1248 cx.background_spawn({
1249 let connection = connection.clone();
1250 async move { Ok(connection.clone()) }
1251 })
1252 .shared(),
1253 ),
1254 );
1255 })
1256 });
1257
1258 (opts, server_client)
1259 }
1260
1261 #[cfg(any(test, feature = "test-support"))]
1262 pub async fn fake_client(
1263 opts: SshConnectionOptions,
1264 client_cx: &mut gpui::TestAppContext,
1265 ) -> Entity<Self> {
1266 let (_tx, rx) = oneshot::channel();
1267 client_cx
1268 .update(|cx| {
1269 Self::new(
1270 ConnectionIdentifier::setup(),
1271 opts,
1272 rx,
1273 Arc::new(fake::Delegate),
1274 cx,
1275 )
1276 })
1277 .await
1278 .unwrap()
1279 .unwrap()
1280 }
1281}
1282
1283enum ConnectionPoolEntry {
1284 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1285 Connected(Weak<dyn RemoteConnection>),
1286}
1287
1288#[derive(Default)]
1289struct ConnectionPool {
1290 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1291}
1292
1293impl Global for ConnectionPool {}
1294
1295impl ConnectionPool {
1296 pub fn connect(
1297 &mut self,
1298 opts: SshConnectionOptions,
1299 delegate: &Arc<dyn SshClientDelegate>,
1300 cx: &mut App,
1301 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1302 let connection = self.connections.get(&opts);
1303 match connection {
1304 Some(ConnectionPoolEntry::Connecting(task)) => {
1305 let delegate = delegate.clone();
1306 cx.spawn(async move |cx| {
1307 delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1308 })
1309 .detach();
1310 return task.clone();
1311 }
1312 Some(ConnectionPoolEntry::Connected(ssh)) => {
1313 if let Some(ssh) = ssh.upgrade() {
1314 if !ssh.has_been_killed() {
1315 return Task::ready(Ok(ssh)).shared();
1316 }
1317 }
1318 self.connections.remove(&opts);
1319 }
1320 None => {}
1321 }
1322
1323 let task = cx
1324 .spawn({
1325 let opts = opts.clone();
1326 let delegate = delegate.clone();
1327 async move |cx| {
1328 let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1329 .await
1330 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1331
1332 cx.update_global(|pool: &mut Self, _| {
1333 debug_assert!(matches!(
1334 pool.connections.get(&opts),
1335 Some(ConnectionPoolEntry::Connecting(_))
1336 ));
1337 match connection {
1338 Ok(connection) => {
1339 pool.connections.insert(
1340 opts.clone(),
1341 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1342 );
1343 Ok(connection)
1344 }
1345 Err(error) => {
1346 pool.connections.remove(&opts);
1347 Err(Arc::new(error))
1348 }
1349 }
1350 })?
1351 }
1352 })
1353 .shared();
1354
1355 self.connections
1356 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1357 task
1358 }
1359}
1360
1361impl From<SshRemoteClient> for AnyProtoClient {
1362 fn from(client: SshRemoteClient) -> Self {
1363 AnyProtoClient::new(client.client.clone())
1364 }
1365}
1366
1367#[async_trait(?Send)]
1368trait RemoteConnection: Send + Sync {
1369 fn start_proxy(
1370 &self,
1371 unique_identifier: String,
1372 reconnect: bool,
1373 incoming_tx: UnboundedSender<Envelope>,
1374 outgoing_rx: UnboundedReceiver<Envelope>,
1375 connection_activity_tx: Sender<()>,
1376 delegate: Arc<dyn SshClientDelegate>,
1377 cx: &mut AsyncApp,
1378 ) -> Task<Result<i32>>;
1379 fn upload_directory(
1380 &self,
1381 src_path: PathBuf,
1382 dest_path: RemotePathBuf,
1383 cx: &App,
1384 ) -> Task<Result<()>>;
1385 async fn kill(&self) -> Result<()>;
1386 fn has_been_killed(&self) -> bool;
1387 /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1388 /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1389 fn ssh_args(&self) -> SshArgs;
1390 fn connection_options(&self) -> SshConnectionOptions;
1391 fn path_style(&self) -> PathStyle;
1392
1393 #[cfg(any(test, feature = "test-support"))]
1394 fn simulate_disconnect(&self, _: &AsyncApp) {}
1395}
1396
1397struct SshRemoteConnection {
1398 socket: SshSocket,
1399 master_process: Mutex<Option<Child>>,
1400 remote_binary_path: Option<RemotePathBuf>,
1401 ssh_platform: SshPlatform,
1402 ssh_path_style: PathStyle,
1403 _temp_dir: TempDir,
1404}
1405
1406#[async_trait(?Send)]
1407impl RemoteConnection for SshRemoteConnection {
1408 async fn kill(&self) -> Result<()> {
1409 let Some(mut process) = self.master_process.lock().take() else {
1410 return Ok(());
1411 };
1412 process.kill().ok();
1413 process.status().await?;
1414 Ok(())
1415 }
1416
1417 fn has_been_killed(&self) -> bool {
1418 self.master_process.lock().is_none()
1419 }
1420
1421 fn ssh_args(&self) -> SshArgs {
1422 self.socket.ssh_args()
1423 }
1424
1425 fn connection_options(&self) -> SshConnectionOptions {
1426 self.socket.connection_options.clone()
1427 }
1428
1429 fn upload_directory(
1430 &self,
1431 src_path: PathBuf,
1432 dest_path: RemotePathBuf,
1433 cx: &App,
1434 ) -> Task<Result<()>> {
1435 let mut command = util::command::new_smol_command("scp");
1436 let output = self
1437 .socket
1438 .ssh_options(&mut command)
1439 .args(
1440 self.socket
1441 .connection_options
1442 .port
1443 .map(|port| vec!["-P".to_string(), port.to_string()])
1444 .unwrap_or_default(),
1445 )
1446 .arg("-C")
1447 .arg("-r")
1448 .arg(&src_path)
1449 .arg(format!(
1450 "{}:{}",
1451 self.socket.connection_options.scp_url(),
1452 dest_path.to_string()
1453 ))
1454 .output();
1455
1456 cx.background_spawn(async move {
1457 let output = output.await?;
1458
1459 anyhow::ensure!(
1460 output.status.success(),
1461 "failed to upload directory {} -> {}: {}",
1462 src_path.display(),
1463 dest_path.to_string(),
1464 String::from_utf8_lossy(&output.stderr)
1465 );
1466
1467 Ok(())
1468 })
1469 }
1470
1471 fn start_proxy(
1472 &self,
1473 unique_identifier: String,
1474 reconnect: bool,
1475 incoming_tx: UnboundedSender<Envelope>,
1476 outgoing_rx: UnboundedReceiver<Envelope>,
1477 connection_activity_tx: Sender<()>,
1478 delegate: Arc<dyn SshClientDelegate>,
1479 cx: &mut AsyncApp,
1480 ) -> Task<Result<i32>> {
1481 delegate.set_status(Some("Starting proxy"), cx);
1482
1483 let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1484 return Task::ready(Err(anyhow!("Remote binary path not set")));
1485 };
1486
1487 let mut start_proxy_command = shell_script!(
1488 "exec {binary_path} proxy --identifier {identifier}",
1489 binary_path = &remote_binary_path.to_string(),
1490 identifier = &unique_identifier,
1491 );
1492
1493 if let Some(rust_log) = std::env::var("RUST_LOG").ok() {
1494 start_proxy_command = format!(
1495 "RUST_LOG={} {}",
1496 shlex::try_quote(&rust_log).unwrap(),
1497 start_proxy_command
1498 )
1499 }
1500 if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() {
1501 start_proxy_command = format!(
1502 "RUST_BACKTRACE={} {}",
1503 shlex::try_quote(&rust_backtrace).unwrap(),
1504 start_proxy_command
1505 )
1506 }
1507 if reconnect {
1508 start_proxy_command.push_str(" --reconnect");
1509 }
1510
1511 let ssh_proxy_process = match self
1512 .socket
1513 .ssh_command("sh", &["-c", &start_proxy_command])
1514 // IMPORTANT: we kill this process when we drop the task that uses it.
1515 .kill_on_drop(true)
1516 .spawn()
1517 {
1518 Ok(process) => process,
1519 Err(error) => {
1520 return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1521 }
1522 };
1523
1524 Self::multiplex(
1525 ssh_proxy_process,
1526 incoming_tx,
1527 outgoing_rx,
1528 connection_activity_tx,
1529 &cx,
1530 )
1531 }
1532
1533 fn path_style(&self) -> PathStyle {
1534 self.ssh_path_style
1535 }
1536}
1537
1538impl SshRemoteConnection {
1539 async fn new(
1540 connection_options: SshConnectionOptions,
1541 delegate: Arc<dyn SshClientDelegate>,
1542 cx: &mut AsyncApp,
1543 ) -> Result<Self> {
1544 use askpass::AskPassResult;
1545
1546 delegate.set_status(Some("Connecting"), cx);
1547
1548 let url = connection_options.ssh_url();
1549
1550 let temp_dir = tempfile::Builder::new()
1551 .prefix("zed-ssh-session")
1552 .tempdir()?;
1553 let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1554 let delegate = delegate.clone();
1555 move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1556 });
1557
1558 let mut askpass =
1559 askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1560
1561 // Start the master SSH process, which does not do anything except for establish
1562 // the connection and keep it open, allowing other ssh commands to reuse it
1563 // via a control socket.
1564 #[cfg(not(target_os = "windows"))]
1565 let socket_path = temp_dir.path().join("ssh.sock");
1566
1567 let mut master_process = {
1568 #[cfg(not(target_os = "windows"))]
1569 let args = [
1570 "-N",
1571 "-o",
1572 "ControlPersist=no",
1573 "-o",
1574 "ControlMaster=yes",
1575 "-o",
1576 ];
1577 // On Windows, `ControlMaster` and `ControlPath` are not supported:
1578 // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1579 // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1580 #[cfg(target_os = "windows")]
1581 let args = ["-N"];
1582 let mut master_process = util::command::new_smol_command("ssh");
1583 master_process
1584 .kill_on_drop(true)
1585 .stdin(Stdio::null())
1586 .stdout(Stdio::piped())
1587 .stderr(Stdio::piped())
1588 .env("SSH_ASKPASS_REQUIRE", "force")
1589 .env("SSH_ASKPASS", askpass.script_path())
1590 .args(connection_options.additional_args())
1591 .args(args);
1592 #[cfg(not(target_os = "windows"))]
1593 master_process.arg(format!("ControlPath={}", socket_path.display()));
1594 master_process.arg(&url).spawn()?
1595 };
1596 // Wait for this ssh process to close its stdout, indicating that authentication
1597 // has completed.
1598 let mut stdout = master_process.stdout.take().unwrap();
1599 let mut output = Vec::new();
1600
1601 let result = select_biased! {
1602 result = askpass.run().fuse() => {
1603 match result {
1604 AskPassResult::CancelledByUser => {
1605 master_process.kill().ok();
1606 anyhow::bail!("SSH connection canceled")
1607 }
1608 AskPassResult::Timedout => {
1609 anyhow::bail!("connecting to host timed out")
1610 }
1611 }
1612 }
1613 _ = stdout.read_to_end(&mut output).fuse() => {
1614 anyhow::Ok(())
1615 }
1616 };
1617
1618 if let Err(e) = result {
1619 return Err(e.context("Failed to connect to host"));
1620 }
1621
1622 if master_process.try_status()?.is_some() {
1623 output.clear();
1624 let mut stderr = master_process.stderr.take().unwrap();
1625 stderr.read_to_end(&mut output).await?;
1626
1627 let error_message = format!(
1628 "failed to connect: {}",
1629 String::from_utf8_lossy(&output).trim()
1630 );
1631 anyhow::bail!(error_message);
1632 }
1633
1634 #[cfg(not(target_os = "windows"))]
1635 let socket = SshSocket::new(connection_options, socket_path)?;
1636 #[cfg(target_os = "windows")]
1637 let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1638 drop(askpass);
1639
1640 let ssh_platform = socket.platform().await?;
1641 let ssh_path_style = match ssh_platform.os {
1642 "windows" => PathStyle::Windows,
1643 _ => PathStyle::Posix,
1644 };
1645
1646 let mut this = Self {
1647 socket,
1648 master_process: Mutex::new(Some(master_process)),
1649 _temp_dir: temp_dir,
1650 remote_binary_path: None,
1651 ssh_path_style,
1652 ssh_platform,
1653 };
1654
1655 let (release_channel, version, commit) = cx.update(|cx| {
1656 (
1657 ReleaseChannel::global(cx),
1658 AppVersion::global(cx),
1659 AppCommitSha::try_global(cx),
1660 )
1661 })?;
1662 this.remote_binary_path = Some(
1663 this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1664 .await?,
1665 );
1666
1667 Ok(this)
1668 }
1669
1670 fn multiplex(
1671 mut ssh_proxy_process: Child,
1672 incoming_tx: UnboundedSender<Envelope>,
1673 mut outgoing_rx: UnboundedReceiver<Envelope>,
1674 mut connection_activity_tx: Sender<()>,
1675 cx: &AsyncApp,
1676 ) -> Task<Result<i32>> {
1677 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1678 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1679 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1680
1681 let mut stdin_buffer = Vec::new();
1682 let mut stdout_buffer = Vec::new();
1683 let mut stderr_buffer = Vec::new();
1684 let mut stderr_offset = 0;
1685
1686 let stdin_task = cx.background_spawn(async move {
1687 while let Some(outgoing) = outgoing_rx.next().await {
1688 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1689 }
1690 anyhow::Ok(())
1691 });
1692
1693 let stdout_task = cx.background_spawn({
1694 let mut connection_activity_tx = connection_activity_tx.clone();
1695 async move {
1696 loop {
1697 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1698 let len = child_stdout.read(&mut stdout_buffer).await?;
1699
1700 if len == 0 {
1701 return anyhow::Ok(());
1702 }
1703
1704 if len < MESSAGE_LEN_SIZE {
1705 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1706 }
1707
1708 let message_len = message_len_from_buffer(&stdout_buffer);
1709 let envelope =
1710 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1711 .await?;
1712 connection_activity_tx.try_send(()).ok();
1713 incoming_tx.unbounded_send(envelope).ok();
1714 }
1715 }
1716 });
1717
1718 let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1719 loop {
1720 stderr_buffer.resize(stderr_offset + 1024, 0);
1721
1722 let len = child_stderr
1723 .read(&mut stderr_buffer[stderr_offset..])
1724 .await?;
1725 if len == 0 {
1726 return anyhow::Ok(());
1727 }
1728
1729 stderr_offset += len;
1730 let mut start_ix = 0;
1731 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1732 .iter()
1733 .position(|b| b == &b'\n')
1734 {
1735 let line_ix = start_ix + ix;
1736 let content = &stderr_buffer[start_ix..line_ix];
1737 start_ix = line_ix + 1;
1738 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1739 record.log(log::logger())
1740 } else {
1741 eprintln!("(remote) {}", String::from_utf8_lossy(content));
1742 }
1743 }
1744 stderr_buffer.drain(0..start_ix);
1745 stderr_offset -= start_ix;
1746
1747 connection_activity_tx.try_send(()).ok();
1748 }
1749 });
1750
1751 cx.background_spawn(async move {
1752 let result = futures::select! {
1753 result = stdin_task.fuse() => {
1754 result.context("stdin")
1755 }
1756 result = stdout_task.fuse() => {
1757 result.context("stdout")
1758 }
1759 result = stderr_task.fuse() => {
1760 result.context("stderr")
1761 }
1762 };
1763
1764 let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1765 match result {
1766 Ok(_) => Ok(status),
1767 Err(error) => Err(error),
1768 }
1769 })
1770 }
1771
1772 #[allow(unused)]
1773 async fn ensure_server_binary(
1774 &self,
1775 delegate: &Arc<dyn SshClientDelegate>,
1776 release_channel: ReleaseChannel,
1777 version: SemanticVersion,
1778 commit: Option<AppCommitSha>,
1779 cx: &mut AsyncApp,
1780 ) -> Result<RemotePathBuf> {
1781 let version_str = match release_channel {
1782 ReleaseChannel::Nightly => {
1783 let commit = commit.map(|s| s.full()).unwrap_or_default();
1784 format!("{}-{}", version, commit)
1785 }
1786 ReleaseChannel::Dev => "build".to_string(),
1787 _ => version.to_string(),
1788 };
1789 let binary_name = format!(
1790 "zed-remote-server-{}-{}",
1791 release_channel.dev_name(),
1792 version_str
1793 );
1794 let dst_path = RemotePathBuf::new(
1795 paths::remote_server_dir_relative().join(binary_name),
1796 self.ssh_path_style,
1797 );
1798
1799 let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1800 #[cfg(debug_assertions)]
1801 if let Some(build_remote_server) = build_remote_server {
1802 let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1803 let tmp_path = RemotePathBuf::new(
1804 paths::remote_server_dir_relative().join(format!(
1805 "download-{}-{}",
1806 std::process::id(),
1807 src_path.file_name().unwrap().to_string_lossy()
1808 )),
1809 self.ssh_path_style,
1810 );
1811 self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1812 .await?;
1813 self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1814 .await?;
1815 return Ok(dst_path);
1816 }
1817
1818 if self
1819 .socket
1820 .run_command(&dst_path.to_string(), &["version"])
1821 .await
1822 .is_ok()
1823 {
1824 return Ok(dst_path);
1825 }
1826
1827 let wanted_version = cx.update(|cx| match release_channel {
1828 ReleaseChannel::Nightly => Ok(None),
1829 ReleaseChannel::Dev => {
1830 anyhow::bail!(
1831 "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1832 dst_path
1833 )
1834 }
1835 _ => Ok(Some(AppVersion::global(cx))),
1836 })??;
1837
1838 let tmp_path_gz = RemotePathBuf::new(
1839 PathBuf::from(format!(
1840 "{}-download-{}.gz",
1841 dst_path.to_string(),
1842 std::process::id()
1843 )),
1844 self.ssh_path_style,
1845 );
1846 if !self.socket.connection_options.upload_binary_over_ssh {
1847 if let Some((url, body)) = delegate
1848 .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1849 .await?
1850 {
1851 match self
1852 .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1853 .await
1854 {
1855 Ok(_) => {
1856 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1857 .await?;
1858 return Ok(dst_path);
1859 }
1860 Err(e) => {
1861 log::error!(
1862 "Failed to download binary on server, attempting to upload server: {}",
1863 e
1864 )
1865 }
1866 }
1867 }
1868 }
1869
1870 let src_path = delegate
1871 .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1872 .await?;
1873 self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1874 .await?;
1875 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1876 .await?;
1877 return Ok(dst_path);
1878 }
1879
1880 async fn download_binary_on_server(
1881 &self,
1882 url: &str,
1883 body: &str,
1884 tmp_path_gz: &RemotePathBuf,
1885 delegate: &Arc<dyn SshClientDelegate>,
1886 cx: &mut AsyncApp,
1887 ) -> Result<()> {
1888 if let Some(parent) = tmp_path_gz.parent() {
1889 self.socket
1890 .run_command(
1891 "sh",
1892 &[
1893 "-c",
1894 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1895 ],
1896 )
1897 .await?;
1898 }
1899
1900 delegate.set_status(Some("Downloading remote development server on host"), cx);
1901
1902 match self
1903 .socket
1904 .run_command(
1905 "curl",
1906 &[
1907 "-f",
1908 "-L",
1909 "-X",
1910 "GET",
1911 "-H",
1912 "Content-Type: application/json",
1913 "-d",
1914 &body,
1915 &url,
1916 "-o",
1917 &tmp_path_gz.to_string(),
1918 ],
1919 )
1920 .await
1921 {
1922 Ok(_) => {}
1923 Err(e) => {
1924 if self.socket.run_command("which", &["curl"]).await.is_ok() {
1925 return Err(e);
1926 }
1927
1928 match self
1929 .socket
1930 .run_command(
1931 "wget",
1932 &[
1933 "--method=GET",
1934 "--header=Content-Type: application/json",
1935 "--body-data",
1936 &body,
1937 &url,
1938 "-O",
1939 &tmp_path_gz.to_string(),
1940 ],
1941 )
1942 .await
1943 {
1944 Ok(_) => {}
1945 Err(e) => {
1946 if self.socket.run_command("which", &["wget"]).await.is_ok() {
1947 return Err(e);
1948 } else {
1949 anyhow::bail!("Neither curl nor wget is available");
1950 }
1951 }
1952 }
1953 }
1954 }
1955
1956 Ok(())
1957 }
1958
1959 async fn upload_local_server_binary(
1960 &self,
1961 src_path: &Path,
1962 tmp_path_gz: &RemotePathBuf,
1963 delegate: &Arc<dyn SshClientDelegate>,
1964 cx: &mut AsyncApp,
1965 ) -> Result<()> {
1966 if let Some(parent) = tmp_path_gz.parent() {
1967 self.socket
1968 .run_command(
1969 "sh",
1970 &[
1971 "-c",
1972 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1973 ],
1974 )
1975 .await?;
1976 }
1977
1978 let src_stat = fs::metadata(&src_path).await?;
1979 let size = src_stat.len();
1980
1981 let t0 = Instant::now();
1982 delegate.set_status(Some("Uploading remote development server"), cx);
1983 log::info!(
1984 "uploading remote development server to {:?} ({}kb)",
1985 tmp_path_gz,
1986 size / 1024
1987 );
1988 self.upload_file(&src_path, &tmp_path_gz)
1989 .await
1990 .context("failed to upload server binary")?;
1991 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1992 Ok(())
1993 }
1994
1995 async fn extract_server_binary(
1996 &self,
1997 dst_path: &RemotePathBuf,
1998 tmp_path: &RemotePathBuf,
1999 delegate: &Arc<dyn SshClientDelegate>,
2000 cx: &mut AsyncApp,
2001 ) -> Result<()> {
2002 delegate.set_status(Some("Extracting remote development server"), cx);
2003 let server_mode = 0o755;
2004
2005 let orig_tmp_path = tmp_path.to_string();
2006 let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2007 shell_script!(
2008 "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2009 server_mode = &format!("{:o}", server_mode),
2010 dst_path = &dst_path.to_string(),
2011 )
2012 } else {
2013 shell_script!(
2014 "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2015 server_mode = &format!("{:o}", server_mode),
2016 dst_path = &dst_path.to_string()
2017 )
2018 };
2019 self.socket.run_command("sh", &["-c", &script]).await?;
2020 Ok(())
2021 }
2022
2023 async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2024 log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2025 let mut command = util::command::new_smol_command("scp");
2026 let output = self
2027 .socket
2028 .ssh_options(&mut command)
2029 .args(
2030 self.socket
2031 .connection_options
2032 .port
2033 .map(|port| vec!["-P".to_string(), port.to_string()])
2034 .unwrap_or_default(),
2035 )
2036 .arg(src_path)
2037 .arg(format!(
2038 "{}:{}",
2039 self.socket.connection_options.scp_url(),
2040 dest_path.to_string()
2041 ))
2042 .output()
2043 .await?;
2044
2045 anyhow::ensure!(
2046 output.status.success(),
2047 "failed to upload file {} -> {}: {}",
2048 src_path.display(),
2049 dest_path.to_string(),
2050 String::from_utf8_lossy(&output.stderr)
2051 );
2052 Ok(())
2053 }
2054
2055 #[cfg(debug_assertions)]
2056 async fn build_local(
2057 &self,
2058 build_remote_server: String,
2059 delegate: &Arc<dyn SshClientDelegate>,
2060 cx: &mut AsyncApp,
2061 ) -> Result<PathBuf> {
2062 use smol::process::{Command, Stdio};
2063 use std::env::VarError;
2064
2065 async fn run_cmd(command: &mut Command) -> Result<()> {
2066 let output = command
2067 .kill_on_drop(true)
2068 .stderr(Stdio::inherit())
2069 .output()
2070 .await?;
2071 anyhow::ensure!(
2072 output.status.success(),
2073 "Failed to run command: {command:?}"
2074 );
2075 Ok(())
2076 }
2077
2078 let use_musl = !build_remote_server.contains("nomusl");
2079 let triple = format!(
2080 "{}-{}",
2081 self.ssh_platform.arch,
2082 match self.ssh_platform.os {
2083 "linux" =>
2084 if use_musl {
2085 "unknown-linux-musl"
2086 } else {
2087 "unknown-linux-gnu"
2088 },
2089 "macos" => "apple-darwin",
2090 _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2091 }
2092 );
2093 let mut rust_flags = match std::env::var("RUSTFLAGS") {
2094 Ok(val) => val,
2095 Err(VarError::NotPresent) => String::new(),
2096 Err(e) => {
2097 log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2098 String::new()
2099 }
2100 };
2101 if self.ssh_platform.os == "linux" && use_musl {
2102 rust_flags.push_str(" -C target-feature=+crt-static");
2103 }
2104 if build_remote_server.contains("mold") {
2105 rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2106 }
2107
2108 if self.ssh_platform.arch == std::env::consts::ARCH
2109 && self.ssh_platform.os == std::env::consts::OS
2110 {
2111 delegate.set_status(Some("Building remote server binary from source"), cx);
2112 log::info!("building remote server binary from source");
2113 run_cmd(
2114 Command::new("cargo")
2115 .args([
2116 "build",
2117 "--package",
2118 "remote_server",
2119 "--features",
2120 "debug-embed",
2121 "--target-dir",
2122 "target/remote_server",
2123 "--target",
2124 &triple,
2125 ])
2126 .env("RUSTFLAGS", &rust_flags),
2127 )
2128 .await?;
2129 } else {
2130 if build_remote_server.contains("cross") {
2131 #[cfg(target_os = "windows")]
2132 use util::paths::SanitizedPath;
2133
2134 delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2135 log::info!("installing cross");
2136 run_cmd(Command::new("cargo").args([
2137 "install",
2138 "cross",
2139 "--git",
2140 "https://github.com/cross-rs/cross",
2141 ]))
2142 .await?;
2143
2144 delegate.set_status(
2145 Some(&format!(
2146 "Building remote server binary from source for {} with Docker",
2147 &triple
2148 )),
2149 cx,
2150 );
2151 log::info!("building remote server binary from source for {}", &triple);
2152
2153 // On Windows, the binding needs to be set to the canonical path
2154 #[cfg(target_os = "windows")]
2155 let src =
2156 SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2157 #[cfg(not(target_os = "windows"))]
2158 let src = "./target";
2159 run_cmd(
2160 Command::new("cross")
2161 .args([
2162 "build",
2163 "--package",
2164 "remote_server",
2165 "--features",
2166 "debug-embed",
2167 "--target-dir",
2168 "target/remote_server",
2169 "--target",
2170 &triple,
2171 ])
2172 .env(
2173 "CROSS_CONTAINER_OPTS",
2174 format!("--mount type=bind,src={src},dst=/app/target"),
2175 )
2176 .env("RUSTFLAGS", &rust_flags),
2177 )
2178 .await?;
2179 } else {
2180 let which = cx
2181 .background_spawn(async move { which::which("zig") })
2182 .await;
2183
2184 if which.is_err() {
2185 #[cfg(not(target_os = "windows"))]
2186 {
2187 anyhow::bail!(
2188 "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2189 )
2190 }
2191 #[cfg(target_os = "windows")]
2192 {
2193 anyhow::bail!(
2194 "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2195 )
2196 }
2197 }
2198
2199 delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2200 log::info!("adding rustup target");
2201 run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2202
2203 delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2204 log::info!("installing cargo-zigbuild");
2205 run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"]))
2206 .await?;
2207
2208 delegate.set_status(
2209 Some(&format!(
2210 "Building remote binary from source for {triple} with Zig"
2211 )),
2212 cx,
2213 );
2214 log::info!("building remote binary from source for {triple} with Zig");
2215 run_cmd(
2216 Command::new("cargo")
2217 .args([
2218 "zigbuild",
2219 "--package",
2220 "remote_server",
2221 "--features",
2222 "debug-embed",
2223 "--target-dir",
2224 "target/remote_server",
2225 "--target",
2226 &triple,
2227 ])
2228 .env("RUSTFLAGS", &rust_flags),
2229 )
2230 .await?;
2231 }
2232 };
2233 let bin_path = Path::new("target")
2234 .join("remote_server")
2235 .join(&triple)
2236 .join("debug")
2237 .join("remote_server");
2238
2239 let path = if !build_remote_server.contains("nocompress") {
2240 delegate.set_status(Some("Compressing binary"), cx);
2241
2242 #[cfg(not(target_os = "windows"))]
2243 {
2244 run_cmd(Command::new("gzip").args(["-9", "-f", &bin_path.to_string_lossy()]))
2245 .await?;
2246 }
2247 #[cfg(target_os = "windows")]
2248 {
2249 // On Windows, we use 7z to compress the binary
2250 let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2251 let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2252 if smol::fs::metadata(&gz_path).await.is_ok() {
2253 smol::fs::remove_file(&gz_path).await?;
2254 }
2255 run_cmd(Command::new(seven_zip).args([
2256 "a",
2257 "-tgzip",
2258 &gz_path,
2259 &bin_path.to_string_lossy(),
2260 ]))
2261 .await?;
2262 }
2263
2264 let mut archive_path = bin_path;
2265 archive_path.set_extension("gz");
2266 std::env::current_dir()?.join(archive_path)
2267 } else {
2268 bin_path
2269 };
2270
2271 Ok(path)
2272 }
2273}
2274
2275type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2276
2277pub struct ChannelClient {
2278 next_message_id: AtomicU32,
2279 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2280 buffer: Mutex<VecDeque<Envelope>>,
2281 response_channels: ResponseChannels,
2282 message_handlers: Mutex<ProtoMessageHandlerSet>,
2283 max_received: AtomicU32,
2284 name: &'static str,
2285 task: Mutex<Task<Result<()>>>,
2286}
2287
2288impl ChannelClient {
2289 pub fn new(
2290 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2291 outgoing_tx: mpsc::UnboundedSender<Envelope>,
2292 cx: &App,
2293 name: &'static str,
2294 ) -> Arc<Self> {
2295 Arc::new_cyclic(|this| Self {
2296 outgoing_tx: Mutex::new(outgoing_tx),
2297 next_message_id: AtomicU32::new(0),
2298 max_received: AtomicU32::new(0),
2299 response_channels: ResponseChannels::default(),
2300 message_handlers: Default::default(),
2301 buffer: Mutex::new(VecDeque::new()),
2302 name,
2303 task: Mutex::new(Self::start_handling_messages(
2304 this.clone(),
2305 incoming_rx,
2306 &cx.to_async(),
2307 )),
2308 })
2309 }
2310
2311 fn start_handling_messages(
2312 this: Weak<Self>,
2313 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2314 cx: &AsyncApp,
2315 ) -> Task<Result<()>> {
2316 cx.spawn(async move |cx| {
2317 let peer_id = PeerId { owner_id: 0, id: 0 };
2318 while let Some(incoming) = incoming_rx.next().await {
2319 let Some(this) = this.upgrade() else {
2320 return anyhow::Ok(());
2321 };
2322 if let Some(ack_id) = incoming.ack_id {
2323 let mut buffer = this.buffer.lock();
2324 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2325 buffer.pop_front();
2326 }
2327 }
2328 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2329 {
2330 log::debug!(
2331 "{}:ssh message received. name:FlushBufferedMessages",
2332 this.name
2333 );
2334 {
2335 let buffer = this.buffer.lock();
2336 for envelope in buffer.iter() {
2337 this.outgoing_tx
2338 .lock()
2339 .unbounded_send(envelope.clone())
2340 .ok();
2341 }
2342 }
2343 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2344 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2345 this.outgoing_tx.lock().unbounded_send(envelope).ok();
2346 continue;
2347 }
2348
2349 this.max_received.store(incoming.id, SeqCst);
2350
2351 if let Some(request_id) = incoming.responding_to {
2352 let request_id = MessageId(request_id);
2353 let sender = this.response_channels.lock().remove(&request_id);
2354 if let Some(sender) = sender {
2355 let (tx, rx) = oneshot::channel();
2356 if incoming.payload.is_some() {
2357 sender.send((incoming, tx)).ok();
2358 }
2359 rx.await.ok();
2360 }
2361 } else if let Some(envelope) =
2362 build_typed_envelope(peer_id, Instant::now(), incoming)
2363 {
2364 let type_name = envelope.payload_type_name();
2365 if let Some(future) = ProtoMessageHandlerSet::handle_message(
2366 &this.message_handlers,
2367 envelope,
2368 this.clone().into(),
2369 cx.clone(),
2370 ) {
2371 log::debug!("{}:ssh message received. name:{type_name}", this.name);
2372 cx.foreground_executor()
2373 .spawn(async move {
2374 match future.await {
2375 Ok(_) => {
2376 log::debug!(
2377 "{}:ssh message handled. name:{type_name}",
2378 this.name
2379 );
2380 }
2381 Err(error) => {
2382 log::error!(
2383 "{}:error handling message. type:{}, error:{}",
2384 this.name,
2385 type_name,
2386 format!("{error:#}").lines().fold(
2387 String::new(),
2388 |mut message, line| {
2389 if !message.is_empty() {
2390 message.push(' ');
2391 }
2392 message.push_str(line);
2393 message
2394 }
2395 )
2396 );
2397 }
2398 }
2399 })
2400 .detach()
2401 } else {
2402 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2403 }
2404 }
2405 }
2406 anyhow::Ok(())
2407 })
2408 }
2409
2410 pub fn reconnect(
2411 self: &Arc<Self>,
2412 incoming_rx: UnboundedReceiver<Envelope>,
2413 outgoing_tx: UnboundedSender<Envelope>,
2414 cx: &AsyncApp,
2415 ) {
2416 *self.outgoing_tx.lock() = outgoing_tx;
2417 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2418 }
2419
2420 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
2421 let id = (TypeId::of::<E>(), remote_id);
2422
2423 let mut message_handlers = self.message_handlers.lock();
2424 if message_handlers
2425 .entities_by_type_and_remote_id
2426 .contains_key(&id)
2427 {
2428 panic!("already subscribed to entity");
2429 }
2430
2431 message_handlers.entities_by_type_and_remote_id.insert(
2432 id,
2433 EntityMessageSubscriber::Entity {
2434 handle: entity.downgrade().into(),
2435 },
2436 );
2437 }
2438
2439 pub fn request<T: RequestMessage>(
2440 &self,
2441 payload: T,
2442 ) -> impl 'static + Future<Output = Result<T::Response>> {
2443 self.request_internal(payload, true)
2444 }
2445
2446 fn request_internal<T: RequestMessage>(
2447 &self,
2448 payload: T,
2449 use_buffer: bool,
2450 ) -> impl 'static + Future<Output = Result<T::Response>> {
2451 log::debug!("ssh request start. name:{}", T::NAME);
2452 let response =
2453 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2454 async move {
2455 let response = response.await?;
2456 log::debug!("ssh request finish. name:{}", T::NAME);
2457 T::Response::from_envelope(response).context("received a response of the wrong type")
2458 }
2459 }
2460
2461 pub async fn resync(&self, timeout: Duration) -> Result<()> {
2462 smol::future::or(
2463 async {
2464 self.request_internal(proto::FlushBufferedMessages {}, false)
2465 .await?;
2466
2467 for envelope in self.buffer.lock().iter() {
2468 self.outgoing_tx
2469 .lock()
2470 .unbounded_send(envelope.clone())
2471 .ok();
2472 }
2473 Ok(())
2474 },
2475 async {
2476 smol::Timer::after(timeout).await;
2477 anyhow::bail!("Timeout detected")
2478 },
2479 )
2480 .await
2481 }
2482
2483 pub async fn ping(&self, timeout: Duration) -> Result<()> {
2484 smol::future::or(
2485 async {
2486 self.request(proto::Ping {}).await?;
2487 Ok(())
2488 },
2489 async {
2490 smol::Timer::after(timeout).await;
2491 anyhow::bail!("Timeout detected")
2492 },
2493 )
2494 .await
2495 }
2496
2497 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2498 log::debug!("ssh send name:{}", T::NAME);
2499 self.send_dynamic(payload.into_envelope(0, None, None))
2500 }
2501
2502 fn request_dynamic(
2503 &self,
2504 mut envelope: proto::Envelope,
2505 type_name: &'static str,
2506 use_buffer: bool,
2507 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2508 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2509 let (tx, rx) = oneshot::channel();
2510 let mut response_channels_lock = self.response_channels.lock();
2511 response_channels_lock.insert(MessageId(envelope.id), tx);
2512 drop(response_channels_lock);
2513
2514 let result = if use_buffer {
2515 self.send_buffered(envelope)
2516 } else {
2517 self.send_unbuffered(envelope)
2518 };
2519 async move {
2520 if let Err(error) = &result {
2521 log::error!("failed to send message: {error}");
2522 anyhow::bail!("failed to send message: {error}");
2523 }
2524
2525 let response = rx.await.context("connection lost")?.0;
2526 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2527 return Err(RpcError::from_proto(error, type_name));
2528 }
2529 Ok(response)
2530 }
2531 }
2532
2533 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2534 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2535 self.send_buffered(envelope)
2536 }
2537
2538 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2539 envelope.ack_id = Some(self.max_received.load(SeqCst));
2540 self.buffer.lock().push_back(envelope.clone());
2541 // ignore errors on send (happen while we're reconnecting)
2542 // assume that the global "disconnected" overlay is sufficient.
2543 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2544 Ok(())
2545 }
2546
2547 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2548 envelope.ack_id = Some(self.max_received.load(SeqCst));
2549 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2550 Ok(())
2551 }
2552}
2553
2554impl ProtoClient for ChannelClient {
2555 fn request(
2556 &self,
2557 envelope: proto::Envelope,
2558 request_type: &'static str,
2559 ) -> BoxFuture<'static, Result<proto::Envelope>> {
2560 self.request_dynamic(envelope, request_type, true).boxed()
2561 }
2562
2563 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2564 self.send_dynamic(envelope)
2565 }
2566
2567 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2568 self.send_dynamic(envelope)
2569 }
2570
2571 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2572 &self.message_handlers
2573 }
2574
2575 fn is_via_collab(&self) -> bool {
2576 false
2577 }
2578}
2579
2580#[cfg(any(test, feature = "test-support"))]
2581mod fake {
2582 use std::{path::PathBuf, sync::Arc};
2583
2584 use anyhow::Result;
2585 use async_trait::async_trait;
2586 use futures::{
2587 FutureExt, SinkExt, StreamExt,
2588 channel::{
2589 mpsc::{self, Sender},
2590 oneshot,
2591 },
2592 select_biased,
2593 };
2594 use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2595 use release_channel::ReleaseChannel;
2596 use rpc::proto::Envelope;
2597 use util::paths::{PathStyle, RemotePathBuf};
2598
2599 use super::{
2600 ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2601 SshPlatform,
2602 };
2603
2604 pub(super) struct FakeRemoteConnection {
2605 pub(super) connection_options: SshConnectionOptions,
2606 pub(super) server_channel: Arc<ChannelClient>,
2607 pub(super) server_cx: SendableCx,
2608 }
2609
2610 pub(super) struct SendableCx(AsyncApp);
2611 impl SendableCx {
2612 // SAFETY: When run in test mode, GPUI is always single threaded.
2613 pub(super) fn new(cx: &TestAppContext) -> Self {
2614 Self(cx.to_async())
2615 }
2616
2617 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2618 fn get(&self, _: &AsyncApp) -> AsyncApp {
2619 self.0.clone()
2620 }
2621 }
2622
2623 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2624 unsafe impl Send for SendableCx {}
2625 unsafe impl Sync for SendableCx {}
2626
2627 #[async_trait(?Send)]
2628 impl RemoteConnection for FakeRemoteConnection {
2629 async fn kill(&self) -> Result<()> {
2630 Ok(())
2631 }
2632
2633 fn has_been_killed(&self) -> bool {
2634 false
2635 }
2636
2637 fn ssh_args(&self) -> SshArgs {
2638 SshArgs {
2639 arguments: Vec::new(),
2640 envs: None,
2641 }
2642 }
2643
2644 fn upload_directory(
2645 &self,
2646 _src_path: PathBuf,
2647 _dest_path: RemotePathBuf,
2648 _cx: &App,
2649 ) -> Task<Result<()>> {
2650 unreachable!()
2651 }
2652
2653 fn connection_options(&self) -> SshConnectionOptions {
2654 self.connection_options.clone()
2655 }
2656
2657 fn simulate_disconnect(&self, cx: &AsyncApp) {
2658 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2659 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2660 self.server_channel
2661 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
2662 }
2663
2664 fn start_proxy(
2665 &self,
2666 _unique_identifier: String,
2667 _reconnect: bool,
2668 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2669 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2670 mut connection_activity_tx: Sender<()>,
2671 _delegate: Arc<dyn SshClientDelegate>,
2672 cx: &mut AsyncApp,
2673 ) -> Task<Result<i32>> {
2674 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2675 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2676
2677 self.server_channel.reconnect(
2678 server_incoming_rx,
2679 server_outgoing_tx,
2680 &self.server_cx.get(cx),
2681 );
2682
2683 cx.background_spawn(async move {
2684 loop {
2685 select_biased! {
2686 server_to_client = server_outgoing_rx.next().fuse() => {
2687 let Some(server_to_client) = server_to_client else {
2688 return Ok(1)
2689 };
2690 connection_activity_tx.try_send(()).ok();
2691 client_incoming_tx.send(server_to_client).await.ok();
2692 }
2693 client_to_server = client_outgoing_rx.next().fuse() => {
2694 let Some(client_to_server) = client_to_server else {
2695 return Ok(1)
2696 };
2697 server_incoming_tx.send(client_to_server).await.ok();
2698 }
2699 }
2700 }
2701 })
2702 }
2703
2704 fn path_style(&self) -> PathStyle {
2705 PathStyle::current()
2706 }
2707 }
2708
2709 pub(super) struct Delegate;
2710
2711 impl SshClientDelegate for Delegate {
2712 fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2713 unreachable!()
2714 }
2715
2716 fn download_server_binary_locally(
2717 &self,
2718 _: SshPlatform,
2719 _: ReleaseChannel,
2720 _: Option<SemanticVersion>,
2721 _: &mut AsyncApp,
2722 ) -> Task<Result<PathBuf>> {
2723 unreachable!()
2724 }
2725
2726 fn get_download_params(
2727 &self,
2728 _platform: SshPlatform,
2729 _release_channel: ReleaseChannel,
2730 _version: Option<SemanticVersion>,
2731 _cx: &mut AsyncApp,
2732 ) -> Task<Result<Option<(String, String)>>> {
2733 unreachable!()
2734 }
2735
2736 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2737 }
2738}