channels.rs

   1use super::*;
   2use rpc::proto::ChannelEdge;
   3use smallvec::SmallVec;
   4
   5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
   6
   7impl Database {
   8    #[cfg(test)]
   9    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
  10        self.transaction(move |tx| async move {
  11            let mut channels = Vec::new();
  12            let mut rows = channel::Entity::find().stream(&*tx).await?;
  13            while let Some(row) = rows.next().await {
  14                let row = row?;
  15                channels.push((row.id, row.name));
  16            }
  17            Ok(channels)
  18        })
  19        .await
  20    }
  21
  22    pub async fn create_root_channel(
  23        &self,
  24        name: &str,
  25        live_kit_room: &str,
  26        creator_id: UserId,
  27    ) -> Result<ChannelId> {
  28        self.create_channel(name, None, live_kit_room, creator_id)
  29            .await
  30    }
  31
  32    pub async fn create_channel(
  33        &self,
  34        name: &str,
  35        parent: Option<ChannelId>,
  36        live_kit_room: &str,
  37        creator_id: UserId,
  38    ) -> Result<ChannelId> {
  39        let name = Self::sanitize_channel_name(name)?;
  40        self.transaction(move |tx| async move {
  41            if let Some(parent) = parent {
  42                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  43                    .await?;
  44            }
  45
  46            let channel = channel::ActiveModel {
  47                name: ActiveValue::Set(name.to_string()),
  48                ..Default::default()
  49            }
  50            .insert(&*tx)
  51            .await?;
  52
  53            if let Some(parent) = parent {
  54                let sql = r#"
  55                    INSERT INTO channel_paths
  56                    (id_path, channel_id)
  57                    SELECT
  58                        id_path || $1 || '/', $2
  59                    FROM
  60                        channel_paths
  61                    WHERE
  62                        channel_id = $3
  63                "#;
  64                let channel_paths_stmt = Statement::from_sql_and_values(
  65                    self.pool.get_database_backend(),
  66                    sql,
  67                    [
  68                        channel.id.to_proto().into(),
  69                        channel.id.to_proto().into(),
  70                        parent.to_proto().into(),
  71                    ],
  72                );
  73                tx.execute(channel_paths_stmt).await?;
  74            } else {
  75                channel_path::Entity::insert(channel_path::ActiveModel {
  76                    channel_id: ActiveValue::Set(channel.id),
  77                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  78                })
  79                .exec(&*tx)
  80                .await?;
  81            }
  82
  83            channel_member::ActiveModel {
  84                channel_id: ActiveValue::Set(channel.id),
  85                user_id: ActiveValue::Set(creator_id),
  86                accepted: ActiveValue::Set(true),
  87                admin: ActiveValue::Set(true),
  88                ..Default::default()
  89            }
  90            .insert(&*tx)
  91            .await?;
  92
  93            room::ActiveModel {
  94                channel_id: ActiveValue::Set(Some(channel.id)),
  95                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
  96                ..Default::default()
  97            }
  98            .insert(&*tx)
  99            .await?;
 100
 101            Ok(channel.id)
 102        })
 103        .await
 104    }
 105
 106    pub async fn delete_channel(
 107        &self,
 108        channel_id: ChannelId,
 109        user_id: UserId,
 110    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
 111        self.transaction(move |tx| async move {
 112            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 113                .await?;
 114
 115            // Don't remove descendant channels that have additional parents.
 116            let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
 117            {
 118                let mut channels_to_keep = channel_path::Entity::find()
 119                    .filter(
 120                        channel_path::Column::ChannelId
 121                            .is_in(
 122                                channels_to_remove
 123                                    .keys()
 124                                    .copied()
 125                                    .filter(|&id| id != channel_id),
 126                            )
 127                            .and(
 128                                channel_path::Column::IdPath
 129                                    .not_like(&format!("%/{}/%", channel_id)),
 130                            ),
 131                    )
 132                    .stream(&*tx)
 133                    .await?;
 134                while let Some(row) = channels_to_keep.next().await {
 135                    let row = row?;
 136                    channels_to_remove.remove(&row.channel_id);
 137                }
 138            }
 139
 140            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 141            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 142                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 143                .select_only()
 144                .column(channel_member::Column::UserId)
 145                .distinct()
 146                .into_values::<_, QueryUserIds>()
 147                .all(&*tx)
 148                .await?;
 149
 150            channel::Entity::delete_many()
 151                .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
 152                .exec(&*tx)
 153                .await?;
 154
 155            // Delete any other paths that include this channel
 156            let sql = r#"
 157                    DELETE FROM channel_paths
 158                    WHERE
 159                        id_path LIKE '%' || $1 || '%'
 160                "#;
 161            let channel_paths_stmt = Statement::from_sql_and_values(
 162                self.pool.get_database_backend(),
 163                sql,
 164                [channel_id.to_proto().into()],
 165            );
 166            tx.execute(channel_paths_stmt).await?;
 167
 168            Ok((channels_to_remove.into_keys().collect(), members_to_notify))
 169        })
 170        .await
 171    }
 172
 173    pub async fn invite_channel_member(
 174        &self,
 175        channel_id: ChannelId,
 176        invitee_id: UserId,
 177        inviter_id: UserId,
 178        is_admin: bool,
 179    ) -> Result<()> {
 180        self.transaction(move |tx| async move {
 181            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
 182                .await?;
 183
 184            channel_member::ActiveModel {
 185                channel_id: ActiveValue::Set(channel_id),
 186                user_id: ActiveValue::Set(invitee_id),
 187                accepted: ActiveValue::Set(false),
 188                admin: ActiveValue::Set(is_admin),
 189                ..Default::default()
 190            }
 191            .insert(&*tx)
 192            .await?;
 193
 194            Ok(())
 195        })
 196        .await
 197    }
 198
 199    fn sanitize_channel_name(name: &str) -> Result<&str> {
 200        let new_name = name.trim().trim_start_matches('#');
 201        if new_name == "" {
 202            Err(anyhow!("channel name can't be blank"))?;
 203        }
 204        Ok(new_name)
 205    }
 206
 207    pub async fn rename_channel(
 208        &self,
 209        channel_id: ChannelId,
 210        user_id: UserId,
 211        new_name: &str,
 212    ) -> Result<String> {
 213        self.transaction(move |tx| async move {
 214            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 215
 216            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 217                .await?;
 218
 219            channel::ActiveModel {
 220                id: ActiveValue::Unchanged(channel_id),
 221                name: ActiveValue::Set(new_name.clone()),
 222                ..Default::default()
 223            }
 224            .update(&*tx)
 225            .await?;
 226
 227            Ok(new_name)
 228        })
 229        .await
 230    }
 231
 232    pub async fn respond_to_channel_invite(
 233        &self,
 234        channel_id: ChannelId,
 235        user_id: UserId,
 236        accept: bool,
 237    ) -> Result<()> {
 238        self.transaction(move |tx| async move {
 239            let rows_affected = if accept {
 240                channel_member::Entity::update_many()
 241                    .set(channel_member::ActiveModel {
 242                        accepted: ActiveValue::Set(accept),
 243                        ..Default::default()
 244                    })
 245                    .filter(
 246                        channel_member::Column::ChannelId
 247                            .eq(channel_id)
 248                            .and(channel_member::Column::UserId.eq(user_id))
 249                            .and(channel_member::Column::Accepted.eq(false)),
 250                    )
 251                    .exec(&*tx)
 252                    .await?
 253                    .rows_affected
 254            } else {
 255                channel_member::ActiveModel {
 256                    channel_id: ActiveValue::Unchanged(channel_id),
 257                    user_id: ActiveValue::Unchanged(user_id),
 258                    ..Default::default()
 259                }
 260                .delete(&*tx)
 261                .await?
 262                .rows_affected
 263            };
 264
 265            if rows_affected == 0 {
 266                Err(anyhow!("no such invitation"))?;
 267            }
 268
 269            Ok(())
 270        })
 271        .await
 272    }
 273
 274    pub async fn remove_channel_member(
 275        &self,
 276        channel_id: ChannelId,
 277        member_id: UserId,
 278        remover_id: UserId,
 279    ) -> Result<()> {
 280        self.transaction(|tx| async move {
 281            self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
 282                .await?;
 283
 284            let result = channel_member::Entity::delete_many()
 285                .filter(
 286                    channel_member::Column::ChannelId
 287                        .eq(channel_id)
 288                        .and(channel_member::Column::UserId.eq(member_id)),
 289                )
 290                .exec(&*tx)
 291                .await?;
 292
 293            if result.rows_affected == 0 {
 294                Err(anyhow!("no such member"))?;
 295            }
 296
 297            Ok(())
 298        })
 299        .await
 300    }
 301
 302    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 303        self.transaction(|tx| async move {
 304            let channel_invites = channel_member::Entity::find()
 305                .filter(
 306                    channel_member::Column::UserId
 307                        .eq(user_id)
 308                        .and(channel_member::Column::Accepted.eq(false)),
 309                )
 310                .all(&*tx)
 311                .await?;
 312
 313            let channels = channel::Entity::find()
 314                .filter(
 315                    channel::Column::Id.is_in(
 316                        channel_invites
 317                            .into_iter()
 318                            .map(|channel_member| channel_member.channel_id),
 319                    ),
 320                )
 321                .all(&*tx)
 322                .await?;
 323
 324            let channels = channels
 325                .into_iter()
 326                .map(|channel| Channel {
 327                    id: channel.id,
 328                    name: channel.name,
 329                })
 330                .collect();
 331
 332            Ok(channels)
 333        })
 334        .await
 335    }
 336
 337    async fn get_channel_graph(
 338        &self,
 339        parents_by_child_id: ChannelDescendants,
 340        trim_dangling_parents: bool,
 341        tx: &DatabaseTransaction,
 342    ) -> Result<ChannelGraph> {
 343        let mut channels = Vec::with_capacity(parents_by_child_id.len());
 344        {
 345            let mut rows = channel::Entity::find()
 346                .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
 347                .stream(&*tx)
 348                .await?;
 349            while let Some(row) = rows.next().await {
 350                let row = row?;
 351                channels.push(Channel {
 352                    id: row.id,
 353                    name: row.name,
 354                })
 355            }
 356        }
 357
 358        let mut edges = Vec::with_capacity(parents_by_child_id.len());
 359        for (channel, parents) in parents_by_child_id.iter() {
 360            for parent in parents.into_iter() {
 361                if trim_dangling_parents {
 362                    if parents_by_child_id.contains_key(parent) {
 363                        edges.push(ChannelEdge {
 364                            channel_id: channel.to_proto(),
 365                            parent_id: parent.to_proto(),
 366                        });
 367                    }
 368                } else {
 369                    edges.push(ChannelEdge {
 370                        channel_id: channel.to_proto(),
 371                        parent_id: parent.to_proto(),
 372                    });
 373                }
 374            }
 375        }
 376
 377        Ok(ChannelGraph { channels, edges })
 378    }
 379
 380    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 381        self.transaction(|tx| async move {
 382            let tx = tx;
 383
 384            let channel_memberships = channel_member::Entity::find()
 385                .filter(
 386                    channel_member::Column::UserId
 387                        .eq(user_id)
 388                        .and(channel_member::Column::Accepted.eq(true)),
 389                )
 390                .all(&*tx)
 391                .await?;
 392
 393            self.get_user_channels(user_id, channel_memberships, &tx)
 394                .await
 395        })
 396        .await
 397    }
 398
 399    pub async fn get_channel_for_user(
 400        &self,
 401        channel_id: ChannelId,
 402        user_id: UserId,
 403    ) -> Result<ChannelsForUser> {
 404        self.transaction(|tx| async move {
 405            let tx = tx;
 406
 407            let channel_membership = channel_member::Entity::find()
 408                .filter(
 409                    channel_member::Column::UserId
 410                        .eq(user_id)
 411                        .and(channel_member::Column::ChannelId.eq(channel_id))
 412                        .and(channel_member::Column::Accepted.eq(true)),
 413                )
 414                .all(&*tx)
 415                .await?;
 416
 417            self.get_user_channels(user_id, channel_membership, &tx)
 418                .await
 419        })
 420        .await
 421    }
 422
 423    pub async fn get_user_channels(
 424        &self,
 425        user_id: UserId,
 426        channel_memberships: Vec<channel_member::Model>,
 427        tx: &DatabaseTransaction,
 428    ) -> Result<ChannelsForUser> {
 429        let parents_by_child_id = self
 430            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 431            .await?;
 432
 433        let channels_with_admin_privileges = channel_memberships
 434            .iter()
 435            .filter_map(|membership| membership.admin.then_some(membership.channel_id))
 436            .collect();
 437
 438        let graph = self
 439            .get_channel_graph(parents_by_child_id, true, &tx)
 440            .await?;
 441
 442        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 443        enum QueryUserIdsAndChannelIds {
 444            ChannelId,
 445            UserId,
 446        }
 447
 448        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 449        {
 450            let mut rows = room_participant::Entity::find()
 451                .inner_join(room::Entity)
 452                .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
 453                .select_only()
 454                .column(room::Column::ChannelId)
 455                .column(room_participant::Column::UserId)
 456                .into_values::<_, QueryUserIdsAndChannelIds>()
 457                .stream(&*tx)
 458                .await?;
 459            while let Some(row) = rows.next().await {
 460                let row: (ChannelId, UserId) = row?;
 461                channel_participants.entry(row.0).or_default().push(row.1)
 462            }
 463        }
 464
 465        let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
 466        let channel_buffer_changes = self
 467            .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
 468            .await?;
 469
 470        let unseen_messages = self
 471            .unseen_channel_messages(user_id, &channel_ids, &*tx)
 472            .await?;
 473
 474        Ok(ChannelsForUser {
 475            channels: graph,
 476            channel_participants,
 477            channels_with_admin_privileges,
 478            unseen_buffer_changes: channel_buffer_changes,
 479            channel_messages: unseen_messages,
 480        })
 481    }
 482
 483    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 484        self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
 485            .await
 486    }
 487
 488    pub async fn set_channel_member_admin(
 489        &self,
 490        channel_id: ChannelId,
 491        from: UserId,
 492        for_user: UserId,
 493        admin: bool,
 494    ) -> Result<()> {
 495        self.transaction(|tx| async move {
 496            self.check_user_is_channel_admin(channel_id, from, &*tx)
 497                .await?;
 498
 499            let result = channel_member::Entity::update_many()
 500                .filter(
 501                    channel_member::Column::ChannelId
 502                        .eq(channel_id)
 503                        .and(channel_member::Column::UserId.eq(for_user)),
 504                )
 505                .set(channel_member::ActiveModel {
 506                    admin: ActiveValue::set(admin),
 507                    ..Default::default()
 508                })
 509                .exec(&*tx)
 510                .await?;
 511
 512            if result.rows_affected == 0 {
 513                Err(anyhow!("no such member"))?;
 514            }
 515
 516            Ok(())
 517        })
 518        .await
 519    }
 520
 521    pub async fn get_channel_member_details(
 522        &self,
 523        channel_id: ChannelId,
 524        user_id: UserId,
 525    ) -> Result<Vec<proto::ChannelMember>> {
 526        self.transaction(|tx| async move {
 527            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 528                .await?;
 529
 530            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 531            enum QueryMemberDetails {
 532                UserId,
 533                Admin,
 534                IsDirectMember,
 535                Accepted,
 536            }
 537
 538            let tx = tx;
 539            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 540            let mut stream = channel_member::Entity::find()
 541                .distinct()
 542                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 543                .select_only()
 544                .column(channel_member::Column::UserId)
 545                .column(channel_member::Column::Admin)
 546                .column_as(
 547                    channel_member::Column::ChannelId.eq(channel_id),
 548                    QueryMemberDetails::IsDirectMember,
 549                )
 550                .column(channel_member::Column::Accepted)
 551                .order_by_asc(channel_member::Column::UserId)
 552                .into_values::<_, QueryMemberDetails>()
 553                .stream(&*tx)
 554                .await?;
 555
 556            let mut rows = Vec::<proto::ChannelMember>::new();
 557            while let Some(row) = stream.next().await {
 558                let (user_id, is_admin, is_direct_member, is_invite_accepted): (
 559                    UserId,
 560                    bool,
 561                    bool,
 562                    bool,
 563                ) = row?;
 564                let kind = match (is_direct_member, is_invite_accepted) {
 565                    (true, true) => proto::channel_member::Kind::Member,
 566                    (true, false) => proto::channel_member::Kind::Invitee,
 567                    (false, true) => proto::channel_member::Kind::AncestorMember,
 568                    (false, false) => continue,
 569                };
 570                let user_id = user_id.to_proto();
 571                let kind = kind.into();
 572                if let Some(last_row) = rows.last_mut() {
 573                    if last_row.user_id == user_id {
 574                        if is_direct_member {
 575                            last_row.kind = kind;
 576                            last_row.admin = is_admin;
 577                        }
 578                        continue;
 579                    }
 580                }
 581                rows.push(proto::ChannelMember {
 582                    user_id,
 583                    kind,
 584                    admin: is_admin,
 585                });
 586            }
 587
 588            Ok(rows)
 589        })
 590        .await
 591    }
 592
 593    pub async fn get_channel_members_internal(
 594        &self,
 595        id: ChannelId,
 596        tx: &DatabaseTransaction,
 597    ) -> Result<Vec<UserId>> {
 598        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 599        let user_ids = channel_member::Entity::find()
 600            .distinct()
 601            .filter(
 602                channel_member::Column::ChannelId
 603                    .is_in(ancestor_ids.iter().copied())
 604                    .and(channel_member::Column::Accepted.eq(true)),
 605            )
 606            .select_only()
 607            .column(channel_member::Column::UserId)
 608            .into_values::<_, QueryUserIds>()
 609            .all(&*tx)
 610            .await?;
 611        Ok(user_ids)
 612    }
 613
 614    pub async fn check_user_is_channel_member(
 615        &self,
 616        channel_id: ChannelId,
 617        user_id: UserId,
 618        tx: &DatabaseTransaction,
 619    ) -> Result<()> {
 620        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 621        channel_member::Entity::find()
 622            .filter(
 623                channel_member::Column::ChannelId
 624                    .is_in(channel_ids)
 625                    .and(channel_member::Column::UserId.eq(user_id)),
 626            )
 627            .one(&*tx)
 628            .await?
 629            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
 630        Ok(())
 631    }
 632
 633    pub async fn check_user_is_channel_admin(
 634        &self,
 635        channel_id: ChannelId,
 636        user_id: UserId,
 637        tx: &DatabaseTransaction,
 638    ) -> Result<()> {
 639        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 640        channel_member::Entity::find()
 641            .filter(
 642                channel_member::Column::ChannelId
 643                    .is_in(channel_ids)
 644                    .and(channel_member::Column::UserId.eq(user_id))
 645                    .and(channel_member::Column::Admin.eq(true)),
 646            )
 647            .one(&*tx)
 648            .await?
 649            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
 650        Ok(())
 651    }
 652
 653    /// Returns the channel ancestors, deepest first
 654    pub async fn get_channel_ancestors(
 655        &self,
 656        channel_id: ChannelId,
 657        tx: &DatabaseTransaction,
 658    ) -> Result<Vec<ChannelId>> {
 659        let paths = channel_path::Entity::find()
 660            .filter(channel_path::Column::ChannelId.eq(channel_id))
 661            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 662            .all(tx)
 663            .await?;
 664        let mut channel_ids = Vec::new();
 665        for path in paths {
 666            for id in path.id_path.trim_matches('/').split('/') {
 667                if let Ok(id) = id.parse() {
 668                    let id = ChannelId::from_proto(id);
 669                    if let Err(ix) = channel_ids.binary_search(&id) {
 670                        channel_ids.insert(ix, id);
 671                    }
 672                }
 673            }
 674        }
 675        Ok(channel_ids)
 676    }
 677
 678    /// Returns the channel descendants,
 679    /// Structured as a map from child ids to their parent ids
 680    /// For example, the descendants of 'a' in this DAG:
 681    ///
 682    ///   /- b -\
 683    /// a -- c -- d
 684    ///
 685    /// would be:
 686    /// {
 687    ///     a: [],
 688    ///     b: [a],
 689    ///     c: [a],
 690    ///     d: [a, c],
 691    /// }
 692    async fn get_channel_descendants(
 693        &self,
 694        channel_ids: impl IntoIterator<Item = ChannelId>,
 695        tx: &DatabaseTransaction,
 696    ) -> Result<ChannelDescendants> {
 697        let mut values = String::new();
 698        for id in channel_ids {
 699            if !values.is_empty() {
 700                values.push_str(", ");
 701            }
 702            write!(&mut values, "({})", id).unwrap();
 703        }
 704
 705        if values.is_empty() {
 706            return Ok(HashMap::default());
 707        }
 708
 709        let sql = format!(
 710            r#"
 711            SELECT
 712                descendant_paths.*
 713            FROM
 714                channel_paths parent_paths, channel_paths descendant_paths
 715            WHERE
 716                parent_paths.channel_id IN ({values}) AND
 717                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
 718        "#
 719        );
 720
 721        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 722
 723        let mut parents_by_child_id: ChannelDescendants = HashMap::default();
 724        let mut paths = channel_path::Entity::find()
 725            .from_raw_sql(stmt)
 726            .stream(tx)
 727            .await?;
 728
 729        while let Some(path) = paths.next().await {
 730            let path = path?;
 731            let ids = path.id_path.trim_matches('/').split('/');
 732            let mut parent_id = None;
 733            for id in ids {
 734                if let Ok(id) = id.parse() {
 735                    let id = ChannelId::from_proto(id);
 736                    if id == path.channel_id {
 737                        break;
 738                    }
 739                    parent_id = Some(id);
 740                }
 741            }
 742            let entry = parents_by_child_id.entry(path.channel_id).or_default();
 743            if let Some(parent_id) = parent_id {
 744                entry.insert(parent_id);
 745            }
 746        }
 747
 748        Ok(parents_by_child_id)
 749    }
 750
 751    /// Returns the channel with the given ID and:
 752    /// - true if the user is a member
 753    /// - false if the user hasn't accepted the invitation yet
 754    pub async fn get_channel(
 755        &self,
 756        channel_id: ChannelId,
 757        user_id: UserId,
 758    ) -> Result<Option<(Channel, bool)>> {
 759        self.transaction(|tx| async move {
 760            let tx = tx;
 761
 762            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
 763
 764            if let Some(channel) = channel {
 765                if self
 766                    .check_user_is_channel_member(channel_id, user_id, &*tx)
 767                    .await
 768                    .is_err()
 769                {
 770                    return Ok(None);
 771                }
 772
 773                let channel_membership = channel_member::Entity::find()
 774                    .filter(
 775                        channel_member::Column::ChannelId
 776                            .eq(channel_id)
 777                            .and(channel_member::Column::UserId.eq(user_id)),
 778                    )
 779                    .one(&*tx)
 780                    .await?;
 781
 782                let is_accepted = channel_membership
 783                    .map(|membership| membership.accepted)
 784                    .unwrap_or(false);
 785
 786                Ok(Some((
 787                    Channel {
 788                        id: channel.id,
 789                        name: channel.name,
 790                    },
 791                    is_accepted,
 792                )))
 793            } else {
 794                Ok(None)
 795            }
 796        })
 797        .await
 798    }
 799
 800    pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
 801        self.transaction(|tx| async move {
 802            let tx = tx;
 803            let room = channel::Model {
 804                id: channel_id,
 805                ..Default::default()
 806            }
 807            .find_related(room::Entity)
 808            .one(&*tx)
 809            .await?
 810            .ok_or_else(|| anyhow!("invalid channel"))?;
 811            Ok(room.id)
 812        })
 813        .await
 814    }
 815
 816    // Insert an edge from the given channel to the given other channel.
 817    pub async fn link_channel(
 818        &self,
 819        user: UserId,
 820        channel: ChannelId,
 821        to: ChannelId,
 822    ) -> Result<ChannelGraph> {
 823        self.transaction(|tx| async move {
 824            // Note that even with these maxed permissions, this linking operation
 825            // is still insecure because you can't remove someone's permissions to a
 826            // channel if they've linked the channel to one where they're an admin.
 827            self.check_user_is_channel_admin(channel, user, &*tx)
 828                .await?;
 829
 830            self.link_channel_internal(user, channel, to, &*tx).await
 831        })
 832        .await
 833    }
 834
 835    pub async fn link_channel_internal(
 836        &self,
 837        user: UserId,
 838        channel: ChannelId,
 839        to: ChannelId,
 840        tx: &DatabaseTransaction,
 841    ) -> Result<ChannelGraph> {
 842        self.check_user_is_channel_admin(to, user, &*tx).await?;
 843
 844        let paths = channel_path::Entity::find()
 845            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
 846            .all(tx)
 847            .await?;
 848
 849        let mut new_path_suffixes = HashSet::default();
 850        for path in paths {
 851            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
 852                new_path_suffixes.insert((
 853                    path.channel_id,
 854                    path.id_path[(start_offset + 1)..].to_string(),
 855                ));
 856            }
 857        }
 858
 859        let paths_to_new_parent = channel_path::Entity::find()
 860            .filter(channel_path::Column::ChannelId.eq(to))
 861            .all(tx)
 862            .await?;
 863
 864        let mut new_paths = Vec::new();
 865        for path in paths_to_new_parent {
 866            if path.id_path.contains(&format!("/{}/", channel)) {
 867                Err(anyhow!("cycle"))?;
 868            }
 869
 870            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
 871                channel_path::ActiveModel {
 872                    channel_id: ActiveValue::Set(*channel_id),
 873                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
 874                }
 875            }));
 876        }
 877
 878        channel_path::Entity::insert_many(new_paths)
 879            .exec(&*tx)
 880            .await?;
 881
 882        // remove any root edges for the channel we just linked
 883        {
 884            channel_path::Entity::delete_many()
 885                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
 886                .exec(&*tx)
 887                .await?;
 888        }
 889
 890        let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
 891        if let Some(channel) = channel_descendants.get_mut(&channel) {
 892            // Remove the other parents
 893            channel.clear();
 894            channel.insert(to);
 895        }
 896
 897        let channels = self
 898            .get_channel_graph(channel_descendants, false, &*tx)
 899            .await?;
 900
 901        Ok(channels)
 902    }
 903
 904    /// Unlink a channel from a given parent. This will add in a root edge if
 905    /// the channel has no other parents after this operation.
 906    pub async fn unlink_channel(
 907        &self,
 908        user: UserId,
 909        channel: ChannelId,
 910        from: ChannelId,
 911    ) -> Result<()> {
 912        self.transaction(|tx| async move {
 913            // Note that even with these maxed permissions, this linking operation
 914            // is still insecure because you can't remove someone's permissions to a
 915            // channel if they've linked the channel to one where they're an admin.
 916            self.check_user_is_channel_admin(channel, user, &*tx)
 917                .await?;
 918
 919            self.unlink_channel_internal(user, channel, from, &*tx)
 920                .await?;
 921
 922            Ok(())
 923        })
 924        .await
 925    }
 926
 927    pub async fn unlink_channel_internal(
 928        &self,
 929        user: UserId,
 930        channel: ChannelId,
 931        from: ChannelId,
 932        tx: &DatabaseTransaction,
 933    ) -> Result<()> {
 934        self.check_user_is_channel_admin(from, user, &*tx).await?;
 935
 936        let sql = r#"
 937            DELETE FROM channel_paths
 938            WHERE
 939                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
 940            RETURNING id_path, channel_id
 941        "#;
 942
 943        let paths = channel_path::Entity::find()
 944            .from_raw_sql(Statement::from_sql_and_values(
 945                self.pool.get_database_backend(),
 946                sql,
 947                [from.to_proto().into(), channel.to_proto().into()],
 948            ))
 949            .all(&*tx)
 950            .await?;
 951
 952        let is_stranded = channel_path::Entity::find()
 953            .filter(channel_path::Column::ChannelId.eq(channel))
 954            .count(&*tx)
 955            .await?
 956            == 0;
 957
 958        // Make sure that there is always at least one path to the channel
 959        if is_stranded {
 960            let root_paths: Vec<_> = paths
 961                .iter()
 962                .map(|path| {
 963                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
 964                    channel_path::ActiveModel {
 965                        channel_id: ActiveValue::Set(path.channel_id),
 966                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
 967                    }
 968                })
 969                .collect();
 970            channel_path::Entity::insert_many(root_paths)
 971                .exec(&*tx)
 972                .await?;
 973        }
 974
 975        Ok(())
 976    }
 977
 978    /// Move a channel from one parent to another, returns the
 979    /// Channels that were moved for notifying clients
 980    pub async fn move_channel(
 981        &self,
 982        user: UserId,
 983        channel: ChannelId,
 984        from: ChannelId,
 985        to: ChannelId,
 986    ) -> Result<ChannelGraph> {
 987        if from == to {
 988            return Ok(ChannelGraph {
 989                channels: vec![],
 990                edges: vec![],
 991            });
 992        }
 993
 994        self.transaction(|tx| async move {
 995            self.check_user_is_channel_admin(channel, user, &*tx)
 996                .await?;
 997
 998            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
 999
1000            self.unlink_channel_internal(user, channel, from, &*tx)
1001                .await?;
1002
1003            Ok(moved_channels)
1004        })
1005        .await
1006    }
1007}
1008
1009#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1010enum QueryUserIds {
1011    UserId,
1012}
1013
1014#[derive(Debug)]
1015pub struct ChannelGraph {
1016    pub channels: Vec<Channel>,
1017    pub edges: Vec<ChannelEdge>,
1018}
1019
1020impl ChannelGraph {
1021    pub fn is_empty(&self) -> bool {
1022        self.channels.is_empty() && self.edges.is_empty()
1023    }
1024}
1025
1026#[cfg(test)]
1027impl PartialEq for ChannelGraph {
1028    fn eq(&self, other: &Self) -> bool {
1029        // Order independent comparison for tests
1030        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1031        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1032        let edges_set = self
1033            .edges
1034            .iter()
1035            .map(|edge| (edge.channel_id, edge.parent_id))
1036            .collect::<HashSet<_>>();
1037        let other_edges_set = other
1038            .edges
1039            .iter()
1040            .map(|edge| (edge.channel_id, edge.parent_id))
1041            .collect::<HashSet<_>>();
1042
1043        channels_set == other_channels_set && edges_set == other_edges_set
1044    }
1045}
1046
1047#[cfg(not(test))]
1048impl PartialEq for ChannelGraph {
1049    fn eq(&self, other: &Self) -> bool {
1050        self.channels == other.channels && self.edges == other.edges
1051    }
1052}
1053
1054struct SmallSet<T>(SmallVec<[T; 1]>);
1055
1056impl<T> Deref for SmallSet<T> {
1057    type Target = [T];
1058
1059    fn deref(&self) -> &Self::Target {
1060        self.0.deref()
1061    }
1062}
1063
1064impl<T> Default for SmallSet<T> {
1065    fn default() -> Self {
1066        Self(SmallVec::new())
1067    }
1068}
1069
1070impl<T> SmallSet<T> {
1071    fn insert(&mut self, value: T) -> bool
1072    where
1073        T: Ord,
1074    {
1075        match self.binary_search(&value) {
1076            Ok(_) => false,
1077            Err(ix) => {
1078                self.0.insert(ix, value);
1079                true
1080            }
1081        }
1082    }
1083
1084    fn clear(&mut self) {
1085        self.0.clear();
1086    }
1087}