Always use the database to retrieve collaborators for a project

Antonio Scandurra created

Change summary

crates/collab/src/db.rs        |  58 ++++++++++++
crates/collab/src/rpc.rs       | 174 +++++++++++++++++++++--------------
crates/collab/src/rpc/store.rs |  28 -----
3 files changed, 160 insertions(+), 100 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -1886,6 +1886,64 @@ where
         .await
     }
 
+    pub async fn project_collaborators(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<Vec<ProjectCollaborator>> {
+        self.transact(|mut tx| async move {
+            let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
+                "
+                SELECT *
+                FROM project_collaborators
+                WHERE project_id = $1
+                ",
+            )
+            .bind(project_id)
+            .fetch_all(&mut tx)
+            .await?;
+
+            if collaborators
+                .iter()
+                .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
+            {
+                Ok(collaborators)
+            } else {
+                Err(anyhow!("no such project"))?
+            }
+        })
+        .await
+    }
+
+    pub async fn project_connection_ids(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<HashSet<ConnectionId>> {
+        self.transact(|mut tx| async move {
+            let connection_ids = sqlx::query_scalar::<_, i32>(
+                "
+                SELECT connection_id
+                FROM project_collaborators
+                WHERE project_id = $1
+                ",
+            )
+            .bind(project_id)
+            .fetch_all(&mut tx)
+            .await?;
+
+            if connection_ids.contains(&(connection_id.0 as i32)) {
+                Ok(connection_ids
+                    .into_iter()
+                    .map(|connection_id| ConnectionId(connection_id as u32))
+                    .collect())
+            } else {
+                Err(anyhow!("no such project"))?
+            }
+        })
+        .await
+    }
+
     pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
         todo!()
         // test_support!(self, {

crates/collab/src/rpc.rs 🔗

@@ -1187,13 +1187,15 @@ impl Server {
         self: Arc<Server>,
         request: Message<proto::UpdateLanguageServer>,
     ) -> Result<()> {
-        let receiver_ids = self.store().await.project_connection_ids(
-            ProjectId::from_proto(request.payload.project_id),
-            request.sender_connection_id,
-        )?;
+        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
         broadcast(
             request.sender_connection_id,
-            receiver_ids,
+            project_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,
@@ -1214,25 +1216,25 @@ impl Server {
         T: EntityMessage + RequestMessage,
     {
         let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
-        let host_connection_id = self
-            .store()
-            .await
-            .read_project(project_id, request.sender_connection_id)?
-            .host_connection_id;
+        let collaborators = self
+            .app_state
+            .db
+            .project_collaborators(project_id, request.sender_connection_id)
+            .await?;
+        let host = collaborators
+            .iter()
+            .find(|collaborator| collaborator.is_host)
+            .ok_or_else(|| anyhow!("host not found"))?;
+
         let payload = self
             .peer
             .forward_request(
                 request.sender_connection_id,
-                host_connection_id,
+                ConnectionId(host.connection_id as u32),
                 request.payload,
             )
             .await?;
 
-        // Ensure project still exists by the time we get the response from the host.
-        self.store()
-            .await
-            .read_project(project_id, request.sender_connection_id)?;
-
         response.send(payload)?;
         Ok(())
     }
@@ -1243,25 +1245,39 @@ impl Server {
         response: Response<proto::SaveBuffer>,
     ) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
-        let host = self
-            .store()
-            .await
-            .read_project(project_id, request.sender_connection_id)?
-            .host_connection_id;
+        let collaborators = self
+            .app_state
+            .db
+            .project_collaborators(project_id, request.sender_connection_id)
+            .await?;
+        let host = collaborators
+            .into_iter()
+            .find(|collaborator| collaborator.is_host)
+            .ok_or_else(|| anyhow!("host not found"))?;
+        let host_connection_id = ConnectionId(host.connection_id as u32);
         let response_payload = self
             .peer
-            .forward_request(request.sender_connection_id, host, request.payload.clone())
+            .forward_request(
+                request.sender_connection_id,
+                host_connection_id,
+                request.payload.clone(),
+            )
             .await?;
 
-        let mut guests = self
-            .store()
-            .await
-            .read_project(project_id, request.sender_connection_id)?
-            .connection_ids();
-        guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id);
-        broadcast(host, guests, |conn_id| {
+        let mut collaborators = self
+            .app_state
+            .db
+            .project_collaborators(project_id, request.sender_connection_id)
+            .await?;
+        collaborators.retain(|collaborator| {
+            collaborator.connection_id != request.sender_connection_id.0 as i32
+        });
+        let project_connection_ids = collaborators
+            .into_iter()
+            .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
+        broadcast(host_connection_id, project_connection_ids, |conn_id| {
             self.peer
-                .forward_send(host, conn_id, response_payload.clone())
+                .forward_send(host_connection_id, conn_id, response_payload.clone())
         });
         response.send(response_payload)?;
         Ok(())
@@ -1285,14 +1301,15 @@ impl Server {
         response: Response<proto::UpdateBuffer>,
     ) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
-        let receiver_ids = {
-            let store = self.store().await;
-            store.project_connection_ids(project_id, request.sender_connection_id)?
-        };
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
 
         broadcast(
             request.sender_connection_id,
-            receiver_ids,
+            project_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,
@@ -1309,13 +1326,16 @@ impl Server {
         self: Arc<Server>,
         request: Message<proto::UpdateBufferFile>,
     ) -> Result<()> {
-        let receiver_ids = self.store().await.project_connection_ids(
-            ProjectId::from_proto(request.payload.project_id),
-            request.sender_connection_id,
-        )?;
+        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
+
         broadcast(
             request.sender_connection_id,
-            receiver_ids,
+            project_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,
@@ -1331,13 +1351,15 @@ impl Server {
         self: Arc<Server>,
         request: Message<proto::BufferReloaded>,
     ) -> Result<()> {
-        let receiver_ids = self.store().await.project_connection_ids(
-            ProjectId::from_proto(request.payload.project_id),
-            request.sender_connection_id,
-        )?;
+        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
         broadcast(
             request.sender_connection_id,
-            receiver_ids,
+            project_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,
@@ -1350,13 +1372,15 @@ impl Server {
     }
 
     async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
-        let receiver_ids = self.store().await.project_connection_ids(
-            ProjectId::from_proto(request.payload.project_id),
-            request.sender_connection_id,
-        )?;
+        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
         broadcast(
             request.sender_connection_id,
-            receiver_ids,
+            project_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,
@@ -1376,14 +1400,14 @@ impl Server {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let leader_id = ConnectionId(request.payload.leader_id);
         let follower_id = request.sender_connection_id;
-        {
-            let store = self.store().await;
-            if !store
-                .project_connection_ids(project_id, follower_id)?
-                .contains(&leader_id)
-            {
-                Err(anyhow!("no such peer"))?;
-            }
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
+
+        if !project_connection_ids.contains(&leader_id) {
+            Err(anyhow!("no such peer"))?;
         }
 
         let mut response_payload = self
@@ -1400,11 +1424,12 @@ impl Server {
     async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
         let leader_id = ConnectionId(request.payload.leader_id);
-        let store = self.store().await;
-        if !store
-            .project_connection_ids(project_id, request.sender_connection_id)?
-            .contains(&leader_id)
-        {
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
+        if !project_connection_ids.contains(&leader_id) {
             Err(anyhow!("no such peer"))?;
         }
         self.peer
@@ -1417,9 +1442,12 @@ impl Server {
         request: Message<proto::UpdateFollowers>,
     ) -> Result<()> {
         let project_id = ProjectId::from_proto(request.payload.project_id);
-        let store = self.store().await;
-        let connection_ids =
-            store.project_connection_ids(project_id, request.sender_connection_id)?;
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
+
         let leader_id = request
             .payload
             .variant
@@ -1431,7 +1459,7 @@ impl Server {
             });
         for follower_id in &request.payload.follower_ids {
             let follower_id = ConnectionId(*follower_id);
-            if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
+            if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
                 self.peer.forward_send(
                     request.sender_connection_id,
                     follower_id,
@@ -1629,13 +1657,15 @@ impl Server {
         self: Arc<Server>,
         request: Message<proto::UpdateDiffBase>,
     ) -> Result<()> {
-        let receiver_ids = self.store().await.project_connection_ids(
-            ProjectId::from_proto(request.payload.project_id),
-            request.sender_connection_id,
-        )?;
+        let project_id = ProjectId::from_proto(request.payload.project_id);
+        let project_connection_ids = self
+            .app_state
+            .db
+            .project_connection_ids(project_id, request.sender_connection_id)
+            .await?;
         broadcast(
             request.sender_connection_id,
-            receiver_ids,
+            project_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,

crates/collab/src/rpc/store.rs 🔗

@@ -325,34 +325,6 @@ impl Store {
         })
     }
 
-    pub fn project_connection_ids(
-        &self,
-        project_id: ProjectId,
-        acting_connection_id: ConnectionId,
-    ) -> Result<Vec<ConnectionId>> {
-        Ok(self
-            .read_project(project_id, acting_connection_id)?
-            .connection_ids())
-    }
-
-    pub fn read_project(
-        &self,
-        project_id: ProjectId,
-        connection_id: ConnectionId,
-    ) -> Result<&Project> {
-        let project = self
-            .projects
-            .get(&project_id)
-            .ok_or_else(|| anyhow!("no such project"))?;
-        if project.host_connection_id == connection_id
-            || project.guests.contains_key(&connection_id)
-        {
-            Ok(project)
-        } else {
-            Err(anyhow!("no such project"))?
-        }
-    }
-
     #[cfg(test)]
     pub fn check_invariants(&self) {
         for (connection_id, connection) in &self.connections {