channels.rs

   1use rpc::proto::ChannelEdge;
   2use smallvec::SmallVec;
   3
   4use super::*;
   5
   6type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
   7
   8impl Database {
   9    #[cfg(test)]
  10    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
  11        self.transaction(move |tx| async move {
  12            let mut channels = Vec::new();
  13            let mut rows = channel::Entity::find().stream(&*tx).await?;
  14            while let Some(row) = rows.next().await {
  15                let row = row?;
  16                channels.push((row.id, row.name));
  17            }
  18            Ok(channels)
  19        })
  20        .await
  21    }
  22
  23    pub async fn create_root_channel(
  24        &self,
  25        name: &str,
  26        live_kit_room: &str,
  27        creator_id: UserId,
  28    ) -> Result<ChannelId> {
  29        self.create_channel(name, None, live_kit_room, creator_id)
  30            .await
  31    }
  32
  33    pub async fn create_channel(
  34        &self,
  35        name: &str,
  36        parent: Option<ChannelId>,
  37        live_kit_room: &str,
  38        creator_id: UserId,
  39    ) -> Result<ChannelId> {
  40        let name = Self::sanitize_channel_name(name)?;
  41        self.transaction(move |tx| async move {
  42            if let Some(parent) = parent {
  43                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  44                    .await?;
  45            }
  46
  47            let channel = channel::ActiveModel {
  48                name: ActiveValue::Set(name.to_string()),
  49                ..Default::default()
  50            }
  51            .insert(&*tx)
  52            .await?;
  53
  54            if let Some(parent) = parent {
  55                let sql = r#"
  56                    INSERT INTO channel_paths
  57                    (id_path, channel_id)
  58                    SELECT
  59                        id_path || $1 || '/', $2
  60                    FROM
  61                        channel_paths
  62                    WHERE
  63                        channel_id = $3
  64                "#;
  65                let channel_paths_stmt = Statement::from_sql_and_values(
  66                    self.pool.get_database_backend(),
  67                    sql,
  68                    [
  69                        channel.id.to_proto().into(),
  70                        channel.id.to_proto().into(),
  71                        parent.to_proto().into(),
  72                    ],
  73                );
  74                tx.execute(channel_paths_stmt).await?;
  75            } else {
  76                channel_path::Entity::insert(channel_path::ActiveModel {
  77                    channel_id: ActiveValue::Set(channel.id),
  78                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  79                })
  80                .exec(&*tx)
  81                .await?;
  82            }
  83
  84            channel_member::ActiveModel {
  85                channel_id: ActiveValue::Set(channel.id),
  86                user_id: ActiveValue::Set(creator_id),
  87                accepted: ActiveValue::Set(true),
  88                admin: ActiveValue::Set(true),
  89                ..Default::default()
  90            }
  91            .insert(&*tx)
  92            .await?;
  93
  94            room::ActiveModel {
  95                channel_id: ActiveValue::Set(Some(channel.id)),
  96                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
  97                ..Default::default()
  98            }
  99            .insert(&*tx)
 100            .await?;
 101
 102            Ok(channel.id)
 103        })
 104        .await
 105    }
 106
 107    pub async fn delete_channel(
 108        &self,
 109        channel_id: ChannelId,
 110        user_id: UserId,
 111    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
 112        self.transaction(move |tx| async move {
 113            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 114                .await?;
 115
 116            // Don't remove descendant channels that have additional parents.
 117            let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
 118            {
 119                let mut channels_to_keep = channel_path::Entity::find()
 120                    .filter(
 121                        channel_path::Column::ChannelId
 122                            .is_in(
 123                                channels_to_remove
 124                                    .keys()
 125                                    .copied()
 126                                    .filter(|&id| id != channel_id),
 127                            )
 128                            .and(
 129                                channel_path::Column::IdPath
 130                                    .not_like(&format!("%/{}/%", channel_id)),
 131                            ),
 132                    )
 133                    .stream(&*tx)
 134                    .await?;
 135                while let Some(row) = channels_to_keep.next().await {
 136                    let row = row?;
 137                    channels_to_remove.remove(&row.channel_id);
 138                }
 139            }
 140
 141            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 142            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 143                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 144                .select_only()
 145                .column(channel_member::Column::UserId)
 146                .distinct()
 147                .into_values::<_, QueryUserIds>()
 148                .all(&*tx)
 149                .await?;
 150
 151            channel::Entity::delete_many()
 152                .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
 153                .exec(&*tx)
 154                .await?;
 155
 156            // Delete any other paths that include this channel
 157            let sql = r#"
 158                    DELETE FROM channel_paths
 159                    WHERE
 160                        id_path LIKE '%' || $1 || '%'
 161                "#;
 162            let channel_paths_stmt = Statement::from_sql_and_values(
 163                self.pool.get_database_backend(),
 164                sql,
 165                [channel_id.to_proto().into()],
 166            );
 167            tx.execute(channel_paths_stmt).await?;
 168
 169            Ok((channels_to_remove.into_keys().collect(), members_to_notify))
 170        })
 171        .await
 172    }
 173
 174    pub async fn invite_channel_member(
 175        &self,
 176        channel_id: ChannelId,
 177        invitee_id: UserId,
 178        inviter_id: UserId,
 179        is_admin: bool,
 180    ) -> Result<()> {
 181        self.transaction(move |tx| async move {
 182            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
 183                .await?;
 184
 185            channel_member::ActiveModel {
 186                channel_id: ActiveValue::Set(channel_id),
 187                user_id: ActiveValue::Set(invitee_id),
 188                accepted: ActiveValue::Set(false),
 189                admin: ActiveValue::Set(is_admin),
 190                ..Default::default()
 191            }
 192            .insert(&*tx)
 193            .await?;
 194
 195            Ok(())
 196        })
 197        .await
 198    }
 199
 200    fn sanitize_channel_name(name: &str) -> Result<&str> {
 201        let new_name = name.trim().trim_start_matches('#');
 202        if new_name == "" {
 203            Err(anyhow!("channel name can't be blank"))?;
 204        }
 205        Ok(new_name)
 206    }
 207
 208    pub async fn rename_channel(
 209        &self,
 210        channel_id: ChannelId,
 211        user_id: UserId,
 212        new_name: &str,
 213    ) -> Result<String> {
 214        self.transaction(move |tx| async move {
 215            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 216
 217            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 218                .await?;
 219
 220            channel::ActiveModel {
 221                id: ActiveValue::Unchanged(channel_id),
 222                name: ActiveValue::Set(new_name.clone()),
 223                ..Default::default()
 224            }
 225            .update(&*tx)
 226            .await?;
 227
 228            Ok(new_name)
 229        })
 230        .await
 231    }
 232
 233    pub async fn respond_to_channel_invite(
 234        &self,
 235        channel_id: ChannelId,
 236        user_id: UserId,
 237        accept: bool,
 238    ) -> Result<()> {
 239        self.transaction(move |tx| async move {
 240            let rows_affected = if accept {
 241                channel_member::Entity::update_many()
 242                    .set(channel_member::ActiveModel {
 243                        accepted: ActiveValue::Set(accept),
 244                        ..Default::default()
 245                    })
 246                    .filter(
 247                        channel_member::Column::ChannelId
 248                            .eq(channel_id)
 249                            .and(channel_member::Column::UserId.eq(user_id))
 250                            .and(channel_member::Column::Accepted.eq(false)),
 251                    )
 252                    .exec(&*tx)
 253                    .await?
 254                    .rows_affected
 255            } else {
 256                channel_member::ActiveModel {
 257                    channel_id: ActiveValue::Unchanged(channel_id),
 258                    user_id: ActiveValue::Unchanged(user_id),
 259                    ..Default::default()
 260                }
 261                .delete(&*tx)
 262                .await?
 263                .rows_affected
 264            };
 265
 266            if rows_affected == 0 {
 267                Err(anyhow!("no such invitation"))?;
 268            }
 269
 270            Ok(())
 271        })
 272        .await
 273    }
 274
 275    pub async fn remove_channel_member(
 276        &self,
 277        channel_id: ChannelId,
 278        member_id: UserId,
 279        remover_id: UserId,
 280    ) -> Result<()> {
 281        self.transaction(|tx| async move {
 282            self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
 283                .await?;
 284
 285            let result = channel_member::Entity::delete_many()
 286                .filter(
 287                    channel_member::Column::ChannelId
 288                        .eq(channel_id)
 289                        .and(channel_member::Column::UserId.eq(member_id)),
 290                )
 291                .exec(&*tx)
 292                .await?;
 293
 294            if result.rows_affected == 0 {
 295                Err(anyhow!("no such member"))?;
 296            }
 297
 298            Ok(())
 299        })
 300        .await
 301    }
 302
 303    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 304        self.transaction(|tx| async move {
 305            let channel_invites = channel_member::Entity::find()
 306                .filter(
 307                    channel_member::Column::UserId
 308                        .eq(user_id)
 309                        .and(channel_member::Column::Accepted.eq(false)),
 310                )
 311                .all(&*tx)
 312                .await?;
 313
 314            let channels = channel::Entity::find()
 315                .filter(
 316                    channel::Column::Id.is_in(
 317                        channel_invites
 318                            .into_iter()
 319                            .map(|channel_member| channel_member.channel_id),
 320                    ),
 321                )
 322                .all(&*tx)
 323                .await?;
 324
 325            let channels = channels
 326                .into_iter()
 327                .map(|channel| Channel {
 328                    id: channel.id,
 329                    name: channel.name,
 330                })
 331                .collect();
 332
 333            Ok(channels)
 334        })
 335        .await
 336    }
 337
 338    async fn get_channel_graph(
 339        &self,
 340        parents_by_child_id: ChannelDescendants,
 341        trim_dangling_parents: bool,
 342        tx: &DatabaseTransaction,
 343    ) -> Result<ChannelGraph> {
 344        let mut channels = Vec::with_capacity(parents_by_child_id.len());
 345        {
 346            let mut rows = channel::Entity::find()
 347                .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
 348                .stream(&*tx)
 349                .await?;
 350            while let Some(row) = rows.next().await {
 351                let row = row?;
 352                channels.push(Channel {
 353                    id: row.id,
 354                    name: row.name,
 355                })
 356            }
 357        }
 358
 359        let mut edges = Vec::with_capacity(parents_by_child_id.len());
 360        for (channel, parents) in parents_by_child_id.iter() {
 361            for parent in parents.into_iter() {
 362                if trim_dangling_parents {
 363                    if parents_by_child_id.contains_key(parent) {
 364                        edges.push(ChannelEdge {
 365                            channel_id: channel.to_proto(),
 366                            parent_id: parent.to_proto(),
 367                        });
 368                    }
 369                } else {
 370                    edges.push(ChannelEdge {
 371                        channel_id: channel.to_proto(),
 372                        parent_id: parent.to_proto(),
 373                    });
 374                }
 375            }
 376        }
 377
 378        Ok(ChannelGraph { channels, edges })
 379    }
 380
 381    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 382        self.transaction(|tx| async move {
 383            let tx = tx;
 384
 385            let channel_memberships = channel_member::Entity::find()
 386                .filter(
 387                    channel_member::Column::UserId
 388                        .eq(user_id)
 389                        .and(channel_member::Column::Accepted.eq(true)),
 390                )
 391                .all(&*tx)
 392                .await?;
 393
 394            self.get_user_channels(channel_memberships, &tx).await
 395        })
 396        .await
 397    }
 398
 399    pub async fn get_channel_for_user(
 400        &self,
 401        channel_id: ChannelId,
 402        user_id: UserId,
 403    ) -> Result<ChannelsForUser> {
 404        self.transaction(|tx| async move {
 405            let tx = tx;
 406
 407            let channel_membership = channel_member::Entity::find()
 408                .filter(
 409                    channel_member::Column::UserId
 410                        .eq(user_id)
 411                        .and(channel_member::Column::ChannelId.eq(channel_id))
 412                        .and(channel_member::Column::Accepted.eq(true)),
 413                )
 414                .all(&*tx)
 415                .await?;
 416
 417            self.get_user_channels(channel_membership, &tx).await
 418        })
 419        .await
 420    }
 421
 422    pub async fn get_user_channels(
 423        &self,
 424        channel_memberships: Vec<channel_member::Model>,
 425        tx: &DatabaseTransaction,
 426    ) -> Result<ChannelsForUser> {
 427        let parents_by_child_id = self
 428            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 429            .await?;
 430
 431        let channels_with_admin_privileges = channel_memberships
 432            .iter()
 433            .filter_map(|membership| membership.admin.then_some(membership.channel_id))
 434            .collect();
 435
 436        let graph = self
 437            .get_channel_graph(parents_by_child_id, true, &tx)
 438            .await?;
 439
 440        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 441        enum QueryUserIdsAndChannelIds {
 442            ChannelId,
 443            UserId,
 444        }
 445
 446        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 447        {
 448            let mut rows = room_participant::Entity::find()
 449                .inner_join(room::Entity)
 450                .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
 451                .select_only()
 452                .column(room::Column::ChannelId)
 453                .column(room_participant::Column::UserId)
 454                .into_values::<_, QueryUserIdsAndChannelIds>()
 455                .stream(&*tx)
 456                .await?;
 457            while let Some(row) = rows.next().await {
 458                let row: (ChannelId, UserId) = row?;
 459                channel_participants.entry(row.0).or_default().push(row.1)
 460            }
 461        }
 462
 463        Ok(ChannelsForUser {
 464            channels: graph,
 465            channel_participants,
 466            channels_with_admin_privileges,
 467        })
 468    }
 469
 470    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 471        self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
 472            .await
 473    }
 474
 475    pub async fn set_channel_member_admin(
 476        &self,
 477        channel_id: ChannelId,
 478        from: UserId,
 479        for_user: UserId,
 480        admin: bool,
 481    ) -> Result<()> {
 482        self.transaction(|tx| async move {
 483            self.check_user_is_channel_admin(channel_id, from, &*tx)
 484                .await?;
 485
 486            let result = channel_member::Entity::update_many()
 487                .filter(
 488                    channel_member::Column::ChannelId
 489                        .eq(channel_id)
 490                        .and(channel_member::Column::UserId.eq(for_user)),
 491                )
 492                .set(channel_member::ActiveModel {
 493                    admin: ActiveValue::set(admin),
 494                    ..Default::default()
 495                })
 496                .exec(&*tx)
 497                .await?;
 498
 499            if result.rows_affected == 0 {
 500                Err(anyhow!("no such member"))?;
 501            }
 502
 503            Ok(())
 504        })
 505        .await
 506    }
 507
 508    pub async fn get_channel_member_details(
 509        &self,
 510        channel_id: ChannelId,
 511        user_id: UserId,
 512    ) -> Result<Vec<proto::ChannelMember>> {
 513        self.transaction(|tx| async move {
 514            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 515                .await?;
 516
 517            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 518            enum QueryMemberDetails {
 519                UserId,
 520                Admin,
 521                IsDirectMember,
 522                Accepted,
 523            }
 524
 525            let tx = tx;
 526            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 527            let mut stream = channel_member::Entity::find()
 528                .distinct()
 529                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 530                .select_only()
 531                .column(channel_member::Column::UserId)
 532                .column(channel_member::Column::Admin)
 533                .column_as(
 534                    channel_member::Column::ChannelId.eq(channel_id),
 535                    QueryMemberDetails::IsDirectMember,
 536                )
 537                .column(channel_member::Column::Accepted)
 538                .order_by_asc(channel_member::Column::UserId)
 539                .into_values::<_, QueryMemberDetails>()
 540                .stream(&*tx)
 541                .await?;
 542
 543            let mut rows = Vec::<proto::ChannelMember>::new();
 544            while let Some(row) = stream.next().await {
 545                let (user_id, is_admin, is_direct_member, is_invite_accepted): (
 546                    UserId,
 547                    bool,
 548                    bool,
 549                    bool,
 550                ) = row?;
 551                let kind = match (is_direct_member, is_invite_accepted) {
 552                    (true, true) => proto::channel_member::Kind::Member,
 553                    (true, false) => proto::channel_member::Kind::Invitee,
 554                    (false, true) => proto::channel_member::Kind::AncestorMember,
 555                    (false, false) => continue,
 556                };
 557                let user_id = user_id.to_proto();
 558                let kind = kind.into();
 559                if let Some(last_row) = rows.last_mut() {
 560                    if last_row.user_id == user_id {
 561                        if is_direct_member {
 562                            last_row.kind = kind;
 563                            last_row.admin = is_admin;
 564                        }
 565                        continue;
 566                    }
 567                }
 568                rows.push(proto::ChannelMember {
 569                    user_id,
 570                    kind,
 571                    admin: is_admin,
 572                });
 573            }
 574
 575            Ok(rows)
 576        })
 577        .await
 578    }
 579
 580    pub async fn get_channel_members_internal(
 581        &self,
 582        id: ChannelId,
 583        tx: &DatabaseTransaction,
 584    ) -> Result<Vec<UserId>> {
 585        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 586        let user_ids = channel_member::Entity::find()
 587            .distinct()
 588            .filter(
 589                channel_member::Column::ChannelId
 590                    .is_in(ancestor_ids.iter().copied())
 591                    .and(channel_member::Column::Accepted.eq(true)),
 592            )
 593            .select_only()
 594            .column(channel_member::Column::UserId)
 595            .into_values::<_, QueryUserIds>()
 596            .all(&*tx)
 597            .await?;
 598        Ok(user_ids)
 599    }
 600
 601    pub async fn check_user_is_channel_member(
 602        &self,
 603        channel_id: ChannelId,
 604        user_id: UserId,
 605        tx: &DatabaseTransaction,
 606    ) -> Result<()> {
 607        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 608        channel_member::Entity::find()
 609            .filter(
 610                channel_member::Column::ChannelId
 611                    .is_in(channel_ids)
 612                    .and(channel_member::Column::UserId.eq(user_id)),
 613            )
 614            .one(&*tx)
 615            .await?
 616            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
 617        Ok(())
 618    }
 619
 620    pub async fn check_user_is_channel_admin(
 621        &self,
 622        channel_id: ChannelId,
 623        user_id: UserId,
 624        tx: &DatabaseTransaction,
 625    ) -> Result<()> {
 626        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 627        channel_member::Entity::find()
 628            .filter(
 629                channel_member::Column::ChannelId
 630                    .is_in(channel_ids)
 631                    .and(channel_member::Column::UserId.eq(user_id))
 632                    .and(channel_member::Column::Admin.eq(true)),
 633            )
 634            .one(&*tx)
 635            .await?
 636            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
 637        Ok(())
 638    }
 639
 640    /// Returns the channel ancestors, deepest first
 641    pub async fn get_channel_ancestors(
 642        &self,
 643        channel_id: ChannelId,
 644        tx: &DatabaseTransaction,
 645    ) -> Result<Vec<ChannelId>> {
 646        let paths = channel_path::Entity::find()
 647            .filter(channel_path::Column::ChannelId.eq(channel_id))
 648            .order_by(channel_path::Column::IdPath, sea_query::Order::Desc)
 649            .all(tx)
 650            .await?;
 651        let mut channel_ids = Vec::new();
 652        for path in paths {
 653            for id in path.id_path.trim_matches('/').split('/') {
 654                if let Ok(id) = id.parse() {
 655                    let id = ChannelId::from_proto(id);
 656                    if let Err(ix) = channel_ids.binary_search(&id) {
 657                        channel_ids.insert(ix, id);
 658                    }
 659                }
 660            }
 661        }
 662        Ok(channel_ids)
 663    }
 664
 665    /// Returns the channel descendants,
 666    /// Structured as a map from child ids to their parent ids
 667    /// For example, the descendants of 'a' in this DAG:
 668    ///
 669    ///   /- b -\
 670    /// a -- c -- d
 671    ///
 672    /// would be:
 673    /// {
 674    ///     a: [],
 675    ///     b: [a],
 676    ///     c: [a],
 677    ///     d: [a, c],
 678    /// }
 679    async fn get_channel_descendants(
 680        &self,
 681        channel_ids: impl IntoIterator<Item = ChannelId>,
 682        tx: &DatabaseTransaction,
 683    ) -> Result<ChannelDescendants> {
 684        let mut values = String::new();
 685        for id in channel_ids {
 686            if !values.is_empty() {
 687                values.push_str(", ");
 688            }
 689            write!(&mut values, "({})", id).unwrap();
 690        }
 691
 692        if values.is_empty() {
 693            return Ok(HashMap::default());
 694        }
 695
 696        let sql = format!(
 697            r#"
 698            SELECT
 699                descendant_paths.*
 700            FROM
 701                channel_paths parent_paths, channel_paths descendant_paths
 702            WHERE
 703                parent_paths.channel_id IN ({values}) AND
 704                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
 705        "#
 706        );
 707
 708        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 709
 710        let mut parents_by_child_id: ChannelDescendants = HashMap::default();
 711        let mut paths = channel_path::Entity::find()
 712            .from_raw_sql(stmt)
 713            .stream(tx)
 714            .await?;
 715
 716        while let Some(path) = paths.next().await {
 717            let path = path?;
 718            let ids = path.id_path.trim_matches('/').split('/');
 719            let mut parent_id = None;
 720            for id in ids {
 721                if let Ok(id) = id.parse() {
 722                    let id = ChannelId::from_proto(id);
 723                    if id == path.channel_id {
 724                        break;
 725                    }
 726                    parent_id = Some(id);
 727                }
 728            }
 729            let entry = parents_by_child_id.entry(path.channel_id).or_default();
 730            if let Some(parent_id) = parent_id {
 731                entry.insert(parent_id);
 732            }
 733        }
 734
 735        Ok(parents_by_child_id)
 736    }
 737
 738    /// Returns the channel with the given ID and:
 739    /// - true if the user is a member
 740    /// - false if the user hasn't accepted the invitation yet
 741    pub async fn get_channel(
 742        &self,
 743        channel_id: ChannelId,
 744        user_id: UserId,
 745    ) -> Result<Option<(Channel, bool)>> {
 746        self.transaction(|tx| async move {
 747            let tx = tx;
 748
 749            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
 750
 751            if let Some(channel) = channel {
 752                if self
 753                    .check_user_is_channel_member(channel_id, user_id, &*tx)
 754                    .await
 755                    .is_err()
 756                {
 757                    return Ok(None);
 758                }
 759
 760                let channel_membership = channel_member::Entity::find()
 761                    .filter(
 762                        channel_member::Column::ChannelId
 763                            .eq(channel_id)
 764                            .and(channel_member::Column::UserId.eq(user_id)),
 765                    )
 766                    .one(&*tx)
 767                    .await?;
 768
 769                let is_accepted = channel_membership
 770                    .map(|membership| membership.accepted)
 771                    .unwrap_or(false);
 772
 773                Ok(Some((
 774                    Channel {
 775                        id: channel.id,
 776                        name: channel.name,
 777                    },
 778                    is_accepted,
 779                )))
 780            } else {
 781                Ok(None)
 782            }
 783        })
 784        .await
 785    }
 786
 787    pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
 788        self.transaction(|tx| async move {
 789            let tx = tx;
 790            let room = channel::Model {
 791                id: channel_id,
 792                ..Default::default()
 793            }
 794            .find_related(room::Entity)
 795            .one(&*tx)
 796            .await?
 797            .ok_or_else(|| anyhow!("invalid channel"))?;
 798            Ok(room.id)
 799        })
 800        .await
 801    }
 802
 803    // Insert an edge from the given channel to the given other channel.
 804    pub async fn link_channel(
 805        &self,
 806        user: UserId,
 807        channel: ChannelId,
 808        to: ChannelId,
 809    ) -> Result<ChannelGraph> {
 810        self.transaction(|tx| async move {
 811            // Note that even with these maxed permissions, this linking operation
 812            // is still insecure because you can't remove someone's permissions to a
 813            // channel if they've linked the channel to one where they're an admin.
 814            self.check_user_is_channel_admin(channel, user, &*tx)
 815                .await?;
 816
 817            self.link_channel_internal(user, channel, to, &*tx).await
 818        })
 819        .await
 820    }
 821
 822    pub async fn link_channel_internal(
 823        &self,
 824        user: UserId,
 825        channel: ChannelId,
 826        to: ChannelId,
 827        tx: &DatabaseTransaction,
 828    ) -> Result<ChannelGraph> {
 829        self.check_user_is_channel_admin(to, user, &*tx).await?;
 830
 831        let paths = channel_path::Entity::find()
 832            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
 833            .all(tx)
 834            .await?;
 835
 836        let mut new_path_suffixes = HashSet::default();
 837        for path in paths {
 838            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
 839                new_path_suffixes.insert((
 840                    path.channel_id,
 841                    path.id_path[(start_offset + 1)..].to_string(),
 842                ));
 843            }
 844        }
 845
 846        let paths_to_new_parent = channel_path::Entity::find()
 847            .filter(channel_path::Column::ChannelId.eq(to))
 848            .all(tx)
 849            .await?;
 850
 851        let mut new_paths = Vec::new();
 852        for path in paths_to_new_parent {
 853            if path.id_path.contains(&format!("/{}/", channel)) {
 854                Err(anyhow!("cycle"))?;
 855            }
 856
 857            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
 858                channel_path::ActiveModel {
 859                    channel_id: ActiveValue::Set(*channel_id),
 860                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
 861                }
 862            }));
 863        }
 864
 865        channel_path::Entity::insert_many(new_paths)
 866            .exec(&*tx)
 867            .await?;
 868
 869        // remove any root edges for the channel we just linked
 870        {
 871            channel_path::Entity::delete_many()
 872                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
 873                .exec(&*tx)
 874                .await?;
 875        }
 876
 877        let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
 878        if let Some(channel) = channel_descendants.get_mut(&channel) {
 879            // Remove the other parents
 880            channel.clear();
 881            channel.insert(to);
 882        }
 883
 884        let channels = self
 885            .get_channel_graph(channel_descendants, false, &*tx)
 886            .await?;
 887
 888        Ok(channels)
 889    }
 890
 891    /// Unlink a channel from a given parent. This will add in a root edge if
 892    /// the channel has no other parents after this operation.
 893    pub async fn unlink_channel(
 894        &self,
 895        user: UserId,
 896        channel: ChannelId,
 897        from: ChannelId,
 898    ) -> Result<()> {
 899        self.transaction(|tx| async move {
 900            // Note that even with these maxed permissions, this linking operation
 901            // is still insecure because you can't remove someone's permissions to a
 902            // channel if they've linked the channel to one where they're an admin.
 903            self.check_user_is_channel_admin(channel, user, &*tx)
 904                .await?;
 905
 906            self.unlink_channel_internal(user, channel, from, &*tx)
 907                .await?;
 908
 909            Ok(())
 910        })
 911        .await
 912    }
 913
 914    pub async fn unlink_channel_internal(
 915        &self,
 916        user: UserId,
 917        channel: ChannelId,
 918        from: ChannelId,
 919        tx: &DatabaseTransaction,
 920    ) -> Result<()> {
 921        self.check_user_is_channel_admin(from, user, &*tx).await?;
 922
 923        let sql = r#"
 924            DELETE FROM channel_paths
 925            WHERE
 926                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
 927            RETURNING id_path, channel_id
 928        "#;
 929
 930        let paths = channel_path::Entity::find()
 931            .from_raw_sql(Statement::from_sql_and_values(
 932                self.pool.get_database_backend(),
 933                sql,
 934                [from.to_proto().into(), channel.to_proto().into()],
 935            ))
 936            .all(&*tx)
 937            .await?;
 938
 939        let is_stranded = channel_path::Entity::find()
 940            .filter(channel_path::Column::ChannelId.eq(channel))
 941            .count(&*tx)
 942            .await?
 943            == 0;
 944
 945        // Make sure that there is always at least one path to the channel
 946        if is_stranded {
 947            let root_paths: Vec<_> = paths
 948                .iter()
 949                .map(|path| {
 950                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
 951                    channel_path::ActiveModel {
 952                        channel_id: ActiveValue::Set(path.channel_id),
 953                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
 954                    }
 955                })
 956                .collect();
 957            channel_path::Entity::insert_many(root_paths)
 958                .exec(&*tx)
 959                .await?;
 960        }
 961
 962        Ok(())
 963    }
 964
 965    /// Move a channel from one parent to another, returns the
 966    /// Channels that were moved for notifying clients
 967    pub async fn move_channel(
 968        &self,
 969        user: UserId,
 970        channel: ChannelId,
 971        from: ChannelId,
 972        to: ChannelId,
 973    ) -> Result<ChannelGraph> {
 974        if from == to {
 975            return Ok(ChannelGraph {
 976                channels: vec![],
 977                edges: vec![],
 978            });
 979        }
 980
 981        self.transaction(|tx| async move {
 982            self.check_user_is_channel_admin(channel, user, &*tx)
 983                .await?;
 984
 985            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
 986
 987            self.unlink_channel_internal(user, channel, from, &*tx)
 988                .await?;
 989
 990            Ok(moved_channels)
 991        })
 992        .await
 993    }
 994}
 995
 996#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 997enum QueryUserIds {
 998    UserId,
 999}
1000
1001#[derive(Debug)]
1002pub struct ChannelGraph {
1003    pub channels: Vec<Channel>,
1004    pub edges: Vec<ChannelEdge>,
1005}
1006
1007impl ChannelGraph {
1008    pub fn is_empty(&self) -> bool {
1009        self.channels.is_empty() && self.edges.is_empty()
1010    }
1011}
1012
1013#[cfg(test)]
1014impl PartialEq for ChannelGraph {
1015    fn eq(&self, other: &Self) -> bool {
1016        // Order independent comparison for tests
1017        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1018        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1019        let edges_set = self
1020            .edges
1021            .iter()
1022            .map(|edge| (edge.channel_id, edge.parent_id))
1023            .collect::<HashSet<_>>();
1024        let other_edges_set = other
1025            .edges
1026            .iter()
1027            .map(|edge| (edge.channel_id, edge.parent_id))
1028            .collect::<HashSet<_>>();
1029
1030        channels_set == other_channels_set && edges_set == other_edges_set
1031    }
1032}
1033
1034#[cfg(not(test))]
1035impl PartialEq for ChannelGraph {
1036    fn eq(&self, other: &Self) -> bool {
1037        self.channels == other.channels && self.edges == other.edges
1038    }
1039}
1040
1041struct SmallSet<T>(SmallVec<[T; 1]>);
1042
1043impl<T> Deref for SmallSet<T> {
1044    type Target = [T];
1045
1046    fn deref(&self) -> &Self::Target {
1047        self.0.deref()
1048    }
1049}
1050
1051impl<T> Default for SmallSet<T> {
1052    fn default() -> Self {
1053        Self(SmallVec::new())
1054    }
1055}
1056
1057impl<T> SmallSet<T> {
1058    fn insert(&mut self, value: T) -> bool
1059    where
1060        T: Ord,
1061    {
1062        match self.binary_search(&value) {
1063            Ok(_) => false,
1064            Err(ix) => {
1065                self.0.insert(ix, value);
1066                true
1067            }
1068        }
1069    }
1070
1071    fn clear(&mut self) {
1072        self.0.clear();
1073    }
1074}