@@ -7,12 +7,14 @@ use super::{
};
use anyhow::anyhow;
use async_io::Timer;
-use async_std::task;
+use async_std::{
+ sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
+ task,
+};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use collections::{HashMap, HashSet};
use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt};
use log::{as_debug, as_display};
-use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use rpc::{
proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Connection, ConnectionId, Peer, TypedEnvelope,
@@ -21,6 +23,9 @@ use sha1::{Digest as _, Sha1};
use std::{
any::TypeId,
future::Future,
+ marker::PhantomData,
+ ops::{Deref, DerefMut},
+ rc::Rc,
sync::Arc,
time::{Duration, Instant},
};
@@ -31,6 +36,7 @@ use tide::{
Request, Response,
};
use time::OffsetDateTime;
+use util::ResultExt;
type MessageHandler = Box<
dyn Send
@@ -58,6 +64,16 @@ pub struct RealExecutor;
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
+struct StoreReadGuard<'a> {
+ guard: RwLockReadGuard<'a, Store>,
+ _not_send: PhantomData<Rc<()>>,
+}
+
+struct StoreWriteGuard<'a> {
+ guard: RwLockWriteGuard<'a, Store>,
+ _not_send: PhantomData<Rc<()>>,
+}
+
impl Server {
pub fn new(
app_state: Arc<AppState>,
@@ -197,10 +213,10 @@ impl Server {
let _ = send_connection_id.send(connection_id).await;
}
- this.state_mut().add_connection(connection_id, user_id);
- if let Err(err) = this.update_contacts_for_users(&[user_id]) {
- log::error!("error updating contacts for {:?}: {}", user_id, err);
- }
+ this.state_mut()
+ .await
+ .add_connection(connection_id, user_id);
+ this.update_contacts_for_users(&[user_id]).await;
let handle_io = handle_io.fuse();
futures::pin_mut!(handle_io);
@@ -257,7 +273,7 @@ impl Server {
async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
self.peer.disconnect(connection_id);
- let removed_connection = self.state_mut().remove_connection(connection_id)?;
+ let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
for (project_id, project) in removed_connection.hosted_projects {
if let Some(share) = project.share {
@@ -268,7 +284,7 @@ impl Server {
self.peer
.send(conn_id, proto::UnshareProject { project_id })
},
- )?;
+ );
}
}
@@ -281,10 +297,11 @@ impl Server {
peer_id: connection_id.0,
},
)
- })?;
+ });
}
- self.update_contacts_for_users(removed_connection.contact_ids.iter())?;
+ self.update_contacts_for_users(removed_connection.contact_ids.iter())
+ .await;
Ok(())
}
@@ -297,7 +314,7 @@ impl Server {
request: TypedEnvelope<proto::RegisterProject>,
) -> tide::Result<proto::RegisterProjectResponse> {
let project_id = {
- let mut state = self.state_mut();
+ let mut state = self.state_mut().await;
let user_id = state.user_id_for_connection(request.sender_id)?;
state.register_project(request.sender_id, user_id)
};
@@ -310,8 +327,10 @@ impl Server {
) -> tide::Result<()> {
let project = self
.state_mut()
+ .await
.unregister_project(request.payload.project_id, request.sender_id)?;
- self.update_contacts_for_users(project.authorized_user_ids().iter())?;
+ self.update_contacts_for_users(project.authorized_user_ids().iter())
+ .await;
Ok(())
}
@@ -320,6 +339,7 @@ impl Server {
request: TypedEnvelope<proto::ShareProject>,
) -> tide::Result<proto::Ack> {
self.state_mut()
+ .await
.share_project(request.payload.project_id, request.sender_id);
Ok(proto::Ack {})
}
@@ -331,13 +351,15 @@ impl Server {
let project_id = request.payload.project_id;
let project = self
.state_mut()
+ .await
.unshare_project(project_id, request.sender_id)?;
broadcast(request.sender_id, project.connection_ids, |conn_id| {
self.peer
.send(conn_id, proto::UnshareProject { project_id })
- })?;
- self.update_contacts_for_users(&project.authorized_user_ids)?;
+ });
+ self.update_contacts_for_users(&project.authorized_user_ids)
+ .await;
Ok(())
}
@@ -347,9 +369,13 @@ impl Server {
) -> tide::Result<proto::JoinProjectResponse> {
let project_id = request.payload.project_id;
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let (response, connection_ids, contact_user_ids) = self
.state_mut()
+ .await
.join_project(request.sender_id, user_id, project_id)
.and_then(|joined| {
let share = joined.project.share()?;
@@ -410,8 +436,8 @@ impl Server {
}),
},
)
- })?;
- self.update_contacts_for_users(&contact_user_ids)?;
+ });
+ self.update_contacts_for_users(&contact_user_ids).await;
Ok(response)
}
@@ -421,7 +447,10 @@ impl Server {
) -> tide::Result<()> {
let sender_id = request.sender_id;
let project_id = request.payload.project_id;
- let worktree = self.state_mut().leave_project(sender_id, project_id)?;
+ let worktree = self
+ .state_mut()
+ .await
+ .leave_project(sender_id, project_id)?;
broadcast(sender_id, worktree.connection_ids, |conn_id| {
self.peer.send(
@@ -431,8 +460,9 @@ impl Server {
peer_id: sender_id.0,
},
)
- })?;
- self.update_contacts_for_users(&worktree.authorized_user_ids)?;
+ });
+ self.update_contacts_for_users(&worktree.authorized_user_ids)
+ .await;
Ok(())
}
@@ -441,7 +471,10 @@ impl Server {
mut self: Arc<Server>,
request: TypedEnvelope<proto::RegisterWorktree>,
) -> tide::Result<proto::Ack> {
- let host_user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let host_user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let mut contact_user_ids = HashSet::default();
contact_user_ids.insert(host_user_id);
@@ -453,7 +486,7 @@ impl Server {
let contact_user_ids = contact_user_ids.into_iter().collect::<Vec<_>>();
let guest_connection_ids;
{
- let mut state = self.state_mut();
+ let mut state = self.state_mut().await;
guest_connection_ids = state
.read_project(request.payload.project_id, request.sender_id)?
.guest_connection_ids();
@@ -471,8 +504,8 @@ impl Server {
broadcast(request.sender_id, guest_connection_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
- self.update_contacts_for_users(&contact_user_ids)?;
+ });
+ self.update_contacts_for_users(&contact_user_ids).await;
Ok(proto::Ack {})
}
@@ -482,9 +515,11 @@ impl Server {
) -> tide::Result<()> {
let project_id = request.payload.project_id;
let worktree_id = request.payload.worktree_id;
- let (worktree, guest_connection_ids) =
- self.state_mut()
- .unregister_worktree(project_id, worktree_id, request.sender_id)?;
+ let (worktree, guest_connection_ids) = self.state_mut().await.unregister_worktree(
+ project_id,
+ worktree_id,
+ request.sender_id,
+ )?;
broadcast(request.sender_id, guest_connection_ids, |conn_id| {
self.peer.send(
conn_id,
@@ -493,8 +528,9 @@ impl Server {
worktree_id,
},
)
- })?;
- self.update_contacts_for_users(&worktree.authorized_user_ids)?;
+ });
+ self.update_contacts_for_users(&worktree.authorized_user_ids)
+ .await;
Ok(())
}
@@ -502,7 +538,7 @@ impl Server {
mut self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
) -> tide::Result<proto::Ack> {
- let connection_ids = self.state_mut().update_worktree(
+ let connection_ids = self.state_mut().await.update_worktree(
request.sender_id,
request.payload.project_id,
request.payload.worktree_id,
@@ -513,7 +549,7 @@ impl Server {
broadcast(request.sender_id, connection_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(proto::Ack {})
}
@@ -527,7 +563,7 @@ impl Server {
.summary
.clone()
.ok_or_else(|| anyhow!("invalid summary"))?;
- let receiver_ids = self.state_mut().update_diagnostic_summary(
+ let receiver_ids = self.state_mut().await.update_diagnostic_summary(
request.payload.project_id,
request.payload.worktree_id,
request.sender_id,
@@ -537,7 +573,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -545,7 +581,7 @@ impl Server {
mut self: Arc<Server>,
request: TypedEnvelope<proto::StartLanguageServer>,
) -> tide::Result<()> {
- let receiver_ids = self.state_mut().start_language_server(
+ let receiver_ids = self.state_mut().await.start_language_server(
request.payload.project_id,
request.sender_id,
request
@@ -557,7 +593,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -567,11 +603,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -584,6 +621,7 @@ impl Server {
{
let host_connection_id = self
.state()
+ .await
.read_project(request.payload.remote_entity_id(), request.sender_id)?
.host_connection_id;
Ok(self
@@ -598,6 +636,7 @@ impl Server {
) -> tide::Result<proto::BufferSaved> {
let host = self
.state()
+ .await
.read_project(request.payload.project_id, request.sender_id)?
.host_connection_id;
let response = self
@@ -607,12 +646,13 @@ impl Server {
let mut guests = self
.state()
+ .await
.read_project(request.payload.project_id, request.sender_id)?
.connection_ids();
guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
broadcast(host, guests, |conn_id| {
self.peer.forward_send(host, conn_id, response.clone())
- })?;
+ });
Ok(response)
}
@@ -623,11 +663,12 @@ impl Server {
) -> tide::Result<proto::Ack> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(proto::Ack {})
}
@@ -637,11 +678,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -651,11 +693,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -665,11 +708,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -681,6 +725,7 @@ impl Server {
let follower_id = request.sender_id;
if !self
.state()
+ .await
.project_connection_ids(request.payload.project_id, follower_id)?
.contains(&leader_id)
{
@@ -703,6 +748,7 @@ impl Server {
let leader_id = ConnectionId(request.payload.leader_id);
if !self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?
.contains(&leader_id)
{
@@ -719,6 +765,7 @@ impl Server {
) -> tide::Result<()> {
let connection_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
let leader_id = request
.payload
@@ -743,7 +790,10 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::GetChannels>,
) -> tide::Result<proto::GetChannelsResponse> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
Ok(proto::GetChannelsResponse {
channels: channels
@@ -781,33 +831,34 @@ impl Server {
Ok(proto::GetUsersResponse { users })
}
- fn update_contacts_for_users<'a>(
+ async fn update_contacts_for_users<'a>(
self: &Arc<Server>,
user_ids: impl IntoIterator<Item = &'a UserId>,
- ) -> anyhow::Result<()> {
- let mut result = Ok(());
- let state = self.state();
+ ) {
+ let state = self.state().await;
for user_id in user_ids {
let contacts = state.contacts_for_user(*user_id);
for connection_id in state.connection_ids_for_user(*user_id) {
- if let Err(error) = self.peer.send(
- connection_id,
- proto::UpdateContacts {
- contacts: contacts.clone(),
- },
- ) {
- result = Err(error);
- }
+ self.peer
+ .send(
+ connection_id,
+ proto::UpdateContacts {
+ contacts: contacts.clone(),
+ },
+ )
+ .log_err();
}
}
- result
}
async fn join_channel(
mut self: Arc<Self>,
request: TypedEnvelope<proto::JoinChannel>,
) -> tide::Result<proto::JoinChannelResponse> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
@@ -818,7 +869,9 @@ impl Server {
Err(anyhow!("access denied"))?;
}
- self.state_mut().join_channel(request.sender_id, channel_id);
+ self.state_mut()
+ .await
+ .join_channel(request.sender_id, channel_id);
let messages = self
.app_state
.db
@@ -843,7 +896,10 @@ impl Server {
mut self: Arc<Self>,
request: TypedEnvelope<proto::LeaveChannel>,
) -> tide::Result<()> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
@@ -855,6 +911,7 @@ impl Server {
}
self.state_mut()
+ .await
.leave_channel(request.sender_id, channel_id);
Ok(())
@@ -868,7 +925,7 @@ impl Server {
let user_id;
let connection_ids;
{
- let state = self.state();
+ let state = self.state().await;
user_id = state.user_id_for_connection(request.sender_id)?;
connection_ids = state.channel_connection_ids(channel_id)?;
}
@@ -909,7 +966,7 @@ impl Server {
message: Some(message.clone()),
},
)
- })?;
+ });
Ok(proto::SendChannelMessageResponse {
message: Some(message),
})
@@ -919,7 +976,10 @@ impl Server {
self: Arc<Self>,
request: TypedEnvelope<proto::GetChannelMessages>,
) -> tide::Result<proto::GetChannelMessagesResponse> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
@@ -955,12 +1015,57 @@ impl Server {
})
}
- fn state<'a>(self: &'a Arc<Self>) -> RwLockReadGuard<'a, Store> {
- self.store.read()
+ async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ let guard = self.store.read().await;
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ StoreReadGuard {
+ guard,
+ _not_send: PhantomData,
+ }
+ }
+
+ async fn state_mut<'a>(self: &'a mut Arc<Self>) -> StoreWriteGuard<'a> {
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ let guard = self.store.write().await;
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ StoreWriteGuard {
+ guard,
+ _not_send: PhantomData,
+ }
+ }
+}
+
+impl<'a> Deref for StoreReadGuard<'a> {
+ type Target = Store;
+
+ fn deref(&self) -> &Self::Target {
+ &*self.guard
+ }
+}
+
+impl<'a> Deref for StoreWriteGuard<'a> {
+ type Target = Store;
+
+ fn deref(&self) -> &Self::Target {
+ &*self.guard
}
+}
- fn state_mut<'a>(self: &'a mut Arc<Self>) -> RwLockWriteGuard<'a, Store> {
- self.store.write()
+impl<'a> DerefMut for StoreWriteGuard<'a> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut *self.guard
+ }
+}
+
+impl<'a> Drop for StoreWriteGuard<'a> {
+ fn drop(&mut self) {
+ #[cfg(test)]
+ self.check_invariants();
}
}
@@ -976,25 +1081,15 @@ impl Executor for RealExecutor {
}
}
-fn broadcast<F>(
- sender_id: ConnectionId,
- receiver_ids: Vec<ConnectionId>,
- mut f: F,
-) -> anyhow::Result<()>
+fn broadcast<F>(sender_id: ConnectionId, receiver_ids: Vec<ConnectionId>, mut f: F)
where
F: FnMut(ConnectionId) -> anyhow::Result<()>,
{
- let mut result = Ok(());
for receiver_id in receiver_ids {
if receiver_id != sender_id {
- if let Err(error) = f(receiver_id) {
- if result.is_ok() {
- result = Err(error);
- }
- }
+ f(receiver_id).log_err();
}
}
- result
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
@@ -5216,6 +5311,7 @@ mod tests {
let contacts = server
.store
.read()
+ .await
.contacts_for_user(guest.current_user_id(&guest_cx));
assert!(!contacts
.iter()
@@ -5292,7 +5388,7 @@ mod tests {
.unwrap()
.read_with(&guest_cx, |project, _| assert!(project.is_read_only()));
for user_id in &user_ids {
- for contact in server.store.read().contacts_for_user(*user_id) {
+ for contact in server.store.read().await.contacts_for_user(*user_id) {
assert_ne!(
contact.user_id, removed_guest_id.0 as u64,
"removed guest is still a contact of another peer"
@@ -5590,7 +5686,7 @@ mod tests {
}
async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
- self.server.store.read()
+ self.server.store.read().await
}
async fn condition<F>(&mut self, mut predicate: F)
@@ -5598,7 +5694,7 @@ mod tests {
F: FnMut(&Store) -> bool,
{
async_std::future::timeout(Duration::from_millis(500), async {
- while !(predicate)(&*self.server.store.read()) {
+ while !(predicate)(&*self.server.store.read().await) {
self.foreground.start_waiting();
self.notifications.next().await;
self.foreground.finish_waiting();