Only allow one release channel in a call

Conrad Irwin created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql           |  1 
crates/collab/migrations/20231009181554_add_release_channel_to_rooms.sql |  1 
crates/collab/src/db/queries/rooms.rs                                    | 32 
crates/collab/src/db/tables/room.rs                                      |  1 
crates/collab/src/db/tests.rs                                            |  2 
crates/collab/src/db/tests/channel_tests.rs                              | 20 
crates/collab/src/db/tests/db_tests.rs                                   | 92 
crates/collab/src/rpc.rs                                                 | 22 
8 files changed, 152 insertions(+), 19 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -37,6 +37,7 @@ CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b");
 CREATE TABLE "rooms" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "live_kit_room" VARCHAR NOT NULL,
+    "release_channel" VARCHAR,
     "channel_id" INTEGER REFERENCES channels (id) ON DELETE CASCADE
 );
 

crates/collab/src/db/queries/rooms.rs 🔗

@@ -107,10 +107,12 @@ impl Database {
         user_id: UserId,
         connection: ConnectionId,
         live_kit_room: &str,
+        release_channel: &str,
     ) -> Result<proto::Room> {
         self.transaction(|tx| async move {
             let room = room::ActiveModel {
                 live_kit_room: ActiveValue::set(live_kit_room.into()),
+                release_channel: ActiveValue::set(Some(release_channel.to_string())),
                 ..Default::default()
             }
             .insert(&*tx)
@@ -270,20 +272,31 @@ impl Database {
         room_id: RoomId,
         user_id: UserId,
         connection: ConnectionId,
+        collab_release_channel: &str,
     ) -> Result<RoomGuard<JoinRoom>> {
         self.room_transaction(room_id, |tx| async move {
             #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
-            enum QueryChannelId {
+            enum QueryChannelIdAndReleaseChannel {
                 ChannelId,
+                ReleaseChannel,
+            }
+
+            let (channel_id, release_channel): (Option<ChannelId>, Option<String>) =
+                room::Entity::find()
+                    .select_only()
+                    .column(room::Column::ChannelId)
+                    .column(room::Column::ReleaseChannel)
+                    .filter(room::Column::Id.eq(room_id))
+                    .into_values::<_, QueryChannelIdAndReleaseChannel>()
+                    .one(&*tx)
+                    .await?
+                    .ok_or_else(|| anyhow!("no such room"))?;
+
+            if let Some(release_channel) = release_channel {
+                if &release_channel != collab_release_channel {
+                    Err(anyhow!("must join using the {} release", release_channel))?;
+                }
             }
-            let channel_id: Option<ChannelId> = room::Entity::find()
-                .select_only()
-                .column(room::Column::ChannelId)
-                .filter(room::Column::Id.eq(room_id))
-                .into_values::<_, QueryChannelId>()
-                .one(&*tx)
-                .await?
-                .ok_or_else(|| anyhow!("no such room"))?;
 
             #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
             enum QueryParticipantIndices {
@@ -300,6 +313,7 @@ impl Database {
                 .into_values::<_, QueryParticipantIndices>()
                 .all(&*tx)
                 .await?;
+
             let mut participant_index = 0;
             while existing_participant_indices.contains(&participant_index) {
                 participant_index += 1;

crates/collab/src/db/tables/room.rs 🔗

@@ -8,6 +8,7 @@ pub struct Model {
     pub id: RoomId,
     pub live_kit_room: String,
     pub channel_id: Option<ChannelId>,
+    pub release_channel: Option<String>,
 }
 
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

crates/collab/src/db/tests.rs 🔗

@@ -12,6 +12,8 @@ use sea_orm::ConnectionTrait;
 use sqlx::migrate::MigrateDatabase;
 use std::sync::Arc;
 
+const TEST_RELEASE_CHANNEL: &'static str = "test";
+
 pub struct TestDb {
     pub db: Option<Arc<Database>>,
     pub connection: Option<sqlx::AnyConnection>,

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -5,7 +5,11 @@ use rpc::{
 };
 
 use crate::{
-    db::{queries::channels::ChannelGraph, tests::graph, ChannelId, Database, NewUserParams},
+    db::{
+        queries::channels::ChannelGraph,
+        tests::{graph, TEST_RELEASE_CHANNEL},
+        ChannelId, Database, NewUserParams,
+    },
     test_both_dbs,
 };
 use std::sync::Arc;
@@ -206,7 +210,12 @@ async fn test_joining_channels(db: &Arc<Database>) {
 
     // can join a room with membership to its channel
     let joined_room = db
-        .join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
+        .join_room(
+            room_1,
+            user_1,
+            ConnectionId { owner_id, id: 1 },
+            TEST_RELEASE_CHANNEL,
+        )
         .await
         .unwrap();
     assert_eq!(joined_room.room.participants.len(), 1);
@@ -214,7 +223,12 @@ async fn test_joining_channels(db: &Arc<Database>) {
     drop(joined_room);
     // cannot join a room without membership to its channel
     assert!(db
-        .join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
+        .join_room(
+            room_1,
+            user_2,
+            ConnectionId { owner_id, id: 1 },
+            TEST_RELEASE_CHANNEL
+        )
         .await
         .is_err());
 }

crates/collab/src/db/tests/db_tests.rs 🔗

@@ -479,7 +479,7 @@ async fn test_project_count(db: &Arc<Database>) {
         .unwrap();
 
     let room_id = RoomId::from_proto(
-        db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "")
+        db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "", "dev")
             .await
             .unwrap()
             .id,
@@ -493,9 +493,14 @@ async fn test_project_count(db: &Arc<Database>) {
     )
     .await
     .unwrap();
-    db.join_room(room_id, user2.user_id, ConnectionId { owner_id, id: 1 })
-        .await
-        .unwrap();
+    db.join_room(
+        room_id,
+        user2.user_id,
+        ConnectionId { owner_id, id: 1 },
+        "dev",
+    )
+    .await
+    .unwrap();
     assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
 
     db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
@@ -575,6 +580,85 @@ async fn test_fuzzy_search_users() {
     }
 }
 
+test_both_dbs!(
+    test_non_matching_release_channels,
+    test_non_matching_release_channels_postgres,
+    test_non_matching_release_channels_sqlite
+);
+
+async fn test_non_matching_release_channels(db: &Arc<Database>) {
+    let owner_id = db.create_server("test").await.unwrap().0 as u32;
+
+    let user1 = db
+        .create_user(
+            &format!("admin@example.com"),
+            true,
+            NewUserParams {
+                github_login: "admin".into(),
+                github_user_id: 0,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap();
+    let user2 = db
+        .create_user(
+            &format!("user@example.com"),
+            false,
+            NewUserParams {
+                github_login: "user".into(),
+                github_user_id: 1,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap();
+
+    let room = db
+        .create_room(
+            user1.user_id,
+            ConnectionId { owner_id, id: 0 },
+            "",
+            "stable",
+        )
+        .await
+        .unwrap();
+
+    db.call(
+        RoomId::from_proto(room.id),
+        user1.user_id,
+        ConnectionId { owner_id, id: 0 },
+        user2.user_id,
+        None,
+    )
+    .await
+    .unwrap();
+
+    // User attempts to join from preview
+    let result = db
+        .join_room(
+            RoomId::from_proto(room.id),
+            user2.user_id,
+            ConnectionId { owner_id, id: 1 },
+            "preview",
+        )
+        .await;
+
+    assert!(result.is_err());
+
+    // User switches to stable
+    let result = db
+        .join_room(
+            RoomId::from_proto(room.id),
+            user2.user_id,
+            ConnectionId { owner_id, id: 1 },
+            "stable",
+        )
+        .await;
+
+    assert!(result.is_ok())
+}
+
 fn build_background_executor() -> Arc<Background> {
     Deterministic::new(0).build_background()
 }

crates/collab/src/rpc.rs 🔗

@@ -63,6 +63,7 @@ use time::OffsetDateTime;
 use tokio::sync::{watch, Semaphore};
 use tower::ServiceBuilder;
 use tracing::{info_span, instrument, Instrument};
+use util::channel::RELEASE_CHANNEL_NAME;
 
 pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
 pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
@@ -957,7 +958,12 @@ async fn create_room(
     let room = session
         .db()
         .await
-        .create_room(session.user_id, session.connection_id, &live_kit_room)
+        .create_room(
+            session.user_id,
+            session.connection_id,
+            &live_kit_room,
+            RELEASE_CHANNEL_NAME.as_str(),
+        )
         .await?;
 
     response.send(proto::CreateRoomResponse {
@@ -979,7 +985,12 @@ async fn join_room(
         let room = session
             .db()
             .await
-            .join_room(room_id, session.user_id, session.connection_id)
+            .join_room(
+                room_id,
+                session.user_id,
+                session.connection_id,
+                RELEASE_CHANNEL_NAME.as_str(),
+            )
             .await?;
         room_updated(&room.room, &session.peer);
         room.into_inner()
@@ -2616,7 +2627,12 @@ async fn join_channel(
         let room_id = db.room_id_for_channel(channel_id).await?;
 
         let joined_room = db
-            .join_room(room_id, session.user_id, session.connection_id)
+            .join_room(
+                room_id,
+                session.user_id,
+                session.connection_id,
+                RELEASE_CHANNEL_NAME.as_str(),
+            )
             .await?;
 
         let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {