Only allow read-write users to update buffers

Conrad Irwin created

Change summary

crates/collab/src/db/ids.rs              |  8 +++
crates/collab/src/db/queries/projects.rs | 56 ++++++++++++++++++++++++-
crates/collab/src/rpc.rs                 | 26 +++--------
3 files changed, 69 insertions(+), 21 deletions(-)

Detailed changes

crates/collab/src/db/ids.rs 🔗

@@ -148,6 +148,14 @@ impl ChannelRole {
             Guest | Banned => false,
         }
     }
+
+    pub fn can_read_projects(&self) -> bool {
+        use ChannelRole::*;
+        match self {
+            Admin | Member | Guest => true,
+            Banned => false,
+        }
+    }
 }
 
 impl From<proto::ChannelRole> for ChannelRole {

crates/collab/src/db/queries/projects.rs 🔗

@@ -805,6 +805,43 @@ impl Database {
         .map(|guard| guard.into_inner())
     }
 
+    pub async fn host_for_read_only_project_request(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<ConnectionId> {
+        let room_id = self.room_id_for_project(project_id).await?;
+        self.room_transaction(room_id, |tx| async move {
+            let current_participant = room_participant::Entity::find()
+                .filter(room_participant::Column::RoomId.eq(room_id))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such room"))?;
+
+            if !current_participant
+                .role
+                .map_or(false, |role| role.can_read_projects())
+            {
+                Err(anyhow!("not authorized to read projects"))?;
+            }
+
+            let host = project_collaborator::Entity::find()
+                .filter(
+                    project_collaborator::Column::ProjectId
+                        .eq(project_id)
+                        .and(project_collaborator::Column::IsHost.eq(true)),
+                )
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("failed to read project host"))?;
+
+            Ok(host.connection())
+        })
+        .await
+        .map(|guard| guard.into_inner())
+    }
+
     pub async fn host_for_mutating_project_request(
         &self,
         project_id: ProjectId,
@@ -821,8 +858,7 @@ impl Database {
 
             if !current_participant
                 .role
-                .unwrap_or(ChannelRole::Guest)
-                .can_edit_projects()
+                .map_or(false, |role| role.can_edit_projects())
             {
                 Err(anyhow!("not authorized to edit projects"))?;
             }
@@ -843,13 +879,27 @@ impl Database {
         .map(|guard| guard.into_inner())
     }
 
-    pub async fn project_collaborators(
+    pub async fn project_collaborators_for_buffer_update(
         &self,
         project_id: ProjectId,
         connection_id: ConnectionId,
     ) -> Result<RoomGuard<Vec<ProjectCollaborator>>> {
         let room_id = self.room_id_for_project(project_id).await?;
         self.room_transaction(room_id, |tx| async move {
+            let current_participant = room_participant::Entity::find()
+                .filter(room_participant::Column::RoomId.eq(room_id))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such room"))?;
+
+            if !current_participant
+                .role
+                .map_or(false, |role| role.can_edit_projects())
+            {
+                Err(anyhow!("not authorized to edit projects"))?;
+            }
+
             let collaborators = project_collaborator::Entity::find()
                 .filter(project_collaborator::Column::ProjectId.eq(project_id))
                 .all(&*tx)

crates/collab/src/rpc.rs 🔗

@@ -227,7 +227,7 @@ impl Server {
             .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
             .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
             .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
-            .add_request_handler(forward_mutating_project_request::<proto::OpenBufferByPath>)
+            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
             .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
             .add_request_handler(
                 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
@@ -1750,24 +1750,15 @@ where
     T: EntityMessage + RequestMessage,
 {
     let project_id = ProjectId::from_proto(request.remote_entity_id());
-    let host_connection_id = {
-        let collaborators = session
-            .db()
-            .await
-            .project_collaborators(project_id, session.connection_id)
-            .await?;
-        collaborators
-            .iter()
-            .find(|collaborator| collaborator.is_host)
-            .ok_or_else(|| anyhow!("host not found"))?
-            .connection_id
-    };
-
+    let host_connection_id = session
+        .db()
+        .await
+        .host_for_read_only_project_request(project_id, session.connection_id)
+        .await?;
     let payload = session
         .peer
         .forward_request(session.connection_id, host_connection_id, request)
         .await?;
-
     response.send(payload)?;
     Ok(())
 }
@@ -1786,12 +1777,10 @@ where
         .await
         .host_for_mutating_project_request(project_id, session.connection_id)
         .await?;
-
     let payload = session
         .peer
         .forward_request(session.connection_id, host_connection_id, request)
         .await?;
-
     response.send(payload)?;
     Ok(())
 }
@@ -1823,11 +1812,12 @@ async fn update_buffer(
     let project_id = ProjectId::from_proto(request.project_id);
     let mut guest_connection_ids;
     let mut host_connection_id = None;
+
     {
         let collaborators = session
             .db()
             .await
-            .project_collaborators(project_id, session.connection_id)
+            .project_collaborators_for_buffer_update(project_id, session.connection_id)
             .await?;
         guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
         for collaborator in collaborators.iter() {