ssh: Detect timeouts when server is unresponsive (#18808)

Bennet Bo Fenner and Thorsten created

To detect connection timeouts we ping the remote server every X seconds
and attempt to reconnect if the server failed to respond.
Next up is showing some feedback in the UI to make this visible to the
user, and stop reconnecting after X amount of retries.

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>

Change summary

crates/remote/src/ssh_session.rs             | 75 +++++++++++++++++++++
crates/remote_server/src/headless_project.rs | 10 ++
2 files changed, 82 insertions(+), 3 deletions(-)

Detailed changes

crates/remote/src/ssh_session.rs 🔗

@@ -26,6 +26,7 @@ use rpc::{
 use smol::{
     fs,
     process::{self, Child, Stdio},
+    Timer,
 };
 use std::{
     any::TypeId,
@@ -36,7 +37,7 @@ use std::{
         atomic::{AtomicU32, Ordering::SeqCst},
         Arc,
     },
-    time::Instant,
+    time::{Duration, Instant},
 };
 use tempfile::TempDir;
 use util::maybe;
@@ -173,7 +174,7 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
 #[cfg(unix)]
 async fn read_with_timeout(
     stdout: &mut process::ChildStdout,
-    timeout: std::time::Duration,
+    timeout: Duration,
     output: &mut Vec<u8>,
 ) -> Result<(), std::io::Error> {
     smol::future::or(
@@ -260,6 +261,7 @@ struct SshRemoteClientState {
     delegate: Arc<dyn SshClientDelegate>,
     forwarder: ChannelForwarder,
     multiplex_task: Task<Result<()>>,
+    heartbeat_task: Task<Result<()>>,
 }
 
 pub struct SshRemoteClient {
@@ -327,6 +329,7 @@ impl SshRemoteClient {
                     delegate,
                     forwarder: proxy,
                     multiplex_task,
+                    heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx),
                 }
             };
 
@@ -353,6 +356,7 @@ impl SshRemoteClient {
     }
 
     fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
+        log::info!("Trying to reconnect to ssh server...");
         let Some(state) = self.inner_state.lock().take() else {
             return Err(anyhow!("reconnect is already in progress"));
         };
@@ -364,8 +368,10 @@ impl SshRemoteClient {
             delegate,
             forwarder: proxy,
             multiplex_task,
+            heartbeat_task,
         } = state;
         drop(multiplex_task);
+        drop(heartbeat_task);
 
         cx.spawn(|this, mut cx| async move {
             let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
@@ -401,6 +407,7 @@ impl SshRemoteClient {
                     proxy_outgoing_rx,
                     &mut cx,
                 ),
+                heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
             };
 
             this.update(&mut cx, |this, _| {
@@ -411,6 +418,68 @@ impl SshRemoteClient {
         Ok(())
     }
 
+    fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
+        let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
+            return Task::ready(Err(anyhow!("SshRemoteClient lost")));
+        };
+        cx.spawn(|mut cx| {
+            let this = this.clone();
+            async move {
+                const MAX_MISSED_HEARTBEATS: usize = 5;
+                const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
+                const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
+
+                let mut missed_heartbeats = 0;
+
+                let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
+                loop {
+                    timer.next().await;
+
+                    log::info!("Sending heartbeat to server...");
+
+                    let result = smol::future::or(
+                        async {
+                            client.request(proto::Ping {}).await?;
+                            Ok(())
+                        },
+                        async {
+                            smol::Timer::after(HEARTBEAT_TIMEOUT).await;
+
+                            Err(anyhow!("Timeout detected"))
+                        },
+                    )
+                    .await;
+
+                    if result.is_err() {
+                        missed_heartbeats += 1;
+                        log::warn!(
+                            "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
+                            HEARTBEAT_TIMEOUT,
+                            missed_heartbeats,
+                            MAX_MISSED_HEARTBEATS
+                        );
+                    } else {
+                        missed_heartbeats = 0;
+                    }
+
+                    if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
+                        log::error!(
+                            "Missed last {} hearbeats. Reconnecting...",
+                            missed_heartbeats
+                        );
+
+                        this.update(&mut cx, |this, cx| {
+                            this.reconnect(cx)
+                                .context("failed to reconnect after missing heartbeats")
+                        })
+                        .context("failed to update weak reference, SshRemoteClient lost?")??;
+                        return Ok(());
+                    }
+                }
+            }
+        })
+    }
+
     fn multiplex(
         this: WeakModel<Self>,
         mut ssh_proxy_process: Child,
@@ -712,7 +781,7 @@ impl SshRemoteConnection {
         // has completed.
         let stdout = master_process.stdout.as_mut().unwrap();
         let mut output = Vec::new();
-        let connection_timeout = std::time::Duration::from_secs(10);
+        let connection_timeout = Duration::from_secs(10);
         let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
         if let Err(e) = result {
             let error_message = if e.kind() == std::io::ErrorKind::TimedOut {

crates/remote_server/src/headless_project.rs 🔗

@@ -113,6 +113,7 @@ impl HeadlessProject {
         client.add_request_handler(cx.weak_model(), Self::handle_list_remote_directory);
         client.add_request_handler(cx.weak_model(), Self::handle_check_file_exists);
         client.add_request_handler(cx.weak_model(), Self::handle_shutdown_remote_server);
+        client.add_request_handler(cx.weak_model(), Self::handle_ping);
 
         client.add_model_request_handler(Self::handle_add_worktree);
         client.add_model_request_handler(Self::handle_open_buffer_by_path);
@@ -354,4 +355,13 @@ impl HeadlessProject {
 
         Ok(proto::Ack {})
     }
+
+    pub async fn handle_ping(
+        _this: Model<Self>,
+        _envelope: TypedEnvelope<proto::Ping>,
+        _cx: AsyncAppContext,
+    ) -> Result<proto::Ack> {
+        log::debug!("Received ping from client");
+        Ok(proto::Ack {})
+    }
 }