diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 066c93ec718f15a451164270dcaa0eb8b53dcd82..58607836cc4f44fab8acf4a88e5c972e92d9e9d8 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -3093,6 +3093,22 @@ impl Database { self.transaction(move |tx| async move { let tx = tx; + if let Some(parent) = parent { + let channels = self.get_channel_ancestors(parent, &*tx).await?; + channel_member::Entity::find() + .filter(channel_member::Column::ChannelId.is_in(channels.iter().copied())) + .filter( + channel_member::Column::UserId + .eq(creator_id) + .and(channel_member::Column::Accepted.eq(true)), + ) + .one(&*tx) + .await? + .ok_or_else(|| { + anyhow!("User does not have the permissions to create this channel") + })?; + } + let channel = channel::ActiveModel { name: ActiveValue::Set(name.to_string()), ..Default::default() @@ -3175,11 +3191,6 @@ impl Database { let channels_to_remove = descendants.keys().copied().collect::>(); - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryUserIds { - UserId, - } - let members_to_notify: Vec = channel_member::Entity::find() .filter(channel_member::Column::ChannelId.is_in(channels_to_remove.iter().copied())) .select_only() @@ -3325,11 +3336,6 @@ impl Database { self.transaction(|tx| async move { let tx = tx; - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryChannelIds { - ChannelId, - } - let starting_channel_ids: Vec = channel_member::Entity::find() .filter( channel_member::Column::UserId @@ -3368,6 +3374,65 @@ impl Database { .await } + pub async fn get_channel_members(&self, id: ChannelId) -> Result> { + self.transaction(|tx| async move { + let tx = tx; + let ancestor_ids = self.get_channel_ancestors(id, &*tx).await?; + let user_ids = channel_member::Entity::find() + .distinct() + .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied())) + .select_only() + .column(channel_member::Column::UserId) + .into_values::<_, QueryUserIds>() + .all(&*tx) + .await?; + Ok(user_ids) + }) + .await + } + + async fn get_channel_ancestors( + &self, + id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result> { + let sql = format!( + r#" + WITH RECURSIVE channel_tree(child_id, parent_id) AS ( + SELECT CAST(NULL as INTEGER) as child_id, root_ids.column1 as parent_id + FROM (VALUES ({})) as root_ids + UNION + SELECT channel_parents.child_id, channel_parents.parent_id + FROM channel_parents, channel_tree + WHERE channel_parents.child_id = channel_tree.parent_id + ) + SELECT DISTINCT channel_tree.parent_id + FROM channel_tree + "#, + id + ); + + #[derive(FromQueryResult, Debug, PartialEq)] + pub struct ChannelParent { + pub parent_id: ChannelId, + } + + let stmt = Statement::from_string(self.pool.get_database_backend(), sql); + + let mut channel_ids_stream = channel_parent::Entity::find() + .from_raw_sql(stmt) + .into_model::() + .stream(&*tx) + .await?; + + let mut channel_ids = vec![]; + while let Some(channel_id) = channel_ids_stream.next().await { + channel_ids.push(channel_id?.parent_id); + } + + Ok(channel_ids) + } + async fn get_channel_descendants( &self, channel_ids: impl IntoIterator, @@ -3948,6 +4013,16 @@ pub struct WorktreeSettingsFile { pub content: String, } +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +enum QueryChannelIds { + ChannelId, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +enum QueryUserIds { + UserId, +} + #[cfg(test)] pub use test::*; diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 3a47097f7d91f721c2fdd0da54b854c543a97e77..2ffcef454b631f496b248330345cbdfc73daca03 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -899,7 +899,30 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, { .unwrap() .user_id; + let b_id = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "user2".into(), + github_user_id: 6, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap(); + + db.invite_channel_member(zed_id, b_id, a_id, true) + .await + .unwrap(); + + db.respond_to_channel_invite(zed_id, b_id, true) + .await + .unwrap(); + let crdb_id = db .create_channel("crdb", Some(zed_id), "2", a_id) .await @@ -912,6 +935,11 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, { .create_channel("replace", Some(zed_id), "4", a_id) .await .unwrap(); + + let mut members = db.get_channel_members(replace_id).await.unwrap(); + members.sort(); + assert_eq!(members, &[a_id, b_id]); + let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap(); let cargo_id = db .create_channel("cargo", Some(rust_id), "6", a_id) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 1465c666016f39f6b5de2f1dcbf4def1dbb2eaaf..819a3dc4f67ef5bbcf71ebbaf052e6529e108c11 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2108,25 +2108,33 @@ async fn create_channel( live_kit.create_room(live_kit_room.clone()).await?; } + let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); let id = db - .create_channel( - &request.name, - request.parent_id.map(|id| ChannelId::from_proto(id)), - &live_kit_room, - session.user_id, - ) + .create_channel(&request.name, parent_id, &live_kit_room, session.user_id) .await?; + response.send(proto::CreateChannelResponse { + channel_id: id.to_proto(), + })?; + let mut update = proto::UpdateChannels::default(); update.channels.push(proto::Channel { id: id.to_proto(), name: request.name, parent_id: request.parent_id, }); - session.peer.send(session.connection_id, update)?; - response.send(proto::CreateChannelResponse { - channel_id: id.to_proto(), - })?; + + if let Some(parent_id) = parent_id { + let member_ids = db.get_channel_members(parent_id).await?; + let connection_pool = session.connection_pool().await; + for member_id in member_ids { + for connection_id in connection_pool.user_connection_ids(member_id) { + session.peer.send(connection_id, update.clone())?; + } + } + } else { + session.peer.send(session.connection_id, update)?; + } Ok(()) }