1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use collections::HashMap;
10use futures::{
11 channel::{
12 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
13 oneshot,
14 },
15 future::BoxFuture,
16 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
17 StreamExt as _,
18};
19use gpui::{
20 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
21 WeakModel,
22};
23use parking_lot::Mutex;
24use rpc::{
25 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
26 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
27};
28use smol::{
29 fs,
30 process::{self, Child, Stdio},
31};
32use std::{
33 any::TypeId,
34 ffi::OsStr,
35 fmt,
36 ops::ControlFlow,
37 path::{Path, PathBuf},
38 sync::{
39 atomic::{AtomicU32, Ordering::SeqCst},
40 Arc,
41 },
42 time::{Duration, Instant},
43};
44use tempfile::TempDir;
45use util::ResultExt;
46
47#[derive(
48 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
49)]
50pub struct SshProjectId(pub u64);
51
52#[derive(Clone)]
53pub struct SshSocket {
54 connection_options: SshConnectionOptions,
55 socket_path: PathBuf,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct SshConnectionOptions {
60 pub host: String,
61 pub username: Option<String>,
62 pub port: Option<u16>,
63 pub password: Option<String>,
64 pub args: Option<Vec<String>>,
65}
66
67impl SshConnectionOptions {
68 pub fn parse_command_line(input: &str) -> Result<Self> {
69 let input = input.trim_start_matches("ssh ");
70 let mut hostname: Option<String> = None;
71 let mut username: Option<String> = None;
72 let mut port: Option<u16> = None;
73 let mut args = Vec::new();
74
75 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
76 const ALLOWED_OPTS: &[&str] = &[
77 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
78 ];
79 const ALLOWED_ARGS: &[&str] = &[
80 "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
81 "-w",
82 ];
83
84 let mut tokens = shlex::split(input)
85 .ok_or_else(|| anyhow!("invalid input"))?
86 .into_iter();
87
88 'outer: while let Some(arg) = tokens.next() {
89 if ALLOWED_OPTS.contains(&(&arg as &str)) {
90 args.push(arg.to_string());
91 continue;
92 }
93 if arg == "-p" {
94 port = tokens.next().and_then(|arg| arg.parse().ok());
95 continue;
96 } else if let Some(p) = arg.strip_prefix("-p") {
97 port = p.parse().ok();
98 continue;
99 }
100 if arg == "-l" {
101 username = tokens.next();
102 continue;
103 } else if let Some(l) = arg.strip_prefix("-l") {
104 username = Some(l.to_string());
105 continue;
106 }
107 for a in ALLOWED_ARGS {
108 if arg == *a {
109 args.push(arg);
110 if let Some(next) = tokens.next() {
111 args.push(next);
112 }
113 continue 'outer;
114 } else if arg.starts_with(a) {
115 args.push(arg);
116 continue 'outer;
117 }
118 }
119 if arg.starts_with("-") || hostname.is_some() {
120 anyhow::bail!("unsupported argument: {:?}", arg);
121 }
122 let mut input = &arg as &str;
123 if let Some((u, rest)) = input.split_once('@') {
124 input = rest;
125 username = Some(u.to_string());
126 }
127 if let Some((rest, p)) = input.split_once(':') {
128 input = rest;
129 port = p.parse().ok()
130 }
131 hostname = Some(input.to_string())
132 }
133
134 let Some(hostname) = hostname else {
135 anyhow::bail!("missing hostname");
136 };
137
138 Ok(Self {
139 host: hostname.to_string(),
140 username: username.clone(),
141 port,
142 password: None,
143 args: Some(args),
144 })
145 }
146
147 pub fn ssh_url(&self) -> String {
148 let mut result = String::from("ssh://");
149 if let Some(username) = &self.username {
150 result.push_str(username);
151 result.push('@');
152 }
153 result.push_str(&self.host);
154 if let Some(port) = self.port {
155 result.push(':');
156 result.push_str(&port.to_string());
157 }
158 result
159 }
160
161 pub fn additional_args(&self) -> Option<&Vec<String>> {
162 self.args.as_ref()
163 }
164
165 fn scp_url(&self) -> String {
166 if let Some(username) = &self.username {
167 format!("{}@{}", username, self.host)
168 } else {
169 self.host.clone()
170 }
171 }
172
173 pub fn connection_string(&self) -> String {
174 let host = if let Some(username) = &self.username {
175 format!("{}@{}", username, self.host)
176 } else {
177 self.host.clone()
178 };
179 if let Some(port) = &self.port {
180 format!("{}:{}", host, port)
181 } else {
182 host
183 }
184 }
185
186 // Uniquely identifies dev server projects on a remote host. Needs to be
187 // stable for the same dev server project.
188 pub fn dev_server_identifier(&self) -> String {
189 let mut identifier = format!("dev-server-{:?}", self.host);
190 if let Some(username) = self.username.as_ref() {
191 identifier.push('-');
192 identifier.push_str(&username);
193 }
194 identifier
195 }
196}
197
198#[derive(Copy, Clone, Debug)]
199pub struct SshPlatform {
200 pub os: &'static str,
201 pub arch: &'static str,
202}
203
204impl SshPlatform {
205 pub fn triple(&self) -> Option<String> {
206 Some(format!(
207 "{}-{}",
208 self.arch,
209 match self.os {
210 "linux" => "unknown-linux-gnu",
211 "macos" => "apple-darwin",
212 _ => return None,
213 }
214 ))
215 }
216}
217
218pub trait SshClientDelegate: Send + Sync {
219 fn ask_password(
220 &self,
221 prompt: String,
222 cx: &mut AsyncAppContext,
223 ) -> oneshot::Receiver<Result<String>>;
224 fn remote_server_binary_path(
225 &self,
226 platform: SshPlatform,
227 cx: &mut AsyncAppContext,
228 ) -> Result<PathBuf>;
229 fn get_server_binary(
230 &self,
231 platform: SshPlatform,
232 cx: &mut AsyncAppContext,
233 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
234 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
235 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
236}
237
238impl SshSocket {
239 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
240 let mut command = process::Command::new("ssh");
241 self.ssh_options(&mut command)
242 .arg(self.connection_options.ssh_url())
243 .arg(program);
244 command
245 }
246
247 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
248 command
249 .stdin(Stdio::piped())
250 .stdout(Stdio::piped())
251 .stderr(Stdio::piped())
252 .args(["-o", "ControlMaster=no", "-o"])
253 .arg(format!("ControlPath={}", self.socket_path.display()))
254 }
255
256 fn ssh_args(&self) -> Vec<String> {
257 vec![
258 "-o".to_string(),
259 "ControlMaster=no".to_string(),
260 "-o".to_string(),
261 format!("ControlPath={}", self.socket_path.display()),
262 self.connection_options.ssh_url(),
263 ]
264 }
265}
266
267async fn run_cmd(command: &mut process::Command) -> Result<String> {
268 let output = command.output().await?;
269 if output.status.success() {
270 Ok(String::from_utf8_lossy(&output.stdout).to_string())
271 } else {
272 Err(anyhow!(
273 "failed to run command: {}",
274 String::from_utf8_lossy(&output.stderr)
275 ))
276 }
277}
278
279struct ChannelForwarder {
280 quit_tx: UnboundedSender<()>,
281 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
282}
283
284impl ChannelForwarder {
285 fn new(
286 mut incoming_tx: UnboundedSender<Envelope>,
287 mut outgoing_rx: UnboundedReceiver<Envelope>,
288 cx: &AsyncAppContext,
289 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
290 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
291
292 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
293 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
294
295 let forwarding_task = cx.background_executor().spawn(async move {
296 loop {
297 select_biased! {
298 _ = quit_rx.next().fuse() => {
299 break;
300 },
301 incoming_envelope = proxy_incoming_rx.next().fuse() => {
302 if let Some(envelope) = incoming_envelope {
303 if incoming_tx.send(envelope).await.is_err() {
304 break;
305 }
306 } else {
307 break;
308 }
309 }
310 outgoing_envelope = outgoing_rx.next().fuse() => {
311 if let Some(envelope) = outgoing_envelope {
312 if proxy_outgoing_tx.send(envelope).await.is_err() {
313 break;
314 }
315 } else {
316 break;
317 }
318 }
319 }
320 }
321
322 (incoming_tx, outgoing_rx)
323 });
324
325 (
326 Self {
327 forwarding_task,
328 quit_tx,
329 },
330 proxy_incoming_tx,
331 proxy_outgoing_rx,
332 )
333 }
334
335 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
336 let _ = self.quit_tx.send(()).await;
337 self.forwarding_task.await
338 }
339}
340
341const MAX_MISSED_HEARTBEATS: usize = 5;
342const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
343const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
344
345const MAX_RECONNECT_ATTEMPTS: usize = 3;
346
347enum State {
348 Connecting,
349 Connected {
350 ssh_connection: SshRemoteConnection,
351 delegate: Arc<dyn SshClientDelegate>,
352 forwarder: ChannelForwarder,
353
354 multiplex_task: Task<Result<()>>,
355 heartbeat_task: Task<Result<()>>,
356 },
357 HeartbeatMissed {
358 missed_heartbeats: usize,
359
360 ssh_connection: SshRemoteConnection,
361 delegate: Arc<dyn SshClientDelegate>,
362 forwarder: ChannelForwarder,
363
364 multiplex_task: Task<Result<()>>,
365 heartbeat_task: Task<Result<()>>,
366 },
367 Reconnecting,
368 ReconnectFailed {
369 ssh_connection: SshRemoteConnection,
370 delegate: Arc<dyn SshClientDelegate>,
371 forwarder: ChannelForwarder,
372
373 error: anyhow::Error,
374 attempts: usize,
375 },
376 ReconnectExhausted,
377 ServerNotRunning,
378}
379
380impl fmt::Display for State {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 match self {
383 Self::Connecting => write!(f, "connecting"),
384 Self::Connected { .. } => write!(f, "connected"),
385 Self::Reconnecting => write!(f, "reconnecting"),
386 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
387 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
388 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
389 Self::ServerNotRunning { .. } => write!(f, "server not running"),
390 }
391 }
392}
393
394impl State {
395 fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
396 match self {
397 Self::Connected { ssh_connection, .. } => Some(ssh_connection),
398 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
399 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
400 _ => None,
401 }
402 }
403
404 fn can_reconnect(&self) -> bool {
405 match self {
406 Self::Connected { .. }
407 | Self::HeartbeatMissed { .. }
408 | Self::ReconnectFailed { .. } => true,
409 State::Connecting
410 | State::Reconnecting
411 | State::ReconnectExhausted
412 | State::ServerNotRunning => false,
413 }
414 }
415
416 fn is_reconnect_failed(&self) -> bool {
417 matches!(self, Self::ReconnectFailed { .. })
418 }
419
420 fn is_reconnect_exhausted(&self) -> bool {
421 matches!(self, Self::ReconnectExhausted { .. })
422 }
423
424 fn is_reconnecting(&self) -> bool {
425 matches!(self, Self::Reconnecting { .. })
426 }
427
428 fn heartbeat_recovered(self) -> Self {
429 match self {
430 Self::HeartbeatMissed {
431 ssh_connection,
432 delegate,
433 forwarder,
434 multiplex_task,
435 heartbeat_task,
436 ..
437 } => Self::Connected {
438 ssh_connection,
439 delegate,
440 forwarder,
441 multiplex_task,
442 heartbeat_task,
443 },
444 _ => self,
445 }
446 }
447
448 fn heartbeat_missed(self) -> Self {
449 match self {
450 Self::Connected {
451 ssh_connection,
452 delegate,
453 forwarder,
454 multiplex_task,
455 heartbeat_task,
456 } => Self::HeartbeatMissed {
457 missed_heartbeats: 1,
458 ssh_connection,
459 delegate,
460 forwarder,
461 multiplex_task,
462 heartbeat_task,
463 },
464 Self::HeartbeatMissed {
465 missed_heartbeats,
466 ssh_connection,
467 delegate,
468 forwarder,
469 multiplex_task,
470 heartbeat_task,
471 } => Self::HeartbeatMissed {
472 missed_heartbeats: missed_heartbeats + 1,
473 ssh_connection,
474 delegate,
475 forwarder,
476 multiplex_task,
477 heartbeat_task,
478 },
479 _ => self,
480 }
481 }
482}
483
484/// The state of the ssh connection.
485#[derive(Clone, Copy, Debug, PartialEq, Eq)]
486pub enum ConnectionState {
487 Connecting,
488 Connected,
489 HeartbeatMissed,
490 Reconnecting,
491 Disconnected,
492}
493
494impl From<&State> for ConnectionState {
495 fn from(value: &State) -> Self {
496 match value {
497 State::Connecting => Self::Connecting,
498 State::Connected { .. } => Self::Connected,
499 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
500 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
501 State::ReconnectExhausted => Self::Disconnected,
502 State::ServerNotRunning => Self::Disconnected,
503 }
504 }
505}
506
507pub struct SshRemoteClient {
508 client: Arc<ChannelClient>,
509 unique_identifier: String,
510 connection_options: SshConnectionOptions,
511 state: Arc<Mutex<Option<State>>>,
512}
513
514#[derive(Debug)]
515pub enum SshRemoteEvent {
516 Disconnected,
517}
518
519impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
520
521impl SshRemoteClient {
522 pub fn new(
523 unique_identifier: String,
524 connection_options: SshConnectionOptions,
525 delegate: Arc<dyn SshClientDelegate>,
526 cx: &AppContext,
527 ) -> Task<Result<Model<Self>>> {
528 cx.spawn(|mut cx| async move {
529 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
530 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
531 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
532
533 let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
534 let this = cx.new_model(|_| Self {
535 client: client.clone(),
536 unique_identifier: unique_identifier.clone(),
537 connection_options: connection_options.clone(),
538 state: Arc::new(Mutex::new(Some(State::Connecting))),
539 })?;
540
541 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
542 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
543
544 let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
545 unique_identifier,
546 false,
547 connection_options,
548 delegate.clone(),
549 &mut cx,
550 )
551 .await?;
552
553 let multiplex_task = Self::multiplex(
554 this.downgrade(),
555 ssh_proxy_process,
556 proxy_incoming_tx,
557 proxy_outgoing_rx,
558 connection_activity_tx,
559 &mut cx,
560 );
561
562 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
563 log::error!("failed to establish connection: {}", error);
564 delegate.set_error(error.to_string(), &mut cx);
565 return Err(error);
566 }
567
568 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
569
570 this.update(&mut cx, |this, _| {
571 *this.state.lock() = Some(State::Connected {
572 ssh_connection,
573 delegate,
574 forwarder: proxy,
575 multiplex_task,
576 heartbeat_task,
577 });
578 })?;
579
580 Ok(this)
581 })
582 }
583
584 pub fn shutdown_processes<T: RequestMessage>(
585 &self,
586 shutdown_request: Option<T>,
587 ) -> Option<impl Future<Output = ()>> {
588 let state = self.state.lock().take()?;
589 log::info!("shutting down ssh processes");
590
591 let State::Connected {
592 multiplex_task,
593 heartbeat_task,
594 ssh_connection,
595 delegate,
596 forwarder,
597 } = state
598 else {
599 return None;
600 };
601
602 let client = self.client.clone();
603
604 Some(async move {
605 if let Some(shutdown_request) = shutdown_request {
606 client.send(shutdown_request).log_err();
607 // We wait 50ms instead of waiting for a response, because
608 // waiting for a response would require us to wait on the main thread
609 // which we want to avoid in an `on_app_quit` callback.
610 smol::Timer::after(Duration::from_millis(50)).await;
611 }
612
613 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
614 // child of master_process.
615 drop(multiplex_task);
616 // Now drop the rest of state, which kills master process.
617 drop(heartbeat_task);
618 drop(ssh_connection);
619 drop(delegate);
620 drop(forwarder);
621 })
622 }
623
624 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
625 let mut lock = self.state.lock();
626
627 let can_reconnect = lock
628 .as_ref()
629 .map(|state| state.can_reconnect())
630 .unwrap_or(false);
631 if !can_reconnect {
632 let error = if let Some(state) = lock.as_ref() {
633 format!("invalid state, cannot reconnect while in state {state}")
634 } else {
635 "no state set".to_string()
636 };
637 log::info!("aborting reconnect, because not in state that allows reconnecting");
638 return Err(anyhow!(error));
639 }
640
641 let state = lock.take().unwrap();
642 let (attempts, mut ssh_connection, delegate, forwarder) = match state {
643 State::Connected {
644 ssh_connection,
645 delegate,
646 forwarder,
647 multiplex_task,
648 heartbeat_task,
649 }
650 | State::HeartbeatMissed {
651 ssh_connection,
652 delegate,
653 forwarder,
654 multiplex_task,
655 heartbeat_task,
656 ..
657 } => {
658 drop(multiplex_task);
659 drop(heartbeat_task);
660 (0, ssh_connection, delegate, forwarder)
661 }
662 State::ReconnectFailed {
663 attempts,
664 ssh_connection,
665 delegate,
666 forwarder,
667 ..
668 } => (attempts, ssh_connection, delegate, forwarder),
669 State::Connecting
670 | State::Reconnecting
671 | State::ReconnectExhausted
672 | State::ServerNotRunning => unreachable!(),
673 };
674
675 let attempts = attempts + 1;
676 if attempts > MAX_RECONNECT_ATTEMPTS {
677 log::error!(
678 "Failed to reconnect to after {} attempts, giving up",
679 MAX_RECONNECT_ATTEMPTS
680 );
681 drop(lock);
682 self.set_state(State::ReconnectExhausted, cx);
683 return Ok(());
684 }
685 drop(lock);
686
687 self.set_state(State::Reconnecting, cx);
688
689 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
690
691 let identifier = self.unique_identifier.clone();
692 let client = self.client.clone();
693 let reconnect_task = cx.spawn(|this, mut cx| async move {
694 macro_rules! failed {
695 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
696 return State::ReconnectFailed {
697 error: anyhow!($error),
698 attempts: $attempts,
699 ssh_connection: $ssh_connection,
700 delegate: $delegate,
701 forwarder: $forwarder,
702 };
703 };
704 }
705
706 if let Err(error) = ssh_connection.master_process.kill() {
707 failed!(error, attempts, ssh_connection, delegate, forwarder);
708 };
709
710 if let Err(error) = ssh_connection
711 .master_process
712 .status()
713 .await
714 .context("Failed to kill ssh process")
715 {
716 failed!(error, attempts, ssh_connection, delegate, forwarder);
717 }
718
719 let connection_options = ssh_connection.socket.connection_options.clone();
720
721 let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
722 let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
723 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
724 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
725
726 let (ssh_connection, ssh_process) = match Self::establish_connection(
727 identifier,
728 true,
729 connection_options,
730 delegate.clone(),
731 &mut cx,
732 )
733 .await
734 {
735 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
736 Err(error) => {
737 failed!(error, attempts, ssh_connection, delegate, forwarder);
738 }
739 };
740
741 let multiplex_task = Self::multiplex(
742 this.clone(),
743 ssh_process,
744 proxy_incoming_tx,
745 proxy_outgoing_rx,
746 connection_activity_tx,
747 &mut cx,
748 );
749
750 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
751 failed!(error, attempts, ssh_connection, delegate, forwarder);
752 };
753
754 State::Connected {
755 ssh_connection,
756 delegate,
757 forwarder,
758 multiplex_task,
759 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
760 }
761 });
762
763 cx.spawn(|this, mut cx| async move {
764 let new_state = reconnect_task.await;
765 this.update(&mut cx, |this, cx| {
766 this.try_set_state(cx, |old_state| {
767 if old_state.is_reconnecting() {
768 match &new_state {
769 State::Connecting
770 | State::Reconnecting { .. }
771 | State::HeartbeatMissed { .. }
772 | State::ServerNotRunning => {}
773 State::Connected { .. } => {
774 log::info!("Successfully reconnected");
775 }
776 State::ReconnectFailed {
777 error, attempts, ..
778 } => {
779 log::error!(
780 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
781 attempts,
782 error
783 );
784 }
785 State::ReconnectExhausted => {
786 log::error!("Reconnect attempt failed and all attempts exhausted");
787 }
788 }
789 Some(new_state)
790 } else {
791 None
792 }
793 });
794
795 if this.state_is(State::is_reconnect_failed) {
796 this.reconnect(cx)
797 } else if this.state_is(State::is_reconnect_exhausted) {
798 cx.emit(SshRemoteEvent::Disconnected);
799 Ok(())
800 } else {
801 log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
802 Ok(())
803 }
804 })
805 })
806 .detach_and_log_err(cx);
807
808 Ok(())
809 }
810
811 fn heartbeat(
812 this: WeakModel<Self>,
813 mut connection_activity_rx: mpsc::Receiver<()>,
814 cx: &mut AsyncAppContext,
815 ) -> Task<Result<()>> {
816 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
817 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
818 };
819
820 cx.spawn(|mut cx| {
821 let this = this.clone();
822 async move {
823 let mut missed_heartbeats = 0;
824
825 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
826 futures::pin_mut!(keepalive_timer);
827
828 loop {
829 select_biased! {
830 result = connection_activity_rx.next().fuse() => {
831 if result.is_none() {
832 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
833 return Ok(());
834 }
835 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
836 }
837 _ = keepalive_timer => {
838 log::debug!("Sending heartbeat to server...");
839
840 let result = select_biased! {
841 _ = connection_activity_rx.next().fuse() => {
842 Ok(())
843 }
844 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
845 ping_result
846 }
847 };
848 if result.is_err() {
849 missed_heartbeats += 1;
850 log::warn!(
851 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
852 HEARTBEAT_TIMEOUT,
853 missed_heartbeats,
854 MAX_MISSED_HEARTBEATS
855 );
856 } else if missed_heartbeats != 0 {
857 missed_heartbeats = 0;
858 } else {
859 continue;
860 }
861
862 let result = this.update(&mut cx, |this, mut cx| {
863 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
864 })?;
865 if result.is_break() {
866 return Ok(());
867 }
868 }
869 }
870 }
871 }
872 })
873 }
874
875 fn handle_heartbeat_result(
876 &mut self,
877 missed_heartbeats: usize,
878 cx: &mut ModelContext<Self>,
879 ) -> ControlFlow<()> {
880 let state = self.state.lock().take().unwrap();
881 let next_state = if missed_heartbeats > 0 {
882 state.heartbeat_missed()
883 } else {
884 state.heartbeat_recovered()
885 };
886
887 self.set_state(next_state, cx);
888
889 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
890 log::error!(
891 "Missed last {} heartbeats. Reconnecting...",
892 missed_heartbeats
893 );
894
895 self.reconnect(cx)
896 .context("failed to start reconnect process after missing heartbeats")
897 .log_err();
898 ControlFlow::Break(())
899 } else {
900 ControlFlow::Continue(())
901 }
902 }
903
904 fn multiplex(
905 this: WeakModel<Self>,
906 mut ssh_proxy_process: Child,
907 incoming_tx: UnboundedSender<Envelope>,
908 mut outgoing_rx: UnboundedReceiver<Envelope>,
909 mut connection_activity_tx: Sender<()>,
910 cx: &AsyncAppContext,
911 ) -> Task<Result<()>> {
912 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
913 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
914 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
915
916 let io_task = cx.background_executor().spawn(async move {
917 let mut stdin_buffer = Vec::new();
918 let mut stdout_buffer = Vec::new();
919 let mut stderr_buffer = Vec::new();
920 let mut stderr_offset = 0;
921
922 loop {
923 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
924 stderr_buffer.resize(stderr_offset + 1024, 0);
925
926 select_biased! {
927 outgoing = outgoing_rx.next().fuse() => {
928 let Some(outgoing) = outgoing else {
929 return anyhow::Ok(None);
930 };
931
932 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
933 }
934
935 result = child_stdout.read(&mut stdout_buffer).fuse() => {
936 match result {
937 Ok(0) => {
938 child_stdin.close().await?;
939 outgoing_rx.close();
940 let status = ssh_proxy_process.status().await?;
941 // If we don't have a code, we assume process
942 // has been killed and treat it as non-zero exit
943 // code
944 return Ok(status.code().or_else(|| Some(1)));
945 }
946 Ok(len) => {
947 if len < stdout_buffer.len() {
948 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
949 }
950
951 let message_len = message_len_from_buffer(&stdout_buffer);
952 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
953 Ok(envelope) => {
954 connection_activity_tx.try_send(()).ok();
955 incoming_tx.unbounded_send(envelope).ok();
956 }
957 Err(error) => {
958 log::error!("error decoding message {error:?}");
959 }
960 }
961 }
962 Err(error) => {
963 Err(anyhow!("error reading stdout: {error:?}"))?;
964 }
965 }
966 }
967
968 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
969 match result {
970 Ok(len) => {
971 stderr_offset += len;
972 let mut start_ix = 0;
973 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
974 let line_ix = start_ix + ix;
975 let content = &stderr_buffer[start_ix..line_ix];
976 start_ix = line_ix + 1;
977 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
978 record.log(log::logger())
979 } else {
980 eprintln!("(remote) {}", String::from_utf8_lossy(content));
981 }
982 }
983 stderr_buffer.drain(0..start_ix);
984 stderr_offset -= start_ix;
985
986 connection_activity_tx.try_send(()).ok();
987 }
988 Err(error) => {
989 Err(anyhow!("error reading stderr: {error:?}"))?;
990 }
991 }
992 }
993 }
994 }
995 });
996
997 cx.spawn(|mut cx| async move {
998 let result = io_task.await;
999
1000 match result {
1001 Ok(Some(exit_code)) => {
1002 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1003 match error {
1004 ProxyLaunchError::ServerNotRunning => {
1005 log::error!("failed to reconnect because server is not running");
1006 this.update(&mut cx, |this, cx| {
1007 this.set_state(State::ServerNotRunning, cx);
1008 cx.emit(SshRemoteEvent::Disconnected);
1009 })?;
1010 }
1011 }
1012 } else if exit_code > 0 {
1013 log::error!("proxy process terminated unexpectedly");
1014 this.update(&mut cx, |this, cx| {
1015 this.reconnect(cx).ok();
1016 })?;
1017 }
1018 }
1019 Ok(None) => {}
1020 Err(error) => {
1021 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1022 this.update(&mut cx, |this, cx| {
1023 this.reconnect(cx).ok();
1024 })?;
1025 }
1026 }
1027 Ok(())
1028 })
1029 }
1030
1031 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1032 self.state.lock().as_ref().map_or(false, check)
1033 }
1034
1035 fn try_set_state(
1036 &self,
1037 cx: &mut ModelContext<Self>,
1038 map: impl FnOnce(&State) -> Option<State>,
1039 ) {
1040 let mut lock = self.state.lock();
1041 let new_state = lock.as_ref().and_then(map);
1042
1043 if let Some(new_state) = new_state {
1044 lock.replace(new_state);
1045 cx.notify();
1046 }
1047 }
1048
1049 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
1050 log::info!("setting state to '{}'", &state);
1051 self.state.lock().replace(state);
1052 cx.notify();
1053 }
1054
1055 async fn establish_connection(
1056 unique_identifier: String,
1057 reconnect: bool,
1058 connection_options: SshConnectionOptions,
1059 delegate: Arc<dyn SshClientDelegate>,
1060 cx: &mut AsyncAppContext,
1061 ) -> Result<(SshRemoteConnection, Child)> {
1062 let ssh_connection =
1063 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
1064
1065 let platform = ssh_connection.query_platform().await?;
1066 let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1067 ssh_connection
1068 .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1069 .await?;
1070
1071 let socket = ssh_connection.socket.clone();
1072 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1073
1074 delegate.set_status(Some("Starting proxy"), cx);
1075
1076 let mut start_proxy_command = format!(
1077 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1078 std::env::var("RUST_LOG").unwrap_or_default(),
1079 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1080 remote_binary_path,
1081 unique_identifier,
1082 );
1083 if reconnect {
1084 start_proxy_command.push_str(" --reconnect");
1085 }
1086
1087 let ssh_proxy_process = socket
1088 .ssh_command(start_proxy_command)
1089 // IMPORTANT: we kill this process when we drop the task that uses it.
1090 .kill_on_drop(true)
1091 .spawn()
1092 .context("failed to spawn remote server")?;
1093
1094 Ok((ssh_connection, ssh_proxy_process))
1095 }
1096
1097 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1098 self.client.subscribe_to_entity(remote_id, entity);
1099 }
1100
1101 pub fn ssh_args(&self) -> Option<Vec<String>> {
1102 self.state
1103 .lock()
1104 .as_ref()
1105 .and_then(|state| state.ssh_connection())
1106 .map(|ssh_connection| ssh_connection.socket.ssh_args())
1107 }
1108
1109 pub fn proto_client(&self) -> AnyProtoClient {
1110 self.client.clone().into()
1111 }
1112
1113 pub fn connection_string(&self) -> String {
1114 self.connection_options.connection_string()
1115 }
1116
1117 pub fn connection_options(&self) -> SshConnectionOptions {
1118 self.connection_options.clone()
1119 }
1120
1121 #[cfg(not(any(test, feature = "test-support")))]
1122 pub fn connection_state(&self) -> ConnectionState {
1123 self.state
1124 .lock()
1125 .as_ref()
1126 .map(ConnectionState::from)
1127 .unwrap_or(ConnectionState::Disconnected)
1128 }
1129
1130 #[cfg(any(test, feature = "test-support"))]
1131 pub fn connection_state(&self) -> ConnectionState {
1132 ConnectionState::Connected
1133 }
1134
1135 pub fn is_disconnected(&self) -> bool {
1136 self.connection_state() == ConnectionState::Disconnected
1137 }
1138
1139 #[cfg(any(test, feature = "test-support"))]
1140 pub fn fake(
1141 client_cx: &mut gpui::TestAppContext,
1142 server_cx: &mut gpui::TestAppContext,
1143 ) -> (Model<Self>, Arc<ChannelClient>) {
1144 use gpui::Context;
1145
1146 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1147 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1148
1149 (
1150 client_cx.update(|cx| {
1151 let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1152 cx.new_model(|_| Self {
1153 client,
1154 unique_identifier: "fake".to_string(),
1155 connection_options: SshConnectionOptions::default(),
1156 state: Arc::new(Mutex::new(None)),
1157 })
1158 }),
1159 server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1160 )
1161 }
1162}
1163
1164impl From<SshRemoteClient> for AnyProtoClient {
1165 fn from(client: SshRemoteClient) -> Self {
1166 AnyProtoClient::new(client.client.clone())
1167 }
1168}
1169
1170struct SshRemoteConnection {
1171 socket: SshSocket,
1172 master_process: process::Child,
1173 _temp_dir: TempDir,
1174}
1175
1176impl Drop for SshRemoteConnection {
1177 fn drop(&mut self) {
1178 if let Err(error) = self.master_process.kill() {
1179 log::error!("failed to kill SSH master process: {}", error);
1180 }
1181 }
1182}
1183
1184impl SshRemoteConnection {
1185 #[cfg(not(unix))]
1186 async fn new(
1187 _connection_options: SshConnectionOptions,
1188 _delegate: Arc<dyn SshClientDelegate>,
1189 _cx: &mut AsyncAppContext,
1190 ) -> Result<Self> {
1191 Err(anyhow!("ssh is not supported on this platform"))
1192 }
1193
1194 #[cfg(unix)]
1195 async fn new(
1196 connection_options: SshConnectionOptions,
1197 delegate: Arc<dyn SshClientDelegate>,
1198 cx: &mut AsyncAppContext,
1199 ) -> Result<Self> {
1200 use futures::{io::BufReader, AsyncBufReadExt as _};
1201 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1202 use util::ResultExt as _;
1203
1204 delegate.set_status(Some("connecting"), cx);
1205
1206 let url = connection_options.ssh_url();
1207 let temp_dir = tempfile::Builder::new()
1208 .prefix("zed-ssh-session")
1209 .tempdir()?;
1210
1211 // Create a domain socket listener to handle requests from the askpass program.
1212 let askpass_socket = temp_dir.path().join("askpass.sock");
1213 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1214 let listener =
1215 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1216
1217 let askpass_task = cx.spawn({
1218 let delegate = delegate.clone();
1219 |mut cx| async move {
1220 let mut askpass_opened_tx = Some(askpass_opened_tx);
1221
1222 while let Ok((mut stream, _)) = listener.accept().await {
1223 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1224 askpass_opened_tx.send(()).ok();
1225 }
1226 let mut buffer = Vec::new();
1227 let mut reader = BufReader::new(&mut stream);
1228 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1229 buffer.clear();
1230 }
1231 let password_prompt = String::from_utf8_lossy(&buffer);
1232 if let Some(password) = delegate
1233 .ask_password(password_prompt.to_string(), &mut cx)
1234 .await
1235 .context("failed to get ssh password")
1236 .and_then(|p| p)
1237 .log_err()
1238 {
1239 stream.write_all(password.as_bytes()).await.log_err();
1240 }
1241 }
1242 }
1243 });
1244
1245 // Create an askpass script that communicates back to this process.
1246 let askpass_script = format!(
1247 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1248 askpass_socket = askpass_socket.display(),
1249 print_args = "printf '%s\\0' \"$@\"",
1250 shebang = "#!/bin/sh",
1251 );
1252 let askpass_script_path = temp_dir.path().join("askpass.sh");
1253 fs::write(&askpass_script_path, askpass_script).await?;
1254 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1255
1256 // Start the master SSH process, which does not do anything except for establish
1257 // the connection and keep it open, allowing other ssh commands to reuse it
1258 // via a control socket.
1259 let socket_path = temp_dir.path().join("ssh.sock");
1260 let mut master_process = process::Command::new("ssh")
1261 .stdin(Stdio::null())
1262 .stdout(Stdio::piped())
1263 .stderr(Stdio::piped())
1264 .env("SSH_ASKPASS_REQUIRE", "force")
1265 .env("SSH_ASKPASS", &askpass_script_path)
1266 .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1267 .args([
1268 "-N",
1269 "-o",
1270 "ControlPersist=no",
1271 "-o",
1272 "ControlMaster=yes",
1273 "-o",
1274 ])
1275 .arg(format!("ControlPath={}", socket_path.display()))
1276 .arg(&url)
1277 .spawn()?;
1278
1279 // Wait for this ssh process to close its stdout, indicating that authentication
1280 // has completed.
1281 let stdout = master_process.stdout.as_mut().unwrap();
1282 let mut output = Vec::new();
1283 let connection_timeout = Duration::from_secs(10);
1284
1285 let result = select_biased! {
1286 _ = askpass_opened_rx.fuse() => {
1287 // If the askpass script has opened, that means the user is typing
1288 // their password, in which case we don't want to timeout anymore,
1289 // since we know a connection has been established.
1290 stdout.read_to_end(&mut output).await?;
1291 Ok(())
1292 }
1293 result = stdout.read_to_end(&mut output).fuse() => {
1294 result?;
1295 Ok(())
1296 }
1297 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1298 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1299 }
1300 };
1301
1302 if let Err(e) = result {
1303 let error_message = format!("Failed to connect to host: {}.", e);
1304 delegate.set_error(error_message, cx);
1305 return Err(e);
1306 }
1307
1308 drop(askpass_task);
1309
1310 if master_process.try_status()?.is_some() {
1311 output.clear();
1312 let mut stderr = master_process.stderr.take().unwrap();
1313 stderr.read_to_end(&mut output).await?;
1314
1315 let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1316 delegate.set_error(error_message.clone(), cx);
1317 Err(anyhow!(error_message))?;
1318 }
1319
1320 Ok(Self {
1321 socket: SshSocket {
1322 connection_options,
1323 socket_path,
1324 },
1325 master_process,
1326 _temp_dir: temp_dir,
1327 })
1328 }
1329
1330 async fn ensure_server_binary(
1331 &self,
1332 delegate: &Arc<dyn SshClientDelegate>,
1333 dst_path: &Path,
1334 platform: SshPlatform,
1335 cx: &mut AsyncAppContext,
1336 ) -> Result<()> {
1337 if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1338 if let Ok(installed_version) =
1339 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1340 {
1341 log::info!("using cached server binary version {}", installed_version);
1342 return Ok(());
1343 }
1344 }
1345
1346 let mut dst_path_gz = dst_path.to_path_buf();
1347 dst_path_gz.set_extension("gz");
1348
1349 if let Some(parent) = dst_path.parent() {
1350 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1351 }
1352
1353 let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
1354
1355 let mut server_binary_exists = false;
1356 if !server_binary_exists && cfg!(not(debug_assertions)) {
1357 if let Ok(installed_version) =
1358 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1359 {
1360 if installed_version.trim() == version.to_string() {
1361 server_binary_exists = true;
1362 }
1363 }
1364 }
1365
1366 if server_binary_exists {
1367 log::info!("remote development server already present",);
1368 return Ok(());
1369 }
1370
1371 let src_stat = fs::metadata(&src_path).await?;
1372 let size = src_stat.len();
1373 let server_mode = 0o755;
1374
1375 let t0 = Instant::now();
1376 delegate.set_status(Some("uploading remote development server"), cx);
1377 log::info!("uploading remote development server ({}kb)", size / 1024);
1378 self.upload_file(&src_path, &dst_path_gz)
1379 .await
1380 .context("failed to upload server binary")?;
1381 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1382
1383 delegate.set_status(Some("extracting remote development server"), cx);
1384 run_cmd(
1385 self.socket
1386 .ssh_command("gunzip")
1387 .arg("--force")
1388 .arg(&dst_path_gz),
1389 )
1390 .await?;
1391
1392 delegate.set_status(Some("unzipping remote development server"), cx);
1393 run_cmd(
1394 self.socket
1395 .ssh_command("chmod")
1396 .arg(format!("{:o}", server_mode))
1397 .arg(dst_path),
1398 )
1399 .await?;
1400
1401 Ok(())
1402 }
1403
1404 async fn query_platform(&self) -> Result<SshPlatform> {
1405 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1406 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1407
1408 let os = match os.trim() {
1409 "Darwin" => "macos",
1410 "Linux" => "linux",
1411 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1412 };
1413 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1414 "aarch64"
1415 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1416 "x86_64"
1417 } else {
1418 Err(anyhow!("unknown uname architecture {arch:?}"))?
1419 };
1420
1421 Ok(SshPlatform { os, arch })
1422 }
1423
1424 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1425 let mut command = process::Command::new("scp");
1426 let output = self
1427 .socket
1428 .ssh_options(&mut command)
1429 .args(
1430 self.socket
1431 .connection_options
1432 .port
1433 .map(|port| vec!["-P".to_string(), port.to_string()])
1434 .unwrap_or_default(),
1435 )
1436 .arg(src_path)
1437 .arg(format!(
1438 "{}:{}",
1439 self.socket.connection_options.scp_url(),
1440 dest_path.display()
1441 ))
1442 .output()
1443 .await?;
1444
1445 if output.status.success() {
1446 Ok(())
1447 } else {
1448 Err(anyhow!(
1449 "failed to upload file {} -> {}: {}",
1450 src_path.display(),
1451 dest_path.display(),
1452 String::from_utf8_lossy(&output.stderr)
1453 ))
1454 }
1455 }
1456}
1457
1458type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1459
1460pub struct ChannelClient {
1461 next_message_id: AtomicU32,
1462 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1463 response_channels: ResponseChannels, // Lock
1464 message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1465}
1466
1467impl ChannelClient {
1468 pub fn new(
1469 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1470 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1471 cx: &AppContext,
1472 ) -> Arc<Self> {
1473 let this = Arc::new(Self {
1474 outgoing_tx,
1475 next_message_id: AtomicU32::new(0),
1476 response_channels: ResponseChannels::default(),
1477 message_handlers: Default::default(),
1478 });
1479
1480 Self::start_handling_messages(this.clone(), incoming_rx, cx);
1481
1482 this
1483 }
1484
1485 fn start_handling_messages(
1486 this: Arc<Self>,
1487 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1488 cx: &AppContext,
1489 ) {
1490 cx.spawn(|cx| {
1491 let this = Arc::downgrade(&this);
1492 async move {
1493 let peer_id = PeerId { owner_id: 0, id: 0 };
1494 while let Some(incoming) = incoming_rx.next().await {
1495 let Some(this) = this.upgrade() else {
1496 return anyhow::Ok(());
1497 };
1498
1499 if let Some(request_id) = incoming.responding_to {
1500 let request_id = MessageId(request_id);
1501 let sender = this.response_channels.lock().remove(&request_id);
1502 if let Some(sender) = sender {
1503 let (tx, rx) = oneshot::channel();
1504 if incoming.payload.is_some() {
1505 sender.send((incoming, tx)).ok();
1506 }
1507 rx.await.ok();
1508 }
1509 } else if let Some(envelope) =
1510 build_typed_envelope(peer_id, Instant::now(), incoming)
1511 {
1512 let type_name = envelope.payload_type_name();
1513 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1514 &this.message_handlers,
1515 envelope,
1516 this.clone().into(),
1517 cx.clone(),
1518 ) {
1519 log::debug!("ssh message received. name:{type_name}");
1520 cx.foreground_executor().spawn(async move {
1521 match future.await {
1522 Ok(_) => {
1523 log::debug!("ssh message handled. name:{type_name}");
1524 }
1525 Err(error) => {
1526 log::error!(
1527 "error handling message. type:{type_name}, error:{error}",
1528 );
1529 }
1530 }
1531 }).detach();
1532
1533 } else {
1534 log::error!("unhandled ssh message name:{type_name}");
1535 }
1536 }
1537 }
1538 anyhow::Ok(())
1539 }
1540 })
1541 .detach();
1542 }
1543
1544 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1545 let id = (TypeId::of::<E>(), remote_id);
1546
1547 let mut message_handlers = self.message_handlers.lock();
1548 if message_handlers
1549 .entities_by_type_and_remote_id
1550 .contains_key(&id)
1551 {
1552 panic!("already subscribed to entity");
1553 }
1554
1555 message_handlers.entities_by_type_and_remote_id.insert(
1556 id,
1557 EntityMessageSubscriber::Entity {
1558 handle: entity.downgrade().into(),
1559 },
1560 );
1561 }
1562
1563 pub fn request<T: RequestMessage>(
1564 &self,
1565 payload: T,
1566 ) -> impl 'static + Future<Output = Result<T::Response>> {
1567 log::debug!("ssh request start. name:{}", T::NAME);
1568 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1569 async move {
1570 let response = response.await?;
1571 log::debug!("ssh request finish. name:{}", T::NAME);
1572 T::Response::from_envelope(response)
1573 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1574 }
1575 }
1576
1577 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1578 smol::future::or(
1579 async {
1580 self.request(proto::Ping {}).await?;
1581 Ok(())
1582 },
1583 async {
1584 smol::Timer::after(timeout).await;
1585 Err(anyhow!("Timeout detected"))
1586 },
1587 )
1588 .await
1589 }
1590
1591 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1592 log::debug!("ssh send name:{}", T::NAME);
1593 self.send_dynamic(payload.into_envelope(0, None, None))
1594 }
1595
1596 pub fn request_dynamic(
1597 &self,
1598 mut envelope: proto::Envelope,
1599 type_name: &'static str,
1600 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1601 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1602 let (tx, rx) = oneshot::channel();
1603 let mut response_channels_lock = self.response_channels.lock();
1604 response_channels_lock.insert(MessageId(envelope.id), tx);
1605 drop(response_channels_lock);
1606 let result = self.outgoing_tx.unbounded_send(envelope);
1607 async move {
1608 if let Err(error) = &result {
1609 log::error!("failed to send message: {}", error);
1610 return Err(anyhow!("failed to send message: {}", error));
1611 }
1612
1613 let response = rx.await.context("connection lost")?.0;
1614 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1615 return Err(RpcError::from_proto(error, type_name));
1616 }
1617 Ok(response)
1618 }
1619 }
1620
1621 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1622 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1623 self.outgoing_tx.unbounded_send(envelope)?;
1624 Ok(())
1625 }
1626}
1627
1628impl ProtoClient for ChannelClient {
1629 fn request(
1630 &self,
1631 envelope: proto::Envelope,
1632 request_type: &'static str,
1633 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1634 self.request_dynamic(envelope, request_type).boxed()
1635 }
1636
1637 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1638 self.send_dynamic(envelope)
1639 }
1640
1641 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1642 self.send_dynamic(envelope)
1643 }
1644
1645 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1646 &self.message_handlers
1647 }
1648
1649 fn is_via_collab(&self) -> bool {
1650 false
1651 }
1652}