Add channel linking operation

Mikayla created

Change summary

crates/collab/Cargo.toml                    |   2 
crates/collab/src/db/queries/channels.rs    | 102 ++++++++
crates/collab/src/db/tests/channel_tests.rs | 266 +++++++++++++++++++++-
3 files changed, 348 insertions(+), 22 deletions(-)

Detailed changes

crates/collab/Cargo.toml 🔗

@@ -72,7 +72,6 @@ fs = { path = "../fs", features = ["test-support"] }
 git = { path = "../git", features = ["test-support"] }
 live_kit_client = { path = "../live_kit_client", features = ["test-support"] }
 lsp = { path = "../lsp", features = ["test-support"] }
-pretty_assertions.workspace = true
 project = { path = "../project", features = ["test-support"] }
 rpc = { path = "../rpc", features = ["test-support"] }
 settings = { path = "../settings", features = ["test-support"] }
@@ -81,6 +80,7 @@ workspace = { path = "../workspace", features = ["test-support"] }
 collab_ui = { path = "../collab_ui", features = ["test-support"] }
 
 async-trait.workspace = true
+pretty_assertions.workspace = true
 ctor.workspace = true
 env_logger.workspace = true
 indoc.workspace = true

crates/collab/src/db/queries/channels.rs 🔗

@@ -68,6 +68,8 @@ impl Database {
                     ],
                 );
                 tx.execute(channel_paths_stmt).await?;
+
+                dbg!(channel_path::Entity::find().all(&*tx).await?);
             } else {
                 channel_path::Entity::insert(channel_path::ActiveModel {
                     channel_id: ActiveValue::Set(channel.id),
@@ -336,6 +338,8 @@ impl Database {
                 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
                 .await?;
 
+            dbg!(&parents_by_child_id);
+
             let channels_with_admin_privileges = channel_memberships
                 .iter()
                 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
@@ -349,11 +353,24 @@ impl Database {
                     .await?;
                 while let Some(row) = rows.next().await {
                     let row = row?;
-                    channels.push(Channel {
-                        id: row.id,
-                        name: row.name,
-                        parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
-                    });
+
+                    // As these rows are pulled from the map's keys, this unwrap is safe.
+                    let parents = parents_by_child_id.get(&row.id).unwrap();
+                    if parents.len() > 0 {
+                        for parent in parents {
+                            channels.push(Channel {
+                                id: row.id,
+                                name: row.name.clone(),
+                                parent_id: Some(*parent),
+                            });
+                        }
+                    } else {
+                        channels.push(Channel {
+                            id: row.id,
+                            name: row.name,
+                            parent_id: None,
+                        });
+                    }
                 }
             }
 
@@ -559,6 +576,7 @@ impl Database {
         Ok(())
     }
 
+    /// Returns the channel ancestors, deepest first
     pub async fn get_channel_ancestors(
         &self,
         channel_id: ChannelId,
@@ -566,6 +584,7 @@ impl Database {
     ) -> Result<Vec<ChannelId>> {
         let paths = channel_path::Entity::find()
             .filter(channel_path::Column::ChannelId.eq(channel_id))
+            .order_by(channel_path::Column::IdPath, sea_query::Order::Desc)
             .all(tx)
             .await?;
         let mut channel_ids = Vec::new();
@@ -586,7 +605,7 @@ impl Database {
         &self,
         channel_ids: impl IntoIterator<Item = ChannelId>,
         tx: &DatabaseTransaction,
-    ) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
+    ) -> Result<HashMap<ChannelId, HashSet<ChannelId>>> {
         let mut values = String::new();
         for id in channel_ids {
             if !values.is_empty() {
@@ -613,7 +632,7 @@ impl Database {
 
         let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 
-        let mut parents_by_child_id = HashMap::default();
+        let mut parents_by_child_id: HashMap<ChannelId, HashSet<ChannelId>> = HashMap::default();
         let mut paths = channel_path::Entity::find()
             .from_raw_sql(stmt)
             .stream(tx)
@@ -632,7 +651,10 @@ impl Database {
                     parent_id = Some(id);
                 }
             }
-            parents_by_child_id.insert(path.channel_id, parent_id);
+            let entry = parents_by_child_id.entry(path.channel_id).or_default();
+            if let Some(parent_id) = parent_id {
+                entry.insert(parent_id);
+            }
         }
 
         Ok(parents_by_child_id)
@@ -704,12 +726,74 @@ impl Database {
         .await
     }
 
+    pub async fn link_channel(&self, user: UserId, from: ChannelId, to: ChannelId) -> Result<()> {
+        self.transaction(|tx| async move {
+            self.check_user_is_channel_admin(to, user, &*tx).await?;
+
+            // TODO: Downgrade this check once our permissions system isn't busted
+            // You should be able to safely link a member channel for  your own uses. See:
+            // https://zed.dev/blog/this-week-at-zed-15 > Mikayla's section
+            //
+            // Note that even with these higher permissions, this linking operation
+            // is still insecure because you can't remove someone's permissions to a
+            // channel if they've linked the channel to one where they're an admin.
+            self.check_user_is_channel_admin(from, user, &*tx).await?;
+
+            let to_ancestors = self.get_channel_ancestors(to, &*tx).await?;
+            let from_descendants = self.get_channel_descendants([from], &*tx).await?;
+            for ancestor in to_ancestors {
+                if from_descendants.contains_key(&ancestor) {
+                    return Err(anyhow!("Cannot create a channel cycle").into());
+                }
+            }
+
+            let sql = r#"
+                INSERT INTO channel_paths
+                (id_path, channel_id)
+                SELECT
+                    id_path || $1 || '/', $2
+                FROM
+                    channel_paths
+                WHERE
+                    channel_id = $3
+                ON CONFLICT (id_path) DO NOTHING;
+            "#;
+            let channel_paths_stmt = Statement::from_sql_and_values(
+                self.pool.get_database_backend(),
+                sql,
+                [
+                    from.to_proto().into(),
+                    from.to_proto().into(),
+                    to.to_proto().into(),
+                ],
+            );
+            tx.execute(channel_paths_stmt).await?;
+
+            for (from_id, to_ids) in from_descendants.iter().filter(|(id, _)| id == &&from) {
+                for to_id in to_ids {
+                    let channel_paths_stmt = Statement::from_sql_and_values(
+                        self.pool.get_database_backend(),
+                        sql,
+                        [
+                            from_id.to_proto().into(),
+                            from_id.to_proto().into(),
+                            to_id.to_proto().into(),
+                        ],
+                    );
+                    tx.execute(channel_paths_stmt).await?;
+                }
+            }
+
+            Ok(())
+        })
+        .await
+    }
+
     pub async fn move_channel(
         &self,
         user: UserId,
         from: ChannelId,
         to: Option<ChannelId>,
-        link: bool,
     ) -> Result<()> {
         self.transaction(|tx| async move { todo!() }).await
     }

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -486,14 +486,24 @@ async fn test_channels_moving(db: &Arc<Database>) {
         .await
         .unwrap();
 
+    let gpui2_id = db
+        .create_channel("gpui2", Some(zed_id), "3", a_id)
+        .await
+        .unwrap();
+
     let livestreaming_id = db
-        .create_channel("livestreaming", Some(crdb_id), "3", a_id)
+        .create_channel("livestreaming", Some(crdb_id), "4", a_id)
+        .await
+        .unwrap();
+
+    let livestreaming_dag_id = db
+        .create_channel("livestreaming_dag", Some(livestreaming_id), "5", a_id)
         .await
         .unwrap();
 
     // sanity check
     let result = db.get_channels_for_user(a_id).await.unwrap();
-    assert_eq!(
+    pretty_assertions::assert_eq!(
         result.channels,
         vec![
             Channel {
@@ -506,33 +516,93 @@ async fn test_channels_moving(db: &Arc<Database>) {
                 name: "crdb".to_string(),
                 parent_id: Some(zed_id),
             },
+            Channel {
+                id: gpui2_id,
+                name: "gpui2".to_string(),
+                parent_id: Some(zed_id),
+            },
             Channel {
                 id: livestreaming_id,
                 name: "livestreaming".to_string(),
                 parent_id: Some(crdb_id),
             },
+            Channel {
+                id: livestreaming_dag_id,
+                name: "livestreaming_dag".to_string(),
+                parent_id: Some(livestreaming_id),
+            },
         ]
     );
+    // Initial DAG:
+    //     /- gpui2
+    // zed -- crdb - livestreaming - livestreaming_dag
 
-    // Move channel up
-    db.move_channel(a_id, livestreaming_id, Some(zed_id), false)
-        .await
-        .unwrap();
-
-    // Attempt to make a cycle
+    // Attemp to make a cycle
     assert!(db
-        .move_channel(a_id, zed_id, Some(livestreaming_id), false)
+        .link_channel(a_id, zed_id, livestreaming_id)
         .await
         .is_err());
 
     // Make a link
-    db.move_channel(a_id, crdb_id, Some(livestreaming_id), true)
+    db.link_channel(a_id, livestreaming_id, zed_id)
         .await
         .unwrap();
 
+    // DAG is now:
+    //     /- gpui2
+    // zed -- crdb - livestreaming - livestreaming_dag
+    //    \---------/
+
     let result = db.get_channels_for_user(a_id).await.unwrap();
-    assert_eq!(
-        result.channels,
+    pretty_assertions::assert_eq!(
+        dbg!(result.channels),
+        vec![
+            Channel {
+                id: zed_id,
+                name: "zed".to_string(),
+                parent_id: None,
+            },
+            Channel {
+                id: crdb_id,
+                name: "crdb".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: gpui2_id,
+                name: "gpui2".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(crdb_id),
+            },
+            Channel {
+                id: livestreaming_dag_id,
+                name: "livestreaming_dag".to_string(),
+                parent_id: Some(livestreaming_id),
+            },
+        ]
+    );
+
+    let livestreaming_dag_sub_id = db
+        .create_channel("livestreaming_dag_sub", Some(livestreaming_dag_id), "6", a_id)
+        .await
+        .unwrap();
+
+    // DAG is now:
+    //     /- gpui2
+    // zed -- crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id
+    //    \---------/
+
+    let result = db.get_channels_for_user(a_id).await.unwrap();
+    pretty_assertions::assert_eq!(
+        dbg!(result.channels),
         vec![
             Channel {
                 id: zed_id,
@@ -544,16 +614,188 @@ async fn test_channels_moving(db: &Arc<Database>) {
                 name: "crdb".to_string(),
                 parent_id: Some(zed_id),
             },
+            Channel {
+                id: gpui2_id,
+                name: "gpui2".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(crdb_id),
+            },
+            Channel {
+                id: livestreaming_dag_id,
+                name: "livestreaming_dag".to_string(),
+                parent_id: Some(livestreaming_id),
+            },
+            Channel {
+                id: livestreaming_dag_sub_id,
+                name: "livestreaming_dag_sub".to_string(),
+                parent_id: Some(livestreaming_dag_id),
+            },
+        ]
+    );
+
+    // Make a link
+    db.link_channel(a_id, livestreaming_dag_sub_id, livestreaming_id)
+        .await
+        .unwrap();
+
+    // DAG is now:
+    //    /- gpui2                /---------------------\
+    // zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id
+    //    \--------/
+
+    let result = db.get_channels_for_user(a_id).await.unwrap();
+    pretty_assertions::assert_eq!(
+        dbg!(result.channels),
+        vec![
+            Channel {
+                id: zed_id,
+                name: "zed".to_string(),
+                parent_id: None,
+            },
             Channel {
                 id: crdb_id,
                 name: "crdb".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: gpui2_id,
+                name: "gpui2".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(crdb_id),
+            },
+            Channel {
+                id: livestreaming_dag_id,
+                name: "livestreaming_dag".to_string(),
                 parent_id: Some(livestreaming_id),
             },
+            Channel {
+                id: livestreaming_dag_sub_id,
+                name: "livestreaming_dag_sub".to_string(),
+                parent_id: Some(livestreaming_id),
+            },
+            Channel {
+                id: livestreaming_dag_sub_id,
+                name: "livestreaming_dag_sub".to_string(),
+                parent_id: Some(livestreaming_dag_id),
+            },
+        ]
+    );
+
+    // Make another link
+    db.link_channel(a_id, livestreaming_id, gpui2_id)
+        .await
+        .unwrap();
+
+    // DAG is now:
+    //    /- gpui2 -\             /---------------------\
+    // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub_id
+    //    \---------/
+
+    let result = db.get_channels_for_user(a_id).await.unwrap();
+    pretty_assertions::assert_eq!(
+        dbg!(result.channels),
+        vec![
+            Channel {
+                id: zed_id,
+                name: "zed".to_string(),
+                parent_id: None,
+            },
+            Channel {
+                id: crdb_id,
+                name: "crdb".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: gpui2_id,
+                name: "gpui2".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(gpui2_id),
+            },
             Channel {
                 id: livestreaming_id,
                 name: "livestreaming".to_string(),
                 parent_id: Some(zed_id),
             },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(crdb_id),
+            },
+            Channel {
+                id: livestreaming_dag_id,
+                name: "livestreaming_dag".to_string(),
+                parent_id: Some(livestreaming_id),
+            },
+            Channel {
+                id: livestreaming_dag_sub_id,
+                name: "livestreaming_dag_sub".to_string(),
+                parent_id: Some(livestreaming_id),
+            },
+            Channel {
+                id: livestreaming_dag_sub_id,
+                name: "livestreaming_dag_sub".to_string(),
+                parent_id: Some(livestreaming_dag_id),
+            },
         ]
     );
+
+    // // Attempt to make a cycle
+    // assert!(db
+    //     .move_channel(a_id, zed_id, Some(livestreaming_id))
+    //     .await
+    //     .is_err());
+
+    // // Move channel up
+    // db.move_channel(a_id, livestreaming_id, Some(zed_id))
+    //     .await
+    //     .unwrap();
+
+    // let result = db.get_channels_for_user(a_id).await.unwrap();
+    // pretty_assertions::assert_eq!(
+    //     result.channels,
+    //     vec![
+    //         Channel {
+    //             id: zed_id,
+    //             name: "zed".to_string(),
+    //             parent_id: None,
+    //         },
+    //         Channel {
+    //             id: crdb_id,
+    //             name: "crdb".to_string(),
+    //             parent_id: Some(zed_id),
+    //         },
+    //         Channel {
+    //             id: crdb_id,
+    //             name: "crdb".to_string(),
+    //             parent_id: Some(livestreaming_id),
+    //         },
+    //         Channel {
+    //             id: livestreaming_id,
+    //             name: "livestreaming".to_string(),
+    //             parent_id: Some(zed_id),
+    //         },
+    //     ]
+    // );
 }