Ensure client reconnects after erroring during the handshake (#31278)

Antonio Scandurra created

Release Notes:

- Fixed a bug that prevented Zed from reconnecting after erroring during
the initial handshake with the server.

Change summary

crates/client/src/client.rs                  | 10 ++++++
crates/collab/src/db.rs                      | 17 ++++++++++--
crates/collab/src/db/tests.rs                | 26 ++++++++++++++-----
crates/collab/src/tests/integration_tests.rs | 29 ++++++++++++++++++++++
crates/collab/src/tests/test_server.rs       | 11 ++++++--
5 files changed, 79 insertions(+), 14 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -905,7 +905,15 @@ impl Client {
                         }
 
                         futures::select_biased! {
-                            result = self.set_connection(conn, cx).fuse() => ConnectionResult::Result(result.context("client auth and connect")),
+                            result = self.set_connection(conn, cx).fuse() => {
+                                match result.context("client auth and connect") {
+                                    Ok(()) => ConnectionResult::Result(Ok(())),
+                                    Err(err) => {
+                                        self.set_status(Status::ConnectionError, cx);
+                                        ConnectionResult::Result(Err(err))
+                                    },
+                                }
+                            },
                             _ = timeout => {
                                 self.set_status(Status::ConnectionError, cx);
                                 ConnectionResult::Timeout

crates/collab/src/db.rs 🔗

@@ -56,6 +56,12 @@ pub use sea_orm::ConnectOptions;
 pub use tables::user::Model as User;
 pub use tables::*;
 
+#[cfg(test)]
+pub struct DatabaseTestOptions {
+    pub runtime: tokio::runtime::Runtime,
+    pub query_failure_probability: parking_lot::Mutex<f64>,
+}
+
 /// Database gives you a handle that lets you access the database.
 /// It handles pooling internally.
 pub struct Database {
@@ -68,7 +74,7 @@ pub struct Database {
     notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
     notification_kinds_by_name: HashMap<String, NotificationKindId>,
     #[cfg(test)]
-    runtime: Option<tokio::runtime::Runtime>,
+    test_options: Option<DatabaseTestOptions>,
 }
 
 // The `Database` type has so many methods that its impl blocks are split into
@@ -87,7 +93,7 @@ impl Database {
             notification_kinds_by_name: HashMap::default(),
             executor,
             #[cfg(test)]
-            runtime: None,
+            test_options: None,
         })
     }
 
@@ -355,11 +361,16 @@ impl Database {
     {
         #[cfg(test)]
         {
+            let test_options = self.test_options.as_ref().unwrap();
             if let Executor::Deterministic(executor) = &self.executor {
                 executor.simulate_random_delay().await;
+                let fail_probability = *test_options.query_failure_probability.lock();
+                if executor.rng().gen_bool(fail_probability) {
+                    return Err(anyhow!("simulated query failure"))?;
+                }
             }
 
-            self.runtime.as_ref().unwrap().block_on(future)
+            test_options.runtime.block_on(future)
         }
 
         #[cfg(not(test))]

crates/collab/src/db/tests.rs 🔗

@@ -30,7 +30,7 @@ pub struct TestDb {
 }
 
 impl TestDb {
-    pub fn sqlite(background: BackgroundExecutor) -> Self {
+    pub fn sqlite(executor: BackgroundExecutor) -> Self {
         let url = "sqlite::memory:";
         let runtime = tokio::runtime::Builder::new_current_thread()
             .enable_io()
@@ -41,7 +41,7 @@ impl TestDb {
         let mut db = runtime.block_on(async {
             let mut options = ConnectOptions::new(url);
             options.max_connections(5);
-            let mut db = Database::new(options, Executor::Deterministic(background))
+            let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
                 .await
                 .unwrap();
             let sql = include_str!(concat!(
@@ -59,7 +59,10 @@ impl TestDb {
             db
         });
 
-        db.runtime = Some(runtime);
+        db.test_options = Some(DatabaseTestOptions {
+            runtime,
+            query_failure_probability: parking_lot::Mutex::new(0.0),
+        });
 
         Self {
             db: Some(Arc::new(db)),
@@ -67,7 +70,7 @@ impl TestDb {
         }
     }
 
-    pub fn postgres(background: BackgroundExecutor) -> Self {
+    pub fn postgres(executor: BackgroundExecutor) -> Self {
         static LOCK: Mutex<()> = Mutex::new(());
 
         let _guard = LOCK.lock();
@@ -90,7 +93,7 @@ impl TestDb {
             options
                 .max_connections(5)
                 .idle_timeout(Duration::from_secs(0));
-            let mut db = Database::new(options, Executor::Deterministic(background))
+            let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
                 .await
                 .unwrap();
             let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
@@ -101,7 +104,10 @@ impl TestDb {
             db
         });
 
-        db.runtime = Some(runtime);
+        db.test_options = Some(DatabaseTestOptions {
+            runtime,
+            query_failure_probability: parking_lot::Mutex::new(0.0),
+        });
 
         Self {
             db: Some(Arc::new(db)),
@@ -112,6 +118,12 @@ impl TestDb {
     pub fn db(&self) -> &Arc<Database> {
         self.db.as_ref().unwrap()
     }
+
+    pub fn set_query_failure_probability(&self, probability: f64) {
+        let database = self.db.as_ref().unwrap();
+        let test_options = database.test_options.as_ref().unwrap();
+        *test_options.query_failure_probability.lock() = probability;
+    }
 }
 
 #[macro_export]
@@ -136,7 +148,7 @@ impl Drop for TestDb {
     fn drop(&mut self) {
         let db = self.db.take().unwrap();
         if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
-            db.runtime.as_ref().unwrap().block_on(async {
+            db.test_options.as_ref().unwrap().runtime.block_on(async {
                 use util::ResultExt;
                 let query = "
                         SELECT pg_terminate_backend(pg_stat_activity.pid)

crates/collab/src/tests/integration_tests.rs 🔗

@@ -61,6 +61,35 @@ fn init_logger() {
     }
 }
 
+#[gpui::test(iterations = 10)]
+async fn test_database_failure_during_client_reconnection(
+    executor: BackgroundExecutor,
+    cx: &mut TestAppContext,
+) {
+    let mut server = TestServer::start(executor.clone()).await;
+    let client = server.create_client(cx, "user_a").await;
+
+    // Keep disconnecting the client until a database failure prevents it from
+    // reconnecting.
+    server.test_db.set_query_failure_probability(0.3);
+    loop {
+        server.disconnect_client(client.peer_id().unwrap());
+        executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
+        if !client.status().borrow().is_connected() {
+            break;
+        }
+    }
+
+    // Make the database healthy again and ensure the client can finally connect.
+    server.test_db.set_query_failure_probability(0.);
+    executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
+    assert!(
+        matches!(*client.status().borrow(), client::Status::Connected { .. }),
+        "status was {:?}",
+        *client.status().borrow()
+    );
+}
+
 #[gpui::test(iterations = 10)]
 async fn test_basic_calls(
     executor: BackgroundExecutor,

crates/collab/src/tests/test_server.rs 🔗

@@ -52,11 +52,11 @@ use livekit_client::test::TestServer as LivekitTestServer;
 pub struct TestServer {
     pub app_state: Arc<AppState>,
     pub test_livekit_server: Arc<LivekitTestServer>,
+    pub test_db: TestDb,
     server: Arc<Server>,
     next_github_user_id: i32,
     connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
     forbid_connections: Arc<AtomicBool>,
-    _test_db: TestDb,
 }
 
 pub struct TestClient {
@@ -117,7 +117,7 @@ impl TestServer {
             connection_killers: Default::default(),
             forbid_connections: Default::default(),
             next_github_user_id: 0,
-            _test_db: test_db,
+            test_db,
             test_livekit_server: livekit_server,
         }
     }
@@ -241,7 +241,12 @@ impl TestServer {
                         let user = db
                             .get_user_by_id(user_id)
                             .await
-                            .expect("retrieving user failed")
+                            .map_err(|e| {
+                                EstablishConnectionError::Other(anyhow!(
+                                    "retrieving user failed: {}",
+                                    e
+                                ))
+                            })?
                             .unwrap();
                         cx.background_spawn(server.handle_connection(
                             server_conn,