Merge pull request #1729 from zed-industries/connection-timeout

Antonio Scandurra created

Introduce client-side timeout when trying to connect

Change summary

crates/client/src/client.rs | 135 ++++++++++++++++++++++++++++++--------
crates/client/src/test.rs   |  12 ++-
crates/collab/src/rpc.rs    |   3 
crates/rpc/src/peer.rs      |  55 ++++++---------
4 files changed, 137 insertions(+), 68 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -53,6 +53,8 @@ lazy_static! {
 }
 
 pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
+pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
+pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
 
 actions!(client, [Authenticate]);
 
@@ -330,7 +332,7 @@ impl Client {
                 let reconnect_interval = state.reconnect_interval;
                 state._reconnect_task = Some(cx.spawn(|cx| async move {
                     let mut rng = StdRng::from_entropy();
-                    let mut delay = Duration::from_millis(100);
+                    let mut delay = INITIAL_RECONNECTION_DELAY;
                     while let Err(error) = this.authenticate_and_connect(true, &cx).await {
                         log::error!("failed to connect {}", error);
                         if matches!(*this.status().borrow(), Status::ConnectionError) {
@@ -661,44 +663,51 @@ impl Client {
             self.set_status(Status::Reconnecting, cx);
         }
 
-        match self.establish_connection(&credentials, cx).await {
-            Ok(conn) => {
-                self.state.write().credentials = Some(credentials.clone());
-                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
-                    write_credentials_to_keychain(&credentials, cx).log_err();
-                }
-                self.set_connection(conn, cx).await;
-                Ok(())
-            }
-            Err(EstablishConnectionError::Unauthorized) => {
-                self.state.write().credentials.take();
-                if read_from_keychain {
-                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
-                    self.set_status(Status::SignedOut, cx);
-                    self.authenticate_and_connect(false, cx).await
-                } else {
-                    self.set_status(Status::ConnectionError, cx);
-                    Err(EstablishConnectionError::Unauthorized)?
+        futures::select_biased! {
+            connection = self.establish_connection(&credentials, cx).fuse() => {
+                match connection {
+                    Ok(conn) => {
+                        self.state.write().credentials = Some(credentials.clone());
+                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
+                            write_credentials_to_keychain(&credentials, cx).log_err();
+                        }
+                        self.set_connection(conn, cx);
+                        Ok(())
+                    }
+                    Err(EstablishConnectionError::Unauthorized) => {
+                        self.state.write().credentials.take();
+                        if read_from_keychain {
+                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+                            self.set_status(Status::SignedOut, cx);
+                            self.authenticate_and_connect(false, cx).await
+                        } else {
+                            self.set_status(Status::ConnectionError, cx);
+                            Err(EstablishConnectionError::Unauthorized)?
+                        }
+                    }
+                    Err(EstablishConnectionError::UpgradeRequired) => {
+                        self.set_status(Status::UpgradeRequired, cx);
+                        Err(EstablishConnectionError::UpgradeRequired)?
+                    }
+                    Err(error) => {
+                        self.set_status(Status::ConnectionError, cx);
+                        Err(error)?
+                    }
                 }
             }
-            Err(EstablishConnectionError::UpgradeRequired) => {
-                self.set_status(Status::UpgradeRequired, cx);
-                Err(EstablishConnectionError::UpgradeRequired)?
-            }
-            Err(error) => {
+            _ = cx.background().timer(CONNECTION_TIMEOUT).fuse() => {
                 self.set_status(Status::ConnectionError, cx);
-                Err(error)?
+                Err(anyhow!("timed out trying to establish connection"))
             }
         }
     }
 
-    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
+    fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
         let executor = cx.background();
         log::info!("add connection to peer");
         let (connection_id, handle_io, mut incoming) = self
             .peer
-            .add_connection(conn, move |duration| executor.timer(duration))
-            .await;
+            .add_connection(conn, move |duration| executor.timer(duration));
         log::info!("set status to connected {}", connection_id);
         self.set_status(Status::Connected { connection_id }, cx);
         cx.foreground()
@@ -1169,6 +1178,76 @@ mod tests {
         assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
+        deterministic.forbid_parking();
+
+        let user_id = 5;
+        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+        let mut status = client.status();
+
+        // Time out when client tries to connect.
+        client.override_authenticate(move |cx| {
+            cx.foreground().spawn(async move {
+                Ok(Credentials {
+                    user_id,
+                    access_token: "token".into(),
+                })
+            })
+        });
+        client.override_establish_connection(|_, cx| {
+            cx.foreground().spawn(async move {
+                future::pending::<()>().await;
+                unreachable!()
+            })
+        });
+        let auth_and_connect = cx.spawn({
+            let client = client.clone();
+            |cx| async move { client.authenticate_and_connect(false, &cx).await }
+        });
+        deterministic.run_until_parked();
+        assert!(matches!(status.next().await, Some(Status::Connecting)));
+
+        deterministic.advance_clock(CONNECTION_TIMEOUT);
+        assert!(matches!(
+            status.next().await,
+            Some(Status::ConnectionError { .. })
+        ));
+        auth_and_connect.await.unwrap_err();
+
+        // Allow the connection to be established.
+        let server = FakeServer::for_client(user_id, &client, cx).await;
+        assert!(matches!(
+            status.next().await,
+            Some(Status::Connected { .. })
+        ));
+
+        // Disconnect client.
+        server.forbid_connections();
+        server.disconnect();
+        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
+
+        // Time out when re-establishing the connection.
+        server.allow_connections();
+        client.override_establish_connection(|_, cx| {
+            cx.foreground().spawn(async move {
+                future::pending::<()>().await;
+                unreachable!()
+            })
+        });
+        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
+        assert!(matches!(
+            status.next().await,
+            Some(Status::Reconnecting { .. })
+        ));
+
+        deterministic.advance_clock(CONNECTION_TIMEOUT);
+        assert!(matches!(
+            status.next().await,
+            Some(Status::ReconnectionError { .. })
+        ));
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_authenticating_more_than_once(
         cx: &mut TestAppContext,

crates/client/src/test.rs 🔗

@@ -82,7 +82,7 @@ impl FakeServer {
 
                         let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
                         let (connection_id, io, incoming) =
-                            peer.add_test_connection(server_conn, cx.background()).await;
+                            peer.add_test_connection(server_conn, cx.background());
                         cx.background().spawn(io).detach();
                         let mut state = state.lock();
                         state.connection_id = Some(connection_id);
@@ -101,10 +101,12 @@ impl FakeServer {
     }
 
     pub fn disconnect(&self) {
-        self.peer.disconnect(self.connection_id());
-        let mut state = self.state.lock();
-        state.connection_id.take();
-        state.incoming.take();
+        if self.state.lock().connection_id.is_some() {
+            self.peer.disconnect(self.connection_id());
+            let mut state = self.state.lock();
+            state.connection_id.take();
+            state.incoming.take();
+        }
     }
 
     pub fn auth_count(&self) -> usize {

crates/collab/src/rpc.rs 🔗

@@ -365,8 +365,7 @@ impl Server {
                             timer.await;
                         }
                     }
-                })
-                .await;
+                });
 
             tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 

crates/rpc/src/peer.rs 🔗

@@ -113,7 +113,7 @@ impl Peer {
     }
 
     #[instrument(skip_all)]
-    pub async fn add_connection<F, Fut, Out>(
+    pub fn add_connection<F, Fut, Out>(
         self: &Arc<Self>,
         connection: Connection,
         create_timer: F,
@@ -326,7 +326,7 @@ impl Peer {
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub async fn add_test_connection(
+    pub fn add_test_connection(
         self: &Arc<Self>,
         connection: Connection,
         executor: Arc<gpui::executor::Background>,
@@ -337,7 +337,6 @@ impl Peer {
     ) {
         let executor = executor.clone();
         self.add_connection(connection, move |duration| executor.timer(duration))
-            .await
     }
 
     pub fn disconnect(&self, connection_id: ConnectionId) {
@@ -522,21 +521,17 @@ mod tests {
 
         let (client1_to_server_conn, server_to_client_1_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client1_conn_id, io_task1, client1_incoming) = client1
-            .add_test_connection(client1_to_server_conn, cx.background())
-            .await;
-        let (_, io_task2, server_incoming1) = server
-            .add_test_connection(server_to_client_1_conn, cx.background())
-            .await;
+        let (client1_conn_id, io_task1, client1_incoming) =
+            client1.add_test_connection(client1_to_server_conn, cx.background());
+        let (_, io_task2, server_incoming1) =
+            server.add_test_connection(server_to_client_1_conn, cx.background());
 
         let (client2_to_server_conn, server_to_client_2_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client2_conn_id, io_task3, client2_incoming) = client2
-            .add_test_connection(client2_to_server_conn, cx.background())
-            .await;
-        let (_, io_task4, server_incoming2) = server
-            .add_test_connection(server_to_client_2_conn, cx.background())
-            .await;
+        let (client2_conn_id, io_task3, client2_incoming) =
+            client2.add_test_connection(client2_to_server_conn, cx.background());
+        let (_, io_task4, server_incoming2) =
+            server.add_test_connection(server_to_client_2_conn, cx.background());
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -619,12 +614,10 @@ mod tests {
 
         let (client_to_server_conn, server_to_client_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
-            .add_test_connection(client_to_server_conn, cx.background())
-            .await;
-        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
-            .add_test_connection(server_to_client_conn, cx.background())
-            .await;
+        let (client_to_server_conn_id, io_task1, mut client_incoming) =
+            client.add_test_connection(client_to_server_conn, cx.background());
+        let (server_to_client_conn_id, io_task2, mut server_incoming) =
+            server.add_test_connection(server_to_client_conn, cx.background());
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -719,12 +712,10 @@ mod tests {
 
         let (client_to_server_conn, server_to_client_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
-            .add_test_connection(client_to_server_conn, cx.background())
-            .await;
-        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
-            .add_test_connection(server_to_client_conn, cx.background())
-            .await;
+        let (client_to_server_conn_id, io_task1, mut client_incoming) =
+            client.add_test_connection(client_to_server_conn, cx.background());
+        let (server_to_client_conn_id, io_task2, mut server_incoming) =
+            server.add_test_connection(server_to_client_conn, cx.background());
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -832,9 +823,8 @@ mod tests {
         let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
-        let (connection_id, io_handler, mut incoming) = client
-            .add_test_connection(client_conn, cx.background())
-            .await;
+        let (connection_id, io_handler, mut incoming) =
+            client.add_test_connection(client_conn, cx.background());
 
         let (io_ended_tx, io_ended_rx) = oneshot::channel();
         executor
@@ -868,9 +858,8 @@ mod tests {
         let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
-        let (connection_id, io_handler, mut incoming) = client
-            .add_test_connection(client_conn, cx.background())
-            .await;
+        let (connection_id, io_handler, mut incoming) =
+            client.add_test_connection(client_conn, cx.background());
         executor.spawn(io_handler).detach();
         executor
             .spawn(async move { incoming.next().await })