Fix slow query for fetching descendants of channels (#7008)

Conrad Irwin and Max created

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/collab/migrations/20240129193601_fix_parent_path_index.sql |  4 
crates/collab/src/db/queries/channels.rs                          | 98 
crates/collab/src/db/tables/channel.rs                            |  4 
3 files changed, 37 insertions(+), 69 deletions(-)

Detailed changes

crates/collab/src/db/queries/channels.rs 🔗

@@ -197,12 +197,10 @@ impl Database {
                 }
             } else if visibility == ChannelVisibility::Members {
                 if self
-                    .get_channel_descendants_including_self(vec![channel_id], &*tx)
+                    .get_channel_descendants_excluding_self([&channel], &*tx)
                     .await?
                     .into_iter()
-                    .any(|channel| {
-                        channel.id != channel_id && channel.visibility == ChannelVisibility::Public
-                    })
+                    .any(|channel| channel.visibility == ChannelVisibility::Public)
                 {
                     Err(ErrorCode::BadPublicNesting
                         .with_tag("direction", "children")
@@ -261,10 +259,11 @@ impl Database {
                 .await?;
 
             let channels_to_remove = self
-                .get_channel_descendants_including_self(vec![channel.id], &*tx)
+                .get_channel_descendants_excluding_self([&channel], &*tx)
                 .await?
                 .into_iter()
                 .map(|channel| channel.id)
+                .chain(Some(channel_id))
                 .collect::<Vec<_>>();
 
             channel::Entity::delete_many()
@@ -445,16 +444,12 @@ impl Database {
     ) -> Result<MembershipUpdated> {
         let new_channels = self.get_user_channels(user_id, Some(channel), &*tx).await?;
         let removed_channels = self
-            .get_channel_descendants_including_self(vec![channel.id], &*tx)
+            .get_channel_descendants_excluding_self([channel], &*tx)
             .await?
             .into_iter()
-            .filter_map(|channel| {
-                if !new_channels.channels.iter().any(|c| c.id == channel.id) {
-                    Some(channel.id)
-                } else {
-                    None
-                }
-            })
+            .map(|channel| channel.id)
+            .chain([channel.id])
+            .filter(|channel_id| !new_channels.channels.iter().any(|c| c.id == *channel_id))
             .collect::<Vec<_>>();
 
         Ok(MembershipUpdated {
@@ -545,26 +540,6 @@ impl Database {
         .await
     }
 
-    pub async fn get_channel_memberships(
-        &self,
-        user_id: UserId,
-    ) -> Result<(Vec<channel_member::Model>, Vec<channel::Model>)> {
-        self.transaction(|tx| async move {
-            let memberships = channel_member::Entity::find()
-                .filter(channel_member::Column::UserId.eq(user_id))
-                .all(&*tx)
-                .await?;
-            let channels = self
-                .get_channel_descendants_including_self(
-                    memberships.iter().map(|m| m.channel_id),
-                    &*tx,
-                )
-                .await?;
-            Ok((memberships, channels))
-        })
-        .await
-    }
-
     /// Returns all channels for the user with the given ID.
     pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
         self.transaction(|tx| async move {
@@ -596,13 +571,21 @@ impl Database {
             .all(&*tx)
             .await?;
 
-        let descendants = self
-            .get_channel_descendants_including_self(
-                channel_memberships.iter().map(|m| m.channel_id),
-                &*tx,
-            )
+        let channels = channel::Entity::find()
+            .filter(channel::Column::Id.is_in(channel_memberships.iter().map(|m| m.channel_id)))
+            .all(&*tx)
+            .await?;
+
+        let mut descendants = self
+            .get_channel_descendants_excluding_self(channels.iter(), &*tx)
             .await?;
 
+        for channel in channels {
+            if let Err(ix) = descendants.binary_search_by_key(&channel.path(), |c| c.path()) {
+                descendants.insert(ix, channel);
+            }
+        }
+
         let roles_by_channel_id = channel_memberships
             .iter()
             .map(|membership| (membership.channel_id, membership.role))
@@ -880,46 +863,23 @@ impl Database {
 
     // Get the descendants of the given set if channels, ordered by their
     // path.
-    async fn get_channel_descendants_including_self(
+    pub(crate) async fn get_channel_descendants_excluding_self(
         &self,
-        channel_ids: impl IntoIterator<Item = ChannelId>,
+        channels: impl IntoIterator<Item = &channel::Model>,
         tx: &DatabaseTransaction,
     ) -> Result<Vec<channel::Model>> {
-        let mut values = String::new();
-        for id in channel_ids {
-            if !values.is_empty() {
-                values.push_str(", ");
-            }
-            write!(&mut values, "({})", id).unwrap();
+        let mut filter = Condition::any();
+        for channel in channels.into_iter() {
+            filter = filter.add(channel::Column::ParentPath.like(channel.descendant_path_filter()));
         }
 
-        if values.is_empty() {
+        if filter.is_empty() {
             return Ok(vec![]);
         }
 
-        let sql = format!(
-            r#"
-            SELECT DISTINCT
-                descendant_channels.*,
-                descendant_channels.parent_path || descendant_channels.id as full_path
-            FROM
-                channels parent_channels, channels descendant_channels
-            WHERE
-                descendant_channels.id IN ({values}) OR
-                (
-                    parent_channels.id IN ({values}) AND
-                    descendant_channels.parent_path LIKE (parent_channels.parent_path || parent_channels.id || '/%')
-                )
-            ORDER BY
-                full_path ASC
-            "#
-        );
-
         Ok(channel::Entity::find()
-            .from_raw_sql(Statement::from_string(
-                self.pool.get_database_backend(),
-                sql,
-            ))
+            .filter(filter)
+            .order_by_asc(Expr::cust("parent_path || id || '/'"))
             .all(tx)
             .await?)
     }

crates/collab/src/db/tables/channel.rs 🔗

@@ -39,6 +39,10 @@ impl Model {
     pub fn path(&self) -> String {
         format!("{}{}/", self.parent_path, self.id)
     }
+
+    pub fn descendant_path_filter(&self) -> String {
+        format!("{}{}/%", self.parent_path, self.id)
+    }
 }
 
 impl ActiveModelBehavior for ActiveModel {}