1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use collections::HashMap;
10use futures::{
11 channel::{
12 mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
13 oneshot,
14 },
15 future::BoxFuture,
16 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
17 StreamExt as _,
18};
19use gpui::{
20 AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
21 WeakModel,
22};
23use parking_lot::Mutex;
24use rpc::{
25 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
26 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
27};
28use smol::{
29 fs,
30 process::{self, Child, Stdio},
31};
32use std::{
33 any::TypeId,
34 ffi::OsStr,
35 fmt,
36 ops::ControlFlow,
37 path::{Path, PathBuf},
38 sync::{
39 atomic::{AtomicU32, Ordering::SeqCst},
40 Arc,
41 },
42 time::{Duration, Instant},
43};
44use tempfile::TempDir;
45use util::ResultExt;
46
47#[derive(
48 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
49)]
50pub struct SshProjectId(pub u64);
51
52#[derive(Clone)]
53pub struct SshSocket {
54 connection_options: SshConnectionOptions,
55 socket_path: PathBuf,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct SshConnectionOptions {
60 pub host: String,
61 pub username: Option<String>,
62 pub port: Option<u16>,
63 pub password: Option<String>,
64}
65
66impl SshConnectionOptions {
67 pub fn ssh_url(&self) -> String {
68 let mut result = String::from("ssh://");
69 if let Some(username) = &self.username {
70 result.push_str(username);
71 result.push('@');
72 }
73 result.push_str(&self.host);
74 if let Some(port) = self.port {
75 result.push(':');
76 result.push_str(&port.to_string());
77 }
78 result
79 }
80
81 fn scp_url(&self) -> String {
82 if let Some(username) = &self.username {
83 format!("{}@{}", username, self.host)
84 } else {
85 self.host.clone()
86 }
87 }
88
89 pub fn connection_string(&self) -> String {
90 let host = if let Some(username) = &self.username {
91 format!("{}@{}", username, self.host)
92 } else {
93 self.host.clone()
94 };
95 if let Some(port) = &self.port {
96 format!("{}:{}", host, port)
97 } else {
98 host
99 }
100 }
101
102 // Uniquely identifies dev server projects on a remote host. Needs to be
103 // stable for the same dev server project.
104 pub fn dev_server_identifier(&self) -> String {
105 let mut identifier = format!("dev-server-{:?}", self.host);
106 if let Some(username) = self.username.as_ref() {
107 identifier.push('-');
108 identifier.push_str(&username);
109 }
110 identifier
111 }
112}
113
114#[derive(Copy, Clone, Debug)]
115pub struct SshPlatform {
116 pub os: &'static str,
117 pub arch: &'static str,
118}
119
120impl SshPlatform {
121 pub fn triple(&self) -> Option<String> {
122 Some(format!(
123 "{}-{}",
124 self.arch,
125 match self.os {
126 "linux" => "unknown-linux-gnu",
127 "macos" => "apple-darwin",
128 _ => return None,
129 }
130 ))
131 }
132}
133
134pub trait SshClientDelegate: Send + Sync {
135 fn ask_password(
136 &self,
137 prompt: String,
138 cx: &mut AsyncAppContext,
139 ) -> oneshot::Receiver<Result<String>>;
140 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
141 fn get_server_binary(
142 &self,
143 platform: SshPlatform,
144 cx: &mut AsyncAppContext,
145 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
146 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
147 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
148}
149
150impl SshSocket {
151 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
152 let mut command = process::Command::new("ssh");
153 self.ssh_options(&mut command)
154 .arg(self.connection_options.ssh_url())
155 .arg(program);
156 command
157 }
158
159 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
160 command
161 .stdin(Stdio::piped())
162 .stdout(Stdio::piped())
163 .stderr(Stdio::piped())
164 .args(["-o", "ControlMaster=no", "-o"])
165 .arg(format!("ControlPath={}", self.socket_path.display()))
166 }
167
168 fn ssh_args(&self) -> Vec<String> {
169 vec![
170 "-o".to_string(),
171 "ControlMaster=no".to_string(),
172 "-o".to_string(),
173 format!("ControlPath={}", self.socket_path.display()),
174 self.connection_options.ssh_url(),
175 ]
176 }
177}
178
179async fn run_cmd(command: &mut process::Command) -> Result<String> {
180 let output = command.output().await?;
181 if output.status.success() {
182 Ok(String::from_utf8_lossy(&output.stdout).to_string())
183 } else {
184 Err(anyhow!(
185 "failed to run command: {}",
186 String::from_utf8_lossy(&output.stderr)
187 ))
188 }
189}
190
191struct ChannelForwarder {
192 quit_tx: UnboundedSender<()>,
193 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
194}
195
196impl ChannelForwarder {
197 fn new(
198 mut incoming_tx: UnboundedSender<Envelope>,
199 mut outgoing_rx: UnboundedReceiver<Envelope>,
200 cx: &AsyncAppContext,
201 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
202 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
203
204 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
205 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
206
207 let forwarding_task = cx.background_executor().spawn(async move {
208 loop {
209 select_biased! {
210 _ = quit_rx.next().fuse() => {
211 break;
212 },
213 incoming_envelope = proxy_incoming_rx.next().fuse() => {
214 if let Some(envelope) = incoming_envelope {
215 if incoming_tx.send(envelope).await.is_err() {
216 break;
217 }
218 } else {
219 break;
220 }
221 }
222 outgoing_envelope = outgoing_rx.next().fuse() => {
223 if let Some(envelope) = outgoing_envelope {
224 if proxy_outgoing_tx.send(envelope).await.is_err() {
225 break;
226 }
227 } else {
228 break;
229 }
230 }
231 }
232 }
233
234 (incoming_tx, outgoing_rx)
235 });
236
237 (
238 Self {
239 forwarding_task,
240 quit_tx,
241 },
242 proxy_incoming_tx,
243 proxy_outgoing_rx,
244 )
245 }
246
247 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
248 let _ = self.quit_tx.send(()).await;
249 self.forwarding_task.await
250 }
251}
252
253const MAX_MISSED_HEARTBEATS: usize = 5;
254const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
255const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
256
257const MAX_RECONNECT_ATTEMPTS: usize = 3;
258
259enum State {
260 Connecting,
261 Connected {
262 ssh_connection: SshRemoteConnection,
263 delegate: Arc<dyn SshClientDelegate>,
264 forwarder: ChannelForwarder,
265
266 multiplex_task: Task<Result<()>>,
267 heartbeat_task: Task<Result<()>>,
268 },
269 HeartbeatMissed {
270 missed_heartbeats: usize,
271
272 ssh_connection: SshRemoteConnection,
273 delegate: Arc<dyn SshClientDelegate>,
274 forwarder: ChannelForwarder,
275
276 multiplex_task: Task<Result<()>>,
277 heartbeat_task: Task<Result<()>>,
278 },
279 Reconnecting,
280 ReconnectFailed {
281 ssh_connection: SshRemoteConnection,
282 delegate: Arc<dyn SshClientDelegate>,
283 forwarder: ChannelForwarder,
284
285 error: anyhow::Error,
286 attempts: usize,
287 },
288 ReconnectExhausted,
289 ServerNotRunning,
290}
291
292impl fmt::Display for State {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 match self {
295 Self::Connecting => write!(f, "connecting"),
296 Self::Connected { .. } => write!(f, "connected"),
297 Self::Reconnecting => write!(f, "reconnecting"),
298 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
299 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
300 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
301 Self::ServerNotRunning { .. } => write!(f, "server not running"),
302 }
303 }
304}
305
306impl State {
307 fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
308 match self {
309 Self::Connected { ssh_connection, .. } => Some(ssh_connection),
310 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
311 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
312 _ => None,
313 }
314 }
315
316 fn can_reconnect(&self) -> bool {
317 match self {
318 Self::Connected { .. }
319 | Self::HeartbeatMissed { .. }
320 | Self::ReconnectFailed { .. } => true,
321 State::Connecting
322 | State::Reconnecting
323 | State::ReconnectExhausted
324 | State::ServerNotRunning => false,
325 }
326 }
327
328 fn is_reconnect_failed(&self) -> bool {
329 matches!(self, Self::ReconnectFailed { .. })
330 }
331
332 fn is_reconnect_exhausted(&self) -> bool {
333 matches!(self, Self::ReconnectExhausted { .. })
334 }
335
336 fn is_reconnecting(&self) -> bool {
337 matches!(self, Self::Reconnecting { .. })
338 }
339
340 fn heartbeat_recovered(self) -> Self {
341 match self {
342 Self::HeartbeatMissed {
343 ssh_connection,
344 delegate,
345 forwarder,
346 multiplex_task,
347 heartbeat_task,
348 ..
349 } => Self::Connected {
350 ssh_connection,
351 delegate,
352 forwarder,
353 multiplex_task,
354 heartbeat_task,
355 },
356 _ => self,
357 }
358 }
359
360 fn heartbeat_missed(self) -> Self {
361 match self {
362 Self::Connected {
363 ssh_connection,
364 delegate,
365 forwarder,
366 multiplex_task,
367 heartbeat_task,
368 } => Self::HeartbeatMissed {
369 missed_heartbeats: 1,
370 ssh_connection,
371 delegate,
372 forwarder,
373 multiplex_task,
374 heartbeat_task,
375 },
376 Self::HeartbeatMissed {
377 missed_heartbeats,
378 ssh_connection,
379 delegate,
380 forwarder,
381 multiplex_task,
382 heartbeat_task,
383 } => Self::HeartbeatMissed {
384 missed_heartbeats: missed_heartbeats + 1,
385 ssh_connection,
386 delegate,
387 forwarder,
388 multiplex_task,
389 heartbeat_task,
390 },
391 _ => self,
392 }
393 }
394}
395
396/// The state of the ssh connection.
397#[derive(Clone, Copy, Debug, PartialEq, Eq)]
398pub enum ConnectionState {
399 Connecting,
400 Connected,
401 HeartbeatMissed,
402 Reconnecting,
403 Disconnected,
404}
405
406impl From<&State> for ConnectionState {
407 fn from(value: &State) -> Self {
408 match value {
409 State::Connecting => Self::Connecting,
410 State::Connected { .. } => Self::Connected,
411 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
412 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
413 State::ReconnectExhausted => Self::Disconnected,
414 State::ServerNotRunning => Self::Disconnected,
415 }
416 }
417}
418
419pub struct SshRemoteClient {
420 client: Arc<ChannelClient>,
421 unique_identifier: String,
422 connection_options: SshConnectionOptions,
423 state: Arc<Mutex<Option<State>>>,
424}
425
426#[derive(Debug)]
427pub enum SshRemoteEvent {
428 Disconnected,
429}
430
431impl EventEmitter<SshRemoteEvent> for SshRemoteClient {}
432
433impl SshRemoteClient {
434 pub fn new(
435 unique_identifier: String,
436 connection_options: SshConnectionOptions,
437 delegate: Arc<dyn SshClientDelegate>,
438 cx: &AppContext,
439 ) -> Task<Result<Model<Self>>> {
440 cx.spawn(|mut cx| async move {
441 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
442 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
443 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
444
445 let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
446 let this = cx.new_model(|_| Self {
447 client: client.clone(),
448 unique_identifier: unique_identifier.clone(),
449 connection_options: connection_options.clone(),
450 state: Arc::new(Mutex::new(Some(State::Connecting))),
451 })?;
452
453 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
454 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
455
456 let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
457 unique_identifier,
458 false,
459 connection_options,
460 delegate.clone(),
461 &mut cx,
462 )
463 .await?;
464
465 let multiplex_task = Self::multiplex(
466 this.downgrade(),
467 ssh_proxy_process,
468 proxy_incoming_tx,
469 proxy_outgoing_rx,
470 connection_activity_tx,
471 &mut cx,
472 );
473
474 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
475 log::error!("failed to establish connection: {}", error);
476 delegate.set_error(error.to_string(), &mut cx);
477 return Err(error);
478 }
479
480 let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
481
482 this.update(&mut cx, |this, _| {
483 *this.state.lock() = Some(State::Connected {
484 ssh_connection,
485 delegate,
486 forwarder: proxy,
487 multiplex_task,
488 heartbeat_task,
489 });
490 })?;
491
492 Ok(this)
493 })
494 }
495
496 pub fn shutdown_processes<T: RequestMessage>(
497 &self,
498 shutdown_request: Option<T>,
499 ) -> Option<impl Future<Output = ()>> {
500 let state = self.state.lock().take()?;
501 log::info!("shutting down ssh processes");
502
503 let State::Connected {
504 multiplex_task,
505 heartbeat_task,
506 ssh_connection,
507 delegate,
508 forwarder,
509 } = state
510 else {
511 return None;
512 };
513
514 let client = self.client.clone();
515
516 Some(async move {
517 if let Some(shutdown_request) = shutdown_request {
518 client.send(shutdown_request).log_err();
519 // We wait 50ms instead of waiting for a response, because
520 // waiting for a response would require us to wait on the main thread
521 // which we want to avoid in an `on_app_quit` callback.
522 smol::Timer::after(Duration::from_millis(50)).await;
523 }
524
525 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
526 // child of master_process.
527 drop(multiplex_task);
528 // Now drop the rest of state, which kills master process.
529 drop(heartbeat_task);
530 drop(ssh_connection);
531 drop(delegate);
532 drop(forwarder);
533 })
534 }
535
536 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
537 let mut lock = self.state.lock();
538
539 let can_reconnect = lock
540 .as_ref()
541 .map(|state| state.can_reconnect())
542 .unwrap_or(false);
543 if !can_reconnect {
544 let error = if let Some(state) = lock.as_ref() {
545 format!("invalid state, cannot reconnect while in state {state}")
546 } else {
547 "no state set".to_string()
548 };
549 log::info!("aborting reconnect, because not in state that allows reconnecting");
550 return Err(anyhow!(error));
551 }
552
553 let state = lock.take().unwrap();
554 let (attempts, mut ssh_connection, delegate, forwarder) = match state {
555 State::Connected {
556 ssh_connection,
557 delegate,
558 forwarder,
559 multiplex_task,
560 heartbeat_task,
561 }
562 | State::HeartbeatMissed {
563 ssh_connection,
564 delegate,
565 forwarder,
566 multiplex_task,
567 heartbeat_task,
568 ..
569 } => {
570 drop(multiplex_task);
571 drop(heartbeat_task);
572 (0, ssh_connection, delegate, forwarder)
573 }
574 State::ReconnectFailed {
575 attempts,
576 ssh_connection,
577 delegate,
578 forwarder,
579 ..
580 } => (attempts, ssh_connection, delegate, forwarder),
581 State::Connecting
582 | State::Reconnecting
583 | State::ReconnectExhausted
584 | State::ServerNotRunning => unreachable!(),
585 };
586
587 let attempts = attempts + 1;
588 if attempts > MAX_RECONNECT_ATTEMPTS {
589 log::error!(
590 "Failed to reconnect to after {} attempts, giving up",
591 MAX_RECONNECT_ATTEMPTS
592 );
593 drop(lock);
594 self.set_state(State::ReconnectExhausted, cx);
595 return Ok(());
596 }
597 drop(lock);
598
599 self.set_state(State::Reconnecting, cx);
600
601 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
602
603 let identifier = self.unique_identifier.clone();
604 let client = self.client.clone();
605 let reconnect_task = cx.spawn(|this, mut cx| async move {
606 macro_rules! failed {
607 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
608 return State::ReconnectFailed {
609 error: anyhow!($error),
610 attempts: $attempts,
611 ssh_connection: $ssh_connection,
612 delegate: $delegate,
613 forwarder: $forwarder,
614 };
615 };
616 }
617
618 if let Err(error) = ssh_connection.master_process.kill() {
619 failed!(error, attempts, ssh_connection, delegate, forwarder);
620 };
621
622 if let Err(error) = ssh_connection
623 .master_process
624 .status()
625 .await
626 .context("Failed to kill ssh process")
627 {
628 failed!(error, attempts, ssh_connection, delegate, forwarder);
629 }
630
631 let connection_options = ssh_connection.socket.connection_options.clone();
632
633 let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
634 let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
635 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
636 let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
637
638 let (ssh_connection, ssh_process) = match Self::establish_connection(
639 identifier,
640 true,
641 connection_options,
642 delegate.clone(),
643 &mut cx,
644 )
645 .await
646 {
647 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
648 Err(error) => {
649 failed!(error, attempts, ssh_connection, delegate, forwarder);
650 }
651 };
652
653 let multiplex_task = Self::multiplex(
654 this.clone(),
655 ssh_process,
656 proxy_incoming_tx,
657 proxy_outgoing_rx,
658 connection_activity_tx,
659 &mut cx,
660 );
661
662 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
663 failed!(error, attempts, ssh_connection, delegate, forwarder);
664 };
665
666 State::Connected {
667 ssh_connection,
668 delegate,
669 forwarder,
670 multiplex_task,
671 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
672 }
673 });
674
675 cx.spawn(|this, mut cx| async move {
676 let new_state = reconnect_task.await;
677 this.update(&mut cx, |this, cx| {
678 this.try_set_state(cx, |old_state| {
679 if old_state.is_reconnecting() {
680 match &new_state {
681 State::Connecting
682 | State::Reconnecting { .. }
683 | State::HeartbeatMissed { .. }
684 | State::ServerNotRunning => {}
685 State::Connected { .. } => {
686 log::info!("Successfully reconnected");
687 }
688 State::ReconnectFailed {
689 error, attempts, ..
690 } => {
691 log::error!(
692 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
693 attempts,
694 error
695 );
696 }
697 State::ReconnectExhausted => {
698 log::error!("Reconnect attempt failed and all attempts exhausted");
699 }
700 }
701 Some(new_state)
702 } else {
703 None
704 }
705 });
706
707 if this.state_is(State::is_reconnect_failed) {
708 this.reconnect(cx)
709 } else if this.state_is(State::is_reconnect_exhausted) {
710 cx.emit(SshRemoteEvent::Disconnected);
711 Ok(())
712 } else {
713 log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
714 Ok(())
715 }
716 })
717 })
718 .detach_and_log_err(cx);
719
720 Ok(())
721 }
722
723 fn heartbeat(
724 this: WeakModel<Self>,
725 mut connection_activity_rx: mpsc::Receiver<()>,
726 cx: &mut AsyncAppContext,
727 ) -> Task<Result<()>> {
728 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
729 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
730 };
731
732 cx.spawn(|mut cx| {
733 let this = this.clone();
734 async move {
735 let mut missed_heartbeats = 0;
736
737 let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
738 futures::pin_mut!(keepalive_timer);
739
740 loop {
741 select_biased! {
742 _ = connection_activity_rx.next().fuse() => {
743 keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
744 }
745 _ = keepalive_timer => {
746 log::debug!("Sending heartbeat to server...");
747
748 let result = select_biased! {
749 _ = connection_activity_rx.next().fuse() => {
750 Ok(())
751 }
752 ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
753 ping_result
754 }
755 };
756 if result.is_err() {
757 missed_heartbeats += 1;
758 log::warn!(
759 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
760 HEARTBEAT_TIMEOUT,
761 missed_heartbeats,
762 MAX_MISSED_HEARTBEATS
763 );
764 } else if missed_heartbeats != 0 {
765 missed_heartbeats = 0;
766 } else {
767 continue;
768 }
769
770 let result = this.update(&mut cx, |this, mut cx| {
771 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
772 })?;
773 if result.is_break() {
774 return Ok(());
775 }
776 }
777 }
778 }
779 }
780 })
781 }
782
783 fn handle_heartbeat_result(
784 &mut self,
785 missed_heartbeats: usize,
786 cx: &mut ModelContext<Self>,
787 ) -> ControlFlow<()> {
788 let state = self.state.lock().take().unwrap();
789 let next_state = if missed_heartbeats > 0 {
790 state.heartbeat_missed()
791 } else {
792 state.heartbeat_recovered()
793 };
794
795 self.set_state(next_state, cx);
796
797 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
798 log::error!(
799 "Missed last {} heartbeats. Reconnecting...",
800 missed_heartbeats
801 );
802
803 self.reconnect(cx)
804 .context("failed to start reconnect process after missing heartbeats")
805 .log_err();
806 ControlFlow::Break(())
807 } else {
808 ControlFlow::Continue(())
809 }
810 }
811
812 fn multiplex(
813 this: WeakModel<Self>,
814 mut ssh_proxy_process: Child,
815 incoming_tx: UnboundedSender<Envelope>,
816 mut outgoing_rx: UnboundedReceiver<Envelope>,
817 mut connection_activity_tx: Sender<()>,
818 cx: &AsyncAppContext,
819 ) -> Task<Result<()>> {
820 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
821 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
822 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
823
824 let io_task = cx.background_executor().spawn(async move {
825 let mut stdin_buffer = Vec::new();
826 let mut stdout_buffer = Vec::new();
827 let mut stderr_buffer = Vec::new();
828 let mut stderr_offset = 0;
829
830 loop {
831 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
832 stderr_buffer.resize(stderr_offset + 1024, 0);
833
834 select_biased! {
835 outgoing = outgoing_rx.next().fuse() => {
836 let Some(outgoing) = outgoing else {
837 return anyhow::Ok(None);
838 };
839
840 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
841 }
842
843 result = child_stdout.read(&mut stdout_buffer).fuse() => {
844 match result {
845 Ok(0) => {
846 child_stdin.close().await?;
847 outgoing_rx.close();
848 let status = ssh_proxy_process.status().await?;
849 // If we don't have a code, we assume process
850 // has been killed and treat it as non-zero exit
851 // code
852 return Ok(status.code().or_else(|| Some(1)));
853 }
854 Ok(len) => {
855 if len < stdout_buffer.len() {
856 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
857 }
858
859 let message_len = message_len_from_buffer(&stdout_buffer);
860 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
861 Ok(envelope) => {
862 connection_activity_tx.try_send(()).ok();
863 incoming_tx.unbounded_send(envelope).ok();
864 }
865 Err(error) => {
866 log::error!("error decoding message {error:?}");
867 }
868 }
869 }
870 Err(error) => {
871 Err(anyhow!("error reading stdout: {error:?}"))?;
872 }
873 }
874 }
875
876 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
877 match result {
878 Ok(len) => {
879 stderr_offset += len;
880 let mut start_ix = 0;
881 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
882 let line_ix = start_ix + ix;
883 let content = &stderr_buffer[start_ix..line_ix];
884 start_ix = line_ix + 1;
885 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
886 record.log(log::logger())
887 } else {
888 eprintln!("(remote) {}", String::from_utf8_lossy(content));
889 }
890 }
891 stderr_buffer.drain(0..start_ix);
892 stderr_offset -= start_ix;
893
894 connection_activity_tx.try_send(()).ok();
895 }
896 Err(error) => {
897 Err(anyhow!("error reading stderr: {error:?}"))?;
898 }
899 }
900 }
901 }
902 }
903 });
904
905 cx.spawn(|mut cx| async move {
906 let result = io_task.await;
907
908 match result {
909 Ok(Some(exit_code)) => {
910 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
911 match error {
912 ProxyLaunchError::ServerNotRunning => {
913 log::error!("failed to reconnect because server is not running");
914 this.update(&mut cx, |this, cx| {
915 this.set_state(State::ServerNotRunning, cx);
916 cx.emit(SshRemoteEvent::Disconnected);
917 })?;
918 }
919 }
920 } else if exit_code > 0 {
921 log::error!("proxy process terminated unexpectedly");
922 this.update(&mut cx, |this, cx| {
923 this.reconnect(cx).ok();
924 })?;
925 }
926 }
927 Ok(None) => {}
928 Err(error) => {
929 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
930 this.update(&mut cx, |this, cx| {
931 this.reconnect(cx).ok();
932 })?;
933 }
934 }
935 Ok(())
936 })
937 }
938
939 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
940 self.state.lock().as_ref().map_or(false, check)
941 }
942
943 fn try_set_state(
944 &self,
945 cx: &mut ModelContext<Self>,
946 map: impl FnOnce(&State) -> Option<State>,
947 ) {
948 let mut lock = self.state.lock();
949 let new_state = lock.as_ref().and_then(map);
950
951 if let Some(new_state) = new_state {
952 lock.replace(new_state);
953 cx.notify();
954 }
955 }
956
957 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
958 log::info!("setting state to '{}'", &state);
959 self.state.lock().replace(state);
960 cx.notify();
961 }
962
963 async fn establish_connection(
964 unique_identifier: String,
965 reconnect: bool,
966 connection_options: SshConnectionOptions,
967 delegate: Arc<dyn SshClientDelegate>,
968 cx: &mut AsyncAppContext,
969 ) -> Result<(SshRemoteConnection, Child)> {
970 let ssh_connection =
971 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
972
973 let platform = ssh_connection.query_platform().await?;
974 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
975 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
976 ssh_connection
977 .ensure_server_binary(
978 &delegate,
979 &local_binary_path,
980 &remote_binary_path,
981 version,
982 cx,
983 )
984 .await?;
985
986 let socket = ssh_connection.socket.clone();
987 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
988
989 delegate.set_status(Some("Starting proxy"), cx);
990
991 let mut start_proxy_command = format!(
992 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
993 std::env::var("RUST_LOG").unwrap_or_default(),
994 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
995 remote_binary_path,
996 unique_identifier,
997 );
998 if reconnect {
999 start_proxy_command.push_str(" --reconnect");
1000 }
1001
1002 let ssh_proxy_process = socket
1003 .ssh_command(start_proxy_command)
1004 // IMPORTANT: we kill this process when we drop the task that uses it.
1005 .kill_on_drop(true)
1006 .spawn()
1007 .context("failed to spawn remote server")?;
1008
1009 Ok((ssh_connection, ssh_proxy_process))
1010 }
1011
1012 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1013 self.client.subscribe_to_entity(remote_id, entity);
1014 }
1015
1016 pub fn ssh_args(&self) -> Option<Vec<String>> {
1017 self.state
1018 .lock()
1019 .as_ref()
1020 .and_then(|state| state.ssh_connection())
1021 .map(|ssh_connection| ssh_connection.socket.ssh_args())
1022 }
1023
1024 pub fn to_proto_client(&self) -> AnyProtoClient {
1025 self.client.clone().into()
1026 }
1027
1028 pub fn connection_string(&self) -> String {
1029 self.connection_options.connection_string()
1030 }
1031
1032 pub fn connection_options(&self) -> SshConnectionOptions {
1033 self.connection_options.clone()
1034 }
1035
1036 #[cfg(not(any(test, feature = "test-support")))]
1037 pub fn connection_state(&self) -> ConnectionState {
1038 self.state
1039 .lock()
1040 .as_ref()
1041 .map(ConnectionState::from)
1042 .unwrap_or(ConnectionState::Disconnected)
1043 }
1044
1045 #[cfg(any(test, feature = "test-support"))]
1046 pub fn connection_state(&self) -> ConnectionState {
1047 ConnectionState::Connected
1048 }
1049
1050 pub fn is_disconnected(&self) -> bool {
1051 self.connection_state() == ConnectionState::Disconnected
1052 }
1053
1054 #[cfg(any(test, feature = "test-support"))]
1055 pub fn fake(
1056 client_cx: &mut gpui::TestAppContext,
1057 server_cx: &mut gpui::TestAppContext,
1058 ) -> (Model<Self>, Arc<ChannelClient>) {
1059 use gpui::Context;
1060
1061 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
1062 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
1063
1064 (
1065 client_cx.update(|cx| {
1066 let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
1067 cx.new_model(|_| Self {
1068 client,
1069 unique_identifier: "fake".to_string(),
1070 connection_options: SshConnectionOptions::default(),
1071 state: Arc::new(Mutex::new(None)),
1072 })
1073 }),
1074 server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
1075 )
1076 }
1077}
1078
1079impl From<SshRemoteClient> for AnyProtoClient {
1080 fn from(client: SshRemoteClient) -> Self {
1081 AnyProtoClient::new(client.client.clone())
1082 }
1083}
1084
1085struct SshRemoteConnection {
1086 socket: SshSocket,
1087 master_process: process::Child,
1088 _temp_dir: TempDir,
1089}
1090
1091impl Drop for SshRemoteConnection {
1092 fn drop(&mut self) {
1093 if let Err(error) = self.master_process.kill() {
1094 log::error!("failed to kill SSH master process: {}", error);
1095 }
1096 }
1097}
1098
1099impl SshRemoteConnection {
1100 #[cfg(not(unix))]
1101 async fn new(
1102 _connection_options: SshConnectionOptions,
1103 _delegate: Arc<dyn SshClientDelegate>,
1104 _cx: &mut AsyncAppContext,
1105 ) -> Result<Self> {
1106 Err(anyhow!("ssh is not supported on this platform"))
1107 }
1108
1109 #[cfg(unix)]
1110 async fn new(
1111 connection_options: SshConnectionOptions,
1112 delegate: Arc<dyn SshClientDelegate>,
1113 cx: &mut AsyncAppContext,
1114 ) -> Result<Self> {
1115 use futures::{io::BufReader, AsyncBufReadExt as _};
1116 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1117 use util::ResultExt as _;
1118
1119 delegate.set_status(Some("connecting"), cx);
1120
1121 let url = connection_options.ssh_url();
1122 let temp_dir = tempfile::Builder::new()
1123 .prefix("zed-ssh-session")
1124 .tempdir()?;
1125
1126 // Create a domain socket listener to handle requests from the askpass program.
1127 let askpass_socket = temp_dir.path().join("askpass.sock");
1128 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1129 let listener =
1130 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1131
1132 let askpass_task = cx.spawn({
1133 let delegate = delegate.clone();
1134 |mut cx| async move {
1135 let mut askpass_opened_tx = Some(askpass_opened_tx);
1136
1137 while let Ok((mut stream, _)) = listener.accept().await {
1138 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1139 askpass_opened_tx.send(()).ok();
1140 }
1141 let mut buffer = Vec::new();
1142 let mut reader = BufReader::new(&mut stream);
1143 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1144 buffer.clear();
1145 }
1146 let password_prompt = String::from_utf8_lossy(&buffer);
1147 if let Some(password) = delegate
1148 .ask_password(password_prompt.to_string(), &mut cx)
1149 .await
1150 .context("failed to get ssh password")
1151 .and_then(|p| p)
1152 .log_err()
1153 {
1154 stream.write_all(password.as_bytes()).await.log_err();
1155 }
1156 }
1157 }
1158 });
1159
1160 // Create an askpass script that communicates back to this process.
1161 let askpass_script = format!(
1162 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1163 askpass_socket = askpass_socket.display(),
1164 print_args = "printf '%s\\0' \"$@\"",
1165 shebang = "#!/bin/sh",
1166 );
1167 let askpass_script_path = temp_dir.path().join("askpass.sh");
1168 fs::write(&askpass_script_path, askpass_script).await?;
1169 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1170
1171 // Start the master SSH process, which does not do anything except for establish
1172 // the connection and keep it open, allowing other ssh commands to reuse it
1173 // via a control socket.
1174 let socket_path = temp_dir.path().join("ssh.sock");
1175 let mut master_process = process::Command::new("ssh")
1176 .stdin(Stdio::null())
1177 .stdout(Stdio::piped())
1178 .stderr(Stdio::piped())
1179 .env("SSH_ASKPASS_REQUIRE", "force")
1180 .env("SSH_ASKPASS", &askpass_script_path)
1181 .args(["-N", "-o", "ControlMaster=yes", "-o"])
1182 .arg(format!("ControlPath={}", socket_path.display()))
1183 .arg(&url)
1184 .spawn()?;
1185
1186 // Wait for this ssh process to close its stdout, indicating that authentication
1187 // has completed.
1188 let stdout = master_process.stdout.as_mut().unwrap();
1189 let mut output = Vec::new();
1190 let connection_timeout = Duration::from_secs(10);
1191
1192 let result = select_biased! {
1193 _ = askpass_opened_rx.fuse() => {
1194 // If the askpass script has opened, that means the user is typing
1195 // their password, in which case we don't want to timeout anymore,
1196 // since we know a connection has been established.
1197 stdout.read_to_end(&mut output).await?;
1198 Ok(())
1199 }
1200 result = stdout.read_to_end(&mut output).fuse() => {
1201 result?;
1202 Ok(())
1203 }
1204 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1205 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1206 }
1207 };
1208
1209 if let Err(e) = result {
1210 let error_message = format!("Failed to connect to host: {}.", e);
1211 delegate.set_error(error_message, cx);
1212 return Err(e);
1213 }
1214
1215 drop(askpass_task);
1216
1217 if master_process.try_status()?.is_some() {
1218 output.clear();
1219 let mut stderr = master_process.stderr.take().unwrap();
1220 stderr.read_to_end(&mut output).await?;
1221
1222 let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1223 delegate.set_error(error_message.clone(), cx);
1224 Err(anyhow!(error_message))?;
1225 }
1226
1227 Ok(Self {
1228 socket: SshSocket {
1229 connection_options,
1230 socket_path,
1231 },
1232 master_process,
1233 _temp_dir: temp_dir,
1234 })
1235 }
1236
1237 async fn ensure_server_binary(
1238 &self,
1239 delegate: &Arc<dyn SshClientDelegate>,
1240 src_path: &Path,
1241 dst_path: &Path,
1242 version: SemanticVersion,
1243 cx: &mut AsyncAppContext,
1244 ) -> Result<()> {
1245 let mut dst_path_gz = dst_path.to_path_buf();
1246 dst_path_gz.set_extension("gz");
1247
1248 if let Some(parent) = dst_path.parent() {
1249 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1250 }
1251
1252 let mut server_binary_exists = false;
1253 if cfg!(not(debug_assertions)) {
1254 if let Ok(installed_version) =
1255 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1256 {
1257 if installed_version.trim() == version.to_string() {
1258 server_binary_exists = true;
1259 }
1260 }
1261 }
1262
1263 if server_binary_exists {
1264 log::info!("remote development server already present",);
1265 return Ok(());
1266 }
1267
1268 let src_stat = fs::metadata(src_path).await?;
1269 let size = src_stat.len();
1270 let server_mode = 0o755;
1271
1272 let t0 = Instant::now();
1273 delegate.set_status(Some("uploading remote development server"), cx);
1274 log::info!("uploading remote development server ({}kb)", size / 1024);
1275 self.upload_file(src_path, &dst_path_gz)
1276 .await
1277 .context("failed to upload server binary")?;
1278 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1279
1280 delegate.set_status(Some("extracting remote development server"), cx);
1281 run_cmd(
1282 self.socket
1283 .ssh_command("gunzip")
1284 .arg("--force")
1285 .arg(&dst_path_gz),
1286 )
1287 .await?;
1288
1289 delegate.set_status(Some("unzipping remote development server"), cx);
1290 run_cmd(
1291 self.socket
1292 .ssh_command("chmod")
1293 .arg(format!("{:o}", server_mode))
1294 .arg(dst_path),
1295 )
1296 .await?;
1297
1298 Ok(())
1299 }
1300
1301 async fn query_platform(&self) -> Result<SshPlatform> {
1302 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1303 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1304
1305 let os = match os.trim() {
1306 "Darwin" => "macos",
1307 "Linux" => "linux",
1308 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1309 };
1310 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1311 "aarch64"
1312 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1313 "x86_64"
1314 } else {
1315 Err(anyhow!("unknown uname architecture {arch:?}"))?
1316 };
1317
1318 Ok(SshPlatform { os, arch })
1319 }
1320
1321 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1322 let mut command = process::Command::new("scp");
1323 let output = self
1324 .socket
1325 .ssh_options(&mut command)
1326 .args(
1327 self.socket
1328 .connection_options
1329 .port
1330 .map(|port| vec!["-P".to_string(), port.to_string()])
1331 .unwrap_or_default(),
1332 )
1333 .arg(src_path)
1334 .arg(format!(
1335 "{}:{}",
1336 self.socket.connection_options.scp_url(),
1337 dest_path.display()
1338 ))
1339 .output()
1340 .await?;
1341
1342 if output.status.success() {
1343 Ok(())
1344 } else {
1345 Err(anyhow!(
1346 "failed to upload file {} -> {}: {}",
1347 src_path.display(),
1348 dest_path.display(),
1349 String::from_utf8_lossy(&output.stderr)
1350 ))
1351 }
1352 }
1353}
1354
1355type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1356
1357pub struct ChannelClient {
1358 next_message_id: AtomicU32,
1359 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1360 response_channels: ResponseChannels, // Lock
1361 message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1362}
1363
1364impl ChannelClient {
1365 pub fn new(
1366 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1367 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1368 cx: &AppContext,
1369 ) -> Arc<Self> {
1370 let this = Arc::new(Self {
1371 outgoing_tx,
1372 next_message_id: AtomicU32::new(0),
1373 response_channels: ResponseChannels::default(),
1374 message_handlers: Default::default(),
1375 });
1376
1377 Self::start_handling_messages(this.clone(), incoming_rx, cx);
1378
1379 this
1380 }
1381
1382 fn start_handling_messages(
1383 this: Arc<Self>,
1384 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1385 cx: &AppContext,
1386 ) {
1387 cx.spawn(|cx| {
1388 let this = Arc::downgrade(&this);
1389 async move {
1390 let peer_id = PeerId { owner_id: 0, id: 0 };
1391 while let Some(incoming) = incoming_rx.next().await {
1392 let Some(this) = this.upgrade() else {
1393 return anyhow::Ok(());
1394 };
1395
1396 if let Some(request_id) = incoming.responding_to {
1397 let request_id = MessageId(request_id);
1398 let sender = this.response_channels.lock().remove(&request_id);
1399 if let Some(sender) = sender {
1400 let (tx, rx) = oneshot::channel();
1401 if incoming.payload.is_some() {
1402 sender.send((incoming, tx)).ok();
1403 }
1404 rx.await.ok();
1405 }
1406 } else if let Some(envelope) =
1407 build_typed_envelope(peer_id, Instant::now(), incoming)
1408 {
1409 let type_name = envelope.payload_type_name();
1410 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1411 &this.message_handlers,
1412 envelope,
1413 this.clone().into(),
1414 cx.clone(),
1415 ) {
1416 log::debug!("ssh message received. name:{type_name}");
1417 cx.foreground_executor().spawn(async move {
1418 match future.await {
1419 Ok(_) => {
1420 log::debug!("ssh message handled. name:{type_name}");
1421 }
1422 Err(error) => {
1423 log::error!(
1424 "error handling message. type:{type_name}, error:{error}",
1425 );
1426 }
1427 }
1428 }).detach();
1429
1430 } else {
1431 log::error!("unhandled ssh message name:{type_name}");
1432 }
1433 }
1434 }
1435 anyhow::Ok(())
1436 }
1437 })
1438 .detach();
1439 }
1440
1441 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1442 let id = (TypeId::of::<E>(), remote_id);
1443
1444 let mut message_handlers = self.message_handlers.lock();
1445 if message_handlers
1446 .entities_by_type_and_remote_id
1447 .contains_key(&id)
1448 {
1449 panic!("already subscribed to entity");
1450 }
1451
1452 message_handlers.entities_by_type_and_remote_id.insert(
1453 id,
1454 EntityMessageSubscriber::Entity {
1455 handle: entity.downgrade().into(),
1456 },
1457 );
1458 }
1459
1460 pub fn request<T: RequestMessage>(
1461 &self,
1462 payload: T,
1463 ) -> impl 'static + Future<Output = Result<T::Response>> {
1464 log::debug!("ssh request start. name:{}", T::NAME);
1465 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1466 async move {
1467 let response = response.await?;
1468 log::debug!("ssh request finish. name:{}", T::NAME);
1469 T::Response::from_envelope(response)
1470 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1471 }
1472 }
1473
1474 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1475 smol::future::or(
1476 async {
1477 self.request(proto::Ping {}).await?;
1478 Ok(())
1479 },
1480 async {
1481 smol::Timer::after(timeout).await;
1482 Err(anyhow!("Timeout detected"))
1483 },
1484 )
1485 .await
1486 }
1487
1488 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1489 log::debug!("ssh send name:{}", T::NAME);
1490 self.send_dynamic(payload.into_envelope(0, None, None))
1491 }
1492
1493 pub fn request_dynamic(
1494 &self,
1495 mut envelope: proto::Envelope,
1496 type_name: &'static str,
1497 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1498 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1499 let (tx, rx) = oneshot::channel();
1500 let mut response_channels_lock = self.response_channels.lock();
1501 response_channels_lock.insert(MessageId(envelope.id), tx);
1502 drop(response_channels_lock);
1503 let result = self.outgoing_tx.unbounded_send(envelope);
1504 async move {
1505 if let Err(error) = &result {
1506 log::error!("failed to send message: {}", error);
1507 return Err(anyhow!("failed to send message: {}", error));
1508 }
1509
1510 let response = rx.await.context("connection lost")?.0;
1511 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1512 return Err(RpcError::from_proto(error, type_name));
1513 }
1514 Ok(response)
1515 }
1516 }
1517
1518 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1519 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1520 self.outgoing_tx.unbounded_send(envelope)?;
1521 Ok(())
1522 }
1523}
1524
1525impl ProtoClient for ChannelClient {
1526 fn request(
1527 &self,
1528 envelope: proto::Envelope,
1529 request_type: &'static str,
1530 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1531 self.request_dynamic(envelope, request_type).boxed()
1532 }
1533
1534 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1535 self.send_dynamic(envelope)
1536 }
1537
1538 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1539 self.send_dynamic(envelope)
1540 }
1541
1542 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1543 &self.message_handlers
1544 }
1545
1546 fn is_via_collab(&self) -> bool {
1547 false
1548 }
1549}