channels.rs

   1use super::*;
   2use rpc::proto::ChannelEdge;
   3use smallvec::SmallVec;
   4
   5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
   6
   7impl Database {
   8    #[cfg(test)]
   9    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
  10        self.transaction(move |tx| async move {
  11            let mut channels = Vec::new();
  12            let mut rows = channel::Entity::find().stream(&*tx).await?;
  13            while let Some(row) = rows.next().await {
  14                let row = row?;
  15                channels.push((row.id, row.name));
  16            }
  17            Ok(channels)
  18        })
  19        .await
  20    }
  21
  22    pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
  23        self.create_channel(name, None, creator_id).await
  24    }
  25
  26    pub async fn create_channel(
  27        &self,
  28        name: &str,
  29        parent: Option<ChannelId>,
  30        creator_id: UserId,
  31    ) -> Result<ChannelId> {
  32        let name = Self::sanitize_channel_name(name)?;
  33        self.transaction(move |tx| async move {
  34            if let Some(parent) = parent {
  35                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  36                    .await?;
  37            }
  38
  39            let channel = channel::ActiveModel {
  40                name: ActiveValue::Set(name.to_string()),
  41                ..Default::default()
  42            }
  43            .insert(&*tx)
  44            .await?;
  45
  46            if let Some(parent) = parent {
  47                let sql = r#"
  48                    INSERT INTO channel_paths
  49                    (id_path, channel_id)
  50                    SELECT
  51                        id_path || $1 || '/', $2
  52                    FROM
  53                        channel_paths
  54                    WHERE
  55                        channel_id = $3
  56                "#;
  57                let channel_paths_stmt = Statement::from_sql_and_values(
  58                    self.pool.get_database_backend(),
  59                    sql,
  60                    [
  61                        channel.id.to_proto().into(),
  62                        channel.id.to_proto().into(),
  63                        parent.to_proto().into(),
  64                    ],
  65                );
  66                tx.execute(channel_paths_stmt).await?;
  67            } else {
  68                channel_path::Entity::insert(channel_path::ActiveModel {
  69                    channel_id: ActiveValue::Set(channel.id),
  70                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  71                })
  72                .exec(&*tx)
  73                .await?;
  74            }
  75
  76            channel_member::ActiveModel {
  77                channel_id: ActiveValue::Set(channel.id),
  78                user_id: ActiveValue::Set(creator_id),
  79                accepted: ActiveValue::Set(true),
  80                admin: ActiveValue::Set(true),
  81                ..Default::default()
  82            }
  83            .insert(&*tx)
  84            .await?;
  85
  86            Ok(channel.id)
  87        })
  88        .await
  89    }
  90
  91    pub async fn delete_channel(
  92        &self,
  93        channel_id: ChannelId,
  94        user_id: UserId,
  95    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
  96        self.transaction(move |tx| async move {
  97            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
  98                .await?;
  99
 100            // Don't remove descendant channels that have additional parents.
 101            let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
 102            {
 103                let mut channels_to_keep = channel_path::Entity::find()
 104                    .filter(
 105                        channel_path::Column::ChannelId
 106                            .is_in(
 107                                channels_to_remove
 108                                    .keys()
 109                                    .copied()
 110                                    .filter(|&id| id != channel_id),
 111                            )
 112                            .and(
 113                                channel_path::Column::IdPath
 114                                    .not_like(&format!("%/{}/%", channel_id)),
 115                            ),
 116                    )
 117                    .stream(&*tx)
 118                    .await?;
 119                while let Some(row) = channels_to_keep.next().await {
 120                    let row = row?;
 121                    channels_to_remove.remove(&row.channel_id);
 122                }
 123            }
 124
 125            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 126            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 127                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 128                .select_only()
 129                .column(channel_member::Column::UserId)
 130                .distinct()
 131                .into_values::<_, QueryUserIds>()
 132                .all(&*tx)
 133                .await?;
 134
 135            channel::Entity::delete_many()
 136                .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
 137                .exec(&*tx)
 138                .await?;
 139
 140            // Delete any other paths that include this channel
 141            let sql = r#"
 142                    DELETE FROM channel_paths
 143                    WHERE
 144                        id_path LIKE '%' || $1 || '%'
 145                "#;
 146            let channel_paths_stmt = Statement::from_sql_and_values(
 147                self.pool.get_database_backend(),
 148                sql,
 149                [channel_id.to_proto().into()],
 150            );
 151            tx.execute(channel_paths_stmt).await?;
 152
 153            Ok((channels_to_remove.into_keys().collect(), members_to_notify))
 154        })
 155        .await
 156    }
 157
 158    pub async fn invite_channel_member(
 159        &self,
 160        channel_id: ChannelId,
 161        invitee_id: UserId,
 162        inviter_id: UserId,
 163        is_admin: bool,
 164    ) -> Result<()> {
 165        self.transaction(move |tx| async move {
 166            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
 167                .await?;
 168
 169            channel_member::ActiveModel {
 170                channel_id: ActiveValue::Set(channel_id),
 171                user_id: ActiveValue::Set(invitee_id),
 172                accepted: ActiveValue::Set(false),
 173                admin: ActiveValue::Set(is_admin),
 174                ..Default::default()
 175            }
 176            .insert(&*tx)
 177            .await?;
 178
 179            Ok(())
 180        })
 181        .await
 182    }
 183
 184    fn sanitize_channel_name(name: &str) -> Result<&str> {
 185        let new_name = name.trim().trim_start_matches('#');
 186        if new_name == "" {
 187            Err(anyhow!("channel name can't be blank"))?;
 188        }
 189        Ok(new_name)
 190    }
 191
 192    pub async fn rename_channel(
 193        &self,
 194        channel_id: ChannelId,
 195        user_id: UserId,
 196        new_name: &str,
 197    ) -> Result<String> {
 198        self.transaction(move |tx| async move {
 199            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 200
 201            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 202                .await?;
 203
 204            channel::ActiveModel {
 205                id: ActiveValue::Unchanged(channel_id),
 206                name: ActiveValue::Set(new_name.clone()),
 207                ..Default::default()
 208            }
 209            .update(&*tx)
 210            .await?;
 211
 212            Ok(new_name)
 213        })
 214        .await
 215    }
 216
 217    pub async fn respond_to_channel_invite(
 218        &self,
 219        channel_id: ChannelId,
 220        user_id: UserId,
 221        accept: bool,
 222    ) -> Result<()> {
 223        self.transaction(move |tx| async move {
 224            let rows_affected = if accept {
 225                channel_member::Entity::update_many()
 226                    .set(channel_member::ActiveModel {
 227                        accepted: ActiveValue::Set(accept),
 228                        ..Default::default()
 229                    })
 230                    .filter(
 231                        channel_member::Column::ChannelId
 232                            .eq(channel_id)
 233                            .and(channel_member::Column::UserId.eq(user_id))
 234                            .and(channel_member::Column::Accepted.eq(false)),
 235                    )
 236                    .exec(&*tx)
 237                    .await?
 238                    .rows_affected
 239            } else {
 240                channel_member::ActiveModel {
 241                    channel_id: ActiveValue::Unchanged(channel_id),
 242                    user_id: ActiveValue::Unchanged(user_id),
 243                    ..Default::default()
 244                }
 245                .delete(&*tx)
 246                .await?
 247                .rows_affected
 248            };
 249
 250            if rows_affected == 0 {
 251                Err(anyhow!("no such invitation"))?;
 252            }
 253
 254            Ok(())
 255        })
 256        .await
 257    }
 258
 259    pub async fn remove_channel_member(
 260        &self,
 261        channel_id: ChannelId,
 262        member_id: UserId,
 263        remover_id: UserId,
 264    ) -> Result<()> {
 265        self.transaction(|tx| async move {
 266            self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
 267                .await?;
 268
 269            let result = channel_member::Entity::delete_many()
 270                .filter(
 271                    channel_member::Column::ChannelId
 272                        .eq(channel_id)
 273                        .and(channel_member::Column::UserId.eq(member_id)),
 274                )
 275                .exec(&*tx)
 276                .await?;
 277
 278            if result.rows_affected == 0 {
 279                Err(anyhow!("no such member"))?;
 280            }
 281
 282            Ok(())
 283        })
 284        .await
 285    }
 286
 287    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 288        self.transaction(|tx| async move {
 289            let channel_invites = channel_member::Entity::find()
 290                .filter(
 291                    channel_member::Column::UserId
 292                        .eq(user_id)
 293                        .and(channel_member::Column::Accepted.eq(false)),
 294                )
 295                .all(&*tx)
 296                .await?;
 297
 298            let channels = channel::Entity::find()
 299                .filter(
 300                    channel::Column::Id.is_in(
 301                        channel_invites
 302                            .into_iter()
 303                            .map(|channel_member| channel_member.channel_id),
 304                    ),
 305                )
 306                .all(&*tx)
 307                .await?;
 308
 309            let channels = channels
 310                .into_iter()
 311                .map(|channel| Channel {
 312                    id: channel.id,
 313                    name: channel.name,
 314                })
 315                .collect();
 316
 317            Ok(channels)
 318        })
 319        .await
 320    }
 321
 322    async fn get_channel_graph(
 323        &self,
 324        parents_by_child_id: ChannelDescendants,
 325        trim_dangling_parents: bool,
 326        tx: &DatabaseTransaction,
 327    ) -> Result<ChannelGraph> {
 328        let mut channels = Vec::with_capacity(parents_by_child_id.len());
 329        {
 330            let mut rows = channel::Entity::find()
 331                .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
 332                .stream(&*tx)
 333                .await?;
 334            while let Some(row) = rows.next().await {
 335                let row = row?;
 336                channels.push(Channel {
 337                    id: row.id,
 338                    name: row.name,
 339                })
 340            }
 341        }
 342
 343        let mut edges = Vec::with_capacity(parents_by_child_id.len());
 344        for (channel, parents) in parents_by_child_id.iter() {
 345            for parent in parents.into_iter() {
 346                if trim_dangling_parents {
 347                    if parents_by_child_id.contains_key(parent) {
 348                        edges.push(ChannelEdge {
 349                            channel_id: channel.to_proto(),
 350                            parent_id: parent.to_proto(),
 351                        });
 352                    }
 353                } else {
 354                    edges.push(ChannelEdge {
 355                        channel_id: channel.to_proto(),
 356                        parent_id: parent.to_proto(),
 357                    });
 358                }
 359            }
 360        }
 361
 362        Ok(ChannelGraph { channels, edges })
 363    }
 364
 365    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 366        self.transaction(|tx| async move {
 367            let tx = tx;
 368
 369            let channel_memberships = channel_member::Entity::find()
 370                .filter(
 371                    channel_member::Column::UserId
 372                        .eq(user_id)
 373                        .and(channel_member::Column::Accepted.eq(true)),
 374                )
 375                .all(&*tx)
 376                .await?;
 377
 378            self.get_user_channels(user_id, channel_memberships, &tx)
 379                .await
 380        })
 381        .await
 382    }
 383
 384    pub async fn get_channel_for_user(
 385        &self,
 386        channel_id: ChannelId,
 387        user_id: UserId,
 388    ) -> Result<ChannelsForUser> {
 389        self.transaction(|tx| async move {
 390            let tx = tx;
 391
 392            let channel_membership = channel_member::Entity::find()
 393                .filter(
 394                    channel_member::Column::UserId
 395                        .eq(user_id)
 396                        .and(channel_member::Column::ChannelId.eq(channel_id))
 397                        .and(channel_member::Column::Accepted.eq(true)),
 398                )
 399                .all(&*tx)
 400                .await?;
 401
 402            self.get_user_channels(user_id, channel_membership, &tx)
 403                .await
 404        })
 405        .await
 406    }
 407
 408    pub async fn get_user_channels(
 409        &self,
 410        user_id: UserId,
 411        channel_memberships: Vec<channel_member::Model>,
 412        tx: &DatabaseTransaction,
 413    ) -> Result<ChannelsForUser> {
 414        let parents_by_child_id = self
 415            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 416            .await?;
 417
 418        let channels_with_admin_privileges = channel_memberships
 419            .iter()
 420            .filter_map(|membership| membership.admin.then_some(membership.channel_id))
 421            .collect();
 422
 423        let graph = self
 424            .get_channel_graph(parents_by_child_id, true, &tx)
 425            .await?;
 426
 427        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 428        enum QueryUserIdsAndChannelIds {
 429            ChannelId,
 430            UserId,
 431        }
 432
 433        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 434        {
 435            let mut rows = room_participant::Entity::find()
 436                .inner_join(room::Entity)
 437                .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
 438                .select_only()
 439                .column(room::Column::ChannelId)
 440                .column(room_participant::Column::UserId)
 441                .into_values::<_, QueryUserIdsAndChannelIds>()
 442                .stream(&*tx)
 443                .await?;
 444            while let Some(row) = rows.next().await {
 445                let row: (ChannelId, UserId) = row?;
 446                channel_participants.entry(row.0).or_default().push(row.1)
 447            }
 448        }
 449
 450        let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
 451        let channel_buffer_changes = self
 452            .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
 453            .await?;
 454
 455        let unseen_messages = self
 456            .unseen_channel_messages(user_id, &channel_ids, &*tx)
 457            .await?;
 458
 459        Ok(ChannelsForUser {
 460            channels: graph,
 461            channel_participants,
 462            channels_with_admin_privileges,
 463            unseen_buffer_changes: channel_buffer_changes,
 464            channel_messages: unseen_messages,
 465        })
 466    }
 467
 468    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 469        self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
 470            .await
 471    }
 472
 473    pub async fn set_channel_member_admin(
 474        &self,
 475        channel_id: ChannelId,
 476        from: UserId,
 477        for_user: UserId,
 478        admin: bool,
 479    ) -> Result<()> {
 480        self.transaction(|tx| async move {
 481            self.check_user_is_channel_admin(channel_id, from, &*tx)
 482                .await?;
 483
 484            let result = channel_member::Entity::update_many()
 485                .filter(
 486                    channel_member::Column::ChannelId
 487                        .eq(channel_id)
 488                        .and(channel_member::Column::UserId.eq(for_user)),
 489                )
 490                .set(channel_member::ActiveModel {
 491                    admin: ActiveValue::set(admin),
 492                    ..Default::default()
 493                })
 494                .exec(&*tx)
 495                .await?;
 496
 497            if result.rows_affected == 0 {
 498                Err(anyhow!("no such member"))?;
 499            }
 500
 501            Ok(())
 502        })
 503        .await
 504    }
 505
 506    pub async fn get_channel_member_details(
 507        &self,
 508        channel_id: ChannelId,
 509        user_id: UserId,
 510    ) -> Result<Vec<proto::ChannelMember>> {
 511        self.transaction(|tx| async move {
 512            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 513                .await?;
 514
 515            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 516            enum QueryMemberDetails {
 517                UserId,
 518                Admin,
 519                IsDirectMember,
 520                Accepted,
 521            }
 522
 523            let tx = tx;
 524            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 525            let mut stream = channel_member::Entity::find()
 526                .distinct()
 527                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 528                .select_only()
 529                .column(channel_member::Column::UserId)
 530                .column(channel_member::Column::Admin)
 531                .column_as(
 532                    channel_member::Column::ChannelId.eq(channel_id),
 533                    QueryMemberDetails::IsDirectMember,
 534                )
 535                .column(channel_member::Column::Accepted)
 536                .order_by_asc(channel_member::Column::UserId)
 537                .into_values::<_, QueryMemberDetails>()
 538                .stream(&*tx)
 539                .await?;
 540
 541            let mut rows = Vec::<proto::ChannelMember>::new();
 542            while let Some(row) = stream.next().await {
 543                let (user_id, is_admin, is_direct_member, is_invite_accepted): (
 544                    UserId,
 545                    bool,
 546                    bool,
 547                    bool,
 548                ) = row?;
 549                let kind = match (is_direct_member, is_invite_accepted) {
 550                    (true, true) => proto::channel_member::Kind::Member,
 551                    (true, false) => proto::channel_member::Kind::Invitee,
 552                    (false, true) => proto::channel_member::Kind::AncestorMember,
 553                    (false, false) => continue,
 554                };
 555                let user_id = user_id.to_proto();
 556                let kind = kind.into();
 557                if let Some(last_row) = rows.last_mut() {
 558                    if last_row.user_id == user_id {
 559                        if is_direct_member {
 560                            last_row.kind = kind;
 561                            last_row.admin = is_admin;
 562                        }
 563                        continue;
 564                    }
 565                }
 566                rows.push(proto::ChannelMember {
 567                    user_id,
 568                    kind,
 569                    admin: is_admin,
 570                });
 571            }
 572
 573            Ok(rows)
 574        })
 575        .await
 576    }
 577
 578    pub async fn get_channel_members_internal(
 579        &self,
 580        id: ChannelId,
 581        tx: &DatabaseTransaction,
 582    ) -> Result<Vec<UserId>> {
 583        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 584        let user_ids = channel_member::Entity::find()
 585            .distinct()
 586            .filter(
 587                channel_member::Column::ChannelId
 588                    .is_in(ancestor_ids.iter().copied())
 589                    .and(channel_member::Column::Accepted.eq(true)),
 590            )
 591            .select_only()
 592            .column(channel_member::Column::UserId)
 593            .into_values::<_, QueryUserIds>()
 594            .all(&*tx)
 595            .await?;
 596        Ok(user_ids)
 597    }
 598
 599    pub async fn check_user_is_channel_member(
 600        &self,
 601        channel_id: ChannelId,
 602        user_id: UserId,
 603        tx: &DatabaseTransaction,
 604    ) -> Result<()> {
 605        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 606        channel_member::Entity::find()
 607            .filter(
 608                channel_member::Column::ChannelId
 609                    .is_in(channel_ids)
 610                    .and(channel_member::Column::UserId.eq(user_id)),
 611            )
 612            .one(&*tx)
 613            .await?
 614            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
 615        Ok(())
 616    }
 617
 618    pub async fn check_user_is_channel_admin(
 619        &self,
 620        channel_id: ChannelId,
 621        user_id: UserId,
 622        tx: &DatabaseTransaction,
 623    ) -> Result<()> {
 624        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 625        channel_member::Entity::find()
 626            .filter(
 627                channel_member::Column::ChannelId
 628                    .is_in(channel_ids)
 629                    .and(channel_member::Column::UserId.eq(user_id))
 630                    .and(channel_member::Column::Admin.eq(true)),
 631            )
 632            .one(&*tx)
 633            .await?
 634            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
 635        Ok(())
 636    }
 637
 638    /// Returns the channel ancestors, deepest first
 639    pub async fn get_channel_ancestors(
 640        &self,
 641        channel_id: ChannelId,
 642        tx: &DatabaseTransaction,
 643    ) -> Result<Vec<ChannelId>> {
 644        let paths = channel_path::Entity::find()
 645            .filter(channel_path::Column::ChannelId.eq(channel_id))
 646            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 647            .all(tx)
 648            .await?;
 649        let mut channel_ids = Vec::new();
 650        for path in paths {
 651            for id in path.id_path.trim_matches('/').split('/') {
 652                if let Ok(id) = id.parse() {
 653                    let id = ChannelId::from_proto(id);
 654                    if let Err(ix) = channel_ids.binary_search(&id) {
 655                        channel_ids.insert(ix, id);
 656                    }
 657                }
 658            }
 659        }
 660        Ok(channel_ids)
 661    }
 662
 663    /// Returns the channel descendants,
 664    /// Structured as a map from child ids to their parent ids
 665    /// For example, the descendants of 'a' in this DAG:
 666    ///
 667    ///   /- b -\
 668    /// a -- c -- d
 669    ///
 670    /// would be:
 671    /// {
 672    ///     a: [],
 673    ///     b: [a],
 674    ///     c: [a],
 675    ///     d: [a, c],
 676    /// }
 677    async fn get_channel_descendants(
 678        &self,
 679        channel_ids: impl IntoIterator<Item = ChannelId>,
 680        tx: &DatabaseTransaction,
 681    ) -> Result<ChannelDescendants> {
 682        let mut values = String::new();
 683        for id in channel_ids {
 684            if !values.is_empty() {
 685                values.push_str(", ");
 686            }
 687            write!(&mut values, "({})", id).unwrap();
 688        }
 689
 690        if values.is_empty() {
 691            return Ok(HashMap::default());
 692        }
 693
 694        let sql = format!(
 695            r#"
 696            SELECT
 697                descendant_paths.*
 698            FROM
 699                channel_paths parent_paths, channel_paths descendant_paths
 700            WHERE
 701                parent_paths.channel_id IN ({values}) AND
 702                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
 703        "#
 704        );
 705
 706        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 707
 708        let mut parents_by_child_id: ChannelDescendants = HashMap::default();
 709        let mut paths = channel_path::Entity::find()
 710            .from_raw_sql(stmt)
 711            .stream(tx)
 712            .await?;
 713
 714        while let Some(path) = paths.next().await {
 715            let path = path?;
 716            let ids = path.id_path.trim_matches('/').split('/');
 717            let mut parent_id = None;
 718            for id in ids {
 719                if let Ok(id) = id.parse() {
 720                    let id = ChannelId::from_proto(id);
 721                    if id == path.channel_id {
 722                        break;
 723                    }
 724                    parent_id = Some(id);
 725                }
 726            }
 727            let entry = parents_by_child_id.entry(path.channel_id).or_default();
 728            if let Some(parent_id) = parent_id {
 729                entry.insert(parent_id);
 730            }
 731        }
 732
 733        Ok(parents_by_child_id)
 734    }
 735
 736    /// Returns the channel with the given ID and:
 737    /// - true if the user is a member
 738    /// - false if the user hasn't accepted the invitation yet
 739    pub async fn get_channel(
 740        &self,
 741        channel_id: ChannelId,
 742        user_id: UserId,
 743    ) -> Result<Option<(Channel, bool)>> {
 744        self.transaction(|tx| async move {
 745            let tx = tx;
 746
 747            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
 748
 749            if let Some(channel) = channel {
 750                if self
 751                    .check_user_is_channel_member(channel_id, user_id, &*tx)
 752                    .await
 753                    .is_err()
 754                {
 755                    return Ok(None);
 756                }
 757
 758                let channel_membership = channel_member::Entity::find()
 759                    .filter(
 760                        channel_member::Column::ChannelId
 761                            .eq(channel_id)
 762                            .and(channel_member::Column::UserId.eq(user_id)),
 763                    )
 764                    .one(&*tx)
 765                    .await?;
 766
 767                let is_accepted = channel_membership
 768                    .map(|membership| membership.accepted)
 769                    .unwrap_or(false);
 770
 771                Ok(Some((
 772                    Channel {
 773                        id: channel.id,
 774                        name: channel.name,
 775                    },
 776                    is_accepted,
 777                )))
 778            } else {
 779                Ok(None)
 780            }
 781        })
 782        .await
 783    }
 784
 785    pub async fn get_or_create_channel_room(
 786        &self,
 787        channel_id: ChannelId,
 788        live_kit_room: &str,
 789        enviroment: &str,
 790    ) -> Result<RoomId> {
 791        self.transaction(|tx| async move {
 792            let tx = tx;
 793
 794            let room = room::Entity::find()
 795                .filter(room::Column::ChannelId.eq(channel_id))
 796                .one(&*tx)
 797                .await?;
 798
 799            let room_id = if let Some(room) = room {
 800                room.id
 801            } else {
 802                let result = room::Entity::insert(room::ActiveModel {
 803                    channel_id: ActiveValue::Set(Some(channel_id)),
 804                    live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
 805                    enviroment: ActiveValue::Set(Some(enviroment.to_string())),
 806                    ..Default::default()
 807                })
 808                .exec(&*tx)
 809                .await?;
 810
 811                result.last_insert_id
 812            };
 813
 814            Ok(room_id)
 815        })
 816        .await
 817    }
 818
 819    // Insert an edge from the given channel to the given other channel.
 820    pub async fn link_channel(
 821        &self,
 822        user: UserId,
 823        channel: ChannelId,
 824        to: ChannelId,
 825    ) -> Result<ChannelGraph> {
 826        self.transaction(|tx| async move {
 827            // Note that even with these maxed permissions, this linking operation
 828            // is still insecure because you can't remove someone's permissions to a
 829            // channel if they've linked the channel to one where they're an admin.
 830            self.check_user_is_channel_admin(channel, user, &*tx)
 831                .await?;
 832
 833            self.link_channel_internal(user, channel, to, &*tx).await
 834        })
 835        .await
 836    }
 837
 838    pub async fn link_channel_internal(
 839        &self,
 840        user: UserId,
 841        channel: ChannelId,
 842        to: ChannelId,
 843        tx: &DatabaseTransaction,
 844    ) -> Result<ChannelGraph> {
 845        self.check_user_is_channel_admin(to, user, &*tx).await?;
 846
 847        let paths = channel_path::Entity::find()
 848            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
 849            .all(tx)
 850            .await?;
 851
 852        let mut new_path_suffixes = HashSet::default();
 853        for path in paths {
 854            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
 855                new_path_suffixes.insert((
 856                    path.channel_id,
 857                    path.id_path[(start_offset + 1)..].to_string(),
 858                ));
 859            }
 860        }
 861
 862        let paths_to_new_parent = channel_path::Entity::find()
 863            .filter(channel_path::Column::ChannelId.eq(to))
 864            .all(tx)
 865            .await?;
 866
 867        let mut new_paths = Vec::new();
 868        for path in paths_to_new_parent {
 869            if path.id_path.contains(&format!("/{}/", channel)) {
 870                Err(anyhow!("cycle"))?;
 871            }
 872
 873            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
 874                channel_path::ActiveModel {
 875                    channel_id: ActiveValue::Set(*channel_id),
 876                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
 877                }
 878            }));
 879        }
 880
 881        channel_path::Entity::insert_many(new_paths)
 882            .exec(&*tx)
 883            .await?;
 884
 885        // remove any root edges for the channel we just linked
 886        {
 887            channel_path::Entity::delete_many()
 888                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
 889                .exec(&*tx)
 890                .await?;
 891        }
 892
 893        let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
 894        if let Some(channel) = channel_descendants.get_mut(&channel) {
 895            // Remove the other parents
 896            channel.clear();
 897            channel.insert(to);
 898        }
 899
 900        let channels = self
 901            .get_channel_graph(channel_descendants, false, &*tx)
 902            .await?;
 903
 904        Ok(channels)
 905    }
 906
 907    /// Unlink a channel from a given parent. This will add in a root edge if
 908    /// the channel has no other parents after this operation.
 909    pub async fn unlink_channel(
 910        &self,
 911        user: UserId,
 912        channel: ChannelId,
 913        from: ChannelId,
 914    ) -> Result<()> {
 915        self.transaction(|tx| async move {
 916            // Note that even with these maxed permissions, this linking operation
 917            // is still insecure because you can't remove someone's permissions to a
 918            // channel if they've linked the channel to one where they're an admin.
 919            self.check_user_is_channel_admin(channel, user, &*tx)
 920                .await?;
 921
 922            self.unlink_channel_internal(user, channel, from, &*tx)
 923                .await?;
 924
 925            Ok(())
 926        })
 927        .await
 928    }
 929
 930    pub async fn unlink_channel_internal(
 931        &self,
 932        user: UserId,
 933        channel: ChannelId,
 934        from: ChannelId,
 935        tx: &DatabaseTransaction,
 936    ) -> Result<()> {
 937        self.check_user_is_channel_admin(from, user, &*tx).await?;
 938
 939        let sql = r#"
 940            DELETE FROM channel_paths
 941            WHERE
 942                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
 943            RETURNING id_path, channel_id
 944        "#;
 945
 946        let paths = channel_path::Entity::find()
 947            .from_raw_sql(Statement::from_sql_and_values(
 948                self.pool.get_database_backend(),
 949                sql,
 950                [from.to_proto().into(), channel.to_proto().into()],
 951            ))
 952            .all(&*tx)
 953            .await?;
 954
 955        let is_stranded = channel_path::Entity::find()
 956            .filter(channel_path::Column::ChannelId.eq(channel))
 957            .count(&*tx)
 958            .await?
 959            == 0;
 960
 961        // Make sure that there is always at least one path to the channel
 962        if is_stranded {
 963            let root_paths: Vec<_> = paths
 964                .iter()
 965                .map(|path| {
 966                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
 967                    channel_path::ActiveModel {
 968                        channel_id: ActiveValue::Set(path.channel_id),
 969                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
 970                    }
 971                })
 972                .collect();
 973            channel_path::Entity::insert_many(root_paths)
 974                .exec(&*tx)
 975                .await?;
 976        }
 977
 978        Ok(())
 979    }
 980
 981    /// Move a channel from one parent to another, returns the
 982    /// Channels that were moved for notifying clients
 983    pub async fn move_channel(
 984        &self,
 985        user: UserId,
 986        channel: ChannelId,
 987        from: ChannelId,
 988        to: ChannelId,
 989    ) -> Result<ChannelGraph> {
 990        if from == to {
 991            return Ok(ChannelGraph {
 992                channels: vec![],
 993                edges: vec![],
 994            });
 995        }
 996
 997        self.transaction(|tx| async move {
 998            self.check_user_is_channel_admin(channel, user, &*tx)
 999                .await?;
1000
1001            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1002
1003            self.unlink_channel_internal(user, channel, from, &*tx)
1004                .await?;
1005
1006            Ok(moved_channels)
1007        })
1008        .await
1009    }
1010}
1011
1012#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1013enum QueryUserIds {
1014    UserId,
1015}
1016
1017#[derive(Debug)]
1018pub struct ChannelGraph {
1019    pub channels: Vec<Channel>,
1020    pub edges: Vec<ChannelEdge>,
1021}
1022
1023impl ChannelGraph {
1024    pub fn is_empty(&self) -> bool {
1025        self.channels.is_empty() && self.edges.is_empty()
1026    }
1027}
1028
1029#[cfg(test)]
1030impl PartialEq for ChannelGraph {
1031    fn eq(&self, other: &Self) -> bool {
1032        // Order independent comparison for tests
1033        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1034        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1035        let edges_set = self
1036            .edges
1037            .iter()
1038            .map(|edge| (edge.channel_id, edge.parent_id))
1039            .collect::<HashSet<_>>();
1040        let other_edges_set = other
1041            .edges
1042            .iter()
1043            .map(|edge| (edge.channel_id, edge.parent_id))
1044            .collect::<HashSet<_>>();
1045
1046        channels_set == other_channels_set && edges_set == other_edges_set
1047    }
1048}
1049
1050#[cfg(not(test))]
1051impl PartialEq for ChannelGraph {
1052    fn eq(&self, other: &Self) -> bool {
1053        self.channels == other.channels && self.edges == other.edges
1054    }
1055}
1056
1057struct SmallSet<T>(SmallVec<[T; 1]>);
1058
1059impl<T> Deref for SmallSet<T> {
1060    type Target = [T];
1061
1062    fn deref(&self) -> &Self::Target {
1063        self.0.deref()
1064    }
1065}
1066
1067impl<T> Default for SmallSet<T> {
1068    fn default() -> Self {
1069        Self(SmallVec::new())
1070    }
1071}
1072
1073impl<T> SmallSet<T> {
1074    fn insert(&mut self, value: T) -> bool
1075    where
1076        T: Ord,
1077    {
1078        match self.binary_search(&value) {
1079            Ok(_) => false,
1080            Err(ix) => {
1081                self.0.insert(ix, value);
1082                true
1083            }
1084        }
1085    }
1086
1087    fn clear(&mut self) {
1088        self.0.clear();
1089    }
1090}