ssh remoting: Treat other message as heartbeat (#19219)

Bennet Bo Fenner , Thorsten , and Antonio created

This improves the heartbeat detection logic. We now treat any other
incoming message from the ssh remote server
as a heartbeat message, meaning that we can detect re-connects earlier.

It also changes the connection handling to await futures detached.

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>

Change summary

crates/remote/src/ssh_session.rs | 107 +++++++++++++++++++++------------
1 file changed, 68 insertions(+), 39 deletions(-)

Detailed changes

crates/remote/src/ssh_session.rs 🔗

@@ -9,7 +9,7 @@ use anyhow::{anyhow, Context as _, Result};
 use collections::HashMap;
 use futures::{
     channel::{
-        mpsc::{self, UnboundedReceiver, UnboundedSender},
+        mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
         oneshot,
     },
     future::BoxFuture,
@@ -28,7 +28,6 @@ use rpc::{
 use smol::{
     fs,
     process::{self, Child, Stdio},
-    Timer,
 };
 use std::{
     any::TypeId,
@@ -441,6 +440,7 @@ impl SshRemoteClient {
         cx.spawn(|mut cx| async move {
             let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
             let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
+            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 
             let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
             let this = cx.new_model(|_| Self {
@@ -467,6 +467,7 @@ impl SshRemoteClient {
                 ssh_proxy_process,
                 proxy_incoming_tx,
                 proxy_outgoing_rx,
+                connection_activity_tx,
                 &mut cx,
             );
 
@@ -476,7 +477,7 @@ impl SshRemoteClient {
                 return Err(error);
             }
 
-            let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
+            let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
 
             this.update(&mut cx, |this, _| {
                 *this.state.lock() = Some(State::Connected {
@@ -518,7 +519,7 @@ impl SshRemoteClient {
                 // We wait 50ms instead of waiting for a response, because
                 // waiting for a response would require us to wait on the main thread
                 // which we want to avoid in an `on_app_quit` callback.
-                Timer::after(Duration::from_millis(50)).await;
+                smol::Timer::after(Duration::from_millis(50)).await;
             }
 
             // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
@@ -632,6 +633,7 @@ impl SshRemoteClient {
             let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
             let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
                 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+            let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 
             let (ssh_connection, ssh_process) = match Self::establish_connection(
                 identifier,
@@ -653,6 +655,7 @@ impl SshRemoteClient {
                 ssh_process,
                 proxy_incoming_tx,
                 proxy_outgoing_rx,
+                connection_activity_tx,
                 &mut cx,
             );
 
@@ -665,7 +668,7 @@ impl SshRemoteClient {
                 delegate,
                 forwarder,
                 multiplex_task,
-                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
+                heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
             }
         });
 
@@ -717,41 +720,60 @@ impl SshRemoteClient {
         Ok(())
     }
 
-    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
+    fn heartbeat(
+        this: WeakModel<Self>,
+        mut connection_activity_rx: mpsc::Receiver<()>,
+        cx: &mut AsyncAppContext,
+    ) -> Task<Result<()>> {
         let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
             return Task::ready(Err(anyhow!("SshRemoteClient lost")));
         };
+
         cx.spawn(|mut cx| {
             let this = this.clone();
             async move {
                 let mut missed_heartbeats = 0;
 
-                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
+                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
+                futures::pin_mut!(keepalive_timer);
+
                 loop {
-                    timer.next().await;
-
-                    log::debug!("Sending heartbeat to server...");
-
-                    let result = client.ping(HEARTBEAT_TIMEOUT).await;
-                    if result.is_err() {
-                        missed_heartbeats += 1;
-                        log::warn!(
-                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
-                            HEARTBEAT_TIMEOUT,
-                            missed_heartbeats,
-                            MAX_MISSED_HEARTBEATS
-                        );
-                    } else if missed_heartbeats != 0 {
-                        missed_heartbeats = 0;
-                    } else {
-                        continue;
-                    }
+                    select_biased! {
+                        _ = connection_activity_rx.next().fuse() => {
+                            keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
+                        }
+                        _ = keepalive_timer => {
+                            log::debug!("Sending heartbeat to server...");
 
-                    let result = this.update(&mut cx, |this, mut cx| {
-                        this.handle_heartbeat_result(missed_heartbeats, &mut cx)
-                    })?;
-                    if result.is_break() {
-                        return Ok(());
+                            let result = select_biased! {
+                                _ = connection_activity_rx.next().fuse() => {
+                                    Ok(())
+                                }
+                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
+                                    ping_result
+                                }
+                            };
+                            if result.is_err() {
+                                missed_heartbeats += 1;
+                                log::warn!(
+                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
+                                    HEARTBEAT_TIMEOUT,
+                                    missed_heartbeats,
+                                    MAX_MISSED_HEARTBEATS
+                                );
+                            } else if missed_heartbeats != 0 {
+                                missed_heartbeats = 0;
+                            } else {
+                                continue;
+                            }
+
+                            let result = this.update(&mut cx, |this, mut cx| {
+                                this.handle_heartbeat_result(missed_heartbeats, &mut cx)
+                            })?;
+                            if result.is_break() {
+                                return Ok(());
+                            }
+                        }
                     }
                 }
             }
@@ -792,6 +814,7 @@ impl SshRemoteClient {
         mut ssh_proxy_process: Child,
         incoming_tx: UnboundedSender<Envelope>,
         mut outgoing_rx: UnboundedReceiver<Envelope>,
+        mut connection_activity_tx: Sender<()>,
         cx: &AsyncAppContext,
     ) -> Task<Result<()>> {
         let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
@@ -833,6 +856,7 @@ impl SshRemoteClient {
                                 let message_len = message_len_from_buffer(&stdout_buffer);
                                 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
                                     Ok(envelope) => {
+                                        connection_activity_tx.try_send(()).ok();
                                         incoming_tx.unbounded_send(envelope).ok();
                                     }
                                     Err(error) => {
@@ -863,6 +887,8 @@ impl SshRemoteClient {
                                 }
                                 stderr_buffer.drain(0..start_ix);
                                 stderr_offset -= start_ix;
+
+                                connection_activity_tx.try_send(()).ok();
                             }
                             Err(error) => {
                                 Err(anyhow!("error reading stderr: {error:?}"))?;
@@ -1392,16 +1418,19 @@ impl ChannelClient {
                             cx.clone(),
                         ) {
                             log::debug!("ssh message received. name:{type_name}");
-                            match future.await {
-                                Ok(_) => {
-                                    log::debug!("ssh message handled. name:{type_name}");
-                                }
-                                Err(error) => {
-                                    log::error!(
-                                        "error handling message. type:{type_name}, error:{error}",
-                                    );
+                            cx.foreground_executor().spawn(async move {
+                                match future.await {
+                                    Ok(_) => {
+                                        log::debug!("ssh message handled. name:{type_name}");
+                                    }
+                                    Err(error) => {
+                                        log::error!(
+                                            "error handling message. type:{type_name}, error:{error}",
+                                        );
+                                    }
                                 }
-                            }
+                            }).detach();
+
                         } else {
                             log::error!("unhandled ssh message name:{type_name}");
                         }