Rename `Store` to `ConnectionPool`

Antonio Scandurra created

Change summary

crates/collab/src/integration_tests.rs   |  21 +-
crates/collab/src/rpc.rs                 | 167 ++++++++++++++++++-------
crates/collab/src/rpc/connection_pool.rs |  57 --------
3 files changed, 133 insertions(+), 112 deletions(-)

Detailed changes

crates/collab/src/integration_tests.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    db::{NewUserParams, SqliteTestDb as TestDb, UserId},
+    db::{self, NewUserParams, SqliteTestDb as TestDb, UserId},
     rpc::{Executor, Server},
     AppState,
 };
@@ -5469,18 +5469,15 @@ async fn test_random_collaboration(
                 }
                 for user_id in &user_ids {
                     let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
-                    let contacts = server
-                        .store
-                        .lock()
-                        .await
-                        .build_initial_contacts_update(contacts)
-                        .contacts;
+                    let pool = server.connection_pool.lock().await;
                     for contact in contacts {
-                        if contact.online {
-                            assert_ne!(
-                                contact.user_id, removed_guest_id.0 as u64,
-                                "removed guest is still a contact of another peer"
-                            );
+                        if let db::Contact::Accepted { user_id, .. } = contact {
+                            if pool.is_user_online(user_id) {
+                                assert_ne!(
+                                    user_id, removed_guest_id,
+                                    "removed guest is still a contact of another peer"
+                                );
+                            }
                         }
                     }
                 }

crates/collab/src/rpc.rs 🔗

@@ -1,4 +1,4 @@
-mod store;
+mod connection_pool;
 
 use crate::{
     auth,
@@ -23,6 +23,7 @@ use axum::{
     Extension, Router, TypedHeader,
 };
 use collections::{HashMap, HashSet};
+pub use connection_pool::ConnectionPool;
 use futures::{
     channel::oneshot,
     future::{self, BoxFuture},
@@ -49,7 +50,6 @@ use std::{
     },
     time::Duration,
 };
-pub use store::Store;
 use tokio::{
     sync::{Mutex, MutexGuard},
     time::Sleep,
@@ -103,7 +103,7 @@ impl<R: RequestMessage> Response<R> {
 
 pub struct Server {
     peer: Arc<Peer>,
-    pub(crate) store: Mutex<Store>,
+    pub(crate) connection_pool: Mutex<ConnectionPool>,
     app_state: Arc<AppState>,
     handlers: HashMap<TypeId, MessageHandler>,
 }
@@ -117,8 +117,8 @@ pub trait Executor: Send + Clone {
 #[derive(Clone)]
 pub struct RealExecutor;
 
-pub(crate) struct StoreGuard<'a> {
-    guard: MutexGuard<'a, Store>,
+pub(crate) struct ConnectionPoolGuard<'a> {
+    guard: MutexGuard<'a, ConnectionPool>,
     _not_send: PhantomData<Rc<()>>,
 }
 
@@ -126,7 +126,7 @@ pub(crate) struct StoreGuard<'a> {
 pub struct ServerSnapshot<'a> {
     peer: &'a Peer,
     #[serde(serialize_with = "serialize_deref")]
-    store: StoreGuard<'a>,
+    connection_pool: ConnectionPoolGuard<'a>,
 }
 
 pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
@@ -143,7 +143,7 @@ impl Server {
         let mut server = Self {
             peer: Peer::new(),
             app_state,
-            store: Default::default(),
+            connection_pool: Default::default(),
             handlers: Default::default(),
         };
 
@@ -257,8 +257,6 @@ impl Server {
         self
     }
 
-    /// Handle a request while holding a lock to the store. This is useful when we're registering
-    /// a connection but we want to respond on the connection before anybody else can send on it.
     fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
         F: 'static + Send + Sync + Fn(Arc<Self>, M, Response<M>, Session) -> Fut,
@@ -342,9 +340,9 @@ impl Server {
             ).await?;
 
             {
-                let mut store = this.store().await;
-                store.add_connection(connection_id, user_id, user.admin);
-                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
+                let mut pool = this.connection_pool().await;
+                pool.add_connection(connection_id, user_id, user.admin);
+                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 
                 if let Some((code, count)) = invite_code {
                     this.peer.send(connection_id, proto::UpdateInviteInfo {
@@ -435,9 +433,9 @@ impl Server {
     ) -> Result<()> {
         self.peer.disconnect(connection_id);
         let decline_calls = {
-            let mut store = self.store().await;
-            store.remove_connection(connection_id)?;
-            let mut connections = store.user_connection_ids(user_id);
+            let mut pool = self.connection_pool().await;
+            pool.remove_connection(connection_id)?;
+            let mut connections = pool.user_connection_ids(user_id);
             connections.next().is_none()
         };
 
@@ -468,9 +466,9 @@ impl Server {
     ) -> Result<()> {
         if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
             if let Some(code) = &user.invite_code {
-                let store = self.store().await;
-                let invitee_contact = store.contact_for_user(invitee_id, true, false);
-                for connection_id in store.user_connection_ids(inviter_id) {
+                let pool = self.connection_pool().await;
+                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
+                for connection_id in pool.user_connection_ids(inviter_id) {
                     self.peer.send(
                         connection_id,
                         proto::UpdateContacts {
@@ -494,8 +492,8 @@ impl Server {
     pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
         if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
             if let Some(invite_code) = &user.invite_code {
-                let store = self.store().await;
-                for connection_id in store.user_connection_ids(user_id) {
+                let pool = self.connection_pool().await;
+                for connection_id in pool.user_connection_ids(user_id) {
                     self.peer.send(
                         connection_id,
                         proto::UpdateInviteInfo {
@@ -582,7 +580,11 @@ impl Server {
                 session.connection_id,
             )
             .await?;
-        for connection_id in self.store().await.user_connection_ids(session.user_id) {
+        for connection_id in self
+            .connection_pool()
+            .await
+            .user_connection_ids(session.user_id)
+        {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
@@ -672,9 +674,9 @@ impl Server {
 
         self.room_updated(&left_room.room);
         {
-            let store = self.store().await;
+            let pool = self.connection_pool().await;
             for canceled_user_id in left_room.canceled_calls_to_user_ids {
-                for connection_id in store.user_connection_ids(canceled_user_id) {
+                for connection_id in pool.user_connection_ids(canceled_user_id) {
                     self.peer
                         .send(connection_id, proto::CallCanceled {})
                         .trace_err();
@@ -742,7 +744,7 @@ impl Server {
         self.update_user_contacts(called_user_id).await?;
 
         let mut calls = self
-            .store()
+            .connection_pool()
             .await
             .user_connection_ids(called_user_id)
             .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
@@ -784,7 +786,11 @@ impl Server {
             .db
             .cancel_call(Some(room_id), session.connection_id, called_user_id)
             .await?;
-        for connection_id in self.store().await.user_connection_ids(called_user_id) {
+        for connection_id in self
+            .connection_pool()
+            .await
+            .user_connection_ids(called_user_id)
+        {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
@@ -807,7 +813,11 @@ impl Server {
             .db
             .decline_call(Some(room_id), session.user_id)
             .await?;
-        for connection_id in self.store().await.user_connection_ids(session.user_id) {
+        for connection_id in self
+            .connection_pool()
+            .await
+            .user_connection_ids(session.user_id)
+        {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
@@ -897,15 +907,15 @@ impl Server {
     async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
         let contacts = self.app_state.db.get_contacts(user_id).await?;
         let busy = self.app_state.db.is_user_busy(user_id).await?;
-        let store = self.store().await;
-        let updated_contact = store.contact_for_user(user_id, false, busy);
+        let pool = self.connection_pool().await;
+        let updated_contact = contact_for_user(user_id, false, busy, &pool);
         for contact in contacts {
             if let db::Contact::Accepted {
                 user_id: contact_user_id,
                 ..
             } = contact
             {
-                for contact_conn_id in store.user_connection_ids(contact_user_id) {
+                for contact_conn_id in pool.user_connection_ids(contact_user_id) {
                     self.peer
                         .send(
                             contact_conn_id,
@@ -1522,7 +1532,11 @@ impl Server {
         // Update outgoing contact requests of requester
         let mut update = proto::UpdateContacts::default();
         update.outgoing_requests.push(responder_id.to_proto());
-        for connection_id in self.store().await.user_connection_ids(requester_id) {
+        for connection_id in self
+            .connection_pool()
+            .await
+            .user_connection_ids(requester_id)
+        {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1534,7 +1548,11 @@ impl Server {
                 requester_id: requester_id.to_proto(),
                 should_notify: true,
             });
-        for connection_id in self.store().await.user_connection_ids(responder_id) {
+        for connection_id in self
+            .connection_pool()
+            .await
+            .user_connection_ids(responder_id)
+        {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1563,18 +1581,18 @@ impl Server {
                 .await?;
             let busy = self.app_state.db.is_user_busy(requester_id).await?;
 
-            let store = self.store().await;
+            let pool = self.connection_pool().await;
             // Update responder with new contact
             let mut update = proto::UpdateContacts::default();
             if accept {
                 update
                     .contacts
-                    .push(store.contact_for_user(requester_id, false, busy));
+                    .push(contact_for_user(requester_id, false, busy, &pool));
             }
             update
                 .remove_incoming_requests
                 .push(requester_id.to_proto());
-            for connection_id in store.user_connection_ids(responder_id) {
+            for connection_id in pool.user_connection_ids(responder_id) {
                 self.peer.send(connection_id, update.clone())?;
             }
 
@@ -1583,12 +1601,12 @@ impl Server {
             if accept {
                 update
                     .contacts
-                    .push(store.contact_for_user(responder_id, true, busy));
+                    .push(contact_for_user(responder_id, true, busy, &pool));
             }
             update
                 .remove_outgoing_requests
                 .push(responder_id.to_proto());
-            for connection_id in store.user_connection_ids(requester_id) {
+            for connection_id in pool.user_connection_ids(requester_id) {
                 self.peer.send(connection_id, update.clone())?;
             }
         }
@@ -1615,7 +1633,11 @@ impl Server {
         update
             .remove_outgoing_requests
             .push(responder_id.to_proto());
-        for connection_id in self.store().await.user_connection_ids(requester_id) {
+        for connection_id in self
+            .connection_pool()
+            .await
+            .user_connection_ids(requester_id)
+        {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1624,7 +1646,11 @@ impl Server {
         update
             .remove_incoming_requests
             .push(requester_id.to_proto());
-        for connection_id in self.store().await.user_connection_ids(responder_id) {
+        for connection_id in self
+            .connection_pool()
+            .await
+            .user_connection_ids(responder_id)
+        {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1678,13 +1704,13 @@ impl Server {
         Ok(())
     }
 
-    pub(crate) async fn store(&self) -> StoreGuard<'_> {
+    pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
         #[cfg(test)]
         tokio::task::yield_now().await;
-        let guard = self.store.lock().await;
+        let guard = self.connection_pool.lock().await;
         #[cfg(test)]
         tokio::task::yield_now().await;
-        StoreGuard {
+        ConnectionPoolGuard {
             guard,
             _not_send: PhantomData,
         }
@@ -1692,27 +1718,27 @@ impl Server {
 
     pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
         ServerSnapshot {
-            store: self.store().await,
+            connection_pool: self.connection_pool().await,
             peer: &self.peer,
         }
     }
 }
 
-impl<'a> Deref for StoreGuard<'a> {
-    type Target = Store;
+impl<'a> Deref for ConnectionPoolGuard<'a> {
+    type Target = ConnectionPool;
 
     fn deref(&self) -> &Self::Target {
         &*self.guard
     }
 }
 
-impl<'a> DerefMut for StoreGuard<'a> {
+impl<'a> DerefMut for ConnectionPoolGuard<'a> {
     fn deref_mut(&mut self) -> &mut Self::Target {
         &mut *self.guard
     }
 }
 
-impl<'a> Drop for StoreGuard<'a> {
+impl<'a> Drop for ConnectionPoolGuard<'a> {
     fn drop(&mut self) {
         #[cfg(test)]
         self.check_invariants();
@@ -1821,7 +1847,7 @@ pub async fn handle_websocket_request(
 
 pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
     let connections = server
-        .store()
+        .connection_pool()
         .await
         .connections()
         .filter(|connection| !connection.admin)
@@ -1868,6 +1894,53 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
     }
 }
 
+fn build_initial_contacts_update(
+    contacts: Vec<db::Contact>,
+    pool: &ConnectionPool,
+) -> proto::UpdateContacts {
+    let mut update = proto::UpdateContacts::default();
+
+    for contact in contacts {
+        match contact {
+            db::Contact::Accepted {
+                user_id,
+                should_notify,
+                busy,
+            } => {
+                update
+                    .contacts
+                    .push(contact_for_user(user_id, should_notify, busy, &pool));
+            }
+            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
+            db::Contact::Incoming {
+                user_id,
+                should_notify,
+            } => update
+                .incoming_requests
+                .push(proto::IncomingContactRequest {
+                    requester_id: user_id.to_proto(),
+                    should_notify,
+                }),
+        }
+    }
+
+    update
+}
+
+fn contact_for_user(
+    user_id: UserId,
+    should_notify: bool,
+    busy: bool,
+    pool: &ConnectionPool,
+) -> proto::Contact {
+    proto::Contact {
+        user_id: user_id.to_proto(),
+        online: pool.is_user_online(user_id),
+        busy,
+        should_notify,
+    }
+}
+
 pub trait ResultExt {
     type Ok;
 

crates/collab/src/rpc/store.rs → crates/collab/src/rpc/connection_pool.rs 🔗

@@ -1,12 +1,12 @@
-use crate::db::{self, UserId};
+use crate::db::UserId;
 use anyhow::{anyhow, Result};
 use collections::{BTreeMap, HashSet};
-use rpc::{proto, ConnectionId};
+use rpc::ConnectionId;
 use serde::Serialize;
 use tracing::instrument;
 
 #[derive(Default, Serialize)]
-pub struct Store {
+pub struct ConnectionPool {
     connections: BTreeMap<ConnectionId, Connection>,
     connected_users: BTreeMap<UserId, ConnectedUser>,
 }
@@ -22,7 +22,7 @@ pub struct Connection {
     pub admin: bool,
 }
 
-impl Store {
+impl ConnectionPool {
     #[instrument(skip(self))]
     pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
         self.connections
@@ -70,55 +70,6 @@ impl Store {
             .is_empty()
     }
 
-    pub fn build_initial_contacts_update(
-        &self,
-        contacts: Vec<db::Contact>,
-    ) -> proto::UpdateContacts {
-        let mut update = proto::UpdateContacts::default();
-
-        for contact in contacts {
-            match contact {
-                db::Contact::Accepted {
-                    user_id,
-                    should_notify,
-                    busy,
-                } => {
-                    update
-                        .contacts
-                        .push(self.contact_for_user(user_id, should_notify, busy));
-                }
-                db::Contact::Outgoing { user_id } => {
-                    update.outgoing_requests.push(user_id.to_proto())
-                }
-                db::Contact::Incoming {
-                    user_id,
-                    should_notify,
-                } => update
-                    .incoming_requests
-                    .push(proto::IncomingContactRequest {
-                        requester_id: user_id.to_proto(),
-                        should_notify,
-                    }),
-            }
-        }
-
-        update
-    }
-
-    pub fn contact_for_user(
-        &self,
-        user_id: UserId,
-        should_notify: bool,
-        busy: bool,
-    ) -> proto::Contact {
-        proto::Contact {
-            user_id: user_id.to_proto(),
-            online: self.is_user_online(user_id),
-            busy,
-            should_notify,
-        }
-    }
-
     #[cfg(test)]
     pub fn check_invariants(&self) {
         for (connection_id, connection) in &self.connections {