Make joining a channel as a guest always succeed

Conrad Irwin created

Change summary

crates/channel/src/channel_store.rs                |   1 
crates/collab/src/db/queries/channels.rs           | 129 +++++++++--
crates/collab/src/db/queries/rooms.rs              | 184 +++++++++------
crates/collab/src/db/tests/channel_tests.rs        |  15 
crates/collab/src/rpc.rs                           | 158 ++++++++-----
crates/collab/src/tests/channel_tests.rs           |  52 ++++
crates/collab_ui/src/collab_panel/channel_modal.rs |   2 
7 files changed, 370 insertions(+), 171 deletions(-)

Detailed changes

crates/channel/src/channel_store.rs 🔗

@@ -972,6 +972,7 @@ impl ChannelStore {
 
         let mut all_user_ids = Vec::new();
         let channel_participants = payload.channel_participants;
+        dbg!(&channel_participants);
         for entry in &channel_participants {
             for user_id in entry.participant_user_ids.iter() {
                 if let Err(ix) = all_user_ids.binary_search(user_id) {

crates/collab/src/db/queries/channels.rs 🔗

@@ -88,6 +88,84 @@ impl Database {
         .await
     }
 
+    pub async fn join_channel_internal(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        connection: ConnectionId,
+        environment: &str,
+        tx: &DatabaseTransaction,
+    ) -> Result<(JoinRoom, bool)> {
+        let mut joined = false;
+
+        let channel = channel::Entity::find()
+            .filter(channel::Column::Id.eq(channel_id))
+            .one(&*tx)
+            .await?;
+
+        let mut role = self
+            .channel_role_for_user(channel_id, user_id, &*tx)
+            .await?;
+
+        if role.is_none() {
+            if channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public) {
+                channel_member::Entity::insert(channel_member::ActiveModel {
+                    id: ActiveValue::NotSet,
+                    channel_id: ActiveValue::Set(channel_id),
+                    user_id: ActiveValue::Set(user_id),
+                    accepted: ActiveValue::Set(true),
+                    role: ActiveValue::Set(ChannelRole::Guest),
+                })
+                .on_conflict(
+                    OnConflict::columns([
+                        channel_member::Column::UserId,
+                        channel_member::Column::ChannelId,
+                    ])
+                    .update_columns([channel_member::Column::Accepted])
+                    .to_owned(),
+                )
+                .exec(&*tx)
+                .await?;
+
+                debug_assert!(
+                    self.channel_role_for_user(channel_id, user_id, &*tx)
+                        .await?
+                        == Some(ChannelRole::Guest)
+                );
+
+                role = Some(ChannelRole::Guest);
+                joined = true;
+            }
+        }
+
+        if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
+            Err(anyhow!("no such channel, or not allowed"))?
+        }
+
+        let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
+        let room_id = self
+            .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
+            .await?;
+
+        self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
+            .await
+            .map(|jr| (jr, joined))
+    }
+
+    pub async fn join_channel(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        connection: ConnectionId,
+        environment: &str,
+    ) -> Result<(JoinRoom, bool)> {
+        self.transaction(move |tx| async move {
+            self.join_channel_internal(channel_id, user_id, connection, environment, &*tx)
+                .await
+        })
+        .await
+    }
+
     pub async fn set_channel_visibility(
         &self,
         channel_id: ChannelId,
@@ -981,38 +1059,39 @@ impl Database {
         .await
     }
 
-    pub async fn get_or_create_channel_room(
+    pub(crate) async fn get_or_create_channel_room(
         &self,
         channel_id: ChannelId,
         live_kit_room: &str,
-        enviroment: &str,
+        environment: &str,
+        tx: &DatabaseTransaction,
     ) -> Result<RoomId> {
-        self.transaction(|tx| async move {
-            let tx = tx;
-
-            let room = room::Entity::find()
-                .filter(room::Column::ChannelId.eq(channel_id))
-                .one(&*tx)
-                .await?;
+        let room = room::Entity::find()
+            .filter(room::Column::ChannelId.eq(channel_id))
+            .one(&*tx)
+            .await?;
 
-            let room_id = if let Some(room) = room {
-                room.id
-            } else {
-                let result = room::Entity::insert(room::ActiveModel {
-                    channel_id: ActiveValue::Set(Some(channel_id)),
-                    live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
-                    enviroment: ActiveValue::Set(Some(enviroment.to_string())),
-                    ..Default::default()
-                })
-                .exec(&*tx)
-                .await?;
+        let room_id = if let Some(room) = room {
+            if let Some(env) = room.enviroment {
+                if &env != environment {
+                    Err(anyhow!("must join using the {} release", env))?;
+                }
+            }
+            room.id
+        } else {
+            let result = room::Entity::insert(room::ActiveModel {
+                channel_id: ActiveValue::Set(Some(channel_id)),
+                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
+                enviroment: ActiveValue::Set(Some(environment.to_string())),
+                ..Default::default()
+            })
+            .exec(&*tx)
+            .await?;
 
-                result.last_insert_id
-            };
+            result.last_insert_id
+        };
 
-            Ok(room_id)
-        })
-        .await
+        Ok(room_id)
     }
 
     // Insert an edge from the given channel to the given other channel.

crates/collab/src/db/queries/rooms.rs 🔗

@@ -300,99 +300,139 @@ impl Database {
                 }
             }
 
-            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
-            enum QueryParticipantIndices {
-                ParticipantIndex,
-            }
-            let existing_participant_indices: Vec<i32> = room_participant::Entity::find()
-                .filter(
-                    room_participant::Column::RoomId
-                        .eq(room_id)
-                        .and(room_participant::Column::ParticipantIndex.is_not_null()),
-                )
-                .select_only()
-                .column(room_participant::Column::ParticipantIndex)
-                .into_values::<_, QueryParticipantIndices>()
-                .all(&*tx)
-                .await?;
-
-            let mut participant_index = 0;
-            while existing_participant_indices.contains(&participant_index) {
-                participant_index += 1;
+            if channel_id.is_some() {
+                Err(anyhow!("tried to join channel call directly"))?
             }
 
-            if let Some(channel_id) = channel_id {
-                self.check_user_is_channel_member(channel_id, user_id, &*tx)
-                    .await?;
+            let participant_index = self
+                .get_next_participant_index_internal(room_id, &*tx)
+                .await?;
 
-                room_participant::Entity::insert_many([room_participant::ActiveModel {
-                    room_id: ActiveValue::set(room_id),
-                    user_id: ActiveValue::set(user_id),
+            let result = room_participant::Entity::update_many()
+                .filter(
+                    Condition::all()
+                        .add(room_participant::Column::RoomId.eq(room_id))
+                        .add(room_participant::Column::UserId.eq(user_id))
+                        .add(room_participant::Column::AnsweringConnectionId.is_null()),
+                )
+                .set(room_participant::ActiveModel {
+                    participant_index: ActiveValue::Set(Some(participant_index)),
                     answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
                     answering_connection_server_id: ActiveValue::set(Some(ServerId(
                         connection.owner_id as i32,
                     ))),
                     answering_connection_lost: ActiveValue::set(false),
-                    calling_user_id: ActiveValue::set(user_id),
-                    calling_connection_id: ActiveValue::set(connection.id as i32),
-                    calling_connection_server_id: ActiveValue::set(Some(ServerId(
-                        connection.owner_id as i32,
-                    ))),
-                    participant_index: ActiveValue::Set(Some(participant_index)),
                     ..Default::default()
-                }])
-                .on_conflict(
-                    OnConflict::columns([room_participant::Column::UserId])
-                        .update_columns([
-                            room_participant::Column::AnsweringConnectionId,
-                            room_participant::Column::AnsweringConnectionServerId,
-                            room_participant::Column::AnsweringConnectionLost,
-                            room_participant::Column::ParticipantIndex,
-                        ])
-                        .to_owned(),
-                )
+                })
                 .exec(&*tx)
                 .await?;
-            } else {
-                let result = room_participant::Entity::update_many()
-                    .filter(
-                        Condition::all()
-                            .add(room_participant::Column::RoomId.eq(room_id))
-                            .add(room_participant::Column::UserId.eq(user_id))
-                            .add(room_participant::Column::AnsweringConnectionId.is_null()),
-                    )
-                    .set(room_participant::ActiveModel {
-                        participant_index: ActiveValue::Set(Some(participant_index)),
-                        answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
-                        answering_connection_server_id: ActiveValue::set(Some(ServerId(
-                            connection.owner_id as i32,
-                        ))),
-                        answering_connection_lost: ActiveValue::set(false),
-                        ..Default::default()
-                    })
-                    .exec(&*tx)
-                    .await?;
-                if result.rows_affected == 0 {
-                    Err(anyhow!("room does not exist or was already joined"))?;
-                }
+            if result.rows_affected == 0 {
+                Err(anyhow!("room does not exist or was already joined"))?;
             }
 
             let room = self.get_room(room_id, &tx).await?;
-            let channel_members = if let Some(channel_id) = channel_id {
-                self.get_channel_participants_internal(channel_id, &tx)
-                    .await?
-            } else {
-                Vec::new()
-            };
             Ok(JoinRoom {
                 room,
-                channel_id,
-                channel_members,
+                channel_id: None,
+                channel_members: vec![],
             })
         })
         .await
     }
 
+    async fn get_next_participant_index_internal(
+        &self,
+        room_id: RoomId,
+        tx: &DatabaseTransaction,
+    ) -> Result<i32> {
+        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+        enum QueryParticipantIndices {
+            ParticipantIndex,
+        }
+        let existing_participant_indices: Vec<i32> = room_participant::Entity::find()
+            .filter(
+                room_participant::Column::RoomId
+                    .eq(room_id)
+                    .and(room_participant::Column::ParticipantIndex.is_not_null()),
+            )
+            .select_only()
+            .column(room_participant::Column::ParticipantIndex)
+            .into_values::<_, QueryParticipantIndices>()
+            .all(&*tx)
+            .await?;
+
+        let mut participant_index = 0;
+        while existing_participant_indices.contains(&participant_index) {
+            participant_index += 1;
+        }
+
+        Ok(participant_index)
+    }
+
+    pub async fn channel_id_for_room(&self, room_id: RoomId) -> Result<Option<ChannelId>> {
+        self.transaction(|tx| async move {
+            let room: Option<room::Model> = room::Entity::find()
+                .filter(room::Column::Id.eq(room_id))
+                .one(&*tx)
+                .await?;
+
+            Ok(room.and_then(|room| room.channel_id))
+        })
+        .await
+    }
+
+    pub(crate) async fn join_channel_room_internal(
+        &self,
+        channel_id: ChannelId,
+        room_id: RoomId,
+        user_id: UserId,
+        connection: ConnectionId,
+        tx: &DatabaseTransaction,
+    ) -> Result<JoinRoom> {
+        let participant_index = self
+            .get_next_participant_index_internal(room_id, &*tx)
+            .await?;
+
+        room_participant::Entity::insert_many([room_participant::ActiveModel {
+            room_id: ActiveValue::set(room_id),
+            user_id: ActiveValue::set(user_id),
+            answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
+            answering_connection_server_id: ActiveValue::set(Some(ServerId(
+                connection.owner_id as i32,
+            ))),
+            answering_connection_lost: ActiveValue::set(false),
+            calling_user_id: ActiveValue::set(user_id),
+            calling_connection_id: ActiveValue::set(connection.id as i32),
+            calling_connection_server_id: ActiveValue::set(Some(ServerId(
+                connection.owner_id as i32,
+            ))),
+            participant_index: ActiveValue::Set(Some(participant_index)),
+            ..Default::default()
+        }])
+        .on_conflict(
+            OnConflict::columns([room_participant::Column::UserId])
+                .update_columns([
+                    room_participant::Column::AnsweringConnectionId,
+                    room_participant::Column::AnsweringConnectionServerId,
+                    room_participant::Column::AnsweringConnectionLost,
+                    room_participant::Column::ParticipantIndex,
+                ])
+                .to_owned(),
+        )
+        .exec(&*tx)
+        .await?;
+
+        let room = self.get_room(room_id, &tx).await?;
+        let channel_members = self
+            .get_channel_participants_internal(channel_id, &tx)
+            .await?;
+        Ok(JoinRoom {
+            room,
+            channel_id: Some(channel_id),
+            channel_members,
+        })
+    }
+
     pub async fn rejoin_room(
         &self,
         rejoin_room: proto::RejoinRoom,

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -8,7 +8,7 @@ use crate::{
     db::{
         queries::channels::ChannelGraph,
         tests::{graph, TEST_RELEASE_CHANNEL},
-        ChannelId, ChannelRole, Database, NewUserParams, UserId,
+        ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId,
     },
     test_both_dbs,
 };
@@ -207,15 +207,11 @@ async fn test_joining_channels(db: &Arc<Database>) {
         .user_id;
 
     let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap();
-    let room_1 = db
-        .get_or_create_channel_room(channel_1, "1", TEST_RELEASE_CHANNEL)
-        .await
-        .unwrap();
 
     // can join a room with membership to its channel
-    let joined_room = db
-        .join_room(
-            room_1,
+    let (joined_room, _) = db
+        .join_channel(
+            channel_1,
             user_1,
             ConnectionId { owner_id, id: 1 },
             TEST_RELEASE_CHANNEL,
@@ -224,11 +220,12 @@ async fn test_joining_channels(db: &Arc<Database>) {
         .unwrap();
     assert_eq!(joined_room.room.participants.len(), 1);
 
+    let room_id = RoomId::from_proto(joined_room.room.id);
     drop(joined_room);
     // cannot join a room without membership to its channel
     assert!(db
         .join_room(
-            room_1,
+            room_id,
             user_2,
             ConnectionId { owner_id, id: 1 },
             TEST_RELEASE_CHANNEL

crates/collab/src/rpc.rs 🔗

@@ -38,7 +38,7 @@ use lazy_static::lazy_static;
 use prometheus::{register_int_gauge, IntGauge};
 use rpc::{
     proto::{
-        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
+        self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, JoinRoom,
         LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
     },
     Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
@@ -977,6 +977,13 @@ async fn join_room(
     session: Session,
 ) -> Result<()> {
     let room_id = RoomId::from_proto(request.id);
+
+    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
+
+    if let Some(channel_id) = channel_id {
+        return join_channel_internal(channel_id, Box::new(response), session).await;
+    }
+
     let joined_room = {
         let room = session
             .db()
@@ -992,16 +999,6 @@ async fn join_room(
         room.into_inner()
     };
 
-    if let Some(channel_id) = joined_room.channel_id {
-        channel_updated(
-            channel_id,
-            &joined_room.room,
-            &joined_room.channel_members,
-            &session.peer,
-            &*session.connection_pool().await,
-        )
-    }
-
     for connection_id in session
         .connection_pool()
         .await
@@ -1039,7 +1036,7 @@ async fn join_room(
 
     response.send(proto::JoinRoomResponse {
         room: Some(joined_room.room),
-        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
+        channel_id: None,
         live_kit_connection_info,
     })?;
 
@@ -2602,54 +2599,68 @@ async fn respond_to_channel_invite(
     db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
         .await?;
 
+    if request.accept {
+        channel_membership_updated(db, channel_id, &session).await?;
+    } else {
+        let mut update = proto::UpdateChannels::default();
+        update
+            .remove_channel_invitations
+            .push(channel_id.to_proto());
+        session.peer.send(session.connection_id, update)?;
+    }
+    response.send(proto::Ack {})?;
+
+    Ok(())
+}
+
+async fn channel_membership_updated(
+    db: tokio::sync::MutexGuard<'_, DbHandle>,
+    channel_id: ChannelId,
+    session: &Session,
+) -> Result<(), crate::Error> {
     let mut update = proto::UpdateChannels::default();
     update
         .remove_channel_invitations
         .push(channel_id.to_proto());
-    if request.accept {
-        let result = db.get_channel_for_user(channel_id, session.user_id).await?;
-        update
+
+    let result = db.get_channel_for_user(channel_id, session.user_id).await?;
+    update.channels.extend(
+        result
             .channels
-            .extend(
-                result
-                    .channels
-                    .channels
-                    .into_iter()
-                    .map(|channel| proto::Channel {
-                        id: channel.id.to_proto(),
-                        visibility: channel.visibility.into(),
-                        name: channel.name,
-                    }),
-            );
-        update.unseen_channel_messages = result.channel_messages;
-        update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
-        update.insert_edge = result.channels.edges;
-        update
-            .channel_participants
-            .extend(
-                result
-                    .channel_participants
-                    .into_iter()
-                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
-                        channel_id: channel_id.to_proto(),
-                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
-                    }),
-            );
-        update
-            .channel_permissions
-            .extend(
-                result
-                    .channels_with_admin_privileges
-                    .into_iter()
-                    .map(|channel_id| proto::ChannelPermission {
-                        channel_id: channel_id.to_proto(),
-                        role: proto::ChannelRole::Admin.into(),
-                    }),
-            );
-    }
+            .channels
+            .into_iter()
+            .map(|channel| proto::Channel {
+                id: channel.id.to_proto(),
+                visibility: channel.visibility.into(),
+                name: channel.name,
+            }),
+    );
+    update.unseen_channel_messages = result.channel_messages;
+    update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
+    update.insert_edge = result.channels.edges;
+    update
+        .channel_participants
+        .extend(
+            result
+                .channel_participants
+                .into_iter()
+                .map(|(channel_id, user_ids)| proto::ChannelParticipants {
+                    channel_id: channel_id.to_proto(),
+                    participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
+                }),
+        );
+    update
+        .channel_permissions
+        .extend(
+            result
+                .channels_with_admin_privileges
+                .into_iter()
+                .map(|channel_id| proto::ChannelPermission {
+                    channel_id: channel_id.to_proto(),
+                    role: proto::ChannelRole::Admin.into(),
+                }),
+        );
     session.peer.send(session.connection_id, update)?;
-    response.send(proto::Ack {})?;
-
     Ok(())
 }
 
@@ -2659,19 +2670,35 @@ async fn join_channel(
     session: Session,
 ) -> Result<()> {
     let channel_id = ChannelId::from_proto(request.channel_id);
-    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
+    join_channel_internal(channel_id, Box::new(response), session).await
+}
+
+trait JoinChannelInternalResponse {
+    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
+}
+impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
+    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
+        Response::<proto::JoinChannel>::send(self, result)
+    }
+}
+impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
+    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
+        Response::<proto::JoinRoom>::send(self, result)
+    }
+}
 
+async fn join_channel_internal(
+    channel_id: ChannelId,
+    response: Box<impl JoinChannelInternalResponse>,
+    session: Session,
+) -> Result<()> {
     let joined_room = {
         leave_room_for_session(&session).await?;
         let db = session.db().await;
 
-        let room_id = db
-            .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME)
-            .await?;
-
-        let joined_room = db
-            .join_room(
-                room_id,
+        let (joined_room, joined_channel) = db
+            .join_channel(
+                channel_id,
                 session.user_id,
                 session.connection_id,
                 RELEASE_CHANNEL_NAME.as_str(),
@@ -2698,9 +2725,13 @@ async fn join_channel(
             live_kit_connection_info,
         })?;
 
+        if joined_channel {
+            channel_membership_updated(db, channel_id, &session).await?
+        }
+
         room_updated(&joined_room.room, &session.peer);
 
-        joined_room.into_inner()
+        joined_room
     };
 
     channel_updated(
@@ -2712,7 +2743,6 @@ async fn join_channel(
     );
 
     update_user_contacts(session.user_id, &session).await?;
-
     Ok(())
 }
 

crates/collab/src/tests/channel_tests.rs 🔗

@@ -912,6 +912,58 @@ async fn test_lost_channel_creation(
         ],
     );
 }
+#[gpui::test]
+async fn test_guest_access(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+
+    let channels = server
+        .make_channel_tree(&[("channel-a", None)], (&client_a, cx_a))
+        .await;
+    let channel_a_id = channels[0];
+
+    let active_call_b = cx_b.read(ActiveCall::global);
+
+    // should not be allowed to join
+    assert!(active_call_b
+        .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx))
+        .await
+        .is_err());
+
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, cx| {
+            channel_store.set_channel_visibility(channel_a_id, proto::ChannelVisibility::Public, cx)
+        })
+        .await
+        .unwrap();
+
+    active_call_b
+        .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx))
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+
+    assert!(client_b
+        .channel_store()
+        .update(cx_b, |channel_store, _| channel_store
+            .channel_for_id(channel_a_id)
+            .is_some()));
+
+    client_a.channel_store().update(cx_a, |channel_store, _| {
+        let participants = channel_store.channel_participants(channel_a_id);
+        assert_eq!(participants.len(), 1);
+        assert_eq!(participants[0].id, client_b.user_id().unwrap());
+    })
+}
 
 #[gpui::test]
 async fn test_channel_moving(

crates/collab_ui/src/collab_panel/channel_modal.rs 🔗

@@ -1,4 +1,4 @@
-use channel::{Channel, ChannelId, ChannelMembership, ChannelStore};
+use channel::{ChannelId, ChannelMembership, ChannelStore};
 use client::{
     proto::{self, ChannelRole, ChannelVisibility},
     User, UserId, UserStore,