Replace synchronous `Store` lock with an async lock

Antonio Scandurra created

This also fixes some failures due to `broadcast` and `update_contacts_for_users`
being fallible. As part of this commit, these two functions don't return `Result`
anymore: the reason for this change is that we don't want a request to fail only
because a peer disconnected while we were trying to broadcast a message to them.

Change summary

crates/server/Cargo.toml       |   1 
crates/server/src/rpc.rs       | 256 ++++++++++++++++++++++++-----------
crates/server/src/rpc/store.rs |  25 ---
3 files changed, 178 insertions(+), 104 deletions(-)

Detailed changes

crates/server/Cargo.toml 🔗

@@ -15,6 +15,7 @@ required-features = ["seed-support"]
 [dependencies]
 collections = { path = "../collections" }
 rpc = { path = "../rpc" }
+util = { path = "../util" }
 anyhow = "1.0.40"
 async-io = "1.3"
 async-std = { version = "1.8.0", features = ["attributes"] }

crates/server/src/rpc.rs 🔗

@@ -7,12 +7,14 @@ use super::{
 };
 use anyhow::anyhow;
 use async_io::Timer;
-use async_std::task;
+use async_std::{
+    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
+    task,
+};
 use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
 use collections::{HashMap, HashSet};
 use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt};
 use log::{as_debug, as_display};
-use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
 use rpc::{
     proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
     Connection, ConnectionId, Peer, TypedEnvelope,
@@ -21,6 +23,9 @@ use sha1::{Digest as _, Sha1};
 use std::{
     any::TypeId,
     future::Future,
+    marker::PhantomData,
+    ops::{Deref, DerefMut},
+    rc::Rc,
     sync::Arc,
     time::{Duration, Instant},
 };
@@ -31,6 +36,7 @@ use tide::{
     Request, Response,
 };
 use time::OffsetDateTime;
+use util::ResultExt;
 
 type MessageHandler = Box<
     dyn Send
@@ -58,6 +64,16 @@ pub struct RealExecutor;
 const MESSAGE_COUNT_PER_PAGE: usize = 100;
 const MAX_MESSAGE_LEN: usize = 1024;
 
+struct StoreReadGuard<'a> {
+    guard: RwLockReadGuard<'a, Store>,
+    _not_send: PhantomData<Rc<()>>,
+}
+
+struct StoreWriteGuard<'a> {
+    guard: RwLockWriteGuard<'a, Store>,
+    _not_send: PhantomData<Rc<()>>,
+}
+
 impl Server {
     pub fn new(
         app_state: Arc<AppState>,
@@ -197,10 +213,10 @@ impl Server {
                 let _ = send_connection_id.send(connection_id).await;
             }
 
-            this.state_mut().add_connection(connection_id, user_id);
-            if let Err(err) = this.update_contacts_for_users(&[user_id]) {
-                log::error!("error updating contacts for {:?}: {}", user_id, err);
-            }
+            this.state_mut()
+                .await
+                .add_connection(connection_id, user_id);
+            this.update_contacts_for_users(&[user_id]).await;
 
             let handle_io = handle_io.fuse();
             futures::pin_mut!(handle_io);
@@ -257,7 +273,7 @@ impl Server {
 
     async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
         self.peer.disconnect(connection_id);
-        let removed_connection = self.state_mut().remove_connection(connection_id)?;
+        let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
 
         for (project_id, project) in removed_connection.hosted_projects {
             if let Some(share) = project.share {
@@ -268,7 +284,7 @@ impl Server {
                         self.peer
                             .send(conn_id, proto::UnshareProject { project_id })
                     },
-                )?;
+                );
             }
         }
 
@@ -281,10 +297,11 @@ impl Server {
                         peer_id: connection_id.0,
                     },
                 )
-            })?;
+            });
         }
 
-        self.update_contacts_for_users(removed_connection.contact_ids.iter())?;
+        self.update_contacts_for_users(removed_connection.contact_ids.iter())
+            .await;
         Ok(())
     }
 
@@ -297,7 +314,7 @@ impl Server {
         request: TypedEnvelope<proto::RegisterProject>,
     ) -> tide::Result<proto::RegisterProjectResponse> {
         let project_id = {
-            let mut state = self.state_mut();
+            let mut state = self.state_mut().await;
             let user_id = state.user_id_for_connection(request.sender_id)?;
             state.register_project(request.sender_id, user_id)
         };
@@ -310,8 +327,10 @@ impl Server {
     ) -> tide::Result<()> {
         let project = self
             .state_mut()
+            .await
             .unregister_project(request.payload.project_id, request.sender_id)?;
-        self.update_contacts_for_users(project.authorized_user_ids().iter())?;
+        self.update_contacts_for_users(project.authorized_user_ids().iter())
+            .await;
         Ok(())
     }
 
@@ -320,6 +339,7 @@ impl Server {
         request: TypedEnvelope<proto::ShareProject>,
     ) -> tide::Result<proto::Ack> {
         self.state_mut()
+            .await
             .share_project(request.payload.project_id, request.sender_id);
         Ok(proto::Ack {})
     }
@@ -331,13 +351,15 @@ impl Server {
         let project_id = request.payload.project_id;
         let project = self
             .state_mut()
+            .await
             .unshare_project(project_id, request.sender_id)?;
 
         broadcast(request.sender_id, project.connection_ids, |conn_id| {
             self.peer
                 .send(conn_id, proto::UnshareProject { project_id })
-        })?;
-        self.update_contacts_for_users(&project.authorized_user_ids)?;
+        });
+        self.update_contacts_for_users(&project.authorized_user_ids)
+            .await;
         Ok(())
     }
 
@@ -347,9 +369,13 @@ impl Server {
     ) -> tide::Result<proto::JoinProjectResponse> {
         let project_id = request.payload.project_id;
 
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let (response, connection_ids, contact_user_ids) = self
             .state_mut()
+            .await
             .join_project(request.sender_id, user_id, project_id)
             .and_then(|joined| {
                 let share = joined.project.share()?;
@@ -410,8 +436,8 @@ impl Server {
                     }),
                 },
             )
-        })?;
-        self.update_contacts_for_users(&contact_user_ids)?;
+        });
+        self.update_contacts_for_users(&contact_user_ids).await;
         Ok(response)
     }
 
@@ -421,7 +447,10 @@ impl Server {
     ) -> tide::Result<()> {
         let sender_id = request.sender_id;
         let project_id = request.payload.project_id;
-        let worktree = self.state_mut().leave_project(sender_id, project_id)?;
+        let worktree = self
+            .state_mut()
+            .await
+            .leave_project(sender_id, project_id)?;
 
         broadcast(sender_id, worktree.connection_ids, |conn_id| {
             self.peer.send(
@@ -431,8 +460,9 @@ impl Server {
                     peer_id: sender_id.0,
                 },
             )
-        })?;
-        self.update_contacts_for_users(&worktree.authorized_user_ids)?;
+        });
+        self.update_contacts_for_users(&worktree.authorized_user_ids)
+            .await;
 
         Ok(())
     }
@@ -441,7 +471,10 @@ impl Server {
         mut self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterWorktree>,
     ) -> tide::Result<proto::Ack> {
-        let host_user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let host_user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
 
         let mut contact_user_ids = HashSet::default();
         contact_user_ids.insert(host_user_id);
@@ -453,7 +486,7 @@ impl Server {
         let contact_user_ids = contact_user_ids.into_iter().collect::<Vec<_>>();
         let guest_connection_ids;
         {
-            let mut state = self.state_mut();
+            let mut state = self.state_mut().await;
             guest_connection_ids = state
                 .read_project(request.payload.project_id, request.sender_id)?
                 .guest_connection_ids();
@@ -471,8 +504,8 @@ impl Server {
         broadcast(request.sender_id, guest_connection_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
-        self.update_contacts_for_users(&contact_user_ids)?;
+        });
+        self.update_contacts_for_users(&contact_user_ids).await;
         Ok(proto::Ack {})
     }
 
@@ -482,9 +515,11 @@ impl Server {
     ) -> tide::Result<()> {
         let project_id = request.payload.project_id;
         let worktree_id = request.payload.worktree_id;
-        let (worktree, guest_connection_ids) =
-            self.state_mut()
-                .unregister_worktree(project_id, worktree_id, request.sender_id)?;
+        let (worktree, guest_connection_ids) = self.state_mut().await.unregister_worktree(
+            project_id,
+            worktree_id,
+            request.sender_id,
+        )?;
         broadcast(request.sender_id, guest_connection_ids, |conn_id| {
             self.peer.send(
                 conn_id,
@@ -493,8 +528,9 @@ impl Server {
                     worktree_id,
                 },
             )
-        })?;
-        self.update_contacts_for_users(&worktree.authorized_user_ids)?;
+        });
+        self.update_contacts_for_users(&worktree.authorized_user_ids)
+            .await;
         Ok(())
     }
 
@@ -502,7 +538,7 @@ impl Server {
         mut self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateWorktree>,
     ) -> tide::Result<proto::Ack> {
-        let connection_ids = self.state_mut().update_worktree(
+        let connection_ids = self.state_mut().await.update_worktree(
             request.sender_id,
             request.payload.project_id,
             request.payload.worktree_id,
@@ -513,7 +549,7 @@ impl Server {
         broadcast(request.sender_id, connection_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
 
         Ok(proto::Ack {})
     }
@@ -527,7 +563,7 @@ impl Server {
             .summary
             .clone()
             .ok_or_else(|| anyhow!("invalid summary"))?;
-        let receiver_ids = self.state_mut().update_diagnostic_summary(
+        let receiver_ids = self.state_mut().await.update_diagnostic_summary(
             request.payload.project_id,
             request.payload.worktree_id,
             request.sender_id,
@@ -537,7 +573,7 @@ impl Server {
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -545,7 +581,7 @@ impl Server {
         mut self: Arc<Server>,
         request: TypedEnvelope<proto::StartLanguageServer>,
     ) -> tide::Result<()> {
-        let receiver_ids = self.state_mut().start_language_server(
+        let receiver_ids = self.state_mut().await.start_language_server(
             request.payload.project_id,
             request.sender_id,
             request
@@ -557,7 +593,7 @@ impl Server {
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -567,11 +603,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -584,6 +621,7 @@ impl Server {
     {
         let host_connection_id = self
             .state()
+            .await
             .read_project(request.payload.remote_entity_id(), request.sender_id)?
             .host_connection_id;
         Ok(self
@@ -598,6 +636,7 @@ impl Server {
     ) -> tide::Result<proto::BufferSaved> {
         let host = self
             .state()
+            .await
             .read_project(request.payload.project_id, request.sender_id)?
             .host_connection_id;
         let response = self
@@ -607,12 +646,13 @@ impl Server {
 
         let mut guests = self
             .state()
+            .await
             .read_project(request.payload.project_id, request.sender_id)?
             .connection_ids();
         guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
         broadcast(host, guests, |conn_id| {
             self.peer.forward_send(host, conn_id, response.clone())
-        })?;
+        });
 
         Ok(response)
     }
@@ -623,11 +663,12 @@ impl Server {
     ) -> tide::Result<proto::Ack> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(proto::Ack {})
     }
 
@@ -637,11 +678,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -651,11 +693,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -665,11 +708,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -681,6 +725,7 @@ impl Server {
         let follower_id = request.sender_id;
         if !self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, follower_id)?
             .contains(&leader_id)
         {
@@ -703,6 +748,7 @@ impl Server {
         let leader_id = ConnectionId(request.payload.leader_id);
         if !self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?
             .contains(&leader_id)
         {
@@ -719,6 +765,7 @@ impl Server {
     ) -> tide::Result<()> {
         let connection_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         let leader_id = request
             .payload
@@ -743,7 +790,10 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::GetChannels>,
     ) -> tide::Result<proto::GetChannelsResponse> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channels = self.app_state.db.get_accessible_channels(user_id).await?;
         Ok(proto::GetChannelsResponse {
             channels: channels
@@ -781,33 +831,34 @@ impl Server {
         Ok(proto::GetUsersResponse { users })
     }
 
-    fn update_contacts_for_users<'a>(
+    async fn update_contacts_for_users<'a>(
         self: &Arc<Server>,
         user_ids: impl IntoIterator<Item = &'a UserId>,
-    ) -> anyhow::Result<()> {
-        let mut result = Ok(());
-        let state = self.state();
+    ) {
+        let state = self.state().await;
         for user_id in user_ids {
             let contacts = state.contacts_for_user(*user_id);
             for connection_id in state.connection_ids_for_user(*user_id) {
-                if let Err(error) = self.peer.send(
-                    connection_id,
-                    proto::UpdateContacts {
-                        contacts: contacts.clone(),
-                    },
-                ) {
-                    result = Err(error);
-                }
+                self.peer
+                    .send(
+                        connection_id,
+                        proto::UpdateContacts {
+                            contacts: contacts.clone(),
+                        },
+                    )
+                    .log_err();
             }
         }
-        result
     }
 
     async fn join_channel(
         mut self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,
     ) -> tide::Result<proto::JoinChannelResponse> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -818,7 +869,9 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.state_mut().join_channel(request.sender_id, channel_id);
+        self.state_mut()
+            .await
+            .join_channel(request.sender_id, channel_id);
         let messages = self
             .app_state
             .db
@@ -843,7 +896,10 @@ impl Server {
         mut self: Arc<Self>,
         request: TypedEnvelope<proto::LeaveChannel>,
     ) -> tide::Result<()> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -855,6 +911,7 @@ impl Server {
         }
 
         self.state_mut()
+            .await
             .leave_channel(request.sender_id, channel_id);
 
         Ok(())
@@ -868,7 +925,7 @@ impl Server {
         let user_id;
         let connection_ids;
         {
-            let state = self.state();
+            let state = self.state().await;
             user_id = state.user_id_for_connection(request.sender_id)?;
             connection_ids = state.channel_connection_ids(channel_id)?;
         }
@@ -909,7 +966,7 @@ impl Server {
                     message: Some(message.clone()),
                 },
             )
-        })?;
+        });
         Ok(proto::SendChannelMessageResponse {
             message: Some(message),
         })
@@ -919,7 +976,10 @@ impl Server {
         self: Arc<Self>,
         request: TypedEnvelope<proto::GetChannelMessages>,
     ) -> tide::Result<proto::GetChannelMessagesResponse> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -955,12 +1015,57 @@ impl Server {
         })
     }
 
-    fn state<'a>(self: &'a Arc<Self>) -> RwLockReadGuard<'a, Store> {
-        self.store.read()
+    async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        let guard = self.store.read().await;
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        StoreReadGuard {
+            guard,
+            _not_send: PhantomData,
+        }
+    }
+
+    async fn state_mut<'a>(self: &'a mut Arc<Self>) -> StoreWriteGuard<'a> {
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        let guard = self.store.write().await;
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        StoreWriteGuard {
+            guard,
+            _not_send: PhantomData,
+        }
+    }
+}
+
+impl<'a> Deref for StoreReadGuard<'a> {
+    type Target = Store;
+
+    fn deref(&self) -> &Self::Target {
+        &*self.guard
+    }
+}
+
+impl<'a> Deref for StoreWriteGuard<'a> {
+    type Target = Store;
+
+    fn deref(&self) -> &Self::Target {
+        &*self.guard
     }
+}
 
-    fn state_mut<'a>(self: &'a mut Arc<Self>) -> RwLockWriteGuard<'a, Store> {
-        self.store.write()
+impl<'a> DerefMut for StoreWriteGuard<'a> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        &mut *self.guard
+    }
+}
+
+impl<'a> Drop for StoreWriteGuard<'a> {
+    fn drop(&mut self) {
+        #[cfg(test)]
+        self.check_invariants();
     }
 }
 
@@ -976,25 +1081,15 @@ impl Executor for RealExecutor {
     }
 }
 
-fn broadcast<F>(
-    sender_id: ConnectionId,
-    receiver_ids: Vec<ConnectionId>,
-    mut f: F,
-) -> anyhow::Result<()>
+fn broadcast<F>(sender_id: ConnectionId, receiver_ids: Vec<ConnectionId>, mut f: F)
 where
     F: FnMut(ConnectionId) -> anyhow::Result<()>,
 {
-    let mut result = Ok(());
     for receiver_id in receiver_ids {
         if receiver_id != sender_id {
-            if let Err(error) = f(receiver_id) {
-                if result.is_ok() {
-                    result = Err(error);
-                }
-            }
+            f(receiver_id).log_err();
         }
     }
-    result
 }
 
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
@@ -5216,6 +5311,7 @@ mod tests {
                     let contacts = server
                         .store
                         .read()
+                        .await
                         .contacts_for_user(guest.current_user_id(&guest_cx));
                     assert!(!contacts
                         .iter()
@@ -5292,7 +5388,7 @@ mod tests {
                         .unwrap()
                         .read_with(&guest_cx, |project, _| assert!(project.is_read_only()));
                     for user_id in &user_ids {
-                        for contact in server.store.read().contacts_for_user(*user_id) {
+                        for contact in server.store.read().await.contacts_for_user(*user_id) {
                             assert_ne!(
                                 contact.user_id, removed_guest_id.0 as u64,
                                 "removed guest is still a contact of another peer"
@@ -5590,7 +5686,7 @@ mod tests {
         }
 
         async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
-            self.server.store.read()
+            self.server.store.read().await
         }
 
         async fn condition<F>(&mut self, mut predicate: F)
@@ -5598,7 +5694,7 @@ mod tests {
             F: FnMut(&Store) -> bool,
         {
             async_std::future::timeout(Duration::from_millis(500), async {
-                while !(predicate)(&*self.server.store.read()) {
+                while !(predicate)(&*self.server.store.read().await) {
                     self.foreground.start_waiting();
                     self.notifications.next().await;
                     self.foreground.finish_waiting();

crates/server/src/rpc/store.rs 🔗

@@ -130,9 +130,6 @@ impl Store {
             }
         }
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(result)
     }
 
@@ -275,8 +272,6 @@ impl Store {
                 share.worktrees.insert(worktree_id, Default::default());
             }
 
-            #[cfg(test)]
-            self.check_invariants();
             Ok(())
         } else {
             Err(anyhow!("no such project"))?
@@ -313,8 +308,6 @@ impl Store {
                         }
                     }
 
-                    #[cfg(test)]
-                    self.check_invariants();
                     Ok(project)
                 } else {
                     Err(anyhow!("no such project"))?
@@ -359,9 +352,6 @@ impl Store {
             }
         }
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok((worktree, guest_connection_ids))
     }
 
@@ -403,9 +393,6 @@ impl Store {
                 }
             }
 
-            #[cfg(test)]
-            self.check_invariants();
-
             Ok(UnsharedProject {
                 connection_ids,
                 authorized_user_ids,
@@ -491,9 +478,6 @@ impl Store {
         share.active_replica_ids.insert(replica_id);
         share.guests.insert(connection_id, (replica_id, user_id));
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(JoinedProject {
             replica_id,
             project: &self.projects[&project_id],
@@ -526,9 +510,6 @@ impl Store {
         let connection_ids = project.connection_ids();
         let authorized_user_ids = project.authorized_user_ids();
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(LeftProject {
             connection_ids,
             authorized_user_ids,
@@ -556,10 +537,6 @@ impl Store {
             worktree.entries.insert(entry.id, entry.clone());
         }
         let connection_ids = project.connection_ids();
-
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(connection_ids)
     }
 
@@ -633,7 +610,7 @@ impl Store {
     }
 
     #[cfg(test)]
-    fn check_invariants(&self) {
+    pub fn check_invariants(&self) {
         for (connection_id, connection) in &self.connections {
             for project_id in &connection.projects {
                 let project = &self.projects.get(&project_id).unwrap();