1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use collections::HashMap;
10use futures::{
11 channel::{
12 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
13 oneshot,
14 },
15 future::BoxFuture,
16 select_biased, AsyncReadExt as _, Future, FutureExt as _, SinkExt, StreamExt as _,
17};
18use gpui::{
19 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
20 WeakModel,
21};
22use parking_lot::Mutex;
23use rpc::{
24 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
25 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
26};
27use smol::{
28 fs,
29 process::{self, Child, Stdio},
30};
31use std::{
32 any::TypeId,
33 ffi::OsStr,
34 fmt,
35 ops::ControlFlow,
36 path::{Path, PathBuf},
37 sync::{
38 atomic::{AtomicU32, Ordering::SeqCst},
39 Arc,
40 },
41 time::{Duration, Instant},
42};
43use tempfile::TempDir;
44use util::ResultExt;
45
46#[derive(
47 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
48)]
49pub struct SshProjectId(pub u64);
50
51#[derive(Clone)]
52pub struct SshSocket {
53 connection_options: SshConnectionOptions,
54 socket_path: PathBuf,
55}
56
57#[derive(Debug, Default, Clone, PartialEq, Eq)]
58pub struct SshConnectionOptions {
59 pub host: String,
60 pub username: Option<String>,
61 pub port: Option<u16>,
62 pub password: Option<String>,
63 pub args: Option<Vec<String>>,
64}
65
66impl SshConnectionOptions {
67 pub fn parse_command_line(input: &str) -> Result<Self> {
68 let input = input.trim_start_matches("ssh ");
69 let mut hostname: Option<String> = None;
70 let mut username: Option<String> = None;
71 let mut port: Option<u16> = None;
72 let mut args = Vec::new();
73
74 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
75 const ALLOWED_OPTS: &[&str] = &[
76 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
77 ];
78 const ALLOWED_ARGS: &[&str] = &[
79 "-B", "-b", "-c", "-D", "-I", "-i", "-J", "-L", "-l", "-m", "-o", "-P", "-p", "-R",
80 "-w",
81 ];
82
83 let mut tokens = shlex::split(input)
84 .ok_or_else(|| anyhow!("invalid input"))?
85 .into_iter();
86
87 'outer: while let Some(arg) = tokens.next() {
88 if ALLOWED_OPTS.contains(&(&arg as &str)) {
89 args.push(arg.to_string());
90 continue;
91 }
92 if arg == "-p" {
93 port = tokens.next().and_then(|arg| arg.parse().ok());
94 continue;
95 } else if let Some(p) = arg.strip_prefix("-p") {
96 port = p.parse().ok();
97 continue;
98 }
99 if arg == "-l" {
100 username = tokens.next();
101 continue;
102 } else if let Some(l) = arg.strip_prefix("-l") {
103 username = Some(l.to_string());
104 continue;
105 }
106 for a in ALLOWED_ARGS {
107 if arg == *a {
108 args.push(arg);
109 if let Some(next) = tokens.next() {
110 args.push(next);
111 }
112 continue 'outer;
113 } else if arg.starts_with(a) {
114 args.push(arg);
115 continue 'outer;
116 }
117 }
118 if arg.starts_with("-") || hostname.is_some() {
119 anyhow::bail!("unsupported argument: {:?}", arg);
120 }
121 let mut input = &arg as &str;
122 if let Some((u, rest)) = input.split_once('@') {
123 input = rest;
124 username = Some(u.to_string());
125 }
126 if let Some((rest, p)) = input.split_once(':') {
127 input = rest;
128 port = p.parse().ok()
129 }
130 hostname = Some(input.to_string())
131 }
132
133 let Some(hostname) = hostname else {
134 anyhow::bail!("missing hostname");
135 };
136
137 Ok(Self {
138 host: hostname.to_string(),
139 username: username.clone(),
140 port,
141 password: None,
142 args: Some(args),
143 })
144 }
145
146 pub fn ssh_url(&self) -> String {
147 let mut result = String::from("ssh://");
148 if let Some(username) = &self.username {
149 result.push_str(username);
150 result.push('@');
151 }
152 result.push_str(&self.host);
153 if let Some(port) = self.port {
154 result.push(':');
155 result.push_str(&port.to_string());
156 }
157 result
158 }
159
160 pub fn additional_args(&self) -> Option<&Vec<String>> {
161 self.args.as_ref()
162 }
163
164 fn scp_url(&self) -> String {
165 if let Some(username) = &self.username {
166 format!("{}@{}", username, self.host)
167 } else {
168 self.host.clone()
169 }
170 }
171
172 pub fn connection_string(&self) -> String {
173 let host = if let Some(username) = &self.username {
174 format!("{}@{}", username, self.host)
175 } else {
176 self.host.clone()
177 };
178 if let Some(port) = &self.port {
179 format!("{}:{}", host, port)
180 } else {
181 host
182 }
183 }
184
185 // Uniquely identifies dev server projects on a remote host. Needs to be
186 // stable for the same dev server project.
187 pub fn dev_server_identifier(&self) -> String {
188 let mut identifier = format!("dev-server-{:?}", self.host);
189 if let Some(username) = self.username.as_ref() {
190 identifier.push('-');
191 identifier.push_str(&username);
192 }
193 identifier
194 }
195}
196
197#[derive(Copy, Clone, Debug)]
198pub struct SshPlatform {
199 pub os: &'static str,
200 pub arch: &'static str,
201}
202
203impl SshPlatform {
204 pub fn triple(&self) -> Option<String> {
205 Some(format!(
206 "{}-{}",
207 self.arch,
208 match self.os {
209 "linux" => "unknown-linux-gnu",
210 "macos" => "apple-darwin",
211 _ => return None,
212 }
213 ))
214 }
215}
216
217pub trait SshClientDelegate: Send + Sync {
218 fn ask_password(
219 &self,
220 prompt: String,
221 cx: &mut AsyncAppContext,
222 ) -> oneshot::Receiver<Result<String>>;
223 fn remote_server_binary_path(
224 &self,
225 platform: SshPlatform,
226 cx: &mut AsyncAppContext,
227 ) -> Result<PathBuf>;
228 fn get_server_binary(
229 &self,
230 platform: SshPlatform,
231 cx: &mut AsyncAppContext,
232 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
233 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
234 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
235}
236
237impl SshSocket {
238 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
239 let mut command = process::Command::new("ssh");
240 self.ssh_options(&mut command)
241 .arg(self.connection_options.ssh_url())
242 .arg(program);
243 command
244 }
245
246 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
247 command
248 .stdin(Stdio::piped())
249 .stdout(Stdio::piped())
250 .stderr(Stdio::piped())
251 .args(["-o", "ControlMaster=no", "-o"])
252 .arg(format!("ControlPath={}", self.socket_path.display()))
253 }
254
255 fn ssh_args(&self) -> Vec<String> {
256 vec![
257 "-o".to_string(),
258 "ControlMaster=no".to_string(),
259 "-o".to_string(),
260 format!("ControlPath={}", self.socket_path.display()),
261 self.connection_options.ssh_url(),
262 ]
263 }
264}
265
266async fn run_cmd(command: &mut process::Command) -> Result<String> {
267 let output = command.output().await?;
268 if output.status.success() {
269 Ok(String::from_utf8_lossy(&output.stdout).to_string())
270 } else {
271 Err(anyhow!(
272 "failed to run command: {}",
273 String::from_utf8_lossy(&output.stderr)
274 ))
275 }
276}
277
278struct ChannelForwarder {
279 quit_tx: UnboundedSender<()>,
280 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
281}
282
283impl ChannelForwarder {
284 fn new(
285 mut incoming_tx: UnboundedSender<Envelope>,
286 mut outgoing_rx: UnboundedReceiver<Envelope>,
287 cx: &AsyncAppContext,
288 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
289 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
290
291 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
292 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
293
294 let forwarding_task = cx.background_executor().spawn(async move {
295 loop {
296 select_biased! {
297 _ = quit_rx.next().fuse() => {
298 break;
299 },
300 incoming_envelope = proxy_incoming_rx.next().fuse() => {
301 if let Some(envelope) = incoming_envelope {
302 if incoming_tx.send(envelope).await.is_err() {
303 break;
304 }
305 } else {
306 break;
307 }
308 }
309 outgoing_envelope = outgoing_rx.next().fuse() => {
310 if let Some(envelope) = outgoing_envelope {
311 if proxy_outgoing_tx.send(envelope).await.is_err() {
312 break;
313 }
314 } else {
315 break;
316 }
317 }
318 }
319 }
320
321 (incoming_tx, outgoing_rx)
322 });
323
324 (
325 Self {
326 forwarding_task,
327 quit_tx,
328 },
329 proxy_incoming_tx,
330 proxy_outgoing_rx,
331 )
332 }
333
334 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
335 let _ = self.quit_tx.send(()).await;
336 self.forwarding_task.await
337 }
338}
339
340const MAX_MISSED_HEARTBEATS: usize = 5;
341const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
342const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
343
344const MAX_RECONNECT_ATTEMPTS: usize = 3;
345
346enum State {
347 Connecting,
348 Connected {
349 ssh_connection: SshRemoteConnection,
350 delegate: Arc<dyn SshClientDelegate>,
351 forwarder: ChannelForwarder,
352
353 multiplex_task: Task<Result<()>>,
354 heartbeat_task: Task<Result<()>>,
355 },
356 HeartbeatMissed {
357 missed_heartbeats: usize,
358
359 ssh_connection: SshRemoteConnection,
360 delegate: Arc<dyn SshClientDelegate>,
361 forwarder: ChannelForwarder,
362
363 multiplex_task: Task<Result<()>>,
364 heartbeat_task: Task<Result<()>>,
365 },
366 Reconnecting,
367 ReconnectFailed {
368 ssh_connection: SshRemoteConnection,
369 delegate: Arc<dyn SshClientDelegate>,
370 forwarder: ChannelForwarder,
371
372 error: anyhow::Error,
373 attempts: usize,
374 },
375 ReconnectExhausted,
376 ServerNotRunning,
377}
378
379impl fmt::Display for State {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 match self {
382 Self::Connecting => write!(f, "connecting"),
383 Self::Connected { .. } => write!(f, "connected"),
384 Self::Reconnecting => write!(f, "reconnecting"),
385 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
386 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
387 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
388 Self::ServerNotRunning { .. } => write!(f, "server not running"),
389 }
390 }
391}
392
393impl State {
394 fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
395 match self {
396 Self::Connected { ssh_connection, .. } => Some(ssh_connection),
397 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
398 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
399 _ => None,
400 }
401 }
402
403 fn can_reconnect(&self) -> bool {
404 match self {
405 Self::Connected { .. }
406 | Self::HeartbeatMissed { .. }
407 | Self::ReconnectFailed { .. } => true,
408 State::Connecting
409 | State::Reconnecting
410 | State::ReconnectExhausted
411 | State::ServerNotRunning => false,
412 }
413 }
414
415 fn is_reconnect_failed(&self) -> bool {
416 matches!(self, Self::ReconnectFailed { .. })
417 }
418
419 fn is_reconnect_exhausted(&self) -> bool {
420 matches!(self, Self::ReconnectExhausted { .. })
421 }
422
423 fn is_reconnecting(&self) -> bool {
424 matches!(self, Self::Reconnecting { .. })
425 }
426
427 fn heartbeat_recovered(self) -> Self {
428 match self {
429 Self::HeartbeatMissed {
430 ssh_connection,
431 delegate,
432 forwarder,
433 multiplex_task,
434 heartbeat_task,
435 ..
436 } => Self::Connected {
437 ssh_connection,
438 delegate,
439 forwarder,
440 multiplex_task,
441 heartbeat_task,
442 },
443 _ => self,
444 }
445 }
446
447 fn heartbeat_missed(self) -> Self {
448 match self {
449 Self::Connected {
450 ssh_connection,
451 delegate,
452 forwarder,
453 multiplex_task,
454 heartbeat_task,
455 } => Self::HeartbeatMissed {
456 missed_heartbeats: 1,
457 ssh_connection,
458 delegate,
459 forwarder,
460 multiplex_task,
461 heartbeat_task,
462 },
463 Self::HeartbeatMissed {
464 missed_heartbeats,
465 ssh_connection,
466 delegate,
467 forwarder,
468 multiplex_task,
469 heartbeat_task,
470 } => Self::HeartbeatMissed {
471 missed_heartbeats: missed_heartbeats + 1,
472 ssh_connection,
473 delegate,
474 forwarder,
475 multiplex_task,
476 heartbeat_task,
477 },
478 _ => self,
479 }
480 }
481}
482
483/// The state of the ssh connection.
484#[derive(Clone, Copy, Debug, PartialEq, Eq)]
485pub enum ConnectionState {
486 Connecting,
487 Connected,
488 HeartbeatMissed,
489 Reconnecting,
490 Disconnected,
491}
492
493impl From<&State> for ConnectionState {
494 fn from(value: &State) -> Self {
495 match value {
496 State::Connecting => Self::Connecting,
497 State::Connected { .. } => Self::Connected,
498 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
499 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
500 State::ReconnectExhausted => Self::Disconnected,
501 State::ServerNotRunning => Self::Disconnected,
502 }
503 }
504}
505
506pub struct SshRemoteClient {
507 client: Arc<ChannelClient>,
508 unique_identifier: String,
509 connection_options: SshConnectionOptions,
510 state: Arc<Mutex<Option<State>>>,
511}
512
513#[derive(Debug)]
514pub enum SshRemoteEvent {
515 Disconnected,
516}
517
518impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
519
520impl SshRemoteClient {
521 pub fn new(
522 unique_identifier: String,
523 connection_options: SshConnectionOptions,
524 delegate: Arc<dyn SshClientDelegate>,
525 cx: &AppContext,
526 ) -> Task<Result<Model<Self>>> {
527 cx.spawn(|mut cx| async move {
528 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
529 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
530 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
531
532 let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
533 let this = cx.new_model(|_| Self {
534 client: client.clone(),
535 unique_identifier: unique_identifier.clone(),
536 connection_options: connection_options.clone(),
537 state: Arc::new(Mutex::new(Some(State::Connecting))),
538 })?;
539
540 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
541 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
542
543 let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
544 unique_identifier,
545 false,
546 connection_options,
547 delegate.clone(),
548 &mut cx,
549 )
550 .await?;
551
552 let multiplex_task = Self::multiplex(
553 this.downgrade(),
554 ssh_proxy_process,
555 proxy_incoming_tx,
556 proxy_outgoing_rx,
557 connection_activity_tx,
558 &mut cx,
559 );
560
561 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
562 log::error!("failed to establish connection: {}", error);
563 delegate.set_error(error.to_string(), &mut cx);
564 return Err(error);
565 }
566
567 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
568
569 this.update(&mut cx, |this, _| {
570 *this.state.lock() = Some(State::Connected {
571 ssh_connection,
572 delegate,
573 forwarder: proxy,
574 multiplex_task,
575 heartbeat_task,
576 });
577 })?;
578
579 Ok(this)
580 })
581 }
582
583 pub fn shutdown_processes<T: RequestMessage>(
584 &self,
585 shutdown_request: Option<T>,
586 ) -> Option<impl Future<Output = ()>> {
587 let state = self.state.lock().take()?;
588 log::info!("shutting down ssh processes");
589
590 let State::Connected {
591 multiplex_task,
592 heartbeat_task,
593 ssh_connection,
594 delegate,
595 forwarder,
596 } = state
597 else {
598 return None;
599 };
600
601 let client = self.client.clone();
602
603 Some(async move {
604 if let Some(shutdown_request) = shutdown_request {
605 client.send(shutdown_request).log_err();
606 // We wait 50ms instead of waiting for a response, because
607 // waiting for a response would require us to wait on the main thread
608 // which we want to avoid in an `on_app_quit` callback.
609 smol::Timer::after(Duration::from_millis(50)).await;
610 }
611
612 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
613 // child of master_process.
614 drop(multiplex_task);
615 // Now drop the rest of state, which kills master process.
616 drop(heartbeat_task);
617 drop(ssh_connection);
618 drop(delegate);
619 drop(forwarder);
620 })
621 }
622
623 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
624 let mut lock = self.state.lock();
625
626 let can_reconnect = lock
627 .as_ref()
628 .map(|state| state.can_reconnect())
629 .unwrap_or(false);
630 if !can_reconnect {
631 let error = if let Some(state) = lock.as_ref() {
632 format!("invalid state, cannot reconnect while in state {state}")
633 } else {
634 "no state set".to_string()
635 };
636 log::info!("aborting reconnect, because not in state that allows reconnecting");
637 return Err(anyhow!(error));
638 }
639
640 let state = lock.take().unwrap();
641 let (attempts, mut ssh_connection, delegate, forwarder) = match state {
642 State::Connected {
643 ssh_connection,
644 delegate,
645 forwarder,
646 multiplex_task,
647 heartbeat_task,
648 }
649 | State::HeartbeatMissed {
650 ssh_connection,
651 delegate,
652 forwarder,
653 multiplex_task,
654 heartbeat_task,
655 ..
656 } => {
657 drop(multiplex_task);
658 drop(heartbeat_task);
659 (0, ssh_connection, delegate, forwarder)
660 }
661 State::ReconnectFailed {
662 attempts,
663 ssh_connection,
664 delegate,
665 forwarder,
666 ..
667 } => (attempts, ssh_connection, delegate, forwarder),
668 State::Connecting
669 | State::Reconnecting
670 | State::ReconnectExhausted
671 | State::ServerNotRunning => unreachable!(),
672 };
673
674 let attempts = attempts + 1;
675 if attempts > MAX_RECONNECT_ATTEMPTS {
676 log::error!(
677 "Failed to reconnect to after {} attempts, giving up",
678 MAX_RECONNECT_ATTEMPTS
679 );
680 drop(lock);
681 self.set_state(State::ReconnectExhausted, cx);
682 return Ok(());
683 }
684 drop(lock);
685
686 self.set_state(State::Reconnecting, cx);
687
688 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
689
690 let identifier = self.unique_identifier.clone();
691 let client = self.client.clone();
692 let reconnect_task = cx.spawn(|this, mut cx| async move {
693 macro_rules! failed {
694 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
695 return State::ReconnectFailed {
696 error: anyhow!($error),
697 attempts: $attempts,
698 ssh_connection: $ssh_connection,
699 delegate: $delegate,
700 forwarder: $forwarder,
701 };
702 };
703 }
704
705 if let Err(error) = ssh_connection.master_process.kill() {
706 failed!(error, attempts, ssh_connection, delegate, forwarder);
707 };
708
709 if let Err(error) = ssh_connection
710 .master_process
711 .status()
712 .await
713 .context("Failed to kill ssh process")
714 {
715 failed!(error, attempts, ssh_connection, delegate, forwarder);
716 }
717
718 let connection_options = ssh_connection.socket.connection_options.clone();
719
720 let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
721 let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
722 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
723 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
724
725 let (ssh_connection, ssh_process) = match Self::establish_connection(
726 identifier,
727 true,
728 connection_options,
729 delegate.clone(),
730 &mut cx,
731 )
732 .await
733 {
734 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
735 Err(error) => {
736 failed!(error, attempts, ssh_connection, delegate, forwarder);
737 }
738 };
739
740 let multiplex_task = Self::multiplex(
741 this.clone(),
742 ssh_process,
743 proxy_incoming_tx,
744 proxy_outgoing_rx,
745 connection_activity_tx,
746 &mut cx,
747 );
748
749 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
750 failed!(error, attempts, ssh_connection, delegate, forwarder);
751 };
752
753 State::Connected {
754 ssh_connection,
755 delegate,
756 forwarder,
757 multiplex_task,
758 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
759 }
760 });
761
762 cx.spawn(|this, mut cx| async move {
763 let new_state = reconnect_task.await;
764 this.update(&mut cx, |this, cx| {
765 this.try_set_state(cx, |old_state| {
766 if old_state.is_reconnecting() {
767 match &new_state {
768 State::Connecting
769 | State::Reconnecting { .. }
770 | State::HeartbeatMissed { .. }
771 | State::ServerNotRunning => {}
772 State::Connected { .. } => {
773 log::info!("Successfully reconnected");
774 }
775 State::ReconnectFailed {
776 error, attempts, ..
777 } => {
778 log::error!(
779 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
780 attempts,
781 error
782 );
783 }
784 State::ReconnectExhausted => {
785 log::error!("Reconnect attempt failed and all attempts exhausted");
786 }
787 }
788 Some(new_state)
789 } else {
790 None
791 }
792 });
793
794 if this.state_is(State::is_reconnect_failed) {
795 this.reconnect(cx)
796 } else if this.state_is(State::is_reconnect_exhausted) {
797 cx.emit(SshRemoteEvent::Disconnected);
798 Ok(())
799 } else {
800 log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
801 Ok(())
802 }
803 })
804 })
805 .detach_and_log_err(cx);
806
807 Ok(())
808 }
809
810 fn heartbeat(
811 this: WeakModel<Self>,
812 mut connection_activity_rx: mpsc::Receiver<()>,
813 cx: &mut AsyncAppContext,
814 ) -> Task<Result<()>> {
815 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
816 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
817 };
818
819 cx.spawn(|mut cx| {
820 let this = this.clone();
821 async move {
822 let mut missed_heartbeats = 0;
823
824 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
825 futures::pin_mut!(keepalive_timer);
826
827 loop {
828 select_biased! {
829 result = connection_activity_rx.next().fuse() => {
830 if result.is_none() {
831 log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
832 return Ok(());
833 }
834
835 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
836
837 if missed_heartbeats != 0 {
838 missed_heartbeats = 0;
839 this.update(&mut cx, |this, mut cx| {
840 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
841 })?;
842 }
843 }
844 _ = keepalive_timer => {
845 log::debug!("Sending heartbeat to server...");
846
847 let result = select_biased! {
848 _ = connection_activity_rx.next().fuse() => {
849 Ok(())
850 }
851 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
852 ping_result
853 }
854 };
855
856 if result.is_err() {
857 missed_heartbeats += 1;
858 log::warn!(
859 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
860 HEARTBEAT_TIMEOUT,
861 missed_heartbeats,
862 MAX_MISSED_HEARTBEATS
863 );
864 } else if missed_heartbeats != 0 {
865 missed_heartbeats = 0;
866 } else {
867 continue;
868 }
869
870 let result = this.update(&mut cx, |this, mut cx| {
871 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
872 })?;
873 if result.is_break() {
874 return Ok(());
875 }
876 }
877 }
878 }
879 }
880 })
881 }
882
883 fn handle_heartbeat_result(
884 &mut self,
885 missed_heartbeats: usize,
886 cx: &mut ModelContext<Self>,
887 ) -> ControlFlow<()> {
888 let state = self.state.lock().take().unwrap();
889 let next_state = if missed_heartbeats > 0 {
890 state.heartbeat_missed()
891 } else {
892 state.heartbeat_recovered()
893 };
894
895 self.set_state(next_state, cx);
896
897 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
898 log::error!(
899 "Missed last {} heartbeats. Reconnecting...",
900 missed_heartbeats
901 );
902
903 self.reconnect(cx)
904 .context("failed to start reconnect process after missing heartbeats")
905 .log_err();
906 ControlFlow::Break(())
907 } else {
908 ControlFlow::Continue(())
909 }
910 }
911
912 fn multiplex(
913 this: WeakModel<Self>,
914 mut ssh_proxy_process: Child,
915 incoming_tx: UnboundedSender<Envelope>,
916 mut outgoing_rx: UnboundedReceiver<Envelope>,
917 mut connection_activity_tx: Sender<()>,
918 cx: &AsyncAppContext,
919 ) -> Task<Result<()>> {
920 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
921 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
922 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
923
924 let mut stdin_buffer = Vec::new();
925 let mut stdout_buffer = Vec::new();
926 let mut stderr_buffer = Vec::new();
927 let mut stderr_offset = 0;
928
929 let stdin_task = cx.background_executor().spawn(async move {
930 while let Some(outgoing) = outgoing_rx.next().await {
931 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
932 }
933 anyhow::Ok(())
934 });
935
936 let stdout_task = cx.background_executor().spawn({
937 let mut connection_activity_tx = connection_activity_tx.clone();
938 async move {
939 loop {
940 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
941 let len = child_stdout.read(&mut stdout_buffer).await?;
942
943 if len == 0 {
944 return anyhow::Ok(());
945 }
946
947 if len < MESSAGE_LEN_SIZE {
948 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
949 }
950
951 let message_len = message_len_from_buffer(&stdout_buffer);
952 let envelope =
953 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
954 .await?;
955 connection_activity_tx.try_send(()).ok();
956 incoming_tx.unbounded_send(envelope).ok();
957 }
958 }
959 });
960
961 let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
962 loop {
963 stderr_buffer.resize(stderr_offset + 1024, 0);
964
965 let len = child_stderr
966 .read(&mut stderr_buffer[stderr_offset..])
967 .await?;
968
969 stderr_offset += len;
970 let mut start_ix = 0;
971 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
972 .iter()
973 .position(|b| b == &b'\n')
974 {
975 let line_ix = start_ix + ix;
976 let content = &stderr_buffer[start_ix..line_ix];
977 start_ix = line_ix + 1;
978 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
979 record.log(log::logger())
980 } else {
981 eprintln!("(remote) {}", String::from_utf8_lossy(content));
982 }
983 }
984 stderr_buffer.drain(0..start_ix);
985 stderr_offset -= start_ix;
986
987 connection_activity_tx.try_send(()).ok();
988 }
989 });
990
991 cx.spawn(|mut cx| async move {
992 let result = futures::select! {
993 result = stdin_task.fuse() => {
994 result.context("stdin")
995 }
996 result = stdout_task.fuse() => {
997 result.context("stdout")
998 }
999 result = stderr_task.fuse() => {
1000 result.context("stderr")
1001 }
1002 };
1003
1004 match result {
1005 Ok(_) => {
1006 let exit_code = ssh_proxy_process.status().await?.code().unwrap_or(1);
1007
1008 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
1009 match error {
1010 ProxyLaunchError::ServerNotRunning => {
1011 log::error!("failed to reconnect because server is not running");
1012 this.update(&mut cx, |this, cx| {
1013 this.set_state(State::ServerNotRunning, cx);
1014 cx.emit(SshRemoteEvent::Disconnected);
1015 })?;
1016 }
1017 }
1018 } else if exit_code > 0 {
1019 log::error!("proxy process terminated unexpectedly");
1020 this.update(&mut cx, |this, cx| {
1021 this.reconnect(cx).ok();
1022 })?;
1023 }
1024 }
1025 Err(error) => {
1026 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
1027 this.update(&mut cx, |this, cx| {
1028 this.reconnect(cx).ok();
1029 })?;
1030 }
1031 }
1032
1033 Ok(())
1034 })
1035 }
1036
1037 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
1038 self.state.lock().as_ref().map_or(false, check)
1039 }
1040
1041 fn try_set_state(
1042 &self,
1043 cx: &mut ModelContext<Self>,
1044 map: impl FnOnce(&State) -> Option<State>,
1045 ) {
1046 let mut lock = self.state.lock();
1047 let new_state = lock.as_ref().and_then(map);
1048
1049 if let Some(new_state) = new_state {
1050 lock.replace(new_state);
1051 cx.notify();
1052 }
1053 }
1054
1055 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
1056 log::info!("setting state to '{}'", &state);
1057 self.state.lock().replace(state);
1058 cx.notify();
1059 }
1060
1061 async fn establish_connection(
1062 unique_identifier: String,
1063 reconnect: bool,
1064 connection_options: SshConnectionOptions,
1065 delegate: Arc<dyn SshClientDelegate>,
1066 cx: &mut AsyncAppContext,
1067 ) -> Result<(SshRemoteConnection, Child)> {
1068 let ssh_connection =
1069 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
1070
1071 let platform = ssh_connection.query_platform().await?;
1072 let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
1073 ssh_connection
1074 .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
1075 .await?;
1076
1077 let socket = ssh_connection.socket.clone();
1078 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
1079
1080 delegate.set_status(Some("Starting proxy"), cx);
1081
1082 let mut start_proxy_command = format!(
1083 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
1084 std::env::var("RUST_LOG").unwrap_or_default(),
1085 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
1086 remote_binary_path,
1087 unique_identifier,
1088 );
1089 if reconnect {
1090 start_proxy_command.push_str(" --reconnect");
1091 }
1092
1093 let ssh_proxy_process = socket
1094 .ssh_command(start_proxy_command)
1095 // IMPORTANT: we kill this process when we drop the task that uses it.
1096 .kill_on_drop(true)
1097 .spawn()
1098 .context("failed to spawn remote server")?;
1099
1100 Ok((ssh_connection, ssh_proxy_process))
1101 }
1102
1103 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1104 self.client.subscribe_to_entity(remote_id, entity);
1105 }
1106
1107 pub fn ssh_args(&self) -> Option<Vec<String>> {
1108 self.state
1109 .lock()
1110 .as_ref()
1111 .and_then(|state| state.ssh_connection())
1112 .map(|ssh_connection| ssh_connection.socket.ssh_args())
1113 }
1114
1115 pub fn proto_client(&self) -> AnyProtoClient {
1116 self.client.clone().into()
1117 }
1118
1119 pub fn connection_string(&self) -> String {
1120 self.connection_options.connection_string()
1121 }
1122
1123 pub fn connection_options(&self) -> SshConnectionOptions {
1124 self.connection_options.clone()
1125 }
1126
1127 #[cfg(not(any(test, feature = "test-support")))]
1128 pub fn connection_state(&self) -> ConnectionState {
1129 self.state
1130 .lock()
1131 .as_ref()
1132 .map(ConnectionState::from)
1133 .unwrap_or(ConnectionState::Disconnected)
1134 }
1135
1136 #[cfg(any(test, feature = "test-support"))]
1137 pub fn connection_state(&self) -> ConnectionState {
1138 ConnectionState::Connected
1139 }
1140
1141 pub fn is_disconnected(&self) -> bool {
1142 self.connection_state() == ConnectionState::Disconnected
1143 }
1144
1145 #[cfg(any(test, feature = "test-support"))]
1146 pub fn fake(
1147 client_cx: &mut gpui::TestAppContext,
1148 server_cx: &mut gpui::TestAppContext,
1149 ) -> (Model<Self>, Arc<ChannelClient>) {
1150 use gpui::Context;
1151
1152 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1153 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1154
1155 (
1156 client_cx.update(|cx| {
1157 let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1158 cx.new_model(|_| Self {
1159 client,
1160 unique_identifier: "fake".to_string(),
1161 connection_options: SshConnectionOptions::default(),
1162 state: Arc::new(Mutex::new(None)),
1163 })
1164 }),
1165 server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1166 )
1167 }
1168}
1169
1170impl From<SshRemoteClient> for AnyProtoClient {
1171 fn from(client: SshRemoteClient) -> Self {
1172 AnyProtoClient::new(client.client.clone())
1173 }
1174}
1175
1176struct SshRemoteConnection {
1177 socket: SshSocket,
1178 master_process: process::Child,
1179 _temp_dir: TempDir,
1180}
1181
1182impl Drop for SshRemoteConnection {
1183 fn drop(&mut self) {
1184 if let Err(error) = self.master_process.kill() {
1185 log::error!("failed to kill SSH master process: {}", error);
1186 }
1187 }
1188}
1189
1190impl SshRemoteConnection {
1191 #[cfg(not(unix))]
1192 async fn new(
1193 _connection_options: SshConnectionOptions,
1194 _delegate: Arc<dyn SshClientDelegate>,
1195 _cx: &mut AsyncAppContext,
1196 ) -> Result<Self> {
1197 Err(anyhow!("ssh is not supported on this platform"))
1198 }
1199
1200 #[cfg(unix)]
1201 async fn new(
1202 connection_options: SshConnectionOptions,
1203 delegate: Arc<dyn SshClientDelegate>,
1204 cx: &mut AsyncAppContext,
1205 ) -> Result<Self> {
1206 use futures::AsyncWriteExt as _;
1207 use futures::{io::BufReader, AsyncBufReadExt as _};
1208 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1209 use util::ResultExt as _;
1210
1211 delegate.set_status(Some("connecting"), cx);
1212
1213 let url = connection_options.ssh_url();
1214 let temp_dir = tempfile::Builder::new()
1215 .prefix("zed-ssh-session")
1216 .tempdir()?;
1217
1218 // Create a domain socket listener to handle requests from the askpass program.
1219 let askpass_socket = temp_dir.path().join("askpass.sock");
1220 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1221 let listener =
1222 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1223
1224 let askpass_task = cx.spawn({
1225 let delegate = delegate.clone();
1226 |mut cx| async move {
1227 let mut askpass_opened_tx = Some(askpass_opened_tx);
1228
1229 while let Ok((mut stream, _)) = listener.accept().await {
1230 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1231 askpass_opened_tx.send(()).ok();
1232 }
1233 let mut buffer = Vec::new();
1234 let mut reader = BufReader::new(&mut stream);
1235 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1236 buffer.clear();
1237 }
1238 let password_prompt = String::from_utf8_lossy(&buffer);
1239 if let Some(password) = delegate
1240 .ask_password(password_prompt.to_string(), &mut cx)
1241 .await
1242 .context("failed to get ssh password")
1243 .and_then(|p| p)
1244 .log_err()
1245 {
1246 stream.write_all(password.as_bytes()).await.log_err();
1247 }
1248 }
1249 }
1250 });
1251
1252 // Create an askpass script that communicates back to this process.
1253 let askpass_script = format!(
1254 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1255 askpass_socket = askpass_socket.display(),
1256 print_args = "printf '%s\\0' \"$@\"",
1257 shebang = "#!/bin/sh",
1258 );
1259 let askpass_script_path = temp_dir.path().join("askpass.sh");
1260 fs::write(&askpass_script_path, askpass_script).await?;
1261 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1262
1263 // Start the master SSH process, which does not do anything except for establish
1264 // the connection and keep it open, allowing other ssh commands to reuse it
1265 // via a control socket.
1266 let socket_path = temp_dir.path().join("ssh.sock");
1267 let mut master_process = process::Command::new("ssh")
1268 .stdin(Stdio::null())
1269 .stdout(Stdio::piped())
1270 .stderr(Stdio::piped())
1271 .env("SSH_ASKPASS_REQUIRE", "force")
1272 .env("SSH_ASKPASS", &askpass_script_path)
1273 .args(connection_options.additional_args().unwrap_or(&Vec::new()))
1274 .args([
1275 "-N",
1276 "-o",
1277 "ControlPersist=no",
1278 "-o",
1279 "ControlMaster=yes",
1280 "-o",
1281 ])
1282 .arg(format!("ControlPath={}", socket_path.display()))
1283 .arg(&url)
1284 .spawn()?;
1285
1286 // Wait for this ssh process to close its stdout, indicating that authentication
1287 // has completed.
1288 let stdout = master_process.stdout.as_mut().unwrap();
1289 let mut output = Vec::new();
1290 let connection_timeout = Duration::from_secs(10);
1291
1292 let result = select_biased! {
1293 _ = askpass_opened_rx.fuse() => {
1294 // If the askpass script has opened, that means the user is typing
1295 // their password, in which case we don't want to timeout anymore,
1296 // since we know a connection has been established.
1297 stdout.read_to_end(&mut output).await?;
1298 Ok(())
1299 }
1300 result = stdout.read_to_end(&mut output).fuse() => {
1301 result?;
1302 Ok(())
1303 }
1304 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1305 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1306 }
1307 };
1308
1309 if let Err(e) = result {
1310 let error_message = format!("Failed to connect to host: {}.", e);
1311 delegate.set_error(error_message, cx);
1312 return Err(e);
1313 }
1314
1315 drop(askpass_task);
1316
1317 if master_process.try_status()?.is_some() {
1318 output.clear();
1319 let mut stderr = master_process.stderr.take().unwrap();
1320 stderr.read_to_end(&mut output).await?;
1321
1322 let error_message = format!(
1323 "failed to connect: {}",
1324 String::from_utf8_lossy(&output).trim()
1325 );
1326 delegate.set_error(error_message.clone(), cx);
1327 Err(anyhow!(error_message))?;
1328 }
1329
1330 Ok(Self {
1331 socket: SshSocket {
1332 connection_options,
1333 socket_path,
1334 },
1335 master_process,
1336 _temp_dir: temp_dir,
1337 })
1338 }
1339
1340 async fn ensure_server_binary(
1341 &self,
1342 delegate: &Arc<dyn SshClientDelegate>,
1343 dst_path: &Path,
1344 platform: SshPlatform,
1345 cx: &mut AsyncAppContext,
1346 ) -> Result<()> {
1347 if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
1348 if let Ok(installed_version) =
1349 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1350 {
1351 log::info!("using cached server binary version {}", installed_version);
1352 return Ok(());
1353 }
1354 }
1355
1356 let mut dst_path_gz = dst_path.to_path_buf();
1357 dst_path_gz.set_extension("gz");
1358
1359 if let Some(parent) = dst_path.parent() {
1360 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1361 }
1362
1363 let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
1364
1365 let mut server_binary_exists = false;
1366 if !server_binary_exists && cfg!(not(debug_assertions)) {
1367 if let Ok(installed_version) =
1368 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1369 {
1370 if installed_version.trim() == version.to_string() {
1371 server_binary_exists = true;
1372 }
1373 }
1374 }
1375
1376 if server_binary_exists {
1377 log::info!("remote development server already present",);
1378 return Ok(());
1379 }
1380
1381 let src_stat = fs::metadata(&src_path).await?;
1382 let size = src_stat.len();
1383 let server_mode = 0o755;
1384
1385 let t0 = Instant::now();
1386 delegate.set_status(Some("Uploading remote development server"), cx);
1387 log::info!("uploading remote development server ({}kb)", size / 1024);
1388 self.upload_file(&src_path, &dst_path_gz)
1389 .await
1390 .context("failed to upload server binary")?;
1391 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1392
1393 delegate.set_status(Some("Extracting remote development server"), cx);
1394 run_cmd(
1395 self.socket
1396 .ssh_command("gunzip")
1397 .arg("--force")
1398 .arg(&dst_path_gz),
1399 )
1400 .await?;
1401
1402 delegate.set_status(Some("Marking remote development server executable"), cx);
1403 run_cmd(
1404 self.socket
1405 .ssh_command("chmod")
1406 .arg(format!("{:o}", server_mode))
1407 .arg(dst_path),
1408 )
1409 .await?;
1410
1411 Ok(())
1412 }
1413
1414 async fn query_platform(&self) -> Result<SshPlatform> {
1415 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1416 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1417
1418 let os = match os.trim() {
1419 "Darwin" => "macos",
1420 "Linux" => "linux",
1421 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1422 };
1423 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1424 "aarch64"
1425 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1426 "x86_64"
1427 } else {
1428 Err(anyhow!("unknown uname architecture {arch:?}"))?
1429 };
1430
1431 Ok(SshPlatform { os, arch })
1432 }
1433
1434 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1435 let mut command = process::Command::new("scp");
1436 let output = self
1437 .socket
1438 .ssh_options(&mut command)
1439 .args(
1440 self.socket
1441 .connection_options
1442 .port
1443 .map(|port| vec!["-P".to_string(), port.to_string()])
1444 .unwrap_or_default(),
1445 )
1446 .arg(src_path)
1447 .arg(format!(
1448 "{}:{}",
1449 self.socket.connection_options.scp_url(),
1450 dest_path.display()
1451 ))
1452 .output()
1453 .await?;
1454
1455 if output.status.success() {
1456 Ok(())
1457 } else {
1458 Err(anyhow!(
1459 "failed to upload file {} -> {}: {}",
1460 src_path.display(),
1461 dest_path.display(),
1462 String::from_utf8_lossy(&output.stderr)
1463 ))
1464 }
1465 }
1466}
1467
1468type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1469
1470pub struct ChannelClient {
1471 next_message_id: AtomicU32,
1472 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1473 response_channels: ResponseChannels, // Lock
1474 message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1475}
1476
1477impl ChannelClient {
1478 pub fn new(
1479 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1480 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1481 cx: &AppContext,
1482 ) -> Arc<Self> {
1483 let this = Arc::new(Self {
1484 outgoing_tx,
1485 next_message_id: AtomicU32::new(0),
1486 response_channels: ResponseChannels::default(),
1487 message_handlers: Default::default(),
1488 });
1489
1490 Self::start_handling_messages(this.clone(), incoming_rx, cx);
1491
1492 this
1493 }
1494
1495 fn start_handling_messages(
1496 this: Arc<Self>,
1497 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1498 cx: &AppContext,
1499 ) {
1500 cx.spawn(|cx| {
1501 let this = Arc::downgrade(&this);
1502 async move {
1503 let peer_id = PeerId { owner_id: 0, id: 0 };
1504 while let Some(incoming) = incoming_rx.next().await {
1505 let Some(this) = this.upgrade() else {
1506 return anyhow::Ok(());
1507 };
1508
1509 if let Some(request_id) = incoming.responding_to {
1510 let request_id = MessageId(request_id);
1511 let sender = this.response_channels.lock().remove(&request_id);
1512 if let Some(sender) = sender {
1513 let (tx, rx) = oneshot::channel();
1514 if incoming.payload.is_some() {
1515 sender.send((incoming, tx)).ok();
1516 }
1517 rx.await.ok();
1518 }
1519 } else if let Some(envelope) =
1520 build_typed_envelope(peer_id, Instant::now(), incoming)
1521 {
1522 let type_name = envelope.payload_type_name();
1523 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1524 &this.message_handlers,
1525 envelope,
1526 this.clone().into(),
1527 cx.clone(),
1528 ) {
1529 log::debug!("ssh message received. name:{type_name}");
1530 match future.await {
1531 Ok(_) => {
1532 log::debug!("ssh message handled. name:{type_name}");
1533 }
1534 Err(error) => {
1535 log::error!(
1536 "error handling message. type:{type_name}, error:{error}",
1537 );
1538 }
1539 }
1540 } else {
1541 log::error!("unhandled ssh message name:{type_name}");
1542 }
1543 }
1544 }
1545 anyhow::Ok(())
1546 }
1547 })
1548 .detach();
1549 }
1550
1551 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1552 let id = (TypeId::of::<E>(), remote_id);
1553
1554 let mut message_handlers = self.message_handlers.lock();
1555 if message_handlers
1556 .entities_by_type_and_remote_id
1557 .contains_key(&id)
1558 {
1559 panic!("already subscribed to entity");
1560 }
1561
1562 message_handlers.entities_by_type_and_remote_id.insert(
1563 id,
1564 EntityMessageSubscriber::Entity {
1565 handle: entity.downgrade().into(),
1566 },
1567 );
1568 }
1569
1570 pub fn request<T: RequestMessage>(
1571 &self,
1572 payload: T,
1573 ) -> impl 'static + Future<Output = Result<T::Response>> {
1574 log::debug!("ssh request start. name:{}", T::NAME);
1575 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1576 async move {
1577 let response = response.await?;
1578 log::debug!("ssh request finish. name:{}", T::NAME);
1579 T::Response::from_envelope(response)
1580 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1581 }
1582 }
1583
1584 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1585 smol::future::or(
1586 async {
1587 self.request(proto::Ping {}).await?;
1588 Ok(())
1589 },
1590 async {
1591 smol::Timer::after(timeout).await;
1592 Err(anyhow!("Timeout detected"))
1593 },
1594 )
1595 .await
1596 }
1597
1598 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1599 log::debug!("ssh send name:{}", T::NAME);
1600 self.send_dynamic(payload.into_envelope(0, None, None))
1601 }
1602
1603 pub fn request_dynamic(
1604 &self,
1605 mut envelope: proto::Envelope,
1606 type_name: &'static str,
1607 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1608 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1609 let (tx, rx) = oneshot::channel();
1610 let mut response_channels_lock = self.response_channels.lock();
1611 response_channels_lock.insert(MessageId(envelope.id), tx);
1612 drop(response_channels_lock);
1613 let result = self.outgoing_tx.unbounded_send(envelope);
1614 async move {
1615 if let Err(error) = &result {
1616 log::error!("failed to send message: {}", error);
1617 return Err(anyhow!("failed to send message: {}", error));
1618 }
1619
1620 let response = rx.await.context("connection lost")?.0;
1621 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1622 return Err(RpcError::from_proto(error, type_name));
1623 }
1624 Ok(response)
1625 }
1626 }
1627
1628 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1629 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1630 self.outgoing_tx.unbounded_send(envelope)?;
1631 Ok(())
1632 }
1633}
1634
1635impl ProtoClient for ChannelClient {
1636 fn request(
1637 &self,
1638 envelope: proto::Envelope,
1639 request_type: &'static str,
1640 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1641 self.request_dynamic(envelope, request_type).boxed()
1642 }
1643
1644 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1645 self.send_dynamic(envelope)
1646 }
1647
1648 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1649 self.send_dynamic(envelope)
1650 }
1651
1652 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1653 &self.message_handlers
1654 }
1655
1656 fn is_via_collab(&self) -> bool {
1657 false
1658 }
1659}