From 9a8b0388fa0dd2f1f0bfd37685dccfed35f97012 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 11 Apr 2022 17:38:17 +0200 Subject: [PATCH] Replace synchronous `Store` lock with an async lock This also fixes some failures due to `broadcast` and `update_contacts_for_users` being fallible. As part of this commit, these two functions don't return `Result` anymore: the reason for this change is that we don't want a request to fail only because a peer disconnected while we were trying to broadcast a message to them. --- crates/server/Cargo.toml | 1 + crates/server/src/rpc.rs | 256 ++++++++++++++++++++++----------- crates/server/src/rpc/store.rs | 25 +--- 3 files changed, 178 insertions(+), 104 deletions(-) diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index e0834b76c5ecb1c518e918a70af78824685e6f42..7c9bb8078597f1bbf37dcb2312d19eb594d94cf2 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -15,6 +15,7 @@ required-features = ["seed-support"] [dependencies] collections = { path = "../collections" } rpc = { path = "../rpc" } +util = { path = "../util" } anyhow = "1.0.40" async-io = "1.3" async-std = { version = "1.8.0", features = ["attributes"] } diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 2fe0931c4ca8c6e97bfe1704e763a7f6d643e786..0ffcde9176890ad2b541772f9a828f049823b15c 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -7,12 +7,14 @@ use super::{ }; use anyhow::anyhow; use async_io::Timer; -use async_std::task; +use async_std::{ + sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}, + task, +}; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use collections::{HashMap, HashSet}; use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt}; use log::{as_debug, as_display}; -use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use rpc::{ proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}, Connection, ConnectionId, Peer, TypedEnvelope, @@ -21,6 +23,9 @@ use sha1::{Digest as _, Sha1}; use std::{ any::TypeId, future::Future, + marker::PhantomData, + ops::{Deref, DerefMut}, + rc::Rc, sync::Arc, time::{Duration, Instant}, }; @@ -31,6 +36,7 @@ use tide::{ Request, Response, }; use time::OffsetDateTime; +use util::ResultExt; type MessageHandler = Box< dyn Send @@ -58,6 +64,16 @@ pub struct RealExecutor; const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; +struct StoreReadGuard<'a> { + guard: RwLockReadGuard<'a, Store>, + _not_send: PhantomData>, +} + +struct StoreWriteGuard<'a> { + guard: RwLockWriteGuard<'a, Store>, + _not_send: PhantomData>, +} + impl Server { pub fn new( app_state: Arc, @@ -197,10 +213,10 @@ impl Server { let _ = send_connection_id.send(connection_id).await; } - this.state_mut().add_connection(connection_id, user_id); - if let Err(err) = this.update_contacts_for_users(&[user_id]) { - log::error!("error updating contacts for {:?}: {}", user_id, err); - } + this.state_mut() + .await + .add_connection(connection_id, user_id); + this.update_contacts_for_users(&[user_id]).await; let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); @@ -257,7 +273,7 @@ impl Server { async fn sign_out(self: &mut Arc, connection_id: ConnectionId) -> tide::Result<()> { self.peer.disconnect(connection_id); - let removed_connection = self.state_mut().remove_connection(connection_id)?; + let removed_connection = self.state_mut().await.remove_connection(connection_id)?; for (project_id, project) in removed_connection.hosted_projects { if let Some(share) = project.share { @@ -268,7 +284,7 @@ impl Server { self.peer .send(conn_id, proto::UnshareProject { project_id }) }, - )?; + ); } } @@ -281,10 +297,11 @@ impl Server { peer_id: connection_id.0, }, ) - })?; + }); } - self.update_contacts_for_users(removed_connection.contact_ids.iter())?; + self.update_contacts_for_users(removed_connection.contact_ids.iter()) + .await; Ok(()) } @@ -297,7 +314,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result { let project_id = { - let mut state = self.state_mut(); + let mut state = self.state_mut().await; let user_id = state.user_id_for_connection(request.sender_id)?; state.register_project(request.sender_id, user_id) }; @@ -310,8 +327,10 @@ impl Server { ) -> tide::Result<()> { let project = self .state_mut() + .await .unregister_project(request.payload.project_id, request.sender_id)?; - self.update_contacts_for_users(project.authorized_user_ids().iter())?; + self.update_contacts_for_users(project.authorized_user_ids().iter()) + .await; Ok(()) } @@ -320,6 +339,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result { self.state_mut() + .await .share_project(request.payload.project_id, request.sender_id); Ok(proto::Ack {}) } @@ -331,13 +351,15 @@ impl Server { let project_id = request.payload.project_id; let project = self .state_mut() + .await .unshare_project(project_id, request.sender_id)?; broadcast(request.sender_id, project.connection_ids, |conn_id| { self.peer .send(conn_id, proto::UnshareProject { project_id }) - })?; - self.update_contacts_for_users(&project.authorized_user_ids)?; + }); + self.update_contacts_for_users(&project.authorized_user_ids) + .await; Ok(()) } @@ -347,9 +369,13 @@ impl Server { ) -> tide::Result { let project_id = request.payload.project_id; - let user_id = self.state().user_id_for_connection(request.sender_id)?; + let user_id = self + .state() + .await + .user_id_for_connection(request.sender_id)?; let (response, connection_ids, contact_user_ids) = self .state_mut() + .await .join_project(request.sender_id, user_id, project_id) .and_then(|joined| { let share = joined.project.share()?; @@ -410,8 +436,8 @@ impl Server { }), }, ) - })?; - self.update_contacts_for_users(&contact_user_ids)?; + }); + self.update_contacts_for_users(&contact_user_ids).await; Ok(response) } @@ -421,7 +447,10 @@ impl Server { ) -> tide::Result<()> { let sender_id = request.sender_id; let project_id = request.payload.project_id; - let worktree = self.state_mut().leave_project(sender_id, project_id)?; + let worktree = self + .state_mut() + .await + .leave_project(sender_id, project_id)?; broadcast(sender_id, worktree.connection_ids, |conn_id| { self.peer.send( @@ -431,8 +460,9 @@ impl Server { peer_id: sender_id.0, }, ) - })?; - self.update_contacts_for_users(&worktree.authorized_user_ids)?; + }); + self.update_contacts_for_users(&worktree.authorized_user_ids) + .await; Ok(()) } @@ -441,7 +471,10 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result { - let host_user_id = self.state().user_id_for_connection(request.sender_id)?; + let host_user_id = self + .state() + .await + .user_id_for_connection(request.sender_id)?; let mut contact_user_ids = HashSet::default(); contact_user_ids.insert(host_user_id); @@ -453,7 +486,7 @@ impl Server { let contact_user_ids = contact_user_ids.into_iter().collect::>(); let guest_connection_ids; { - let mut state = self.state_mut(); + let mut state = self.state_mut().await; guest_connection_ids = state .read_project(request.payload.project_id, request.sender_id)? .guest_connection_ids(); @@ -471,8 +504,8 @@ impl Server { broadcast(request.sender_id, guest_connection_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; - self.update_contacts_for_users(&contact_user_ids)?; + }); + self.update_contacts_for_users(&contact_user_ids).await; Ok(proto::Ack {}) } @@ -482,9 +515,11 @@ impl Server { ) -> tide::Result<()> { let project_id = request.payload.project_id; let worktree_id = request.payload.worktree_id; - let (worktree, guest_connection_ids) = - self.state_mut() - .unregister_worktree(project_id, worktree_id, request.sender_id)?; + let (worktree, guest_connection_ids) = self.state_mut().await.unregister_worktree( + project_id, + worktree_id, + request.sender_id, + )?; broadcast(request.sender_id, guest_connection_ids, |conn_id| { self.peer.send( conn_id, @@ -493,8 +528,9 @@ impl Server { worktree_id, }, ) - })?; - self.update_contacts_for_users(&worktree.authorized_user_ids)?; + }); + self.update_contacts_for_users(&worktree.authorized_user_ids) + .await; Ok(()) } @@ -502,7 +538,7 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result { - let connection_ids = self.state_mut().update_worktree( + let connection_ids = self.state_mut().await.update_worktree( request.sender_id, request.payload.project_id, request.payload.worktree_id, @@ -513,7 +549,7 @@ impl Server { broadcast(request.sender_id, connection_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(proto::Ack {}) } @@ -527,7 +563,7 @@ impl Server { .summary .clone() .ok_or_else(|| anyhow!("invalid summary"))?; - let receiver_ids = self.state_mut().update_diagnostic_summary( + let receiver_ids = self.state_mut().await.update_diagnostic_summary( request.payload.project_id, request.payload.worktree_id, request.sender_id, @@ -537,7 +573,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(()) } @@ -545,7 +581,7 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let receiver_ids = self.state_mut().start_language_server( + let receiver_ids = self.state_mut().await.start_language_server( request.payload.project_id, request.sender_id, request @@ -557,7 +593,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(()) } @@ -567,11 +603,12 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() + .await .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(()) } @@ -584,6 +621,7 @@ impl Server { { let host_connection_id = self .state() + .await .read_project(request.payload.remote_entity_id(), request.sender_id)? .host_connection_id; Ok(self @@ -598,6 +636,7 @@ impl Server { ) -> tide::Result { let host = self .state() + .await .read_project(request.payload.project_id, request.sender_id)? .host_connection_id; let response = self @@ -607,12 +646,13 @@ impl Server { let mut guests = self .state() + .await .read_project(request.payload.project_id, request.sender_id)? .connection_ids(); guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id); broadcast(host, guests, |conn_id| { self.peer.forward_send(host, conn_id, response.clone()) - })?; + }); Ok(response) } @@ -623,11 +663,12 @@ impl Server { ) -> tide::Result { let receiver_ids = self .state() + .await .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(proto::Ack {}) } @@ -637,11 +678,12 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() + .await .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(()) } @@ -651,11 +693,12 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() + .await .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(()) } @@ -665,11 +708,12 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() + .await .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - })?; + }); Ok(()) } @@ -681,6 +725,7 @@ impl Server { let follower_id = request.sender_id; if !self .state() + .await .project_connection_ids(request.payload.project_id, follower_id)? .contains(&leader_id) { @@ -703,6 +748,7 @@ impl Server { let leader_id = ConnectionId(request.payload.leader_id); if !self .state() + .await .project_connection_ids(request.payload.project_id, request.sender_id)? .contains(&leader_id) { @@ -719,6 +765,7 @@ impl Server { ) -> tide::Result<()> { let connection_ids = self .state() + .await .project_connection_ids(request.payload.project_id, request.sender_id)?; let leader_id = request .payload @@ -743,7 +790,10 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result { - let user_id = self.state().user_id_for_connection(request.sender_id)?; + let user_id = self + .state() + .await + .user_id_for_connection(request.sender_id)?; let channels = self.app_state.db.get_accessible_channels(user_id).await?; Ok(proto::GetChannelsResponse { channels: channels @@ -781,33 +831,34 @@ impl Server { Ok(proto::GetUsersResponse { users }) } - fn update_contacts_for_users<'a>( + async fn update_contacts_for_users<'a>( self: &Arc, user_ids: impl IntoIterator, - ) -> anyhow::Result<()> { - let mut result = Ok(()); - let state = self.state(); + ) { + let state = self.state().await; for user_id in user_ids { let contacts = state.contacts_for_user(*user_id); for connection_id in state.connection_ids_for_user(*user_id) { - if let Err(error) = self.peer.send( - connection_id, - proto::UpdateContacts { - contacts: contacts.clone(), - }, - ) { - result = Err(error); - } + self.peer + .send( + connection_id, + proto::UpdateContacts { + contacts: contacts.clone(), + }, + ) + .log_err(); } } - result } async fn join_channel( mut self: Arc, request: TypedEnvelope, ) -> tide::Result { - let user_id = self.state().user_id_for_connection(request.sender_id)?; + let user_id = self + .state() + .await + .user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self .app_state @@ -818,7 +869,9 @@ impl Server { Err(anyhow!("access denied"))?; } - self.state_mut().join_channel(request.sender_id, channel_id); + self.state_mut() + .await + .join_channel(request.sender_id, channel_id); let messages = self .app_state .db @@ -843,7 +896,10 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let user_id = self.state().user_id_for_connection(request.sender_id)?; + let user_id = self + .state() + .await + .user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self .app_state @@ -855,6 +911,7 @@ impl Server { } self.state_mut() + .await .leave_channel(request.sender_id, channel_id); Ok(()) @@ -868,7 +925,7 @@ impl Server { let user_id; let connection_ids; { - let state = self.state(); + let state = self.state().await; user_id = state.user_id_for_connection(request.sender_id)?; connection_ids = state.channel_connection_ids(channel_id)?; } @@ -909,7 +966,7 @@ impl Server { message: Some(message.clone()), }, ) - })?; + }); Ok(proto::SendChannelMessageResponse { message: Some(message), }) @@ -919,7 +976,10 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result { - let user_id = self.state().user_id_for_connection(request.sender_id)?; + let user_id = self + .state() + .await + .user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self .app_state @@ -955,12 +1015,57 @@ impl Server { }) } - fn state<'a>(self: &'a Arc) -> RwLockReadGuard<'a, Store> { - self.store.read() + async fn state<'a>(self: &'a Arc) -> StoreReadGuard<'a> { + #[cfg(test)] + async_std::task::yield_now().await; + let guard = self.store.read().await; + #[cfg(test)] + async_std::task::yield_now().await; + StoreReadGuard { + guard, + _not_send: PhantomData, + } + } + + async fn state_mut<'a>(self: &'a mut Arc) -> StoreWriteGuard<'a> { + #[cfg(test)] + async_std::task::yield_now().await; + let guard = self.store.write().await; + #[cfg(test)] + async_std::task::yield_now().await; + StoreWriteGuard { + guard, + _not_send: PhantomData, + } + } +} + +impl<'a> Deref for StoreReadGuard<'a> { + type Target = Store; + + fn deref(&self) -> &Self::Target { + &*self.guard + } +} + +impl<'a> Deref for StoreWriteGuard<'a> { + type Target = Store; + + fn deref(&self) -> &Self::Target { + &*self.guard } +} - fn state_mut<'a>(self: &'a mut Arc) -> RwLockWriteGuard<'a, Store> { - self.store.write() +impl<'a> DerefMut for StoreWriteGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.guard + } +} + +impl<'a> Drop for StoreWriteGuard<'a> { + fn drop(&mut self) { + #[cfg(test)] + self.check_invariants(); } } @@ -976,25 +1081,15 @@ impl Executor for RealExecutor { } } -fn broadcast( - sender_id: ConnectionId, - receiver_ids: Vec, - mut f: F, -) -> anyhow::Result<()> +fn broadcast(sender_id: ConnectionId, receiver_ids: Vec, mut f: F) where F: FnMut(ConnectionId) -> anyhow::Result<()>, { - let mut result = Ok(()); for receiver_id in receiver_ids { if receiver_id != sender_id { - if let Err(error) = f(receiver_id) { - if result.is_ok() { - result = Err(error); - } - } + f(receiver_id).log_err(); } } - result } pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { @@ -5216,6 +5311,7 @@ mod tests { let contacts = server .store .read() + .await .contacts_for_user(guest.current_user_id(&guest_cx)); assert!(!contacts .iter() @@ -5292,7 +5388,7 @@ mod tests { .unwrap() .read_with(&guest_cx, |project, _| assert!(project.is_read_only())); for user_id in &user_ids { - for contact in server.store.read().contacts_for_user(*user_id) { + for contact in server.store.read().await.contacts_for_user(*user_id) { assert_ne!( contact.user_id, removed_guest_id.0 as u64, "removed guest is still a contact of another peer" @@ -5590,7 +5686,7 @@ mod tests { } async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> { - self.server.store.read() + self.server.store.read().await } async fn condition(&mut self, mut predicate: F) @@ -5598,7 +5694,7 @@ mod tests { F: FnMut(&Store) -> bool, { async_std::future::timeout(Duration::from_millis(500), async { - while !(predicate)(&*self.server.store.read()) { + while !(predicate)(&*self.server.store.read().await) { self.foreground.start_waiting(); self.notifications.next().await; self.foreground.finish_waiting(); diff --git a/crates/server/src/rpc/store.rs b/crates/server/src/rpc/store.rs index 6c330c9c8bae3e3558280ea940fc180207ce5c70..33d2a399816ad90f6e809373e8ece3ca67c8046e 100644 --- a/crates/server/src/rpc/store.rs +++ b/crates/server/src/rpc/store.rs @@ -130,9 +130,6 @@ impl Store { } } - #[cfg(test)] - self.check_invariants(); - Ok(result) } @@ -275,8 +272,6 @@ impl Store { share.worktrees.insert(worktree_id, Default::default()); } - #[cfg(test)] - self.check_invariants(); Ok(()) } else { Err(anyhow!("no such project"))? @@ -313,8 +308,6 @@ impl Store { } } - #[cfg(test)] - self.check_invariants(); Ok(project) } else { Err(anyhow!("no such project"))? @@ -359,9 +352,6 @@ impl Store { } } - #[cfg(test)] - self.check_invariants(); - Ok((worktree, guest_connection_ids)) } @@ -403,9 +393,6 @@ impl Store { } } - #[cfg(test)] - self.check_invariants(); - Ok(UnsharedProject { connection_ids, authorized_user_ids, @@ -491,9 +478,6 @@ impl Store { share.active_replica_ids.insert(replica_id); share.guests.insert(connection_id, (replica_id, user_id)); - #[cfg(test)] - self.check_invariants(); - Ok(JoinedProject { replica_id, project: &self.projects[&project_id], @@ -526,9 +510,6 @@ impl Store { let connection_ids = project.connection_ids(); let authorized_user_ids = project.authorized_user_ids(); - #[cfg(test)] - self.check_invariants(); - Ok(LeftProject { connection_ids, authorized_user_ids, @@ -556,10 +537,6 @@ impl Store { worktree.entries.insert(entry.id, entry.clone()); } let connection_ids = project.connection_ids(); - - #[cfg(test)] - self.check_invariants(); - Ok(connection_ids) } @@ -633,7 +610,7 @@ impl Store { } #[cfg(test)] - fn check_invariants(&self) { + pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { for project_id in &connection.projects { let project = &self.projects.get(&project_id).unwrap();