diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 1236af42cb05af4b544f74166284d34aa3e44739..006598a6b191e593c7934d145a3c146da0a7c496 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,5 +1,5 @@ use crate::{ - db::{NewUserParams, SqliteTestDb as TestDb, UserId}, + db::{self, NewUserParams, SqliteTestDb as TestDb, UserId}, rpc::{Executor, Server}, AppState, }; @@ -5469,18 +5469,15 @@ async fn test_random_collaboration( } for user_id in &user_ids { let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap(); - let contacts = server - .store - .lock() - .await - .build_initial_contacts_update(contacts) - .contacts; + let pool = server.connection_pool.lock().await; for contact in contacts { - if contact.online { - assert_ne!( - contact.user_id, removed_guest_id.0 as u64, - "removed guest is still a contact of another peer" - ); + if let db::Contact::Accepted { user_id, .. } = contact { + if pool.is_user_online(user_id) { + assert_ne!( + user_id, removed_guest_id, + "removed guest is still a contact of another peer" + ); + } } } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 58870163f50f349082636e9753171bc80560ea7f..175e3604c04acc522348a6f2c92e7fdb53b16599 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,4 +1,4 @@ -mod store; +mod connection_pool; use crate::{ auth, @@ -23,6 +23,7 @@ use axum::{ Extension, Router, TypedHeader, }; use collections::{HashMap, HashSet}; +pub use connection_pool::ConnectionPool; use futures::{ channel::oneshot, future::{self, BoxFuture}, @@ -49,7 +50,6 @@ use std::{ }, time::Duration, }; -pub use store::Store; use tokio::{ sync::{Mutex, MutexGuard}, time::Sleep, @@ -103,7 +103,7 @@ impl Response { pub struct Server { peer: Arc, - pub(crate) store: Mutex, + pub(crate) connection_pool: Mutex, app_state: Arc, handlers: HashMap, } @@ -117,8 +117,8 @@ pub trait Executor: Send + Clone { #[derive(Clone)] pub struct RealExecutor; -pub(crate) struct StoreGuard<'a> { - guard: MutexGuard<'a, Store>, +pub(crate) struct ConnectionPoolGuard<'a> { + guard: MutexGuard<'a, ConnectionPool>, _not_send: PhantomData>, } @@ -126,7 +126,7 @@ pub(crate) struct StoreGuard<'a> { pub struct ServerSnapshot<'a> { peer: &'a Peer, #[serde(serialize_with = "serialize_deref")] - store: StoreGuard<'a>, + connection_pool: ConnectionPoolGuard<'a>, } pub fn serialize_deref(value: &T, serializer: S) -> Result @@ -143,7 +143,7 @@ impl Server { let mut server = Self { peer: Peer::new(), app_state, - store: Default::default(), + connection_pool: Default::default(), handlers: Default::default(), }; @@ -257,8 +257,6 @@ impl Server { self } - /// Handle a request while holding a lock to the store. This is useful when we're registering - /// a connection but we want to respond on the connection before anybody else can send on it. fn add_request_handler(&mut self, handler: F) -> &mut Self where F: 'static + Send + Sync + Fn(Arc, M, Response, Session) -> Fut, @@ -342,9 +340,9 @@ impl Server { ).await?; { - let mut store = this.store().await; - store.add_connection(connection_id, user_id, user.admin); - this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?; + let mut pool = this.connection_pool().await; + pool.add_connection(connection_id, user_id, user.admin); + this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; if let Some((code, count)) = invite_code { this.peer.send(connection_id, proto::UpdateInviteInfo { @@ -435,9 +433,9 @@ impl Server { ) -> Result<()> { self.peer.disconnect(connection_id); let decline_calls = { - let mut store = self.store().await; - store.remove_connection(connection_id)?; - let mut connections = store.user_connection_ids(user_id); + let mut pool = self.connection_pool().await; + pool.remove_connection(connection_id)?; + let mut connections = pool.user_connection_ids(user_id); connections.next().is_none() }; @@ -468,9 +466,9 @@ impl Server { ) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { - let store = self.store().await; - let invitee_contact = store.contact_for_user(invitee_id, true, false); - for connection_id in store.user_connection_ids(inviter_id) { + let pool = self.connection_pool().await; + let invitee_contact = contact_for_user(invitee_id, true, false, &pool); + for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( connection_id, proto::UpdateContacts { @@ -494,8 +492,8 @@ impl Server { pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { - let store = self.store().await; - for connection_id in store.user_connection_ids(user_id) { + let pool = self.connection_pool().await; + for connection_id in pool.user_connection_ids(user_id) { self.peer.send( connection_id, proto::UpdateInviteInfo { @@ -582,7 +580,11 @@ impl Server { session.connection_id, ) .await?; - for connection_id in self.store().await.user_connection_ids(session.user_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(session.user_id) + { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -672,9 +674,9 @@ impl Server { self.room_updated(&left_room.room); { - let store = self.store().await; + let pool = self.connection_pool().await; for canceled_user_id in left_room.canceled_calls_to_user_ids { - for connection_id in store.user_connection_ids(canceled_user_id) { + for connection_id in pool.user_connection_ids(canceled_user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -742,7 +744,7 @@ impl Server { self.update_user_contacts(called_user_id).await?; let mut calls = self - .store() + .connection_pool() .await .user_connection_ids(called_user_id) .map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) @@ -784,7 +786,11 @@ impl Server { .db .cancel_call(Some(room_id), session.connection_id, called_user_id) .await?; - for connection_id in self.store().await.user_connection_ids(called_user_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(called_user_id) + { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -807,7 +813,11 @@ impl Server { .db .decline_call(Some(room_id), session.user_id) .await?; - for connection_id in self.store().await.user_connection_ids(session.user_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(session.user_id) + { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -897,15 +907,15 @@ impl Server { async fn update_user_contacts(self: &Arc, user_id: UserId) -> Result<()> { let contacts = self.app_state.db.get_contacts(user_id).await?; let busy = self.app_state.db.is_user_busy(user_id).await?; - let store = self.store().await; - let updated_contact = store.contact_for_user(user_id, false, busy); + let pool = self.connection_pool().await; + let updated_contact = contact_for_user(user_id, false, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, .. } = contact { - for contact_conn_id in store.user_connection_ids(contact_user_id) { + for contact_conn_id in pool.user_connection_ids(contact_user_id) { self.peer .send( contact_conn_id, @@ -1522,7 +1532,11 @@ impl Server { // Update outgoing contact requests of requester let mut update = proto::UpdateContacts::default(); update.outgoing_requests.push(responder_id.to_proto()); - for connection_id in self.store().await.user_connection_ids(requester_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(requester_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1534,7 +1548,11 @@ impl Server { requester_id: requester_id.to_proto(), should_notify: true, }); - for connection_id in self.store().await.user_connection_ids(responder_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(responder_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1563,18 +1581,18 @@ impl Server { .await?; let busy = self.app_state.db.is_user_busy(requester_id).await?; - let store = self.store().await; + let pool = self.connection_pool().await; // Update responder with new contact let mut update = proto::UpdateContacts::default(); if accept { update .contacts - .push(store.contact_for_user(requester_id, false, busy)); + .push(contact_for_user(requester_id, false, busy, &pool)); } update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in store.user_connection_ids(responder_id) { + for connection_id in pool.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1583,12 +1601,12 @@ impl Server { if accept { update .contacts - .push(store.contact_for_user(responder_id, true, busy)); + .push(contact_for_user(responder_id, true, busy, &pool)); } update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in store.user_connection_ids(requester_id) { + for connection_id in pool.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } } @@ -1615,7 +1633,11 @@ impl Server { update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in self.store().await.user_connection_ids(requester_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(requester_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1624,7 +1646,11 @@ impl Server { update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in self.store().await.user_connection_ids(responder_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(responder_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1678,13 +1704,13 @@ impl Server { Ok(()) } - pub(crate) async fn store(&self) -> StoreGuard<'_> { + pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { #[cfg(test)] tokio::task::yield_now().await; - let guard = self.store.lock().await; + let guard = self.connection_pool.lock().await; #[cfg(test)] tokio::task::yield_now().await; - StoreGuard { + ConnectionPoolGuard { guard, _not_send: PhantomData, } @@ -1692,27 +1718,27 @@ impl Server { pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { ServerSnapshot { - store: self.store().await, + connection_pool: self.connection_pool().await, peer: &self.peer, } } } -impl<'a> Deref for StoreGuard<'a> { - type Target = Store; +impl<'a> Deref for ConnectionPoolGuard<'a> { + type Target = ConnectionPool; fn deref(&self) -> &Self::Target { &*self.guard } } -impl<'a> DerefMut for StoreGuard<'a> { +impl<'a> DerefMut for ConnectionPoolGuard<'a> { fn deref_mut(&mut self) -> &mut Self::Target { &mut *self.guard } } -impl<'a> Drop for StoreGuard<'a> { +impl<'a> Drop for ConnectionPoolGuard<'a> { fn drop(&mut self) { #[cfg(test)] self.check_invariants(); @@ -1821,7 +1847,7 @@ pub async fn handle_websocket_request( pub async fn handle_metrics(Extension(server): Extension>) -> Result { let connections = server - .store() + .connection_pool() .await .connections() .filter(|connection| !connection.admin) @@ -1868,6 +1894,53 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { } } +fn build_initial_contacts_update( + contacts: Vec, + pool: &ConnectionPool, +) -> proto::UpdateContacts { + let mut update = proto::UpdateContacts::default(); + + for contact in contacts { + match contact { + db::Contact::Accepted { + user_id, + should_notify, + busy, + } => { + update + .contacts + .push(contact_for_user(user_id, should_notify, busy, &pool)); + } + db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), + db::Contact::Incoming { + user_id, + should_notify, + } => update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + should_notify, + }), + } + } + + update +} + +fn contact_for_user( + user_id: UserId, + should_notify: bool, + busy: bool, + pool: &ConnectionPool, +) -> proto::Contact { + proto::Contact { + user_id: user_id.to_proto(), + online: pool.is_user_online(user_id), + busy, + should_notify, + } +} + pub trait ResultExt { type Ok; diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/connection_pool.rs similarity index 64% rename from crates/collab/src/rpc/store.rs rename to crates/collab/src/rpc/connection_pool.rs index 2bb6d89f401a0274c3ac83b70eaa9cd192c882d1..ac7632f7da2ae6d4d6beb95aeb298d8e409f8d80 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/connection_pool.rs @@ -1,12 +1,12 @@ -use crate::db::{self, UserId}; +use crate::db::UserId; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashSet}; -use rpc::{proto, ConnectionId}; +use rpc::ConnectionId; use serde::Serialize; use tracing::instrument; #[derive(Default, Serialize)] -pub struct Store { +pub struct ConnectionPool { connections: BTreeMap, connected_users: BTreeMap, } @@ -22,7 +22,7 @@ pub struct Connection { pub admin: bool, } -impl Store { +impl ConnectionPool { #[instrument(skip(self))] pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { self.connections @@ -70,55 +70,6 @@ impl Store { .is_empty() } - pub fn build_initial_contacts_update( - &self, - contacts: Vec, - ) -> proto::UpdateContacts { - let mut update = proto::UpdateContacts::default(); - - for contact in contacts { - match contact { - db::Contact::Accepted { - user_id, - should_notify, - busy, - } => { - update - .contacts - .push(self.contact_for_user(user_id, should_notify, busy)); - } - db::Contact::Outgoing { user_id } => { - update.outgoing_requests.push(user_id.to_proto()) - } - db::Contact::Incoming { - user_id, - should_notify, - } => update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: user_id.to_proto(), - should_notify, - }), - } - } - - update - } - - pub fn contact_for_user( - &self, - user_id: UserId, - should_notify: bool, - busy: bool, - ) -> proto::Contact { - proto::Contact { - user_id: user_id.to_proto(), - online: self.is_user_online(user_id), - busy, - should_notify, - } - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections {