Make `Server::update_contacts_for_users` always synchronous

Antonio Scandurra created

Change summary

crates/server/src/rpc.rs | 149 +++++++++++++++--------------------------
1 file changed, 55 insertions(+), 94 deletions(-)

Detailed changes

crates/server/src/rpc.rs 🔗

@@ -249,10 +249,11 @@ impl Server {
                 let _ = send_connection_id.send(connection_id).await;
             }
 
-            this.state_mut()
-                .await
-                .add_connection(connection_id, user_id);
-            this.update_contacts_for_users(&[user_id]).await;
+            {
+                let mut state = this.state_mut().await;
+                state.add_connection(connection_id, user_id);
+                this.update_contacts_for_users(&*state, &[user_id]);
+            }
 
             let handle_io = handle_io.fuse();
             futures::pin_mut!(handle_io);
@@ -309,7 +310,8 @@ impl Server {
 
     async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
         self.peer.disconnect(connection_id);
-        let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
+        let mut state = self.state_mut().await;
+        let removed_connection = state.remove_connection(connection_id)?;
 
         for (project_id, project) in removed_connection.hosted_projects {
             if let Some(share) = project.share {
@@ -336,8 +338,7 @@ impl Server {
             });
         }
 
-        self.update_contacts_for_users(removed_connection.contact_ids.iter())
-            .await;
+        self.update_contacts_for_users(&*state, removed_connection.contact_ids.iter());
         Ok(())
     }
 
@@ -346,7 +347,7 @@ impl Server {
     }
 
     async fn register_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterProject>,
     ) -> tide::Result<proto::RegisterProjectResponse> {
         let project_id = {
@@ -358,20 +359,17 @@ impl Server {
     }
 
     async fn unregister_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UnregisterProject>,
     ) -> tide::Result<()> {
-        let project = self
-            .state_mut()
-            .await
-            .unregister_project(request.payload.project_id, request.sender_id)?;
-        self.update_contacts_for_users(project.authorized_user_ids().iter())
-            .await;
+        let mut state = self.state_mut().await;
+        let project = state.unregister_project(request.payload.project_id, request.sender_id)?;
+        self.update_contacts_for_users(&*state, &project.authorized_user_ids());
         Ok(())
     }
 
     async fn share_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::ShareProject>,
     ) -> tide::Result<proto::Ack> {
         self.state_mut()
@@ -381,21 +379,17 @@ impl Server {
     }
 
     async fn unshare_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UnshareProject>,
     ) -> tide::Result<()> {
         let project_id = request.payload.project_id;
-        let project = self
-            .state_mut()
-            .await
-            .unshare_project(project_id, request.sender_id)?;
-
+        let mut state = self.state_mut().await;
+        let project = state.unshare_project(project_id, request.sender_id)?;
         broadcast(request.sender_id, project.connection_ids, |conn_id| {
             self.peer
                 .send(conn_id, proto::UnshareProject { project_id })
         });
-        self.update_contacts_for_users(&project.authorized_user_ids)
-            .await;
+        self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
         Ok(())
     }
 
@@ -469,21 +463,18 @@ impl Server {
                 },
             )
         });
-        self.update_contacts_for_users_sync(state, &contact_user_ids);
+        self.update_contacts_for_users(state, &contact_user_ids);
         Ok(response)
     }
 
     async fn leave_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::LeaveProject>,
     ) -> tide::Result<()> {
         let sender_id = request.sender_id;
         let project_id = request.payload.project_id;
-        let worktree = self
-            .state_mut()
-            .await
-            .leave_project(sender_id, project_id)?;
-
+        let mut state = self.state_mut().await;
+        let worktree = state.leave_project(sender_id, project_id)?;
         broadcast(sender_id, worktree.connection_ids, |conn_id| {
             self.peer.send(
                 conn_id,
@@ -493,65 +484,56 @@ impl Server {
                 },
             )
         });
-        self.update_contacts_for_users(&worktree.authorized_user_ids)
-            .await;
-
+        self.update_contacts_for_users(&*state, &worktree.authorized_user_ids);
         Ok(())
     }
 
     async fn register_worktree(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterWorktree>,
     ) -> tide::Result<proto::Ack> {
-        let host_user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
-
         let mut contact_user_ids = HashSet::default();
-        contact_user_ids.insert(host_user_id);
         for github_login in &request.payload.authorized_logins {
             let contact_user_id = self.app_state.db.create_user(github_login, false).await?;
             contact_user_ids.insert(contact_user_id);
         }
 
+        let mut state = self.state_mut().await;
+        let host_user_id = state.user_id_for_connection(request.sender_id)?;
+        contact_user_ids.insert(host_user_id);
+
         let contact_user_ids = contact_user_ids.into_iter().collect::<Vec<_>>();
-        let guest_connection_ids;
-        {
-            let mut state = self.state_mut().await;
-            guest_connection_ids = state
-                .read_project(request.payload.project_id, request.sender_id)?
-                .guest_connection_ids();
-            state.register_worktree(
-                request.payload.project_id,
-                request.payload.worktree_id,
-                request.sender_id,
-                Worktree {
-                    authorized_user_ids: contact_user_ids.clone(),
-                    root_name: request.payload.root_name.clone(),
-                    visible: request.payload.visible,
-                },
-            )?;
-        }
+        let guest_connection_ids = state
+            .read_project(request.payload.project_id, request.sender_id)?
+            .guest_connection_ids();
+        state.register_worktree(
+            request.payload.project_id,
+            request.payload.worktree_id,
+            request.sender_id,
+            Worktree {
+                authorized_user_ids: contact_user_ids.clone(),
+                root_name: request.payload.root_name.clone(),
+                visible: request.payload.visible,
+            },
+        )?;
+
         broadcast(request.sender_id, guest_connection_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
         });
-        self.update_contacts_for_users(&contact_user_ids).await;
+        self.update_contacts_for_users(&*state, &contact_user_ids);
         Ok(proto::Ack {})
     }
 
     async fn unregister_worktree(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UnregisterWorktree>,
     ) -> tide::Result<()> {
         let project_id = request.payload.project_id;
         let worktree_id = request.payload.worktree_id;
-        let (worktree, guest_connection_ids) = self.state_mut().await.unregister_worktree(
-            project_id,
-            worktree_id,
-            request.sender_id,
-        )?;
+        let mut state = self.state_mut().await;
+        let (worktree, guest_connection_ids) =
+            state.unregister_worktree(project_id, worktree_id, request.sender_id)?;
         broadcast(request.sender_id, guest_connection_ids, |conn_id| {
             self.peer.send(
                 conn_id,
@@ -561,13 +543,12 @@ impl Server {
                 },
             )
         });
-        self.update_contacts_for_users(&worktree.authorized_user_ids)
-            .await;
+        self.update_contacts_for_users(&*state, &worktree.authorized_user_ids);
         Ok(())
     }
 
     async fn update_worktree(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateWorktree>,
     ) -> tide::Result<proto::Ack> {
         let connection_ids = self.state_mut().await.update_worktree(
@@ -587,7 +568,7 @@ impl Server {
     }
 
     async fn update_diagnostic_summary(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
     ) -> tide::Result<()> {
         let summary = request
@@ -610,7 +591,7 @@ impl Server {
     }
 
     async fn start_language_server(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::StartLanguageServer>,
     ) -> tide::Result<()> {
         let receiver_ids = self.state_mut().await.start_language_server(
@@ -863,27 +844,7 @@ impl Server {
         Ok(proto::GetUsersResponse { users })
     }
 
-    async fn update_contacts_for_users<'a>(
-        self: &Arc<Server>,
-        user_ids: impl IntoIterator<Item = &'a UserId>,
-    ) {
-        let state = self.state().await;
-        for user_id in user_ids {
-            let contacts = state.contacts_for_user(*user_id);
-            for connection_id in state.connection_ids_for_user(*user_id) {
-                self.peer
-                    .send(
-                        connection_id,
-                        proto::UpdateContacts {
-                            contacts: contacts.clone(),
-                        },
-                    )
-                    .log_err();
-            }
-        }
-    }
-
-    fn update_contacts_for_users_sync<'a>(
+    fn update_contacts_for_users<'a>(
         self: &Arc<Self>,
         state: &Store,
         user_ids: impl IntoIterator<Item = &'a UserId>,
@@ -904,7 +865,7 @@ impl Server {
     }
 
     async fn join_channel(
-        mut self: Arc<Self>,
+        self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,
     ) -> tide::Result<proto::JoinChannelResponse> {
         let user_id = self
@@ -945,7 +906,7 @@ impl Server {
     }
 
     async fn leave_channel(
-        mut self: Arc<Self>,
+        self: Arc<Self>,
         request: TypedEnvelope<proto::LeaveChannel>,
     ) -> tide::Result<()> {
         let user_id = self
@@ -1079,7 +1040,7 @@ impl Server {
         }
     }
 
-    async fn state_mut<'a>(self: &'a mut Arc<Self>) -> StoreWriteGuard<'a> {
+    async fn state_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
         #[cfg(test)]
         async_std::task::yield_now().await;
         let guard = self.store.write().await;