ssh: Limit amount of reconnect attempts (#18819)

Bennet Bo Fenner and Thorsten created

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>

Change summary

Cargo.lock                        |   1 
crates/project/src/project.rs     |   6 
crates/remote/src/remote.rs       |   4 
crates/remote/src/ssh_session.rs  | 514 +++++++++++++++++++++++++-------
crates/title_bar/Cargo.toml       |   1 
crates/title_bar/src/title_bar.rs |  10 
6 files changed, 412 insertions(+), 124 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -11885,6 +11885,7 @@ dependencies = [
  "pretty_assertions",
  "project",
  "recent_projects",
+ "remote",
  "rpc",
  "serde",
  "settings",

crates/project/src/project.rs 🔗

@@ -1263,8 +1263,10 @@ impl Project {
             .clone()
     }
 
-    pub fn ssh_is_connected(&self, cx: &AppContext) -> Option<bool> {
-        Some(!self.ssh_client.as_ref()?.read(cx).is_reconnect_underway())
+    pub fn ssh_connection_state(&self, cx: &AppContext) -> Option<remote::ConnectionState> {
+        self.ssh_client
+            .as_ref()
+            .map(|ssh| ssh.read(cx).connection_state())
     }
 
     pub fn replica_id(&self) -> ReplicaId {

crates/remote/src/remote.rs 🔗

@@ -2,4 +2,6 @@ pub mod json_log;
 pub mod protocol;
 pub mod ssh_session;
 
-pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient};
+pub use ssh_session::{
+    ConnectionState, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient,
+};

crates/remote/src/ssh_session.rs 🔗

@@ -31,7 +31,8 @@ use smol::{
 use std::{
     any::TypeId,
     ffi::OsStr,
-    mem,
+    fmt,
+    ops::ControlFlow,
     path::{Path, PathBuf},
     sync::{
         atomic::{AtomicU32, Ordering::SeqCst},
@@ -40,7 +41,7 @@ use std::{
     time::{Duration, Instant},
 };
 use tempfile::TempDir;
-use util::maybe;
+use util::ResultExt;
 
 #[derive(
     Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
@@ -234,19 +235,157 @@ impl ChannelForwarder {
     }
 }
 
-struct SshRemoteClientState {
-    ssh_connection: SshRemoteConnection,
-    delegate: Arc<dyn SshClientDelegate>,
-    forwarder: ChannelForwarder,
-    multiplex_task: Task<Result<()>>,
-    heartbeat_task: Task<Result<()>>,
+const MAX_MISSED_HEARTBEATS: usize = 5;
+const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
+const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
+
+const MAX_RECONNECT_ATTEMPTS: usize = 3;
+
+enum State {
+    Connecting,
+    Connected {
+        ssh_connection: SshRemoteConnection,
+        delegate: Arc<dyn SshClientDelegate>,
+        forwarder: ChannelForwarder,
+
+        multiplex_task: Task<Result<()>>,
+        heartbeat_task: Task<Result<()>>,
+    },
+    HeartbeatMissed {
+        missed_heartbeats: usize,
+
+        ssh_connection: SshRemoteConnection,
+        delegate: Arc<dyn SshClientDelegate>,
+        forwarder: ChannelForwarder,
+
+        multiplex_task: Task<Result<()>>,
+        heartbeat_task: Task<Result<()>>,
+    },
+    Reconnecting,
+    ReconnectFailed {
+        ssh_connection: SshRemoteConnection,
+        delegate: Arc<dyn SshClientDelegate>,
+        forwarder: ChannelForwarder,
+
+        error: anyhow::Error,
+        attempts: usize,
+    },
+    ReconnectExhausted,
+}
+
+impl fmt::Display for State {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Self::Connecting => write!(f, "connecting"),
+            Self::Connected { .. } => write!(f, "connected"),
+            Self::Reconnecting => write!(f, "reconnecting"),
+            Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
+            Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
+            Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
+        }
+    }
+}
+
+impl State {
+    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
+        match self {
+            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
+            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
+            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
+            _ => None,
+        }
+    }
+
+    fn can_reconnect(&self) -> bool {
+        matches!(
+            self,
+            Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. }
+        )
+    }
+
+    fn heartbeat_recovered(self) -> Self {
+        match self {
+            Self::HeartbeatMissed {
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+                ..
+            } => Self::Connected {
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+            },
+            _ => self,
+        }
+    }
+
+    fn heartbeat_missed(self) -> Self {
+        match self {
+            Self::Connected {
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+            } => Self::HeartbeatMissed {
+                missed_heartbeats: 1,
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+            },
+            Self::HeartbeatMissed {
+                missed_heartbeats,
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+            } => Self::HeartbeatMissed {
+                missed_heartbeats: missed_heartbeats + 1,
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+            },
+            _ => self,
+        }
+    }
+}
+
+/// The state of the ssh connection.
+#[derive(Clone, Copy, Debug)]
+pub enum ConnectionState {
+    Connecting,
+    Connected,
+    HeartbeatMissed,
+    Reconnecting,
+    Disconnected,
+}
+
+impl From<&State> for ConnectionState {
+    fn from(value: &State) -> Self {
+        match value {
+            State::Connecting => Self::Connecting,
+            State::Connected { .. } => Self::Connected,
+            State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
+            State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
+            State::ReconnectExhausted => Self::Disconnected,
+        }
+    }
 }
 
 pub struct SshRemoteClient {
     client: Arc<ChannelClient>,
     unique_identifier: String,
     connection_options: SshConnectionOptions,
-    inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
+    state: Arc<Mutex<Option<State>>>,
 }
 
 impl Drop for SshRemoteClient {
@@ -266,6 +405,7 @@ impl SshRemoteClient {
             let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
             let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 
+            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
             let this = cx.new_model(|cx| {
                 cx.on_app_quit(|this: &mut Self, _| {
                     this.shutdown_processes();
@@ -273,47 +413,49 @@ impl SshRemoteClient {
                 })
                 .detach();
 
-                let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
                 Self {
-                    client,
+                    client: client.clone(),
                     unique_identifier: unique_identifier.clone(),
-                    connection_options: SshConnectionOptions::default(),
-                    inner_state: Arc::new(Mutex::new(None)),
+                    connection_options: connection_options.clone(),
+                    state: Arc::new(Mutex::new(Some(State::Connecting))),
                 }
             })?;
 
-            let inner_state = {
-                let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
-                    ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
-
-                let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
-                    unique_identifier,
-                    connection_options,
-                    delegate.clone(),
-                    &mut cx,
-                )
-                .await?;
-
-                let multiplex_task = Self::multiplex(
-                    this.downgrade(),
-                    ssh_proxy_process,
-                    proxy_incoming_tx,
-                    proxy_outgoing_rx,
-                    &mut cx,
-                );
-
-                SshRemoteClientState {
+            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
+                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+
+            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
+                unique_identifier,
+                connection_options,
+                delegate.clone(),
+                &mut cx,
+            )
+            .await?;
+
+            let multiplex_task = Self::multiplex(
+                this.downgrade(),
+                ssh_proxy_process,
+                proxy_incoming_tx,
+                proxy_outgoing_rx,
+                &mut cx,
+            );
+
+            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
+                log::error!("failed to establish connection: {}", error);
+                delegate.set_error(error.to_string(), &mut cx);
+                return Err(error);
+            }
+
+            let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
+
+            this.update(&mut cx, |this, _| {
+                *this.state.lock() = Some(State::Connected {
                     ssh_connection,
                     delegate,
                     forwarder: proxy,
                     multiplex_task,
-                    heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx),
-                }
-            };
-
-            this.update(&mut cx, |this, cx| {
-                this.inner_state.lock().replace(inner_state);
-                cx.notify();
+                    heartbeat_task,
+                });
             })?;
 
             Ok(this)
@@ -321,78 +463,192 @@ impl SshRemoteClient {
     }
 
     fn shutdown_processes(&self) {
-        let Some(mut state) = self.inner_state.lock().take() else {
+        let Some(state) = self.state.lock().take() else {
             return;
         };
         log::info!("shutting down ssh processes");
+
+        let State::Connected {
+            multiplex_task,
+            heartbeat_task,
+            ..
+        } = state
+        else {
+            return;
+        };
         // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
         // child of master_process.
-        let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
-        drop(task);
+        drop(multiplex_task);
         // Now drop the rest of state, which kills master process.
-        drop(state);
+        drop(heartbeat_task);
     }
 
-    fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
-        log::info!("Trying to reconnect to ssh server...");
-        let Some(state) = self.inner_state.lock().take() else {
-            return Err(anyhow!("reconnect is already in progress"));
-        };
+    fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
+        let mut lock = self.state.lock();
 
-        let workspace_identifier = self.unique_identifier.clone();
+        let can_reconnect = lock
+            .as_ref()
+            .map(|state| state.can_reconnect())
+            .unwrap_or(false);
+        if !can_reconnect {
+            let error = if let Some(state) = lock.as_ref() {
+                format!("invalid state, cannot reconnect while in state {state}")
+            } else {
+                "no state set".to_string()
+            };
+            return Err(anyhow!(error));
+        }
 
-        let SshRemoteClientState {
-            mut ssh_connection,
-            delegate,
-            forwarder: proxy,
-            multiplex_task,
-            heartbeat_task,
-        } = state;
-        drop(multiplex_task);
-        drop(heartbeat_task);
+        let state = lock.take().unwrap();
+        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
+            State::Connected {
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+            }
+            | State::HeartbeatMissed {
+                ssh_connection,
+                delegate,
+                forwarder,
+                multiplex_task,
+                heartbeat_task,
+                ..
+            } => {
+                drop(multiplex_task);
+                drop(heartbeat_task);
+                (0, ssh_connection, delegate, forwarder)
+            }
+            State::ReconnectFailed {
+                attempts,
+                ssh_connection,
+                delegate,
+                forwarder,
+                ..
+            } => (attempts, ssh_connection, delegate, forwarder),
+            State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(),
+        };
 
-        cx.spawn(|this, mut cx| async move {
-            let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
+        let attempts = attempts + 1;
+        if attempts > MAX_RECONNECT_ATTEMPTS {
+            log::error!(
+                "Failed to reconnect to after {} attempts, giving up",
+                MAX_RECONNECT_ATTEMPTS
+            );
+            *lock = Some(State::ReconnectExhausted);
+            return Ok(());
+        }
+        *lock = Some(State::Reconnecting);
+        drop(lock);
+
+        log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
+
+        let identifier = self.unique_identifier.clone();
+        let client = self.client.clone();
+        let reconnect_task = cx.spawn(|this, mut cx| async move {
+            macro_rules! failed {
+                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
+                    return State::ReconnectFailed {
+                        error: anyhow!($error),
+                        attempts: $attempts,
+                        ssh_connection: $ssh_connection,
+                        delegate: $delegate,
+                        forwarder: $forwarder,
+                    };
+                };
+            }
 
-            ssh_connection.master_process.kill()?;
-            ssh_connection
+            if let Err(error) = ssh_connection.master_process.kill() {
+                failed!(error, attempts, ssh_connection, delegate, forwarder);
+            };
+
+            if let Err(error) = ssh_connection
                 .master_process
                 .status()
                 .await
-                .context("Failed to kill ssh process")?;
+                .context("Failed to kill ssh process")
+            {
+                failed!(error, attempts, ssh_connection, delegate, forwarder);
+            }
 
             let connection_options = ssh_connection.socket.connection_options.clone();
 
-            let (ssh_connection, ssh_process) = Self::establish_connection(
-                workspace_identifier,
+            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
+            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
+                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+
+            let (ssh_connection, ssh_process) = match Self::establish_connection(
+                identifier,
                 connection_options,
                 delegate.clone(),
                 &mut cx,
             )
-            .await?;
+            .await
+            {
+                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
+                Err(error) => {
+                    failed!(error, attempts, ssh_connection, delegate, forwarder);
+                }
+            };
 
-            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
-                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+            let multiplex_task = Self::multiplex(
+                this.clone(),
+                ssh_process,
+                proxy_incoming_tx,
+                proxy_outgoing_rx,
+                &mut cx,
+            );
 
-            let inner_state = SshRemoteClientState {
+            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
+                failed!(error, attempts, ssh_connection, delegate, forwarder);
+            };
+
+            State::Connected {
                 ssh_connection,
                 delegate,
-                forwarder: proxy,
-                multiplex_task: Self::multiplex(
-                    this.clone(),
-                    ssh_process,
-                    proxy_incoming_tx,
-                    proxy_outgoing_rx,
-                    &mut cx,
-                ),
+                forwarder,
+                multiplex_task,
                 heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
-            };
+            }
+        });
 
-            this.update(&mut cx, |this, _| {
-                this.inner_state.lock().replace(inner_state);
+        cx.spawn(|this, mut cx| async move {
+            let new_state = reconnect_task.await;
+            this.update(&mut cx, |this, cx| {
+                match &new_state {
+                    State::Connecting
+                    | State::Reconnecting { .. }
+                    | State::HeartbeatMissed { .. } => {}
+                    State::Connected { .. } => {
+                        log::info!("Successfully reconnected");
+                    }
+                    State::ReconnectFailed {
+                        error, attempts, ..
+                    } => {
+                        log::error!(
+                            "Reconnect attempt {} failed: {:?}. Starting new attempt...",
+                            attempts,
+                            error
+                        );
+                    }
+                    State::ReconnectExhausted => {
+                        log::error!("Reconnect attempt failed and all attempts exhausted");
+                    }
+                }
+
+                let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. });
+                *this.state.lock() = Some(new_state);
+                cx.notify();
+                if reconnect_failed {
+                    this.reconnect(cx)
+                } else {
+                    Ok(())
+                }
             })
         })
-        .detach();
+        .detach_and_log_err(cx);
+
         Ok(())
     }
 
@@ -403,10 +659,6 @@ impl SshRemoteClient {
         cx.spawn(|mut cx| {
             let this = this.clone();
             async move {
-                const MAX_MISSED_HEARTBEATS: usize = 5;
-                const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
-                const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
-
                 let mut missed_heartbeats = 0;
 
                 let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
@@ -415,19 +667,7 @@ impl SshRemoteClient {
 
                     log::info!("Sending heartbeat to server...");
 
-                    let result = smol::future::or(
-                        async {
-                            client.request(proto::Ping {}).await?;
-                            Ok(())
-                        },
-                        async {
-                            smol::Timer::after(HEARTBEAT_TIMEOUT).await;
-
-                            Err(anyhow!("Timeout detected"))
-                        },
-                    )
-                    .await;
-
+                    let result = client.ping(HEARTBEAT_TIMEOUT).await;
                     if result.is_err() {
                         missed_heartbeats += 1;
                         log::warn!(
@@ -440,17 +680,10 @@ impl SshRemoteClient {
                         missed_heartbeats = 0;
                     }
 
-                    if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
-                        log::error!(
-                            "Missed last {} hearbeats. Reconnecting...",
-                            missed_heartbeats
-                        );
-
-                        this.update(&mut cx, |this, cx| {
-                            this.reconnect(cx)
-                                .context("failed to reconnect after missing heartbeats")
-                        })
-                        .context("failed to update weak reference, SshRemoteClient lost?")??;
+                    let result = this.update(&mut cx, |this, mut cx| {
+                        this.handle_heartbeat_result(missed_heartbeats, &mut cx)
+                    })?;
+                    if result.is_break() {
                         return Ok(());
                     }
                 }
@@ -458,6 +691,34 @@ impl SshRemoteClient {
         })
     }
 
+    fn handle_heartbeat_result(
+        &mut self,
+        missed_heartbeats: usize,
+        cx: &mut ModelContext<Self>,
+    ) -> ControlFlow<()> {
+        let state = self.state.lock().take().unwrap();
+        self.state.lock().replace(if missed_heartbeats > 0 {
+            state.heartbeat_missed()
+        } else {
+            state.heartbeat_recovered()
+        });
+        cx.notify();
+
+        if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
+            log::error!(
+                "Missed last {} heartbeats. Reconnecting...",
+                missed_heartbeats
+            );
+
+            self.reconnect(cx)
+                .context("failed to start reconnect process after missing heartbeats")
+                .log_err();
+            ControlFlow::Break(())
+        } else {
+            ControlFlow::Continue(())
+        }
+    }
+
     fn multiplex(
         this: WeakModel<Self>,
         mut ssh_proxy_process: Child,
@@ -611,10 +872,11 @@ impl SshRemoteClient {
     }
 
     pub fn ssh_args(&self) -> Option<Vec<String>> {
-        let state = self.inner_state.lock();
-        state
+        self.state
+            .lock()
             .as_ref()
-            .map(|state| state.ssh_connection.socket.ssh_args())
+            .and_then(|state| state.ssh_connection())
+            .map(|ssh_connection| ssh_connection.socket.ssh_args())
     }
 
     pub fn to_proto_client(&self) -> AnyProtoClient {
@@ -625,8 +887,12 @@ impl SshRemoteClient {
         self.connection_options.connection_string()
     }
 
-    pub fn is_reconnect_underway(&self) -> bool {
-        maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default()
+    pub fn connection_state(&self) -> ConnectionState {
+        self.state
+            .lock()
+            .as_ref()
+            .map(ConnectionState::from)
+            .unwrap_or(ConnectionState::Disconnected)
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -646,7 +912,7 @@ impl SshRemoteClient {
                     client,
                     unique_identifier: "fake".to_string(),
                     connection_options: SshConnectionOptions::default(),
-                    inner_state: Arc::new(Mutex::new(None)),
+                    state: Arc::new(Mutex::new(None)),
                 })
             }),
             server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
@@ -1046,6 +1312,20 @@ impl ChannelClient {
         }
     }
 
+    pub async fn ping(&self, timeout: Duration) -> Result<()> {
+        smol::future::or(
+            async {
+                self.request(proto::Ping {}).await?;
+                Ok(())
+            },
+            async {
+                smol::Timer::after(timeout).await;
+                Err(anyhow!("Timeout detected"))
+            },
+        )
+        .await
+    }
+
     pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
         log::debug!("ssh send name:{}", T::NAME);
         self.send_dynamic(payload.into_envelope(0, None, None))

crates/title_bar/Cargo.toml 🔗

@@ -41,6 +41,7 @@ gpui.workspace = true
 notifications.workspace = true
 project.workspace = true
 recent_projects.workspace = true
+remote.workspace = true
 rpc.workspace = true
 serde.workspace = true
 smallvec.workspace = true

crates/title_bar/src/title_bar.rs 🔗

@@ -265,10 +265,12 @@ impl TitleBar {
     fn render_ssh_project_host(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
         let host = self.project.read(cx).ssh_connection_string(cx)?;
         let meta = SharedString::from(format!("Connected to: {host}"));
-        let indicator_color = if self.project.read(cx).ssh_is_connected(cx)? {
-            Color::Success
-        } else {
-            Color::Warning
+        let indicator_color = match self.project.read(cx).ssh_connection_state(cx)? {
+            remote::ConnectionState::Connecting => Color::Info,
+            remote::ConnectionState::Connected => Color::Success,
+            remote::ConnectionState::HeartbeatMissed => Color::Warning,
+            remote::ConnectionState::Reconnecting => Color::Warning,
+            remote::ConnectionState::Disconnected => Color::Error,
         };
         let indicator = div()
             .absolute()