Fix a bug where channel invitations would show up in the channels section

Mikayla created

Block non-members from reading channel information
WIP: Make sure Arc::make_mut() works

Change summary

crates/client/src/channel_store.rs       |  7 +
crates/collab/src/db.rs                  | 60 +++++++++++++++-------
crates/collab/src/db/tests.rs            |  3 +
crates/collab/src/rpc.rs                 | 32 ++++++++----
crates/collab/src/tests/channel_tests.rs | 67 ++++++++++++++++++++++++++
5 files changed, 137 insertions(+), 32 deletions(-)

Detailed changes

crates/client/src/channel_store.rs 🔗

@@ -301,8 +301,10 @@ impl ChannelStore {
                 .iter_mut()
                 .find(|c| c.id == channel.id)
             {
-                let existing_channel = Arc::make_mut(existing_channel);
+                let existing_channel = Arc::get_mut(existing_channel)
+                    .expect("channel is shared, update would have been lost");
                 existing_channel.name = channel.name;
+                existing_channel.user_is_admin = channel.user_is_admin;
                 continue;
             }
 
@@ -320,7 +322,8 @@ impl ChannelStore {
 
         for channel in payload.channels {
             if let Some(existing_channel) = self.channels.iter_mut().find(|c| c.id == channel.id) {
-                let existing_channel = Arc::make_mut(existing_channel);
+                let existing_channel = Arc::get_mut(existing_channel)
+                    .expect("channel is shared, update would have been lost");
                 existing_channel.name = channel.name;
                 existing_channel.user_is_admin = channel.user_is_admin;
                 continue;

crates/collab/src/db.rs 🔗

@@ -3601,7 +3601,7 @@ impl Database {
             )
             .one(&*tx)
             .await?
-            .ok_or_else(|| anyhow!("user is not a channel member"))?;
+            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
         Ok(())
     }
 
@@ -3621,7 +3621,7 @@ impl Database {
             )
             .one(&*tx)
             .await?
-            .ok_or_else(|| anyhow!("user is not a channel admin"))?;
+            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
         Ok(())
     }
 
@@ -3723,31 +3723,53 @@ impl Database {
         Ok(parents_by_child_id)
     }
 
+    /// Returns the channel with the given ID and:
+    /// - true if the user is a member
+    /// - false if the user hasn't accepted the invitation yet
     pub async fn get_channel(
         &self,
         channel_id: ChannelId,
         user_id: UserId,
-    ) -> Result<Option<Channel>> {
+    ) -> Result<Option<(Channel, bool)>> {
         self.transaction(|tx| async move {
             let tx = tx;
+
             let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
-            let user_is_admin = channel_member::Entity::find()
-                .filter(
-                    channel_member::Column::ChannelId
-                        .eq(channel_id)
-                        .and(channel_member::Column::UserId.eq(user_id))
-                        .and(channel_member::Column::Admin.eq(true)),
-                )
-                .count(&*tx)
-                .await?
-                > 0;
 
-            Ok(channel.map(|channel| Channel {
-                id: channel.id,
-                name: channel.name,
-                user_is_admin,
-                parent_id: None,
-            }))
+            if let Some(channel) = channel {
+                if self
+                    .check_user_is_channel_member(channel_id, user_id, &*tx)
+                    .await
+                    .is_err()
+                {
+                    return Ok(None);
+                }
+
+                let channel_membership = channel_member::Entity::find()
+                    .filter(
+                        channel_member::Column::ChannelId
+                            .eq(channel_id)
+                            .and(channel_member::Column::UserId.eq(user_id)),
+                    )
+                    .one(&*tx)
+                    .await?;
+
+                let (user_is_admin, is_accepted) = channel_membership
+                    .map(|membership| (membership.admin, membership.accepted))
+                    .unwrap_or((false, false));
+
+                Ok(Some((
+                    Channel {
+                        id: channel.id,
+                        name: channel.name,
+                        user_is_admin,
+                        parent_id: None,
+                    },
+                    is_accepted,
+                )))
+            } else {
+                Ok(None)
+            }
         })
         .await
     }

crates/collab/src/db/tests.rs 🔗

@@ -915,6 +915,9 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, {
 
     let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
 
+    // Make sure that people cannot read channels they haven't been invited to
+    assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
+
     db.invite_channel_member(zed_id, b_id, a_id, false)
         .await
         .unwrap();

crates/collab/src/rpc.rs 🔗

@@ -2210,14 +2210,15 @@ async fn invite_channel_member(
 ) -> Result<()> {
     let db = session.db().await;
     let channel_id = ChannelId::from_proto(request.channel_id);
-    let channel = db
-        .get_channel(channel_id, session.user_id)
-        .await?
-        .ok_or_else(|| anyhow!("channel not found"))?;
     let invitee_id = UserId::from_proto(request.user_id);
     db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
         .await?;
 
+    let (channel, _) = db
+        .get_channel(channel_id, session.user_id)
+        .await?
+        .ok_or_else(|| anyhow!("channel not found"))?;
+
     let mut update = proto::UpdateChannels::default();
     update.channel_invitations.push(proto::Channel {
         id: channel.id.to_proto(),
@@ -2275,18 +2276,27 @@ async fn set_channel_member_admin(
     db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
         .await?;
 
-    let channel = db
+    let (channel, has_accepted) = db
         .get_channel(channel_id, member_id)
         .await?
         .ok_or_else(|| anyhow!("channel not found"))?;
 
     let mut update = proto::UpdateChannels::default();
-    update.channels.push(proto::Channel {
-        id: channel.id.to_proto(),
-        name: channel.name,
-        parent_id: None,
-        user_is_admin: request.admin,
-    });
+    if has_accepted {
+        update.channels.push(proto::Channel {
+            id: channel.id.to_proto(),
+            name: channel.name,
+            parent_id: None,
+            user_is_admin: request.admin,
+        });
+    } else {
+        update.channel_invitations.push(proto::Channel {
+            id: channel.id.to_proto(),
+            name: channel.name,
+            parent_id: None,
+            user_is_admin: request.admin,
+        });
+    }
 
     for connection_id in session
         .connection_pool()

crates/collab/src/tests/channel_tests.rs 🔗

@@ -584,3 +584,70 @@ async fn test_channel_jumping(deterministic: Arc<Deterministic>, cx_a: &mut Test
         );
     });
 }
+
+#[gpui::test]
+async fn test_permissions_update_while_invited(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+
+    let rust_id = server
+        .make_channel("rust", (&client_a, cx_a), &mut [])
+        .await;
+
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, cx| {
+            channel_store.invite_member(rust_id, client_b.user_id().unwrap(), false, cx)
+        })
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+
+    client_b.channel_store().read_with(cx_b, |channels, _| {
+        assert_eq!(
+            channels.channel_invitations(),
+            &[Arc::new(Channel {
+                id: rust_id,
+                name: "rust".to_string(),
+                parent_id: None,
+                user_is_admin: false,
+                depth: 0,
+            })],
+        );
+
+        assert_eq!(channels.channels(), &[],);
+    });
+
+    // Update B's invite before they've accepted it
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, cx| {
+            channel_store.set_member_admin(rust_id, client_b.user_id().unwrap(), true, cx)
+        })
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+
+    client_b.channel_store().read_with(cx_b, |channels, _| {
+        assert_eq!(
+            channels.channel_invitations(),
+            &[Arc::new(Channel {
+                id: rust_id,
+                name: "rust".to_string(),
+                parent_id: None,
+                user_is_admin: true,
+                depth: 0,
+            })],
+        );
+
+        assert_eq!(channels.channels(), &[],);
+    });
+}