Use a synchronous mutex for `ConnectionPool`

Antonio Scandurra created

Change summary

crates/collab/src/integration_tests.rs   |  4 --
crates/collab/src/rpc.rs                 | 31 +++++++++++--------------
crates/collab/src/rpc/connection_pool.rs |  5 ++++
3 files changed, 20 insertions(+), 20 deletions(-)

Detailed changes

crates/collab/src/integration_tests.rs 🔗

@@ -6062,7 +6062,6 @@ async fn test_random_collaboration(
                 let user_connection_ids = server
                     .connection_pool
                     .lock()
-                    .await
                     .user_connection_ids(removed_guest_id)
                     .collect::<Vec<_>>();
                 assert_eq!(user_connection_ids.len(), 1);
@@ -6083,7 +6082,7 @@ async fn test_random_collaboration(
                 }
                 for user_id in &user_ids {
                     let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
-                    let pool = server.connection_pool.lock().await;
+                    let pool = server.connection_pool.lock();
                     for contact in contacts {
                         if let db::Contact::Accepted { user_id, .. } = contact {
                             if pool.is_user_online(user_id) {
@@ -6112,7 +6111,6 @@ async fn test_random_collaboration(
                 let user_connection_ids = server
                     .connection_pool
                     .lock()
-                    .await
                     .user_connection_ids(user_id)
                     .collect::<Vec<_>>();
                 assert_eq!(user_connection_ids.len(), 1);

crates/collab/src/rpc.rs 🔗

@@ -53,7 +53,7 @@ use std::{
     },
     time::Duration,
 };
-use tokio::sync::{watch, Mutex, MutexGuard};
+use tokio::sync::watch;
 use tower::ServiceBuilder;
 use tracing::{info_span, instrument, Instrument};
 
@@ -90,14 +90,14 @@ impl<R: RequestMessage> Response<R> {
 struct Session {
     user_id: UserId,
     connection_id: ConnectionId,
-    db: Arc<Mutex<DbHandle>>,
+    db: Arc<tokio::sync::Mutex<DbHandle>>,
     peer: Arc<Peer>,
-    connection_pool: Arc<Mutex<ConnectionPool>>,
+    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
     live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 }
 
 impl Session {
-    async fn db(&self) -> MutexGuard<DbHandle> {
+    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
         #[cfg(test)]
         tokio::task::yield_now().await;
         let guard = self.db.lock().await;
@@ -109,9 +109,7 @@ impl Session {
     async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
         #[cfg(test)]
         tokio::task::yield_now().await;
-        let guard = self.connection_pool.lock().await;
-        #[cfg(test)]
-        tokio::task::yield_now().await;
+        let guard = self.connection_pool.lock();
         ConnectionPoolGuard {
             guard,
             _not_send: PhantomData,
@@ -140,7 +138,7 @@ impl Deref for DbHandle {
 
 pub struct Server {
     peer: Arc<Peer>,
-    pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
+    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
     app_state: Arc<AppState>,
     executor: Executor,
     handlers: HashMap<TypeId, MessageHandler>,
@@ -148,7 +146,7 @@ pub struct Server {
 }
 
 pub(crate) struct ConnectionPoolGuard<'a> {
-    guard: MutexGuard<'a, ConnectionPool>,
+    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
     _not_send: PhantomData<Rc<()>>,
 }
 
@@ -268,7 +266,7 @@ impl Server {
                     }
 
                     {
-                        let pool = pool.lock().await;
+                        let pool = pool.lock();
                         for canceled_user_id in canceled_calls_to_user_ids {
                             for connection_id in pool.user_connection_ids(canceled_user_id) {
                                 peer.send(
@@ -286,7 +284,7 @@ impl Server {
                         let busy = db.is_user_busy(user_id).await.trace_err();
                         let contacts = db.get_contacts(user_id).await.trace_err();
                         if let Some((busy, contacts)) = busy.zip(contacts) {
-                            let pool = pool.lock().await;
+                            let pool = pool.lock();
                             let updated_contact = contact_for_user(user_id, false, busy, &pool);
                             for contact in contacts {
                                 if let db::Contact::Accepted {
@@ -456,7 +454,7 @@ impl Server {
             ).await?;
 
             {
-                let mut pool = this.connection_pool.lock().await;
+                let mut pool = this.connection_pool.lock();
                 pool.add_connection(connection_id, user_id, user.admin);
                 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 
@@ -475,7 +473,7 @@ impl Server {
             let session = Session {
                 user_id,
                 connection_id,
-                db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
+                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
                 peer: this.peer.clone(),
                 connection_pool: this.connection_pool.clone(),
                 live_kit_client: this.app_state.live_kit_client.clone()
@@ -550,7 +548,7 @@ impl Server {
     ) -> Result<()> {
         if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
             if let Some(code) = &user.invite_code {
-                let pool = self.connection_pool.lock().await;
+                let pool = self.connection_pool.lock();
                 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
                 for connection_id in pool.user_connection_ids(inviter_id) {
                     self.peer.send(
@@ -576,7 +574,7 @@ impl Server {
     pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
         if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
             if let Some(invite_code) = &user.invite_code {
-                let pool = self.connection_pool.lock().await;
+                let pool = self.connection_pool.lock();
                 for connection_id in pool.user_connection_ids(user_id) {
                     self.peer.send(
                         connection_id,
@@ -597,7 +595,7 @@ impl Server {
     pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
         ServerSnapshot {
             connection_pool: ConnectionPoolGuard {
-                guard: self.connection_pool.lock().await,
+                guard: self.connection_pool.lock(),
                 _not_send: PhantomData,
             },
             peer: &self.peer,
@@ -718,7 +716,6 @@ pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result
     let connections = server
         .connection_pool
         .lock()
-        .await
         .connections()
         .filter(|connection| !connection.admin)
         .count();

crates/collab/src/rpc/connection_pool.rs 🔗

@@ -23,6 +23,11 @@ pub struct Connection {
 }
 
 impl ConnectionPool {
+    pub fn reset(&mut self) {
+        self.connections.clear();
+        self.connected_users.clear();
+    }
+
     #[instrument(skip(self))]
     pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
         self.connections