Allow re-joining room after server restarts

Antonio Scandurra created

Change summary

crates/call/src/room.rs                |  4 
crates/collab/src/db.rs                | 70 +++++++++++++++-------
crates/collab/src/integration_tests.rs | 85 ++++++++++++++++++++++++++-
crates/collab/src/main.rs              |  2 
crates/collab/src/rpc.rs               | 11 +++
5 files changed, 138 insertions(+), 34 deletions(-)

Detailed changes

crates/call/src/room.rs 🔗

@@ -15,7 +15,7 @@ use project::Project;
 use std::{mem, sync::Arc, time::Duration};
 use util::{post_inc, ResultExt};
 
-pub const RECONNECTION_TIMEOUT: Duration = client::RECEIVE_TIMEOUT;
+pub const RECONNECT_TIMEOUT: Duration = client::RECEIVE_TIMEOUT;
 
 #[derive(Clone, Debug, PartialEq, Eq)]
 pub enum Event {
@@ -262,7 +262,7 @@ impl Room {
                     });
 
                 // Wait for client to re-establish a connection to the server.
-                let mut reconnection_timeout = cx.background().timer(RECONNECTION_TIMEOUT).fuse();
+                let mut reconnection_timeout = cx.background().timer(RECONNECT_TIMEOUT).fuse();
                 let client_reconnection = async {
                     loop {
                         if let Some(status) = client_status.next().await {

crates/collab/src/db.rs 🔗

@@ -21,6 +21,7 @@ use dashmap::DashMap;
 use futures::StreamExt;
 use hyper::StatusCode;
 use rpc::{proto, ConnectionId};
+use sea_orm::Condition;
 pub use sea_orm::ConnectOptions;
 use sea_orm::{
     entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
@@ -47,7 +48,7 @@ pub struct Database {
     background: Option<std::sync::Arc<gpui::executor::Background>>,
     #[cfg(test)]
     runtime: Option<tokio::runtime::Runtime>,
-    epoch: Uuid,
+    epoch: parking_lot::RwLock<Uuid>,
 }
 
 impl Database {
@@ -60,10 +61,20 @@ impl Database {
             background: None,
             #[cfg(test)]
             runtime: None,
-            epoch: Uuid::new_v4(),
+            epoch: parking_lot::RwLock::new(Uuid::new_v4()),
         })
     }
 
+    #[cfg(test)]
+    pub fn reset(&self) {
+        self.rooms.clear();
+        *self.epoch.write() = Uuid::new_v4();
+    }
+
+    fn epoch(&self) -> Uuid {
+        *self.epoch.read()
+    }
+
     pub async fn migrate(
         &self,
         migrations_path: &Path,
@@ -105,24 +116,31 @@ impl Database {
         Ok(new_migrations)
     }
 
-    pub async fn clear_stale_data(&self) -> Result<()> {
+    pub async fn delete_stale_projects(&self) -> Result<()> {
         self.transaction(|tx| async move {
             project_collaborator::Entity::delete_many()
-                .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch))
+                .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch()))
                 .exec(&*tx)
                 .await?;
+            project::Entity::delete_many()
+                .filter(project::Column::HostConnectionEpoch.ne(self.epoch()))
+                .exec(&*tx)
+                .await?;
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn delete_stale_rooms(&self) -> Result<()> {
+        self.transaction(|tx| async move {
             room_participant::Entity::delete_many()
                 .filter(
                     room_participant::Column::AnsweringConnectionEpoch
-                        .ne(self.epoch)
-                        .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)),
+                        .ne(self.epoch())
+                        .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch())),
                 )
                 .exec(&*tx)
                 .await?;
-            project::Entity::delete_many()
-                .filter(project::Column::HostConnectionEpoch.ne(self.epoch))
-                .exec(&*tx)
-                .await?;
             room::Entity::delete_many()
                 .filter(
                     room::Column::Id.not_in_subquery(
@@ -1033,11 +1051,11 @@ impl Database {
                 room_id: ActiveValue::set(room_id),
                 user_id: ActiveValue::set(user_id),
                 answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
-                answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
+                answering_connection_epoch: ActiveValue::set(Some(self.epoch())),
                 answering_connection_lost: ActiveValue::set(false),
                 calling_user_id: ActiveValue::set(user_id),
                 calling_connection_id: ActiveValue::set(connection_id.0 as i32),
-                calling_connection_epoch: ActiveValue::set(self.epoch),
+                calling_connection_epoch: ActiveValue::set(self.epoch()),
                 ..Default::default()
             }
             .insert(&*tx)
@@ -1064,7 +1082,7 @@ impl Database {
                 answering_connection_lost: ActiveValue::set(false),
                 calling_user_id: ActiveValue::set(calling_user_id),
                 calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32),
-                calling_connection_epoch: ActiveValue::set(self.epoch),
+                calling_connection_epoch: ActiveValue::set(self.epoch()),
                 initial_project_id: ActiveValue::set(initial_project_id),
                 ..Default::default()
             }
@@ -1174,18 +1192,22 @@ impl Database {
         self.room_transaction(|tx| async move {
             let result = room_participant::Entity::update_many()
                 .filter(
-                    room_participant::Column::RoomId
-                        .eq(room_id)
-                        .and(room_participant::Column::UserId.eq(user_id))
-                        .and(
-                            room_participant::Column::AnsweringConnectionId
-                                .is_null()
-                                .or(room_participant::Column::AnsweringConnectionLost.eq(true)),
+                    Condition::all()
+                        .add(room_participant::Column::RoomId.eq(room_id))
+                        .add(room_participant::Column::UserId.eq(user_id))
+                        .add(
+                            Condition::any()
+                                .add(room_participant::Column::AnsweringConnectionId.is_null())
+                                .add(room_participant::Column::AnsweringConnectionLost.eq(true))
+                                .add(
+                                    room_participant::Column::AnsweringConnectionEpoch
+                                        .ne(self.epoch()),
+                                ),
                         ),
                 )
                 .set(room_participant::ActiveModel {
                     answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
-                    answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
+                    answering_connection_epoch: ActiveValue::set(Some(self.epoch())),
                     answering_connection_lost: ActiveValue::set(false),
                     ..Default::default()
                 })
@@ -1591,7 +1613,7 @@ impl Database {
                 room_id: ActiveValue::set(participant.room_id),
                 host_user_id: ActiveValue::set(participant.user_id),
                 host_connection_id: ActiveValue::set(connection_id.0 as i32),
-                host_connection_epoch: ActiveValue::set(self.epoch),
+                host_connection_epoch: ActiveValue::set(self.epoch()),
                 ..Default::default()
             }
             .insert(&*tx)
@@ -1616,7 +1638,7 @@ impl Database {
             project_collaborator::ActiveModel {
                 project_id: ActiveValue::set(project.id),
                 connection_id: ActiveValue::set(connection_id.0 as i32),
-                connection_epoch: ActiveValue::set(self.epoch),
+                connection_epoch: ActiveValue::set(self.epoch()),
                 user_id: ActiveValue::set(participant.user_id),
                 replica_id: ActiveValue::set(ReplicaId(0)),
                 is_host: ActiveValue::set(true),
@@ -1930,7 +1952,7 @@ impl Database {
             let new_collaborator = project_collaborator::ActiveModel {
                 project_id: ActiveValue::set(project_id),
                 connection_id: ActiveValue::set(connection_id.0 as i32),
-                connection_epoch: ActiveValue::set(self.epoch),
+                connection_epoch: ActiveValue::set(self.epoch()),
                 user_id: ActiveValue::set(participant.user_id),
                 replica_id: ActiveValue::set(replica_id),
                 is_host: ActiveValue::set(false),

crates/collab/src/integration_tests.rs 🔗

@@ -4,7 +4,6 @@ use crate::{
     rpc::{Server, RECONNECT_TIMEOUT},
     AppState,
 };
-use ::rpc::Peer;
 use anyhow::anyhow;
 use call::{room, ActiveCall, ParticipantLocation, Room};
 use client::{
@@ -365,7 +364,7 @@ async fn test_room_uniqueness(
 }
 
 #[gpui::test(iterations = 10)]
-async fn test_disconnecting_from_room(
+async fn test_client_disconnecting_from_room(
     deterministic: Arc<Deterministic>,
     cx_a: &mut TestAppContext,
     cx_b: &mut TestAppContext,
@@ -516,6 +515,75 @@ async fn test_disconnecting_from_room(
     );
 }
 
+#[gpui::test(iterations = 10)]
+async fn test_server_restarts(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(cx_a.background()).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+    server
+        .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+        .await;
+
+    let active_call_a = cx_a.read(ActiveCall::global);
+    let active_call_b = cx_b.read(ActiveCall::global);
+
+    // Call user B from client A.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.invite(client_b.user_id().unwrap(), None, cx)
+        })
+        .await
+        .unwrap();
+    let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
+
+    // User B receives the call and joins the room.
+    let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming());
+    incoming_call_b.next().await.unwrap().unwrap();
+    active_call_b
+        .update(cx_b, |call, cx| call.accept_incoming(cx))
+        .await
+        .unwrap();
+    let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone());
+    deterministic.run_until_parked();
+    assert_eq!(
+        room_participants(&room_a, cx_a),
+        RoomParticipants {
+            remote: vec!["user_b".to_string()],
+            pending: Default::default()
+        }
+    );
+    assert_eq!(
+        room_participants(&room_b, cx_b),
+        RoomParticipants {
+            remote: vec!["user_a".to_string()],
+            pending: Default::default()
+        }
+    );
+
+    // User A automatically reconnects to the room when the server restarts.
+    server.restart().await;
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
+    assert_eq!(
+        room_participants(&room_a, cx_a),
+        RoomParticipants {
+            remote: vec!["user_b".to_string()],
+            pending: Default::default()
+        }
+    );
+    assert_eq!(
+        room_participants(&room_b, cx_b),
+        RoomParticipants {
+            remote: vec!["user_a".to_string()],
+            pending: Default::default()
+        }
+    );
+}
+
 #[gpui::test(iterations = 10)]
 async fn test_calls_on_multiple_connections(
     deterministic: Arc<Deterministic>,
@@ -5933,7 +6001,6 @@ async fn test_random_collaboration(
 }
 
 struct TestServer {
-    peer: Arc<Peer>,
     app_state: Arc<AppState>,
     server: Arc<Server>,
     connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
@@ -5962,10 +6029,9 @@ impl TestServer {
         )
         .unwrap();
         let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
-        let peer = Peer::new();
         let server = Server::new(app_state.clone());
+        server.start().await.unwrap();
         Self {
-            peer,
             app_state,
             server,
             connection_killers: Default::default(),
@@ -5975,6 +6041,14 @@ impl TestServer {
         }
     }
 
+    async fn restart(&self) {
+        self.forbid_connections();
+        self.server.teardown();
+        self.app_state.db.reset();
+        self.server.start().await.unwrap();
+        self.allow_connections();
+    }
+
     async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
         cx.update(|cx| {
             cx.set_global(HomeDir(Path::new("/tmp/").to_path_buf()));
@@ -6192,7 +6266,6 @@ impl Deref for TestServer {
 
 impl Drop for TestServer {
     fn drop(&mut self) {
-        self.peer.reset();
         self.server.teardown();
         self.test_live_kit_server.teardown().unwrap();
     }

crates/collab/src/main.rs 🔗

@@ -52,12 +52,12 @@ async fn main() -> Result<()> {
             init_tracing(&config);
 
             let state = AppState::new(config).await?;
-            state.db.clear_stale_data().await?;
 
             let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
                 .expect("failed to bind TCP listener");
 
             let rpc_server = collab::rpc::Server::new(state.clone());
+            rpc_server.start().await?;
 
             let app = collab::api::routes(rpc_server.clone(), state.clone())
                 .merge(collab::rpc::routes(rpc_server.clone()))

crates/collab/src/rpc.rs 🔗

@@ -237,7 +237,15 @@ impl Server {
         Arc::new(server)
     }
 
+    pub async fn start(&self) -> Result<()> {
+        self.app_state.db.delete_stale_projects().await?;
+        // TODO: delete stale rooms after timeout.
+        // self.app_state.db.delete_stale_rooms().await?;
+        Ok(())
+    }
+
     pub fn teardown(&self) {
+        self.peer.reset();
         let _ = self.teardown.send(());
     }
 
@@ -339,7 +347,7 @@ impl Server {
         let user_id = user.id;
         let login = user.github_login;
         let span = info_span!("handle connection", %user_id, %login, %address);
-        let teardown = self.teardown.subscribe();
+        let mut teardown = self.teardown.subscribe();
         async move {
             let (connection_id, handle_io, mut incoming_rx) = this
                 .peer
@@ -409,6 +417,7 @@ impl Server {
                 let next_message = incoming_rx.next().fuse();
                 futures::pin_mut!(next_message);
                 futures::select_biased! {
+                    _ = teardown.changed().fuse() => return Ok(()),
                     result = handle_io => {
                         if let Err(error) = result {
                             tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");