1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use collections::HashMap;
10use futures::{
11 channel::{
12 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
13 oneshot,
14 },
15 future::BoxFuture,
16 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
17 StreamExt as _,
18};
19use gpui::{
20 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
21 WeakModel,
22};
23use parking_lot::Mutex;
24use rpc::{
25 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
26 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
27};
28use smol::{
29 fs,
30 process::{self, Child, Stdio},
31};
32use std::{
33 any::TypeId,
34 ffi::OsStr,
35 fmt,
36 ops::ControlFlow,
37 path::{Path, PathBuf},
38 sync::{
39 atomic::{AtomicU32, Ordering::SeqCst},
40 Arc,
41 },
42 time::{Duration, Instant},
43};
44use tempfile::TempDir;
45use util::ResultExt;
46
47#[derive(
48 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
49)]
50pub struct SshProjectId(pub u64);
51
52#[derive(Clone)]
53pub struct SshSocket {
54 connection_options: SshConnectionOptions,
55 socket_path: PathBuf,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct SshConnectionOptions {
60 pub host: String,
61 pub username: Option<String>,
62 pub port: Option<u16>,
63 pub password: Option<String>,
64 pub args: Option<Vec<String>>,
65}
66
67impl SshConnectionOptions {
68 pub fn parse_command_line(input: &str) -> Result<Self> {
69 let input = input.trim_start_matches("ssh ");
70 let mut hostname: Option<String> = None;
71 let mut username: Option<String> = None;
72 let mut port: Option<u16> = None;
73 let mut args = Vec::new();
74
75 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
76 const ALLOWED_OPTS: &[&str] = &[
77 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
78 ];
79 const ALLOWED_ARGS: &[&str] = &[
80 "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
81 "-w",
82 ];
83
84 let mut tokens = shlex::split(input)
85 .ok_or_else(|| anyhow!("invalid input"))?
86 .into_iter();
87
88 'outer: while let Some(arg) = tokens.next() {
89 if ALLOWED_OPTS.contains(&(&arg as &str)) {
90 args.push(arg.to_string());
91 continue;
92 }
93 if arg == "-p" {
94 port = tokens.next().and_then(|arg| arg.parse().ok());
95 continue;
96 } else if let Some(p) = arg.strip_prefix("-p") {
97 port = p.parse().ok();
98 continue;
99 }
100 if arg == "-l" {
101 username = tokens.next();
102 continue;
103 } else if let Some(l) = arg.strip_prefix("-l") {
104 username = Some(l.to_string());
105 continue;
106 }
107 for a in ALLOWED_ARGS {
108 if arg == *a {
109 args.push(arg);
110 if let Some(next) = tokens.next() {
111 args.push(next);
112 }
113 continue 'outer;
114 } else if arg.starts_with(a) {
115 args.push(arg);
116 continue 'outer;
117 }
118 }
119 if arg.starts_with("-") || hostname.is_some() {
120 anyhow::bail!("unsupported argument: {:?}", arg);
121 }
122 let mut input = &arg as &str;
123 if let Some((u, rest)) = input.split_once('@') {
124 input = rest;
125 username = Some(u.to_string());
126 }
127 if let Some((rest, p)) = input.split_once(':') {
128 input = rest;
129 port = p.parse().ok()
130 }
131 hostname = Some(input.to_string())
132 }
133
134 let Some(hostname) = hostname else {
135 anyhow::bail!("missing hostname");
136 };
137
138 Ok(Self {
139 host: hostname.to_string(),
140 username: username.clone(),
141 port,
142 password: None,
143 args: Some(args),
144 })
145 }
146
147 pub fn ssh_url(&self) -> String {
148 let mut result = String::from("ssh://");
149 if let Some(username) = &self.username {
150 result.push_str(username);
151 result.push('@');
152 }
153 result.push_str(&self.host);
154 if let Some(port) = self.port {
155 result.push(':');
156 result.push_str(&port.to_string());
157 }
158 result
159 }
160
161 pub fn additional_args(&self) -> Option<&Vec<String>> {
162 self.args.as_ref()
163 }
164
165 fn scp_url(&self) -> String {
166 if let Some(username) = &self.username {
167 format!("{}@{}", username, self.host)
168 } else {
169 self.host.clone()
170 }
171 }
172
173 pub fn connection_string(&self) -> String {
174 let host = if let Some(username) = &self.username {
175 format!("{}@{}", username, self.host)
176 } else {
177 self.host.clone()
178 };
179 if let Some(port) = &self.port {
180 format!("{}:{}", host, port)
181 } else {
182 host
183 }
184 }
185
186 // Uniquely identifies dev server projects on a remote host. Needs to be
187 // stable for the same dev server project.
188 pub fn dev_server_identifier(&self) -> String {
189 let mut identifier = format!("dev-server-{:?}", self.host);
190 if let Some(username) = self.username.as_ref() {
191 identifier.push('-');
192 identifier.push_str(&username);
193 }
194 identifier
195 }
196}
197
198#[derive(Copy, Clone, Debug)]
199pub struct SshPlatform {
200 pub os: &'static str,
201 pub arch: &'static str,
202}
203
204impl SshPlatform {
205 pub fn triple(&self) -> Option<String> {
206 Some(format!(
207 "{}-{}",
208 self.arch,
209 match self.os {
210 "linux" => "unknown-linux-gnu",
211 "macos" => "apple-darwin",
212 _ => return None,
213 }
214 ))
215 }
216}
217
218pub trait SshClientDelegate: Send + Sync {
219 fn ask_password(
220 &self,
221 prompt: String,
222 cx: &mut AsyncAppContext,
223 ) -> oneshot::Receiver<Result<String>>;
224 fn remote_server_binary_path(
225 &self,
226 platform: SshPlatform,
227 cx: &mut AsyncAppContext,
228 ) -> Result<PathBuf>;
229 fn get_server_binary(
230 &self,
231 platform: SshPlatform,
232 cx: &mut AsyncAppContext,
233 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
234 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
235 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
236}
237
238impl SshSocket {
239 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
240 let mut command = process::Command::new("ssh");
241 self.ssh_options(&mut command)
242 .arg(self.connection_options.ssh_url())
243 .arg(program);
244 command
245 }
246
247 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
248 command
249 .stdin(Stdio::piped())
250 .stdout(Stdio::piped())
251 .stderr(Stdio::piped())
252 .args(["-o", "ControlMaster=no", "-o"])
253 .arg(format!("ControlPath={}", self.socket_path.display()))
254 }
255
256 fn ssh_args(&self) -> Vec<String> {
257 vec![
258 "-o".to_string(),
259 "ControlMaster=no".to_string(),
260 "-o".to_string(),
261 format!("ControlPath={}", self.socket_path.display()),
262 self.connection_options.ssh_url(),
263 ]
264 }
265}
266
267async fn run_cmd(command: &mut process::Command) -> Result<String> {
268 let output = command.output().await?;
269 if output.status.success() {
270 Ok(String::from_utf8_lossy(&output.stdout).to_string())
271 } else {
272 Err(anyhow!(
273 "failed to run command: {}",
274 String::from_utf8_lossy(&output.stderr)
275 ))
276 }
277}
278
279struct ChannelForwarder {
280 quit_tx: UnboundedSender<()>,
281 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
282}
283
284impl ChannelForwarder {
285 fn new(
286 mut incoming_tx: UnboundedSender<Envelope>,
287 mut outgoing_rx: UnboundedReceiver<Envelope>,
288 cx: &AsyncAppContext,
289 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
290 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
291
292 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
293 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
294
295 let forwarding_task = cx.background_executor().spawn(async move {
296 loop {
297 select_biased! {
298 _ = quit_rx.next().fuse() => {
299 break;
300 },
301 incoming_envelope = proxy_incoming_rx.next().fuse() => {
302 if let Some(envelope) = incoming_envelope {
303 if incoming_tx.send(envelope).await.is_err() {
304 break;
305 }
306 } else {
307 break;
308 }
309 }
310 outgoing_envelope = outgoing_rx.next().fuse() => {
311 if let Some(envelope) = outgoing_envelope {
312 if proxy_outgoing_tx.send(envelope).await.is_err() {
313 break;
314 }
315 } else {
316 break;
317 }
318 }
319 }
320 }
321
322 (incoming_tx, outgoing_rx)
323 });
324
325 (
326 Self {
327 forwarding_task,
328 quit_tx,
329 },
330 proxy_incoming_tx,
331 proxy_outgoing_rx,
332 )
333 }
334
335 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
336 let _ = self.quit_tx.send(()).await;
337 self.forwarding_task.await
338 }
339}
340
341const MAX_MISSED_HEARTBEATS: usize = 5;
342const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
343const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
344
345const MAX_RECONNECT_ATTEMPTS: usize = 3;
346
347enum State {
348 Connecting,
349 Connected {
350 ssh_connection: SshRemoteConnection,
351 delegate: Arc<dyn SshClientDelegate>,
352 forwarder: ChannelForwarder,
353
354 multiplex_task: Task<Result<()>>,
355 heartbeat_task: Task<Result<()>>,
356 },
357 HeartbeatMissed {
358 missed_heartbeats: usize,
359
360 ssh_connection: SshRemoteConnection,
361 delegate: Arc<dyn SshClientDelegate>,
362 forwarder: ChannelForwarder,
363
364 multiplex_task: Task<Result<()>>,
365 heartbeat_task: Task<Result<()>>,
366 },
367 Reconnecting,
368 ReconnectFailed {
369 ssh_connection: SshRemoteConnection,
370 delegate: Arc<dyn SshClientDelegate>,
371 forwarder: ChannelForwarder,
372
373 error: anyhow::Error,
374 attempts: usize,
375 },
376 ReconnectExhausted,
377 ServerNotRunning,
378}
379
380impl fmt::Display for State {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 match self {
383 Self::Connecting => write!(f, "connecting"),
384 Self::Connected { .. } => write!(f, "connected"),
385 Self::Reconnecting => write!(f, "reconnecting"),
386 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
387 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
388 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
389 Self::ServerNotRunning { .. } => write!(f, "server not running"),
390 }
391 }
392}
393
394impl State {
395 fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
396 match self {
397 Self::Connected { ssh_connection, .. } => Some(ssh_connection),
398 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
399 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
400 _ => None,
401 }
402 }
403
404 fn can_reconnect(&self) -> bool {
405 match self {
406 Self::Connected { .. }
407 | Self::HeartbeatMissed { .. }
408 | Self::ReconnectFailed { .. } => true,
409 State::Connecting
410 | State::Reconnecting
411 | State::ReconnectExhausted
412 | State::ServerNotRunning => false,
413 }
414 }
415
416 fn is_reconnect_failed(&self) -> bool {
417 matches!(self, Self::ReconnectFailed { .. })
418 }
419
420 fn is_reconnect_exhausted(&self) -> bool {
421 matches!(self, Self::ReconnectExhausted { .. })
422 }
423
424 fn is_reconnecting(&self) -> bool {
425 matches!(self, Self::Reconnecting { .. })
426 }
427
428 fn heartbeat_recovered(self) -> Self {
429 match self {
430 Self::HeartbeatMissed {
431 ssh_connection,
432 delegate,
433 forwarder,
434 multiplex_task,
435 heartbeat_task,
436 ..
437 } => Self::Connected {
438 ssh_connection,
439 delegate,
440 forwarder,
441 multiplex_task,
442 heartbeat_task,
443 },
444 _ => self,
445 }
446 }
447
448 fn heartbeat_missed(self) -> Self {
449 match self {
450 Self::Connected {
451 ssh_connection,
452 delegate,
453 forwarder,
454 multiplex_task,
455 heartbeat_task,
456 } => Self::HeartbeatMissed {
457 missed_heartbeats: 1,
458 ssh_connection,
459 delegate,
460 forwarder,
461 multiplex_task,
462 heartbeat_task,
463 },
464 Self::HeartbeatMissed {
465 missed_heartbeats,
466 ssh_connection,
467 delegate,
468 forwarder,
469 multiplex_task,
470 heartbeat_task,
471 } => Self::HeartbeatMissed {
472 missed_heartbeats: missed_heartbeats + 1,
473 ssh_connection,
474 delegate,
475 forwarder,
476 multiplex_task,
477 heartbeat_task,
478 },
479 _ => self,
480 }
481 }
482}
483
484/// The state of the ssh connection.
485#[derive(Clone, Copy, Debug, PartialEq, Eq)]
486pub enum ConnectionState {
487 Connecting,
488 Connected,
489 HeartbeatMissed,
490 Reconnecting,
491 Disconnected,
492}
493
494impl From<&State> for ConnectionState {
495 fn from(value: &State) -> Self {
496 match value {
497 State::Connecting => Self::Connecting,
498 State::Connected { .. } => Self::Connected,
499 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
500 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
501 State::ReconnectExhausted => Self::Disconnected,
502 State::ServerNotRunning => Self::Disconnected,
503 }
504 }
505}
506
507pub struct SshRemoteClient {
508 client: Arc<ChannelClient>,
509 unique_identifier: String,
510 connection_options: SshConnectionOptions,
511 state: Arc<Mutex<Option<State>>>,
512}
513
514#[derive(Debug)]
515pub enum SshRemoteEvent {
516 Disconnected,
517}
518
519impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
520
521impl SshRemoteClient {
522 pub fn new(
523 unique_identifier: String,
524 connection_options: SshConnectionOptions,
525 delegate: Arc<dyn SshClientDelegate>,
526 cx: &AppContext,
527 ) -> Task<Result<Model<Self>>> {
528 cx.spawn(|mut cx| async move {
529 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
530 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
531 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
532
533 let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
534 let this = cx.new_model(|_| Self {
535 client: client.clone(),
536 unique_identifier: unique_identifier.clone(),
537 connection_options: connection_options.clone(),
538 state: Arc::new(Mutex::new(Some(State::Connecting))),
539 })?;
540
541 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
542 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
543
544 let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
545 unique_identifier,
546 false,
547 connection_options,
548 delegate.clone(),
549 &mut cx,
550 )
551 .await?;
552
553 let multiplex_task = Self::multiplex(
554 this.downgrade(),
555 ssh_proxy_process,
556 proxy_incoming_tx,
557 proxy_outgoing_rx,
558 connection_activity_tx,
559 &mut cx,
560 );
561
562 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
563 log::error!("failed to establish connection: {}", error);
564 delegate.set_error(error.to_string(), &mut cx);
565 return Err(error);
566 }
567
568 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
569
570 this.update(&mut cx, |this, _| {
571 *this.state.lock() = Some(State::Connected {
572 ssh_connection,
573 delegate,
574 forwarder: proxy,
575 multiplex_task,
576 heartbeat_task,
577 });
578 })?;
579
580 Ok(this)
581 })
582 }
583
584 pub fn shutdown_processes<T: RequestMessage>(
585 &self,
586 shutdown_request: Option<T>,
587 ) -> Option<impl Future<Output = ()>> {
588 let state = self.state.lock().take()?;
589 log::info!("shutting down ssh processes");
590
591 let State::Connected {
592 multiplex_task,
593 heartbeat_task,
594 ssh_connection,
595 delegate,
596 forwarder,
597 } = state
598 else {
599 return None;
600 };
601
602 let client = self.client.clone();
603
604 Some(async move {
605 if let Some(shutdown_request) = shutdown_request {
606 client.send(shutdown_request).log_err();
607 // We wait 50ms instead of waiting for a response, because
608 // waiting for a response would require us to wait on the main thread
609 // which we want to avoid in an `on_app_quit` callback.
610 smol::Timer::after(Duration::from_millis(50)).await;
611 }
612
613 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
614 // child of master_process.
615 drop(multiplex_task);
616 // Now drop the rest of state, which kills master process.
617 drop(heartbeat_task);
618 drop(ssh_connection);
619 drop(delegate);
620 drop(forwarder);
621 })
622 }
623
624 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
625 let mut lock = self.state.lock();
626
627 let can_reconnect = lock
628 .as_ref()
629 .map(|state| state.can_reconnect())
630 .unwrap_or(false);
631 if !can_reconnect {
632 let error = if let Some(state) = lock.as_ref() {
633 format!("invalid state, cannot reconnect while in state {state}")
634 } else {
635 "no state set".to_string()
636 };
637 log::info!("aborting reconnect, because not in state that allows reconnecting");
638 return Err(anyhow!(error));
639 }
640
641 let state = lock.take().unwrap();
642 let (attempts, mut ssh_connection, delegate, forwarder) = match state {
643 State::Connected {
644 ssh_connection,
645 delegate,
646 forwarder,
647 multiplex_task,
648 heartbeat_task,
649 }
650 | State::HeartbeatMissed {
651 ssh_connection,
652 delegate,
653 forwarder,
654 multiplex_task,
655 heartbeat_task,
656 ..
657 } => {
658 drop(multiplex_task);
659 drop(heartbeat_task);
660 (0, ssh_connection, delegate, forwarder)
661 }
662 State::ReconnectFailed {
663 attempts,
664 ssh_connection,
665 delegate,
666 forwarder,
667 ..
668 } => (attempts, ssh_connection, delegate, forwarder),
669 State::Connecting
670 | State::Reconnecting
671 | State::ReconnectExhausted
672 | State::ServerNotRunning => unreachable!(),
673 };
674
675 let attempts = attempts + 1;
676 if attempts > MAX_RECONNECT_ATTEMPTS {
677 log::error!(
678 "Failed to reconnect to after {} attempts, giving up",
679 MAX_RECONNECT_ATTEMPTS
680 );
681 drop(lock);
682 self.set_state(State::ReconnectExhausted, cx);
683 return Ok(());
684 }
685 drop(lock);
686
687 self.set_state(State::Reconnecting, cx);
688
689 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
690
691 let identifier = self.unique_identifier.clone();
692 let client = self.client.clone();
693 let reconnect_task = cx.spawn(|this, mut cx| async move {
694 macro_rules! failed {
695 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
696 return State::ReconnectFailed {
697 error: anyhow!($error),
698 attempts: $attempts,
699 ssh_connection: $ssh_connection,
700 delegate: $delegate,
701 forwarder: $forwarder,
702 };
703 };
704 }
705
706 if let Err(error) = ssh_connection.master_process.kill() {
707 failed!(error, attempts, ssh_connection, delegate, forwarder);
708 };
709
710 if let Err(error) = ssh_connection
711 .master_process
712 .status()
713 .await
714 .context("Failed to kill ssh process")
715 {
716 failed!(error, attempts, ssh_connection, delegate, forwarder);
717 }
718
719 let connection_options = ssh_connection.socket.connection_options.clone();
720
721 let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
722 let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
723 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
724 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
725
726 let (ssh_connection, ssh_process) = match Self::establish_connection(
727 identifier,
728 true,
729 connection_options,
730 delegate.clone(),
731 &mut cx,
732 )
733 .await
734 {
735 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
736 Err(error) => {
737 failed!(error, attempts, ssh_connection, delegate, forwarder);
738 }
739 };
740
741 let multiplex_task = Self::multiplex(
742 this.clone(),
743 ssh_process,
744 proxy_incoming_tx,
745 proxy_outgoing_rx,
746 connection_activity_tx,
747 &mut cx,
748 );
749
750 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
751 failed!(error, attempts, ssh_connection, delegate, forwarder);
752 };
753
754 State::Connected {
755 ssh_connection,
756 delegate,
757 forwarder,
758 multiplex_task,
759 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
760 }
761 });
762
763 cx.spawn(|this, mut cx| async move {
764 let new_state = reconnect_task.await;
765 this.update(&mut cx, |this, cx| {
766 this.try_set_state(cx, |old_state| {
767 if old_state.is_reconnecting() {
768 match &new_state {
769 State::Connecting
770 | State::Reconnecting { .. }
771 | State::HeartbeatMissed { .. }
772 | State::ServerNotRunning => {}
773 State::Connected { .. } => {
774 log::info!("Successfully reconnected");
775 }
776 State::ReconnectFailed {
777 error, attempts, ..
778 } => {
779 log::error!(
780 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
781 attempts,
782 error
783 );
784 }
785 State::ReconnectExhausted => {
786 log::error!("Reconnect attempt failed and all attempts exhausted");
787 }
788 }
789 Some(new_state)
790 } else {
791 None
792 }
793 });
794
795 if this.state_is(State::is_reconnect_failed) {
796 this.reconnect(cx)
797 } else if this.state_is(State::is_reconnect_exhausted) {
798 cx.emit(SshRemoteEvent::Disconnected);
799 Ok(())
800 } else {
801 log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
802 Ok(())
803 }
804 })
805 })
806 .detach_and_log_err(cx);
807
808 Ok(())
809 }
810
811 fn heartbeat(
812 this: WeakModel<Self>,
813 mut connection_activity_rx: mpsc::Receiver<()>,
814 cx: &mut AsyncAppContext,
815 ) -> Task<Result<()>> {
816 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
817 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
818 };
819
820 cx.spawn(|mut cx| {
821 let this = this.clone();
822 async move {
823 let mut missed_heartbeats = 0;
824
825 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
826 futures::pin_mut!(keepalive_timer);
827
828 loop {
829 select_biased! {
830 result = connection_activity_rx.next().fuse() => {
831 if result.is_none() {
832 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
833 return Ok(());
834 }
835
836 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
837
838 if missed_heartbeats != 0 {
839 missed_heartbeats = 0;
840 this.update(&mut cx, |this, mut cx| {
841 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
842 })?;
843 }
844 }
845 _ = keepalive_timer => {
846 log::debug!("Sending heartbeat to server...");
847
848 let result = select_biased! {
849 _ = connection_activity_rx.next().fuse() => {
850 Ok(())
851 }
852 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
853 ping_result
854 }
855 };
856
857 if result.is_err() {
858 missed_heartbeats += 1;
859 log::warn!(
860 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
861 HEARTBEAT_TIMEOUT,
862 missed_heartbeats,
863 MAX_MISSED_HEARTBEATS
864 );
865 } else if missed_heartbeats != 0 {
866 missed_heartbeats = 0;
867 } else {
868 continue;
869 }
870
871 let result = this.update(&mut cx, |this, mut cx| {
872 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
873 })?;
874 if result.is_break() {
875 return Ok(());
876 }
877 }
878 }
879 }
880 }
881 })
882 }
883
884 fn handle_heartbeat_result(
885 &mut self,
886 missed_heartbeats: usize,
887 cx: &mut ModelContext<Self>,
888 ) -> ControlFlow<()> {
889 let state = self.state.lock().take().unwrap();
890 let next_state = if missed_heartbeats > 0 {
891 state.heartbeat_missed()
892 } else {
893 state.heartbeat_recovered()
894 };
895
896 self.set_state(next_state, cx);
897
898 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
899 log::error!(
900 "Missed last {} heartbeats. Reconnecting...",
901 missed_heartbeats
902 );
903
904 self.reconnect(cx)
905 .context("failed to start reconnect process after missing heartbeats")
906 .log_err();
907 ControlFlow::Break(())
908 } else {
909 ControlFlow::Continue(())
910 }
911 }
912
913 fn multiplex(
914 this: WeakModel<Self>,
915 mut ssh_proxy_process: Child,
916 incoming_tx: UnboundedSender<Envelope>,
917 mut outgoing_rx: UnboundedReceiver<Envelope>,
918 mut connection_activity_tx: Sender<()>,
919 cx: &AsyncAppContext,
920 ) -> Task<Result<()>> {
921 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
922 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
923 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
924
925 let io_task = cx.background_executor().spawn(async move {
926 let mut stdin_buffer = Vec::new();
927 let mut stdout_buffer = Vec::new();
928 let mut stderr_buffer = Vec::new();
929 let mut stderr_offset = 0;
930
931 loop {
932 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
933 stderr_buffer.resize(stderr_offset + 1024, 0);
934
935 select_biased! {
936 outgoing = outgoing_rx.next().fuse() => {
937 let Some(outgoing) = outgoing else {
938 return anyhow::Ok(None);
939 };
940
941 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
942 }
943
944 result = child_stdout.read(&mut stdout_buffer).fuse() => {
945 match result {
946 Ok(0) => {
947 child_stdin.close().await?;
948 outgoing_rx.close();
949 let status = ssh_proxy_process.status().await?;
950 // If we don't have a code, we assume process
951 // has been killed and treat it as non-zero exit
952 // code
953 return Ok(status.code().or_else(|| Some(1)));
954 }
955 Ok(len) => {
956 if len < stdout_buffer.len() {
957 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
958 }
959
960 let message_len = message_len_from_buffer(&stdout_buffer);
961 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
962 Ok(envelope) => {
963 connection_activity_tx.try_send(()).ok();
964 incoming_tx.unbounded_send(envelope).ok();
965 }
966 Err(error) => {
967 log::error!("error decoding message {error:?}");
968 }
969 }
970 }
971 Err(error) => {
972 Err(anyhow!("error reading stdout: {error:?}"))?;
973 }
974 }
975 }
976
977 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
978 match result {
979 Ok(len) => {
980 stderr_offset += len;
981 let mut start_ix = 0;
982 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
983 let line_ix = start_ix + ix;
984 let content = &stderr_buffer[start_ix..line_ix];
985 start_ix = line_ix + 1;
986 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
987 record.log(log::logger())
988 } else {
989 eprintln!("(remote) {}", String::from_utf8_lossy(content));
990 }
991 }
992 stderr_buffer.drain(0..start_ix);
993 stderr_offset -= start_ix;
994
995 connection_activity_tx.try_send(()).ok();
996 }
997 Err(error) => {
998 Err(anyhow!("error reading stderr: {error:?}"))?;
999 }
1000 }
1001 }
1002 }
1003 }
1004 });
1005
1006 cx.spawn(|mut cx| async move {
1007 let result = io_task.await;
1008
1009 match result {
1010 Ok(Some(exit_code)) => {
1011 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1012 match error {
1013 ProxyLaunchError::ServerNotRunning => {
1014 log::error!("failed to reconnect because server is not running");
1015 this.update(&mut cx, |this, cx| {
1016 this.set_state(State::ServerNotRunning, cx);
1017 cx.emit(SshRemoteEvent::Disconnected);
1018 })?;
1019 }
1020 }
1021 } else if exit_code > 0 {
1022 log::error!("proxy process terminated unexpectedly");
1023 this.update(&mut cx, |this, cx| {
1024 this.reconnect(cx).ok();
1025 })?;
1026 }
1027 }
1028 Ok(None) => {}
1029 Err(error) => {
1030 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1031 this.update(&mut cx, |this, cx| {
1032 this.reconnect(cx).ok();
1033 })?;
1034 }
1035 }
1036 Ok(())
1037 })
1038 }
1039
1040 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1041 self.state.lock().as_ref().map_or(false, check)
1042 }
1043
1044 fn try_set_state(
1045 &self,
1046 cx: &mut ModelContext<Self>,
1047 map: impl FnOnce(&State) -> Option<State>,
1048 ) {
1049 let mut lock = self.state.lock();
1050 let new_state = lock.as_ref().and_then(map);
1051
1052 if let Some(new_state) = new_state {
1053 lock.replace(new_state);
1054 cx.notify();
1055 }
1056 }
1057
1058 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
1059 log::info!("setting state to '{}'", &state);
1060 self.state.lock().replace(state);
1061 cx.notify();
1062 }
1063
1064 async fn establish_connection(
1065 unique_identifier: String,
1066 reconnect: bool,
1067 connection_options: SshConnectionOptions,
1068 delegate: Arc<dyn SshClientDelegate>,
1069 cx: &mut AsyncAppContext,
1070 ) -> Result<(SshRemoteConnection, Child)> {
1071 let ssh_connection =
1072 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
1073
1074 let platform = ssh_connection.query_platform().await?;
1075 let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1076 ssh_connection
1077 .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1078 .await?;
1079
1080 let socket = ssh_connection.socket.clone();
1081 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1082
1083 delegate.set_status(Some("Starting proxy"), cx);
1084
1085 let mut start_proxy_command = format!(
1086 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1087 std::env::var("RUST_LOG").unwrap_or_default(),
1088 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1089 remote_binary_path,
1090 unique_identifier,
1091 );
1092 if reconnect {
1093 start_proxy_command.push_str(" --reconnect");
1094 }
1095
1096 let ssh_proxy_process = socket
1097 .ssh_command(start_proxy_command)
1098 // IMPORTANT: we kill this process when we drop the task that uses it.
1099 .kill_on_drop(true)
1100 .spawn()
1101 .context("failed to spawn remote server")?;
1102
1103 Ok((ssh_connection, ssh_proxy_process))
1104 }
1105
1106 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1107 self.client.subscribe_to_entity(remote_id, entity);
1108 }
1109
1110 pub fn ssh_args(&self) -> Option<Vec<String>> {
1111 self.state
1112 .lock()
1113 .as_ref()
1114 .and_then(|state| state.ssh_connection())
1115 .map(|ssh_connection| ssh_connection.socket.ssh_args())
1116 }
1117
1118 pub fn proto_client(&self) -> AnyProtoClient {
1119 self.client.clone().into()
1120 }
1121
1122 pub fn connection_string(&self) -> String {
1123 self.connection_options.connection_string()
1124 }
1125
1126 pub fn connection_options(&self) -> SshConnectionOptions {
1127 self.connection_options.clone()
1128 }
1129
1130 #[cfg(not(any(test, feature = "test-support")))]
1131 pub fn connection_state(&self) -> ConnectionState {
1132 self.state
1133 .lock()
1134 .as_ref()
1135 .map(ConnectionState::from)
1136 .unwrap_or(ConnectionState::Disconnected)
1137 }
1138
1139 #[cfg(any(test, feature = "test-support"))]
1140 pub fn connection_state(&self) -> ConnectionState {
1141 ConnectionState::Connected
1142 }
1143
1144 pub fn is_disconnected(&self) -> bool {
1145 self.connection_state() == ConnectionState::Disconnected
1146 }
1147
1148 #[cfg(any(test, feature = "test-support"))]
1149 pub fn fake(
1150 client_cx: &mut gpui::TestAppContext,
1151 server_cx: &mut gpui::TestAppContext,
1152 ) -> (Model<Self>, Arc<ChannelClient>) {
1153 use gpui::Context;
1154
1155 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1156 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1157
1158 (
1159 client_cx.update(|cx| {
1160 let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1161 cx.new_model(|_| Self {
1162 client,
1163 unique_identifier: "fake".to_string(),
1164 connection_options: SshConnectionOptions::default(),
1165 state: Arc::new(Mutex::new(None)),
1166 })
1167 }),
1168 server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1169 )
1170 }
1171}
1172
1173impl From<SshRemoteClient> for AnyProtoClient {
1174 fn from(client: SshRemoteClient) -> Self {
1175 AnyProtoClient::new(client.client.clone())
1176 }
1177}
1178
1179struct SshRemoteConnection {
1180 socket: SshSocket,
1181 master_process: process::Child,
1182 _temp_dir: TempDir,
1183}
1184
1185impl Drop for SshRemoteConnection {
1186 fn drop(&mut self) {
1187 if let Err(error) = self.master_process.kill() {
1188 log::error!("failed to kill SSH master process: {}", error);
1189 }
1190 }
1191}
1192
1193impl SshRemoteConnection {
1194 #[cfg(not(unix))]
1195 async fn new(
1196 _connection_options: SshConnectionOptions,
1197 _delegate: Arc<dyn SshClientDelegate>,
1198 _cx: &mut AsyncAppContext,
1199 ) -> Result<Self> {
1200 Err(anyhow!("ssh is not supported on this platform"))
1201 }
1202
1203 #[cfg(unix)]
1204 async fn new(
1205 connection_options: SshConnectionOptions,
1206 delegate: Arc<dyn SshClientDelegate>,
1207 cx: &mut AsyncAppContext,
1208 ) -> Result<Self> {
1209 use futures::{io::BufReader, AsyncBufReadExt as _};
1210 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1211 use util::ResultExt as _;
1212
1213 delegate.set_status(Some("connecting"), cx);
1214
1215 let url = connection_options.ssh_url();
1216 let temp_dir = tempfile::Builder::new()
1217 .prefix("zed-ssh-session")
1218 .tempdir()?;
1219
1220 // Create a domain socket listener to handle requests from the askpass program.
1221 let askpass_socket = temp_dir.path().join("askpass.sock");
1222 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1223 let listener =
1224 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1225
1226 let askpass_task = cx.spawn({
1227 let delegate = delegate.clone();
1228 |mut cx| async move {
1229 let mut askpass_opened_tx = Some(askpass_opened_tx);
1230
1231 while let Ok((mut stream, _)) = listener.accept().await {
1232 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1233 askpass_opened_tx.send(()).ok();
1234 }
1235 let mut buffer = Vec::new();
1236 let mut reader = BufReader::new(&mut stream);
1237 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1238 buffer.clear();
1239 }
1240 let password_prompt = String::from_utf8_lossy(&buffer);
1241 if let Some(password) = delegate
1242 .ask_password(password_prompt.to_string(), &mut cx)
1243 .await
1244 .context("failed to get ssh password")
1245 .and_then(|p| p)
1246 .log_err()
1247 {
1248 stream.write_all(password.as_bytes()).await.log_err();
1249 }
1250 }
1251 }
1252 });
1253
1254 // Create an askpass script that communicates back to this process.
1255 let askpass_script = format!(
1256 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1257 askpass_socket = askpass_socket.display(),
1258 print_args = "printf '%s\\0' \"$@\"",
1259 shebang = "#!/bin/sh",
1260 );
1261 let askpass_script_path = temp_dir.path().join("askpass.sh");
1262 fs::write(&askpass_script_path, askpass_script).await?;
1263 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1264
1265 // Start the master SSH process, which does not do anything except for establish
1266 // the connection and keep it open, allowing other ssh commands to reuse it
1267 // via a control socket.
1268 let socket_path = temp_dir.path().join("ssh.sock");
1269 let mut master_process = process::Command::new("ssh")
1270 .stdin(Stdio::null())
1271 .stdout(Stdio::piped())
1272 .stderr(Stdio::piped())
1273 .env("SSH_ASKPASS_REQUIRE", "force")
1274 .env("SSH_ASKPASS", &askpass_script_path)
1275 .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1276 .args([
1277 "-N",
1278 "-o",
1279 "ControlPersist=no",
1280 "-o",
1281 "ControlMaster=yes",
1282 "-o",
1283 ])
1284 .arg(format!("ControlPath={}", socket_path.display()))
1285 .arg(&url)
1286 .spawn()?;
1287
1288 // Wait for this ssh process to close its stdout, indicating that authentication
1289 // has completed.
1290 let stdout = master_process.stdout.as_mut().unwrap();
1291 let mut output = Vec::new();
1292 let connection_timeout = Duration::from_secs(10);
1293
1294 let result = select_biased! {
1295 _ = askpass_opened_rx.fuse() => {
1296 // If the askpass script has opened, that means the user is typing
1297 // their password, in which case we don't want to timeout anymore,
1298 // since we know a connection has been established.
1299 stdout.read_to_end(&mut output).await?;
1300 Ok(())
1301 }
1302 result = stdout.read_to_end(&mut output).fuse() => {
1303 result?;
1304 Ok(())
1305 }
1306 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1307 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1308 }
1309 };
1310
1311 if let Err(e) = result {
1312 let error_message = format!("Failed to connect to host: {}.", e);
1313 delegate.set_error(error_message, cx);
1314 return Err(e);
1315 }
1316
1317 drop(askpass_task);
1318
1319 if master_process.try_status()?.is_some() {
1320 output.clear();
1321 let mut stderr = master_process.stderr.take().unwrap();
1322 stderr.read_to_end(&mut output).await?;
1323
1324 let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1325 delegate.set_error(error_message.clone(), cx);
1326 Err(anyhow!(error_message))?;
1327 }
1328
1329 Ok(Self {
1330 socket: SshSocket {
1331 connection_options,
1332 socket_path,
1333 },
1334 master_process,
1335 _temp_dir: temp_dir,
1336 })
1337 }
1338
1339 async fn ensure_server_binary(
1340 &self,
1341 delegate: &Arc<dyn SshClientDelegate>,
1342 dst_path: &Path,
1343 platform: SshPlatform,
1344 cx: &mut AsyncAppContext,
1345 ) -> Result<()> {
1346 if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1347 if let Ok(installed_version) =
1348 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1349 {
1350 log::info!("using cached server binary version {}", installed_version);
1351 return Ok(());
1352 }
1353 }
1354
1355 let mut dst_path_gz = dst_path.to_path_buf();
1356 dst_path_gz.set_extension("gz");
1357
1358 if let Some(parent) = dst_path.parent() {
1359 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1360 }
1361
1362 let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
1363
1364 let mut server_binary_exists = false;
1365 if !server_binary_exists && cfg!(not(debug_assertions)) {
1366 if let Ok(installed_version) =
1367 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1368 {
1369 if installed_version.trim() == version.to_string() {
1370 server_binary_exists = true;
1371 }
1372 }
1373 }
1374
1375 if server_binary_exists {
1376 log::info!("remote development server already present",);
1377 return Ok(());
1378 }
1379
1380 let src_stat = fs::metadata(&src_path).await?;
1381 let size = src_stat.len();
1382 let server_mode = 0o755;
1383
1384 let t0 = Instant::now();
1385 delegate.set_status(Some("uploading remote development server"), cx);
1386 log::info!("uploading remote development server ({}kb)", size / 1024);
1387 self.upload_file(&src_path, &dst_path_gz)
1388 .await
1389 .context("failed to upload server binary")?;
1390 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1391
1392 delegate.set_status(Some("extracting remote development server"), cx);
1393 run_cmd(
1394 self.socket
1395 .ssh_command("gunzip")
1396 .arg("--force")
1397 .arg(&dst_path_gz),
1398 )
1399 .await?;
1400
1401 delegate.set_status(Some("unzipping remote development server"), cx);
1402 run_cmd(
1403 self.socket
1404 .ssh_command("chmod")
1405 .arg(format!("{:o}", server_mode))
1406 .arg(dst_path),
1407 )
1408 .await?;
1409
1410 Ok(())
1411 }
1412
1413 async fn query_platform(&self) -> Result<SshPlatform> {
1414 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1415 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1416
1417 let os = match os.trim() {
1418 "Darwin" => "macos",
1419 "Linux" => "linux",
1420 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1421 };
1422 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1423 "aarch64"
1424 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1425 "x86_64"
1426 } else {
1427 Err(anyhow!("unknown uname architecture {arch:?}"))?
1428 };
1429
1430 Ok(SshPlatform { os, arch })
1431 }
1432
1433 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1434 let mut command = process::Command::new("scp");
1435 let output = self
1436 .socket
1437 .ssh_options(&mut command)
1438 .args(
1439 self.socket
1440 .connection_options
1441 .port
1442 .map(|port| vec!["-P".to_string(), port.to_string()])
1443 .unwrap_or_default(),
1444 )
1445 .arg(src_path)
1446 .arg(format!(
1447 "{}:{}",
1448 self.socket.connection_options.scp_url(),
1449 dest_path.display()
1450 ))
1451 .output()
1452 .await?;
1453
1454 if output.status.success() {
1455 Ok(())
1456 } else {
1457 Err(anyhow!(
1458 "failed to upload file {} -> {}: {}",
1459 src_path.display(),
1460 dest_path.display(),
1461 String::from_utf8_lossy(&output.stderr)
1462 ))
1463 }
1464 }
1465}
1466
1467type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1468
1469pub struct ChannelClient {
1470 next_message_id: AtomicU32,
1471 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1472 response_channels: ResponseChannels, // Lock
1473 message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1474}
1475
1476impl ChannelClient {
1477 pub fn new(
1478 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1479 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1480 cx: &AppContext,
1481 ) -> Arc<Self> {
1482 let this = Arc::new(Self {
1483 outgoing_tx,
1484 next_message_id: AtomicU32::new(0),
1485 response_channels: ResponseChannels::default(),
1486 message_handlers: Default::default(),
1487 });
1488
1489 Self::start_handling_messages(this.clone(), incoming_rx, cx);
1490
1491 this
1492 }
1493
1494 fn start_handling_messages(
1495 this: Arc<Self>,
1496 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1497 cx: &AppContext,
1498 ) {
1499 cx.spawn(|cx| {
1500 let this = Arc::downgrade(&this);
1501 async move {
1502 let peer_id = PeerId { owner_id: 0, id: 0 };
1503 while let Some(incoming) = incoming_rx.next().await {
1504 let Some(this) = this.upgrade() else {
1505 return anyhow::Ok(());
1506 };
1507
1508 if let Some(request_id) = incoming.responding_to {
1509 let request_id = MessageId(request_id);
1510 let sender = this.response_channels.lock().remove(&request_id);
1511 if let Some(sender) = sender {
1512 let (tx, rx) = oneshot::channel();
1513 if incoming.payload.is_some() {
1514 sender.send((incoming, tx)).ok();
1515 }
1516 rx.await.ok();
1517 }
1518 } else if let Some(envelope) =
1519 build_typed_envelope(peer_id, Instant::now(), incoming)
1520 {
1521 let type_name = envelope.payload_type_name();
1522 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1523 &this.message_handlers,
1524 envelope,
1525 this.clone().into(),
1526 cx.clone(),
1527 ) {
1528 log::debug!("ssh message received. name:{type_name}");
1529 match future.await {
1530 Ok(_) => {
1531 log::debug!("ssh message handled. name:{type_name}");
1532 }
1533 Err(error) => {
1534 log::error!(
1535 "error handling message. type:{type_name}, error:{error}",
1536 );
1537 }
1538 }
1539 } else {
1540 log::error!("unhandled ssh message name:{type_name}");
1541 }
1542 }
1543 }
1544 anyhow::Ok(())
1545 }
1546 })
1547 .detach();
1548 }
1549
1550 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1551 let id = (TypeId::of::<E>(), remote_id);
1552
1553 let mut message_handlers = self.message_handlers.lock();
1554 if message_handlers
1555 .entities_by_type_and_remote_id
1556 .contains_key(&id)
1557 {
1558 panic!("already subscribed to entity");
1559 }
1560
1561 message_handlers.entities_by_type_and_remote_id.insert(
1562 id,
1563 EntityMessageSubscriber::Entity {
1564 handle: entity.downgrade().into(),
1565 },
1566 );
1567 }
1568
1569 pub fn request<T: RequestMessage>(
1570 &self,
1571 payload: T,
1572 ) -> impl 'static + Future<Output = Result<T::Response>> {
1573 log::debug!("ssh request start. name:{}", T::NAME);
1574 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1575 async move {
1576 let response = response.await?;
1577 log::debug!("ssh request finish. name:{}", T::NAME);
1578 T::Response::from_envelope(response)
1579 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1580 }
1581 }
1582
1583 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1584 smol::future::or(
1585 async {
1586 self.request(proto::Ping {}).await?;
1587 Ok(())
1588 },
1589 async {
1590 smol::Timer::after(timeout).await;
1591 Err(anyhow!("Timeout detected"))
1592 },
1593 )
1594 .await
1595 }
1596
1597 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1598 log::debug!("ssh send name:{}", T::NAME);
1599 self.send_dynamic(payload.into_envelope(0, None, None))
1600 }
1601
1602 pub fn request_dynamic(
1603 &self,
1604 mut envelope: proto::Envelope,
1605 type_name: &'static str,
1606 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1607 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1608 let (tx, rx) = oneshot::channel();
1609 let mut response_channels_lock = self.response_channels.lock();
1610 response_channels_lock.insert(MessageId(envelope.id), tx);
1611 drop(response_channels_lock);
1612 let result = self.outgoing_tx.unbounded_send(envelope);
1613 async move {
1614 if let Err(error) = &result {
1615 log::error!("failed to send message: {}", error);
1616 return Err(anyhow!("failed to send message: {}", error));
1617 }
1618
1619 let response = rx.await.context("connection lost")?.0;
1620 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1621 return Err(RpcError::from_proto(error, type_name));
1622 }
1623 Ok(response)
1624 }
1625 }
1626
1627 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1628 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1629 self.outgoing_tx.unbounded_send(envelope)?;
1630 Ok(())
1631 }
1632}
1633
1634impl ProtoClient for ChannelClient {
1635 fn request(
1636 &self,
1637 envelope: proto::Envelope,
1638 request_type: &'static str,
1639 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1640 self.request_dynamic(envelope, request_type).boxed()
1641 }
1642
1643 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1644 self.send_dynamic(envelope)
1645 }
1646
1647 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1648 self.send_dynamic(envelope)
1649 }
1650
1651 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1652 &self.message_handlers
1653 }
1654
1655 fn is_via_collab(&self) -> bool {
1656 false
1657 }
1658}