Finish integration tests for channel moving

Mikayla created

Refactor channel store to combine the channels_by_id and channel_paths into a 'ChannelIndex'

Change summary

crates/channel/src/channel_store.rs               |  96 +++------
crates/channel/src/channel_store/channel_index.rs | 151 +++++++++++++++++
crates/collab/src/db/queries/channels.rs          |  42 ++--
crates/collab/src/db/tests/channel_tests.rs       |  29 --
crates/collab/src/rpc.rs                          |   4 
5 files changed, 214 insertions(+), 108 deletions(-)

Detailed changes

crates/channel/src/channel_store.rs 🔗

@@ -1,3 +1,5 @@
+mod channel_index;
+
 use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat};
 use anyhow::{anyhow, Result};
 use client::{Client, Subscription, User, UserId, UserStore};
@@ -8,13 +10,14 @@ use rpc::{proto, TypedEnvelope};
 use std::{mem, sync::Arc, time::Duration};
 use util::ResultExt;
 
+use self::channel_index::ChannelIndex;
+
 pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
 
 pub type ChannelId = u64;
 
 pub struct ChannelStore {
-    channels_by_id: HashMap<ChannelId, Arc<Channel>>,
-    channel_paths: Vec<Vec<ChannelId>>,
+    channel_index: ChannelIndex,
     channel_invitations: Vec<Arc<Channel>>,
     channel_participants: HashMap<ChannelId, Vec<Arc<User>>>,
     channels_with_admin_privileges: HashSet<ChannelId>,
@@ -82,9 +85,8 @@ impl ChannelStore {
         });
 
         Self {
-            channels_by_id: HashMap::default(),
             channel_invitations: Vec::default(),
-            channel_paths: Vec::default(),
+            channel_index: ChannelIndex::default(),
             channel_participants: Default::default(),
             channels_with_admin_privileges: Default::default(),
             outgoing_invites: Default::default(),
@@ -116,7 +118,7 @@ impl ChannelStore {
     }
 
     pub fn has_children(&self, channel_id: ChannelId) -> bool {
-        self.channel_paths.iter().any(|path| {
+        self.channel_index.iter().any(|path| {
             if let Some(ix) = path.iter().position(|id| *id == channel_id) {
                 path.len() > ix + 1
             } else {
@@ -126,7 +128,7 @@ impl ChannelStore {
     }
 
     pub fn channel_count(&self) -> usize {
-        self.channel_paths.len()
+        self.channel_index.len()
     }
 
     pub fn index_of_channel(&self, channel_id: ChannelId) -> Option<usize> {
@@ -136,7 +138,7 @@ impl ChannelStore {
     }
 
     pub fn channels(&self) -> impl '_ + Iterator<Item = (usize, &Arc<Channel>)> {
-        self.channel_paths.iter().map(move |path| {
+        self.channel_index.iter().map(move |path| {
             let id = path.last().unwrap();
             let channel = self.channel_for_id(*id).unwrap();
             (path.len() - 1, channel)
@@ -144,7 +146,7 @@ impl ChannelStore {
     }
 
     pub fn channel_at_index(&self, ix: usize) -> Option<(usize, &Arc<Channel>)> {
-        let path = self.channel_paths.get(ix)?;
+        let path = self.channel_index.get(ix)?;
         let id = path.last().unwrap();
         let channel = self.channel_for_id(*id).unwrap();
         Some((path.len() - 1, channel))
@@ -155,7 +157,7 @@ impl ChannelStore {
     }
 
     pub fn channel_for_id(&self, channel_id: ChannelId) -> Option<&Arc<Channel>> {
-        self.channels_by_id.get(&channel_id)
+        self.channel_index.by_id().get(&channel_id)
     }
 
     pub fn has_open_channel_buffer(&self, channel_id: ChannelId, cx: &AppContext) -> bool {
@@ -268,7 +270,7 @@ impl ChannelStore {
     }
 
     pub fn is_user_admin(&self, channel_id: ChannelId) -> bool {
-        self.channel_paths.iter().any(|path| {
+        self.channel_index.iter().any(|path| {
             if let Some(ix) = path.iter().position(|id| *id == channel_id) {
                 path[..=ix]
                     .iter()
@@ -323,15 +325,24 @@ impl ChannelStore {
         })
     }
 
-
-    pub fn move_channel(&mut self, channel_id: ChannelId, from_parent: Option<ChannelId>, to: Option<ChannelId>, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+    pub fn move_channel(
+        &mut self,
+        channel_id: ChannelId,
+        from_parent: Option<ChannelId>,
+        to: Option<ChannelId>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
         let client = self.client.clone();
         cx.spawn(|_, _| async move {
             let _ = client
-                .request(proto::MoveChannel { channel_id, from_parent, to })
+                .request(proto::MoveChannel {
+                    channel_id,
+                    from_parent,
+                    to,
+                })
                 .await?;
 
-           Ok(())
+            Ok(())
         })
     }
 
@@ -651,11 +662,11 @@ impl ChannelStore {
     }
 
     fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
-        self.channels_by_id.clear();
+        self.channel_index.clear();
         self.channel_invitations.clear();
         self.channel_participants.clear();
         self.channels_with_admin_privileges.clear();
-        self.channel_paths.clear();
+        self.channel_index.clear();
         self.outgoing_invites.clear();
         cx.notify();
 
@@ -705,8 +716,7 @@ impl ChannelStore {
         let channels_changed = !payload.channels.is_empty() || !payload.delete_channels.is_empty();
         if channels_changed {
             if !payload.delete_channels.is_empty() {
-                self.channels_by_id
-                    .retain(|channel_id, _| !payload.delete_channels.contains(channel_id));
+                self.channel_index.delete_channels(&payload.delete_channels);
                 self.channel_participants
                     .retain(|channel_id, _| !payload.delete_channels.contains(channel_id));
                 self.channels_with_admin_privileges
@@ -724,44 +734,12 @@ impl ChannelStore {
                 }
             }
 
-            for channel_proto in payload.channels {
-                if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
-                    Arc::make_mut(existing_channel).name = channel_proto.name;
-                } else {
-                    let channel = Arc::new(Channel {
-                        id: channel_proto.id,
-                        name: channel_proto.name,
-                    });
-                    self.channels_by_id.insert(channel.id, channel.clone());
-
-                    if let Some(parent_id) = channel_proto.parent_id {
-                        let mut ix = 0;
-                        while ix < self.channel_paths.len() {
-                            let path = &self.channel_paths[ix];
-                            if path.ends_with(&[parent_id]) {
-                                let mut new_path = path.clone();
-                                new_path.push(channel.id);
-                                self.channel_paths.insert(ix + 1, new_path);
-                                ix += 1;
-                            }
-                            ix += 1;
-                        }
-                    } else {
-                        self.channel_paths.push(vec![channel.id]);
-                    }
-                }
-            }
+            self.channel_index.insert_channels(payload.channels);
+        }
 
-            self.channel_paths.sort_by(|a, b| {
-                let a = Self::channel_path_sorting_key(a, &self.channels_by_id);
-                let b = Self::channel_path_sorting_key(b, &self.channels_by_id);
-                a.cmp(b)
-            });
-            self.channel_paths.dedup();
-            self.channel_paths.retain(|path| {
-                path.iter()
-                    .all(|channel_id| self.channels_by_id.contains_key(channel_id))
-            });
+        for edge in payload.delete_channel_edge {
+            self.channel_index
+                .remove_edge(edge.parent_id, edge.channel_id);
         }
 
         for permission in payload.channel_permissions {
@@ -820,11 +798,5 @@ impl ChannelStore {
         }))
     }
 
-    fn channel_path_sorting_key<'a>(
-        path: &'a [ChannelId],
-        channels_by_id: &'a HashMap<ChannelId, Arc<Channel>>,
-    ) -> impl 'a + Iterator<Item = Option<&'a str>> {
-        path.iter()
-            .map(|id| Some(channels_by_id.get(id)?.name.as_str()))
-    }
+
 }

crates/channel/src/channel_store/channel_index.rs 🔗

@@ -0,0 +1,151 @@
+use std::{ops::{Deref, DerefMut}, sync::Arc};
+
+use collections::HashMap;
+use rpc::proto;
+
+use crate::{ChannelId, Channel};
+
+pub type ChannelPath = Vec<ChannelId>;
+pub type ChannelsById = HashMap<ChannelId, Arc<Channel>>;
+
+#[derive(Default, Debug)]
+pub struct ChannelIndex {
+    paths: Vec<ChannelPath>,
+    channels_by_id: ChannelsById,
+}
+
+
+impl ChannelIndex {
+    pub fn by_id(&self) -> &ChannelsById {
+        &self.channels_by_id
+    }
+
+    /// Insert or update all of the given channels into the index
+    pub fn insert_channels(&mut self, channels: Vec<proto::Channel>) {
+        let mut insert = self.insert();
+
+        for channel_proto in channels {
+            if let Some(existing_channel) = insert.channels_by_id.get_mut(&channel_proto.id) {
+                Arc::make_mut(existing_channel).name = channel_proto.name;
+
+                if let Some(parent_id) = channel_proto.parent_id {
+                    insert.insert_edge(parent_id, channel_proto.id)
+                }
+            } else {
+                let channel = Arc::new(Channel {
+                    id: channel_proto.id,
+                    name: channel_proto.name,
+                });
+                insert.channels_by_id.insert(channel.id, channel.clone());
+
+                if let Some(parent_id) = channel_proto.parent_id {
+                    insert.insert_edge(parent_id, channel.id);
+                } else {
+                    insert.insert_root(channel.id);
+                }
+            }
+        }
+    }
+
+    pub fn clear(&mut self) {
+        self.paths.clear();
+        self.channels_by_id.clear();
+    }
+
+    /// Remove the given edge from this index. This will not remove the channel
+    /// and may result in dangling channels.
+    pub fn remove_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) {
+        self.paths.retain(|path| {
+            !path
+                .windows(2)
+                .any(|window| window == [parent_id, channel_id])
+        });
+    }
+
+    /// Delete the given channels from this index.
+    pub fn delete_channels(&mut self, channels: &[ChannelId]) {
+        self.channels_by_id.retain(|channel_id, _| !channels.contains(channel_id));
+        self.paths.retain(|channel_path| !channel_path.iter().any(|channel_id| {channels.contains(channel_id)}))
+    }
+
+    fn insert(& mut self) -> ChannelPathsInsertGuard {
+        ChannelPathsInsertGuard {
+            paths: &mut self.paths,
+            channels_by_id: &mut self.channels_by_id,
+        }
+    }
+}
+
+impl Deref for ChannelIndex {
+    type Target = Vec<ChannelPath>;
+
+    fn deref(&self) -> &Self::Target {
+        &self.paths
+    }
+}
+
+/// A guard for ensuring that the paths index maintains its sort and uniqueness
+/// invariants after a series of insertions
+struct ChannelPathsInsertGuard<'a> {
+    paths:  &'a mut Vec<ChannelPath>,
+    channels_by_id: &'a mut ChannelsById,
+}
+
+impl Deref for ChannelPathsInsertGuard<'_> {
+    type Target = ChannelsById;
+
+    fn deref(&self) -> &Self::Target {
+        &self.channels_by_id
+    }
+}
+
+impl DerefMut for ChannelPathsInsertGuard<'_> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        &mut self.channels_by_id
+    }
+}
+
+
+impl<'a> ChannelPathsInsertGuard<'a> {
+    pub fn insert_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) {
+        let mut ix = 0;
+        while ix < self.paths.len() {
+            let path = &self.paths[ix];
+            if path.ends_with(&[parent_id]) {
+                let mut new_path = path.clone();
+                new_path.push(channel_id);
+                self.paths.insert(ix + 1, new_path);
+                ix += 1;
+            }
+            ix += 1;
+        }
+    }
+
+    pub fn insert_root(&mut self, channel_id: ChannelId) {
+        self.paths.push(vec![channel_id]);
+    }
+}
+
+impl<'a> Drop for ChannelPathsInsertGuard<'a> {
+    fn drop(&mut self) {
+        self.paths.sort_by(|a, b| {
+            let a = channel_path_sorting_key(a, &self.channels_by_id);
+            let b = channel_path_sorting_key(b, &self.channels_by_id);
+            a.cmp(b)
+        });
+        self.paths.dedup();
+        self.paths.retain(|path| {
+            path.iter()
+                .all(|channel_id| self.channels_by_id.contains_key(channel_id))
+        });
+    }
+}
+
+
+fn channel_path_sorting_key<'a>(
+    path: &'a [ChannelId],
+    channels_by_id: &'a ChannelsById,
+) -> impl 'a + Iterator<Item = Option<&'a str>> {
+    path.iter()
+        .map(|id| Some(channels_by_id.get(id)?.name.as_str()))
+}

crates/collab/src/db/queries/channels.rs 🔗

@@ -846,7 +846,8 @@ impl Database {
     /// - (`Some(id)`, `None`) Remove a channel from a given parent, and leave other parents
     /// - (`Some(id)`, `Some(id)`) Move channel from one parent to another, leaving other parents
     ///
-    /// Returns the channel that was moved + it's sub channels
+    /// Returns the channel that was moved + it's sub channels for use
+    /// by the members for `to`
     pub async fn move_channel(
         &self,
         user: UserId,
@@ -861,14 +862,9 @@ impl Database {
             self.check_user_is_channel_admin(from, user, &*tx).await?;
 
             let mut channel_descendants = None;
-            if let Some(from_parent) = from_parent {
-                self.check_user_is_channel_admin(from_parent, user, &*tx)
-                    .await?;
-
-                self.remove_channel_from_parent(from, from_parent, &*tx)
-                    .await?;
-            }
 
+            // Note that we have to do the linking before the removal, so that we
+            // can leave the channel_path table in a consistent state.
             if let Some(to) = to {
                 self.check_user_is_channel_admin(to, user, &*tx).await?;
 
@@ -880,20 +876,28 @@ impl Database {
                 None => self.get_channel_descendants([from], &*tx).await?,
             };
 
-            // Repair the parent ID of the channel in case it was from a cached call
-            if let Some(channel) = channel_descendants.get_mut(&from) {
-                if let Some(from_parent) = from_parent {
-                    channel.remove(&from_parent);
-                }
-                if let Some(to) = to {
-                    channel.insert(to);
-                }
+            if let Some(from_parent) = from_parent {
+                self.check_user_is_channel_admin(from_parent, user, &*tx)
+                    .await?;
+
+                self.remove_channel_from_parent(from, from_parent, &*tx)
+                    .await?;
             }
 
+            let channels;
+            if let Some(to) = to {
+                if let Some(channel) = channel_descendants.get_mut(&from) {
+                    // Remove the other parents
+                    channel.clear();
+                    channel.insert(to);
+                }
 
-            let channels = self
-                .get_all_channels(channel_descendants, &*tx)
-                .await?;
+                 channels = self
+                    .get_all_channels(channel_descendants, &*tx)
+                    .await?;
+            } else {
+                channels = vec![];
+            }
 
             Ok(channels)
         })

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -657,7 +657,7 @@ async fn test_channels_moving(db: &Arc<Database>) {
     // zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id
     //    \--------/
 
-    // make sure we're getting the new link
+    // make sure we're getting just the new link
     pretty_assertions::assert_eq!(
         channels,
         vec![
@@ -665,12 +665,7 @@ async fn test_channels_moving(db: &Arc<Database>) {
                 id: livestreaming_dag_sub_id,
                 name: "livestreaming_dag_sub".to_string(),
                 parent_id: Some(livestreaming_id),
-            },
-            Channel {
-                id: livestreaming_dag_sub_id,
-                name: "livestreaming_dag_sub".to_string(),
-                parent_id: Some(livestreaming_dag_id),
-            },
+            }
         ]
     );
 
@@ -738,16 +733,6 @@ async fn test_channels_moving(db: &Arc<Database>) {
             name: "livestreaming".to_string(),
             parent_id: Some(gpui2_id),
         },
-        Channel {
-            id: livestreaming_id,
-            name: "livestreaming".to_string(),
-            parent_id: Some(zed_id),
-        },
-        Channel {
-            id: livestreaming_id,
-            name: "livestreaming".to_string(),
-            parent_id: Some(crdb_id),
-        },
         Channel {
             id: livestreaming_dag_id,
             name: "livestreaming_dag".to_string(),
@@ -826,16 +811,10 @@ async fn test_channels_moving(db: &Arc<Database>) {
     // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub
     //    \---------/
 
-    // Make sure the recently removed link isn't returned
+    // Since we're not moving it to anywhere, there's nothing to notify anyone about
     pretty_assertions::assert_eq!(
         channels,
-        vec![
-            Channel {
-                id: livestreaming_dag_sub_id,
-                name: "livestreaming_dag_sub".to_string(),
-                parent_id: Some(livestreaming_dag_id),
-            },
-        ]
+        vec![]
     );
 
 

crates/collab/src/rpc.rs 🔗

@@ -2400,7 +2400,7 @@ async fn move_channel(
     let channel_id = ChannelId::from_proto(request.channel_id);
     let from_parent = request.from_parent.map(ChannelId::from_proto);
     let to = request.to.map(ChannelId::from_proto);
-    let channels = db
+    let channels_to_send = db
         .move_channel(
             session.user_id,
             channel_id,
@@ -2432,7 +2432,7 @@ async fn move_channel(
         let members = db.get_channel_members(to).await?;
         let connection_pool = session.connection_pool().await;
         let update = proto::UpdateChannels {
-            channels: channels.into_iter().map(|channel| proto::Channel {
+            channels: channels_to_send.into_iter().map(|channel| proto::Channel {
                 id: channel.id.to_proto(),
                 name: channel.name,
                 parent_id: channel.parent_id.map(ChannelId::to_proto),