1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use async_trait::async_trait;
10use collections::HashMap;
11use futures::{
12 channel::{
13 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
14 oneshot,
15 },
16 future::BoxFuture,
17 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
18 StreamExt as _,
19};
20use gpui::{
21 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
22 WeakModel,
23};
24use parking_lot::Mutex;
25use rpc::{
26 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
27 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
28};
29use smol::{
30 fs,
31 process::{self, Child, Stdio},
32};
33use std::{
34 any::TypeId,
35 collections::VecDeque,
36 ffi::OsStr,
37 fmt,
38 ops::ControlFlow,
39 path::{Path, PathBuf},
40 sync::{
41 atomic::{AtomicU32, Ordering::SeqCst},
42 Arc,
43 },
44 time::{Duration, Instant},
45};
46use tempfile::TempDir;
47use util::ResultExt;
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
51)]
52pub struct SshProjectId(pub u64);
53
54#[derive(Clone)]
55pub struct SshSocket {
56 connection_options: SshConnectionOptions,
57 socket_path: PathBuf,
58}
59
60#[derive(Debug, Default, Clone, PartialEq, Eq)]
61pub struct SshConnectionOptions {
62 pub host: String,
63 pub username: Option<String>,
64 pub port: Option<u16>,
65 pub password: Option<String>,
66 pub args: Option<Vec<String>>,
67}
68
69impl SshConnectionOptions {
70 pub fn parse_command_line(input: &str) -> Result<Self> {
71 let input = input.trim_start_matches("ssh ");
72 let mut hostname: Option<String> = None;
73 let mut username: Option<String> = None;
74 let mut port: Option<u16> = None;
75 let mut args = Vec::new();
76
77 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
78 const ALLOWED_OPTS: &[&str] = &[
79 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
80 ];
81 const ALLOWED_ARGS: &[&str] = &[
82 "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
83 "-w",
84 ];
85
86 let mut tokens = shlex::split(input)
87 .ok_or_else(|| anyhow!("invalid input"))?
88 .into_iter();
89
90 'outer: while let Some(arg) = tokens.next() {
91 if ALLOWED_OPTS.contains(&(&arg as &str)) {
92 args.push(arg.to_string());
93 continue;
94 }
95 if arg == "-p" {
96 port = tokens.next().and_then(|arg| arg.parse().ok());
97 continue;
98 } else if let Some(p) = arg.strip_prefix("-p") {
99 port = p.parse().ok();
100 continue;
101 }
102 if arg == "-l" {
103 username = tokens.next();
104 continue;
105 } else if let Some(l) = arg.strip_prefix("-l") {
106 username = Some(l.to_string());
107 continue;
108 }
109 for a in ALLOWED_ARGS {
110 if arg == *a {
111 args.push(arg);
112 if let Some(next) = tokens.next() {
113 args.push(next);
114 }
115 continue 'outer;
116 } else if arg.starts_with(a) {
117 args.push(arg);
118 continue 'outer;
119 }
120 }
121 if arg.starts_with("-") || hostname.is_some() {
122 anyhow::bail!("unsupported argument: {:?}", arg);
123 }
124 let mut input = &arg as &str;
125 if let Some((u, rest)) = input.split_once('@') {
126 input = rest;
127 username = Some(u.to_string());
128 }
129 if let Some((rest, p)) = input.split_once(':') {
130 input = rest;
131 port = p.parse().ok()
132 }
133 hostname = Some(input.to_string())
134 }
135
136 let Some(hostname) = hostname else {
137 anyhow::bail!("missing hostname");
138 };
139
140 Ok(Self {
141 host: hostname.to_string(),
142 username: username.clone(),
143 port,
144 password: None,
145 args: Some(args),
146 })
147 }
148
149 pub fn ssh_url(&self) -> String {
150 let mut result = String::from("ssh://");
151 if let Some(username) = &self.username {
152 result.push_str(username);
153 result.push('@');
154 }
155 result.push_str(&self.host);
156 if let Some(port) = self.port {
157 result.push(':');
158 result.push_str(&port.to_string());
159 }
160 result
161 }
162
163 pub fn additional_args(&self) -> Option<&Vec<String>> {
164 self.args.as_ref()
165 }
166
167 fn scp_url(&self) -> String {
168 if let Some(username) = &self.username {
169 format!("{}@{}", username, self.host)
170 } else {
171 self.host.clone()
172 }
173 }
174
175 pub fn connection_string(&self) -> String {
176 let host = if let Some(username) = &self.username {
177 format!("{}@{}", username, self.host)
178 } else {
179 self.host.clone()
180 };
181 if let Some(port) = &self.port {
182 format!("{}:{}", host, port)
183 } else {
184 host
185 }
186 }
187
188 // Uniquely identifies dev server projects on a remote host. Needs to be
189 // stable for the same dev server project.
190 pub fn dev_server_identifier(&self) -> String {
191 let mut identifier = format!("dev-server-{:?}", self.host);
192 if let Some(username) = self.username.as_ref() {
193 identifier.push('-');
194 identifier.push_str(&username);
195 }
196 identifier
197 }
198}
199
200#[derive(Copy, Clone, Debug)]
201pub struct SshPlatform {
202 pub os: &'static str,
203 pub arch: &'static str,
204}
205
206impl SshPlatform {
207 pub fn triple(&self) -> Option<String> {
208 Some(format!(
209 "{}-{}",
210 self.arch,
211 match self.os {
212 "linux" => "unknown-linux-gnu",
213 "macos" => "apple-darwin",
214 _ => return None,
215 }
216 ))
217 }
218}
219
220pub trait SshClientDelegate: Send + Sync {
221 fn ask_password(
222 &self,
223 prompt: String,
224 cx: &mut AsyncAppContext,
225 ) -> oneshot::Receiver<Result<String>>;
226 fn remote_server_binary_path(
227 &self,
228 platform: SshPlatform,
229 cx: &mut AsyncAppContext,
230 ) -> Result<PathBuf>;
231 fn get_server_binary(
232 &self,
233 platform: SshPlatform,
234 cx: &mut AsyncAppContext,
235 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
236 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
237 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
238}
239
240impl SshSocket {
241 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
242 let mut command = process::Command::new("ssh");
243 self.ssh_options(&mut command)
244 .arg(self.connection_options.ssh_url())
245 .arg(program);
246 command
247 }
248
249 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
250 command
251 .stdin(Stdio::piped())
252 .stdout(Stdio::piped())
253 .stderr(Stdio::piped())
254 .args(["-o", "ControlMaster=no", "-o"])
255 .arg(format!("ControlPath={}", self.socket_path.display()))
256 }
257
258 fn ssh_args(&self) -> Vec<String> {
259 vec![
260 "-o".to_string(),
261 "ControlMaster=no".to_string(),
262 "-o".to_string(),
263 format!("ControlPath={}", self.socket_path.display()),
264 self.connection_options.ssh_url(),
265 ]
266 }
267}
268
269async fn run_cmd(command: &mut process::Command) -> Result<String> {
270 let output = command.output().await?;
271 if output.status.success() {
272 Ok(String::from_utf8_lossy(&output.stdout).to_string())
273 } else {
274 Err(anyhow!(
275 "failed to run command: {}",
276 String::from_utf8_lossy(&output.stderr)
277 ))
278 }
279}
280
281pub struct ChannelForwarder {
282 quit_tx: UnboundedSender<()>,
283 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
284}
285
286impl ChannelForwarder {
287 fn new(
288 mut incoming_tx: UnboundedSender<Envelope>,
289 mut outgoing_rx: UnboundedReceiver<Envelope>,
290 cx: &AsyncAppContext,
291 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
292 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
293
294 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
295 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
296
297 let forwarding_task = cx.background_executor().spawn(async move {
298 loop {
299 select_biased! {
300 _ = quit_rx.next().fuse() => {
301 break;
302 },
303 incoming_envelope = proxy_incoming_rx.next().fuse() => {
304 if let Some(envelope) = incoming_envelope {
305 if incoming_tx.send(envelope).await.is_err() {
306 break;
307 }
308 } else {
309 break;
310 }
311 }
312 outgoing_envelope = outgoing_rx.next().fuse() => {
313 if let Some(envelope) = outgoing_envelope {
314 if proxy_outgoing_tx.send(envelope).await.is_err() {
315 break;
316 }
317 } else {
318 break;
319 }
320 }
321 }
322 }
323
324 (incoming_tx, outgoing_rx)
325 });
326
327 (
328 Self {
329 forwarding_task,
330 quit_tx,
331 },
332 proxy_incoming_tx,
333 proxy_outgoing_rx,
334 )
335 }
336
337 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
338 let _ = self.quit_tx.send(()).await;
339 self.forwarding_task.await
340 }
341}
342
343const MAX_MISSED_HEARTBEATS: usize = 5;
344const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
345const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
346
347const MAX_RECONNECT_ATTEMPTS: usize = 3;
348
349enum State {
350 Connecting,
351 Connected {
352 ssh_connection: Box<dyn SshRemoteProcess>,
353 delegate: Arc<dyn SshClientDelegate>,
354 forwarder: ChannelForwarder,
355
356 multiplex_task: Task<Result<()>>,
357 heartbeat_task: Task<Result<()>>,
358 },
359 HeartbeatMissed {
360 missed_heartbeats: usize,
361
362 ssh_connection: Box<dyn SshRemoteProcess>,
363 delegate: Arc<dyn SshClientDelegate>,
364 forwarder: ChannelForwarder,
365
366 multiplex_task: Task<Result<()>>,
367 heartbeat_task: Task<Result<()>>,
368 },
369 Reconnecting,
370 ReconnectFailed {
371 ssh_connection: Box<dyn SshRemoteProcess>,
372 delegate: Arc<dyn SshClientDelegate>,
373 forwarder: ChannelForwarder,
374
375 error: anyhow::Error,
376 attempts: usize,
377 },
378 ReconnectExhausted,
379 ServerNotRunning,
380}
381
382impl fmt::Display for State {
383 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384 match self {
385 Self::Connecting => write!(f, "connecting"),
386 Self::Connected { .. } => write!(f, "connected"),
387 Self::Reconnecting => write!(f, "reconnecting"),
388 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
389 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
390 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
391 Self::ServerNotRunning { .. } => write!(f, "server not running"),
392 }
393 }
394}
395
396impl State {
397 fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> {
398 match self {
399 Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
400 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
401 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
402 _ => None,
403 }
404 }
405
406 fn can_reconnect(&self) -> bool {
407 match self {
408 Self::Connected { .. }
409 | Self::HeartbeatMissed { .. }
410 | Self::ReconnectFailed { .. } => true,
411 State::Connecting
412 | State::Reconnecting
413 | State::ReconnectExhausted
414 | State::ServerNotRunning => false,
415 }
416 }
417
418 fn is_reconnect_failed(&self) -> bool {
419 matches!(self, Self::ReconnectFailed { .. })
420 }
421
422 fn is_reconnect_exhausted(&self) -> bool {
423 matches!(self, Self::ReconnectExhausted { .. })
424 }
425
426 fn is_reconnecting(&self) -> bool {
427 matches!(self, Self::Reconnecting { .. })
428 }
429
430 fn heartbeat_recovered(self) -> Self {
431 match self {
432 Self::HeartbeatMissed {
433 ssh_connection,
434 delegate,
435 forwarder,
436 multiplex_task,
437 heartbeat_task,
438 ..
439 } => Self::Connected {
440 ssh_connection,
441 delegate,
442 forwarder,
443 multiplex_task,
444 heartbeat_task,
445 },
446 _ => self,
447 }
448 }
449
450 fn heartbeat_missed(self) -> Self {
451 match self {
452 Self::Connected {
453 ssh_connection,
454 delegate,
455 forwarder,
456 multiplex_task,
457 heartbeat_task,
458 } => Self::HeartbeatMissed {
459 missed_heartbeats: 1,
460 ssh_connection,
461 delegate,
462 forwarder,
463 multiplex_task,
464 heartbeat_task,
465 },
466 Self::HeartbeatMissed {
467 missed_heartbeats,
468 ssh_connection,
469 delegate,
470 forwarder,
471 multiplex_task,
472 heartbeat_task,
473 } => Self::HeartbeatMissed {
474 missed_heartbeats: missed_heartbeats + 1,
475 ssh_connection,
476 delegate,
477 forwarder,
478 multiplex_task,
479 heartbeat_task,
480 },
481 _ => self,
482 }
483 }
484}
485
486/// The state of the ssh connection.
487#[derive(Clone, Copy, Debug, PartialEq, Eq)]
488pub enum ConnectionState {
489 Connecting,
490 Connected,
491 HeartbeatMissed,
492 Reconnecting,
493 Disconnected,
494}
495
496impl From<&State> for ConnectionState {
497 fn from(value: &State) -> Self {
498 match value {
499 State::Connecting => Self::Connecting,
500 State::Connected { .. } => Self::Connected,
501 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
502 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
503 State::ReconnectExhausted => Self::Disconnected,
504 State::ServerNotRunning => Self::Disconnected,
505 }
506 }
507}
508
509pub struct SshRemoteClient {
510 client: Arc<ChannelClient>,
511 unique_identifier: String,
512 connection_options: SshConnectionOptions,
513 state: Arc<Mutex<Option<State>>>,
514}
515
516#[derive(Debug)]
517pub enum SshRemoteEvent {
518 Disconnected,
519}
520
521impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
522
523impl SshRemoteClient {
524 pub fn new(
525 unique_identifier: String,
526 connection_options: SshConnectionOptions,
527 delegate: Arc<dyn SshClientDelegate>,
528 cx: &AppContext,
529 ) -> Task<Result<Model<Self>>> {
530 cx.spawn(|mut cx| async move {
531 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
532 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
533 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
534
535 let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
536 let this = cx.new_model(|_| Self {
537 client: client.clone(),
538 unique_identifier: unique_identifier.clone(),
539 connection_options: connection_options.clone(),
540 state: Arc::new(Mutex::new(Some(State::Connecting))),
541 })?;
542
543 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
544 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
545
546 let (ssh_connection, io_task) = Self::establish_connection(
547 unique_identifier,
548 false,
549 connection_options,
550 proxy_incoming_tx,
551 proxy_outgoing_rx,
552 connection_activity_tx,
553 delegate.clone(),
554 &mut cx,
555 )
556 .await?;
557
558 let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
559
560 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
561 log::error!("failed to establish connection: {}", error);
562 delegate.set_error(error.to_string(), &mut cx);
563 return Err(error);
564 }
565
566 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
567
568 this.update(&mut cx, |this, _| {
569 *this.state.lock() = Some(State::Connected {
570 ssh_connection,
571 delegate,
572 forwarder: proxy,
573 multiplex_task,
574 heartbeat_task,
575 });
576 })?;
577
578 Ok(this)
579 })
580 }
581
582 pub fn shutdown_processes<T: RequestMessage>(
583 &self,
584 shutdown_request: Option<T>,
585 ) -> Option<impl Future<Output = ()>> {
586 let state = self.state.lock().take()?;
587 log::info!("shutting down ssh processes");
588
589 let State::Connected {
590 multiplex_task,
591 heartbeat_task,
592 ssh_connection,
593 delegate,
594 forwarder,
595 } = state
596 else {
597 return None;
598 };
599
600 let client = self.client.clone();
601
602 Some(async move {
603 if let Some(shutdown_request) = shutdown_request {
604 client.send(shutdown_request).log_err();
605 // We wait 50ms instead of waiting for a response, because
606 // waiting for a response would require us to wait on the main thread
607 // which we want to avoid in an `on_app_quit` callback.
608 smol::Timer::after(Duration::from_millis(50)).await;
609 }
610
611 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
612 // child of master_process.
613 drop(multiplex_task);
614 // Now drop the rest of state, which kills master process.
615 drop(heartbeat_task);
616 drop(ssh_connection);
617 drop(delegate);
618 drop(forwarder);
619 })
620 }
621
622 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
623 let mut lock = self.state.lock();
624
625 let can_reconnect = lock
626 .as_ref()
627 .map(|state| state.can_reconnect())
628 .unwrap_or(false);
629 if !can_reconnect {
630 let error = if let Some(state) = lock.as_ref() {
631 format!("invalid state, cannot reconnect while in state {state}")
632 } else {
633 "no state set".to_string()
634 };
635 log::info!("aborting reconnect, because not in state that allows reconnecting");
636 return Err(anyhow!(error));
637 }
638
639 let state = lock.take().unwrap();
640 let (attempts, mut ssh_connection, delegate, forwarder) = match state {
641 State::Connected {
642 ssh_connection,
643 delegate,
644 forwarder,
645 multiplex_task,
646 heartbeat_task,
647 }
648 | State::HeartbeatMissed {
649 ssh_connection,
650 delegate,
651 forwarder,
652 multiplex_task,
653 heartbeat_task,
654 ..
655 } => {
656 drop(multiplex_task);
657 drop(heartbeat_task);
658 (0, ssh_connection, delegate, forwarder)
659 }
660 State::ReconnectFailed {
661 attempts,
662 ssh_connection,
663 delegate,
664 forwarder,
665 ..
666 } => (attempts, ssh_connection, delegate, forwarder),
667 State::Connecting
668 | State::Reconnecting
669 | State::ReconnectExhausted
670 | State::ServerNotRunning => unreachable!(),
671 };
672
673 let attempts = attempts + 1;
674 if attempts > MAX_RECONNECT_ATTEMPTS {
675 log::error!(
676 "Failed to reconnect to after {} attempts, giving up",
677 MAX_RECONNECT_ATTEMPTS
678 );
679 drop(lock);
680 self.set_state(State::ReconnectExhausted, cx);
681 return Ok(());
682 }
683 drop(lock);
684
685 self.set_state(State::Reconnecting, cx);
686
687 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
688
689 let identifier = self.unique_identifier.clone();
690 let client = self.client.clone();
691 let reconnect_task = cx.spawn(|this, mut cx| async move {
692 macro_rules! failed {
693 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
694 return State::ReconnectFailed {
695 error: anyhow!($error),
696 attempts: $attempts,
697 ssh_connection: $ssh_connection,
698 delegate: $delegate,
699 forwarder: $forwarder,
700 };
701 };
702 }
703
704 if let Err(error) = ssh_connection.kill().await.context("Failed to kill ssh process") {
705 failed!(error, attempts, ssh_connection, delegate, forwarder);
706 };
707
708 let connection_options = ssh_connection.connection_options();
709
710 let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
711 let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
712 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
713 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
714
715 let (ssh_connection, io_task) = match Self::establish_connection(
716 identifier,
717 true,
718 connection_options,
719 proxy_incoming_tx,
720 proxy_outgoing_rx,
721 connection_activity_tx,
722 delegate.clone(),
723 &mut cx,
724 )
725 .await
726 {
727 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
728 Err(error) => {
729 failed!(error, attempts, ssh_connection, delegate, forwarder);
730 }
731 };
732
733 let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
734
735 if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
736 failed!(error, attempts, ssh_connection, delegate, forwarder);
737 };
738
739 State::Connected {
740 ssh_connection,
741 delegate,
742 forwarder,
743 multiplex_task,
744 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
745 }
746 });
747
748 cx.spawn(|this, mut cx| async move {
749 let new_state = reconnect_task.await;
750 this.update(&mut cx, |this, cx| {
751 this.try_set_state(cx, |old_state| {
752 if old_state.is_reconnecting() {
753 match &new_state {
754 State::Connecting
755 | State::Reconnecting { .. }
756 | State::HeartbeatMissed { .. }
757 | State::ServerNotRunning => {}
758 State::Connected { .. } => {
759 log::info!("Successfully reconnected");
760 }
761 State::ReconnectFailed {
762 error, attempts, ..
763 } => {
764 log::error!(
765 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
766 attempts,
767 error
768 );
769 }
770 State::ReconnectExhausted => {
771 log::error!("Reconnect attempt failed and all attempts exhausted");
772 }
773 }
774 Some(new_state)
775 } else {
776 None
777 }
778 });
779
780 if this.state_is(State::is_reconnect_failed) {
781 this.reconnect(cx)
782 } else if this.state_is(State::is_reconnect_exhausted) {
783 cx.emit(SshRemoteEvent::Disconnected);
784 Ok(())
785 } else {
786 log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
787 Ok(())
788 }
789 })
790 })
791 .detach_and_log_err(cx);
792
793 Ok(())
794 }
795
796 fn heartbeat(
797 this: WeakModel<Self>,
798 mut connection_activity_rx: mpsc::Receiver<()>,
799 cx: &mut AsyncAppContext,
800 ) -> Task<Result<()>> {
801 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
802 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
803 };
804
805 cx.spawn(|mut cx| {
806 let this = this.clone();
807 async move {
808 let mut missed_heartbeats = 0;
809
810 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
811 futures::pin_mut!(keepalive_timer);
812
813 loop {
814 select_biased! {
815 result = connection_activity_rx.next().fuse() => {
816 if result.is_none() {
817 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
818 return Ok(());
819 }
820
821 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
822
823 if missed_heartbeats != 0 {
824 missed_heartbeats = 0;
825 this.update(&mut cx, |this, mut cx| {
826 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
827 })?;
828 }
829 }
830 _ = keepalive_timer => {
831 log::debug!("Sending heartbeat to server...");
832
833 let result = select_biased! {
834 _ = connection_activity_rx.next().fuse() => {
835 Ok(())
836 }
837 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
838 ping_result
839 }
840 };
841
842 if result.is_err() {
843 missed_heartbeats += 1;
844 log::warn!(
845 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
846 HEARTBEAT_TIMEOUT,
847 missed_heartbeats,
848 MAX_MISSED_HEARTBEATS
849 );
850 } else if missed_heartbeats != 0 {
851 missed_heartbeats = 0;
852 } else {
853 continue;
854 }
855
856 let result = this.update(&mut cx, |this, mut cx| {
857 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
858 })?;
859 if result.is_break() {
860 return Ok(());
861 }
862 }
863 }
864 }
865 }
866 })
867 }
868
869 fn handle_heartbeat_result(
870 &mut self,
871 missed_heartbeats: usize,
872 cx: &mut ModelContext<Self>,
873 ) -> ControlFlow<()> {
874 let state = self.state.lock().take().unwrap();
875 let next_state = if missed_heartbeats > 0 {
876 state.heartbeat_missed()
877 } else {
878 state.heartbeat_recovered()
879 };
880
881 self.set_state(next_state, cx);
882
883 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
884 log::error!(
885 "Missed last {} heartbeats. Reconnecting...",
886 missed_heartbeats
887 );
888
889 self.reconnect(cx)
890 .context("failed to start reconnect process after missing heartbeats")
891 .log_err();
892 ControlFlow::Break(())
893 } else {
894 ControlFlow::Continue(())
895 }
896 }
897
898 fn multiplex(
899 mut ssh_proxy_process: Child,
900 incoming_tx: UnboundedSender<Envelope>,
901 mut outgoing_rx: UnboundedReceiver<Envelope>,
902 mut connection_activity_tx: Sender<()>,
903 cx: &AsyncAppContext,
904 ) -> Task<Result<Option<i32>>> {
905 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
906 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
907 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
908
909 cx.background_executor().spawn(async move {
910 let mut stdin_buffer = Vec::new();
911 let mut stdout_buffer = Vec::new();
912 let mut stderr_buffer = Vec::new();
913 let mut stderr_offset = 0;
914
915 loop {
916 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
917 stderr_buffer.resize(stderr_offset + 1024, 0);
918
919 select_biased! {
920 outgoing = outgoing_rx.next().fuse() => {
921 let Some(outgoing) = outgoing else {
922 return anyhow::Ok(None);
923 };
924
925 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
926 }
927
928 result = child_stdout.read(&mut stdout_buffer).fuse() => {
929 match result {
930 Ok(0) => {
931 child_stdin.close().await?;
932 outgoing_rx.close();
933 let status = ssh_proxy_process.status().await?;
934 // If we don't have a code, we assume process
935 // has been killed and treat it as non-zero exit
936 // code
937 return Ok(status.code().or_else(|| Some(1)));
938 }
939 Ok(len) => {
940 if len < stdout_buffer.len() {
941 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
942 }
943
944 let message_len = message_len_from_buffer(&stdout_buffer);
945 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
946 Ok(envelope) => {
947 connection_activity_tx.try_send(()).ok();
948 incoming_tx.unbounded_send(envelope).ok();
949 }
950 Err(error) => {
951 log::error!("error decoding message {error:?}");
952 }
953 }
954 }
955 Err(error) => {
956 Err(anyhow!("error reading stdout: {error:?}"))?;
957 }
958 }
959 }
960
961 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
962 match result {
963 Ok(len) => {
964 stderr_offset += len;
965 let mut start_ix = 0;
966 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
967 let line_ix = start_ix + ix;
968 let content = &stderr_buffer[start_ix..line_ix];
969 start_ix = line_ix + 1;
970 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
971 record.log(log::logger())
972 } else {
973 eprintln!("(remote) {}", String::from_utf8_lossy(content));
974 }
975 }
976 stderr_buffer.drain(0..start_ix);
977 stderr_offset -= start_ix;
978
979 connection_activity_tx.try_send(()).ok();
980 }
981 Err(error) => {
982 Err(anyhow!("error reading stderr: {error:?}"))?;
983 }
984 }
985 }
986 }
987 }
988 })
989 }
990
991 fn monitor(
992 this: WeakModel<Self>,
993 io_task: Task<Result<Option<i32>>>,
994 cx: &AsyncAppContext,
995 ) -> Task<Result<()>> {
996 cx.spawn(|mut cx| async move {
997 let result = io_task.await;
998
999 match result {
1000 Ok(Some(exit_code)) => {
1001 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1002 match error {
1003 ProxyLaunchError::ServerNotRunning => {
1004 log::error!("failed to reconnect because server is not running");
1005 this.update(&mut cx, |this, cx| {
1006 this.set_state(State::ServerNotRunning, cx);
1007 cx.emit(SshRemoteEvent::Disconnected);
1008 })?;
1009 }
1010 }
1011 } else if exit_code > 0 {
1012 log::error!("proxy process terminated unexpectedly");
1013 this.update(&mut cx, |this, cx| {
1014 this.reconnect(cx).ok();
1015 })?;
1016 }
1017 }
1018 Ok(None) => {}
1019 Err(error) => {
1020 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1021 this.update(&mut cx, |this, cx| {
1022 this.reconnect(cx).ok();
1023 })?;
1024 }
1025 }
1026 Ok(())
1027 })
1028 }
1029
1030 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1031 self.state.lock().as_ref().map_or(false, check)
1032 }
1033
1034 fn try_set_state(
1035 &self,
1036 cx: &mut ModelContext<Self>,
1037 map: impl FnOnce(&State) -> Option<State>,
1038 ) {
1039 let mut lock = self.state.lock();
1040 let new_state = lock.as_ref().and_then(map);
1041
1042 if let Some(new_state) = new_state {
1043 lock.replace(new_state);
1044 cx.notify();
1045 }
1046 }
1047
1048 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
1049 log::info!("setting state to '{}'", &state);
1050 self.state.lock().replace(state);
1051 cx.notify();
1052 }
1053
1054 #[allow(clippy::too_many_arguments)]
1055 async fn establish_connection(
1056 unique_identifier: String,
1057 reconnect: bool,
1058 connection_options: SshConnectionOptions,
1059 proxy_incoming_tx: UnboundedSender<Envelope>,
1060 proxy_outgoing_rx: UnboundedReceiver<Envelope>,
1061 connection_activity_tx: Sender<()>,
1062 delegate: Arc<dyn SshClientDelegate>,
1063 cx: &mut AsyncAppContext,
1064 ) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<Option<i32>>>)> {
1065 #[cfg(any(test, feature = "test-support"))]
1066 if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
1067 let io_task = fake::SshRemoteConnection::multiplex(
1068 fake.connection_options(),
1069 proxy_incoming_tx,
1070 proxy_outgoing_rx,
1071 connection_activity_tx,
1072 cx,
1073 )
1074 .await;
1075 return Ok((fake, io_task));
1076 }
1077
1078 let ssh_connection =
1079 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
1080
1081 let platform = ssh_connection.query_platform().await?;
1082 let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1083 if !reconnect {
1084 ssh_connection
1085 .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1086 .await?;
1087 }
1088
1089 let socket = ssh_connection.socket.clone();
1090 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1091
1092 delegate.set_status(Some("Starting proxy"), cx);
1093
1094 let mut start_proxy_command = format!(
1095 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1096 std::env::var("RUST_LOG").unwrap_or_default(),
1097 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1098 remote_binary_path,
1099 unique_identifier,
1100 );
1101 if reconnect {
1102 start_proxy_command.push_str(" --reconnect");
1103 }
1104
1105 let ssh_proxy_process = socket
1106 .ssh_command(start_proxy_command)
1107 // IMPORTANT: we kill this process when we drop the task that uses it.
1108 .kill_on_drop(true)
1109 .spawn()
1110 .context("failed to spawn remote server")?;
1111
1112 let io_task = Self::multiplex(
1113 ssh_proxy_process,
1114 proxy_incoming_tx,
1115 proxy_outgoing_rx,
1116 connection_activity_tx,
1117 &cx,
1118 );
1119
1120 Ok((Box::new(ssh_connection), io_task))
1121 }
1122
1123 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1124 self.client.subscribe_to_entity(remote_id, entity);
1125 }
1126
1127 pub fn ssh_args(&self) -> Option<Vec<String>> {
1128 self.state
1129 .lock()
1130 .as_ref()
1131 .and_then(|state| state.ssh_connection())
1132 .map(|ssh_connection| ssh_connection.ssh_args())
1133 }
1134
1135 pub fn proto_client(&self) -> AnyProtoClient {
1136 self.client.clone().into()
1137 }
1138
1139 pub fn connection_string(&self) -> String {
1140 self.connection_options.connection_string()
1141 }
1142
1143 pub fn connection_options(&self) -> SshConnectionOptions {
1144 self.connection_options.clone()
1145 }
1146
1147 pub fn connection_state(&self) -> ConnectionState {
1148 self.state
1149 .lock()
1150 .as_ref()
1151 .map(ConnectionState::from)
1152 .unwrap_or(ConnectionState::Disconnected)
1153 }
1154
1155 pub fn is_disconnected(&self) -> bool {
1156 self.connection_state() == ConnectionState::Disconnected
1157 }
1158
1159 #[cfg(any(test, feature = "test-support"))]
1160 pub fn simulate_disconnect(&self, cx: &mut AppContext) -> Task<()> {
1161 use gpui::BorrowAppContext;
1162
1163 let port = self.connection_options().port.unwrap();
1164
1165 let disconnect =
1166 cx.update_global(|c: &mut fake::GlobalConnections, _cx| c.take(port).into_channels());
1167 cx.spawn(|mut cx| async move {
1168 let (input_rx, output_tx) = disconnect.await;
1169 let (forwarder, _, _) = ChannelForwarder::new(input_rx, output_tx, &mut cx);
1170 cx.update_global(|c: &mut fake::GlobalConnections, _cx| c.replace(port, forwarder))
1171 .unwrap()
1172 })
1173 }
1174
1175 #[cfg(any(test, feature = "test-support"))]
1176 pub fn fake_server(
1177 server_cx: &mut gpui::TestAppContext,
1178 ) -> (ChannelForwarder, Arc<ChannelClient>) {
1179 server_cx.update(|cx| {
1180 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
1181 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
1182
1183 // We use the forwarder on the server side (in production we only use one on the client side)
1184 // the idea is that we can simulate a disconnect/reconnect by just messing with the forwarder.
1185 let (forwarder, _, _) =
1186 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx.to_async());
1187
1188 let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
1189 (forwarder, client)
1190 })
1191 }
1192
1193 #[cfg(any(test, feature = "test-support"))]
1194 pub async fn fake_client(
1195 forwarder: ChannelForwarder,
1196 client_cx: &mut gpui::TestAppContext,
1197 ) -> Model<Self> {
1198 use gpui::BorrowAppContext;
1199 client_cx
1200 .update(|cx| {
1201 let port = cx.update_default_global(|c: &mut fake::GlobalConnections, _cx| {
1202 c.push(forwarder)
1203 });
1204
1205 Self::new(
1206 "fake".to_string(),
1207 SshConnectionOptions {
1208 host: "<fake>".to_string(),
1209 port: Some(port),
1210 ..Default::default()
1211 },
1212 Arc::new(fake::Delegate),
1213 cx,
1214 )
1215 })
1216 .await
1217 .unwrap()
1218 }
1219}
1220
1221impl From<SshRemoteClient> for AnyProtoClient {
1222 fn from(client: SshRemoteClient) -> Self {
1223 AnyProtoClient::new(client.client.clone())
1224 }
1225}
1226
1227#[async_trait]
1228trait SshRemoteProcess: Send + Sync {
1229 async fn kill(&mut self) -> Result<()>;
1230 fn ssh_args(&self) -> Vec<String>;
1231 fn connection_options(&self) -> SshConnectionOptions;
1232}
1233
1234struct SshRemoteConnection {
1235 socket: SshSocket,
1236 master_process: process::Child,
1237 _temp_dir: TempDir,
1238}
1239
1240impl Drop for SshRemoteConnection {
1241 fn drop(&mut self) {
1242 if let Err(error) = self.master_process.kill() {
1243 log::error!("failed to kill SSH master process: {}", error);
1244 }
1245 }
1246}
1247
1248#[async_trait]
1249impl SshRemoteProcess for SshRemoteConnection {
1250 async fn kill(&mut self) -> Result<()> {
1251 self.master_process.kill()?;
1252
1253 self.master_process.status().await?;
1254
1255 Ok(())
1256 }
1257
1258 fn ssh_args(&self) -> Vec<String> {
1259 self.socket.ssh_args()
1260 }
1261
1262 fn connection_options(&self) -> SshConnectionOptions {
1263 self.socket.connection_options.clone()
1264 }
1265}
1266
1267impl SshRemoteConnection {
1268 #[cfg(not(unix))]
1269 async fn new(
1270 _connection_options: SshConnectionOptions,
1271 _delegate: Arc<dyn SshClientDelegate>,
1272 _cx: &mut AsyncAppContext,
1273 ) -> Result<Self> {
1274 Err(anyhow!("ssh is not supported on this platform"))
1275 }
1276
1277 #[cfg(unix)]
1278 async fn new(
1279 connection_options: SshConnectionOptions,
1280 delegate: Arc<dyn SshClientDelegate>,
1281 cx: &mut AsyncAppContext,
1282 ) -> Result<Self> {
1283 use futures::{io::BufReader, AsyncBufReadExt as _};
1284 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1285 use util::ResultExt as _;
1286
1287 delegate.set_status(Some("connecting"), cx);
1288
1289 let url = connection_options.ssh_url();
1290 let temp_dir = tempfile::Builder::new()
1291 .prefix("zed-ssh-session")
1292 .tempdir()?;
1293
1294 // Create a domain socket listener to handle requests from the askpass program.
1295 let askpass_socket = temp_dir.path().join("askpass.sock");
1296 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1297 let listener =
1298 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1299
1300 let askpass_task = cx.spawn({
1301 let delegate = delegate.clone();
1302 |mut cx| async move {
1303 let mut askpass_opened_tx = Some(askpass_opened_tx);
1304
1305 while let Ok((mut stream, _)) = listener.accept().await {
1306 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1307 askpass_opened_tx.send(()).ok();
1308 }
1309 let mut buffer = Vec::new();
1310 let mut reader = BufReader::new(&mut stream);
1311 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1312 buffer.clear();
1313 }
1314 let password_prompt = String::from_utf8_lossy(&buffer);
1315 if let Some(password) = delegate
1316 .ask_password(password_prompt.to_string(), &mut cx)
1317 .await
1318 .context("failed to get ssh password")
1319 .and_then(|p| p)
1320 .log_err()
1321 {
1322 stream.write_all(password.as_bytes()).await.log_err();
1323 }
1324 }
1325 }
1326 });
1327
1328 // Create an askpass script that communicates back to this process.
1329 let askpass_script = format!(
1330 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1331 askpass_socket = askpass_socket.display(),
1332 print_args = "printf '%s\\0' \"$@\"",
1333 shebang = "#!/bin/sh",
1334 );
1335 let askpass_script_path = temp_dir.path().join("askpass.sh");
1336 fs::write(&askpass_script_path, askpass_script).await?;
1337 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1338
1339 // Start the master SSH process, which does not do anything except for establish
1340 // the connection and keep it open, allowing other ssh commands to reuse it
1341 // via a control socket.
1342 let socket_path = temp_dir.path().join("ssh.sock");
1343 let mut master_process = process::Command::new("ssh")
1344 .stdin(Stdio::null())
1345 .stdout(Stdio::piped())
1346 .stderr(Stdio::piped())
1347 .env("SSH_ASKPASS_REQUIRE", "force")
1348 .env("SSH_ASKPASS", &askpass_script_path)
1349 .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1350 .args([
1351 "-N",
1352 "-o",
1353 "ControlPersist=no",
1354 "-o",
1355 "ControlMaster=yes",
1356 "-o",
1357 ])
1358 .arg(format!("ControlPath={}", socket_path.display()))
1359 .arg(&url)
1360 .spawn()?;
1361
1362 // Wait for this ssh process to close its stdout, indicating that authentication
1363 // has completed.
1364 let stdout = master_process.stdout.as_mut().unwrap();
1365 let mut output = Vec::new();
1366 let connection_timeout = Duration::from_secs(10);
1367
1368 let result = select_biased! {
1369 _ = askpass_opened_rx.fuse() => {
1370 // If the askpass script has opened, that means the user is typing
1371 // their password, in which case we don't want to timeout anymore,
1372 // since we know a connection has been established.
1373 stdout.read_to_end(&mut output).await?;
1374 Ok(())
1375 }
1376 result = stdout.read_to_end(&mut output).fuse() => {
1377 result?;
1378 Ok(())
1379 }
1380 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1381 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1382 }
1383 };
1384
1385 if let Err(e) = result {
1386 let error_message = format!("Failed to connect to host: {}.", e);
1387 delegate.set_error(error_message, cx);
1388 return Err(e);
1389 }
1390
1391 drop(askpass_task);
1392
1393 if master_process.try_status()?.is_some() {
1394 output.clear();
1395 let mut stderr = master_process.stderr.take().unwrap();
1396 stderr.read_to_end(&mut output).await?;
1397
1398 let error_message = format!(
1399 "failed to connect: {}",
1400 String::from_utf8_lossy(&output).trim()
1401 );
1402 delegate.set_error(error_message.clone(), cx);
1403 Err(anyhow!(error_message))?;
1404 }
1405
1406 Ok(Self {
1407 socket: SshSocket {
1408 connection_options,
1409 socket_path,
1410 },
1411 master_process,
1412 _temp_dir: temp_dir,
1413 })
1414 }
1415
1416 async fn ensure_server_binary(
1417 &self,
1418 delegate: &Arc<dyn SshClientDelegate>,
1419 dst_path: &Path,
1420 platform: SshPlatform,
1421 cx: &mut AsyncAppContext,
1422 ) -> Result<()> {
1423 if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1424 if let Ok(installed_version) =
1425 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1426 {
1427 log::info!("using cached server binary version {}", installed_version);
1428 return Ok(());
1429 }
1430 }
1431
1432 let mut dst_path_gz = dst_path.to_path_buf();
1433 dst_path_gz.set_extension("gz");
1434
1435 if let Some(parent) = dst_path.parent() {
1436 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1437 }
1438
1439 let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
1440
1441 let mut server_binary_exists = false;
1442 if !server_binary_exists && cfg!(not(debug_assertions)) {
1443 if let Ok(installed_version) =
1444 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1445 {
1446 if installed_version.trim() == version.to_string() {
1447 server_binary_exists = true;
1448 }
1449 }
1450 }
1451
1452 if server_binary_exists {
1453 log::info!("remote development server already present",);
1454 return Ok(());
1455 }
1456
1457 let src_stat = fs::metadata(&src_path).await?;
1458 let size = src_stat.len();
1459 let server_mode = 0o755;
1460
1461 let t0 = Instant::now();
1462 delegate.set_status(Some("Uploading remote development server"), cx);
1463 log::info!("uploading remote development server ({}kb)", size / 1024);
1464 self.upload_file(&src_path, &dst_path_gz)
1465 .await
1466 .context("failed to upload server binary")?;
1467 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1468
1469 delegate.set_status(Some("Extracting remote development server"), cx);
1470 run_cmd(
1471 self.socket
1472 .ssh_command("gunzip")
1473 .arg("--force")
1474 .arg(&dst_path_gz),
1475 )
1476 .await?;
1477
1478 delegate.set_status(Some("Marking remote development server executable"), cx);
1479 run_cmd(
1480 self.socket
1481 .ssh_command("chmod")
1482 .arg(format!("{:o}", server_mode))
1483 .arg(dst_path),
1484 )
1485 .await?;
1486
1487 Ok(())
1488 }
1489
1490 async fn query_platform(&self) -> Result<SshPlatform> {
1491 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1492 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1493
1494 let os = match os.trim() {
1495 "Darwin" => "macos",
1496 "Linux" => "linux",
1497 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1498 };
1499 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1500 "aarch64"
1501 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1502 "x86_64"
1503 } else {
1504 Err(anyhow!("unknown uname architecture {arch:?}"))?
1505 };
1506
1507 Ok(SshPlatform { os, arch })
1508 }
1509
1510 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1511 let mut command = process::Command::new("scp");
1512 let output = self
1513 .socket
1514 .ssh_options(&mut command)
1515 .args(
1516 self.socket
1517 .connection_options
1518 .port
1519 .map(|port| vec!["-P".to_string(), port.to_string()])
1520 .unwrap_or_default(),
1521 )
1522 .arg(src_path)
1523 .arg(format!(
1524 "{}:{}",
1525 self.socket.connection_options.scp_url(),
1526 dest_path.display()
1527 ))
1528 .output()
1529 .await?;
1530
1531 if output.status.success() {
1532 Ok(())
1533 } else {
1534 Err(anyhow!(
1535 "failed to upload file {} -> {}: {}",
1536 src_path.display(),
1537 dest_path.display(),
1538 String::from_utf8_lossy(&output.stderr)
1539 ))
1540 }
1541 }
1542}
1543
1544type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1545
1546pub struct ChannelClient {
1547 next_message_id: AtomicU32,
1548 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1549 buffer: Mutex<VecDeque<Envelope>>,
1550 response_channels: ResponseChannels,
1551 message_handlers: Mutex<ProtoMessageHandlerSet>,
1552 max_received: AtomicU32,
1553}
1554
1555impl ChannelClient {
1556 pub fn new(
1557 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1558 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1559 cx: &AppContext,
1560 ) -> Arc<Self> {
1561 let this = Arc::new(Self {
1562 outgoing_tx,
1563 next_message_id: AtomicU32::new(0),
1564 max_received: AtomicU32::new(0),
1565 response_channels: ResponseChannels::default(),
1566 message_handlers: Default::default(),
1567 buffer: Mutex::new(VecDeque::new()),
1568 });
1569
1570 Self::start_handling_messages(this.clone(), incoming_rx, cx);
1571
1572 this
1573 }
1574
1575 fn start_handling_messages(
1576 this: Arc<Self>,
1577 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1578 cx: &AppContext,
1579 ) {
1580 cx.spawn(|cx| {
1581 let this = Arc::downgrade(&this);
1582 async move {
1583 let peer_id = PeerId { owner_id: 0, id: 0 };
1584 while let Some(incoming) = incoming_rx.next().await {
1585 let Some(this) = this.upgrade() else {
1586 return anyhow::Ok(());
1587 };
1588 if let Some(ack_id) = incoming.ack_id {
1589 let mut buffer = this.buffer.lock();
1590 while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
1591 buffer.pop_front();
1592 }
1593 }
1594 if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) = &incoming.payload {
1595 {
1596 let buffer = this.buffer.lock();
1597 for envelope in buffer.iter() {
1598 this.outgoing_tx.unbounded_send(envelope.clone()).ok();
1599 }
1600 }
1601 let response = proto::Ack{}.into_envelope(0, Some(incoming.id), None);
1602 this.send_dynamic(response).ok();
1603 continue;
1604 }
1605
1606 this.max_received.store(incoming.id, SeqCst);
1607
1608 if let Some(request_id) = incoming.responding_to {
1609 let request_id = MessageId(request_id);
1610 let sender = this.response_channels.lock().remove(&request_id);
1611 if let Some(sender) = sender {
1612 let (tx, rx) = oneshot::channel();
1613 if incoming.payload.is_some() {
1614 sender.send((incoming, tx)).ok();
1615 }
1616 rx.await.ok();
1617 }
1618 } else if let Some(envelope) =
1619 build_typed_envelope(peer_id, Instant::now(), incoming)
1620 {
1621 let type_name = envelope.payload_type_name();
1622 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1623 &this.message_handlers,
1624 envelope,
1625 this.clone().into(),
1626 cx.clone(),
1627 ) {
1628 log::debug!("ssh message received. name:{type_name}");
1629 match future.await {
1630 Ok(_) => {
1631 log::debug!("ssh message handled. name:{type_name}");
1632 }
1633 Err(error) => {
1634 log::error!(
1635 "error handling message. type:{type_name}, error:{error}",
1636 );
1637 }
1638 }
1639 } else {
1640 log::error!("unhandled ssh message name:{type_name}");
1641 }
1642 }
1643 }
1644 anyhow::Ok(())
1645 }
1646 })
1647 .detach();
1648 }
1649
1650 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1651 let id = (TypeId::of::<E>(), remote_id);
1652
1653 let mut message_handlers = self.message_handlers.lock();
1654 if message_handlers
1655 .entities_by_type_and_remote_id
1656 .contains_key(&id)
1657 {
1658 panic!("already subscribed to entity");
1659 }
1660
1661 message_handlers.entities_by_type_and_remote_id.insert(
1662 id,
1663 EntityMessageSubscriber::Entity {
1664 handle: entity.downgrade().into(),
1665 },
1666 );
1667 }
1668
1669 pub fn request<T: RequestMessage>(
1670 &self,
1671 payload: T,
1672 ) -> impl 'static + Future<Output = Result<T::Response>> {
1673 log::debug!("ssh request start. name:{}", T::NAME);
1674 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1675 async move {
1676 let response = response.await?;
1677 log::debug!("ssh request finish. name:{}", T::NAME);
1678 T::Response::from_envelope(response)
1679 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1680 }
1681 }
1682
1683 pub async fn resync(&self, timeout: Duration) -> Result<()> {
1684 smol::future::or(
1685 async {
1686 self.request(proto::FlushBufferedMessages {}).await?;
1687 for envelope in self.buffer.lock().iter() {
1688 self.outgoing_tx.unbounded_send(envelope.clone()).ok();
1689 }
1690 Ok(())
1691 },
1692 async {
1693 smol::Timer::after(timeout).await;
1694 Err(anyhow!("Timeout detected"))
1695 },
1696 )
1697 .await
1698 }
1699
1700 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1701 smol::future::or(
1702 async {
1703 self.request(proto::Ping {}).await?;
1704 Ok(())
1705 },
1706 async {
1707 smol::Timer::after(timeout).await;
1708 Err(anyhow!("Timeout detected"))
1709 },
1710 )
1711 .await
1712 }
1713
1714 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1715 log::debug!("ssh send name:{}", T::NAME);
1716 self.send_dynamic(payload.into_envelope(0, None, None))
1717 }
1718
1719 pub fn request_dynamic(
1720 &self,
1721 mut envelope: proto::Envelope,
1722 type_name: &'static str,
1723 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1724 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1725 let (tx, rx) = oneshot::channel();
1726 let mut response_channels_lock = self.response_channels.lock();
1727 response_channels_lock.insert(MessageId(envelope.id), tx);
1728 drop(response_channels_lock);
1729
1730 let result = self.send_buffered(envelope);
1731 async move {
1732 if let Err(error) = &result {
1733 log::error!("failed to send message: {}", error);
1734 return Err(anyhow!("failed to send message: {}", error));
1735 }
1736
1737 let response = rx.await.context("connection lost")?.0;
1738 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1739 return Err(RpcError::from_proto(error, type_name));
1740 }
1741 Ok(response)
1742 }
1743 }
1744
1745 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1746 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1747 self.send_buffered(envelope)
1748 }
1749
1750 pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
1751 envelope.ack_id = Some(self.max_received.load(SeqCst));
1752 self.buffer.lock().push_back(envelope.clone());
1753 self.outgoing_tx.unbounded_send(envelope)?;
1754 Ok(())
1755 }
1756}
1757
1758impl ProtoClient for ChannelClient {
1759 fn request(
1760 &self,
1761 envelope: proto::Envelope,
1762 request_type: &'static str,
1763 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1764 self.request_dynamic(envelope, request_type).boxed()
1765 }
1766
1767 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1768 self.send_dynamic(envelope)
1769 }
1770
1771 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1772 self.send_dynamic(envelope)
1773 }
1774
1775 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1776 &self.message_handlers
1777 }
1778
1779 fn is_via_collab(&self) -> bool {
1780 false
1781 }
1782}
1783
1784#[cfg(any(test, feature = "test-support"))]
1785mod fake {
1786 use std::path::PathBuf;
1787
1788 use anyhow::Result;
1789 use async_trait::async_trait;
1790 use futures::{
1791 channel::{
1792 mpsc::{self, Sender},
1793 oneshot,
1794 },
1795 select_biased, FutureExt, SinkExt, StreamExt,
1796 };
1797 use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
1798 use rpc::proto::Envelope;
1799
1800 use super::{
1801 ChannelForwarder, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess,
1802 };
1803
1804 pub(super) struct SshRemoteConnection {
1805 connection_options: SshConnectionOptions,
1806 }
1807
1808 impl SshRemoteConnection {
1809 pub(super) fn new(
1810 connection_options: &SshConnectionOptions,
1811 ) -> Option<Box<dyn SshRemoteProcess>> {
1812 if connection_options.host == "<fake>" {
1813 return Some(Box::new(Self {
1814 connection_options: connection_options.clone(),
1815 }));
1816 }
1817 return None;
1818 }
1819 pub(super) async fn multiplex(
1820 connection_options: SshConnectionOptions,
1821 mut client_tx: mpsc::UnboundedSender<Envelope>,
1822 mut client_rx: mpsc::UnboundedReceiver<Envelope>,
1823 mut connection_activity_tx: Sender<()>,
1824 cx: &mut AsyncAppContext,
1825 ) -> Task<Result<Option<i32>>> {
1826 let (server_tx, server_rx) = cx
1827 .update(|cx| {
1828 cx.update_global(|conns: &mut GlobalConnections, _| {
1829 conns.take(connection_options.port.unwrap())
1830 })
1831 })
1832 .unwrap()
1833 .into_channels()
1834 .await;
1835
1836 let (forwarder, mut proxy_tx, mut proxy_rx) =
1837 ChannelForwarder::new(server_tx, server_rx, cx);
1838
1839 cx.update(|cx| {
1840 cx.update_global(|conns: &mut GlobalConnections, _| {
1841 conns.replace(connection_options.port.unwrap(), forwarder)
1842 })
1843 })
1844 .unwrap();
1845
1846 cx.background_executor().spawn(async move {
1847 loop {
1848 select_biased! {
1849 server_to_client = proxy_rx.next().fuse() => {
1850 let Some(server_to_client) = server_to_client else {
1851 return Ok(Some(1))
1852 };
1853 connection_activity_tx.try_send(()).ok();
1854 client_tx.send(server_to_client).await.ok();
1855 }
1856 client_to_server = client_rx.next().fuse() => {
1857 let Some(client_to_server) = client_to_server else {
1858 return Ok(None)
1859 };
1860 proxy_tx.send(client_to_server).await.ok();
1861
1862 }
1863 }
1864 }
1865 })
1866 }
1867 }
1868
1869 #[async_trait]
1870 impl SshRemoteProcess for SshRemoteConnection {
1871 async fn kill(&mut self) -> Result<()> {
1872 Ok(())
1873 }
1874
1875 fn ssh_args(&self) -> Vec<String> {
1876 Vec::new()
1877 }
1878
1879 fn connection_options(&self) -> SshConnectionOptions {
1880 self.connection_options.clone()
1881 }
1882 }
1883
1884 #[derive(Default)]
1885 pub(super) struct GlobalConnections(Vec<Option<ChannelForwarder>>);
1886 impl Global for GlobalConnections {}
1887
1888 impl GlobalConnections {
1889 pub(super) fn push(&mut self, forwarder: ChannelForwarder) -> u16 {
1890 self.0.push(Some(forwarder));
1891 self.0.len() as u16 - 1
1892 }
1893
1894 pub(super) fn take(&mut self, port: u16) -> ChannelForwarder {
1895 self.0
1896 .get_mut(port as usize)
1897 .expect("no fake server for port")
1898 .take()
1899 .expect("fake server is already borrowed")
1900 }
1901 pub(super) fn replace(&mut self, port: u16, forwarder: ChannelForwarder) {
1902 let ret = self
1903 .0
1904 .get_mut(port as usize)
1905 .expect("no fake server for port")
1906 .replace(forwarder);
1907 if ret.is_some() {
1908 panic!("fake server is already replaced");
1909 }
1910 }
1911 }
1912
1913 pub(super) struct Delegate;
1914
1915 impl SshClientDelegate for Delegate {
1916 fn ask_password(
1917 &self,
1918 _: String,
1919 _: &mut AsyncAppContext,
1920 ) -> oneshot::Receiver<Result<String>> {
1921 unreachable!()
1922 }
1923 fn remote_server_binary_path(
1924 &self,
1925 _: SshPlatform,
1926 _: &mut AsyncAppContext,
1927 ) -> Result<PathBuf> {
1928 unreachable!()
1929 }
1930 fn get_server_binary(
1931 &self,
1932 _: SshPlatform,
1933 _: &mut AsyncAppContext,
1934 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
1935 unreachable!()
1936 }
1937 fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {
1938 unreachable!()
1939 }
1940 fn set_error(&self, _: String, _: &mut AsyncAppContext) {
1941 unreachable!()
1942 }
1943 }
1944}