Ensure that invitees do not have permissions

Conrad Irwin created

They have to accept the invite, (which joining the channel will do),
first.

Change summary

crates/collab/src/db/queries/channels.rs    | 228 ++++++++++++----------
crates/collab/src/db/tests/channel_tests.rs |  72 +++---
crates/collab/src/rpc.rs                    |  35 +-
crates/collab/src/tests/channel_tests.rs    |  63 ++++++
4 files changed, 238 insertions(+), 160 deletions(-)

Detailed changes

crates/collab/src/db/queries/channels.rs 🔗

@@ -88,80 +88,87 @@ impl Database {
         .await
     }
 
-    pub async fn join_channel_internal(
+    pub async fn join_channel(
         &self,
         channel_id: ChannelId,
         user_id: UserId,
         connection: ConnectionId,
         environment: &str,
-        tx: &DatabaseTransaction,
-    ) -> Result<(JoinRoom, bool)> {
-        let mut joined = false;
+    ) -> Result<(JoinRoom, Option<ChannelId>)> {
+        self.transaction(move |tx| async move {
+            let mut joined_channel_id = None;
 
-        let channel = channel::Entity::find()
-            .filter(channel::Column::Id.eq(channel_id))
-            .one(&*tx)
-            .await?;
+            let channel = channel::Entity::find()
+                .filter(channel::Column::Id.eq(channel_id))
+                .one(&*tx)
+                .await?;
 
-        let mut role = self
-            .channel_role_for_user(channel_id, user_id, &*tx)
-            .await?;
+            let mut role = self
+                .channel_role_for_user(channel_id, user_id, &*tx)
+                .await?;
+
+            if role.is_none() && channel.is_some() {
+                if let Some(invitation) = self
+                    .pending_invite_for_channel(channel_id, user_id, &*tx)
+                    .await?
+                {
+                    // note, this may be a parent channel
+                    joined_channel_id = Some(invitation.channel_id);
+                    role = Some(invitation.role);
+
+                    channel_member::Entity::update(channel_member::ActiveModel {
+                        accepted: ActiveValue::Set(true),
+                        ..invitation.into_active_model()
+                    })
+                    .exec(&*tx)
+                    .await?;
+
+                    debug_assert!(
+                        self.channel_role_for_user(channel_id, user_id, &*tx)
+                            .await?
+                            == role
+                    );
+                }
+            }
+            if role.is_none()
+                && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
+            {
+                let channel_id_to_join = self
+                    .most_public_ancestor_for_channel(channel_id, &*tx)
+                    .await?
+                    .unwrap_or(channel_id);
+                role = Some(ChannelRole::Guest);
+                joined_channel_id = Some(channel_id_to_join);
 
-        if role.is_none() {
-            if channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public) {
                 channel_member::Entity::insert(channel_member::ActiveModel {
                     id: ActiveValue::NotSet,
-                    channel_id: ActiveValue::Set(channel_id),
+                    channel_id: ActiveValue::Set(channel_id_to_join),
                     user_id: ActiveValue::Set(user_id),
                     accepted: ActiveValue::Set(true),
                     role: ActiveValue::Set(ChannelRole::Guest),
                 })
-                .on_conflict(
-                    OnConflict::columns([
-                        channel_member::Column::UserId,
-                        channel_member::Column::ChannelId,
-                    ])
-                    .update_columns([channel_member::Column::Accepted])
-                    .to_owned(),
-                )
                 .exec(&*tx)
                 .await?;
 
                 debug_assert!(
                     self.channel_role_for_user(channel_id, user_id, &*tx)
                         .await?
-                        == Some(ChannelRole::Guest)
+                        == role
                 );
-
-                role = Some(ChannelRole::Guest);
-                joined = true;
             }
-        }
-
-        if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
-            Err(anyhow!("no such channel, or not allowed"))?
-        }
 
-        let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
-        let room_id = self
-            .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
-            .await?;
+            if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
+                Err(anyhow!("no such channel, or not allowed"))?
+            }
 
-        self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
-            .await
-            .map(|jr| (jr, joined))
-    }
+            let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
+            let room_id = self
+                .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
+                .await?;
 
-    pub async fn join_channel(
-        &self,
-        channel_id: ChannelId,
-        user_id: UserId,
-        connection: ConnectionId,
-        environment: &str,
-    ) -> Result<(JoinRoom, bool)> {
-        self.transaction(move |tx| async move {
-            self.join_channel_internal(channel_id, user_id, connection, environment, &*tx)
+            self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
                 .await
+                .map(|jr| (jr, joined_channel_id))
         })
         .await
     }
@@ -624,29 +631,29 @@ impl Database {
         admin_id: UserId,
         for_user: UserId,
         role: ChannelRole,
-    ) -> Result<()> {
+    ) -> Result<channel_member::Model> {
         self.transaction(|tx| async move {
             self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
                 .await?;
 
-            let result = channel_member::Entity::update_many()
+            let membership = channel_member::Entity::find()
                 .filter(
                     channel_member::Column::ChannelId
                         .eq(channel_id)
                         .and(channel_member::Column::UserId.eq(for_user)),
                 )
-                .set(channel_member::ActiveModel {
-                    role: ActiveValue::set(role),
-                    ..Default::default()
-                })
-                .exec(&*tx)
+                .one(&*tx)
                 .await?;
 
-            if result.rows_affected == 0 {
-                Err(anyhow!("no such member"))?;
-            }
+            let Some(membership) = membership else {
+                Err(anyhow!("no such member"))?
+            };
 
-            Ok(())
+            let mut update = membership.into_active_model();
+            update.role = ActiveValue::Set(role);
+            let updated = channel_member::Entity::update(update).exec(&*tx).await?;
+
+            Ok(updated)
         })
         .await
     }
@@ -844,6 +851,52 @@ impl Database {
         }
     }
 
+    pub async fn pending_invite_for_channel(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        tx: &DatabaseTransaction,
+    ) -> Result<Option<channel_member::Model>> {
+        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
+
+        let row = channel_member::Entity::find()
+            .filter(channel_member::Column::ChannelId.is_in(channel_ids))
+            .filter(channel_member::Column::UserId.eq(user_id))
+            .filter(channel_member::Column::Accepted.eq(false))
+            .one(&*tx)
+            .await?;
+
+        Ok(row)
+    }
+
+    pub async fn most_public_ancestor_for_channel(
+        &self,
+        channel_id: ChannelId,
+        tx: &DatabaseTransaction,
+    ) -> Result<Option<ChannelId>> {
+        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
+
+        let rows = channel::Entity::find()
+            .filter(channel::Column::Id.is_in(channel_ids.clone()))
+            .filter(channel::Column::Visibility.eq(ChannelVisibility::Public))
+            .all(&*tx)
+            .await?;
+
+        let mut visible_channels: HashSet<ChannelId> = HashSet::default();
+
+        for row in rows {
+            visible_channels.insert(row.id);
+        }
+
+        for ancestor in channel_ids.into_iter().rev() {
+            if visible_channels.contains(&ancestor) {
+                return Ok(Some(ancestor));
+            }
+        }
+
+        Ok(None)
+    }
+
     pub async fn channel_role_for_user(
         &self,
         channel_id: ChannelId,
@@ -864,7 +917,8 @@ impl Database {
             .filter(
                 channel_member::Column::ChannelId
                     .is_in(channel_ids)
-                    .and(channel_member::Column::UserId.eq(user_id)),
+                    .and(channel_member::Column::UserId.eq(user_id))
+                    .and(channel_member::Column::Accepted.eq(true)),
             )
             .select_only()
             .column(channel_member::Column::ChannelId)
@@ -1009,52 +1063,22 @@ impl Database {
         Ok(results)
     }
 
-    /// Returns the channel with the given ID and:
-    /// - true if the user is a member
-    /// - false if the user hasn't accepted the invitation yet
-    pub async fn get_channel(
-        &self,
-        channel_id: ChannelId,
-        user_id: UserId,
-    ) -> Result<Option<(Channel, bool)>> {
+    /// Returns the channel with the given ID
+    pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result<Channel> {
         self.transaction(|tx| async move {
-            let tx = tx;
+            self.check_user_is_channel_participant(channel_id, user_id, &*tx)
+                .await?;
 
             let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
+            let Some(channel) = channel else {
+                Err(anyhow!("no such channel"))?
+            };
 
-            if let Some(channel) = channel {
-                if self
-                    .check_user_is_channel_member(channel_id, user_id, &*tx)
-                    .await
-                    .is_err()
-                {
-                    return Ok(None);
-                }
-
-                let channel_membership = channel_member::Entity::find()
-                    .filter(
-                        channel_member::Column::ChannelId
-                            .eq(channel_id)
-                            .and(channel_member::Column::UserId.eq(user_id)),
-                    )
-                    .one(&*tx)
-                    .await?;
-
-                let is_accepted = channel_membership
-                    .map(|membership| membership.accepted)
-                    .unwrap_or(false);
-
-                Ok(Some((
-                    Channel {
-                        id: channel.id,
-                        visibility: channel.visibility,
-                        name: channel.name,
-                    },
-                    is_accepted,
-                )))
-            } else {
-                Ok(None)
-            }
+            Ok(Channel {
+                id: channel.id,
+                visibility: channel.visibility,
+                name: channel.name,
+            })
         })
         .await
     }

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -51,7 +51,7 @@ async fn test_channels(db: &Arc<Database>) {
     let zed_id = db.create_root_channel("zed", a_id).await.unwrap();
 
     // Make sure that people cannot read channels they haven't been invited to
-    assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
+    assert!(db.get_channel(zed_id, b_id).await.is_err());
 
     db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member)
         .await
@@ -157,7 +157,7 @@ async fn test_channels(db: &Arc<Database>) {
 
     // Remove a single channel
     db.delete_channel(crdb_id, a_id).await.unwrap();
-    assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none());
+    assert!(db.get_channel(crdb_id, a_id).await.is_err());
 
     // Remove a channel tree
     let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap();
@@ -165,9 +165,9 @@ async fn test_channels(db: &Arc<Database>) {
     assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]);
     assert_eq!(user_ids, &[a_id]);
 
-    assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none());
-    assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none());
-    assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none());
+    assert!(db.get_channel(rust_id, a_id).await.is_err());
+    assert!(db.get_channel(cargo_id, a_id).await.is_err());
+    assert!(db.get_channel(cargo_ra_id, a_id).await.is_err());
 }
 
 test_both_dbs!(
@@ -381,11 +381,7 @@ async fn test_channel_renames(db: &Arc<Database>) {
 
     let zed_archive_id = zed_id;
 
-    let (channel, _) = db
-        .get_channel(zed_archive_id, user_1)
-        .await
-        .unwrap()
-        .unwrap();
+    let channel = db.get_channel(zed_archive_id, user_1).await.unwrap();
     assert_eq!(channel.name, "zed-archive");
 
     let non_permissioned_rename = db
@@ -860,12 +856,6 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
     })
     .await
     .unwrap();
-    db.transaction(|tx| async move {
-        db.check_user_is_channel_participant(vim_channel, guest, &*tx)
-            .await
-    })
-    .await
-    .unwrap();
 
     let members = db
         .get_channel_participant_details(vim_channel, admin)
@@ -896,6 +886,13 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
         .await
         .unwrap();
 
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(vim_channel, guest, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+
     let channels = db.get_channels_for_user(guest).await.unwrap().channels;
     assert_dag(channels, &[(vim_channel, None)]);
     let channels = db.get_channels_for_user(member).await.unwrap().channels;
@@ -953,29 +950,7 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
         .await
         .unwrap();
 
-    db.transaction(|tx| async move {
-        db.check_user_is_channel_participant(zed_channel, guest, &*tx)
-            .await
-    })
-    .await
-    .unwrap();
-    assert!(db
-        .transaction(|tx| async move {
-            db.check_user_is_channel_participant(active_channel, guest, &*tx)
-                .await
-        })
-        .await
-        .is_err(),);
-
-    db.transaction(|tx| async move {
-        db.check_user_is_channel_participant(vim_channel, guest, &*tx)
-            .await
-    })
-    .await
-    .unwrap();
-
     // currently people invited to parent channels are not shown here
-    // (though they *do* have permissions!)
     let members = db
         .get_channel_participant_details(vim_channel, admin)
         .await
@@ -1000,6 +975,27 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
         .await
         .unwrap();
 
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(zed_channel, guest, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+    assert!(db
+        .transaction(|tx| async move {
+            db.check_user_is_channel_participant(active_channel, guest, &*tx)
+                .await
+        })
+        .await
+        .is_err(),);
+
+    db.transaction(|tx| async move {
+        db.check_user_is_channel_participant(vim_channel, guest, &*tx)
+            .await
+    })
+    .await
+    .unwrap();
+
     let members = db
         .get_channel_participant_details(vim_channel, admin)
         .await

crates/collab/src/rpc.rs 🔗

@@ -38,7 +38,7 @@ use lazy_static::lazy_static;
 use prometheus::{register_int_gauge, IntGauge};
 use rpc::{
     proto::{
-        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, JoinRoom,
+        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
         LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
     },
     Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
@@ -2289,10 +2289,7 @@ async fn invite_channel_member(
     )
     .await?;
 
-    let (channel, _) = db
-        .get_channel(channel_id, session.user_id)
-        .await?
-        .ok_or_else(|| anyhow!("channel not found"))?;
+    let channel = db.get_channel(channel_id, session.user_id).await?;
 
     let mut update = proto::UpdateChannels::default();
     update.channel_invitations.push(proto::Channel {
@@ -2380,21 +2377,19 @@ async fn set_channel_member_role(
     let db = session.db().await;
     let channel_id = ChannelId::from_proto(request.channel_id);
     let member_id = UserId::from_proto(request.user_id);
-    db.set_channel_member_role(
-        channel_id,
-        session.user_id,
-        member_id,
-        request.role().into(),
-    )
-    .await?;
+    let channel_member = db
+        .set_channel_member_role(
+            channel_id,
+            session.user_id,
+            member_id,
+            request.role().into(),
+        )
+        .await?;
 
-    let (channel, has_accepted) = db
-        .get_channel(channel_id, member_id)
-        .await?
-        .ok_or_else(|| anyhow!("channel not found"))?;
+    let channel = db.get_channel(channel_id, session.user_id).await?;
 
     let mut update = proto::UpdateChannels::default();
-    if has_accepted {
+    if channel_member.accepted {
         update.channel_permissions.push(proto::ChannelPermission {
             channel_id: channel.id.to_proto(),
             role: request.role,
@@ -2724,9 +2719,11 @@ async fn join_channel_internal(
             channel_id: joined_room.channel_id.map(|id| id.to_proto()),
             live_kit_connection_info,
         })?;
+        dbg!("Joined channel", &joined_channel);
 
-        if joined_channel {
-            channel_membership_updated(db, channel_id, &session).await?
+        if let Some(joined_channel) = joined_channel {
+            dbg!("CMU");
+            channel_membership_updated(db, joined_channel, &session).await?
         }
 
         room_updated(&joined_room.room, &session.peer);

crates/collab/src/tests/channel_tests.rs 🔗

@@ -7,7 +7,7 @@ use channel::{ChannelId, ChannelMembership, ChannelStore};
 use client::User;
 use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
 use rpc::{
-    proto::{self},
+    proto::{self, ChannelRole},
     RECEIVE_TIMEOUT,
 };
 use std::sync::Arc;
@@ -965,6 +965,67 @@ async fn test_guest_access(
     })
 }
 
+#[gpui::test]
+async fn test_invite_access(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+
+    let channels = server
+        .make_channel_tree(
+            &[("channel-a", None), ("channel-b", Some("channel-a"))],
+            (&client_a, cx_a),
+        )
+        .await;
+    let channel_a_id = channels[0];
+    let channel_b_id = channels[0];
+
+    let active_call_b = cx_b.read(ActiveCall::global);
+
+    // should not be allowed to join
+    assert!(active_call_b
+        .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx))
+        .await
+        .is_err());
+
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, cx| {
+            channel_store.invite_member(
+                channel_a_id,
+                client_b.user_id().unwrap(),
+                ChannelRole::Member,
+                cx,
+            )
+        })
+        .await
+        .unwrap();
+
+    active_call_b
+        .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx))
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+
+    client_b.channel_store().update(cx_b, |channel_store, _| {
+        assert!(channel_store.channel_for_id(channel_b_id).is_some());
+        assert!(channel_store.channel_for_id(channel_a_id).is_some());
+    });
+
+    client_a.channel_store().update(cx_a, |channel_store, _| {
+        let participants = channel_store.channel_participants(channel_b_id);
+        assert_eq!(participants.len(), 1);
+        assert_eq!(participants[0].id, client_b.user_id().unwrap());
+    })
+}
+
 #[gpui::test]
 async fn test_channel_moving(
     deterministic: Arc<Deterministic>,