Rewrite get_user_channels with new permissions

Conrad Irwin created

Change summary

crates/channel/src/channel_store_tests.rs |   2 
crates/collab/src/db/queries/channels.rs  | 174 ++++++++++++++++++++++--
2 files changed, 159 insertions(+), 17 deletions(-)

Detailed changes

crates/channel/src/channel_store_tests.rs 🔗

@@ -3,7 +3,7 @@ use crate::channel_chat::ChannelChatEvent;
 use super::*;
 use client::{test::FakeServer, Client, UserStore};
 use gpui::{AppContext, ModelHandle, TestAppContext};
-use rpc::proto::{self, ChannelRole};
+use rpc::proto::{self};
 use settings::SettingsStore;
 use util::http::FakeHttpClient;
 

crates/collab/src/db/queries/channels.rs 🔗

@@ -439,25 +439,108 @@ impl Database {
         channel_memberships: Vec<channel_member::Model>,
         tx: &DatabaseTransaction,
     ) -> Result<ChannelsForUser> {
-        let parents_by_child_id = self
-            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
+        let mut edges = self
+            .get_channel_descendants_2(channel_memberships.iter().map(|m| m.channel_id), &*tx)
             .await?;
 
-        let channels_with_admin_privileges = channel_memberships
-            .iter()
-            .filter_map(|membership| {
-                if membership.role == Some(ChannelRole::Admin) || membership.admin {
-                    Some(membership.channel_id)
+        let mut role_for_channel: HashMap<ChannelId, ChannelRole> = HashMap::default();
+
+        for membership in channel_memberships.iter() {
+            role_for_channel.insert(
+                membership.channel_id,
+                membership.role.unwrap_or(if membership.admin {
+                    ChannelRole::Admin
                 } else {
-                    None
+                    ChannelRole::Member
+                }),
+            );
+        }
+
+        for ChannelEdge {
+            parent_id,
+            channel_id,
+        } in edges.iter()
+        {
+            let parent_id = ChannelId::from_proto(*parent_id);
+            let channel_id = ChannelId::from_proto(*channel_id);
+            debug_assert!(role_for_channel.get(&parent_id).is_some());
+            let parent_role = role_for_channel[&parent_id];
+            if let Some(existing_role) = role_for_channel.get(&channel_id) {
+                if existing_role.should_override(parent_role) {
+                    continue;
                 }
-            })
-            .collect();
+            }
+            role_for_channel.insert(channel_id, parent_role);
+        }
+
+        let mut channels: Vec<Channel> = Vec::new();
+        let mut channels_with_admin_privileges: HashSet<ChannelId> = HashSet::default();
+        let mut channels_to_remove: HashSet<u64> = HashSet::default();
 
-        let graph = self
-            .get_channel_graph(parents_by_child_id, true, &tx)
+        let mut rows = channel::Entity::find()
+            .filter(channel::Column::Id.is_in(role_for_channel.keys().cloned()))
+            .stream(&*tx)
             .await?;
 
+        while let Some(row) = rows.next().await {
+            let channel = row?;
+            let role = role_for_channel[&channel.id];
+
+            if role == ChannelRole::Banned
+                || role == ChannelRole::Guest && channel.visibility != ChannelVisibility::Public
+            {
+                channels_to_remove.insert(channel.id.0 as u64);
+                continue;
+            }
+
+            channels.push(Channel {
+                id: channel.id,
+                name: channel.name,
+            });
+
+            if role == ChannelRole::Admin {
+                channels_with_admin_privileges.insert(channel.id);
+            }
+        }
+        drop(rows);
+
+        if !channels_to_remove.is_empty() {
+            // Note: this code assumes each channel has one parent.
+            let mut replacement_parent: HashMap<u64, u64> = HashMap::default();
+            for ChannelEdge {
+                parent_id,
+                channel_id,
+            } in edges.iter()
+            {
+                if channels_to_remove.contains(channel_id) {
+                    replacement_parent.insert(*channel_id, *parent_id);
+                }
+            }
+
+            let mut new_edges: Vec<ChannelEdge> = Vec::new();
+            'outer: for ChannelEdge {
+                mut parent_id,
+                channel_id,
+            } in edges.iter()
+            {
+                if channels_to_remove.contains(channel_id) {
+                    continue;
+                }
+                while channels_to_remove.contains(&parent_id) {
+                    if let Some(new_parent_id) = replacement_parent.get(&parent_id) {
+                        parent_id = *new_parent_id;
+                    } else {
+                        continue 'outer;
+                    }
+                }
+                new_edges.push(ChannelEdge {
+                    parent_id,
+                    channel_id: *channel_id,
+                })
+            }
+            edges = new_edges;
+        }
+
         #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
         enum QueryUserIdsAndChannelIds {
             ChannelId,
@@ -468,7 +551,7 @@ impl Database {
         {
             let mut rows = room_participant::Entity::find()
                 .inner_join(room::Entity)
-                .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
+                .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
                 .select_only()
                 .column(room::Column::ChannelId)
                 .column(room_participant::Column::UserId)
@@ -481,7 +564,7 @@ impl Database {
             }
         }
 
-        let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
+        let channel_ids = channels.iter().map(|c| c.id).collect::<Vec<_>>();
         let channel_buffer_changes = self
             .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
             .await?;
@@ -491,7 +574,7 @@ impl Database {
             .await?;
 
         Ok(ChannelsForUser {
-            channels: graph,
+            channels: ChannelGraph { channels, edges },
             channel_participants,
             channels_with_admin_privileges,
             unseen_buffer_changes: channel_buffer_changes,
@@ -842,7 +925,7 @@ impl Database {
         })
     }
 
-    /// Returns the channel ancestors, deepest first
+    /// Returns the channel ancestors, include itself, deepest first
     pub async fn get_channel_ancestors(
         &self,
         channel_id: ChannelId,
@@ -867,6 +950,65 @@ impl Database {
         Ok(channel_ids)
     }
 
+    // Returns the channel desendants as a sorted list of edges for further processing.
+    // The edges are sorted such that you will see unknown channel ids as children
+    // before you see them as parents.
+    async fn get_channel_descendants_2(
+        &self,
+        channel_ids: impl IntoIterator<Item = ChannelId>,
+        tx: &DatabaseTransaction,
+    ) -> Result<Vec<ChannelEdge>> {
+        let mut values = String::new();
+        for id in channel_ids {
+            if !values.is_empty() {
+                values.push_str(", ");
+            }
+            write!(&mut values, "({})", id).unwrap();
+        }
+
+        if values.is_empty() {
+            return Ok(vec![]);
+        }
+
+        let sql = format!(
+            r#"
+            SELECT
+                descendant_paths.*
+            FROM
+                channel_paths parent_paths, channel_paths descendant_paths
+            WHERE
+                parent_paths.channel_id IN ({values}) AND
+                descendant_paths.id_path != parent_paths.id_path AND
+                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
+            ORDER BY
+                descendant_paths.id_path
+        "#
+        );
+
+        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
+
+        let mut paths = channel_path::Entity::find()
+            .from_raw_sql(stmt)
+            .stream(tx)
+            .await?;
+
+        let mut results: Vec<ChannelEdge> = Vec::new();
+        while let Some(path) = paths.next().await {
+            let path = path?;
+            let ids: Vec<&str> = path.id_path.trim_matches('/').split('/').collect();
+
+            debug_assert!(ids.len() >= 2);
+            debug_assert!(ids[ids.len() - 1] == path.channel_id.to_string());
+
+            results.push(ChannelEdge {
+                parent_id: ids[ids.len() - 2].parse().unwrap(),
+                channel_id: ids[ids.len() - 1].parse().unwrap(),
+            })
+        }
+
+        Ok(results)
+    }
+
     /// Returns the channel descendants,
     /// Structured as a map from child ids to their parent ids
     /// For example, the descendants of 'a' in this DAG: