Add check_is_channel_participant

Conrad Irwin created

Refactor permission checks to load ancestor permissions into memory
for all checks to make the different logics more explicit.

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql |   3 
crates/collab/src/db/ids.rs                                    |   4 
crates/collab/src/db/queries/channels.rs                       | 180 +++
crates/collab/src/db/tables/channel.rs                         |   2 
crates/collab/src/db/tests/channel_tests.rs                    | 121 ++
crates/collab/src/tests/channel_tests.rs                       |   5 
crates/rpc/proto/zed.proto                                     |   1 
7 files changed, 285 insertions(+), 31 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -192,7 +192,8 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id");
 CREATE TABLE "channels" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "name" VARCHAR NOT NULL,
-    "created_at" TIMESTAMP NOT NULL DEFAULT now
+    "created_at" TIMESTAMP NOT NULL DEFAULT now,
+    "visibility" VARCHAR NOT NULL
 );
 
 CREATE TABLE IF NOT EXISTS "channel_chat_participants" (

crates/collab/src/db/ids.rs 🔗

@@ -91,6 +91,8 @@ pub enum ChannelRole {
     Member,
     #[sea_orm(string_value = "guest")]
     Guest,
+    #[sea_orm(string_value = "banned")]
+    Banned,
 }
 
 impl From<proto::ChannelRole> for ChannelRole {
@@ -99,6 +101,7 @@ impl From<proto::ChannelRole> for ChannelRole {
             proto::ChannelRole::Admin => ChannelRole::Admin,
             proto::ChannelRole::Member => ChannelRole::Member,
             proto::ChannelRole::Guest => ChannelRole::Guest,
+            proto::ChannelRole::Banned => ChannelRole::Banned,
         }
     }
 }
@@ -109,6 +112,7 @@ impl Into<proto::ChannelRole> for ChannelRole {
             ChannelRole::Admin => proto::ChannelRole::Admin,
             ChannelRole::Member => proto::ChannelRole::Member,
             ChannelRole::Guest => proto::ChannelRole::Guest,
+            ChannelRole::Banned => proto::ChannelRole::Banned,
         }
     }
 }

crates/collab/src/db/queries/channels.rs 🔗

@@ -37,8 +37,9 @@ impl Database {
             }
 
             let channel = channel::ActiveModel {
+                id: ActiveValue::NotSet,
                 name: ActiveValue::Set(name.to_string()),
-                ..Default::default()
+                visibility: ActiveValue::Set(ChannelVisibility::ChannelMembers),
             }
             .insert(&*tx)
             .await?;
@@ -89,6 +90,29 @@ impl Database {
         .await
     }
 
+    pub async fn set_channel_visibility(
+        &self,
+        channel_id: ChannelId,
+        visibility: ChannelVisibility,
+        user_id: UserId,
+    ) -> Result<()> {
+        self.transaction(move |tx| async move {
+            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
+                .await?;
+
+            channel::ActiveModel {
+                id: ActiveValue::Unchanged(channel_id),
+                visibility: ActiveValue::Set(visibility),
+                ..Default::default()
+            }
+            .update(&*tx)
+            .await?;
+
+            Ok(())
+        })
+        .await
+    }
+
     pub async fn delete_channel(
         &self,
         channel_id: ChannelId,
@@ -160,11 +184,11 @@ impl Database {
         &self,
         channel_id: ChannelId,
         invitee_id: UserId,
-        inviter_id: UserId,
+        admin_id: UserId,
         role: ChannelRole,
     ) -> Result<()> {
         self.transaction(move |tx| async move {
-            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
+            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
                 .await?;
 
             channel_member::ActiveModel {
@@ -262,10 +286,10 @@ impl Database {
         &self,
         channel_id: ChannelId,
         member_id: UserId,
-        remover_id: UserId,
+        admin_id: UserId,
     ) -> Result<()> {
         self.transaction(|tx| async move {
-            self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
+            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
                 .await?;
 
             let result = channel_member::Entity::delete_many()
@@ -481,12 +505,12 @@ impl Database {
     pub async fn set_channel_member_role(
         &self,
         channel_id: ChannelId,
-        from: UserId,
+        admin_id: UserId,
         for_user: UserId,
         role: ChannelRole,
     ) -> Result<()> {
         self.transaction(|tx| async move {
-            self.check_user_is_channel_admin(channel_id, from, &*tx)
+            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
                 .await?;
 
             let result = channel_member::Entity::update_many()
@@ -613,43 +637,147 @@ impl Database {
         Ok(user_ids)
     }
 
+    pub async fn check_user_is_channel_admin(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        tx: &DatabaseTransaction,
+    ) -> Result<()> {
+        match self.channel_role_for_user(channel_id, user_id, tx).await? {
+            Some(ChannelRole::Admin) => Ok(()),
+            Some(ChannelRole::Member)
+            | Some(ChannelRole::Banned)
+            | Some(ChannelRole::Guest)
+            | None => Err(anyhow!(
+                "user is not a channel admin or channel does not exist"
+            ))?,
+        }
+    }
+
     pub async fn check_user_is_channel_member(
         &self,
         channel_id: ChannelId,
         user_id: UserId,
         tx: &DatabaseTransaction,
     ) -> Result<()> {
-        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
-        channel_member::Entity::find()
-            .filter(
-                channel_member::Column::ChannelId
-                    .is_in(channel_ids)
-                    .and(channel_member::Column::UserId.eq(user_id)),
-            )
-            .one(&*tx)
-            .await?
-            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
-        Ok(())
+        match self.channel_role_for_user(channel_id, user_id, tx).await? {
+            Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()),
+            Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
+                "user is not a channel member or channel does not exist"
+            ))?,
+        }
     }
 
-    pub async fn check_user_is_channel_admin(
+    pub async fn check_user_is_channel_participant(
         &self,
         channel_id: ChannelId,
         user_id: UserId,
         tx: &DatabaseTransaction,
     ) -> Result<()> {
+        match self.channel_role_for_user(channel_id, user_id, tx).await? {
+            Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
+                Ok(())
+            }
+            Some(ChannelRole::Banned) | None => Err(anyhow!(
+                "user is not a channel participant or channel does not exist"
+            ))?,
+        }
+    }
+
+    pub async fn channel_role_for_user(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        tx: &DatabaseTransaction,
+    ) -> Result<Option<ChannelRole>> {
         let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
-        channel_member::Entity::find()
+
+        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+        enum QueryChannelMembership {
+            ChannelId,
+            Role,
+            Admin,
+            Visibility,
+        }
+
+        let mut rows = channel_member::Entity::find()
+            .left_join(channel::Entity)
             .filter(
                 channel_member::Column::ChannelId
                     .is_in(channel_ids)
-                    .and(channel_member::Column::UserId.eq(user_id))
-                    .and(channel_member::Column::Admin.eq(true)),
+                    .and(channel_member::Column::UserId.eq(user_id)),
             )
-            .one(&*tx)
-            .await?
-            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
-        Ok(())
+            .select_only()
+            .column(channel_member::Column::ChannelId)
+            .column(channel_member::Column::Role)
+            .column(channel_member::Column::Admin)
+            .column(channel::Column::Visibility)
+            .into_values::<_, QueryChannelMembership>()
+            .stream(&*tx)
+            .await?;
+
+        let mut is_admin = false;
+        let mut is_member = false;
+        let mut is_participant = false;
+        let mut is_banned = false;
+        let mut current_channel_visibility = None;
+
+        // note these channels are not iterated in any particular order,
+        // our current logic takes the highest permission available.
+        while let Some(row) = rows.next().await {
+            let (ch_id, role, admin, visibility): (
+                ChannelId,
+                Option<ChannelRole>,
+                bool,
+                ChannelVisibility,
+            ) = row?;
+            match role {
+                Some(ChannelRole::Admin) => is_admin = true,
+                Some(ChannelRole::Member) => is_member = true,
+                Some(ChannelRole::Guest) => {
+                    if visibility == ChannelVisibility::Public {
+                        is_participant = true
+                    }
+                }
+                Some(ChannelRole::Banned) => is_banned = true,
+                None => {
+                    // rows created from pre-role collab server.
+                    if admin {
+                        is_admin = true
+                    } else {
+                        is_member = true
+                    }
+                }
+            }
+            if channel_id == ch_id {
+                current_channel_visibility = Some(visibility);
+            }
+        }
+        // free up database connection
+        drop(rows);
+
+        Ok(if is_admin {
+            Some(ChannelRole::Admin)
+        } else if is_member {
+            Some(ChannelRole::Member)
+        } else if is_banned {
+            Some(ChannelRole::Banned)
+        } else if is_participant {
+            if current_channel_visibility.is_none() {
+                current_channel_visibility = channel::Entity::find()
+                    .filter(channel::Column::Id.eq(channel_id))
+                    .one(&*tx)
+                    .await?
+                    .map(|channel| channel.visibility);
+            }
+            if current_channel_visibility == Some(ChannelVisibility::Public) {
+                Some(ChannelRole::Guest)
+            } else {
+                None
+            }
+        } else {
+            None
+        })
     }
 
     /// Returns the channel ancestors, deepest first

crates/collab/src/db/tables/channel.rs 🔗

@@ -7,7 +7,7 @@ pub struct Model {
     #[sea_orm(primary_key)]
     pub id: ChannelId,
     pub name: String,
-    pub visbility: ChannelVisibility,
+    pub visibility: ChannelVisibility,
 }
 
 impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -8,11 +8,14 @@ use crate::{
     db::{
         queries::channels::ChannelGraph,
         tests::{graph, TEST_RELEASE_CHANNEL},
-        ChannelId, ChannelRole, Database, NewUserParams,
+        ChannelId, ChannelRole, Database, NewUserParams, UserId,
     },
     test_both_dbs,
 };
-use std::sync::Arc;
+use std::sync::{
+    atomic::{AtomicI32, Ordering},
+    Arc,
+};
 
 test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
 
@@ -850,6 +853,101 @@ async fn test_db_channel_moving_bugs(db: &Arc<Database>) {
     );
 }
 
+test_both_dbs!(
+    test_user_is_channel_participant,
+    test_user_is_channel_participant_postgres,
+    test_user_is_channel_participant_sqlite
+);
+
+async fn test_user_is_channel_participant(db: &Arc<Database>) {
+    let admin_id = new_test_user(db, "admin@example.com").await;
+    let member_id = new_test_user(db, "member@example.com").await;
+    let guest_id = new_test_user(db, "guest@example.com").await;
+
+    let zed_id = db.create_root_channel("zed", admin_id).await.unwrap();
+    let intermediate_id = db
+        .create_channel("active", Some(zed_id), admin_id)
+        .await
+        .unwrap();
+    let public_id = db
+        .create_channel("active", Some(intermediate_id), admin_id)
+        .await
+        .unwrap();
+
+    db.set_channel_visibility(public_id, crate::db::ChannelVisibility::Public, admin_id)
+        .await
+        .unwrap();
+    db.invite_channel_member(intermediate_id, member_id, admin_id, ChannelRole::Member)
+        .await
+        .unwrap();
+    db.invite_channel_member(public_id, guest_id, admin_id, ChannelRole::Guest)
+        .await
+        .unwrap();
+
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(public_id, admin_id, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(public_id, member_id, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(public_id, guest_id, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+
+    db.set_channel_member_role(public_id, admin_id, guest_id, ChannelRole::Banned)
+        .await
+        .unwrap();
+    assert!(db
+        .transaction(|tx| async move {
+            db.check_user_is_channel_participant(public_id, guest_id, &*tx)
+                .await
+        })
+        .await
+        .is_err());
+
+    db.remove_channel_member(public_id, guest_id, admin_id)
+        .await
+        .unwrap();
+
+    db.set_channel_visibility(zed_id, crate::db::ChannelVisibility::Public, admin_id)
+        .await
+        .unwrap();
+
+    db.invite_channel_member(zed_id, guest_id, admin_id, ChannelRole::Guest)
+        .await
+        .unwrap();
+
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(zed_id, guest_id, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+    assert!(db
+        .transaction(|tx| async move {
+            db.check_user_is_channel_participant(intermediate_id, guest_id, &*tx)
+                .await
+        })
+        .await
+        .is_err(),);
+
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(public_id, guest_id, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+}
+
 #[track_caller]
 fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)]) {
     let mut actual_map: HashMap<ChannelId, HashSet<ChannelId>> = HashMap::default();
@@ -874,3 +972,22 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)])
 
     pretty_assertions::assert_eq!(actual_map, expected_map)
 }
+
+static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
+
+async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
+    let gid = GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst);
+
+    db.create_user(
+        email,
+        false,
+        NewUserParams {
+            github_login: email[0..email.find("@").unwrap()].to_string(),
+            github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst),
+            invite_count: 0,
+        },
+    )
+    .await
+    .unwrap()
+    .user_id
+}

crates/collab/src/tests/channel_tests.rs 🔗

@@ -6,7 +6,10 @@ use call::ActiveCall;
 use channel::{ChannelId, ChannelMembership, ChannelStore};
 use client::User;
 use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
-use rpc::{proto, RECEIVE_TIMEOUT};
+use rpc::{
+    proto::{self},
+    RECEIVE_TIMEOUT,
+};
 use std::sync::Arc;
 
 #[gpui::test]

crates/rpc/proto/zed.proto 🔗

@@ -1040,6 +1040,7 @@ enum ChannelRole {
     Admin = 0;
     Member = 1;
     Guest = 2;
+    Banned = 3;
 }
 
 message SetChannelMemberRole {