Fix set_channel_visibility for public channels

Conrad Irwin created

Change summary

crates/collab/src/db.rs                  |  1 
crates/collab/src/db/queries/channels.rs | 36 ++++++++-----
crates/collab/src/rpc.rs                 | 12 ++--
crates/collab/src/tests/channel_tests.rs | 65 ++++++++++++++++++++------
4 files changed, 78 insertions(+), 36 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -458,6 +458,7 @@ pub struct CreateChannelResult {
 pub struct SetChannelVisibilityResult {
     pub participants_to_update: HashMap<UserId, ChannelsForUser>,
     pub participants_to_remove: HashSet<UserId>,
+    pub channels_to_remove: Vec<ChannelId>,
 }
 
 #[derive(Debug)]

crates/collab/src/db/queries/channels.rs 🔗

@@ -244,9 +244,30 @@ impl Database {
                 .into_iter()
                 .collect();
 
+            let mut channels_to_remove: Vec<ChannelId> = vec![];
             let mut participants_to_remove: HashSet<UserId> = HashSet::default();
             match visibility {
                 ChannelVisibility::Members => {
+                    let all_descendents: Vec<ChannelId> = self
+                        .get_channel_descendants(vec![channel_id], &*tx)
+                        .await?
+                        .into_iter()
+                        .map(|edge| ChannelId::from_proto(edge.channel_id))
+                        .collect();
+
+                    channels_to_remove = channel::Entity::find()
+                        .filter(
+                            channel::Column::Id
+                                .is_in(all_descendents)
+                                .and(channel::Column::Visibility.eq(ChannelVisibility::Public)),
+                        )
+                        .all(&*tx)
+                        .await?
+                        .into_iter()
+                        .map(|channel| channel.id)
+                        .collect();
+
+                    channels_to_remove.push(channel_id);
                     for member in previous_members {
                         if member.role.can_only_see_public_descendants() {
                             participants_to_remove.insert(member.user_id);
@@ -271,6 +292,7 @@ impl Database {
             Ok(SetChannelVisibilityResult {
                 participants_to_update,
                 participants_to_remove,
+                channels_to_remove,
             })
         })
         .await
@@ -694,14 +716,10 @@ impl Database {
             .all(&*tx)
             .await?;
 
-        dbg!((user_id, &channel_memberships));
-
         let mut edges = self
             .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
             .await?;
 
-        dbg!((user_id, &edges));
-
         let mut role_for_channel: HashMap<ChannelId, (ChannelRole, bool)> = HashMap::default();
 
         for membership in channel_memberships.iter() {
@@ -710,8 +728,6 @@ impl Database {
             role_for_channel.insert(membership.channel_id, (membership.role, included));
         }
 
-        dbg!((&role_for_channel, parent_channel_id));
-
         for ChannelEdge {
             parent_id,
             channel_id,
@@ -739,7 +755,6 @@ impl Database {
                 );
             }
         }
-        dbg!((&role_for_channel, parent_channel_id));
 
         let mut channels: Vec<Channel> = Vec::new();
         let mut channels_to_remove: HashSet<u64> = HashSet::default();
@@ -757,7 +772,6 @@ impl Database {
                 || role == ChannelRole::Banned
                 || role == ChannelRole::Guest && channel.visibility != ChannelVisibility::Public
             {
-                dbg!("remove", channel.id);
                 channels_to_remove.insert(channel.id.0 as u64);
                 continue;
             }
@@ -865,8 +879,6 @@ impl Database {
             .get_channel_participant_details_internal(new_parent, &*tx)
             .await?;
 
-        dbg!(&members);
-
         for member in members.iter() {
             if !member.role.can_see_all_descendants() {
                 continue;
@@ -897,8 +909,6 @@ impl Database {
                 .await?
         };
 
-        dbg!(&public_members);
-
         for member in public_members {
             if !member.role.can_only_see_public_descendants() {
                 continue;
@@ -1666,8 +1676,6 @@ impl Database {
                 .into_iter()
                 .collect();
 
-            dbg!(&participants_to_update);
-
             let mut moved_channels: HashSet<ChannelId> = HashSet::default();
             moved_channels.insert(channel_id);
             for edge in self.get_channel_descendants([channel_id], &*tx).await? {

crates/collab/src/rpc.rs 🔗

@@ -2366,6 +2366,7 @@ async fn set_channel_visibility(
     let SetChannelVisibilityResult {
         participants_to_update,
         participants_to_remove,
+        channels_to_remove,
     } = db
         .set_channel_visibility(channel_id, visibility, session.user_id)
         .await?;
@@ -2379,10 +2380,7 @@ async fn set_channel_visibility(
     }
     for user_id in participants_to_remove {
         let update = proto::UpdateChannels {
-            // for public participants  we only need to remove the current channel
-            // (not descendants)
-            // because they can still see any public descendants
-            delete_channels: vec![channel_id.to_proto()],
+            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
             ..Default::default()
         };
         for connection_id in connection_pool.user_connection_ids(user_id) {
@@ -2645,7 +2643,7 @@ async fn join_channel_internal(
         leave_room_for_session(&session).await?;
         let db = session.db().await;
 
-        let (joined_room, accept_invite_result, role) = db
+        let (joined_room, membership_updated, role) = db
             .join_channel(
                 channel_id,
                 session.user_id,
@@ -2691,10 +2689,10 @@ async fn join_channel_internal(
         })?;
 
         let connection_pool = session.connection_pool().await;
-        if let Some(accept_invite_result) = accept_invite_result {
+        if let Some(membership_updated) = membership_updated {
             notify_membership_updated(
                 &connection_pool,
-                accept_invite_result,
+                membership_updated,
                 session.user_id,
                 &session.peer,
             );

crates/collab/src/tests/channel_tests.rs 🔗

@@ -1122,7 +1122,7 @@ async fn test_channel_link_notifications(
     assert_channels_list_shape(
         client_c.channel_store(),
         cx_c,
-        &[(zed_channel, 0), (helix_channel, 1)],
+        &[(zed_channel, 0)],
     );
 }
 
@@ -1250,44 +1250,79 @@ async fn test_guest_access(
     let client_b = server.create_client(cx_b, "user_b").await;
 
     let channels = server
-        .make_channel_tree(&[("channel-a", None)], (&client_a, cx_a))
+        .make_channel_tree(
+            &[("channel-a", None), ("channel-b", Some("channel-a"))],
+            (&client_a, cx_a),
+        )
         .await;
-    let channel_a_id = channels[0];
+    let channel_a = channels[0];
+    let channel_b = channels[1];
 
     let active_call_b = cx_b.read(ActiveCall::global);
 
-    // should not be allowed to join
+    // Non-members should not be allowed to join
     assert!(active_call_b
-        .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx))
+        .update(cx_b, |call, cx| call.join_channel(channel_a, cx))
         .await
         .is_err());
 
+    // Make channels A and B public
     client_a
         .channel_store()
         .update(cx_a, |channel_store, cx| {
-            channel_store.set_channel_visibility(channel_a_id, proto::ChannelVisibility::Public, cx)
+            channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Public, cx)
+        })
+        .await
+        .unwrap();
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, cx| {
+            channel_store.set_channel_visibility(channel_b, proto::ChannelVisibility::Public, cx)
         })
         .await
         .unwrap();
 
+    // Client B joins channel A as a guest
     active_call_b
-        .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx))
+        .update(cx_b, |call, cx| call.join_channel(channel_a, cx))
         .await
         .unwrap();
 
     deterministic.run_until_parked();
-
-    assert!(client_b
-        .channel_store()
-        .update(cx_b, |channel_store, _| channel_store
-            .channel_for_id(channel_a_id)
-            .is_some()));
+    assert_channels_list_shape(
+        client_a.channel_store(),
+        cx_a,
+        &[(channel_a, 0), (channel_b, 1)],
+    );
+    assert_channels_list_shape(
+        client_b.channel_store(),
+        cx_b,
+        &[(channel_a, 0), (channel_b, 1)],
+    );
 
     client_a.channel_store().update(cx_a, |channel_store, _| {
-        let participants = channel_store.channel_participants(channel_a_id);
+        let participants = channel_store.channel_participants(channel_a);
         assert_eq!(participants.len(), 1);
         assert_eq!(participants[0].id, client_b.user_id().unwrap());
-    })
+    });
+
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, cx| {
+            channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Members, cx)
+        })
+        .await
+        .unwrap();
+
+    assert_channels_list_shape(client_b.channel_store(), cx_b, &[]);
+
+    active_call_b
+        .update(cx_b, |call, cx| call.join_channel(channel_b, cx))
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+    assert_channels_list_shape(client_b.channel_store(), cx_b, &[(channel_b, 0)]);
 }
 
 #[gpui::test]