Fix notifications for membership changes too

Conrad Irwin created

Change summary

crates/channel/src/channel_store.rs                    |  22 
crates/channel/src/channel_store/channel_index.rs      |   6 
crates/collab/src/db.rs                                |  13 
crates/collab/src/db/queries/channels.rs               | 298 +++++++----
crates/collab/src/db/tests/channel_tests.rs            |   2 
crates/collab/src/rpc.rs                               | 172 +++---
crates/collab/src/tests/channel_buffer_tests.rs        |  12 
crates/collab/src/tests/channel_tests.rs               | 170 +++++-
crates/collab/src/tests/random_channel_buffer_tests.rs |   4 
crates/live_kit_client/src/test.rs                     |  10 
crates/live_kit_server/src/api.rs                      |  10 
crates/live_kit_server/src/token.rs                    |   9 
12 files changed, 469 insertions(+), 259 deletions(-)

Detailed changes

crates/channel/src/channel_store.rs 🔗

@@ -939,27 +939,11 @@ impl ChannelStore {
 
         if channels_changed {
             if !payload.delete_channels.is_empty() {
-                let mut channels_to_delete: Vec<u64> = Vec::new();
-                let mut channels_to_rehome: Vec<u64> = Vec::new();
-                for channel_id in payload.delete_channels {
-                    if payload
-                        .channels
-                        .iter()
-                        .any(|channel| channel.id == channel_id)
-                    {
-                        channels_to_rehome.push(channel_id)
-                    } else {
-                        channels_to_delete.push(channel_id)
-                    }
-                }
-
-                self.channel_index.delete_channels(&channels_to_delete);
-                self.channel_index
-                    .delete_paths_through_channels(&channels_to_rehome);
+                self.channel_index.delete_channels(&payload.delete_channels);
                 self.channel_participants
-                    .retain(|channel_id, _| !channels_to_delete.contains(channel_id));
+                    .retain(|channel_id, _| !&payload.delete_channels.contains(channel_id));
 
-                for channel_id in &channels_to_delete {
+                for channel_id in &payload.delete_channels {
                     let channel_id = *channel_id;
                     if payload
                         .channels

crates/channel/src/channel_store/channel_index.rs 🔗

@@ -24,14 +24,8 @@ impl ChannelIndex {
 
     /// Delete the given channels from this index.
     pub fn delete_channels(&mut self, channels: &[ChannelId]) {
-        dbg!("delete_channels", &channels);
         self.channels_by_id
             .retain(|channel_id, _| !channels.contains(channel_id));
-        self.delete_paths_through_channels(channels)
-    }
-
-    pub fn delete_paths_through_channels(&mut self, channels: &[ChannelId]) {
-        dbg!("rehome_channels", &channels);
         self.paths
             .retain(|path| !path.iter().any(|channel_id| channels.contains(channel_id)));
     }

crates/collab/src/db.rs 🔗

@@ -453,6 +453,19 @@ pub struct SetChannelVisibilityResult {
     pub participants_to_remove: HashSet<UserId>,
 }
 
+#[derive(Debug)]
+pub struct MembershipUpdated {
+    pub channel_id: ChannelId,
+    pub new_channels: ChannelsForUser,
+    pub removed_channels: Vec<ChannelId>,
+}
+
+#[derive(Debug)]
+pub enum SetMemberRoleResult {
+    InviteUpdated(Channel),
+    MembershipUpdated(MembershipUpdated),
+}
+
 #[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)]
 pub struct Channel {
     pub id: ChannelId,

crates/collab/src/db/queries/channels.rs 🔗

@@ -128,9 +128,9 @@ impl Database {
         user_id: UserId,
         connection: ConnectionId,
         environment: &str,
-    ) -> Result<(JoinRoom, Option<ChannelId>)> {
+    ) -> Result<(JoinRoom, Option<MembershipUpdated>, ChannelRole)> {
         self.transaction(move |tx| async move {
-            let mut joined_channel_id = None;
+            let mut accept_invite_result = None;
 
             let channel = channel::Entity::find()
                 .filter(channel::Column::Id.eq(channel_id))
@@ -147,9 +147,7 @@ impl Database {
                     .await?
                 {
                     // note, this may be a parent channel
-                    joined_channel_id = Some(invitation.channel_id);
                     role = Some(invitation.role);
-
                     channel_member::Entity::update(channel_member::ActiveModel {
                         accepted: ActiveValue::Set(true),
                         ..invitation.into_active_model()
@@ -157,6 +155,11 @@ impl Database {
                     .exec(&*tx)
                     .await?;
 
+                    accept_invite_result = Some(
+                        self.calculate_membership_updated(channel_id, user_id, &*tx)
+                            .await?,
+                    );
+
                     debug_assert!(
                         self.channel_role_for_user(channel_id, user_id, &*tx)
                             .await?
@@ -167,6 +170,7 @@ impl Database {
             if role.is_none()
                 && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
             {
+                role = Some(ChannelRole::Guest);
                 let channel_id_to_join = self
                     .public_path_to_channel(channel_id, &*tx)
                     .await?
@@ -174,9 +178,6 @@ impl Database {
                     .cloned()
                     .unwrap_or(channel_id);
 
-                role = Some(ChannelRole::Guest);
-                joined_channel_id = Some(channel_id_to_join);
-
                 channel_member::Entity::insert(channel_member::ActiveModel {
                     id: ActiveValue::NotSet,
                     channel_id: ActiveValue::Set(channel_id_to_join),
@@ -187,6 +188,11 @@ impl Database {
                 .exec(&*tx)
                 .await?;
 
+                accept_invite_result = Some(
+                    self.calculate_membership_updated(channel_id, user_id, &*tx)
+                        .await?,
+                );
+
                 debug_assert!(
                     self.channel_role_for_user(channel_id, user_id, &*tx)
                         .await?
@@ -205,7 +211,7 @@ impl Database {
 
             self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
                 .await
-                .map(|jr| (jr, joined_channel_id))
+                .map(|jr| (jr, accept_invite_result, role.unwrap()))
         })
         .await
     }
@@ -345,7 +351,7 @@ impl Database {
         invitee_id: UserId,
         admin_id: UserId,
         role: ChannelRole,
-    ) -> Result<()> {
+    ) -> Result<Channel> {
         self.transaction(move |tx| async move {
             self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
                 .await?;
@@ -360,7 +366,17 @@ impl Database {
             .insert(&*tx)
             .await?;
 
-            Ok(())
+            let channel = channel::Entity::find_by_id(channel_id)
+                .one(&*tx)
+                .await?
+                .unwrap();
+
+            Ok(Channel {
+                id: channel.id,
+                visibility: channel.visibility,
+                name: channel.name,
+                role,
+            })
         })
         .await
     }
@@ -429,10 +445,10 @@ impl Database {
         channel_id: ChannelId,
         user_id: UserId,
         accept: bool,
-    ) -> Result<()> {
+    ) -> Result<Option<MembershipUpdated>> {
         self.transaction(move |tx| async move {
-            let rows_affected = if accept {
-                channel_member::Entity::update_many()
+            if accept {
+                let rows_affected = channel_member::Entity::update_many()
                     .set(channel_member::ActiveModel {
                         accepted: ActiveValue::Set(accept),
                         ..Default::default()
@@ -445,33 +461,96 @@ impl Database {
                     )
                     .exec(&*tx)
                     .await?
-                    .rows_affected
-            } else {
-                channel_member::ActiveModel {
-                    channel_id: ActiveValue::Unchanged(channel_id),
-                    user_id: ActiveValue::Unchanged(user_id),
-                    ..Default::default()
+                    .rows_affected;
+
+                if rows_affected == 0 {
+                    Err(anyhow!("no such invitation"))?;
                 }
-                .delete(&*tx)
-                .await?
-                .rows_affected
-            };
+
+                return Ok(Some(
+                    self.calculate_membership_updated(channel_id, user_id, &*tx)
+                        .await?,
+                ));
+            }
+
+            let rows_affected = channel_member::ActiveModel {
+                channel_id: ActiveValue::Unchanged(channel_id),
+                user_id: ActiveValue::Unchanged(user_id),
+                ..Default::default()
+            }
+            .delete(&*tx)
+            .await?
+            .rows_affected;
 
             if rows_affected == 0 {
                 Err(anyhow!("no such invitation"))?;
             }
 
-            Ok(())
+            Ok(None)
         })
         .await
     }
 
+    async fn calculate_membership_updated(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        tx: &DatabaseTransaction,
+    ) -> Result<MembershipUpdated> {
+        let mut channel_to_refresh = channel_id;
+        let mut removed_channels: Vec<ChannelId> = Vec::new();
+
+        // if the user was previously a guest of a parent public channel they may have seen this
+        // channel (or its descendants) in the tree already.
+        // Now they have new permissions, the graph of channels needs updating from that point.
+        if let Some(public_parent) = self.public_parent_channel_id(channel_id, &*tx).await? {
+            if self
+                .channel_role_for_user(public_parent, user_id, &*tx)
+                .await?
+                == Some(ChannelRole::Guest)
+            {
+                channel_to_refresh = public_parent;
+            }
+        }
+
+        // remove all descendant channels from the user's tree
+        removed_channels.append(
+            &mut self
+                .get_channel_descendants(vec![channel_to_refresh], &*tx)
+                .await?
+                .into_iter()
+                .map(|edge| ChannelId::from_proto(edge.channel_id))
+                .collect(),
+        );
+
+        let new_channels = self
+            .get_user_channels(user_id, Some(channel_to_refresh), &*tx)
+            .await?;
+
+        // We only add the current channel to "moved" if the user has lost access,
+        // otherwise it would be made a root channel on the client.
+        if !new_channels
+            .channels
+            .channels
+            .iter()
+            .any(|c| c.id == channel_id)
+        {
+            removed_channels.push(channel_id);
+        }
+
+        Ok(MembershipUpdated {
+            channel_id,
+            new_channels,
+            removed_channels,
+        })
+    }
+
     pub async fn remove_channel_member(
         &self,
         channel_id: ChannelId,
         member_id: UserId,
         admin_id: UserId,
-    ) -> Result<()> {
+    ) -> Result<MembershipUpdated> {
         self.transaction(|tx| async move {
             self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
                 .await?;
@@ -489,7 +568,9 @@ impl Database {
                 Err(anyhow!("no such member"))?;
             }
 
-            Ok(())
+            Ok(self
+                .calculate_membership_updated(channel_id, member_id, &*tx)
+                .await?)
         })
         .await
     }
@@ -535,44 +616,7 @@ impl Database {
         self.transaction(|tx| async move {
             let tx = tx;
 
-            let channel_memberships = channel_member::Entity::find()
-                .filter(
-                    channel_member::Column::UserId
-                        .eq(user_id)
-                        .and(channel_member::Column::Accepted.eq(true)),
-                )
-                .all(&*tx)
-                .await?;
-
-            self.get_user_channels(user_id, channel_memberships, &tx)
-                .await
-        })
-        .await
-    }
-
-    pub async fn get_channel_for_user(
-        &self,
-        channel_id: ChannelId,
-        user_id: UserId,
-    ) -> Result<ChannelsForUser> {
-        self.transaction(|tx| async move {
-            let tx = tx;
-            let role = self
-                .check_user_is_channel_participant(channel_id, user_id, &*tx)
-                .await?;
-
-            self.get_user_channels(
-                user_id,
-                vec![channel_member::Model {
-                    id: Default::default(),
-                    channel_id,
-                    user_id,
-                    role,
-                    accepted: true,
-                }],
-                &tx,
-            )
-            .await
+            self.get_user_channels(user_id, None, &tx).await
         })
         .await
     }
@@ -580,19 +624,42 @@ impl Database {
     pub async fn get_user_channels(
         &self,
         user_id: UserId,
-        channel_memberships: Vec<channel_member::Model>,
+        parent_channel_id: Option<ChannelId>,
         tx: &DatabaseTransaction,
     ) -> Result<ChannelsForUser> {
+        // note: we could (maybe) win some efficiency here when parent_channel_id
+        // is set by getting just the role for that channel, then getting descendants
+        // with roles attached; but that's not as straightforward as it sounds
+        // because we need to calculate the path to the channel to make the query
+        // efficient, which currently requires an extra round trip to the database.
+        // Fix this later...
+        let channel_memberships = channel_member::Entity::find()
+            .filter(
+                channel_member::Column::UserId
+                    .eq(user_id)
+                    .and(channel_member::Column::Accepted.eq(true)),
+            )
+            .all(&*tx)
+            .await?;
+
+        dbg!((user_id, &channel_memberships));
+
         let mut edges = self
             .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
             .await?;
 
-        let mut role_for_channel: HashMap<ChannelId, ChannelRole> = HashMap::default();
+        dbg!((user_id, &edges));
+
+        let mut role_for_channel: HashMap<ChannelId, (ChannelRole, bool)> = HashMap::default();
 
         for membership in channel_memberships.iter() {
-            role_for_channel.insert(membership.channel_id, membership.role);
+            let included =
+                parent_channel_id.is_none() || membership.channel_id == parent_channel_id.unwrap();
+            role_for_channel.insert(membership.channel_id, (membership.role, included));
         }
 
+        dbg!((&role_for_channel, parent_channel_id));
+
         for ChannelEdge {
             parent_id,
             channel_id,
@@ -601,14 +668,26 @@ impl Database {
             let parent_id = ChannelId::from_proto(*parent_id);
             let channel_id = ChannelId::from_proto(*channel_id);
             debug_assert!(role_for_channel.get(&parent_id).is_some());
-            let parent_role = role_for_channel[&parent_id];
-            if let Some(existing_role) = role_for_channel.get(&channel_id) {
-                if existing_role.should_override(parent_role) {
-                    continue;
-                }
+            let (parent_role, parent_included) = role_for_channel[&parent_id];
+
+            if let Some((existing_role, included)) = role_for_channel.get(&channel_id) {
+                role_for_channel.insert(
+                    channel_id,
+                    (existing_role.max(parent_role), *included || parent_included),
+                );
+            } else {
+                role_for_channel.insert(
+                    channel_id,
+                    (
+                        parent_role,
+                        parent_included
+                            || parent_channel_id.is_none()
+                            || Some(channel_id) == parent_channel_id,
+                    ),
+                );
             }
-            role_for_channel.insert(channel_id, parent_role);
         }
+        dbg!((&role_for_channel, parent_channel_id));
 
         let mut channels: Vec<Channel> = Vec::new();
         let mut channels_to_remove: HashSet<u64> = HashSet::default();
@@ -620,11 +699,13 @@ impl Database {
 
         while let Some(row) = rows.next().await {
             let channel = row?;
-            let role = role_for_channel[&channel.id];
+            let (role, included) = role_for_channel[&channel.id];
 
-            if role == ChannelRole::Banned
+            if !included
+                || role == ChannelRole::Banned
                 || role == ChannelRole::Guest && channel.visibility != ChannelVisibility::Public
             {
+                dbg!("remove", channel.id);
                 channels_to_remove.insert(channel.id.0 as u64);
                 continue;
             }
@@ -633,7 +714,7 @@ impl Database {
                 id: channel.id,
                 name: channel.name,
                 visibility: channel.visibility,
-                role: role,
+                role,
             });
         }
         drop(rows);
@@ -740,18 +821,8 @@ impl Database {
             }
             results.push((
                 member.user_id,
-                self.get_user_channels(
-                    member.user_id,
-                    vec![channel_member::Model {
-                        id: Default::default(),
-                        channel_id: new_parent,
-                        user_id: member.user_id,
-                        role: member.role,
-                        accepted: true,
-                    }],
-                    &*tx,
-                )
-                .await?,
+                self.get_user_channels(member.user_id, Some(new_parent), &*tx)
+                    .await?,
             ))
         }
 
@@ -782,18 +853,8 @@ impl Database {
             };
             results.push((
                 member.user_id,
-                self.get_user_channels(
-                    member.user_id,
-                    vec![channel_member::Model {
-                        id: Default::default(),
-                        channel_id: public_parent,
-                        user_id: member.user_id,
-                        role: member.role,
-                        accepted: true,
-                    }],
-                    &*tx,
-                )
-                .await?,
+                self.get_user_channels(member.user_id, Some(public_parent), &*tx)
+                    .await?,
             ))
         }
 
@@ -806,7 +867,7 @@ impl Database {
         admin_id: UserId,
         for_user: UserId,
         role: ChannelRole,
-    ) -> Result<channel_member::Model> {
+    ) -> Result<SetMemberRoleResult> {
         self.transaction(|tx| async move {
             self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
                 .await?;
@@ -828,7 +889,24 @@ impl Database {
             update.role = ActiveValue::Set(role);
             let updated = channel_member::Entity::update(update).exec(&*tx).await?;
 
-            Ok(updated)
+            if !updated.accepted {
+                let channel = channel::Entity::find_by_id(channel_id)
+                    .one(&*tx)
+                    .await?
+                    .unwrap();
+
+                return Ok(SetMemberRoleResult::InviteUpdated(Channel {
+                    id: channel.id,
+                    visibility: channel.visibility,
+                    name: channel.name,
+                    role,
+                }));
+            }
+
+            Ok(SetMemberRoleResult::MembershipUpdated(
+                self.calculate_membership_updated(channel_id, for_user, &*tx)
+                    .await?,
+            ))
         })
         .await
     }
@@ -1396,16 +1474,7 @@ impl Database {
                 .await?;
         }
 
-        let membership = channel_member::Entity::find()
-            .filter(
-                channel_member::Column::ChannelId
-                    .eq(channel)
-                    .and(channel_member::Column::UserId.eq(user)),
-            )
-            .all(tx)
-            .await?;
-
-        let mut channel_info = self.get_user_channels(user, membership, &*tx).await?;
+        let mut channel_info = self.get_user_channels(user, Some(channel), &*tx).await?;
 
         channel_info.channels.edges.push(ChannelEdge {
             channel_id: channel.to_proto(),
@@ -1466,8 +1535,6 @@ impl Database {
             .await?
             == 0;
 
-        dbg!(is_stranded, &paths);
-
         // Make sure that there is always at least one path to the channel
         if is_stranded {
             let root_paths: Vec<_> = paths
@@ -1481,7 +1548,6 @@ impl Database {
                 })
                 .collect();
 
-            dbg!(is_stranded, &root_paths);
             channel_path::Entity::insert_many(root_paths)
                 .exec(&*tx)
                 .await?;
@@ -1528,6 +1594,8 @@ impl Database {
                 .into_iter()
                 .collect();
 
+            dbg!(&participants_to_update);
+
             let mut moved_channels: HashSet<ChannelId> = HashSet::default();
             moved_channels.insert(channel_id);
             for edge in self.get_channel_descendants([channel_id], &*tx).await? {

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -160,7 +160,7 @@ async fn test_joining_channels(db: &Arc<Database>) {
     let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap();
 
     // can join a room with membership to its channel
-    let (joined_room, _) = db
+    let (joined_room, _, _) = db
         .join_channel(
             channel_1,
             user_1,

crates/collab/src/rpc.rs 🔗

@@ -3,9 +3,9 @@ mod connection_pool;
 use crate::{
     auth,
     db::{
-        self, BufferId, ChannelId, ChannelsForUser, CreateChannelResult, Database, MessageId,
-        MoveChannelResult, ProjectId, RenameChannelResult, RoomId, ServerId,
-        SetChannelVisibilityResult, User, UserId,
+        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, Database,
+        MembershipUpdated, MessageId, MoveChannelResult, ProjectId, RenameChannelResult, RoomId,
+        ServerId, SetChannelVisibilityResult, User, UserId,
     },
     executor::Executor,
     AppState, Result,
@@ -2266,23 +2266,20 @@ async fn invite_channel_member(
     let db = session.db().await;
     let channel_id = ChannelId::from_proto(request.channel_id);
     let invitee_id = UserId::from_proto(request.user_id);
-    db.invite_channel_member(
-        channel_id,
-        invitee_id,
-        session.user_id,
-        request.role().into(),
-    )
-    .await?;
+    let channel = db
+        .invite_channel_member(
+            channel_id,
+            invitee_id,
+            session.user_id,
+            request.role().into(),
+        )
+        .await?;
 
-    let channel = db.get_channel(channel_id, session.user_id).await?;
+    let update = proto::UpdateChannels {
+        channel_invitations: vec![channel.to_proto()],
+        ..Default::default()
+    };
 
-    let mut update = proto::UpdateChannels::default();
-    update.channel_invitations.push(proto::Channel {
-        id: channel.id.to_proto(),
-        visibility: channel.visibility.into(),
-        name: channel.name,
-        role: request.role().into(),
-    });
     for connection_id in session
         .connection_pool()
         .await
@@ -2304,19 +2301,13 @@ async fn remove_channel_member(
     let channel_id = ChannelId::from_proto(request.channel_id);
     let member_id = UserId::from_proto(request.user_id);
 
-    db.remove_channel_member(channel_id, member_id, session.user_id)
+    let membership_updated = db
+        .remove_channel_member(channel_id, member_id, session.user_id)
         .await?;
 
-    let mut update = proto::UpdateChannels::default();
-    update.delete_channels.push(channel_id.to_proto());
+    dbg!(&membership_updated);
 
-    for connection_id in session
-        .connection_pool()
-        .await
-        .user_connection_ids(member_id)
-    {
-        session.peer.send(connection_id, update.clone())?;
-    }
+    notify_membership_updated(membership_updated, member_id, &session).await?;
 
     response.send(proto::Ack {})?;
     Ok(())
@@ -2347,6 +2338,9 @@ async fn set_channel_visibility(
     }
     for user_id in participants_to_remove {
         let update = proto::UpdateChannels {
+            // for public participants  we only need to remove the current channel
+            // (not descendants)
+            // because they can still see any public descendants
             delete_channels: vec![channel_id.to_proto()],
             ..Default::default()
         };
@@ -2367,7 +2361,7 @@ async fn set_channel_member_role(
     let db = session.db().await;
     let channel_id = ChannelId::from_proto(request.channel_id);
     let member_id = UserId::from_proto(request.user_id);
-    let channel_member = db
+    let result = db
         .set_channel_member_role(
             channel_id,
             session.user_id,
@@ -2376,26 +2370,24 @@ async fn set_channel_member_role(
         )
         .await?;
 
-    let mut update = proto::UpdateChannels::default();
-    if channel_member.accepted {
-        let channels = db.get_channel_for_user(channel_id, member_id).await?;
-        update = build_channels_update(channels, vec![]);
-    } else {
-        let channel = db.get_channel(channel_id, session.user_id).await?;
-        update.channel_invitations.push(proto::Channel {
-            id: channel_id.to_proto(),
-            visibility: channel.visibility.into(),
-            name: channel.name,
-            role: request.role().into(),
-        });
-    }
+    match result {
+        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
+            notify_membership_updated(membership_update, member_id, &session).await?;
+        }
+        db::SetMemberRoleResult::InviteUpdated(channel) => {
+            let update = proto::UpdateChannels {
+                channel_invitations: vec![channel.to_proto()],
+                ..Default::default()
+            };
 
-    for connection_id in session
-        .connection_pool()
-        .await
-        .user_connection_ids(member_id)
-    {
-        session.peer.send(connection_id, update.clone())?;
+            for connection_id in session
+                .connection_pool()
+                .await
+                .user_connection_ids(member_id)
+            {
+                session.peer.send(connection_id, update.clone())?;
+            }
+        }
     }
 
     response.send(proto::Ack {})?;
@@ -2541,35 +2533,26 @@ async fn respond_to_channel_invite(
 ) -> Result<()> {
     let db = session.db().await;
     let channel_id = ChannelId::from_proto(request.channel_id);
-    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
+    let result = db
+        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
         .await?;
 
-    if request.accept {
-        channel_membership_updated(db, channel_id, &session).await?;
+    if let Some(accept_invite_result) = result {
+        notify_membership_updated(accept_invite_result, session.user_id, &session).await?;
     } else {
-        let mut update = proto::UpdateChannels::default();
-        update
-            .remove_channel_invitations
-            .push(channel_id.to_proto());
-        session.peer.send(session.connection_id, update)?;
-    }
-    response.send(proto::Ack {})?;
+        let update = proto::UpdateChannels {
+            remove_channel_invitations: vec![channel_id.to_proto()],
+            ..Default::default()
+        };
 
-    Ok(())
-}
+        let connection_pool = session.connection_pool().await;
+        for connection_id in connection_pool.user_connection_ids(session.user_id) {
+            session.peer.send(connection_id, update.clone())?;
+        }
+    };
 
-async fn channel_membership_updated(
-    db: tokio::sync::MutexGuard<'_, DbHandle>,
-    channel_id: ChannelId,
-    session: &Session,
-) -> Result<(), crate::Error> {
-    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
-    let mut update = build_channels_update(result, vec![]);
-    update
-        .remove_channel_invitations
-        .push(channel_id.to_proto());
+    response.send(proto::Ack {})?;
 
-    session.peer.send(session.connection_id, update)?;
     Ok(())
 }
 
@@ -2605,7 +2588,7 @@ async fn join_channel_internal(
         leave_room_for_session(&session).await?;
         let db = session.db().await;
 
-        let (joined_room, joined_channel) = db
+        let (joined_room, accept_invite_result, role) = db
             .join_channel(
                 channel_id,
                 session.user_id,
@@ -2615,12 +2598,21 @@ async fn join_channel_internal(
             .await?;
 
         let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
-            let token = live_kit
-                .room_token(
-                    &joined_room.room.live_kit_room,
-                    &session.user_id.to_string(),
-                )
-                .trace_err()?;
+            let token = if role == ChannelRole::Guest {
+                live_kit
+                    .guest_token(
+                        &joined_room.room.live_kit_room,
+                        &session.user_id.to_string(),
+                    )
+                    .trace_err()?
+            } else {
+                live_kit
+                    .room_token(
+                        &joined_room.room.live_kit_room,
+                        &session.user_id.to_string(),
+                    )
+                    .trace_err()?
+            };
 
             Some(LiveKitConnectionInfo {
                 server_url: live_kit.url().into(),
@@ -2634,8 +2626,8 @@ async fn join_channel_internal(
             live_kit_connection_info,
         })?;
 
-        if let Some(joined_channel) = joined_channel {
-            channel_membership_updated(db, joined_channel, &session).await?
+        if let Some(accept_invite_result) = accept_invite_result {
+            notify_membership_updated(accept_invite_result, session.user_id, &session).await?;
         }
 
         room_updated(&joined_room.room, &session.peer);
@@ -3051,6 +3043,26 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
     }
 }
 
+async fn notify_membership_updated(
+    result: MembershipUpdated,
+    user_id: UserId,
+    session: &Session,
+) -> Result<()> {
+    let mut update = build_channels_update(result.new_channels, vec![]);
+    update.delete_channels = result
+        .removed_channels
+        .into_iter()
+        .map(|id| id.to_proto())
+        .collect();
+    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
+
+    let connection_pool = session.connection_pool().await;
+    for connection_id in connection_pool.user_connection_ids(user_id) {
+        session.peer.send(connection_id, update.clone())?;
+    }
+    Ok(())
+}
+
 fn build_channels_update(
     channels: ChannelsForUser,
     channel_invites: Vec<db::Channel>,

crates/collab/src/tests/channel_buffer_tests.rs 🔗

@@ -410,7 +410,7 @@ async fn test_channel_buffer_disconnect(
     server.disconnect_client(client_a.peer_id().unwrap());
     deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
 
-    channel_buffer_a.update(cx_a, |buffer, _| {
+    channel_buffer_a.update(cx_a, |buffer, cx| {
         assert_eq!(
             buffer.channel(cx).unwrap().as_ref(),
             &channel(channel_id, "the-channel", proto::ChannelRole::Admin)
@@ -435,7 +435,7 @@ async fn test_channel_buffer_disconnect(
     deterministic.run_until_parked();
 
     // Channel buffer observed the deletion
-    channel_buffer_b.update(cx_b, |buffer, _| {
+    channel_buffer_b.update(cx_b, |buffer, cx| {
         assert_eq!(
             buffer.channel(cx).unwrap().as_ref(),
             &channel(channel_id, "the-channel", proto::ChannelRole::Member)
@@ -699,7 +699,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
         .await
         .unwrap();
     channel_view_1_a.update(cx_a, |notes, cx| {
-        assert_eq!(notes.channel(cx).name, "channel-1");
+        assert_eq!(notes.channel(cx).unwrap().name, "channel-1");
         notes.editor.update(cx, |editor, cx| {
             editor.insert("Hello from A.", cx);
             editor.change_selections(None, cx, |selections| {
@@ -731,7 +731,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
             .expect("active item is not a channel view")
     });
     channel_view_1_b.read_with(cx_b, |notes, cx| {
-        assert_eq!(notes.channel(cx).name, "channel-1");
+        assert_eq!(notes.channel(cx).unwrap().name, "channel-1");
         let editor = notes.editor.read(cx);
         assert_eq!(editor.text(cx), "Hello from A.");
         assert_eq!(editor.selections.ranges::<usize>(cx), &[3..4]);
@@ -743,7 +743,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
         .await
         .unwrap();
     channel_view_2_a.read_with(cx_a, |notes, cx| {
-        assert_eq!(notes.channel(cx).name, "channel-2");
+        assert_eq!(notes.channel(cx).unwrap().name, "channel-2");
     });
 
     // Client B is taken to the notes for channel 2.
@@ -760,7 +760,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
             .expect("active item is not a channel view")
     });
     channel_view_2_b.read_with(cx_b, |notes, cx| {
-        assert_eq!(notes.channel(cx).name, "channel-2");
+        assert_eq!(notes.channel(cx).unwrap().name, "channel-2");
     });
 }
 

crates/collab/src/tests/channel_tests.rs 🔗

@@ -48,13 +48,13 @@ async fn test_core_channels(
                 id: channel_a_id,
                 name: "channel-a".to_string(),
                 depth: 0,
-                user_is_admin: true,
+                role: ChannelRole::Admin,
             },
             ExpectedChannel {
                 id: channel_b_id,
                 name: "channel-b".to_string(),
                 depth: 1,
-                user_is_admin: true,
+                role: ChannelRole::Admin,
             },
         ],
     );
@@ -95,7 +95,7 @@ async fn test_core_channels(
             id: channel_a_id,
             name: "channel-a".to_string(),
             depth: 0,
-            user_is_admin: false,
+            role: ChannelRole::Member,
         }],
     );
 
@@ -142,13 +142,13 @@ async fn test_core_channels(
             ExpectedChannel {
                 id: channel_a_id,
                 name: "channel-a".to_string(),
-                user_is_admin: false,
+                role: ChannelRole::Member,
                 depth: 0,
             },
             ExpectedChannel {
                 id: channel_b_id,
                 name: "channel-b".to_string(),
-                user_is_admin: false,
+                role: ChannelRole::Member,
                 depth: 1,
             },
         ],
@@ -171,19 +171,19 @@ async fn test_core_channels(
             ExpectedChannel {
                 id: channel_a_id,
                 name: "channel-a".to_string(),
-                user_is_admin: false,
+                role: ChannelRole::Member,
                 depth: 0,
             },
             ExpectedChannel {
                 id: channel_b_id,
                 name: "channel-b".to_string(),
-                user_is_admin: false,
+                role: ChannelRole::Member,
                 depth: 1,
             },
             ExpectedChannel {
                 id: channel_c_id,
                 name: "channel-c".to_string(),
-                user_is_admin: false,
+                role: ChannelRole::Member,
                 depth: 2,
             },
         ],
@@ -215,19 +215,19 @@ async fn test_core_channels(
                 id: channel_a_id,
                 name: "channel-a".to_string(),
                 depth: 0,
-                user_is_admin: true,
+                role: ChannelRole::Admin,
             },
             ExpectedChannel {
                 id: channel_b_id,
                 name: "channel-b".to_string(),
                 depth: 1,
-                user_is_admin: true,
+                role: ChannelRole::Admin,
             },
             ExpectedChannel {
                 id: channel_c_id,
                 name: "channel-c".to_string(),
                 depth: 2,
-                user_is_admin: true,
+                role: ChannelRole::Admin,
             },
         ],
     );
@@ -249,7 +249,7 @@ async fn test_core_channels(
             id: channel_a_id,
             name: "channel-a".to_string(),
             depth: 0,
-            user_is_admin: true,
+            role: ChannelRole::Admin,
         }],
     );
     assert_channels(
@@ -259,7 +259,7 @@ async fn test_core_channels(
             id: channel_a_id,
             name: "channel-a".to_string(),
             depth: 0,
-            user_is_admin: true,
+            role: ChannelRole::Admin,
         }],
     );
 
@@ -282,7 +282,7 @@ async fn test_core_channels(
             id: channel_a_id,
             name: "channel-a".to_string(),
             depth: 0,
-            user_is_admin: true,
+            role: ChannelRole::Admin,
         }],
     );
 
@@ -304,7 +304,7 @@ async fn test_core_channels(
             id: channel_a_id,
             name: "channel-a".to_string(),
             depth: 0,
-            user_is_admin: true,
+            role: ChannelRole::Admin,
         }],
     );
 }
@@ -412,7 +412,7 @@ async fn test_channel_room(
             id: zed_id,
             name: "zed".to_string(),
             depth: 0,
-            user_is_admin: false,
+            role: ChannelRole::Member,
         }],
     );
     client_b.channel_store().read_with(cx_b, |channels, _| {
@@ -645,7 +645,7 @@ async fn test_permissions_update_while_invited(
             depth: 0,
             id: rust_id,
             name: "rust".to_string(),
-            user_is_admin: false,
+            role: ChannelRole::Member,
         }],
     );
     assert_channels(client_b.channel_store(), cx_b, &[]);
@@ -673,7 +673,7 @@ async fn test_permissions_update_while_invited(
             depth: 0,
             id: rust_id,
             name: "rust".to_string(),
-            user_is_admin: false,
+            role: ChannelRole::Member,
         }],
     );
     assert_channels(client_b.channel_store(), cx_b, &[]);
@@ -713,7 +713,7 @@ async fn test_channel_rename(
             depth: 0,
             id: rust_id,
             name: "rust-archive".to_string(),
-            user_is_admin: true,
+            role: ChannelRole::Admin,
         }],
     );
 
@@ -725,7 +725,7 @@ async fn test_channel_rename(
             depth: 0,
             id: rust_id,
             name: "rust-archive".to_string(),
-            user_is_admin: false,
+            role: ChannelRole::Member,
         }],
     );
 }
@@ -848,7 +848,7 @@ async fn test_lost_channel_creation(
             depth: 0,
             id: channel_id,
             name: "x".to_string(),
-            user_is_admin: false,
+            role: ChannelRole::Member,
         }],
     );
 
@@ -872,13 +872,13 @@ async fn test_lost_channel_creation(
                 depth: 0,
                 id: channel_id,
                 name: "x".to_string(),
-                user_is_admin: true,
+                role: ChannelRole::Admin,
             },
             ExpectedChannel {
                 depth: 1,
                 id: subchannel_id,
                 name: "subchannel".to_string(),
-                user_is_admin: true,
+                role: ChannelRole::Admin,
             },
         ],
     );
@@ -903,13 +903,13 @@ async fn test_lost_channel_creation(
                 depth: 0,
                 id: channel_id,
                 name: "x".to_string(),
-                user_is_admin: false,
+                role: ChannelRole::Member,
             },
             ExpectedChannel {
                 depth: 1,
                 id: subchannel_id,
                 name: "subchannel".to_string(),
-                user_is_admin: false,
+                role: ChannelRole::Member,
             },
         ],
     );
@@ -969,8 +969,7 @@ async fn test_channel_link_notifications(
 
     // we have an admin (a), member (b) and guest (c) all part of the zed channel.
 
-    // create a new private sub-channel
-    // create a new priate channel, make it public, and move it under the previous one, and verify it shows for b and c
+    // create a new private channel, make it public, and move it under the previous one, and verify it shows for b and not c
     let active_channel = client_a
         .channel_store()
         .update(cx_a, |channel_store, cx| {
@@ -1118,6 +1117,117 @@ async fn test_channel_link_notifications(
     );
 }
 
+#[gpui::test]
+async fn test_channel_membership_notifications(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+
+    deterministic.forbid_parking();
+
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_c").await;
+
+    let user_b = client_b.user_id().unwrap();
+
+    let channels = server
+        .make_channel_tree(
+            &[
+                ("zed", None),
+                ("active", Some("zed")),
+                ("vim", Some("active")),
+            ],
+            (&client_a, cx_a),
+        )
+        .await;
+    let zed_channel = channels[0];
+    let _active_channel = channels[1];
+    let vim_channel = channels[2];
+
+    try_join_all(client_a.channel_store().update(cx_a, |channel_store, cx| {
+        [
+            channel_store.set_channel_visibility(zed_channel, proto::ChannelVisibility::Public, cx),
+            channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Public, cx),
+            channel_store.invite_member(vim_channel, user_b, proto::ChannelRole::Member, cx),
+            channel_store.invite_member(zed_channel, user_b, proto::ChannelRole::Guest, cx),
+        ]
+    }))
+    .await
+    .unwrap();
+
+    deterministic.run_until_parked();
+
+    client_b
+        .channel_store()
+        .update(cx_b, |channel_store, _| {
+            channel_store.respond_to_channel_invite(zed_channel, true)
+        })
+        .await
+        .unwrap();
+
+    client_b
+        .channel_store()
+        .update(cx_b, |channel_store, _| {
+            channel_store.respond_to_channel_invite(vim_channel, true)
+        })
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+
+    // we have an admin (a), and a guest (b) with access to all of zed, and membership in vim.
+    assert_channels(
+        client_b.channel_store(),
+        cx_b,
+        &[
+            ExpectedChannel {
+                depth: 0,
+                id: zed_channel,
+                name: "zed".to_string(),
+                role: ChannelRole::Guest,
+            },
+            ExpectedChannel {
+                depth: 1,
+                id: vim_channel,
+                name: "vim".to_string(),
+                role: ChannelRole::Member,
+            },
+        ],
+    );
+
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, cx| {
+            channel_store.remove_member(vim_channel, user_b, cx)
+        })
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+
+    assert_channels(
+        client_b.channel_store(),
+        cx_b,
+        &[
+            ExpectedChannel {
+                depth: 0,
+                id: zed_channel,
+                name: "zed".to_string(),
+                role: ChannelRole::Guest,
+            },
+            ExpectedChannel {
+                depth: 1,
+                id: vim_channel,
+                name: "vim".to_string(),
+                role: ChannelRole::Guest,
+            },
+        ],
+    )
+}
+
 #[gpui::test]
 async fn test_guest_access(
     deterministic: Arc<Deterministic>,
@@ -1485,7 +1595,7 @@ struct ExpectedChannel {
     depth: usize,
     id: ChannelId,
     name: String,
-    user_is_admin: bool,
+    role: ChannelRole,
 }
 
 #[track_caller]
@@ -1502,7 +1612,7 @@ fn assert_channel_invitations(
                 depth: 0,
                 name: channel.name.clone(),
                 id: channel.id,
-                user_is_admin: store.is_channel_admin(channel.id),
+                role: channel.role,
             })
             .collect::<Vec<_>>()
     });
@@ -1522,7 +1632,7 @@ fn assert_channels(
                 depth,
                 name: channel.name.clone(),
                 id: channel.id,
-                user_is_admin: store.is_channel_admin(channel.id),
+                role: channel.role,
             })
             .collect::<Vec<_>>()
     });

crates/collab/src/tests/random_channel_buffer_tests.rs 🔗

@@ -99,14 +99,14 @@ impl RandomizedTest for RandomChannelBufferTest {
                 30..=40 => {
                     if let Some(buffer) = channel_buffers.iter().choose(rng) {
                         let channel_name =
-                            buffer.read_with(cx, |b, _| b.channel(cx).unwrap().name.clone());
+                            buffer.read_with(cx, |b, cx| b.channel(cx).unwrap().name.clone());
                         break ChannelBufferOperation::LeaveChannelNotes { channel_name };
                     }
                 }
 
                 _ => {
                     if let Some(buffer) = channel_buffers.iter().choose(rng) {
-                        break buffer.read_with(cx, |b, _| {
+                        break buffer.read_with(cx, |b, cx| {
                             let channel_name = b.channel(cx).unwrap().name.clone();
                             let edits = b
                                 .buffer()

crates/live_kit_client/src/test.rs 🔗

@@ -306,6 +306,16 @@ impl live_kit_server::api::Client for TestApiClient {
             token::VideoGrant::to_join(room),
         )
     }
+
+    fn guest_token(&self, room: &str, identity: &str) -> Result<String> {
+        let server = TestServer::get(&self.url)?;
+        token::create(
+            &server.api_key,
+            &server.secret_key,
+            Some(identity),
+            token::VideoGrant::for_guest(room),
+        )
+    }
 }
 
 pub type Sid = String;

crates/live_kit_server/src/api.rs 🔗

@@ -12,6 +12,7 @@ pub trait Client: Send + Sync {
     async fn delete_room(&self, name: String) -> Result<()>;
     async fn remove_participant(&self, room: String, identity: String) -> Result<()>;
     fn room_token(&self, room: &str, identity: &str) -> Result<String>;
+    fn guest_token(&self, room: &str, identity: &str) -> Result<String>;
 }
 
 #[derive(Clone)]
@@ -138,4 +139,13 @@ impl Client for LiveKitClient {
             token::VideoGrant::to_join(room),
         )
     }
+
+    fn guest_token(&self, room: &str, identity: &str) -> Result<String> {
+        token::create(
+            &self.key,
+            &self.secret,
+            Some(identity),
+            token::VideoGrant::for_guest(room),
+        )
+    }
 }

crates/live_kit_server/src/token.rs 🔗

@@ -57,6 +57,15 @@ impl<'a> VideoGrant<'a> {
             ..Default::default()
         }
     }
+
+    pub fn for_guest(room: &'a str) -> Self {
+        Self {
+            room: Some(Cow::Borrowed(room)),
+            room_join: Some(true),
+            can_subscribe: Some(true),
+            ..Default::default()
+        }
+    }
 }
 
 pub fn create(