1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 MESSAGE_LEN_SIZE, MessageId, message_len_from_buffer, read_message_with_len, write_message,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{Context as _, Result, anyhow};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
13 channel::{
14 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
15 oneshot,
16 },
17 future::{BoxFuture, Shared},
18 select, select_biased,
19};
20use gpui::{
21 App, AppContext as _, AsyncApp, BackgroundExecutor, BorrowAppContext, Context, Entity,
22 EventEmitter, Global, SemanticVersion, Task, WeakEntity,
23};
24use itertools::Itertools;
25use parking_lot::Mutex;
26
27use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
28use rpc::{
29 AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
30 proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
31};
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use smol::{
35 fs,
36 process::{self, Child, Stdio},
37};
38use std::{
39 collections::VecDeque,
40 fmt, iter,
41 ops::ControlFlow,
42 path::{Path, PathBuf},
43 sync::{
44 Arc, Weak,
45 atomic::{AtomicU32, AtomicU64, Ordering::SeqCst},
46 },
47 time::{Duration, Instant},
48};
49use tempfile::TempDir;
50use util::{
51 ResultExt,
52 paths::{PathStyle, RemotePathBuf},
53};
54
55#[derive(
56 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
57)]
58pub struct SshProjectId(pub u64);
59
60#[derive(Clone)]
61pub struct SshSocket {
62 connection_options: SshConnectionOptions,
63 #[cfg(not(target_os = "windows"))]
64 socket_path: PathBuf,
65 #[cfg(target_os = "windows")]
66 envs: HashMap<String, String>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
70pub struct SshPortForwardOption {
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub local_host: Option<String>,
73 pub local_port: u16,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub remote_host: Option<String>,
76 pub remote_port: u16,
77}
78
79#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
80pub struct SshConnectionOptions {
81 pub host: String,
82 pub username: Option<String>,
83 pub port: Option<u16>,
84 pub password: Option<String>,
85 pub args: Option<Vec<String>>,
86 pub port_forwards: Option<Vec<SshPortForwardOption>>,
87
88 pub nickname: Option<String>,
89 pub upload_binary_over_ssh: bool,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
93pub struct SshArgs {
94 pub arguments: Vec<String>,
95 pub envs: Option<HashMap<String, String>>,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct SshInfo {
100 pub args: SshArgs,
101 pub path_style: PathStyle,
102 pub shell: String,
103}
104
105#[macro_export]
106macro_rules! shell_script {
107 ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
108 format!(
109 $fmt,
110 $(
111 $name = shlex::try_quote($arg).unwrap()
112 ),+
113 )
114 }};
115}
116
117fn parse_port_number(port_str: &str) -> Result<u16> {
118 port_str
119 .parse()
120 .with_context(|| format!("parsing port number: {port_str}"))
121}
122
123fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
124 let parts: Vec<&str> = spec.split(':').collect();
125
126 match parts.len() {
127 4 => {
128 let local_port = parse_port_number(parts[1])?;
129 let remote_port = parse_port_number(parts[3])?;
130
131 Ok(SshPortForwardOption {
132 local_host: Some(parts[0].to_string()),
133 local_port,
134 remote_host: Some(parts[2].to_string()),
135 remote_port,
136 })
137 }
138 3 => {
139 let local_port = parse_port_number(parts[0])?;
140 let remote_port = parse_port_number(parts[2])?;
141
142 Ok(SshPortForwardOption {
143 local_host: None,
144 local_port,
145 remote_host: Some(parts[1].to_string()),
146 remote_port,
147 })
148 }
149 _ => anyhow::bail!("Invalid port forward format"),
150 }
151}
152
153impl SshConnectionOptions {
154 pub fn parse_command_line(input: &str) -> Result<Self> {
155 let input = input.trim_start_matches("ssh ");
156 let mut hostname: Option<String> = None;
157 let mut username: Option<String> = None;
158 let mut port: Option<u16> = None;
159 let mut args = Vec::new();
160 let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
161
162 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
163 const ALLOWED_OPTS: &[&str] = &[
164 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
165 ];
166 const ALLOWED_ARGS: &[&str] = &[
167 "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
168 "-w",
169 ];
170
171 let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
172
173 'outer: while let Some(arg) = tokens.next() {
174 if ALLOWED_OPTS.contains(&(&arg as &str)) {
175 args.push(arg.to_string());
176 continue;
177 }
178 if arg == "-p" {
179 port = tokens.next().and_then(|arg| arg.parse().ok());
180 continue;
181 } else if let Some(p) = arg.strip_prefix("-p") {
182 port = p.parse().ok();
183 continue;
184 }
185 if arg == "-l" {
186 username = tokens.next();
187 continue;
188 } else if let Some(l) = arg.strip_prefix("-l") {
189 username = Some(l.to_string());
190 continue;
191 }
192 if arg == "-L" || arg.starts_with("-L") {
193 let forward_spec = if arg == "-L" {
194 tokens.next()
195 } else {
196 Some(arg.strip_prefix("-L").unwrap().to_string())
197 };
198
199 if let Some(spec) = forward_spec {
200 port_forwards.push(parse_port_forward_spec(&spec)?);
201 } else {
202 anyhow::bail!("Missing port forward format");
203 }
204 }
205
206 for a in ALLOWED_ARGS {
207 if arg == *a {
208 args.push(arg);
209 if let Some(next) = tokens.next() {
210 args.push(next);
211 }
212 continue 'outer;
213 } else if arg.starts_with(a) {
214 args.push(arg);
215 continue 'outer;
216 }
217 }
218 if arg.starts_with("-") || hostname.is_some() {
219 anyhow::bail!("unsupported argument: {:?}", arg);
220 }
221 let mut input = &arg as &str;
222 // Destination might be: username1@username2@ip2@ip1
223 if let Some((u, rest)) = input.rsplit_once('@') {
224 input = rest;
225 username = Some(u.to_string());
226 }
227 if let Some((rest, p)) = input.split_once(':') {
228 input = rest;
229 port = p.parse().ok()
230 }
231 hostname = Some(input.to_string())
232 }
233
234 let Some(hostname) = hostname else {
235 anyhow::bail!("missing hostname");
236 };
237
238 let port_forwards = match port_forwards.len() {
239 0 => None,
240 _ => Some(port_forwards),
241 };
242
243 Ok(Self {
244 host: hostname,
245 username,
246 port,
247 port_forwards,
248 args: Some(args),
249 password: None,
250 nickname: None,
251 upload_binary_over_ssh: false,
252 })
253 }
254
255 pub fn ssh_url(&self) -> String {
256 let mut result = String::from("ssh://");
257 if let Some(username) = &self.username {
258 // Username might be: username1@username2@ip2
259 let username = urlencoding::encode(username);
260 result.push_str(&username);
261 result.push('@');
262 }
263 result.push_str(&self.host);
264 if let Some(port) = self.port {
265 result.push(':');
266 result.push_str(&port.to_string());
267 }
268 result
269 }
270
271 pub fn additional_args(&self) -> Vec<String> {
272 let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
273
274 if let Some(forwards) = &self.port_forwards {
275 args.extend(forwards.iter().map(|pf| {
276 let local_host = match &pf.local_host {
277 Some(host) => host,
278 None => "localhost",
279 };
280 let remote_host = match &pf.remote_host {
281 Some(host) => host,
282 None => "localhost",
283 };
284
285 format!(
286 "-L{}:{}:{}:{}",
287 local_host, pf.local_port, remote_host, pf.remote_port
288 )
289 }));
290 }
291
292 args
293 }
294
295 fn scp_url(&self) -> String {
296 if let Some(username) = &self.username {
297 format!("{}@{}", username, self.host)
298 } else {
299 self.host.clone()
300 }
301 }
302
303 pub fn connection_string(&self) -> String {
304 let host = if let Some(username) = &self.username {
305 format!("{}@{}", username, self.host)
306 } else {
307 self.host.clone()
308 };
309 if let Some(port) = &self.port {
310 format!("{}:{}", host, port)
311 } else {
312 host
313 }
314 }
315}
316
317#[derive(Copy, Clone, Debug)]
318pub struct SshPlatform {
319 pub os: &'static str,
320 pub arch: &'static str,
321}
322
323pub trait SshClientDelegate: Send + Sync {
324 fn ask_password(&self, prompt: String, tx: oneshot::Sender<String>, cx: &mut AsyncApp);
325 fn get_download_params(
326 &self,
327 platform: SshPlatform,
328 release_channel: ReleaseChannel,
329 version: Option<SemanticVersion>,
330 cx: &mut AsyncApp,
331 ) -> Task<Result<Option<(String, String)>>>;
332
333 fn download_server_binary_locally(
334 &self,
335 platform: SshPlatform,
336 release_channel: ReleaseChannel,
337 version: Option<SemanticVersion>,
338 cx: &mut AsyncApp,
339 ) -> Task<Result<PathBuf>>;
340 fn set_status(&self, status: Option<&str>, cx: &mut AsyncApp);
341}
342
343impl SshSocket {
344 #[cfg(not(target_os = "windows"))]
345 fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
346 Ok(Self {
347 connection_options: options,
348 socket_path,
349 })
350 }
351
352 #[cfg(target_os = "windows")]
353 fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
354 let askpass_script = temp_dir.path().join("askpass.bat");
355 std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
356 let mut envs = HashMap::default();
357 envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
358 envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
359 envs.insert("ZED_SSH_ASKPASS".into(), secret);
360 Ok(Self {
361 connection_options: options,
362 envs,
363 })
364 }
365
366 // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
367 // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
368 // and passes -l as an argument to sh, not to ls.
369 // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
370 // into a machine. You must use `cd` to get back to $HOME.
371 // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
372 fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
373 let mut command = util::command::new_smol_command("ssh");
374 let to_run = iter::once(&program)
375 .chain(args.iter())
376 .map(|token| {
377 // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
378 debug_assert!(
379 !token.contains('\n'),
380 "multiline arguments do not work in all shells"
381 );
382 shlex::try_quote(token).unwrap()
383 })
384 .join(" ");
385 let to_run = format!("cd; {to_run}");
386 log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
387 self.ssh_options(&mut command)
388 .arg(self.connection_options.ssh_url())
389 .arg(to_run);
390 command
391 }
392
393 async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
394 let output = self.ssh_command(program, args).output().await?;
395 anyhow::ensure!(
396 output.status.success(),
397 "failed to run command: {}",
398 String::from_utf8_lossy(&output.stderr)
399 );
400 Ok(String::from_utf8_lossy(&output.stdout).to_string())
401 }
402
403 #[cfg(not(target_os = "windows"))]
404 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
405 command
406 .stdin(Stdio::piped())
407 .stdout(Stdio::piped())
408 .stderr(Stdio::piped())
409 .args(self.connection_options.additional_args())
410 .args(["-o", "ControlMaster=no", "-o"])
411 .arg(format!("ControlPath={}", self.socket_path.display()))
412 }
413
414 #[cfg(target_os = "windows")]
415 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
416 command
417 .stdin(Stdio::piped())
418 .stdout(Stdio::piped())
419 .stderr(Stdio::piped())
420 .args(self.connection_options.additional_args())
421 .envs(self.envs.clone())
422 }
423
424 // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
425 // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
426 #[cfg(not(target_os = "windows"))]
427 fn ssh_args(&self) -> SshArgs {
428 let mut arguments = self.connection_options.additional_args();
429 arguments.extend(vec![
430 "-o".to_string(),
431 "ControlMaster=no".to_string(),
432 "-o".to_string(),
433 format!("ControlPath={}", self.socket_path.display()),
434 self.connection_options.ssh_url(),
435 ]);
436 SshArgs {
437 arguments,
438 envs: None,
439 }
440 }
441
442 #[cfg(target_os = "windows")]
443 fn ssh_args(&self) -> SshArgs {
444 let mut arguments = self.connection_options.additional_args();
445 arguments.push(self.connection_options.ssh_url());
446 SshArgs {
447 arguments,
448 envs: Some(self.envs.clone()),
449 }
450 }
451
452 async fn platform(&self) -> Result<SshPlatform> {
453 let uname = self.run_command("sh", &["-c", "uname -sm"]).await?;
454 let Some((os, arch)) = uname.split_once(" ") else {
455 anyhow::bail!("unknown uname: {uname:?}")
456 };
457
458 let os = match os.trim() {
459 "Darwin" => "macos",
460 "Linux" => "linux",
461 _ => anyhow::bail!(
462 "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
463 ),
464 };
465 // exclude armv5,6,7 as they are 32-bit.
466 let arch = if arch.starts_with("armv8")
467 || arch.starts_with("armv9")
468 || arch.starts_with("arm64")
469 || arch.starts_with("aarch64")
470 {
471 "aarch64"
472 } else if arch.starts_with("x86") {
473 "x86_64"
474 } else {
475 anyhow::bail!(
476 "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
477 )
478 };
479
480 Ok(SshPlatform { os, arch })
481 }
482
483 async fn shell(&self) -> String {
484 match self.run_command("sh", &["-c", "echo $SHELL"]).await {
485 Ok(shell) => shell.trim().to_owned(),
486 Err(e) => {
487 log::error!("Failed to get shell: {e}");
488 "sh".to_owned()
489 }
490 }
491 }
492}
493
494const MAX_MISSED_HEARTBEATS: usize = 5;
495const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
496const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
497
498const MAX_RECONNECT_ATTEMPTS: usize = 3;
499
500enum State {
501 Connecting,
502 Connected {
503 ssh_connection: Arc<dyn RemoteConnection>,
504 delegate: Arc<dyn SshClientDelegate>,
505
506 multiplex_task: Task<Result<()>>,
507 heartbeat_task: Task<Result<()>>,
508 },
509 HeartbeatMissed {
510 missed_heartbeats: usize,
511
512 ssh_connection: Arc<dyn RemoteConnection>,
513 delegate: Arc<dyn SshClientDelegate>,
514
515 multiplex_task: Task<Result<()>>,
516 heartbeat_task: Task<Result<()>>,
517 },
518 Reconnecting,
519 ReconnectFailed {
520 ssh_connection: Arc<dyn RemoteConnection>,
521 delegate: Arc<dyn SshClientDelegate>,
522
523 error: anyhow::Error,
524 attempts: usize,
525 },
526 ReconnectExhausted,
527 ServerNotRunning,
528}
529
530impl fmt::Display for State {
531 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
532 match self {
533 Self::Connecting => write!(f, "connecting"),
534 Self::Connected { .. } => write!(f, "connected"),
535 Self::Reconnecting => write!(f, "reconnecting"),
536 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
537 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
538 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
539 Self::ServerNotRunning { .. } => write!(f, "server not running"),
540 }
541 }
542}
543
544impl State {
545 fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
546 match self {
547 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
548 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
549 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
550 _ => None,
551 }
552 }
553
554 fn can_reconnect(&self) -> bool {
555 match self {
556 Self::Connected { .. }
557 | Self::HeartbeatMissed { .. }
558 | Self::ReconnectFailed { .. } => true,
559 State::Connecting
560 | State::Reconnecting
561 | State::ReconnectExhausted
562 | State::ServerNotRunning => false,
563 }
564 }
565
566 fn is_reconnect_failed(&self) -> bool {
567 matches!(self, Self::ReconnectFailed { .. })
568 }
569
570 fn is_reconnect_exhausted(&self) -> bool {
571 matches!(self, Self::ReconnectExhausted { .. })
572 }
573
574 fn is_server_not_running(&self) -> bool {
575 matches!(self, Self::ServerNotRunning)
576 }
577
578 fn is_reconnecting(&self) -> bool {
579 matches!(self, Self::Reconnecting { .. })
580 }
581
582 fn heartbeat_recovered(self) -> Self {
583 match self {
584 Self::HeartbeatMissed {
585 ssh_connection,
586 delegate,
587 multiplex_task,
588 heartbeat_task,
589 ..
590 } => Self::Connected {
591 ssh_connection,
592 delegate,
593 multiplex_task,
594 heartbeat_task,
595 },
596 _ => self,
597 }
598 }
599
600 fn heartbeat_missed(self) -> Self {
601 match self {
602 Self::Connected {
603 ssh_connection,
604 delegate,
605 multiplex_task,
606 heartbeat_task,
607 } => Self::HeartbeatMissed {
608 missed_heartbeats: 1,
609 ssh_connection,
610 delegate,
611 multiplex_task,
612 heartbeat_task,
613 },
614 Self::HeartbeatMissed {
615 missed_heartbeats,
616 ssh_connection,
617 delegate,
618 multiplex_task,
619 heartbeat_task,
620 } => Self::HeartbeatMissed {
621 missed_heartbeats: missed_heartbeats + 1,
622 ssh_connection,
623 delegate,
624 multiplex_task,
625 heartbeat_task,
626 },
627 _ => self,
628 }
629 }
630}
631
632/// The state of the ssh connection.
633#[derive(Clone, Copy, Debug, PartialEq, Eq)]
634pub enum ConnectionState {
635 Connecting,
636 Connected,
637 HeartbeatMissed,
638 Reconnecting,
639 Disconnected,
640}
641
642impl From<&State> for ConnectionState {
643 fn from(value: &State) -> Self {
644 match value {
645 State::Connecting => Self::Connecting,
646 State::Connected { .. } => Self::Connected,
647 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
648 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
649 State::ReconnectExhausted => Self::Disconnected,
650 State::ServerNotRunning => Self::Disconnected,
651 }
652 }
653}
654
655pub struct SshRemoteClient {
656 client: Arc<ChannelClient>,
657 unique_identifier: String,
658 connection_options: SshConnectionOptions,
659 path_style: PathStyle,
660 state: Arc<Mutex<Option<State>>>,
661}
662
663#[derive(Debug)]
664pub enum SshRemoteEvent {
665 Disconnected,
666}
667
668impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
669
670// Identifies the socket on the remote server so that reconnects
671// can re-join the same project.
672pub enum ConnectionIdentifier {
673 Setup(u64),
674 Workspace(i64),
675}
676
677static NEXT_ID: AtomicU64 = AtomicU64::new(1);
678
679impl ConnectionIdentifier {
680 pub fn setup() -> Self {
681 Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
682 }
683
684 // This string gets used in a socket name, and so must be relatively short.
685 // The total length of:
686 // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
687 // Must be less than about 100 characters
688 // https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars
689 // So our strings should be at most 20 characters or so.
690 fn to_string(&self, cx: &App) -> String {
691 let identifier_prefix = match ReleaseChannel::global(cx) {
692 ReleaseChannel::Stable => "".to_string(),
693 release_channel => format!("{}-", release_channel.dev_name()),
694 };
695 match self {
696 Self::Setup(setup_id) => format!("{identifier_prefix}setup-{setup_id}"),
697 Self::Workspace(workspace_id) => {
698 format!("{identifier_prefix}workspace-{workspace_id}",)
699 }
700 }
701 }
702}
703
704impl SshRemoteClient {
705 pub fn new(
706 unique_identifier: ConnectionIdentifier,
707 connection_options: SshConnectionOptions,
708 cancellation: oneshot::Receiver<()>,
709 delegate: Arc<dyn SshClientDelegate>,
710 cx: &mut App,
711 ) -> Task<Result<Option<Entity<Self>>>> {
712 let unique_identifier = unique_identifier.to_string(cx);
713 cx.spawn(async move |cx| {
714 let success = Box::pin(async move {
715 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
716 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
717 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
718
719 let client =
720 cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
721
722 let ssh_connection = cx
723 .update(|cx| {
724 cx.update_default_global(|pool: &mut ConnectionPool, cx| {
725 pool.connect(connection_options.clone(), &delegate, cx)
726 })
727 })?
728 .await
729 .map_err(|e| e.cloned())?;
730
731 let path_style = ssh_connection.path_style();
732 let this = cx.new(|_| Self {
733 client: client.clone(),
734 unique_identifier: unique_identifier.clone(),
735 connection_options,
736 path_style,
737 state: Arc::new(Mutex::new(Some(State::Connecting))),
738 })?;
739
740 let io_task = ssh_connection.start_proxy(
741 unique_identifier,
742 false,
743 incoming_tx,
744 outgoing_rx,
745 connection_activity_tx,
746 delegate.clone(),
747 cx,
748 );
749
750 let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
751
752 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
753 log::error!("failed to establish connection: {}", error);
754 return Err(error);
755 }
756
757 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, cx);
758
759 this.update(cx, |this, _| {
760 *this.state.lock() = Some(State::Connected {
761 ssh_connection,
762 delegate,
763 multiplex_task,
764 heartbeat_task,
765 });
766 })?;
767
768 Ok(Some(this))
769 });
770
771 select! {
772 _ = cancellation.fuse() => {
773 Ok(None)
774 }
775 result = success.fuse() => result
776 }
777 })
778 }
779
780 pub fn proto_client_from_channels(
781 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
782 outgoing_tx: mpsc::UnboundedSender<Envelope>,
783 cx: &App,
784 name: &'static str,
785 ) -> AnyProtoClient {
786 ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
787 }
788
789 pub fn shutdown_processes<T: RequestMessage>(
790 &self,
791 shutdown_request: Option<T>,
792 executor: BackgroundExecutor,
793 ) -> Option<impl Future<Output = ()> + use<T>> {
794 let state = self.state.lock().take()?;
795 log::info!("shutting down ssh processes");
796
797 let State::Connected {
798 multiplex_task,
799 heartbeat_task,
800 ssh_connection,
801 delegate,
802 } = state
803 else {
804 return None;
805 };
806
807 let client = self.client.clone();
808
809 Some(async move {
810 if let Some(shutdown_request) = shutdown_request {
811 client.send(shutdown_request).log_err();
812 // We wait 50ms instead of waiting for a response, because
813 // waiting for a response would require us to wait on the main thread
814 // which we want to avoid in an `on_app_quit` callback.
815 executor.timer(Duration::from_millis(50)).await;
816 }
817
818 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
819 // child of master_process.
820 drop(multiplex_task);
821 // Now drop the rest of state, which kills master process.
822 drop(heartbeat_task);
823 drop(ssh_connection);
824 drop(delegate);
825 })
826 }
827
828 fn reconnect(&mut self, cx: &mut Context<Self>) -> Result<()> {
829 let mut lock = self.state.lock();
830
831 let can_reconnect = lock
832 .as_ref()
833 .map(|state| state.can_reconnect())
834 .unwrap_or(false);
835 if !can_reconnect {
836 log::info!("aborting reconnect, because not in state that allows reconnecting");
837 let error = if let Some(state) = lock.as_ref() {
838 format!("invalid state, cannot reconnect while in state {state}")
839 } else {
840 "no state set".to_string()
841 };
842 anyhow::bail!(error);
843 }
844
845 let state = lock.take().unwrap();
846 let (attempts, ssh_connection, delegate) = match state {
847 State::Connected {
848 ssh_connection,
849 delegate,
850 multiplex_task,
851 heartbeat_task,
852 }
853 | State::HeartbeatMissed {
854 ssh_connection,
855 delegate,
856 multiplex_task,
857 heartbeat_task,
858 ..
859 } => {
860 drop(multiplex_task);
861 drop(heartbeat_task);
862 (0, ssh_connection, delegate)
863 }
864 State::ReconnectFailed {
865 attempts,
866 ssh_connection,
867 delegate,
868 ..
869 } => (attempts, ssh_connection, delegate),
870 State::Connecting
871 | State::Reconnecting
872 | State::ReconnectExhausted
873 | State::ServerNotRunning => unreachable!(),
874 };
875
876 let attempts = attempts + 1;
877 if attempts > MAX_RECONNECT_ATTEMPTS {
878 log::error!(
879 "Failed to reconnect to after {} attempts, giving up",
880 MAX_RECONNECT_ATTEMPTS
881 );
882 drop(lock);
883 self.set_state(State::ReconnectExhausted, cx);
884 return Ok(());
885 }
886 drop(lock);
887
888 self.set_state(State::Reconnecting, cx);
889
890 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
891
892 let unique_identifier = self.unique_identifier.clone();
893 let client = self.client.clone();
894 let reconnect_task = cx.spawn(async move |this, cx| {
895 macro_rules! failed {
896 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
897 return State::ReconnectFailed {
898 error: anyhow!($error),
899 attempts: $attempts,
900 ssh_connection: $ssh_connection,
901 delegate: $delegate,
902 };
903 };
904 }
905
906 if let Err(error) = ssh_connection
907 .kill()
908 .await
909 .context("Failed to kill ssh process")
910 {
911 failed!(error, attempts, ssh_connection, delegate);
912 };
913
914 let connection_options = ssh_connection.connection_options();
915
916 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
917 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
918 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
919
920 let (ssh_connection, io_task) = match async {
921 let ssh_connection = cx
922 .update_global(|pool: &mut ConnectionPool, cx| {
923 pool.connect(connection_options, &delegate, cx)
924 })?
925 .await
926 .map_err(|error| error.cloned())?;
927
928 let io_task = ssh_connection.start_proxy(
929 unique_identifier,
930 true,
931 incoming_tx,
932 outgoing_rx,
933 connection_activity_tx,
934 delegate.clone(),
935 cx,
936 );
937 anyhow::Ok((ssh_connection, io_task))
938 }
939 .await
940 {
941 Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
942 Err(error) => {
943 failed!(error, attempts, ssh_connection, delegate);
944 }
945 };
946
947 let multiplex_task = Self::monitor(this.clone(), io_task, cx);
948 client.reconnect(incoming_rx, outgoing_tx, cx);
949
950 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
951 failed!(error, attempts, ssh_connection, delegate);
952 };
953
954 State::Connected {
955 ssh_connection,
956 delegate,
957 multiplex_task,
958 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, cx),
959 }
960 });
961
962 cx.spawn(async move |this, cx| {
963 let new_state = reconnect_task.await;
964 this.update(cx, |this, cx| {
965 this.try_set_state(cx, |old_state| {
966 if old_state.is_reconnecting() {
967 match &new_state {
968 State::Connecting
969 | State::Reconnecting
970 | State::HeartbeatMissed { .. }
971 | State::ServerNotRunning => {}
972 State::Connected { .. } => {
973 log::info!("Successfully reconnected");
974 }
975 State::ReconnectFailed {
976 error, attempts, ..
977 } => {
978 log::error!(
979 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
980 attempts,
981 error
982 );
983 }
984 State::ReconnectExhausted => {
985 log::error!("Reconnect attempt failed and all attempts exhausted");
986 }
987 }
988 Some(new_state)
989 } else {
990 None
991 }
992 });
993
994 if this.state_is(State::is_reconnect_failed) {
995 this.reconnect(cx)
996 } else if this.state_is(State::is_reconnect_exhausted) {
997 Ok(())
998 } else {
999 log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
1000 Ok(())
1001 }
1002 })
1003 })
1004 .detach_and_log_err(cx);
1005
1006 Ok(())
1007 }
1008
1009 fn heartbeat(
1010 this: WeakEntity<Self>,
1011 mut connection_activity_rx: mpsc::Receiver<()>,
1012 cx: &mut AsyncApp,
1013 ) -> Task<Result<()>> {
1014 let Ok(client) = this.read_with(cx, |this, _| this.client.clone()) else {
1015 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
1016 };
1017
1018 cx.spawn(async move |cx| {
1019 let mut missed_heartbeats = 0;
1020
1021 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
1022 futures::pin_mut!(keepalive_timer);
1023
1024 loop {
1025 select_biased! {
1026 result = connection_activity_rx.next().fuse() => {
1027 if result.is_none() {
1028 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
1029 return Ok(());
1030 }
1031
1032 if missed_heartbeats != 0 {
1033 missed_heartbeats = 0;
1034 let _ =this.update(cx, |this, cx| {
1035 this.handle_heartbeat_result(missed_heartbeats, cx)
1036 })?;
1037 }
1038 }
1039 _ = keepalive_timer => {
1040 log::debug!("Sending heartbeat to server...");
1041
1042 let result = select_biased! {
1043 _ = connection_activity_rx.next().fuse() => {
1044 Ok(())
1045 }
1046 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
1047 ping_result
1048 }
1049 };
1050
1051 if result.is_err() {
1052 missed_heartbeats += 1;
1053 log::warn!(
1054 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
1055 HEARTBEAT_TIMEOUT,
1056 missed_heartbeats,
1057 MAX_MISSED_HEARTBEATS
1058 );
1059 } else if missed_heartbeats != 0 {
1060 missed_heartbeats = 0;
1061 } else {
1062 continue;
1063 }
1064
1065 let result = this.update(cx, |this, cx| {
1066 this.handle_heartbeat_result(missed_heartbeats, cx)
1067 })?;
1068 if result.is_break() {
1069 return Ok(());
1070 }
1071 }
1072 }
1073
1074 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
1075 }
1076 })
1077 }
1078
1079 fn handle_heartbeat_result(
1080 &mut self,
1081 missed_heartbeats: usize,
1082 cx: &mut Context<Self>,
1083 ) -> ControlFlow<()> {
1084 let state = self.state.lock().take().unwrap();
1085 let next_state = if missed_heartbeats > 0 {
1086 state.heartbeat_missed()
1087 } else {
1088 state.heartbeat_recovered()
1089 };
1090
1091 self.set_state(next_state, cx);
1092
1093 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
1094 log::error!(
1095 "Missed last {} heartbeats. Reconnecting...",
1096 missed_heartbeats
1097 );
1098
1099 self.reconnect(cx)
1100 .context("failed to start reconnect process after missing heartbeats")
1101 .log_err();
1102 ControlFlow::Break(())
1103 } else {
1104 ControlFlow::Continue(())
1105 }
1106 }
1107
1108 fn monitor(
1109 this: WeakEntity<Self>,
1110 io_task: Task<Result<i32>>,
1111 cx: &AsyncApp,
1112 ) -> Task<Result<()>> {
1113 cx.spawn(async move |cx| {
1114 let result = io_task.await;
1115
1116 match result {
1117 Ok(exit_code) => {
1118 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1119 match error {
1120 ProxyLaunchError::ServerNotRunning => {
1121 log::error!("failed to reconnect because server is not running");
1122 this.update(cx, |this, cx| {
1123 this.set_state(State::ServerNotRunning, cx);
1124 })?;
1125 }
1126 }
1127 } else if exit_code > 0 {
1128 log::error!("proxy process terminated unexpectedly");
1129 this.update(cx, |this, cx| {
1130 this.reconnect(cx).ok();
1131 })?;
1132 }
1133 }
1134 Err(error) => {
1135 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1136 this.update(cx, |this, cx| {
1137 this.reconnect(cx).ok();
1138 })?;
1139 }
1140 }
1141
1142 Ok(())
1143 })
1144 }
1145
1146 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1147 self.state.lock().as_ref().is_some_and(check)
1148 }
1149
1150 fn try_set_state(&self, cx: &mut Context<Self>, map: impl FnOnce(&State) -> Option<State>) {
1151 let mut lock = self.state.lock();
1152 let new_state = lock.as_ref().and_then(map);
1153
1154 if let Some(new_state) = new_state {
1155 lock.replace(new_state);
1156 cx.notify();
1157 }
1158 }
1159
1160 fn set_state(&self, state: State, cx: &mut Context<Self>) {
1161 log::info!("setting state to '{}'", &state);
1162
1163 let is_reconnect_exhausted = state.is_reconnect_exhausted();
1164 let is_server_not_running = state.is_server_not_running();
1165 self.state.lock().replace(state);
1166
1167 if is_reconnect_exhausted || is_server_not_running {
1168 cx.emit(SshRemoteEvent::Disconnected);
1169 }
1170 cx.notify();
1171 }
1172
1173 pub fn ssh_info(&self) -> Option<SshInfo> {
1174 self.state
1175 .lock()
1176 .as_ref()
1177 .and_then(|state| state.ssh_connection())
1178 .map(|ssh_connection| SshInfo {
1179 args: ssh_connection.ssh_args(),
1180 path_style: ssh_connection.path_style(),
1181 shell: ssh_connection.shell(),
1182 })
1183 }
1184
1185 pub fn upload_directory(
1186 &self,
1187 src_path: PathBuf,
1188 dest_path: RemotePathBuf,
1189 cx: &App,
1190 ) -> Task<Result<()>> {
1191 let state = self.state.lock();
1192 let Some(connection) = state.as_ref().and_then(|state| state.ssh_connection()) else {
1193 return Task::ready(Err(anyhow!("no ssh connection")));
1194 };
1195 connection.upload_directory(src_path, dest_path, cx)
1196 }
1197
1198 pub fn proto_client(&self) -> AnyProtoClient {
1199 self.client.clone().into()
1200 }
1201
1202 pub fn connection_string(&self) -> String {
1203 self.connection_options.connection_string()
1204 }
1205
1206 pub fn connection_options(&self) -> SshConnectionOptions {
1207 self.connection_options.clone()
1208 }
1209
1210 pub fn connection_state(&self) -> ConnectionState {
1211 self.state
1212 .lock()
1213 .as_ref()
1214 .map(ConnectionState::from)
1215 .unwrap_or(ConnectionState::Disconnected)
1216 }
1217
1218 pub fn is_disconnected(&self) -> bool {
1219 self.connection_state() == ConnectionState::Disconnected
1220 }
1221
1222 pub fn path_style(&self) -> PathStyle {
1223 self.path_style
1224 }
1225
1226 #[cfg(any(test, feature = "test-support"))]
1227 pub fn simulate_disconnect(&self, client_cx: &mut App) -> Task<()> {
1228 let opts = self.connection_options();
1229 client_cx.spawn(async move |cx| {
1230 let connection = cx
1231 .update_global(|c: &mut ConnectionPool, _| {
1232 if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
1233 c.clone()
1234 } else {
1235 panic!("missing test connection")
1236 }
1237 })
1238 .unwrap()
1239 .await
1240 .unwrap();
1241
1242 connection.simulate_disconnect(cx);
1243 })
1244 }
1245
1246 #[cfg(any(test, feature = "test-support"))]
1247 pub fn fake_server(
1248 client_cx: &mut gpui::TestAppContext,
1249 server_cx: &mut gpui::TestAppContext,
1250 ) -> (SshConnectionOptions, AnyProtoClient) {
1251 let port = client_cx
1252 .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
1253 let opts = SshConnectionOptions {
1254 host: "<fake>".to_string(),
1255 port: Some(port),
1256 ..Default::default()
1257 };
1258 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
1259 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
1260 let server_client =
1261 server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
1262 let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
1263 connection_options: opts.clone(),
1264 server_cx: fake::SendableCx::new(server_cx),
1265 server_channel: server_client.clone(),
1266 });
1267
1268 client_cx.update(|cx| {
1269 cx.update_default_global(|c: &mut ConnectionPool, cx| {
1270 c.connections.insert(
1271 opts.clone(),
1272 ConnectionPoolEntry::Connecting(
1273 cx.background_spawn({
1274 let connection = connection.clone();
1275 async move { Ok(connection.clone()) }
1276 })
1277 .shared(),
1278 ),
1279 );
1280 })
1281 });
1282
1283 (opts, server_client.into())
1284 }
1285
1286 #[cfg(any(test, feature = "test-support"))]
1287 pub async fn fake_client(
1288 opts: SshConnectionOptions,
1289 client_cx: &mut gpui::TestAppContext,
1290 ) -> Entity<Self> {
1291 let (_tx, rx) = oneshot::channel();
1292 client_cx
1293 .update(|cx| {
1294 Self::new(
1295 ConnectionIdentifier::setup(),
1296 opts,
1297 rx,
1298 Arc::new(fake::Delegate),
1299 cx,
1300 )
1301 })
1302 .await
1303 .unwrap()
1304 .unwrap()
1305 }
1306}
1307
1308enum ConnectionPoolEntry {
1309 Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
1310 Connected(Weak<dyn RemoteConnection>),
1311}
1312
1313#[derive(Default)]
1314struct ConnectionPool {
1315 connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
1316}
1317
1318impl Global for ConnectionPool {}
1319
1320impl ConnectionPool {
1321 pub fn connect(
1322 &mut self,
1323 opts: SshConnectionOptions,
1324 delegate: &Arc<dyn SshClientDelegate>,
1325 cx: &mut App,
1326 ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
1327 let connection = self.connections.get(&opts);
1328 match connection {
1329 Some(ConnectionPoolEntry::Connecting(task)) => {
1330 let delegate = delegate.clone();
1331 cx.spawn(async move |cx| {
1332 delegate.set_status(Some("Waiting for existing connection attempt"), cx);
1333 })
1334 .detach();
1335 return task.clone();
1336 }
1337 Some(ConnectionPoolEntry::Connected(ssh)) => {
1338 if let Some(ssh) = ssh.upgrade()
1339 && !ssh.has_been_killed()
1340 {
1341 return Task::ready(Ok(ssh)).shared();
1342 }
1343 self.connections.remove(&opts);
1344 }
1345 None => {}
1346 }
1347
1348 let task = cx
1349 .spawn({
1350 let opts = opts.clone();
1351 let delegate = delegate.clone();
1352 async move |cx| {
1353 let connection = SshRemoteConnection::new(opts.clone(), delegate, cx)
1354 .await
1355 .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
1356
1357 cx.update_global(|pool: &mut Self, _| {
1358 debug_assert!(matches!(
1359 pool.connections.get(&opts),
1360 Some(ConnectionPoolEntry::Connecting(_))
1361 ));
1362 match connection {
1363 Ok(connection) => {
1364 pool.connections.insert(
1365 opts.clone(),
1366 ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
1367 );
1368 Ok(connection)
1369 }
1370 Err(error) => {
1371 pool.connections.remove(&opts);
1372 Err(Arc::new(error))
1373 }
1374 }
1375 })?
1376 }
1377 })
1378 .shared();
1379
1380 self.connections
1381 .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
1382 task
1383 }
1384}
1385
1386impl From<SshRemoteClient> for AnyProtoClient {
1387 fn from(client: SshRemoteClient) -> Self {
1388 AnyProtoClient::new(client.client)
1389 }
1390}
1391
1392#[async_trait(?Send)]
1393trait RemoteConnection: Send + Sync {
1394 fn start_proxy(
1395 &self,
1396 unique_identifier: String,
1397 reconnect: bool,
1398 incoming_tx: UnboundedSender<Envelope>,
1399 outgoing_rx: UnboundedReceiver<Envelope>,
1400 connection_activity_tx: Sender<()>,
1401 delegate: Arc<dyn SshClientDelegate>,
1402 cx: &mut AsyncApp,
1403 ) -> Task<Result<i32>>;
1404 fn upload_directory(
1405 &self,
1406 src_path: PathBuf,
1407 dest_path: RemotePathBuf,
1408 cx: &App,
1409 ) -> Task<Result<()>>;
1410 async fn kill(&self) -> Result<()>;
1411 fn has_been_killed(&self) -> bool;
1412 /// On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1413 /// On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1414 fn ssh_args(&self) -> SshArgs;
1415 fn connection_options(&self) -> SshConnectionOptions;
1416 fn path_style(&self) -> PathStyle;
1417 fn shell(&self) -> String;
1418
1419 #[cfg(any(test, feature = "test-support"))]
1420 fn simulate_disconnect(&self, _: &AsyncApp) {}
1421}
1422
1423struct SshRemoteConnection {
1424 socket: SshSocket,
1425 master_process: Mutex<Option<Child>>,
1426 remote_binary_path: Option<RemotePathBuf>,
1427 ssh_platform: SshPlatform,
1428 ssh_path_style: PathStyle,
1429 ssh_shell: String,
1430 _temp_dir: TempDir,
1431}
1432
1433#[async_trait(?Send)]
1434impl RemoteConnection for SshRemoteConnection {
1435 async fn kill(&self) -> Result<()> {
1436 let Some(mut process) = self.master_process.lock().take() else {
1437 return Ok(());
1438 };
1439 process.kill().ok();
1440 process.status().await?;
1441 Ok(())
1442 }
1443
1444 fn has_been_killed(&self) -> bool {
1445 self.master_process.lock().is_none()
1446 }
1447
1448 fn ssh_args(&self) -> SshArgs {
1449 self.socket.ssh_args()
1450 }
1451
1452 fn connection_options(&self) -> SshConnectionOptions {
1453 self.socket.connection_options.clone()
1454 }
1455
1456 fn shell(&self) -> String {
1457 self.ssh_shell.clone()
1458 }
1459
1460 fn upload_directory(
1461 &self,
1462 src_path: PathBuf,
1463 dest_path: RemotePathBuf,
1464 cx: &App,
1465 ) -> Task<Result<()>> {
1466 let mut command = util::command::new_smol_command("scp");
1467 let output = self
1468 .socket
1469 .ssh_options(&mut command)
1470 .args(
1471 self.socket
1472 .connection_options
1473 .port
1474 .map(|port| vec!["-P".to_string(), port.to_string()])
1475 .unwrap_or_default(),
1476 )
1477 .arg("-C")
1478 .arg("-r")
1479 .arg(&src_path)
1480 .arg(format!(
1481 "{}:{}",
1482 self.socket.connection_options.scp_url(),
1483 dest_path
1484 ))
1485 .output();
1486
1487 cx.background_spawn(async move {
1488 let output = output.await?;
1489
1490 anyhow::ensure!(
1491 output.status.success(),
1492 "failed to upload directory {} -> {}: {}",
1493 src_path.display(),
1494 dest_path.to_string(),
1495 String::from_utf8_lossy(&output.stderr)
1496 );
1497
1498 Ok(())
1499 })
1500 }
1501
1502 fn start_proxy(
1503 &self,
1504 unique_identifier: String,
1505 reconnect: bool,
1506 incoming_tx: UnboundedSender<Envelope>,
1507 outgoing_rx: UnboundedReceiver<Envelope>,
1508 connection_activity_tx: Sender<()>,
1509 delegate: Arc<dyn SshClientDelegate>,
1510 cx: &mut AsyncApp,
1511 ) -> Task<Result<i32>> {
1512 delegate.set_status(Some("Starting proxy"), cx);
1513
1514 let Some(remote_binary_path) = self.remote_binary_path.clone() else {
1515 return Task::ready(Err(anyhow!("Remote binary path not set")));
1516 };
1517
1518 let mut start_proxy_command = shell_script!(
1519 "exec {binary_path} proxy --identifier {identifier}",
1520 binary_path = &remote_binary_path.to_string(),
1521 identifier = &unique_identifier,
1522 );
1523
1524 for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
1525 if let Some(value) = std::env::var(env_var).ok() {
1526 start_proxy_command = format!(
1527 "{}={} {} ",
1528 env_var,
1529 shlex::try_quote(&value).unwrap(),
1530 start_proxy_command,
1531 );
1532 }
1533 }
1534
1535 if reconnect {
1536 start_proxy_command.push_str(" --reconnect");
1537 }
1538
1539 let ssh_proxy_process = match self
1540 .socket
1541 .ssh_command("sh", &["-c", &start_proxy_command])
1542 // IMPORTANT: we kill this process when we drop the task that uses it.
1543 .kill_on_drop(true)
1544 .spawn()
1545 {
1546 Ok(process) => process,
1547 Err(error) => {
1548 return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
1549 }
1550 };
1551
1552 Self::multiplex(
1553 ssh_proxy_process,
1554 incoming_tx,
1555 outgoing_rx,
1556 connection_activity_tx,
1557 cx,
1558 )
1559 }
1560
1561 fn path_style(&self) -> PathStyle {
1562 self.ssh_path_style
1563 }
1564}
1565
1566impl SshRemoteConnection {
1567 async fn new(
1568 connection_options: SshConnectionOptions,
1569 delegate: Arc<dyn SshClientDelegate>,
1570 cx: &mut AsyncApp,
1571 ) -> Result<Self> {
1572 use askpass::AskPassResult;
1573
1574 delegate.set_status(Some("Connecting"), cx);
1575
1576 let url = connection_options.ssh_url();
1577
1578 let temp_dir = tempfile::Builder::new()
1579 .prefix("zed-ssh-session")
1580 .tempdir()?;
1581 let askpass_delegate = askpass::AskPassDelegate::new(cx, {
1582 let delegate = delegate.clone();
1583 move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
1584 });
1585
1586 let mut askpass =
1587 askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
1588
1589 // Start the master SSH process, which does not do anything except for establish
1590 // the connection and keep it open, allowing other ssh commands to reuse it
1591 // via a control socket.
1592 #[cfg(not(target_os = "windows"))]
1593 let socket_path = temp_dir.path().join("ssh.sock");
1594
1595 let mut master_process = {
1596 #[cfg(not(target_os = "windows"))]
1597 let args = [
1598 "-N",
1599 "-o",
1600 "ControlPersist=no",
1601 "-o",
1602 "ControlMaster=yes",
1603 "-o",
1604 ];
1605 // On Windows, `ControlMaster` and `ControlPath` are not supported:
1606 // https://github.com/PowerShell/Win32-OpenSSH/issues/405
1607 // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
1608 #[cfg(target_os = "windows")]
1609 let args = ["-N"];
1610 let mut master_process = util::command::new_smol_command("ssh");
1611 master_process
1612 .kill_on_drop(true)
1613 .stdin(Stdio::null())
1614 .stdout(Stdio::piped())
1615 .stderr(Stdio::piped())
1616 .env("SSH_ASKPASS_REQUIRE", "force")
1617 .env("SSH_ASKPASS", askpass.script_path())
1618 .args(connection_options.additional_args())
1619 .args(args);
1620 #[cfg(not(target_os = "windows"))]
1621 master_process.arg(format!("ControlPath={}", socket_path.display()));
1622 master_process.arg(&url).spawn()?
1623 };
1624 // Wait for this ssh process to close its stdout, indicating that authentication
1625 // has completed.
1626 let mut stdout = master_process.stdout.take().unwrap();
1627 let mut output = Vec::new();
1628
1629 let result = select_biased! {
1630 result = askpass.run().fuse() => {
1631 match result {
1632 AskPassResult::CancelledByUser => {
1633 master_process.kill().ok();
1634 anyhow::bail!("SSH connection canceled")
1635 }
1636 AskPassResult::Timedout => {
1637 anyhow::bail!("connecting to host timed out")
1638 }
1639 }
1640 }
1641 _ = stdout.read_to_end(&mut output).fuse() => {
1642 anyhow::Ok(())
1643 }
1644 };
1645
1646 if let Err(e) = result {
1647 return Err(e.context("Failed to connect to host"));
1648 }
1649
1650 if master_process.try_status()?.is_some() {
1651 output.clear();
1652 let mut stderr = master_process.stderr.take().unwrap();
1653 stderr.read_to_end(&mut output).await?;
1654
1655 let error_message = format!(
1656 "failed to connect: {}",
1657 String::from_utf8_lossy(&output).trim()
1658 );
1659 anyhow::bail!(error_message);
1660 }
1661
1662 #[cfg(not(target_os = "windows"))]
1663 let socket = SshSocket::new(connection_options, socket_path)?;
1664 #[cfg(target_os = "windows")]
1665 let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
1666 drop(askpass);
1667
1668 let ssh_platform = socket.platform().await?;
1669 let ssh_path_style = match ssh_platform.os {
1670 "windows" => PathStyle::Windows,
1671 _ => PathStyle::Posix,
1672 };
1673 let ssh_shell = socket.shell().await;
1674
1675 let mut this = Self {
1676 socket,
1677 master_process: Mutex::new(Some(master_process)),
1678 _temp_dir: temp_dir,
1679 remote_binary_path: None,
1680 ssh_path_style,
1681 ssh_platform,
1682 ssh_shell,
1683 };
1684
1685 let (release_channel, version, commit) = cx.update(|cx| {
1686 (
1687 ReleaseChannel::global(cx),
1688 AppVersion::global(cx),
1689 AppCommitSha::try_global(cx),
1690 )
1691 })?;
1692 this.remote_binary_path = Some(
1693 this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
1694 .await?,
1695 );
1696
1697 Ok(this)
1698 }
1699
1700 fn multiplex(
1701 mut ssh_proxy_process: Child,
1702 incoming_tx: UnboundedSender<Envelope>,
1703 mut outgoing_rx: UnboundedReceiver<Envelope>,
1704 mut connection_activity_tx: Sender<()>,
1705 cx: &AsyncApp,
1706 ) -> Task<Result<i32>> {
1707 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
1708 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
1709 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
1710
1711 let mut stdin_buffer = Vec::new();
1712 let mut stdout_buffer = Vec::new();
1713 let mut stderr_buffer = Vec::new();
1714 let mut stderr_offset = 0;
1715
1716 let stdin_task = cx.background_spawn(async move {
1717 while let Some(outgoing) = outgoing_rx.next().await {
1718 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
1719 }
1720 anyhow::Ok(())
1721 });
1722
1723 let stdout_task = cx.background_spawn({
1724 let mut connection_activity_tx = connection_activity_tx.clone();
1725 async move {
1726 loop {
1727 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
1728 let len = child_stdout.read(&mut stdout_buffer).await?;
1729
1730 if len == 0 {
1731 return anyhow::Ok(());
1732 }
1733
1734 if len < MESSAGE_LEN_SIZE {
1735 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
1736 }
1737
1738 let message_len = message_len_from_buffer(&stdout_buffer);
1739 let envelope =
1740 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
1741 .await?;
1742 connection_activity_tx.try_send(()).ok();
1743 incoming_tx.unbounded_send(envelope).ok();
1744 }
1745 }
1746 });
1747
1748 let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
1749 loop {
1750 stderr_buffer.resize(stderr_offset + 1024, 0);
1751
1752 let len = child_stderr
1753 .read(&mut stderr_buffer[stderr_offset..])
1754 .await?;
1755 if len == 0 {
1756 return anyhow::Ok(());
1757 }
1758
1759 stderr_offset += len;
1760 let mut start_ix = 0;
1761 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
1762 .iter()
1763 .position(|b| b == &b'\n')
1764 {
1765 let line_ix = start_ix + ix;
1766 let content = &stderr_buffer[start_ix..line_ix];
1767 start_ix = line_ix + 1;
1768 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
1769 record.log(log::logger())
1770 } else {
1771 eprintln!("(remote) {}", String::from_utf8_lossy(content));
1772 }
1773 }
1774 stderr_buffer.drain(0..start_ix);
1775 stderr_offset -= start_ix;
1776
1777 connection_activity_tx.try_send(()).ok();
1778 }
1779 });
1780
1781 cx.background_spawn(async move {
1782 let result = futures::select! {
1783 result = stdin_task.fuse() => {
1784 result.context("stdin")
1785 }
1786 result = stdout_task.fuse() => {
1787 result.context("stdout")
1788 }
1789 result = stderr_task.fuse() => {
1790 result.context("stderr")
1791 }
1792 };
1793
1794 let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
1795 match result {
1796 Ok(_) => Ok(status),
1797 Err(error) => Err(error),
1798 }
1799 })
1800 }
1801
1802 #[allow(unused)]
1803 async fn ensure_server_binary(
1804 &self,
1805 delegate: &Arc<dyn SshClientDelegate>,
1806 release_channel: ReleaseChannel,
1807 version: SemanticVersion,
1808 commit: Option<AppCommitSha>,
1809 cx: &mut AsyncApp,
1810 ) -> Result<RemotePathBuf> {
1811 let version_str = match release_channel {
1812 ReleaseChannel::Nightly => {
1813 let commit = commit.map(|s| s.full()).unwrap_or_default();
1814 format!("{}-{}", version, commit)
1815 }
1816 ReleaseChannel::Dev => "build".to_string(),
1817 _ => version.to_string(),
1818 };
1819 let binary_name = format!(
1820 "zed-remote-server-{}-{}",
1821 release_channel.dev_name(),
1822 version_str
1823 );
1824 let dst_path = RemotePathBuf::new(
1825 paths::remote_server_dir_relative().join(binary_name),
1826 self.ssh_path_style,
1827 );
1828
1829 let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
1830 #[cfg(debug_assertions)]
1831 if let Some(build_remote_server) = build_remote_server {
1832 let src_path = self.build_local(build_remote_server, delegate, cx).await?;
1833 let tmp_path = RemotePathBuf::new(
1834 paths::remote_server_dir_relative().join(format!(
1835 "download-{}-{}",
1836 std::process::id(),
1837 src_path.file_name().unwrap().to_string_lossy()
1838 )),
1839 self.ssh_path_style,
1840 );
1841 self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
1842 .await?;
1843 self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
1844 .await?;
1845 return Ok(dst_path);
1846 }
1847
1848 if self
1849 .socket
1850 .run_command(&dst_path.to_string(), &["version"])
1851 .await
1852 .is_ok()
1853 {
1854 return Ok(dst_path);
1855 }
1856
1857 let wanted_version = cx.update(|cx| match release_channel {
1858 ReleaseChannel::Nightly => Ok(None),
1859 ReleaseChannel::Dev => {
1860 anyhow::bail!(
1861 "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
1862 dst_path
1863 )
1864 }
1865 _ => Ok(Some(AppVersion::global(cx))),
1866 })??;
1867
1868 let tmp_path_gz = RemotePathBuf::new(
1869 PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
1870 self.ssh_path_style,
1871 );
1872 if !self.socket.connection_options.upload_binary_over_ssh
1873 && let Some((url, body)) = delegate
1874 .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
1875 .await?
1876 {
1877 match self
1878 .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
1879 .await
1880 {
1881 Ok(_) => {
1882 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1883 .await?;
1884 return Ok(dst_path);
1885 }
1886 Err(e) => {
1887 log::error!(
1888 "Failed to download binary on server, attempting to upload server: {}",
1889 e
1890 )
1891 }
1892 }
1893 }
1894
1895 let src_path = delegate
1896 .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
1897 .await?;
1898 self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
1899 .await?;
1900 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
1901 .await?;
1902 Ok(dst_path)
1903 }
1904
1905 async fn download_binary_on_server(
1906 &self,
1907 url: &str,
1908 body: &str,
1909 tmp_path_gz: &RemotePathBuf,
1910 delegate: &Arc<dyn SshClientDelegate>,
1911 cx: &mut AsyncApp,
1912 ) -> Result<()> {
1913 if let Some(parent) = tmp_path_gz.parent() {
1914 self.socket
1915 .run_command(
1916 "sh",
1917 &[
1918 "-c",
1919 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1920 ],
1921 )
1922 .await?;
1923 }
1924
1925 delegate.set_status(Some("Downloading remote development server on host"), cx);
1926
1927 match self
1928 .socket
1929 .run_command(
1930 "curl",
1931 &[
1932 "-f",
1933 "-L",
1934 "-X",
1935 "GET",
1936 "-H",
1937 "Content-Type: application/json",
1938 "-d",
1939 body,
1940 url,
1941 "-o",
1942 &tmp_path_gz.to_string(),
1943 ],
1944 )
1945 .await
1946 {
1947 Ok(_) => {}
1948 Err(e) => {
1949 if self.socket.run_command("which", &["curl"]).await.is_ok() {
1950 return Err(e);
1951 }
1952
1953 match self
1954 .socket
1955 .run_command(
1956 "wget",
1957 &[
1958 "--method=GET",
1959 "--header=Content-Type: application/json",
1960 "--body-data",
1961 body,
1962 url,
1963 "-O",
1964 &tmp_path_gz.to_string(),
1965 ],
1966 )
1967 .await
1968 {
1969 Ok(_) => {}
1970 Err(e) => {
1971 if self.socket.run_command("which", &["wget"]).await.is_ok() {
1972 return Err(e);
1973 } else {
1974 anyhow::bail!("Neither curl nor wget is available");
1975 }
1976 }
1977 }
1978 }
1979 }
1980
1981 Ok(())
1982 }
1983
1984 async fn upload_local_server_binary(
1985 &self,
1986 src_path: &Path,
1987 tmp_path_gz: &RemotePathBuf,
1988 delegate: &Arc<dyn SshClientDelegate>,
1989 cx: &mut AsyncApp,
1990 ) -> Result<()> {
1991 if let Some(parent) = tmp_path_gz.parent() {
1992 self.socket
1993 .run_command(
1994 "sh",
1995 &[
1996 "-c",
1997 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
1998 ],
1999 )
2000 .await?;
2001 }
2002
2003 let src_stat = fs::metadata(&src_path).await?;
2004 let size = src_stat.len();
2005
2006 let t0 = Instant::now();
2007 delegate.set_status(Some("Uploading remote development server"), cx);
2008 log::info!(
2009 "uploading remote development server to {:?} ({}kb)",
2010 tmp_path_gz,
2011 size / 1024
2012 );
2013 self.upload_file(src_path, tmp_path_gz)
2014 .await
2015 .context("failed to upload server binary")?;
2016 log::info!("uploaded remote development server in {:?}", t0.elapsed());
2017 Ok(())
2018 }
2019
2020 async fn extract_server_binary(
2021 &self,
2022 dst_path: &RemotePathBuf,
2023 tmp_path: &RemotePathBuf,
2024 delegate: &Arc<dyn SshClientDelegate>,
2025 cx: &mut AsyncApp,
2026 ) -> Result<()> {
2027 delegate.set_status(Some("Extracting remote development server"), cx);
2028 let server_mode = 0o755;
2029
2030 let orig_tmp_path = tmp_path.to_string();
2031 let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
2032 shell_script!(
2033 "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
2034 server_mode = &format!("{:o}", server_mode),
2035 dst_path = &dst_path.to_string(),
2036 )
2037 } else {
2038 shell_script!(
2039 "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
2040 server_mode = &format!("{:o}", server_mode),
2041 dst_path = &dst_path.to_string()
2042 )
2043 };
2044 self.socket.run_command("sh", &["-c", &script]).await?;
2045 Ok(())
2046 }
2047
2048 async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
2049 log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
2050 let mut command = util::command::new_smol_command("scp");
2051 let output = self
2052 .socket
2053 .ssh_options(&mut command)
2054 .args(
2055 self.socket
2056 .connection_options
2057 .port
2058 .map(|port| vec!["-P".to_string(), port.to_string()])
2059 .unwrap_or_default(),
2060 )
2061 .arg(src_path)
2062 .arg(format!(
2063 "{}:{}",
2064 self.socket.connection_options.scp_url(),
2065 dest_path
2066 ))
2067 .output()
2068 .await?;
2069
2070 anyhow::ensure!(
2071 output.status.success(),
2072 "failed to upload file {} -> {}: {}",
2073 src_path.display(),
2074 dest_path.to_string(),
2075 String::from_utf8_lossy(&output.stderr)
2076 );
2077 Ok(())
2078 }
2079
2080 #[cfg(debug_assertions)]
2081 async fn build_local(
2082 &self,
2083 build_remote_server: String,
2084 delegate: &Arc<dyn SshClientDelegate>,
2085 cx: &mut AsyncApp,
2086 ) -> Result<PathBuf> {
2087 use smol::process::{Command, Stdio};
2088 use std::env::VarError;
2089
2090 async fn run_cmd(command: &mut Command) -> Result<()> {
2091 let output = command
2092 .kill_on_drop(true)
2093 .stderr(Stdio::inherit())
2094 .output()
2095 .await?;
2096 anyhow::ensure!(
2097 output.status.success(),
2098 "Failed to run command: {command:?}"
2099 );
2100 Ok(())
2101 }
2102
2103 let use_musl = !build_remote_server.contains("nomusl");
2104 let triple = format!(
2105 "{}-{}",
2106 self.ssh_platform.arch,
2107 match self.ssh_platform.os {
2108 "linux" =>
2109 if use_musl {
2110 "unknown-linux-musl"
2111 } else {
2112 "unknown-linux-gnu"
2113 },
2114 "macos" => "apple-darwin",
2115 _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
2116 }
2117 );
2118 let mut rust_flags = match std::env::var("RUSTFLAGS") {
2119 Ok(val) => val,
2120 Err(VarError::NotPresent) => String::new(),
2121 Err(e) => {
2122 log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
2123 String::new()
2124 }
2125 };
2126 if self.ssh_platform.os == "linux" && use_musl {
2127 rust_flags.push_str(" -C target-feature=+crt-static");
2128 }
2129 if build_remote_server.contains("mold") {
2130 rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
2131 }
2132
2133 if self.ssh_platform.arch == std::env::consts::ARCH
2134 && self.ssh_platform.os == std::env::consts::OS
2135 {
2136 delegate.set_status(Some("Building remote server binary from source"), cx);
2137 log::info!("building remote server binary from source");
2138 run_cmd(
2139 Command::new("cargo")
2140 .args([
2141 "build",
2142 "--package",
2143 "remote_server",
2144 "--features",
2145 "debug-embed",
2146 "--target-dir",
2147 "target/remote_server",
2148 "--target",
2149 &triple,
2150 ])
2151 .env("RUSTFLAGS", &rust_flags),
2152 )
2153 .await?;
2154 } else if build_remote_server.contains("cross") {
2155 #[cfg(target_os = "windows")]
2156 use util::paths::SanitizedPath;
2157
2158 delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
2159 log::info!("installing cross");
2160 run_cmd(Command::new("cargo").args([
2161 "install",
2162 "cross",
2163 "--git",
2164 "https://github.com/cross-rs/cross",
2165 ]))
2166 .await?;
2167
2168 delegate.set_status(
2169 Some(&format!(
2170 "Building remote server binary from source for {} with Docker",
2171 &triple
2172 )),
2173 cx,
2174 );
2175 log::info!("building remote server binary from source for {}", &triple);
2176
2177 // On Windows, the binding needs to be set to the canonical path
2178 #[cfg(target_os = "windows")]
2179 let src =
2180 SanitizedPath::from(smol::fs::canonicalize("./target").await?).to_glob_string();
2181 #[cfg(not(target_os = "windows"))]
2182 let src = "./target";
2183 run_cmd(
2184 Command::new("cross")
2185 .args([
2186 "build",
2187 "--package",
2188 "remote_server",
2189 "--features",
2190 "debug-embed",
2191 "--target-dir",
2192 "target/remote_server",
2193 "--target",
2194 &triple,
2195 ])
2196 .env(
2197 "CROSS_CONTAINER_OPTS",
2198 format!("--mount type=bind,src={src},dst=/app/target"),
2199 )
2200 .env("RUSTFLAGS", &rust_flags),
2201 )
2202 .await?;
2203 } else {
2204 let which = cx
2205 .background_spawn(async move { which::which("zig") })
2206 .await;
2207
2208 if which.is_err() {
2209 #[cfg(not(target_os = "windows"))]
2210 {
2211 anyhow::bail!(
2212 "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2213 )
2214 }
2215 #[cfg(target_os = "windows")]
2216 {
2217 anyhow::bail!(
2218 "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
2219 )
2220 }
2221 }
2222
2223 delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
2224 log::info!("adding rustup target");
2225 run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
2226
2227 delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
2228 log::info!("installing cargo-zigbuild");
2229 run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
2230
2231 delegate.set_status(
2232 Some(&format!(
2233 "Building remote binary from source for {triple} with Zig"
2234 )),
2235 cx,
2236 );
2237 log::info!("building remote binary from source for {triple} with Zig");
2238 run_cmd(
2239 Command::new("cargo")
2240 .args([
2241 "zigbuild",
2242 "--package",
2243 "remote_server",
2244 "--features",
2245 "debug-embed",
2246 "--target-dir",
2247 "target/remote_server",
2248 "--target",
2249 &triple,
2250 ])
2251 .env("RUSTFLAGS", &rust_flags),
2252 )
2253 .await?;
2254 };
2255 let bin_path = Path::new("target")
2256 .join("remote_server")
2257 .join(&triple)
2258 .join("debug")
2259 .join("remote_server");
2260
2261 let path = if !build_remote_server.contains("nocompress") {
2262 delegate.set_status(Some("Compressing binary"), cx);
2263
2264 #[cfg(not(target_os = "windows"))]
2265 {
2266 run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
2267 }
2268 #[cfg(target_os = "windows")]
2269 {
2270 // On Windows, we use 7z to compress the binary
2271 let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
2272 let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
2273 if smol::fs::metadata(&gz_path).await.is_ok() {
2274 smol::fs::remove_file(&gz_path).await?;
2275 }
2276 run_cmd(Command::new(seven_zip).args([
2277 "a",
2278 "-tgzip",
2279 &gz_path,
2280 &bin_path.to_string_lossy(),
2281 ]))
2282 .await?;
2283 }
2284
2285 let mut archive_path = bin_path;
2286 archive_path.set_extension("gz");
2287 std::env::current_dir()?.join(archive_path)
2288 } else {
2289 bin_path
2290 };
2291
2292 Ok(path)
2293 }
2294}
2295
2296type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
2297
2298struct ChannelClient {
2299 next_message_id: AtomicU32,
2300 outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
2301 buffer: Mutex<VecDeque<Envelope>>,
2302 response_channels: ResponseChannels,
2303 message_handlers: Mutex<ProtoMessageHandlerSet>,
2304 max_received: AtomicU32,
2305 name: &'static str,
2306 task: Mutex<Task<Result<()>>>,
2307}
2308
2309impl ChannelClient {
2310 fn new(
2311 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2312 outgoing_tx: mpsc::UnboundedSender<Envelope>,
2313 cx: &App,
2314 name: &'static str,
2315 ) -> Arc<Self> {
2316 Arc::new_cyclic(|this| Self {
2317 outgoing_tx: Mutex::new(outgoing_tx),
2318 next_message_id: AtomicU32::new(0),
2319 max_received: AtomicU32::new(0),
2320 response_channels: ResponseChannels::default(),
2321 message_handlers: Default::default(),
2322 buffer: Mutex::new(VecDeque::new()),
2323 name,
2324 task: Mutex::new(Self::start_handling_messages(
2325 this.clone(),
2326 incoming_rx,
2327 &cx.to_async(),
2328 )),
2329 })
2330 }
2331
2332 fn start_handling_messages(
2333 this: Weak<Self>,
2334 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
2335 cx: &AsyncApp,
2336 ) -> Task<Result<()>> {
2337 cx.spawn(async move |cx| {
2338 let peer_id = PeerId { owner_id: 0, id: 0 };
2339 while let Some(incoming) = incoming_rx.next().await {
2340 let Some(this) = this.upgrade() else {
2341 return anyhow::Ok(());
2342 };
2343 if let Some(ack_id) = incoming.ack_id {
2344 let mut buffer = this.buffer.lock();
2345 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
2346 buffer.pop_front();
2347 }
2348 }
2349 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload
2350 {
2351 log::debug!(
2352 "{}:ssh message received. name:FlushBufferedMessages",
2353 this.name
2354 );
2355 {
2356 let buffer = this.buffer.lock();
2357 for envelope in buffer.iter() {
2358 this.outgoing_tx
2359 .lock()
2360 .unbounded_send(envelope.clone())
2361 .ok();
2362 }
2363 }
2364 let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
2365 envelope.id = this.next_message_id.fetch_add(1, SeqCst);
2366 this.outgoing_tx.lock().unbounded_send(envelope).ok();
2367 continue;
2368 }
2369
2370 this.max_received.store(incoming.id, SeqCst);
2371
2372 if let Some(request_id) = incoming.responding_to {
2373 let request_id = MessageId(request_id);
2374 let sender = this.response_channels.lock().remove(&request_id);
2375 if let Some(sender) = sender {
2376 let (tx, rx) = oneshot::channel();
2377 if incoming.payload.is_some() {
2378 sender.send((incoming, tx)).ok();
2379 }
2380 rx.await.ok();
2381 }
2382 } else if let Some(envelope) =
2383 build_typed_envelope(peer_id, Instant::now(), incoming)
2384 {
2385 let type_name = envelope.payload_type_name();
2386 let message_id = envelope.message_id();
2387 if let Some(future) = ProtoMessageHandlerSet::handle_message(
2388 &this.message_handlers,
2389 envelope,
2390 this.clone().into(),
2391 cx.clone(),
2392 ) {
2393 log::debug!("{}:ssh message received. name:{type_name}", this.name);
2394 cx.foreground_executor()
2395 .spawn(async move {
2396 match future.await {
2397 Ok(_) => {
2398 log::debug!(
2399 "{}:ssh message handled. name:{type_name}",
2400 this.name
2401 );
2402 }
2403 Err(error) => {
2404 log::error!(
2405 "{}:error handling message. type:{}, error:{}",
2406 this.name,
2407 type_name,
2408 format!("{error:#}").lines().fold(
2409 String::new(),
2410 |mut message, line| {
2411 if !message.is_empty() {
2412 message.push(' ');
2413 }
2414 message.push_str(line);
2415 message
2416 }
2417 )
2418 );
2419 }
2420 }
2421 })
2422 .detach()
2423 } else {
2424 log::error!("{}:unhandled ssh message name:{type_name}", this.name);
2425 if let Err(e) = AnyProtoClient::from(this.clone()).send_response(
2426 message_id,
2427 anyhow::anyhow!("no handler registered for {type_name}").to_proto(),
2428 ) {
2429 log::error!(
2430 "{}:error sending error response for {type_name}:{e:#}",
2431 this.name
2432 );
2433 }
2434 }
2435 }
2436 }
2437 anyhow::Ok(())
2438 })
2439 }
2440
2441 fn reconnect(
2442 self: &Arc<Self>,
2443 incoming_rx: UnboundedReceiver<Envelope>,
2444 outgoing_tx: UnboundedSender<Envelope>,
2445 cx: &AsyncApp,
2446 ) {
2447 *self.outgoing_tx.lock() = outgoing_tx;
2448 *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
2449 }
2450
2451 fn request<T: RequestMessage>(
2452 &self,
2453 payload: T,
2454 ) -> impl 'static + Future<Output = Result<T::Response>> {
2455 self.request_internal(payload, true)
2456 }
2457
2458 fn request_internal<T: RequestMessage>(
2459 &self,
2460 payload: T,
2461 use_buffer: bool,
2462 ) -> impl 'static + Future<Output = Result<T::Response>> {
2463 log::debug!("ssh request start. name:{}", T::NAME);
2464 let response =
2465 self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
2466 async move {
2467 let response = response.await?;
2468 log::debug!("ssh request finish. name:{}", T::NAME);
2469 T::Response::from_envelope(response).context("received a response of the wrong type")
2470 }
2471 }
2472
2473 async fn resync(&self, timeout: Duration) -> Result<()> {
2474 smol::future::or(
2475 async {
2476 self.request_internal(proto::FlushBufferedMessages {}, false)
2477 .await?;
2478
2479 for envelope in self.buffer.lock().iter() {
2480 self.outgoing_tx
2481 .lock()
2482 .unbounded_send(envelope.clone())
2483 .ok();
2484 }
2485 Ok(())
2486 },
2487 async {
2488 smol::Timer::after(timeout).await;
2489 anyhow::bail!("Timed out resyncing remote client")
2490 },
2491 )
2492 .await
2493 }
2494
2495 async fn ping(&self, timeout: Duration) -> Result<()> {
2496 smol::future::or(
2497 async {
2498 self.request(proto::Ping {}).await?;
2499 Ok(())
2500 },
2501 async {
2502 smol::Timer::after(timeout).await;
2503 anyhow::bail!("Timed out pinging remote client")
2504 },
2505 )
2506 .await
2507 }
2508
2509 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
2510 log::debug!("ssh send name:{}", T::NAME);
2511 self.send_dynamic(payload.into_envelope(0, None, None))
2512 }
2513
2514 fn request_dynamic(
2515 &self,
2516 mut envelope: proto::Envelope,
2517 type_name: &'static str,
2518 use_buffer: bool,
2519 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
2520 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2521 let (tx, rx) = oneshot::channel();
2522 let mut response_channels_lock = self.response_channels.lock();
2523 response_channels_lock.insert(MessageId(envelope.id), tx);
2524 drop(response_channels_lock);
2525
2526 let result = if use_buffer {
2527 self.send_buffered(envelope)
2528 } else {
2529 self.send_unbuffered(envelope)
2530 };
2531 async move {
2532 if let Err(error) = &result {
2533 log::error!("failed to send message: {error}");
2534 anyhow::bail!("failed to send message: {error}");
2535 }
2536
2537 let response = rx.await.context("connection lost")?.0;
2538 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
2539 return Err(RpcError::from_proto(error, type_name));
2540 }
2541 Ok(response)
2542 }
2543 }
2544
2545 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
2546 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
2547 self.send_buffered(envelope)
2548 }
2549
2550 fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2551 envelope.ack_id = Some(self.max_received.load(SeqCst));
2552 self.buffer.lock().push_back(envelope.clone());
2553 // ignore errors on send (happen while we're reconnecting)
2554 // assume that the global "disconnected" overlay is sufficient.
2555 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2556 Ok(())
2557 }
2558
2559 fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
2560 envelope.ack_id = Some(self.max_received.load(SeqCst));
2561 self.outgoing_tx.lock().unbounded_send(envelope).ok();
2562 Ok(())
2563 }
2564}
2565
2566impl ProtoClient for ChannelClient {
2567 fn request(
2568 &self,
2569 envelope: proto::Envelope,
2570 request_type: &'static str,
2571 ) -> BoxFuture<'static, Result<proto::Envelope>> {
2572 self.request_dynamic(envelope, request_type, true).boxed()
2573 }
2574
2575 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
2576 self.send_dynamic(envelope)
2577 }
2578
2579 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
2580 self.send_dynamic(envelope)
2581 }
2582
2583 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
2584 &self.message_handlers
2585 }
2586
2587 fn is_via_collab(&self) -> bool {
2588 false
2589 }
2590}
2591
2592#[cfg(any(test, feature = "test-support"))]
2593mod fake {
2594 use std::{path::PathBuf, sync::Arc};
2595
2596 use anyhow::Result;
2597 use async_trait::async_trait;
2598 use futures::{
2599 FutureExt, SinkExt, StreamExt,
2600 channel::{
2601 mpsc::{self, Sender},
2602 oneshot,
2603 },
2604 select_biased,
2605 };
2606 use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task, TestAppContext};
2607 use release_channel::ReleaseChannel;
2608 use rpc::proto::Envelope;
2609 use util::paths::{PathStyle, RemotePathBuf};
2610
2611 use super::{
2612 ChannelClient, RemoteConnection, SshArgs, SshClientDelegate, SshConnectionOptions,
2613 SshPlatform,
2614 };
2615
2616 pub(super) struct FakeRemoteConnection {
2617 pub(super) connection_options: SshConnectionOptions,
2618 pub(super) server_channel: Arc<ChannelClient>,
2619 pub(super) server_cx: SendableCx,
2620 }
2621
2622 pub(super) struct SendableCx(AsyncApp);
2623 impl SendableCx {
2624 // SAFETY: When run in test mode, GPUI is always single threaded.
2625 pub(super) fn new(cx: &TestAppContext) -> Self {
2626 Self(cx.to_async())
2627 }
2628
2629 // SAFETY: Enforce that we're on the main thread by requiring a valid AsyncApp
2630 fn get(&self, _: &AsyncApp) -> AsyncApp {
2631 self.0.clone()
2632 }
2633 }
2634
2635 // SAFETY: There is no way to access a SendableCx from a different thread, see [`SendableCx::new`] and [`SendableCx::get`]
2636 unsafe impl Send for SendableCx {}
2637 unsafe impl Sync for SendableCx {}
2638
2639 #[async_trait(?Send)]
2640 impl RemoteConnection for FakeRemoteConnection {
2641 async fn kill(&self) -> Result<()> {
2642 Ok(())
2643 }
2644
2645 fn has_been_killed(&self) -> bool {
2646 false
2647 }
2648
2649 fn ssh_args(&self) -> SshArgs {
2650 SshArgs {
2651 arguments: Vec::new(),
2652 envs: None,
2653 }
2654 }
2655
2656 fn upload_directory(
2657 &self,
2658 _src_path: PathBuf,
2659 _dest_path: RemotePathBuf,
2660 _cx: &App,
2661 ) -> Task<Result<()>> {
2662 unreachable!()
2663 }
2664
2665 fn connection_options(&self) -> SshConnectionOptions {
2666 self.connection_options.clone()
2667 }
2668
2669 fn simulate_disconnect(&self, cx: &AsyncApp) {
2670 let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
2671 let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
2672 self.server_channel
2673 .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(cx));
2674 }
2675
2676 fn start_proxy(
2677 &self,
2678 _unique_identifier: String,
2679 _reconnect: bool,
2680 mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
2681 mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
2682 mut connection_activity_tx: Sender<()>,
2683 _delegate: Arc<dyn SshClientDelegate>,
2684 cx: &mut AsyncApp,
2685 ) -> Task<Result<i32>> {
2686 let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
2687 let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
2688
2689 self.server_channel.reconnect(
2690 server_incoming_rx,
2691 server_outgoing_tx,
2692 &self.server_cx.get(cx),
2693 );
2694
2695 cx.background_spawn(async move {
2696 loop {
2697 select_biased! {
2698 server_to_client = server_outgoing_rx.next().fuse() => {
2699 let Some(server_to_client) = server_to_client else {
2700 return Ok(1)
2701 };
2702 connection_activity_tx.try_send(()).ok();
2703 client_incoming_tx.send(server_to_client).await.ok();
2704 }
2705 client_to_server = client_outgoing_rx.next().fuse() => {
2706 let Some(client_to_server) = client_to_server else {
2707 return Ok(1)
2708 };
2709 server_incoming_tx.send(client_to_server).await.ok();
2710 }
2711 }
2712 }
2713 })
2714 }
2715
2716 fn path_style(&self) -> PathStyle {
2717 PathStyle::current()
2718 }
2719
2720 fn shell(&self) -> String {
2721 "sh".to_owned()
2722 }
2723 }
2724
2725 pub(super) struct Delegate;
2726
2727 impl SshClientDelegate for Delegate {
2728 fn ask_password(&self, _: String, _: oneshot::Sender<String>, _: &mut AsyncApp) {
2729 unreachable!()
2730 }
2731
2732 fn download_server_binary_locally(
2733 &self,
2734 _: SshPlatform,
2735 _: ReleaseChannel,
2736 _: Option<SemanticVersion>,
2737 _: &mut AsyncApp,
2738 ) -> Task<Result<PathBuf>> {
2739 unreachable!()
2740 }
2741
2742 fn get_download_params(
2743 &self,
2744 _platform: SshPlatform,
2745 _release_channel: ReleaseChannel,
2746 _version: Option<SemanticVersion>,
2747 _cx: &mut AsyncApp,
2748 ) -> Task<Result<Option<(String, String)>>> {
2749 unreachable!()
2750 }
2751
2752 fn set_status(&self, _: Option<&str>, _: &mut AsyncApp) {}
2753 }
2754}