Hold the state lock while responding to guest joining a project

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/server/src/rpc.rs | 74 +++++++++++++++++++++++++++++++++++------
1 file changed, 63 insertions(+), 11 deletions(-)

Detailed changes

crates/server/src/rpc.rs 🔗

@@ -94,7 +94,7 @@ impl Server {
             .add_message_handler(Server::unregister_project)
             .add_request_handler(Server::share_project)
             .add_message_handler(Server::unshare_project)
-            .add_request_handler(Server::join_project)
+            .add_sync_request_handler(Server::join_project)
             .add_message_handler(Server::leave_project)
             .add_request_handler(Server::register_worktree)
             .add_message_handler(Server::unregister_worktree)
@@ -186,6 +186,42 @@ impl Server {
         })
     }
 
+    /// Handle a request while holding a lock to the store. This is useful when we're registering
+    /// a connection but we want to respond on the connection before anybody else can send on it.
+    fn add_sync_request_handler<F, M>(&mut self, handler: F) -> &mut Self
+    where
+        F: 'static
+            + Send
+            + Sync
+            + Fn(Arc<Self>, &mut Store, TypedEnvelope<M>) -> tide::Result<M::Response>,
+        M: RequestMessage,
+    {
+        let handler = Arc::new(handler);
+        self.add_message_handler(move |server, envelope| {
+            let receipt = envelope.receipt();
+            let handler = handler.clone();
+            async move {
+                let mut store = server.store.write().await;
+                let response = (handler)(server.clone(), &mut *store, envelope);
+                match response {
+                    Ok(response) => {
+                        server.peer.respond(receipt, response)?;
+                        Ok(())
+                    }
+                    Err(error) => {
+                        server.peer.respond_with_error(
+                            receipt,
+                            proto::Error {
+                                message: error.to_string(),
+                            },
+                        )?;
+                        Err(error)
+                    }
+                }
+            }
+        })
+    }
+
     pub fn handle_connection<E: Executor>(
         self: &Arc<Self>,
         connection: Connection,
@@ -363,19 +399,15 @@ impl Server {
         Ok(())
     }
 
-    async fn join_project(
-        mut self: Arc<Server>,
+    fn join_project(
+        self: Arc<Server>,
+        state: &mut Store,
         request: TypedEnvelope<proto::JoinProject>,
     ) -> tide::Result<proto::JoinProjectResponse> {
         let project_id = request.payload.project_id;
 
-        let user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
-        let (response, connection_ids, contact_user_ids) = self
-            .state_mut()
-            .await
+        let user_id = state.user_id_for_connection(request.sender_id)?;
+        let (response, connection_ids, contact_user_ids) = state
             .join_project(request.sender_id, user_id, project_id)
             .and_then(|joined| {
                 let share = joined.project.share()?;
@@ -437,7 +469,7 @@ impl Server {
                 },
             )
         });
-        self.update_contacts_for_users(&contact_user_ids).await;
+        self.update_contacts_for_users_sync(state, &contact_user_ids);
         Ok(response)
     }
 
@@ -851,6 +883,26 @@ impl Server {
         }
     }
 
+    fn update_contacts_for_users_sync<'a>(
+        self: &Arc<Self>,
+        state: &Store,
+        user_ids: impl IntoIterator<Item = &'a UserId>,
+    ) {
+        for user_id in user_ids {
+            let contacts = state.contacts_for_user(*user_id);
+            for connection_id in state.connection_ids_for_user(*user_id) {
+                self.peer
+                    .send(
+                        connection_id,
+                        proto::UpdateContacts {
+                            contacts: contacts.clone(),
+                        },
+                    )
+                    .log_err();
+            }
+        }
+    }
+
     async fn join_channel(
         mut self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,