Add subchannel creation

Mikayla Maki and max created

co-authored-by: max <max@zed.dev>

Change summary

crates/collab/src/db.rs       | 95 +++++++++++++++++++++++++++++++++---
crates/collab/src/db/tests.rs | 28 ++++++++++
crates/collab/src/rpc.rs      | 28 +++++++---
3 files changed, 131 insertions(+), 20 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -3093,6 +3093,22 @@ impl Database {
         self.transaction(move |tx| async move {
             let tx = tx;
 
+            if let Some(parent) = parent {
+                let channels = self.get_channel_ancestors(parent, &*tx).await?;
+                channel_member::Entity::find()
+                    .filter(channel_member::Column::ChannelId.is_in(channels.iter().copied()))
+                    .filter(
+                        channel_member::Column::UserId
+                            .eq(creator_id)
+                            .and(channel_member::Column::Accepted.eq(true)),
+                    )
+                    .one(&*tx)
+                    .await?
+                    .ok_or_else(|| {
+                        anyhow!("User does not have the permissions to create this channel")
+                    })?;
+            }
+
             let channel = channel::ActiveModel {
                 name: ActiveValue::Set(name.to_string()),
                 ..Default::default()
@@ -3175,11 +3191,6 @@ impl Database {
 
             let channels_to_remove = descendants.keys().copied().collect::<Vec<_>>();
 
-            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
-            enum QueryUserIds {
-                UserId,
-            }
-
             let members_to_notify: Vec<UserId> = channel_member::Entity::find()
                 .filter(channel_member::Column::ChannelId.is_in(channels_to_remove.iter().copied()))
                 .select_only()
@@ -3325,11 +3336,6 @@ impl Database {
         self.transaction(|tx| async move {
             let tx = tx;
 
-            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
-            enum QueryChannelIds {
-                ChannelId,
-            }
-
             let starting_channel_ids: Vec<ChannelId> = channel_member::Entity::find()
                 .filter(
                     channel_member::Column::UserId
@@ -3368,6 +3374,65 @@ impl Database {
         .await
     }
 
+    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
+        self.transaction(|tx| async move {
+            let tx = tx;
+            let ancestor_ids = self.get_channel_ancestors(id, &*tx).await?;
+            let user_ids = channel_member::Entity::find()
+                .distinct()
+                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
+                .select_only()
+                .column(channel_member::Column::UserId)
+                .into_values::<_, QueryUserIds>()
+                .all(&*tx)
+                .await?;
+            Ok(user_ids)
+        })
+        .await
+    }
+
+    async fn get_channel_ancestors(
+        &self,
+        id: ChannelId,
+        tx: &DatabaseTransaction,
+    ) -> Result<Vec<ChannelId>> {
+        let sql = format!(
+            r#"
+            WITH RECURSIVE channel_tree(child_id, parent_id) AS (
+                    SELECT CAST(NULL as INTEGER) as child_id, root_ids.column1 as parent_id
+                    FROM (VALUES ({})) as root_ids
+                UNION
+                    SELECT channel_parents.child_id, channel_parents.parent_id
+                    FROM channel_parents, channel_tree
+                    WHERE channel_parents.child_id = channel_tree.parent_id
+            )
+            SELECT DISTINCT channel_tree.parent_id
+            FROM channel_tree
+            "#,
+            id
+        );
+
+        #[derive(FromQueryResult, Debug, PartialEq)]
+        pub struct ChannelParent {
+            pub parent_id: ChannelId,
+        }
+
+        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
+
+        let mut channel_ids_stream = channel_parent::Entity::find()
+            .from_raw_sql(stmt)
+            .into_model::<ChannelParent>()
+            .stream(&*tx)
+            .await?;
+
+        let mut channel_ids = vec![];
+        while let Some(channel_id) = channel_ids_stream.next().await {
+            channel_ids.push(channel_id?.parent_id);
+        }
+
+        Ok(channel_ids)
+    }
+
     async fn get_channel_descendants(
         &self,
         channel_ids: impl IntoIterator<Item = ChannelId>,
@@ -3948,6 +4013,16 @@ pub struct WorktreeSettingsFile {
     pub content: String,
 }
 
+#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+enum QueryChannelIds {
+    ChannelId,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+enum QueryUserIds {
+    UserId,
+}
+
 #[cfg(test)]
 pub use test::*;
 

crates/collab/src/db/tests.rs 🔗

@@ -899,7 +899,30 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, {
         .unwrap()
         .user_id;
 
+    let b_id = db
+        .create_user(
+            "user2@example.com",
+            false,
+            NewUserParams {
+                github_login: "user2".into(),
+                github_user_id: 6,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+
     let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
+
+    db.invite_channel_member(zed_id, b_id, a_id, true)
+        .await
+        .unwrap();
+
+    db.respond_to_channel_invite(zed_id, b_id, true)
+        .await
+        .unwrap();
+
     let crdb_id = db
         .create_channel("crdb", Some(zed_id), "2", a_id)
         .await
@@ -912,6 +935,11 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, {
         .create_channel("replace", Some(zed_id), "4", a_id)
         .await
         .unwrap();
+
+    let mut members = db.get_channel_members(replace_id).await.unwrap();
+    members.sort();
+    assert_eq!(members, &[a_id, b_id]);
+
     let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap();
     let cargo_id = db
         .create_channel("cargo", Some(rust_id), "6", a_id)

crates/collab/src/rpc.rs 🔗

@@ -2108,25 +2108,33 @@ async fn create_channel(
         live_kit.create_room(live_kit_room.clone()).await?;
     }
 
+    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
     let id = db
-        .create_channel(
-            &request.name,
-            request.parent_id.map(|id| ChannelId::from_proto(id)),
-            &live_kit_room,
-            session.user_id,
-        )
+        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
         .await?;
 
+    response.send(proto::CreateChannelResponse {
+        channel_id: id.to_proto(),
+    })?;
+
     let mut update = proto::UpdateChannels::default();
     update.channels.push(proto::Channel {
         id: id.to_proto(),
         name: request.name,
         parent_id: request.parent_id,
     });
-    session.peer.send(session.connection_id, update)?;
-    response.send(proto::CreateChannelResponse {
-        channel_id: id.to_proto(),
-    })?;
+
+    if let Some(parent_id) = parent_id {
+        let member_ids = db.get_channel_members(parent_id).await?;
+        let connection_pool = session.connection_pool().await;
+        for member_id in member_ids {
+            for connection_id in connection_pool.user_connection_ids(member_id) {
+                session.peer.send(connection_id, update.clone())?;
+            }
+        }
+    } else {
+        session.peer.send(session.connection_id, update)?;
+    }
 
     Ok(())
 }