@@ -3093,6 +3093,22 @@ impl Database {
self.transaction(move |tx| async move {
let tx = tx;
+ if let Some(parent) = parent {
+ let channels = self.get_channel_ancestors(parent, &*tx).await?;
+ channel_member::Entity::find()
+ .filter(channel_member::Column::ChannelId.is_in(channels.iter().copied()))
+ .filter(
+ channel_member::Column::UserId
+ .eq(creator_id)
+ .and(channel_member::Column::Accepted.eq(true)),
+ )
+ .one(&*tx)
+ .await?
+ .ok_or_else(|| {
+ anyhow!("User does not have the permissions to create this channel")
+ })?;
+ }
+
let channel = channel::ActiveModel {
name: ActiveValue::Set(name.to_string()),
..Default::default()
@@ -3175,11 +3191,6 @@ impl Database {
let channels_to_remove = descendants.keys().copied().collect::<Vec<_>>();
- #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
- enum QueryUserIds {
- UserId,
- }
-
let members_to_notify: Vec<UserId> = channel_member::Entity::find()
.filter(channel_member::Column::ChannelId.is_in(channels_to_remove.iter().copied()))
.select_only()
@@ -3325,11 +3336,6 @@ impl Database {
self.transaction(|tx| async move {
let tx = tx;
- #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
- enum QueryChannelIds {
- ChannelId,
- }
-
let starting_channel_ids: Vec<ChannelId> = channel_member::Entity::find()
.filter(
channel_member::Column::UserId
@@ -3368,6 +3374,65 @@ impl Database {
.await
}
+ pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
+ self.transaction(|tx| async move {
+ let tx = tx;
+ let ancestor_ids = self.get_channel_ancestors(id, &*tx).await?;
+ let user_ids = channel_member::Entity::find()
+ .distinct()
+ .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
+ .select_only()
+ .column(channel_member::Column::UserId)
+ .into_values::<_, QueryUserIds>()
+ .all(&*tx)
+ .await?;
+ Ok(user_ids)
+ })
+ .await
+ }
+
+ async fn get_channel_ancestors(
+ &self,
+ id: ChannelId,
+ tx: &DatabaseTransaction,
+ ) -> Result<Vec<ChannelId>> {
+ let sql = format!(
+ r#"
+ WITH RECURSIVE channel_tree(child_id, parent_id) AS (
+ SELECT CAST(NULL as INTEGER) as child_id, root_ids.column1 as parent_id
+ FROM (VALUES ({})) as root_ids
+ UNION
+ SELECT channel_parents.child_id, channel_parents.parent_id
+ FROM channel_parents, channel_tree
+ WHERE channel_parents.child_id = channel_tree.parent_id
+ )
+ SELECT DISTINCT channel_tree.parent_id
+ FROM channel_tree
+ "#,
+ id
+ );
+
+ #[derive(FromQueryResult, Debug, PartialEq)]
+ pub struct ChannelParent {
+ pub parent_id: ChannelId,
+ }
+
+ let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
+
+ let mut channel_ids_stream = channel_parent::Entity::find()
+ .from_raw_sql(stmt)
+ .into_model::<ChannelParent>()
+ .stream(&*tx)
+ .await?;
+
+ let mut channel_ids = vec![];
+ while let Some(channel_id) = channel_ids_stream.next().await {
+ channel_ids.push(channel_id?.parent_id);
+ }
+
+ Ok(channel_ids)
+ }
+
async fn get_channel_descendants(
&self,
channel_ids: impl IntoIterator<Item = ChannelId>,
@@ -3948,6 +4013,16 @@ pub struct WorktreeSettingsFile {
pub content: String,
}
+#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+enum QueryChannelIds {
+ ChannelId,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+enum QueryUserIds {
+ UserId,
+}
+
#[cfg(test)]
pub use test::*;
@@ -899,7 +899,30 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, {
.unwrap()
.user_id;
+ let b_id = db
+ .create_user(
+ "user2@example.com",
+ false,
+ NewUserParams {
+ github_login: "user2".into(),
+ github_user_id: 6,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
+
+ db.invite_channel_member(zed_id, b_id, a_id, true)
+ .await
+ .unwrap();
+
+ db.respond_to_channel_invite(zed_id, b_id, true)
+ .await
+ .unwrap();
+
let crdb_id = db
.create_channel("crdb", Some(zed_id), "2", a_id)
.await
@@ -912,6 +935,11 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, {
.create_channel("replace", Some(zed_id), "4", a_id)
.await
.unwrap();
+
+ let mut members = db.get_channel_members(replace_id).await.unwrap();
+ members.sort();
+ assert_eq!(members, &[a_id, b_id]);
+
let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap();
let cargo_id = db
.create_channel("cargo", Some(rust_id), "6", a_id)
@@ -2108,25 +2108,33 @@ async fn create_channel(
live_kit.create_room(live_kit_room.clone()).await?;
}
+ let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
let id = db
- .create_channel(
- &request.name,
- request.parent_id.map(|id| ChannelId::from_proto(id)),
- &live_kit_room,
- session.user_id,
- )
+ .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
.await?;
+ response.send(proto::CreateChannelResponse {
+ channel_id: id.to_proto(),
+ })?;
+
let mut update = proto::UpdateChannels::default();
update.channels.push(proto::Channel {
id: id.to_proto(),
name: request.name,
parent_id: request.parent_id,
});
- session.peer.send(session.connection_id, update)?;
- response.send(proto::CreateChannelResponse {
- channel_id: id.to_proto(),
- })?;
+
+ if let Some(parent_id) = parent_id {
+ let member_ids = db.get_channel_members(parent_id).await?;
+ let connection_pool = session.connection_pool().await;
+ for member_id in member_ids {
+ for connection_id in connection_pool.user_connection_ids(member_id) {
+ session.peer.send(connection_id, update.clone())?;
+ }
+ }
+ } else {
+ session.peer.send(session.connection_id, update)?;
+ }
Ok(())
}