Introduce client-side timeout when trying to connect

Antonio Scandurra created

Change summary

crates/client/src/client.rs | 131 +++++++++++++++++++++++++++++++-------
crates/client/src/test.rs   |  10 +-
2 files changed, 112 insertions(+), 29 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -53,6 +53,8 @@ lazy_static! {
 }
 
 pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
+pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
+pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
 
 actions!(client, [Authenticate]);
 
@@ -330,7 +332,7 @@ impl Client {
                 let reconnect_interval = state.reconnect_interval;
                 state._reconnect_task = Some(cx.spawn(|cx| async move {
                     let mut rng = StdRng::from_entropy();
-                    let mut delay = Duration::from_millis(100);
+                    let mut delay = INITIAL_RECONNECTION_DELAY;
                     while let Err(error) = this.authenticate_and_connect(true, &cx).await {
                         log::error!("failed to connect {}", error);
                         if matches!(*this.status().borrow(), Status::ConnectionError) {
@@ -661,33 +663,42 @@ impl Client {
             self.set_status(Status::Reconnecting, cx);
         }
 
-        match self.establish_connection(&credentials, cx).await {
-            Ok(conn) => {
-                self.state.write().credentials = Some(credentials.clone());
-                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
-                    write_credentials_to_keychain(&credentials, cx).log_err();
-                }
-                self.set_connection(conn, cx).await;
-                Ok(())
-            }
-            Err(EstablishConnectionError::Unauthorized) => {
-                self.state.write().credentials.take();
-                if read_from_keychain {
-                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
-                    self.set_status(Status::SignedOut, cx);
-                    self.authenticate_and_connect(false, cx).await
-                } else {
-                    self.set_status(Status::ConnectionError, cx);
-                    Err(EstablishConnectionError::Unauthorized)?
+        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
+        futures::select_biased! {
+            connection = self.establish_connection(&credentials, cx).fuse() => {
+                match connection {
+                    Ok(conn) => {
+                        self.state.write().credentials = Some(credentials.clone());
+                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
+                            write_credentials_to_keychain(&credentials, cx).log_err();
+                        }
+                        self.set_connection(conn, cx).await;
+                        Ok(())
+                    }
+                    Err(EstablishConnectionError::Unauthorized) => {
+                        self.state.write().credentials.take();
+                        if read_from_keychain {
+                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+                            self.set_status(Status::SignedOut, cx);
+                            self.authenticate_and_connect(false, cx).await
+                        } else {
+                            self.set_status(Status::ConnectionError, cx);
+                            Err(EstablishConnectionError::Unauthorized)?
+                        }
+                    }
+                    Err(EstablishConnectionError::UpgradeRequired) => {
+                        self.set_status(Status::UpgradeRequired, cx);
+                        Err(EstablishConnectionError::UpgradeRequired)?
+                    }
+                    Err(error) => {
+                        self.set_status(Status::ConnectionError, cx);
+                        Err(error)?
+                    }
                 }
             }
-            Err(EstablishConnectionError::UpgradeRequired) => {
-                self.set_status(Status::UpgradeRequired, cx);
-                Err(EstablishConnectionError::UpgradeRequired)?
-            }
-            Err(error) => {
+            _ = timeout => {
                 self.set_status(Status::ConnectionError, cx);
-                Err(error)?
+                Err(anyhow!("timed out trying to establish connection"))
             }
         }
     }
@@ -1169,6 +1180,76 @@ mod tests {
         assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
+        deterministic.forbid_parking();
+
+        let user_id = 5;
+        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+        let mut status = client.status();
+
+        // Time out when client tries to connect.
+        client.override_authenticate(move |cx| {
+            cx.foreground().spawn(async move {
+                Ok(Credentials {
+                    user_id,
+                    access_token: "token".into(),
+                })
+            })
+        });
+        client.override_establish_connection(|_, cx| {
+            cx.foreground().spawn(async move {
+                future::pending::<()>().await;
+                unreachable!()
+            })
+        });
+        let auth_and_connect = cx.spawn({
+            let client = client.clone();
+            |cx| async move { client.authenticate_and_connect(false, &cx).await }
+        });
+        deterministic.run_until_parked();
+        assert!(matches!(status.next().await, Some(Status::Connecting)));
+
+        deterministic.advance_clock(CONNECTION_TIMEOUT);
+        assert!(matches!(
+            status.next().await,
+            Some(Status::ConnectionError { .. })
+        ));
+        auth_and_connect.await.unwrap_err();
+
+        // Allow the connection to be established.
+        let server = FakeServer::for_client(user_id, &client, cx).await;
+        assert!(matches!(
+            status.next().await,
+            Some(Status::Connected { .. })
+        ));
+
+        // Disconnect client.
+        server.forbid_connections();
+        server.disconnect();
+        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
+
+        // Time out when re-establishing the connection.
+        server.allow_connections();
+        client.override_establish_connection(|_, cx| {
+            cx.foreground().spawn(async move {
+                future::pending::<()>().await;
+                unreachable!()
+            })
+        });
+        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
+        assert!(matches!(
+            status.next().await,
+            Some(Status::Reconnecting { .. })
+        ));
+
+        deterministic.advance_clock(CONNECTION_TIMEOUT);
+        assert!(matches!(
+            status.next().await,
+            Some(Status::ReconnectionError { .. })
+        ));
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_authenticating_more_than_once(
         cx: &mut TestAppContext,

crates/client/src/test.rs 🔗

@@ -101,10 +101,12 @@ impl FakeServer {
     }
 
     pub fn disconnect(&self) {
-        self.peer.disconnect(self.connection_id());
-        let mut state = self.state.lock();
-        state.connection_id.take();
-        state.incoming.take();
+        if self.state.lock().connection_id.is_some() {
+            self.peer.disconnect(self.connection_id());
+            let mut state = self.state.lock();
+            state.connection_id.take();
+            state.incoming.take();
+        }
     }
 
     pub fn auth_count(&self) -> usize {