Add requires_zed_cla column to channels table

Max Brunsfeld and Marshall created

Don't allow granting guests write access in a call where the channel
or one of its ancestors requires the zed CLA, until that guest has
signed the Zed CLA.

Co-authored-by: Marshall <marshall@zed.dev>

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql                      |   3 
crates/collab/migrations/20240122224506_add_requires_zed_cla_column_to_channels.sql |   1 
crates/collab/src/api.rs                                                            |   2 
crates/collab/src/bin/seed.rs                                                       |   2 
crates/collab/src/db/ids.rs                                                         |   8 
crates/collab/src/db/queries/channels.rs                                            |  17 
crates/collab/src/db/queries/contributors.rs                                        |   2 
crates/collab/src/db/queries/rooms.rs                                               |  44 
crates/collab/src/db/queries/users.rs                                               |  69 
crates/collab/src/db/tables/channel.rs                                              |   1 
crates/collab/src/db/tests/contributor_tests.rs                                     |   4 
crates/collab/src/db/tests/db_tests.rs                                              |   4 
crates/collab/src/tests/channel_guest_tests.rs                                      | 102 
crates/collab/src/tests/test_server.rs                                              |   6 
14 files changed, 225 insertions(+), 40 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -196,7 +196,8 @@ CREATE TABLE "channels" (
     "name" VARCHAR NOT NULL,
     "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
     "visibility" VARCHAR NOT NULL,
-    "parent_path" TEXT
+    "parent_path" TEXT,
+    "requires_zed_cla" BOOLEAN NOT NULL DEFAULT FALSE
 );
 
 CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path");

crates/collab/src/api.rs 🔗

@@ -69,7 +69,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
 
 #[derive(Debug, Deserialize)]
 struct AuthenticatedUserParams {
-    github_user_id: i32,
+    github_user_id: Option<i32>,
     github_login: String,
     github_email: Option<String>,
 }

crates/collab/src/bin/seed.rs 🔗

@@ -68,7 +68,7 @@ async fn main() {
             user_count += 1;
             db.get_or_create_user_by_github_account(
                 &github_user.login,
-                github_user.id,
+                Some(github_user.id),
                 github_user.email.as_deref(),
             )
             .await

crates/collab/src/db/ids.rs 🔗

@@ -173,6 +173,14 @@ impl ChannelRole {
             Banned => false,
         }
     }
+
+    pub fn requires_cla(&self) -> bool {
+        use ChannelRole::*;
+        match self {
+            Admin | Member => true,
+            Banned | Guest => false,
+        }
+    }
 }
 
 impl From<proto::ChannelRole> for ChannelRole {

crates/collab/src/db/queries/channels.rs 🔗

@@ -67,6 +67,7 @@ impl Database {
                         .as_ref()
                         .map_or(String::new(), |parent| parent.path()),
                 ),
+                requires_zed_cla: ActiveValue::NotSet,
             }
             .insert(&*tx)
             .await?;
@@ -261,6 +262,22 @@ impl Database {
         .await
     }
 
+    #[cfg(test)]
+    pub async fn set_channel_requires_zed_cla(
+        &self,
+        channel_id: ChannelId,
+        requires_zed_cla: bool,
+    ) -> Result<()> {
+        self.transaction(move |tx| async move {
+            let channel = self.get_channel_internal(channel_id, &*tx).await?;
+            let mut model = channel.into_active_model();
+            model.requires_zed_cla = ActiveValue::Set(requires_zed_cla);
+            model.update(&*tx).await?;
+            Ok(())
+        })
+        .await
+    }
+
     /// Deletes the channel with the specified ID.
     pub async fn delete_channel(
         &self,

crates/collab/src/db/queries/contributors.rs 🔗

@@ -58,7 +58,7 @@ impl Database {
     pub async fn add_contributor(
         &self,
         github_login: &str,
-        github_user_id: i32,
+        github_user_id: Option<i32>,
         github_email: Option<&str>,
     ) -> Result<()> {
         self.transaction(|tx| async move {

crates/collab/src/db/queries/rooms.rs 🔗

@@ -1029,6 +1029,11 @@ impl Database {
                 .await?
                 .ok_or_else(|| anyhow!("only admins can set participant role"))?;
 
+            if role.requires_cla() {
+                self.check_user_has_signed_cla(user_id, room_id, &*tx)
+                    .await?;
+            }
+
             let result = room_participant::Entity::update_many()
                 .filter(
                     Condition::all()
@@ -1050,6 +1055,45 @@ impl Database {
         .await
     }
 
+    async fn check_user_has_signed_cla(
+        &self,
+        user_id: UserId,
+        room_id: RoomId,
+        tx: &DatabaseTransaction,
+    ) -> Result<()> {
+        let channel = room::Entity::find_by_id(room_id)
+            .one(&*tx)
+            .await?
+            .ok_or_else(|| anyhow!("could not find room"))?
+            .find_related(channel::Entity)
+            .one(&*tx)
+            .await?;
+
+        if let Some(channel) = channel {
+            let requires_zed_cla = channel.requires_zed_cla
+                || channel::Entity::find()
+                    .filter(
+                        channel::Column::Id
+                            .is_in(channel.ancestors())
+                            .and(channel::Column::RequiresZedCla.eq(true)),
+                    )
+                    .count(&*tx)
+                    .await?
+                    > 0;
+            if requires_zed_cla {
+                if contributor::Entity::find()
+                    .filter(contributor::Column::UserId.eq(user_id))
+                    .one(&*tx)
+                    .await?
+                    .is_none()
+                {
+                    Err(anyhow!("user has not signed the Zed CLA"))?;
+                }
+            }
+        }
+        Ok(())
+    }
+
     pub async fn connection_lost(&self, connection: ConnectionId) -> Result<()> {
         self.transaction(|tx| async move {
             self.room_connection_lost(connection, &*tx).await?;

crates/collab/src/db/queries/users.rs 🔗

@@ -72,7 +72,7 @@ impl Database {
     pub async fn get_or_create_user_by_github_account(
         &self,
         github_login: &str,
-        github_user_id: i32,
+        github_user_id: Option<i32>,
         github_email: Option<&str>,
     ) -> Result<User> {
         self.transaction(|tx| async move {
@@ -90,39 +90,48 @@ impl Database {
     pub async fn get_or_create_user_by_github_account_tx(
         &self,
         github_login: &str,
-        github_user_id: i32,
+        github_user_id: Option<i32>,
         github_email: Option<&str>,
         tx: &DatabaseTransaction,
     ) -> Result<User> {
-        if let Some(user_by_github_user_id) = user::Entity::find()
-            .filter(user::Column::GithubUserId.eq(github_user_id))
-            .one(tx)
-            .await?
-        {
-            let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
-            user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
-            Ok(user_by_github_user_id.update(tx).await?)
-        } else if let Some(user_by_github_login) = user::Entity::find()
-            .filter(user::Column::GithubLogin.eq(github_login))
-            .one(tx)
-            .await?
-        {
-            let mut user_by_github_login = user_by_github_login.into_active_model();
-            user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
-            Ok(user_by_github_login.update(tx).await?)
+        if let Some(github_user_id) = github_user_id {
+            if let Some(user_by_github_user_id) = user::Entity::find()
+                .filter(user::Column::GithubUserId.eq(github_user_id))
+                .one(tx)
+                .await?
+            {
+                let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
+                user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
+                Ok(user_by_github_user_id.update(tx).await?)
+            } else if let Some(user_by_github_login) = user::Entity::find()
+                .filter(user::Column::GithubLogin.eq(github_login))
+                .one(tx)
+                .await?
+            {
+                let mut user_by_github_login = user_by_github_login.into_active_model();
+                user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
+                Ok(user_by_github_login.update(tx).await?)
+            } else {
+                let user = user::Entity::insert(user::ActiveModel {
+                    email_address: ActiveValue::set(github_email.map(|email| email.into())),
+                    github_login: ActiveValue::set(github_login.into()),
+                    github_user_id: ActiveValue::set(Some(github_user_id)),
+                    admin: ActiveValue::set(false),
+                    invite_count: ActiveValue::set(0),
+                    invite_code: ActiveValue::set(None),
+                    metrics_id: ActiveValue::set(Uuid::new_v4()),
+                    ..Default::default()
+                })
+                .exec_with_returning(&*tx)
+                .await?;
+                Ok(user)
+            }
         } else {
-            let user = user::Entity::insert(user::ActiveModel {
-                email_address: ActiveValue::set(github_email.map(|email| email.into())),
-                github_login: ActiveValue::set(github_login.into()),
-                github_user_id: ActiveValue::set(Some(github_user_id)),
-                admin: ActiveValue::set(false),
-                invite_count: ActiveValue::set(0),
-                invite_code: ActiveValue::set(None),
-                metrics_id: ActiveValue::set(Uuid::new_v4()),
-                ..Default::default()
-            })
-            .exec_with_returning(&*tx)
-            .await?;
+            let user = user::Entity::find()
+                .filter(user::Column::GithubLogin.eq(github_login))
+                .one(tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such user {}", github_login))?;
             Ok(user)
         }
     }

crates/collab/src/db/tests/contributor_tests.rs 🔗

@@ -23,13 +23,13 @@ async fn test_contributors(db: &Arc<Database>) {
 
     assert_eq!(db.get_contributors().await.unwrap(), Vec::<String>::new());
 
-    db.add_contributor("user1", 1, None).await.unwrap();
+    db.add_contributor("user1", Some(1), None).await.unwrap();
     assert_eq!(
         db.get_contributors().await.unwrap(),
         vec!["user1".to_string()]
     );
 
-    db.add_contributor("user2", 2, None).await.unwrap();
+    db.add_contributor("user2", Some(2), None).await.unwrap();
     assert_eq!(
         db.get_contributors().await.unwrap(),
         vec!["user1".to_string(), "user2".to_string()]

crates/collab/src/db/tests/db_tests.rs 🔗

@@ -105,7 +105,7 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
         .user_id;
 
     let user = db
-        .get_or_create_user_by_github_account("the-new-login2", 102, None)
+        .get_or_create_user_by_github_account("the-new-login2", Some(102), None)
         .await
         .unwrap();
     assert_eq!(user.id, user_id2);
@@ -113,7 +113,7 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
     assert_eq!(user.github_user_id, Some(102));
 
     let user = db
-        .get_or_create_user_by_github_account("login3", 103, Some("user3@example.com"))
+        .get_or_create_user_by_github_account("login3", Some(103), Some("user3@example.com"))
         .await
         .unwrap();
     assert_eq!(&user.github_login, "login3");

crates/collab/src/tests/channel_guest_tests.rs 🔗

@@ -1,4 +1,4 @@
-use crate::tests::TestServer;
+use crate::{db::ChannelId, tests::TestServer};
 use call::ActiveCall;
 use editor::Editor;
 use gpui::{BackgroundExecutor, TestAppContext};
@@ -159,3 +159,103 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
         .await
         .is_err());
 }
+
+#[gpui::test]
+async fn test_channel_requires_zed_cla(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+    let mut server = TestServer::start(cx_a.executor()).await;
+
+    server
+        .app_state
+        .db
+        .get_or_create_user_by_github_account("user_b", Some(100), None)
+        .await
+        .unwrap();
+
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+    let active_call_a = cx_a.read(ActiveCall::global);
+    let active_call_b = cx_b.read(ActiveCall::global);
+
+    // Create a parent channel that requires the Zed CLA
+    let parent_channel_id = server
+        .make_channel("the-channel", None, (&client_a, cx_a), &mut [])
+        .await;
+    server
+        .app_state
+        .db
+        .set_channel_requires_zed_cla(ChannelId::from_proto(parent_channel_id), true)
+        .await
+        .unwrap();
+
+    // Create a public channel that is a child of the parent channel.
+    let channel_id = client_a
+        .channel_store()
+        .update(cx_a, |store, cx| {
+            store.create_channel("the-sub-channel", Some(parent_channel_id), cx)
+        })
+        .await
+        .unwrap();
+    client_a
+        .channel_store()
+        .update(cx_a, |store, cx| {
+            store.set_channel_visibility(channel_id, proto::ChannelVisibility::Public, cx)
+        })
+        .await
+        .unwrap();
+
+    // Users A and B join the channel. B is a guest.
+    active_call_a
+        .update(cx_a, |call, cx| call.join_channel(channel_id, cx))
+        .await
+        .unwrap();
+    active_call_b
+        .update(cx_b, |call, cx| call.join_channel(channel_id, cx))
+        .await
+        .unwrap();
+    cx_a.run_until_parked();
+    let room_b = cx_b
+        .read(ActiveCall::global)
+        .update(cx_b, |call, _| call.room().unwrap().clone());
+    assert!(room_b.read_with(cx_b, |room, _| room.read_only()));
+
+    // A tries to grant write access to B, but cannot because B has not
+    // yet signed the zed CLA.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.room().unwrap().update(cx, |room, cx| {
+                room.set_participant_role(
+                    client_b.user_id().unwrap(),
+                    proto::ChannelRole::Member,
+                    cx,
+                )
+            })
+        })
+        .await
+        .unwrap_err();
+    cx_a.run_until_parked();
+    assert!(room_b.read_with(cx_b, |room, _| room.read_only()));
+
+    // User B signs the zed CLA.
+    server
+        .app_state
+        .db
+        .add_contributor("user_b", Some(100), None)
+        .await
+        .unwrap();
+
+    // A can now grant write access to B.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.room().unwrap().update(cx, |room, cx| {
+                room.set_participant_role(
+                    client_b.user_id().unwrap(),
+                    proto::ChannelRole::Member,
+                    cx,
+                )
+            })
+        })
+        .await
+        .unwrap();
+    cx_a.run_until_parked();
+    assert!(room_b.read_with(cx_b, |room, _| !room.read_only()));
+}

crates/collab/src/tests/test_server.rs 🔗

@@ -43,6 +43,7 @@ pub struct TestServer {
     pub app_state: Arc<AppState>,
     pub test_live_kit_server: Arc<live_kit_client::TestServer>,
     server: Arc<Server>,
+    next_github_user_id: i32,
     connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
     forbid_connections: Arc<AtomicBool>,
     _test_db: TestDb,
@@ -108,6 +109,7 @@ impl TestServer {
             server,
             connection_killers: Default::default(),
             forbid_connections: Default::default(),
+            next_github_user_id: 0,
             _test_db: test_db,
             test_live_kit_server: live_kit_server,
         }
@@ -157,6 +159,8 @@ impl TestServer {
         {
             user.id
         } else {
+            let github_user_id = self.next_github_user_id;
+            self.next_github_user_id += 1;
             self.app_state
                 .db
                 .create_user(
@@ -164,7 +168,7 @@ impl TestServer {
                     false,
                     NewUserParams {
                         github_login: name.into(),
-                        github_user_id: 0,
+                        github_user_id,
                     },
                 )
                 .await