maintain channel subscriptions in RAM (#9512)

Conrad Irwin created

This avoids a giant database query on every leave/join event.

Release Notes:

- N/A

Change summary

crates/collab/src/db.rs                     |  14 -
crates/collab/src/db/ids.rs                 |   4 
crates/collab/src/db/queries/channels.rs    |  53 +-----
crates/collab/src/db/queries/rooms.rs       |  33 ---
crates/collab/src/db/tests/channel_tests.rs |   3 
crates/collab/src/rpc.rs                    | 174 ++++++++++++----------
crates/collab/src/rpc/connection_pool.rs    | 106 +++++++++++++
7 files changed, 222 insertions(+), 165 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -546,7 +546,7 @@ pub struct Channel {
 }
 
 impl Channel {
-    fn from_model(value: channel::Model) -> Self {
+    pub fn from_model(value: channel::Model) -> Self {
         Channel {
             id: value.id,
             visibility: value.visibility,
@@ -604,16 +604,14 @@ pub struct RejoinedChannelBuffer {
 #[derive(Clone)]
 pub struct JoinRoom {
     pub room: proto::Room,
-    pub channel_id: Option<ChannelId>,
-    pub channel_members: Vec<UserId>,
+    pub channel: Option<channel::Model>,
 }
 
 pub struct RejoinedRoom {
     pub room: proto::Room,
     pub rejoined_projects: Vec<RejoinedProject>,
     pub reshared_projects: Vec<ResharedProject>,
-    pub channel_id: Option<ChannelId>,
-    pub channel_members: Vec<UserId>,
+    pub channel: Option<channel::Model>,
 }
 
 pub struct ResharedProject {
@@ -649,8 +647,7 @@ pub struct RejoinedWorktree {
 
 pub struct LeftRoom {
     pub room: proto::Room,
-    pub channel_id: Option<ChannelId>,
-    pub channel_members: Vec<UserId>,
+    pub channel: Option<channel::Model>,
     pub left_projects: HashMap<ProjectId, LeftProject>,
     pub canceled_calls_to_user_ids: Vec<UserId>,
     pub deleted: bool,
@@ -658,8 +655,7 @@ pub struct LeftRoom {
 
 pub struct RefreshedRoom {
     pub room: proto::Room,
-    pub channel_id: Option<ChannelId>,
-    pub channel_members: Vec<UserId>,
+    pub channel: Option<channel::Model>,
     pub stale_participant_user_ids: Vec<UserId>,
     pub canceled_calls_to_user_ids: Vec<UserId>,
 }

crates/collab/src/db/ids.rs 🔗

@@ -91,7 +91,9 @@ id_type!(NotificationKindId);
 id_type!(HostedProjectId);
 
 /// ChannelRole gives you permissions for both channels and calls.
-#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)]
+#[derive(
+    Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
+)]
 #[sea_orm(rs_type = "String", db_type = "String(None)")]
 pub enum ChannelRole {
     /// Admin can read/write and change permissions.

crates/collab/src/db/queries/channels.rs 🔗

@@ -45,11 +45,7 @@ impl Database {
         name: &str,
         parent_channel_id: Option<ChannelId>,
         admin_id: UserId,
-    ) -> Result<(
-        Channel,
-        Option<channel_member::Model>,
-        Vec<channel_member::Model>,
-    )> {
+    ) -> Result<(channel::Model, Option<channel_member::Model>)> {
         let name = Self::sanitize_channel_name(name)?;
         self.transaction(move |tx| async move {
             let mut parent = None;
@@ -90,12 +86,7 @@ impl Database {
                 );
             }
 
-            let channel_members = channel_member::Entity::find()
-                .filter(channel_member::Column::ChannelId.eq(channel.root_id()))
-                .all(&*tx)
-                .await?;
-
-            Ok((Channel::from_model(channel), membership, channel_members))
+            Ok((channel, membership))
         })
         .await
     }
@@ -181,7 +172,7 @@ impl Database {
         channel_id: ChannelId,
         visibility: ChannelVisibility,
         admin_id: UserId,
-    ) -> Result<(Channel, Vec<channel_member::Model>)> {
+    ) -> Result<channel::Model> {
         self.transaction(move |tx| async move {
             let channel = self.get_channel_internal(channel_id, &tx).await?;
             self.check_user_is_channel_admin(&channel, admin_id, &tx)
@@ -214,12 +205,7 @@ impl Database {
             model.visibility = ActiveValue::Set(visibility);
             let channel = model.update(&*tx).await?;
 
-            let channel_members = channel_member::Entity::find()
-                .filter(channel_member::Column::ChannelId.eq(channel.root_id()))
-                .all(&*tx)
-                .await?;
-
-            Ok((Channel::from_model(channel), channel_members))
+            Ok(channel)
         })
         .await
     }
@@ -245,21 +231,12 @@ impl Database {
         &self,
         channel_id: ChannelId,
         user_id: UserId,
-    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
+    ) -> Result<(ChannelId, Vec<ChannelId>)> {
         self.transaction(move |tx| async move {
             let channel = self.get_channel_internal(channel_id, &tx).await?;
             self.check_user_is_channel_admin(&channel, user_id, &tx)
                 .await?;
 
-            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
-                .filter(channel_member::Column::ChannelId.eq(channel.root_id()))
-                .select_only()
-                .column(channel_member::Column::UserId)
-                .distinct()
-                .into_values::<_, QueryUserIds>()
-                .all(&*tx)
-                .await?;
-
             let channels_to_remove = self
                 .get_channel_descendants_excluding_self([&channel], &tx)
                 .await?
@@ -273,7 +250,7 @@ impl Database {
                 .exec(&*tx)
                 .await?;
 
-            Ok((channels_to_remove, members_to_notify))
+            Ok((channel.root_id(), channels_to_remove))
         })
         .await
     }
@@ -343,7 +320,7 @@ impl Database {
         channel_id: ChannelId,
         admin_id: UserId,
         new_name: &str,
-    ) -> Result<(Channel, Vec<channel_member::Model>)> {
+    ) -> Result<channel::Model> {
         self.transaction(move |tx| async move {
             let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 
@@ -355,12 +332,7 @@ impl Database {
             model.name = ActiveValue::Set(new_name.clone());
             let channel = model.update(&*tx).await?;
 
-            let channel_members = channel_member::Entity::find()
-                .filter(channel_member::Column::ChannelId.eq(channel.root_id()))
-                .all(&*tx)
-                .await?;
-
-            Ok((Channel::from_model(channel), channel_members))
+            Ok(channel)
         })
         .await
     }
@@ -984,7 +956,7 @@ impl Database {
         channel_id: ChannelId,
         new_parent_id: ChannelId,
         admin_id: UserId,
-    ) -> Result<(Vec<Channel>, Vec<channel_member::Model>)> {
+    ) -> Result<(ChannelId, Vec<Channel>)> {
         self.transaction(|tx| async move {
             let channel = self.get_channel_internal(channel_id, &tx).await?;
             self.check_user_is_channel_admin(&channel, admin_id, &tx)
@@ -1039,12 +1011,7 @@ impl Database {
                 .map(|c| Channel::from_model(c))
                 .collect::<Vec<_>>();
 
-            let channel_members = channel_member::Entity::find()
-                .filter(channel_member::Column::ChannelId.eq(root_id))
-                .all(&*tx)
-                .await?;
-
-            Ok((channels, channel_members))
+            Ok((root_id, channels))
         })
         .await
     }

crates/collab/src/db/queries/rooms.rs 🔗

@@ -52,12 +52,7 @@ impl Database {
             );
 
             let (channel, room) = self.get_channel_room(room_id, &tx).await?;
-            let channel_members;
-            if let Some(channel) = &channel {
-                channel_members = self.get_channel_participants(channel, &tx).await?;
-            } else {
-                channel_members = Vec::new();
-
+            if channel.is_none() {
                 // Delete the room if it becomes empty.
                 if room.participants.is_empty() {
                     project::Entity::delete_many()
@@ -70,8 +65,7 @@ impl Database {
 
             Ok(RefreshedRoom {
                 room,
-                channel_id: channel.map(|channel| channel.id),
-                channel_members,
+                channel,
                 stale_participant_user_ids,
                 canceled_calls_to_user_ids,
             })
@@ -349,8 +343,7 @@ impl Database {
             let room = self.get_room(room_id, &tx).await?;
             Ok(JoinRoom {
                 room,
-                channel_id: None,
-                channel_members: vec![],
+                channel: None,
             })
         })
         .await
@@ -446,11 +439,9 @@ impl Database {
 
         let (channel, room) = self.get_channel_room(room_id, &tx).await?;
         let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?;
-        let channel_members = self.get_channel_participants(&channel, tx).await?;
         Ok(JoinRoom {
             room,
-            channel_id: Some(channel.id),
-            channel_members,
+            channel: Some(channel),
         })
     }
 
@@ -736,16 +727,10 @@ impl Database {
             }
 
             let (channel, room) = self.get_channel_room(room_id, &tx).await?;
-            let channel_members = if let Some(channel) = &channel {
-                self.get_channel_participants(&channel, &tx).await?
-            } else {
-                Vec::new()
-            };
 
             Ok(RejoinedRoom {
                 room,
-                channel_id: channel.map(|channel| channel.id),
-                channel_members,
+                channel,
                 rejoined_projects,
                 reshared_projects,
             })
@@ -902,15 +887,9 @@ impl Database {
                     false
                 };
 
-                let channel_members = if let Some(channel) = &channel {
-                    self.get_channel_participants(channel, &tx).await?
-                } else {
-                    Vec::new()
-                };
                 let left_room = LeftRoom {
                     room,
-                    channel_id: channel.map(|channel| channel.id),
-                    channel_members,
+                    channel,
                     left_projects,
                     canceled_calls_to_user_ids,
                     deleted,

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -109,10 +109,9 @@ async fn test_channels(db: &Arc<Database>) {
     assert!(db.get_channel(crdb_id, a_id).await.is_err());
 
     // Remove a channel tree
-    let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap();
+    let (_, mut channel_ids) = db.delete_channel(rust_id, a_id).await.unwrap();
     channel_ids.sort();
     assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]);
-    assert_eq!(user_ids, &[a_id]);
 
     assert!(db.get_channel(rust_id, a_id).await.is_err());
     assert!(db.get_channel(cargo_id, a_id).await.is_err());

crates/collab/src/rpc.rs 🔗

@@ -3,10 +3,10 @@ mod connection_pool;
 use crate::{
     auth::{self, Impersonator},
     db::{
-        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
-        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, ProjectId,
-        RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, User,
-        UserId,
+        self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage,
+        Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
+        ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
+        User, UserId,
     },
     executor::Executor,
     AppState, Error, Result,
@@ -351,14 +351,8 @@ impl Server {
                                 "refreshed room"
                             );
                             room_updated(&refreshed_room.room, &peer);
-                            if let Some(channel_id) = refreshed_room.channel_id {
-                                channel_updated(
-                                    channel_id,
-                                    &refreshed_room.room,
-                                    &refreshed_room.channel_members,
-                                    &peer,
-                                    &pool.lock(),
-                                );
+                            if let Some(channel) = refreshed_room.channel.as_ref() {
+                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
                             }
                             contacts_to_update
                                 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
@@ -699,6 +693,9 @@ impl Server {
         {
             let mut pool = self.connection_pool.lock();
             pool.add_connection(connection_id, user.id, user.admin, zed_version);
+            for membership in &channels_for_user.channel_memberships {
+                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
+            }
             self.peer.send(
                 connection_id,
                 build_initial_contacts_update(contacts, &pool),
@@ -1148,8 +1145,7 @@ async fn rejoin_room(
     session: Session,
 ) -> Result<()> {
     let room;
-    let channel_id;
-    let channel_members;
+    let channel;
     {
         let mut rejoined_room = session
             .db()
@@ -1315,15 +1311,13 @@ async fn rejoin_room(
         let rejoined_room = rejoined_room.into_inner();
 
         room = rejoined_room.room;
-        channel_id = rejoined_room.channel_id;
-        channel_members = rejoined_room.channel_members;
+        channel = rejoined_room.channel;
     }
 
-    if let Some(channel_id) = channel_id {
+    if let Some(channel) = channel {
         channel_updated(
-            channel_id,
+            &channel,
             &room,
-            &channel_members,
             &session.peer,
             &*session.connection_pool().await,
         );
@@ -2427,31 +2421,39 @@ async fn create_channel(
     let db = session.db().await;
 
     let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
-    let (channel, owner, channel_members) = db
+    let (channel, membership) = db
         .create_channel(&request.name, parent_id, session.user_id)
         .await?;
 
+    let root_id = channel.root_id();
+    let channel = Channel::from_model(channel);
+
     response.send(proto::CreateChannelResponse {
         channel: Some(channel.to_proto()),
         parent_id: request.parent_id,
     })?;
 
-    let connection_pool = session.connection_pool().await;
-    if let Some(owner) = owner {
+    let mut connection_pool = session.connection_pool().await;
+    if let Some(membership) = membership {
+        connection_pool.subscribe_to_channel(
+            membership.user_id,
+            membership.channel_id,
+            membership.role,
+        );
         let update = proto::UpdateUserChannels {
             channel_memberships: vec![proto::ChannelMembership {
-                channel_id: owner.channel_id.to_proto(),
-                role: owner.role.into(),
+                channel_id: membership.channel_id.to_proto(),
+                role: membership.role.into(),
             }],
             ..Default::default()
         };
-        for connection_id in connection_pool.user_connection_ids(owner.user_id) {
+        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
             session.peer.send(connection_id, update.clone())?;
         }
     }
 
-    for channel_member in channel_members {
-        if !channel_member.role.can_see_channel(channel.visibility) {
+    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
+        if !role.can_see_channel(channel.visibility) {
             continue;
         }
 
@@ -2459,9 +2461,7 @@ async fn create_channel(
             channels: vec![channel.to_proto()],
             ..Default::default()
         };
-        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
-            session.peer.send(connection_id, update.clone())?;
-        }
+        session.peer.send(connection_id, update.clone())?;
     }
 
     Ok(())
@@ -2476,7 +2476,7 @@ async fn delete_channel(
     let db = session.db().await;
 
     let channel_id = request.channel_id;
-    let (removed_channels, member_ids) = db
+    let (root_channel, removed_channels) = db
         .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
         .await?;
     response.send(proto::Ack {})?;
@@ -2488,10 +2488,8 @@ async fn delete_channel(
         .extend(removed_channels.into_iter().map(|id| id.to_proto()));
 
     let connection_pool = session.connection_pool().await;
-    for member_id in member_ids {
-        for connection_id in connection_pool.user_connection_ids(member_id) {
-            session.peer.send(connection_id, update.clone())?;
-        }
+    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
+        session.peer.send(connection_id, update.clone())?;
     }
 
     Ok(())
@@ -2551,9 +2549,9 @@ async fn remove_channel_member(
         .remove_channel_member(channel_id, member_id, session.user_id)
         .await?;
 
-    let connection_pool = &session.connection_pool().await;
+    let mut connection_pool = session.connection_pool().await;
     notify_membership_updated(
-        &connection_pool,
+        &mut connection_pool,
         membership_update,
         member_id,
         &session.peer,
@@ -2588,25 +2586,33 @@ async fn set_channel_visibility(
     let channel_id = ChannelId::from_proto(request.channel_id);
     let visibility = request.visibility().into();
 
-    let (channel, channel_members) = db
+    let channel_model = db
         .set_channel_visibility(channel_id, visibility, session.user_id)
         .await?;
+    let root_id = channel_model.root_id();
+    let channel = Channel::from_model(channel_model);
 
-    let connection_pool = session.connection_pool().await;
-    for member in channel_members {
-        let update = if member.role.can_see_channel(channel.visibility) {
+    let mut connection_pool = session.connection_pool().await;
+    for (user_id, role) in connection_pool
+        .channel_user_ids(root_id)
+        .collect::<Vec<_>>()
+        .into_iter()
+    {
+        let update = if role.can_see_channel(channel.visibility) {
+            connection_pool.subscribe_to_channel(user_id, channel_id, role);
             proto::UpdateChannels {
                 channels: vec![channel.to_proto()],
                 ..Default::default()
             }
         } else {
+            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
             proto::UpdateChannels {
                 delete_channels: vec![channel.id.to_proto()],
                 ..Default::default()
             }
         };
 
-        for connection_id in connection_pool.user_connection_ids(member.user_id) {
+        for connection_id in connection_pool.user_connection_ids(user_id) {
             session.peer.send(connection_id, update.clone())?;
         }
     }
@@ -2635,9 +2641,9 @@ async fn set_channel_member_role(
 
     match result {
         db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
-            let connection_pool = session.connection_pool().await;
+            let mut connection_pool = session.connection_pool().await;
             notify_membership_updated(
-                &connection_pool,
+                &mut connection_pool,
                 membership_update,
                 member_id,
                 &session.peer,
@@ -2671,24 +2677,23 @@ async fn rename_channel(
 ) -> Result<()> {
     let db = session.db().await;
     let channel_id = ChannelId::from_proto(request.channel_id);
-    let (channel, channel_members) = db
+    let channel_model = db
         .rename_channel(channel_id, session.user_id, &request.name)
         .await?;
+    let root_id = channel_model.root_id();
+    let channel = Channel::from_model(channel_model);
 
     response.send(proto::RenameChannelResponse {
         channel: Some(channel.to_proto()),
     })?;
 
     let connection_pool = session.connection_pool().await;
-    for channel_member in channel_members {
-        if !channel_member.role.can_see_channel(channel.visibility) {
-            continue;
-        }
-        let update = proto::UpdateChannels {
-            channels: vec![channel.to_proto()],
-            ..Default::default()
-        };
-        for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
+    let update = proto::UpdateChannels {
+        channels: vec![channel.to_proto()],
+        ..Default::default()
+    };
+    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
+        if role.can_see_channel(channel.visibility) {
             session.peer.send(connection_id, update.clone())?;
         }
     }
@@ -2705,18 +2710,18 @@ async fn move_channel(
     let channel_id = ChannelId::from_proto(request.channel_id);
     let to = ChannelId::from_proto(request.to);
 
-    let (channels, channel_members) = session
+    let (root_id, channels) = session
         .db()
         .await
         .move_channel(channel_id, to, session.user_id)
         .await?;
 
     let connection_pool = session.connection_pool().await;
-    for member in channel_members {
+    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
         let channels = channels
             .iter()
             .filter_map(|channel| {
-                if member.role.can_see_channel(channel.visibility) {
+                if role.can_see_channel(channel.visibility) {
                     Some(channel.to_proto())
                 } else {
                     None
@@ -2732,9 +2737,7 @@ async fn move_channel(
             ..Default::default()
         };
 
-        for connection_id in connection_pool.user_connection_ids(member.user_id) {
-            session.peer.send(connection_id, update.clone())?;
-        }
+        session.peer.send(connection_id, update.clone())?;
     }
 
     response.send(Ack {})?;
@@ -2771,10 +2774,10 @@ async fn respond_to_channel_invite(
         .respond_to_channel_invite(channel_id, session.user_id, request.accept)
         .await?;
 
-    let connection_pool = session.connection_pool().await;
+    let mut connection_pool = session.connection_pool().await;
     if let Some(membership_update) = membership_update {
         notify_membership_updated(
-            &connection_pool,
+            &mut connection_pool,
             membership_update,
             session.user_id,
             &session.peer,
@@ -2866,14 +2869,17 @@ async fn join_channel_internal(
 
         response.send(proto::JoinRoomResponse {
             room: Some(joined_room.room.clone()),
-            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
+            channel_id: joined_room
+                .channel
+                .as_ref()
+                .map(|channel| channel.id.to_proto()),
             live_kit_connection_info,
         })?;
 
-        let connection_pool = session.connection_pool().await;
+        let mut connection_pool = session.connection_pool().await;
         if let Some(membership_updated) = membership_updated {
             notify_membership_updated(
-                &connection_pool,
+                &mut connection_pool,
                 membership_updated,
                 session.user_id,
                 &session.peer,
@@ -2886,9 +2892,10 @@ async fn join_channel_internal(
     };
 
     channel_updated(
-        channel_id,
+        &joined_room
+            .channel
+            .ok_or_else(|| anyhow!("channel not returned"))?,
         &joined_room.room,
-        &joined_room.channel_members,
         &session.peer,
         &*session.connection_pool().await,
     );
@@ -3403,11 +3410,18 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
 }
 
 fn notify_membership_updated(
-    connection_pool: &ConnectionPool,
+    connection_pool: &mut ConnectionPool,
     result: MembershipUpdated,
     user_id: UserId,
     peer: &Peer,
 ) {
+    for membership in &result.new_channels.channel_memberships {
+        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
+    }
+    for channel_id in &result.removed_channels {
+        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
+    }
+
     let user_channels_update = proto::UpdateUserChannels {
         channel_memberships: result
             .new_channels
@@ -3420,6 +3434,7 @@ fn notify_membership_updated(
             .collect(),
         ..Default::default()
     };
+
     let mut update = build_channels_update(result.new_channels, vec![]);
     update.delete_channels = result
         .removed_channels
@@ -3533,9 +3548,8 @@ fn room_updated(room: &proto::Room, peer: &Peer) {
 }
 
 fn channel_updated(
-    channel_id: ChannelId,
+    channel: &db::channel::Model,
     room: &proto::Room,
-    channel_members: &[UserId],
     peer: &Peer,
     pool: &ConnectionPool,
 ) {
@@ -3547,15 +3561,16 @@ fn channel_updated(
 
     broadcast(
         None,
-        channel_members
-            .iter()
-            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
+        pool.channel_connection_ids(channel.root_id())
+            .filter_map(|(channel_id, role)| {
+                role.can_see_channel(channel.visibility).then(|| channel_id)
+            }),
         |peer_id| {
             peer.send(
                 peer_id,
                 proto::UpdateChannels {
                     channel_participants: vec![proto::ChannelParticipants {
-                        channel_id: channel_id.to_proto(),
+                        channel_id: channel.id.to_proto(),
                         participant_user_ids: participants.clone(),
                     }],
                     ..Default::default()
@@ -3608,8 +3623,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
     let live_kit_room;
     let delete_live_kit_room;
     let room;
-    let channel_members;
-    let channel_id;
+    let channel;
 
     if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
         contacts_to_update.insert(session.user_id);
@@ -3623,19 +3637,17 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
         live_kit_room = mem::take(&mut left_room.room.live_kit_room);
         delete_live_kit_room = left_room.deleted;
         room = mem::take(&mut left_room.room);
-        channel_members = mem::take(&mut left_room.channel_members);
-        channel_id = left_room.channel_id;
+        channel = mem::take(&mut left_room.channel);
 
         room_updated(&room, &session.peer);
     } else {
         return Ok(());
     }
 
-    if let Some(channel_id) = channel_id {
+    if let Some(channel) = channel {
         channel_updated(
-            channel_id,
+            &channel,
             &room,
-            &channel_members,
             &session.peer,
             &*session.connection_pool().await,
         );

crates/collab/src/rpc/connection_pool.rs 🔗

@@ -1,6 +1,6 @@
-use crate::db::UserId;
+use crate::db::{ChannelId, ChannelRole, UserId};
 use anyhow::{anyhow, Result};
-use collections::{BTreeMap, HashSet};
+use collections::{BTreeMap, HashMap, HashSet};
 use rpc::ConnectionId;
 use serde::Serialize;
 use tracing::instrument;
@@ -10,6 +10,7 @@ use util::SemanticVersion;
 pub struct ConnectionPool {
     connections: BTreeMap<ConnectionId, Connection>,
     connected_users: BTreeMap<UserId, ConnectedUser>,
+    channels: ChannelPool,
 }
 
 #[derive(Default, Serialize)]
@@ -47,6 +48,7 @@ impl ConnectionPool {
     pub fn reset(&mut self) {
         self.connections.clear();
         self.connected_users.clear();
+        self.channels.clear();
     }
 
     #[instrument(skip(self))]
@@ -81,6 +83,7 @@ impl ConnectionPool {
         connected_user.connection_ids.remove(&connection_id);
         if connected_user.connection_ids.is_empty() {
             self.connected_users.remove(&user_id);
+            self.channels.remove_user(&user_id);
         }
         self.connections.remove(&connection_id).unwrap();
         Ok(())
@@ -110,6 +113,38 @@ impl ConnectionPool {
             .copied()
     }
 
+    pub fn channel_user_ids(
+        &self,
+        channel_id: ChannelId,
+    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
+        self.channels.users_to_notify(channel_id)
+    }
+
+    pub fn channel_connection_ids(
+        &self,
+        channel_id: ChannelId,
+    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
+        self.channels
+            .users_to_notify(channel_id)
+            .flat_map(|(user_id, role)| {
+                self.user_connection_ids(user_id)
+                    .map(move |connection_id| (connection_id, role))
+            })
+    }
+
+    pub fn subscribe_to_channel(
+        &mut self,
+        user_id: UserId,
+        channel_id: ChannelId,
+        role: ChannelRole,
+    ) {
+        self.channels.subscribe(user_id, channel_id, role);
+    }
+
+    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
+        self.channels.unsubscribe(user_id, channel_id);
+    }
+
     pub fn is_user_online(&self, user_id: UserId) -> bool {
         !self
             .connected_users
@@ -140,3 +175,70 @@ impl ConnectionPool {
         }
     }
 }
+
+#[derive(Default, Serialize)]
+pub struct ChannelPool {
+    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
+    by_channel: HashMap<ChannelId, HashSet<UserId>>,
+}
+
+impl ChannelPool {
+    pub fn clear(&mut self) {
+        self.by_user.clear();
+        self.by_channel.clear();
+    }
+
+    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
+        self.by_user
+            .entry(user_id)
+            .or_default()
+            .insert(channel_id, role);
+        self.by_channel
+            .entry(channel_id)
+            .or_default()
+            .insert(user_id);
+    }
+
+    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
+        if let Some(channels) = self.by_user.get_mut(user_id) {
+            channels.remove(channel_id);
+            if channels.is_empty() {
+                self.by_user.remove(user_id);
+            }
+        }
+        if let Some(users) = self.by_channel.get_mut(channel_id) {
+            users.remove(user_id);
+            if users.is_empty() {
+                self.by_channel.remove(channel_id);
+            }
+        }
+    }
+
+    pub fn remove_user(&mut self, user_id: &UserId) {
+        if let Some(channels) = self.by_user.remove(&user_id) {
+            for channel_id in channels.keys() {
+                self.unsubscribe(user_id, &channel_id)
+            }
+        }
+    }
+
+    pub fn users_to_notify(
+        &self,
+        channel_id: ChannelId,
+    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
+        self.by_channel
+            .get(&channel_id)
+            .into_iter()
+            .flat_map(move |users| {
+                users.iter().flat_map(move |user_id| {
+                    Some((
+                        *user_id,
+                        self.by_user
+                            .get(user_id)
+                            .and_then(|channels| channels.get(&channel_id))
+                            .copied()?,
+                    ))
+                })
+            })
+    }
+}