channels.rs

   1use super::*;
   2use rpc::proto::ChannelEdge;
   3use smallvec::SmallVec;
   4
   5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
   6
   7impl Database {
   8    #[cfg(test)]
   9    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
  10        self.transaction(move |tx| async move {
  11            let mut channels = Vec::new();
  12            let mut rows = channel::Entity::find().stream(&*tx).await?;
  13            while let Some(row) = rows.next().await {
  14                let row = row?;
  15                channels.push((row.id, row.name));
  16            }
  17            Ok(channels)
  18        })
  19        .await
  20    }
  21
  22    pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
  23        self.create_channel(name, None, creator_id).await
  24    }
  25
  26    pub async fn create_channel(
  27        &self,
  28        name: &str,
  29        parent: Option<ChannelId>,
  30        creator_id: UserId,
  31    ) -> Result<ChannelId> {
  32        let name = Self::sanitize_channel_name(name)?;
  33        self.transaction(move |tx| async move {
  34            if let Some(parent) = parent {
  35                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  36                    .await?;
  37            }
  38
  39            let channel = channel::ActiveModel {
  40                id: ActiveValue::NotSet,
  41                name: ActiveValue::Set(name.to_string()),
  42                visibility: ActiveValue::Set(ChannelVisibility::ChannelMembers),
  43            }
  44            .insert(&*tx)
  45            .await?;
  46
  47            if let Some(parent) = parent {
  48                let sql = r#"
  49                    INSERT INTO channel_paths
  50                    (id_path, channel_id)
  51                    SELECT
  52                        id_path || $1 || '/', $2
  53                    FROM
  54                        channel_paths
  55                    WHERE
  56                        channel_id = $3
  57                "#;
  58                let channel_paths_stmt = Statement::from_sql_and_values(
  59                    self.pool.get_database_backend(),
  60                    sql,
  61                    [
  62                        channel.id.to_proto().into(),
  63                        channel.id.to_proto().into(),
  64                        parent.to_proto().into(),
  65                    ],
  66                );
  67                tx.execute(channel_paths_stmt).await?;
  68            } else {
  69                channel_path::Entity::insert(channel_path::ActiveModel {
  70                    channel_id: ActiveValue::Set(channel.id),
  71                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  72                })
  73                .exec(&*tx)
  74                .await?;
  75            }
  76
  77            channel_member::ActiveModel {
  78                id: ActiveValue::NotSet,
  79                channel_id: ActiveValue::Set(channel.id),
  80                user_id: ActiveValue::Set(creator_id),
  81                accepted: ActiveValue::Set(true),
  82                admin: ActiveValue::Set(true),
  83                role: ActiveValue::Set(Some(ChannelRole::Admin)),
  84            }
  85            .insert(&*tx)
  86            .await?;
  87
  88            Ok(channel.id)
  89        })
  90        .await
  91    }
  92
  93    pub async fn set_channel_visibility(
  94        &self,
  95        channel_id: ChannelId,
  96        visibility: ChannelVisibility,
  97        user_id: UserId,
  98    ) -> Result<()> {
  99        self.transaction(move |tx| async move {
 100            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 101                .await?;
 102
 103            channel::ActiveModel {
 104                id: ActiveValue::Unchanged(channel_id),
 105                visibility: ActiveValue::Set(visibility),
 106                ..Default::default()
 107            }
 108            .update(&*tx)
 109            .await?;
 110
 111            Ok(())
 112        })
 113        .await
 114    }
 115
 116    pub async fn delete_channel(
 117        &self,
 118        channel_id: ChannelId,
 119        user_id: UserId,
 120    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
 121        self.transaction(move |tx| async move {
 122            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 123                .await?;
 124
 125            // Don't remove descendant channels that have additional parents.
 126            let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
 127            {
 128                let mut channels_to_keep = channel_path::Entity::find()
 129                    .filter(
 130                        channel_path::Column::ChannelId
 131                            .is_in(
 132                                channels_to_remove
 133                                    .keys()
 134                                    .copied()
 135                                    .filter(|&id| id != channel_id),
 136                            )
 137                            .and(
 138                                channel_path::Column::IdPath
 139                                    .not_like(&format!("%/{}/%", channel_id)),
 140                            ),
 141                    )
 142                    .stream(&*tx)
 143                    .await?;
 144                while let Some(row) = channels_to_keep.next().await {
 145                    let row = row?;
 146                    channels_to_remove.remove(&row.channel_id);
 147                }
 148            }
 149
 150            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 151            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 152                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 153                .select_only()
 154                .column(channel_member::Column::UserId)
 155                .distinct()
 156                .into_values::<_, QueryUserIds>()
 157                .all(&*tx)
 158                .await?;
 159
 160            channel::Entity::delete_many()
 161                .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
 162                .exec(&*tx)
 163                .await?;
 164
 165            // Delete any other paths that include this channel
 166            let sql = r#"
 167                    DELETE FROM channel_paths
 168                    WHERE
 169                        id_path LIKE '%' || $1 || '%'
 170                "#;
 171            let channel_paths_stmt = Statement::from_sql_and_values(
 172                self.pool.get_database_backend(),
 173                sql,
 174                [channel_id.to_proto().into()],
 175            );
 176            tx.execute(channel_paths_stmt).await?;
 177
 178            Ok((channels_to_remove.into_keys().collect(), members_to_notify))
 179        })
 180        .await
 181    }
 182
 183    pub async fn invite_channel_member(
 184        &self,
 185        channel_id: ChannelId,
 186        invitee_id: UserId,
 187        admin_id: UserId,
 188        role: ChannelRole,
 189    ) -> Result<()> {
 190        self.transaction(move |tx| async move {
 191            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 192                .await?;
 193
 194            channel_member::ActiveModel {
 195                id: ActiveValue::NotSet,
 196                channel_id: ActiveValue::Set(channel_id),
 197                user_id: ActiveValue::Set(invitee_id),
 198                accepted: ActiveValue::Set(false),
 199                admin: ActiveValue::Set(role == ChannelRole::Admin),
 200                role: ActiveValue::Set(Some(role)),
 201            }
 202            .insert(&*tx)
 203            .await?;
 204
 205            Ok(())
 206        })
 207        .await
 208    }
 209
 210    fn sanitize_channel_name(name: &str) -> Result<&str> {
 211        let new_name = name.trim().trim_start_matches('#');
 212        if new_name == "" {
 213            Err(anyhow!("channel name can't be blank"))?;
 214        }
 215        Ok(new_name)
 216    }
 217
 218    pub async fn rename_channel(
 219        &self,
 220        channel_id: ChannelId,
 221        user_id: UserId,
 222        new_name: &str,
 223    ) -> Result<String> {
 224        self.transaction(move |tx| async move {
 225            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 226
 227            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 228                .await?;
 229
 230            channel::ActiveModel {
 231                id: ActiveValue::Unchanged(channel_id),
 232                name: ActiveValue::Set(new_name.clone()),
 233                ..Default::default()
 234            }
 235            .update(&*tx)
 236            .await?;
 237
 238            Ok(new_name)
 239        })
 240        .await
 241    }
 242
 243    pub async fn respond_to_channel_invite(
 244        &self,
 245        channel_id: ChannelId,
 246        user_id: UserId,
 247        accept: bool,
 248    ) -> Result<()> {
 249        self.transaction(move |tx| async move {
 250            let rows_affected = if accept {
 251                channel_member::Entity::update_many()
 252                    .set(channel_member::ActiveModel {
 253                        accepted: ActiveValue::Set(accept),
 254                        ..Default::default()
 255                    })
 256                    .filter(
 257                        channel_member::Column::ChannelId
 258                            .eq(channel_id)
 259                            .and(channel_member::Column::UserId.eq(user_id))
 260                            .and(channel_member::Column::Accepted.eq(false)),
 261                    )
 262                    .exec(&*tx)
 263                    .await?
 264                    .rows_affected
 265            } else {
 266                channel_member::ActiveModel {
 267                    channel_id: ActiveValue::Unchanged(channel_id),
 268                    user_id: ActiveValue::Unchanged(user_id),
 269                    ..Default::default()
 270                }
 271                .delete(&*tx)
 272                .await?
 273                .rows_affected
 274            };
 275
 276            if rows_affected == 0 {
 277                Err(anyhow!("no such invitation"))?;
 278            }
 279
 280            Ok(())
 281        })
 282        .await
 283    }
 284
 285    pub async fn remove_channel_member(
 286        &self,
 287        channel_id: ChannelId,
 288        member_id: UserId,
 289        admin_id: UserId,
 290    ) -> Result<()> {
 291        self.transaction(|tx| async move {
 292            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 293                .await?;
 294
 295            let result = channel_member::Entity::delete_many()
 296                .filter(
 297                    channel_member::Column::ChannelId
 298                        .eq(channel_id)
 299                        .and(channel_member::Column::UserId.eq(member_id)),
 300                )
 301                .exec(&*tx)
 302                .await?;
 303
 304            if result.rows_affected == 0 {
 305                Err(anyhow!("no such member"))?;
 306            }
 307
 308            Ok(())
 309        })
 310        .await
 311    }
 312
 313    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 314        self.transaction(|tx| async move {
 315            let channel_invites = channel_member::Entity::find()
 316                .filter(
 317                    channel_member::Column::UserId
 318                        .eq(user_id)
 319                        .and(channel_member::Column::Accepted.eq(false)),
 320                )
 321                .all(&*tx)
 322                .await?;
 323
 324            let channels = channel::Entity::find()
 325                .filter(
 326                    channel::Column::Id.is_in(
 327                        channel_invites
 328                            .into_iter()
 329                            .map(|channel_member| channel_member.channel_id),
 330                    ),
 331                )
 332                .all(&*tx)
 333                .await?;
 334
 335            let channels = channels
 336                .into_iter()
 337                .map(|channel| Channel {
 338                    id: channel.id,
 339                    name: channel.name,
 340                })
 341                .collect();
 342
 343            Ok(channels)
 344        })
 345        .await
 346    }
 347
 348    async fn get_channel_graph(
 349        &self,
 350        parents_by_child_id: ChannelDescendants,
 351        trim_dangling_parents: bool,
 352        tx: &DatabaseTransaction,
 353    ) -> Result<ChannelGraph> {
 354        let mut channels = Vec::with_capacity(parents_by_child_id.len());
 355        {
 356            let mut rows = channel::Entity::find()
 357                .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
 358                .stream(&*tx)
 359                .await?;
 360            while let Some(row) = rows.next().await {
 361                let row = row?;
 362                channels.push(Channel {
 363                    id: row.id,
 364                    name: row.name,
 365                })
 366            }
 367        }
 368
 369        let mut edges = Vec::with_capacity(parents_by_child_id.len());
 370        for (channel, parents) in parents_by_child_id.iter() {
 371            for parent in parents.into_iter() {
 372                if trim_dangling_parents {
 373                    if parents_by_child_id.contains_key(parent) {
 374                        edges.push(ChannelEdge {
 375                            channel_id: channel.to_proto(),
 376                            parent_id: parent.to_proto(),
 377                        });
 378                    }
 379                } else {
 380                    edges.push(ChannelEdge {
 381                        channel_id: channel.to_proto(),
 382                        parent_id: parent.to_proto(),
 383                    });
 384                }
 385            }
 386        }
 387
 388        Ok(ChannelGraph { channels, edges })
 389    }
 390
 391    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 392        self.transaction(|tx| async move {
 393            let tx = tx;
 394
 395            let channel_memberships = channel_member::Entity::find()
 396                .filter(
 397                    channel_member::Column::UserId
 398                        .eq(user_id)
 399                        .and(channel_member::Column::Accepted.eq(true)),
 400                )
 401                .all(&*tx)
 402                .await?;
 403
 404            self.get_user_channels(user_id, channel_memberships, &tx)
 405                .await
 406        })
 407        .await
 408    }
 409
 410    pub async fn get_channel_for_user(
 411        &self,
 412        channel_id: ChannelId,
 413        user_id: UserId,
 414    ) -> Result<ChannelsForUser> {
 415        self.transaction(|tx| async move {
 416            let tx = tx;
 417
 418            let channel_membership = channel_member::Entity::find()
 419                .filter(
 420                    channel_member::Column::UserId
 421                        .eq(user_id)
 422                        .and(channel_member::Column::ChannelId.eq(channel_id))
 423                        .and(channel_member::Column::Accepted.eq(true)),
 424                )
 425                .all(&*tx)
 426                .await?;
 427
 428            self.get_user_channels(user_id, channel_membership, &tx)
 429                .await
 430        })
 431        .await
 432    }
 433
 434    pub async fn get_user_channels(
 435        &self,
 436        user_id: UserId,
 437        channel_memberships: Vec<channel_member::Model>,
 438        tx: &DatabaseTransaction,
 439    ) -> Result<ChannelsForUser> {
 440        let parents_by_child_id = self
 441            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 442            .await?;
 443
 444        let channels_with_admin_privileges = channel_memberships
 445            .iter()
 446            .filter_map(|membership| {
 447                if membership.role == Some(ChannelRole::Admin) || membership.admin {
 448                    Some(membership.channel_id)
 449                } else {
 450                    None
 451                }
 452            })
 453            .collect();
 454
 455        let graph = self
 456            .get_channel_graph(parents_by_child_id, true, &tx)
 457            .await?;
 458
 459        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 460        enum QueryUserIdsAndChannelIds {
 461            ChannelId,
 462            UserId,
 463        }
 464
 465        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 466        {
 467            let mut rows = room_participant::Entity::find()
 468                .inner_join(room::Entity)
 469                .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
 470                .select_only()
 471                .column(room::Column::ChannelId)
 472                .column(room_participant::Column::UserId)
 473                .into_values::<_, QueryUserIdsAndChannelIds>()
 474                .stream(&*tx)
 475                .await?;
 476            while let Some(row) = rows.next().await {
 477                let row: (ChannelId, UserId) = row?;
 478                channel_participants.entry(row.0).or_default().push(row.1)
 479            }
 480        }
 481
 482        let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
 483        let channel_buffer_changes = self
 484            .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
 485            .await?;
 486
 487        let unseen_messages = self
 488            .unseen_channel_messages(user_id, &channel_ids, &*tx)
 489            .await?;
 490
 491        Ok(ChannelsForUser {
 492            channels: graph,
 493            channel_participants,
 494            channels_with_admin_privileges,
 495            unseen_buffer_changes: channel_buffer_changes,
 496            channel_messages: unseen_messages,
 497        })
 498    }
 499
 500    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 501        self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
 502            .await
 503    }
 504
 505    pub async fn set_channel_member_role(
 506        &self,
 507        channel_id: ChannelId,
 508        admin_id: UserId,
 509        for_user: UserId,
 510        role: ChannelRole,
 511    ) -> Result<()> {
 512        self.transaction(|tx| async move {
 513            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 514                .await?;
 515
 516            let result = channel_member::Entity::update_many()
 517                .filter(
 518                    channel_member::Column::ChannelId
 519                        .eq(channel_id)
 520                        .and(channel_member::Column::UserId.eq(for_user)),
 521                )
 522                .set(channel_member::ActiveModel {
 523                    admin: ActiveValue::set(role == ChannelRole::Admin),
 524                    role: ActiveValue::set(Some(role)),
 525                    ..Default::default()
 526                })
 527                .exec(&*tx)
 528                .await?;
 529
 530            if result.rows_affected == 0 {
 531                Err(anyhow!("no such member"))?;
 532            }
 533
 534            Ok(())
 535        })
 536        .await
 537    }
 538
 539    pub async fn get_channel_member_details(
 540        &self,
 541        channel_id: ChannelId,
 542        user_id: UserId,
 543    ) -> Result<Vec<proto::ChannelMember>> {
 544        self.transaction(|tx| async move {
 545            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 546                .await?;
 547
 548            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 549            enum QueryMemberDetails {
 550                UserId,
 551                Admin,
 552                Role,
 553                IsDirectMember,
 554                Accepted,
 555            }
 556
 557            let tx = tx;
 558            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 559            let mut stream = channel_member::Entity::find()
 560                .distinct()
 561                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 562                .select_only()
 563                .column(channel_member::Column::UserId)
 564                .column(channel_member::Column::Admin)
 565                .column(channel_member::Column::Role)
 566                .column_as(
 567                    channel_member::Column::ChannelId.eq(channel_id),
 568                    QueryMemberDetails::IsDirectMember,
 569                )
 570                .column(channel_member::Column::Accepted)
 571                .order_by_asc(channel_member::Column::UserId)
 572                .into_values::<_, QueryMemberDetails>()
 573                .stream(&*tx)
 574                .await?;
 575
 576            let mut rows = Vec::<proto::ChannelMember>::new();
 577            while let Some(row) = stream.next().await {
 578                let (user_id, is_admin, channel_role, is_direct_member, is_invite_accepted): (
 579                    UserId,
 580                    bool,
 581                    Option<ChannelRole>,
 582                    bool,
 583                    bool,
 584                ) = row?;
 585                let kind = match (is_direct_member, is_invite_accepted) {
 586                    (true, true) => proto::channel_member::Kind::Member,
 587                    (true, false) => proto::channel_member::Kind::Invitee,
 588                    (false, true) => proto::channel_member::Kind::AncestorMember,
 589                    (false, false) => continue,
 590                };
 591                let channel_role = channel_role.unwrap_or(if is_admin {
 592                    ChannelRole::Admin
 593                } else {
 594                    ChannelRole::Member
 595                });
 596                let user_id = user_id.to_proto();
 597                let kind = kind.into();
 598                if let Some(last_row) = rows.last_mut() {
 599                    if last_row.user_id == user_id {
 600                        if is_direct_member {
 601                            last_row.kind = kind;
 602                            last_row.role = channel_role.into()
 603                        }
 604                        continue;
 605                    }
 606                }
 607                rows.push(proto::ChannelMember {
 608                    user_id,
 609                    kind,
 610                    role: channel_role.into(),
 611                });
 612            }
 613
 614            Ok(rows)
 615        })
 616        .await
 617    }
 618
 619    pub async fn get_channel_members_internal(
 620        &self,
 621        id: ChannelId,
 622        tx: &DatabaseTransaction,
 623    ) -> Result<Vec<UserId>> {
 624        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 625        let user_ids = channel_member::Entity::find()
 626            .distinct()
 627            .filter(
 628                channel_member::Column::ChannelId
 629                    .is_in(ancestor_ids.iter().copied())
 630                    .and(channel_member::Column::Accepted.eq(true)),
 631            )
 632            .select_only()
 633            .column(channel_member::Column::UserId)
 634            .into_values::<_, QueryUserIds>()
 635            .all(&*tx)
 636            .await?;
 637        Ok(user_ids)
 638    }
 639
 640    pub async fn check_user_is_channel_admin(
 641        &self,
 642        channel_id: ChannelId,
 643        user_id: UserId,
 644        tx: &DatabaseTransaction,
 645    ) -> Result<()> {
 646        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 647            Some(ChannelRole::Admin) => Ok(()),
 648            Some(ChannelRole::Member)
 649            | Some(ChannelRole::Banned)
 650            | Some(ChannelRole::Guest)
 651            | None => Err(anyhow!(
 652                "user is not a channel admin or channel does not exist"
 653            ))?,
 654        }
 655    }
 656
 657    pub async fn check_user_is_channel_member(
 658        &self,
 659        channel_id: ChannelId,
 660        user_id: UserId,
 661        tx: &DatabaseTransaction,
 662    ) -> Result<()> {
 663        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 664            Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()),
 665            Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
 666                "user is not a channel member or channel does not exist"
 667            ))?,
 668        }
 669    }
 670
 671    pub async fn check_user_is_channel_participant(
 672        &self,
 673        channel_id: ChannelId,
 674        user_id: UserId,
 675        tx: &DatabaseTransaction,
 676    ) -> Result<()> {
 677        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 678            Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
 679                Ok(())
 680            }
 681            Some(ChannelRole::Banned) | None => Err(anyhow!(
 682                "user is not a channel participant or channel does not exist"
 683            ))?,
 684        }
 685    }
 686
 687    pub async fn channel_role_for_user(
 688        &self,
 689        channel_id: ChannelId,
 690        user_id: UserId,
 691        tx: &DatabaseTransaction,
 692    ) -> Result<Option<ChannelRole>> {
 693        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 694
 695        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 696        enum QueryChannelMembership {
 697            ChannelId,
 698            Role,
 699            Admin,
 700            Visibility,
 701        }
 702
 703        let mut rows = channel_member::Entity::find()
 704            .left_join(channel::Entity)
 705            .filter(
 706                channel_member::Column::ChannelId
 707                    .is_in(channel_ids)
 708                    .and(channel_member::Column::UserId.eq(user_id)),
 709            )
 710            .select_only()
 711            .column(channel_member::Column::ChannelId)
 712            .column(channel_member::Column::Role)
 713            .column(channel_member::Column::Admin)
 714            .column(channel::Column::Visibility)
 715            .into_values::<_, QueryChannelMembership>()
 716            .stream(&*tx)
 717            .await?;
 718
 719        let mut is_admin = false;
 720        let mut is_member = false;
 721        let mut is_participant = false;
 722        let mut is_banned = false;
 723        let mut current_channel_visibility = None;
 724
 725        // note these channels are not iterated in any particular order,
 726        // our current logic takes the highest permission available.
 727        while let Some(row) = rows.next().await {
 728            let (ch_id, role, admin, visibility): (
 729                ChannelId,
 730                Option<ChannelRole>,
 731                bool,
 732                ChannelVisibility,
 733            ) = row?;
 734            match role {
 735                Some(ChannelRole::Admin) => is_admin = true,
 736                Some(ChannelRole::Member) => is_member = true,
 737                Some(ChannelRole::Guest) => {
 738                    if visibility == ChannelVisibility::Public {
 739                        is_participant = true
 740                    }
 741                }
 742                Some(ChannelRole::Banned) => is_banned = true,
 743                None => {
 744                    // rows created from pre-role collab server.
 745                    if admin {
 746                        is_admin = true
 747                    } else {
 748                        is_member = true
 749                    }
 750                }
 751            }
 752            if channel_id == ch_id {
 753                current_channel_visibility = Some(visibility);
 754            }
 755        }
 756        // free up database connection
 757        drop(rows);
 758
 759        Ok(if is_admin {
 760            Some(ChannelRole::Admin)
 761        } else if is_member {
 762            Some(ChannelRole::Member)
 763        } else if is_banned {
 764            Some(ChannelRole::Banned)
 765        } else if is_participant {
 766            if current_channel_visibility.is_none() {
 767                current_channel_visibility = channel::Entity::find()
 768                    .filter(channel::Column::Id.eq(channel_id))
 769                    .one(&*tx)
 770                    .await?
 771                    .map(|channel| channel.visibility);
 772            }
 773            if current_channel_visibility == Some(ChannelVisibility::Public) {
 774                Some(ChannelRole::Guest)
 775            } else {
 776                None
 777            }
 778        } else {
 779            None
 780        })
 781    }
 782
 783    /// Returns the channel ancestors, deepest first
 784    pub async fn get_channel_ancestors(
 785        &self,
 786        channel_id: ChannelId,
 787        tx: &DatabaseTransaction,
 788    ) -> Result<Vec<ChannelId>> {
 789        let paths = channel_path::Entity::find()
 790            .filter(channel_path::Column::ChannelId.eq(channel_id))
 791            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 792            .all(tx)
 793            .await?;
 794        let mut channel_ids = Vec::new();
 795        for path in paths {
 796            for id in path.id_path.trim_matches('/').split('/') {
 797                if let Ok(id) = id.parse() {
 798                    let id = ChannelId::from_proto(id);
 799                    if let Err(ix) = channel_ids.binary_search(&id) {
 800                        channel_ids.insert(ix, id);
 801                    }
 802                }
 803            }
 804        }
 805        Ok(channel_ids)
 806    }
 807
 808    /// Returns the channel descendants,
 809    /// Structured as a map from child ids to their parent ids
 810    /// For example, the descendants of 'a' in this DAG:
 811    ///
 812    ///   /- b -\
 813    /// a -- c -- d
 814    ///
 815    /// would be:
 816    /// {
 817    ///     a: [],
 818    ///     b: [a],
 819    ///     c: [a],
 820    ///     d: [a, c],
 821    /// }
 822    async fn get_channel_descendants(
 823        &self,
 824        channel_ids: impl IntoIterator<Item = ChannelId>,
 825        tx: &DatabaseTransaction,
 826    ) -> Result<ChannelDescendants> {
 827        let mut values = String::new();
 828        for id in channel_ids {
 829            if !values.is_empty() {
 830                values.push_str(", ");
 831            }
 832            write!(&mut values, "({})", id).unwrap();
 833        }
 834
 835        if values.is_empty() {
 836            return Ok(HashMap::default());
 837        }
 838
 839        let sql = format!(
 840            r#"
 841            SELECT
 842                descendant_paths.*
 843            FROM
 844                channel_paths parent_paths, channel_paths descendant_paths
 845            WHERE
 846                parent_paths.channel_id IN ({values}) AND
 847                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
 848        "#
 849        );
 850
 851        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 852
 853        let mut parents_by_child_id: ChannelDescendants = HashMap::default();
 854        let mut paths = channel_path::Entity::find()
 855            .from_raw_sql(stmt)
 856            .stream(tx)
 857            .await?;
 858
 859        while let Some(path) = paths.next().await {
 860            let path = path?;
 861            let ids = path.id_path.trim_matches('/').split('/');
 862            let mut parent_id = None;
 863            for id in ids {
 864                if let Ok(id) = id.parse() {
 865                    let id = ChannelId::from_proto(id);
 866                    if id == path.channel_id {
 867                        break;
 868                    }
 869                    parent_id = Some(id);
 870                }
 871            }
 872            let entry = parents_by_child_id.entry(path.channel_id).or_default();
 873            if let Some(parent_id) = parent_id {
 874                entry.insert(parent_id);
 875            }
 876        }
 877
 878        Ok(parents_by_child_id)
 879    }
 880
 881    /// Returns the channel with the given ID and:
 882    /// - true if the user is a member
 883    /// - false if the user hasn't accepted the invitation yet
 884    pub async fn get_channel(
 885        &self,
 886        channel_id: ChannelId,
 887        user_id: UserId,
 888    ) -> Result<Option<(Channel, bool)>> {
 889        self.transaction(|tx| async move {
 890            let tx = tx;
 891
 892            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
 893
 894            if let Some(channel) = channel {
 895                if self
 896                    .check_user_is_channel_member(channel_id, user_id, &*tx)
 897                    .await
 898                    .is_err()
 899                {
 900                    return Ok(None);
 901                }
 902
 903                let channel_membership = channel_member::Entity::find()
 904                    .filter(
 905                        channel_member::Column::ChannelId
 906                            .eq(channel_id)
 907                            .and(channel_member::Column::UserId.eq(user_id)),
 908                    )
 909                    .one(&*tx)
 910                    .await?;
 911
 912                let is_accepted = channel_membership
 913                    .map(|membership| membership.accepted)
 914                    .unwrap_or(false);
 915
 916                Ok(Some((
 917                    Channel {
 918                        id: channel.id,
 919                        name: channel.name,
 920                    },
 921                    is_accepted,
 922                )))
 923            } else {
 924                Ok(None)
 925            }
 926        })
 927        .await
 928    }
 929
 930    pub async fn get_or_create_channel_room(
 931        &self,
 932        channel_id: ChannelId,
 933        live_kit_room: &str,
 934        enviroment: &str,
 935    ) -> Result<RoomId> {
 936        self.transaction(|tx| async move {
 937            let tx = tx;
 938
 939            let room = room::Entity::find()
 940                .filter(room::Column::ChannelId.eq(channel_id))
 941                .one(&*tx)
 942                .await?;
 943
 944            let room_id = if let Some(room) = room {
 945                room.id
 946            } else {
 947                let result = room::Entity::insert(room::ActiveModel {
 948                    channel_id: ActiveValue::Set(Some(channel_id)),
 949                    live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
 950                    enviroment: ActiveValue::Set(Some(enviroment.to_string())),
 951                    ..Default::default()
 952                })
 953                .exec(&*tx)
 954                .await?;
 955
 956                result.last_insert_id
 957            };
 958
 959            Ok(room_id)
 960        })
 961        .await
 962    }
 963
 964    // Insert an edge from the given channel to the given other channel.
 965    pub async fn link_channel(
 966        &self,
 967        user: UserId,
 968        channel: ChannelId,
 969        to: ChannelId,
 970    ) -> Result<ChannelGraph> {
 971        self.transaction(|tx| async move {
 972            // Note that even with these maxed permissions, this linking operation
 973            // is still insecure because you can't remove someone's permissions to a
 974            // channel if they've linked the channel to one where they're an admin.
 975            self.check_user_is_channel_admin(channel, user, &*tx)
 976                .await?;
 977
 978            self.link_channel_internal(user, channel, to, &*tx).await
 979        })
 980        .await
 981    }
 982
 983    pub async fn link_channel_internal(
 984        &self,
 985        user: UserId,
 986        channel: ChannelId,
 987        new_parent: ChannelId,
 988        tx: &DatabaseTransaction,
 989    ) -> Result<ChannelGraph> {
 990        self.check_user_is_channel_admin(new_parent, user, &*tx)
 991            .await?;
 992
 993        let paths = channel_path::Entity::find()
 994            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
 995            .all(tx)
 996            .await?;
 997
 998        let mut new_path_suffixes = HashSet::default();
 999        for path in paths {
1000            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
1001                new_path_suffixes.insert((
1002                    path.channel_id,
1003                    path.id_path[(start_offset + 1)..].to_string(),
1004                ));
1005            }
1006        }
1007
1008        let paths_to_new_parent = channel_path::Entity::find()
1009            .filter(channel_path::Column::ChannelId.eq(new_parent))
1010            .all(tx)
1011            .await?;
1012
1013        let mut new_paths = Vec::new();
1014        for path in paths_to_new_parent {
1015            if path.id_path.contains(&format!("/{}/", channel)) {
1016                Err(anyhow!("cycle"))?;
1017            }
1018
1019            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
1020                channel_path::ActiveModel {
1021                    channel_id: ActiveValue::Set(*channel_id),
1022                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
1023                }
1024            }));
1025        }
1026
1027        channel_path::Entity::insert_many(new_paths)
1028            .exec(&*tx)
1029            .await?;
1030
1031        // remove any root edges for the channel we just linked
1032        {
1033            channel_path::Entity::delete_many()
1034                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
1035                .exec(&*tx)
1036                .await?;
1037        }
1038
1039        let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
1040        if let Some(channel) = channel_descendants.get_mut(&channel) {
1041            // Remove the other parents
1042            channel.clear();
1043            channel.insert(new_parent);
1044        }
1045
1046        let channels = self
1047            .get_channel_graph(channel_descendants, false, &*tx)
1048            .await?;
1049
1050        Ok(channels)
1051    }
1052
1053    /// Unlink a channel from a given parent. This will add in a root edge if
1054    /// the channel has no other parents after this operation.
1055    pub async fn unlink_channel(
1056        &self,
1057        user: UserId,
1058        channel: ChannelId,
1059        from: ChannelId,
1060    ) -> Result<()> {
1061        self.transaction(|tx| async move {
1062            // Note that even with these maxed permissions, this linking operation
1063            // is still insecure because you can't remove someone's permissions to a
1064            // channel if they've linked the channel to one where they're an admin.
1065            self.check_user_is_channel_admin(channel, user, &*tx)
1066                .await?;
1067
1068            self.unlink_channel_internal(user, channel, from, &*tx)
1069                .await?;
1070
1071            Ok(())
1072        })
1073        .await
1074    }
1075
1076    pub async fn unlink_channel_internal(
1077        &self,
1078        user: UserId,
1079        channel: ChannelId,
1080        from: ChannelId,
1081        tx: &DatabaseTransaction,
1082    ) -> Result<()> {
1083        self.check_user_is_channel_admin(from, user, &*tx).await?;
1084
1085        let sql = r#"
1086            DELETE FROM channel_paths
1087            WHERE
1088                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
1089            RETURNING id_path, channel_id
1090        "#;
1091
1092        let paths = channel_path::Entity::find()
1093            .from_raw_sql(Statement::from_sql_and_values(
1094                self.pool.get_database_backend(),
1095                sql,
1096                [from.to_proto().into(), channel.to_proto().into()],
1097            ))
1098            .all(&*tx)
1099            .await?;
1100
1101        let is_stranded = channel_path::Entity::find()
1102            .filter(channel_path::Column::ChannelId.eq(channel))
1103            .count(&*tx)
1104            .await?
1105            == 0;
1106
1107        // Make sure that there is always at least one path to the channel
1108        if is_stranded {
1109            let root_paths: Vec<_> = paths
1110                .iter()
1111                .map(|path| {
1112                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
1113                    channel_path::ActiveModel {
1114                        channel_id: ActiveValue::Set(path.channel_id),
1115                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
1116                    }
1117                })
1118                .collect();
1119            channel_path::Entity::insert_many(root_paths)
1120                .exec(&*tx)
1121                .await?;
1122        }
1123
1124        Ok(())
1125    }
1126
1127    /// Move a channel from one parent to another, returns the
1128    /// Channels that were moved for notifying clients
1129    pub async fn move_channel(
1130        &self,
1131        user: UserId,
1132        channel: ChannelId,
1133        from: ChannelId,
1134        to: ChannelId,
1135    ) -> Result<ChannelGraph> {
1136        if from == to {
1137            return Ok(ChannelGraph {
1138                channels: vec![],
1139                edges: vec![],
1140            });
1141        }
1142
1143        self.transaction(|tx| async move {
1144            self.check_user_is_channel_admin(channel, user, &*tx)
1145                .await?;
1146
1147            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1148
1149            self.unlink_channel_internal(user, channel, from, &*tx)
1150                .await?;
1151
1152            Ok(moved_channels)
1153        })
1154        .await
1155    }
1156}
1157
1158#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1159enum QueryUserIds {
1160    UserId,
1161}
1162
1163#[derive(Debug)]
1164pub struct ChannelGraph {
1165    pub channels: Vec<Channel>,
1166    pub edges: Vec<ChannelEdge>,
1167}
1168
1169impl ChannelGraph {
1170    pub fn is_empty(&self) -> bool {
1171        self.channels.is_empty() && self.edges.is_empty()
1172    }
1173}
1174
1175#[cfg(test)]
1176impl PartialEq for ChannelGraph {
1177    fn eq(&self, other: &Self) -> bool {
1178        // Order independent comparison for tests
1179        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1180        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1181        let edges_set = self
1182            .edges
1183            .iter()
1184            .map(|edge| (edge.channel_id, edge.parent_id))
1185            .collect::<HashSet<_>>();
1186        let other_edges_set = other
1187            .edges
1188            .iter()
1189            .map(|edge| (edge.channel_id, edge.parent_id))
1190            .collect::<HashSet<_>>();
1191
1192        channels_set == other_channels_set && edges_set == other_edges_set
1193    }
1194}
1195
1196#[cfg(not(test))]
1197impl PartialEq for ChannelGraph {
1198    fn eq(&self, other: &Self) -> bool {
1199        self.channels == other.channels && self.edges == other.edges
1200    }
1201}
1202
1203struct SmallSet<T>(SmallVec<[T; 1]>);
1204
1205impl<T> Deref for SmallSet<T> {
1206    type Target = [T];
1207
1208    fn deref(&self) -> &Self::Target {
1209        self.0.deref()
1210    }
1211}
1212
1213impl<T> Default for SmallSet<T> {
1214    fn default() -> Self {
1215        Self(SmallVec::new())
1216    }
1217}
1218
1219impl<T> SmallSet<T> {
1220    fn insert(&mut self, value: T) -> bool
1221    where
1222        T: Ord,
1223    {
1224        match self.binary_search(&value) {
1225            Ok(_) => false,
1226            Err(ix) => {
1227                self.0.insert(ix, value);
1228                true
1229            }
1230        }
1231    }
1232
1233    fn clear(&mut self) {
1234        self.0.clear();
1235    }
1236}