Return project collaborators and connection IDs in a `RoomGuard`

Antonio Scandurra created

Change summary

crates/collab/src/db.rs  | 20 +++++++---
crates/collab/src/rpc.rs | 81 ++++++++++++++++++++++-------------------
2 files changed, 57 insertions(+), 44 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -1981,8 +1981,12 @@ impl Database {
         &self,
         project_id: ProjectId,
         connection_id: ConnectionId,
-    ) -> Result<Vec<project_collaborator::Model>> {
-        self.transaction(|tx| async move {
+    ) -> Result<RoomGuard<Vec<project_collaborator::Model>>> {
+        self.room_transaction(|tx| async move {
+            let project = project::Entity::find_by_id(project_id)
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such project"))?;
             let collaborators = project_collaborator::Entity::find()
                 .filter(project_collaborator::Column::ProjectId.eq(project_id))
                 .all(&*tx)
@@ -1992,7 +1996,7 @@ impl Database {
                 .iter()
                 .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
             {
-                Ok(collaborators)
+                Ok((project.room_id, collaborators))
             } else {
                 Err(anyhow!("no such project"))?
             }
@@ -2004,13 +2008,17 @@ impl Database {
         &self,
         project_id: ProjectId,
         connection_id: ConnectionId,
-    ) -> Result<HashSet<ConnectionId>> {
-        self.transaction(|tx| async move {
+    ) -> Result<RoomGuard<HashSet<ConnectionId>>> {
+        self.room_transaction(|tx| async move {
             #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
             enum QueryAs {
                 ConnectionId,
             }
 
+            let project = project::Entity::find_by_id(project_id)
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such project"))?;
             let mut db_connection_ids = project_collaborator::Entity::find()
                 .select_only()
                 .column_as(
@@ -2028,7 +2036,7 @@ impl Database {
             }
 
             if connection_ids.contains(&connection_id) {
-                Ok(connection_ids)
+                Ok((project.room_id, connection_ids))
             } else {
                 Err(anyhow!("no such project"))?
             }

crates/collab/src/rpc.rs 🔗

@@ -1245,7 +1245,7 @@ async fn update_language_server(
         .await?;
     broadcast(
         session.connection_id,
-        project_connection_ids,
+        project_connection_ids.iter().copied(),
         |connection_id| {
             session
                 .peer
@@ -1264,23 +1264,24 @@ where
     T: EntityMessage + RequestMessage,
 {
     let project_id = ProjectId::from_proto(request.remote_entity_id());
-    let collaborators = session
-        .db()
-        .await
-        .project_collaborators(project_id, session.connection_id)
-        .await?;
-    let host = collaborators
-        .iter()
-        .find(|collaborator| collaborator.is_host)
-        .ok_or_else(|| anyhow!("host not found"))?;
+    let host_connection_id = {
+        let collaborators = session
+            .db()
+            .await
+            .project_collaborators(project_id, session.connection_id)
+            .await?;
+        ConnectionId(
+            collaborators
+                .iter()
+                .find(|collaborator| collaborator.is_host)
+                .ok_or_else(|| anyhow!("host not found"))?
+                .connection_id as u32,
+        )
+    };
 
     let payload = session
         .peer
-        .forward_request(
-            session.connection_id,
-            ConnectionId(host.connection_id as u32),
-            request,
-        )
+        .forward_request(session.connection_id, host_connection_id, request)
         .await?;
 
     response.send(payload)?;
@@ -1293,16 +1294,18 @@ async fn save_buffer(
     session: Session,
 ) -> Result<()> {
     let project_id = ProjectId::from_proto(request.project_id);
-    let collaborators = session
-        .db()
-        .await
-        .project_collaborators(project_id, session.connection_id)
-        .await?;
-    let host = collaborators
-        .into_iter()
-        .find(|collaborator| collaborator.is_host)
-        .ok_or_else(|| anyhow!("host not found"))?;
-    let host_connection_id = ConnectionId(host.connection_id as u32);
+    let host_connection_id = {
+        let collaborators = session
+            .db()
+            .await
+            .project_collaborators(project_id, session.connection_id)
+            .await?;
+        let host = collaborators
+            .iter()
+            .find(|collaborator| collaborator.is_host)
+            .ok_or_else(|| anyhow!("host not found"))?;
+        ConnectionId(host.connection_id as u32)
+    };
     let response_payload = session
         .peer
         .forward_request(session.connection_id, host_connection_id, request.clone())
@@ -1316,7 +1319,7 @@ async fn save_buffer(
     collaborators
         .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
     let project_connection_ids = collaborators
-        .into_iter()
+        .iter()
         .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
     broadcast(host_connection_id, project_connection_ids, |conn_id| {
         session
@@ -1353,7 +1356,7 @@ async fn update_buffer(
 
     broadcast(
         session.connection_id,
-        project_connection_ids,
+        project_connection_ids.iter().copied(),
         |connection_id| {
             session
                 .peer
@@ -1374,7 +1377,7 @@ async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session)
 
     broadcast(
         session.connection_id,
-        project_connection_ids,
+        project_connection_ids.iter().copied(),
         |connection_id| {
             session
                 .peer
@@ -1393,7 +1396,7 @@ async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Re
         .await?;
     broadcast(
         session.connection_id,
-        project_connection_ids,
+        project_connection_ids.iter().copied(),
         |connection_id| {
             session
                 .peer
@@ -1412,7 +1415,7 @@ async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<(
         .await?;
     broadcast(
         session.connection_id,
-        project_connection_ids,
+        project_connection_ids.iter().copied(),
         |connection_id| {
             session
                 .peer
@@ -1430,14 +1433,16 @@ async fn follow(
     let project_id = ProjectId::from_proto(request.project_id);
     let leader_id = ConnectionId(request.leader_id);
     let follower_id = session.connection_id;
-    let project_connection_ids = session
-        .db()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
+    {
+        let project_connection_ids = session
+            .db()
+            .await
+            .project_connection_ids(project_id, session.connection_id)
+            .await?;
 
-    if !project_connection_ids.contains(&leader_id) {
-        Err(anyhow!("no such peer"))?;
+        if !project_connection_ids.contains(&leader_id) {
+            Err(anyhow!("no such peer"))?;
+        }
     }
 
     let mut response_payload = session
@@ -1691,7 +1696,7 @@ async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> R
         .await?;
     broadcast(
         session.connection_id,
-        project_connection_ids,
+        project_connection_ids.iter().copied(),
         |connection_id| {
             session
                 .peer