implement recursive channel query

Mikayla Maki created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql |   3 
crates/collab/src/db.rs                                        | 256 +--
crates/collab/src/db/channel.rs                                |   1 
crates/collab/src/db/channel_parent.rs                         |   3 
crates/collab/src/tests/channel_tests.rs                       |  66 
5 files changed, 191 insertions(+), 138 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -187,7 +187,6 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id");
 
 CREATE TABLE "channels" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
-    -- "id_path" TEXT NOT NULL,
     "name" VARCHAR NOT NULL,
     "room_id" INTEGER REFERENCES rooms (id) ON DELETE SET NULL,
     "created_at" TIMESTAMP NOT NULL DEFAULT now
@@ -199,8 +198,6 @@ CREATE TABLE "channel_parents" (
     PRIMARY KEY(child_id, parent_id)
 );
 
--- CREATE UNIQUE INDEX "index_channels_on_id_path" ON "channels" ("id_path");
-
 CREATE TABLE "channel_members" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,

crates/collab/src/db.rs 🔗

@@ -1,7 +1,7 @@
 mod access_token;
 mod channel;
 mod channel_member;
-// mod channel_parent;
+mod channel_parent;
 mod contact;
 mod follower;
 mod language_server;
@@ -39,7 +39,10 @@ use sea_orm::{
     DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect,
     Statement, TransactionTrait,
 };
-use sea_query::{Alias, Expr, OnConflict, Query, SelectStatement};
+use sea_query::{
+    Alias, ColumnRef, CommonTableExpression, Expr, OnConflict, Order, Query, QueryStatementWriter,
+    SelectStatement, UnionType, WithClause,
+};
 use serde::{Deserialize, Serialize};
 pub use signup::{Invite, NewSignup, WaitlistSummary};
 use sqlx::migrate::{Migrate, Migration, MigrationSource};
@@ -3032,7 +3035,11 @@ impl Database {
 
     // channels
 
-    pub async fn create_channel(&self, name: &str) -> Result<ChannelId> {
+    pub async fn create_root_channel(&self, name: &str) -> Result<ChannelId> {
+        self.create_channel(name, None).await
+    }
+
+    pub async fn create_channel(&self, name: &str, parent: Option<ChannelId>) -> Result<ChannelId> {
         self.transaction(move |tx| async move {
             let tx = tx;
 
@@ -3043,10 +3050,21 @@ impl Database {
 
             let channel = channel.insert(&*tx).await?;
 
+            if let Some(parent) = parent {
+                channel_parent::ActiveModel {
+                    child_id: ActiveValue::Set(channel.id),
+                    parent_id: ActiveValue::Set(parent),
+                }
+                .insert(&*tx)
+                .await?;
+            }
+
             Ok(channel.id)
-        }).await
+        })
+        .await
     }
 
+    // Property: Members are only
     pub async fn add_channel_member(&self, channel_id: ChannelId, user_id: UserId) -> Result<()> {
         self.transaction(move |tx| async move {
             let tx = tx;
@@ -3060,139 +3078,108 @@ impl Database {
             channel_membership.insert(&*tx).await?;
 
             Ok(())
-        }).await
+        })
+        .await
     }
 
-    pub async fn get_channels(&self, user_id: UserId) -> Vec<ChannelId> {
+    pub async fn get_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
         self.transaction(|tx| async move {
             let tx = tx;
 
+            // This is the SQL statement we want to generate:
+            let sql = r#"
+            WITH RECURSIVE channel_tree(child_id, parent_id, depth) AS (
+                    SELECT channel_id as child_id, NULL as parent_id, 0
+                    FROM channel_members
+                    WHERE user_id = ?
+                UNION ALL
+                    SELECT channel_parents.child_id, channel_parents.parent_id, channel_tree.depth + 1
+                    FROM channel_parents, channel_tree
+                    WHERE channel_parents.parent_id = channel_tree.child_id
+            )
+            SELECT channel_tree.child_id as id, channels.name, channel_tree.parent_id
+            FROM channel_tree
+            JOIN channels ON channels.id = channel_tree.child_id
+            ORDER BY channel_tree.depth;
+            "#;
+
+            // let root_channel_ids_query = SelectStatement::new()
+            //     .column(channel_member::Column::ChannelId)
+            //     .expr(Expr::value("NULL"))
+            //     .from(channel_member::Entity.table_ref())
+            //     .and_where(
+            //         Expr::col(channel_member::Column::UserId)
+            //             .eq(Expr::cust_with_values("?", vec![user_id])),
+            //     );
+
+            // let build_tree_query = SelectStatement::new()
+            //     .column(channel_parent::Column::ChildId)
+            //     .column(channel_parent::Column::ParentId)
+            //     .expr(Expr::col(Alias::new("channel_tree.depth")).add(1i32))
+            //     .from(Alias::new("channel_tree"))
+            //     .and_where(
+            //         Expr::col(channel_parent::Column::ParentId)
+            //             .equals(Alias::new("channel_tree"), Alias::new("child_id")),
+            //     )
+            //     .to_owned();
+
+            // let common_table_expression = CommonTableExpression::new()
+            //     .query(
+            //         root_channel_ids_query
+            //             .union(UnionType::Distinct, build_tree_query)
+            //             .to_owned(),
+            //     )
+            //     .column(Alias::new("child_id"))
+            //     .column(Alias::new("parent_id"))
+            //     .column(Alias::new("depth"))
+            //     .table_name(Alias::new("channel_tree"))
+            //     .to_owned();
+
+            // let select = SelectStatement::new()
+            //     .expr_as(
+            //         Expr::col(Alias::new("channel_tree.child_id")),
+            //         Alias::new("id"),
+            //     )
+            //     .column(channel::Column::Name)
+            //     .column(Alias::new("channel_tree.parent_id"))
+            //     .from(Alias::new("channel_tree"))
+            //     .inner_join(
+            //         channel::Entity.table_ref(),
+            //         Expr::eq(
+            //             channel::Column::Id.into_expr(),
+            //             Expr::tbl(Alias::new("channel_tree"), Alias::new("child_id")),
+            //         ),
+            //     )
+            //     .order_by(Alias::new("channel_tree.child_id"), Order::Asc)
+            //     .to_owned();
+
+            // let with_clause = WithClause::new()
+            //     .recursive(true)
+            //     .cte(common_table_expression)
+            //     .to_owned();
+
+            // let query = select.with(with_clause);
+
+            // let query = SelectStatement::new()
+            //     .column(ColumnRef::Asterisk)
+            //     .from_subquery(query, Alias::new("channel_tree")
+            //     .to_owned();
+
+            // let stmt = self.pool.get_database_backend().build(&query);
+
+            let stmt = Statement::from_sql_and_values(
+                self.pool.get_database_backend(),
+                sql,
+                vec![user_id.into()],
+            );
+
+            Ok(channel_parent::Entity::find()
+                .from_raw_sql(stmt)
+                .into_model::<Channel>()
+                .all(&*tx)
+                .await?)
         })
-        //     let user = user::Model {
-        //         id: user_id,
-        //         ..Default::default()
-        //     };
-        //     let mut channel_ids = user
-        //         .find_related(channel_member::Entity)
-        //         .select_only()
-        //         .column(channel_member::Column::ChannelId)
-        //         .all(&*tx)
-        //         .await;
-
-        //     // let descendants = Alias::new("descendants");
-        //     // let cte_referencing = SelectStatement::new()
-        //     //     .column(channel_parent::Column::ChildId)
-        //     //     .from(channel::Entity)
-        //     //     .and_where(
-        //     //         Expr::col(channel_parent::Column::ParentId)
-        //     //             .in_subquery(SelectStatement::new().from(descendants).take())
-        //     //     );
-
-        //     // /*
-        //     // WITH RECURSIVE descendant_ids(id) AS (
-        //     //     $1
-        //     //     UNION ALL
-        //     //     SELECT child_id as id FROM channel_parents WHERE parent_id IN descendants
-        //     // )
-        //     // SELECT * from channels where id in descendant_ids
-        //     // */
-
-
-        //     // // WITH RECURSIVE descendants(id) AS (
-        //     // //    // SQL QUERY FOR SELECTING Initial IDs
-        //     // //   UNION
-        //     // //    SELECT id FROM ancestors WHERE p.parent = id
-        //     // // )
-        //     // // SELECT * FROM descendants;
-
-
-
-        //     // // let descendant_channel_ids =
-
-
-
-        //     // // let query = sea_query::Query::with().recursive(true);
-
-
-        //     // for id_path in id_paths {
-        //     //     //
-        //     // }
-
-
-        //     // // zed/public/plugins
-        //     // // zed/public/plugins/js
-        //     // // zed/zed-livekit
-        //     // // livekit/zed-livekit
-        //     // // zed - 101
-        //     // // livekit - 500
-        //     // // zed-livekit - 510
-        //     // // public - 150
-        //     // // plugins - 200
-        //     // // js - 300
-        //     // //
-        //     // // Channel, Parent - edges
-        //     // // 510 - 500
-        //     // // 510 - 101
-        //     // //
-        //     // // Given the channel 'Zed' (101)
-        //     // // Select * from EDGES where parent = 101 => 510
-        //     // //
-
-
-        //     // "SELECT * from channels where id_path like '$1?'"
-
-        //     // // https://www.postgresql.org/docs/current/queries-with.html
-        //     // // https://www.sqlite.org/lang_with.html
-
-        //     // "SELECT channel_id from channel_ancestors where ancestor_id IN $()"
-
-        //     // // | channel_id | ancestor_ids |
-        //     // // 150              150
-        //     // // 150              101
-        //     // // 200              101
-        //     // // 300              101
-        //     // // 200              150
-        //     // // 300              150
-        //     // // 300              200
-        //     // //
-        //     // // // | channel_id | ancestor_ids |
-        //     // // 150              101
-        //     // // 200              101
-        //     // // 300              101
-        //     // // 200              150
-        //     // // 300              [150, 200]
-
-        //     // channel::Entity::find()
-        //     //     .filter(channel::Column::IdPath.like(id_paths.unwrap()))
-
-        //     // dbg!(&id_paths.unwrap()[0].id_path);
-
-        //     // // let mut channel_members_by_channel_id = HashMap::new();
-        //     // // for channel_member in channel_members {
-        //     // //     channel_members_by_channel_id
-        //     // //         .entry(channel_member.channel_id)
-        //     // //         .or_insert_with(Vec::new)
-        //     // //         .push(channel_member);
-        //     // // }
-
-        //     // // let mut channel_messages = channel_message::Entity::find()
-        //     // //     .filter(channel_message::Column::ChannelId.in_selection(channel_ids))
-        //     // //     .all(&*tx)
-        //     // //     .await?;
-
-        //     // // let mut channel_messages_by_channel_id = HashMap::new();
-        //     // // for channel_message in channel_messages {
-        //     // //     channel_messages_by_channel_id
-        //     // //         .entry(channel_message.channel_id)
-        //     // //         .or_insert_with(Vec::new)
-        //     // //         .push(channel_message);
-        //     // // }
-
-        //     // todo!();
-        //     // // Ok(channels)
-        //     Err(Error("not implemented"))
-        // })
-        // .await
+        .await
     }
 
     async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
@@ -3440,6 +3427,13 @@ pub struct NewUserResult {
     pub signup_device_id: Option<String>,
 }
 
+#[derive(FromQueryResult, Debug, PartialEq)]
+pub struct Channel {
+    pub id: ChannelId,
+    pub name: String,
+    pub parent_id: Option<ChannelId>,
+}
+
 fn random_invite_code() -> String {
     nanoid::nanoid!(16)
 }

crates/collab/src/db/channel.rs 🔗

@@ -8,7 +8,6 @@ pub struct Model {
     pub id: ChannelId,
     pub name: String,
     pub room_id: Option<RoomId>,
-    // pub id_path: String,
 }
 
 impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/tests/channel_tests.rs 🔗

@@ -1,6 +1,8 @@
 use gpui::{executor::Deterministic, TestAppContext};
 use std::sync::Arc;
 
+use crate::db::Channel;
+
 use super::TestServer;
 
 #[gpui::test]
@@ -11,13 +13,71 @@ async fn test_basic_channels(deterministic: Arc<Deterministic>, cx: &mut TestApp
     let a_id = crate::db::UserId(client_a.user_id().unwrap() as i32);
     let db = server._test_db.db();
 
-    let zed_id = db.create_channel("zed").await.unwrap();
+    let zed_id = db.create_root_channel("zed").await.unwrap();
+    let crdb_id = db.create_channel("crdb", Some(zed_id)).await.unwrap();
+    let livestreaming_id = db
+        .create_channel("livestreaming", Some(zed_id))
+        .await
+        .unwrap();
+    let replace_id = db.create_channel("replace", Some(zed_id)).await.unwrap();
+    let rust_id = db.create_root_channel("rust").await.unwrap();
+    let cargo_id = db.create_channel("cargo", Some(rust_id)).await.unwrap();
 
     db.add_channel_member(zed_id, a_id).await.unwrap();
+    db.add_channel_member(rust_id, a_id).await.unwrap();
+
+    let channels = db.get_channels(a_id).await.unwrap();
+    assert_eq!(
+        channels,
+        vec![
+            Channel {
+                id: zed_id,
+                name: "zed".to_string(),
+                parent_id: None,
+            },
+            Channel {
+                id: rust_id,
+                name: "rust".to_string(),
+                parent_id: None,
+            },
+            Channel {
+                id: crdb_id,
+                name: "crdb".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: livestreaming_id,
+                name: "livestreaming".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: replace_id,
+                name: "replace".to_string(),
+                parent_id: Some(zed_id),
+            },
+            Channel {
+                id: cargo_id,
+                name: "cargo".to_string(),
+                parent_id: Some(rust_id),
+            }
+        ]
+    );
+}
 
-    let channels = db.get_channels(a_id).await;
+#[gpui::test]
+async fn test_block_cycle_creation(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx, "user_a").await;
+    let a_id = crate::db::UserId(client_a.user_id().unwrap() as i32);
+    let db = server._test_db.db();
 
-    assert_eq!(channels, vec![zed_id]);
+    let zed_id = db.create_root_channel("zed").await.unwrap();
+    let first_id = db.create_channel("first", Some(zed_id)).await.unwrap();
+    let second_id = db
+        .create_channel("second_id", Some(first_id))
+        .await
+        .unwrap();
 }
 
 /*