Ensure worktrees have been sent before responding with definitions

Antonio Scandurra created

Changing the frequency at which we update worktrees highlighted a
problem in the randomized tests that was causing clients to receive
a definition to a worktree *before* observing the registration of
the worktree itself. This was most likely caused by #1224 because
the scenario that pull request enabled was the following:

- Guest requests a definition pointing to a non-existant worktree
- Server forwards the request to the host
- Host sends an `UpdateProject` message
- Host sends a response to the definition request
- Server observes the `UpdateProject` message and tries to acquire
  the store
- Given that we're waiting, the server goes ahead to process the
  response for the definition request, responding *before*
  `UpdateProject` is forwarded
- Server finally forwards `UpdateProject` to the guest

This commit ensures that, after forwarding a project request and getting a
response, we acquire a lock to the store again to ensure the project still
exists. This has the effect of ordering the forwarded request *after* any
message that was received prior to the response and for which we are still
waiting to acquire a lock to the store.

Change summary

crates/collab/src/integration_tests.rs |  19 +---
crates/collab/src/rpc.rs               | 106 +++++++++++----------------
2 files changed, 49 insertions(+), 76 deletions(-)

Detailed changes

crates/collab/src/integration_tests.rs 🔗

@@ -50,7 +50,6 @@ use std::{
     time::Duration,
 };
 use theme::ThemeRegistry;
-use tokio::sync::RwLockReadGuard;
 use workspace::{Item, SplitDirection, ToggleFollow, Workspace};
 
 #[ctor::ctor]
@@ -589,7 +588,7 @@ async fn test_offline_projects(
     deterministic.run_until_parked();
     assert!(server
         .store
-        .read()
+        .lock()
         .await
         .project_metadata_for_user(user_a)
         .is_empty());
@@ -620,7 +619,7 @@ async fn test_offline_projects(
     cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
     assert!(server
         .store
-        .read()
+        .lock()
         .await
         .project_metadata_for_user(user_a)
         .is_empty());
@@ -1446,7 +1445,7 @@ async fn test_collaborating_with_diagnostics(
     // Wait for server to see the diagnostics update.
     deterministic.run_until_parked();
     {
-        let store = server.store.read().await;
+        let store = server.store.lock().await;
         let project = store.project(ProjectId::from_proto(project_id)).unwrap();
         let worktree = project.worktrees.get(&worktree_id.to_proto()).unwrap();
         assert!(!worktree.diagnostic_summaries.is_empty());
@@ -3172,7 +3171,7 @@ async fn test_basic_chat(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
 
     assert_eq!(
         server
-            .state()
+            .store()
             .await
             .channel(channel_id)
             .unwrap()
@@ -4660,7 +4659,7 @@ async fn test_random_collaboration(
                     .unwrap();
                 let contacts = server
                     .store
-                    .read()
+                    .lock()
                     .await
                     .build_initial_contacts_update(contacts)
                     .contacts;
@@ -4745,7 +4744,7 @@ async fn test_random_collaboration(
                     let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
                     let contacts = server
                         .store
-                        .read()
+                        .lock()
                         .await
                         .build_initial_contacts_update(contacts)
                         .contacts;
@@ -5077,10 +5076,6 @@ impl TestServer {
         })
     }
 
-    async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
-        self.server.store.read().await
-    }
-
     async fn condition<F>(&mut self, mut predicate: F)
     where
         F: FnMut(&Store) -> bool,
@@ -5089,7 +5084,7 @@ impl TestServer {
             self.foreground.parking_forbidden(),
             "you must call forbid_parking to use server conditions so we don't block indefinitely"
         );
-        while !(predicate)(&*self.server.store.read().await) {
+        while !(predicate)(&*self.server.store.lock().await) {
             self.foreground.start_waiting();
             self.notifications.next().await;
             self.foreground.finish_waiting();

crates/collab/src/rpc.rs 🔗

@@ -51,7 +51,7 @@ use std::{
 };
 use time::OffsetDateTime;
 use tokio::{
-    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
+    sync::{Mutex, MutexGuard},
     time::Sleep,
 };
 use tower::ServiceBuilder;
@@ -97,7 +97,7 @@ impl<R: RequestMessage> Response<R> {
 
 pub struct Server {
     peer: Arc<Peer>,
-    pub(crate) store: RwLock<Store>,
+    pub(crate) store: Mutex<Store>,
     app_state: Arc<AppState>,
     handlers: HashMap<TypeId, MessageHandler>,
     notifications: Option<mpsc::UnboundedSender<()>>,
@@ -115,13 +115,8 @@ pub struct RealExecutor;
 const MESSAGE_COUNT_PER_PAGE: usize = 100;
 const MAX_MESSAGE_LEN: usize = 1024;
 
-struct StoreReadGuard<'a> {
-    guard: RwLockReadGuard<'a, Store>,
-    _not_send: PhantomData<Rc<()>>,
-}
-
-struct StoreWriteGuard<'a> {
-    guard: RwLockWriteGuard<'a, Store>,
+pub(crate) struct StoreGuard<'a> {
+    guard: MutexGuard<'a, Store>,
     _not_send: PhantomData<Rc<()>>,
 }
 
@@ -129,7 +124,7 @@ struct StoreWriteGuard<'a> {
 pub struct ServerSnapshot<'a> {
     peer: &'a Peer,
     #[serde(serialize_with = "serialize_deref")]
-    store: RwLockReadGuard<'a, Store>,
+    store: StoreGuard<'a>,
 }
 
 pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
@@ -384,7 +379,7 @@ impl Server {
             ).await?;
 
             {
-                let mut store = this.store_mut().await;
+                let mut store = this.store().await;
                 store.add_connection(connection_id, user_id, user.admin);
                 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
 
@@ -471,7 +466,7 @@ impl Server {
         let mut projects_to_unregister = Vec::new();
         let removed_user_id;
         {
-            let mut store = self.store_mut().await;
+            let mut store = self.store().await;
             let removed_connection = store.remove_connection(connection_id)?;
 
             for (project_id, project) in removed_connection.hosted_projects {
@@ -605,7 +600,7 @@ impl Server {
             .await
             .user_id_for_connection(request.sender_id)?;
         let project_id = self.app_state.db.register_project(user_id).await?;
-        self.store_mut()
+        self.store()
             .await
             .register_project(request.sender_id, project_id)?;
 
@@ -623,7 +618,7 @@ impl Server {
     ) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let (user_id, project) = {
-            let mut state = self.store_mut().await;
+            let mut state = self.store().await;
             let project = state.unregister_project(project_id, request.sender_id)?;
             (state.user_id_for_connection(request.sender_id)?, project)
         };
@@ -725,7 +720,7 @@ impl Server {
             return Err(anyhow!("no such project"))?;
         }
 
-        self.store_mut().await.request_join_project(
+        self.store().await.request_join_project(
             guest_user_id,
             project_id,
             response.into_receipt(),
@@ -747,7 +742,7 @@ impl Server {
         let host_user_id;
 
         {
-            let mut state = self.store_mut().await;
+            let mut state = self.store().await;
             let project_id = ProjectId::from_proto(request.payload.project_id);
             let project = state.project(project_id)?;
             if project.host_connection_id != request.sender_id {
@@ -897,7 +892,7 @@ impl Server {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let project;
         {
-            let mut store = self.store_mut().await;
+            let mut store = self.store().await;
             project = store.leave_project(sender_id, project_id)?;
             tracing::info!(
                 %project_id,
@@ -948,7 +943,7 @@ impl Server {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let user_id;
         {
-            let mut state = self.store_mut().await;
+            let mut state = self.store().await;
             user_id = state.user_id_for_connection(request.sender_id)?;
             let guest_connection_ids = state
                 .read_project(project_id, request.sender_id)?
@@ -967,7 +962,7 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterProjectActivity>,
     ) -> Result<()> {
-        self.store_mut().await.register_project_activity(
+        self.store().await.register_project_activity(
             ProjectId::from_proto(request.payload.project_id),
             request.sender_id,
         )?;
@@ -982,7 +977,7 @@ impl Server {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let worktree_id = request.payload.worktree_id;
         let (connection_ids, metadata_changed, extension_counts) = {
-            let mut store = self.store_mut().await;
+            let mut store = self.store().await;
             let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
                 request.sender_id,
                 project_id,
@@ -1024,7 +1019,7 @@ impl Server {
             .summary
             .clone()
             .ok_or_else(|| anyhow!("invalid summary"))?;
-        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
+        let receiver_ids = self.store().await.update_diagnostic_summary(
             ProjectId::from_proto(request.payload.project_id),
             request.payload.worktree_id,
             request.sender_id,
@@ -1042,7 +1037,7 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::StartLanguageServer>,
     ) -> Result<()> {
-        let receiver_ids = self.store_mut().await.start_language_server(
+        let receiver_ids = self.store().await.start_language_server(
             ProjectId::from_proto(request.payload.project_id),
             request.sender_id,
             request
@@ -1081,20 +1076,23 @@ impl Server {
     where
         T: EntityMessage + RequestMessage,
     {
+        let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
         let host_connection_id = self
             .store()
             .await
-            .read_project(
-                ProjectId::from_proto(request.payload.remote_entity_id()),
-                request.sender_id,
-            )?
+            .read_project(project_id, request.sender_id)?
             .host_connection_id;
+        let payload = self
+            .peer
+            .forward_request(request.sender_id, host_connection_id, request.payload)
+            .await?;
 
-        response.send(
-            self.peer
-                .forward_request(request.sender_id, host_connection_id, request.payload)
-                .await?,
-        )?;
+        // Ensure project still exists by the time we get the response from the host.
+        self.store()
+            .await
+            .read_project(project_id, request.sender_id)?;
+
+        response.send(payload)?;
         Ok(())
     }
 
@@ -1135,7 +1133,7 @@ impl Server {
     ) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let receiver_ids = {
-            let mut store = self.store_mut().await;
+            let mut store = self.store().await;
             store.register_project_activity(project_id, request.sender_id)?;
             store.project_connection_ids(project_id, request.sender_id)?
         };
@@ -1202,7 +1200,7 @@ impl Server {
         let leader_id = ConnectionId(request.payload.leader_id);
         let follower_id = request.sender_id;
         {
-            let mut store = self.store_mut().await;
+            let mut store = self.store().await;
             if !store
                 .project_connection_ids(project_id, follower_id)?
                 .contains(&leader_id)
@@ -1227,7 +1225,7 @@ impl Server {
     async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let leader_id = ConnectionId(request.payload.leader_id);
-        let mut store = self.store_mut().await;
+        let mut store = self.store().await;
         if !store
             .project_connection_ids(project_id, request.sender_id)?
             .contains(&leader_id)
@@ -1245,7 +1243,7 @@ impl Server {
         request: TypedEnvelope<proto::UpdateFollowers>,
     ) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
-        let mut store = self.store_mut().await;
+        let mut store = self.store().await;
         store.register_project_activity(project_id, request.sender_id)?;
         let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
         let leader_id = request
@@ -1503,7 +1501,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.store_mut()
+        self.store()
             .await
             .join_channel(request.sender_id, channel_id);
         let messages = self
@@ -1545,7 +1543,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.store_mut()
+        self.store()
             .await
             .leave_channel(request.sender_id, channel_id);
 
@@ -1653,25 +1651,13 @@ impl Server {
         Ok(())
     }
 
-    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
+    pub(crate) async fn store<'a>(&'a self) -> StoreGuard<'a> {
         #[cfg(test)]
         tokio::task::yield_now().await;
-        let guard = self.store.read().await;
+        let guard = self.store.lock().await;
         #[cfg(test)]
         tokio::task::yield_now().await;
-        StoreReadGuard {
-            guard,
-            _not_send: PhantomData,
-        }
-    }
-
-    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
-        #[cfg(test)]
-        tokio::task::yield_now().await;
-        let guard = self.store.write().await;
-        #[cfg(test)]
-        tokio::task::yield_now().await;
-        StoreWriteGuard {
+        StoreGuard {
             guard,
             _not_send: PhantomData,
         }
@@ -1679,21 +1665,13 @@ impl Server {
 
     pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
         ServerSnapshot {
-            store: self.store.read().await,
+            store: self.store().await,
             peer: &self.peer,
         }
     }
 }
 
-impl<'a> Deref for StoreReadGuard<'a> {
-    type Target = Store;
-
-    fn deref(&self) -> &Self::Target {
-        &*self.guard
-    }
-}
-
-impl<'a> Deref for StoreWriteGuard<'a> {
+impl<'a> Deref for StoreGuard<'a> {
     type Target = Store;
 
     fn deref(&self) -> &Self::Target {
@@ -1701,13 +1679,13 @@ impl<'a> Deref for StoreWriteGuard<'a> {
     }
 }
 
-impl<'a> DerefMut for StoreWriteGuard<'a> {
+impl<'a> DerefMut for StoreGuard<'a> {
     fn deref_mut(&mut self) -> &mut Self::Target {
         &mut *self.guard
     }
 }
 
-impl<'a> Drop for StoreWriteGuard<'a> {
+impl<'a> Drop for StoreGuard<'a> {
     fn drop(&mut self) {
         #[cfg(test)]
         self.check_invariants();