channels.rs

   1use super::*;
   2use rpc::proto::ChannelEdge;
   3use smallvec::SmallVec;
   4
   5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
   6
   7impl Database {
   8    #[cfg(test)]
   9    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
  10        self.transaction(move |tx| async move {
  11            let mut channels = Vec::new();
  12            let mut rows = channel::Entity::find().stream(&*tx).await?;
  13            while let Some(row) = rows.next().await {
  14                let row = row?;
  15                channels.push((row.id, row.name));
  16            }
  17            Ok(channels)
  18        })
  19        .await
  20    }
  21
  22    pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
  23        self.create_channel(name, None, creator_id).await
  24    }
  25
  26    pub async fn create_channel(
  27        &self,
  28        name: &str,
  29        parent: Option<ChannelId>,
  30        creator_id: UserId,
  31    ) -> Result<ChannelId> {
  32        let name = Self::sanitize_channel_name(name)?;
  33        self.transaction(move |tx| async move {
  34            if let Some(parent) = parent {
  35                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  36                    .await?;
  37            }
  38
  39            let channel = channel::ActiveModel {
  40                name: ActiveValue::Set(name.to_string()),
  41                ..Default::default()
  42            }
  43            .insert(&*tx)
  44            .await?;
  45
  46            if let Some(parent) = parent {
  47                let sql = r#"
  48                    INSERT INTO channel_paths
  49                    (id_path, channel_id)
  50                    SELECT
  51                        id_path || $1 || '/', $2
  52                    FROM
  53                        channel_paths
  54                    WHERE
  55                        channel_id = $3
  56                "#;
  57                let channel_paths_stmt = Statement::from_sql_and_values(
  58                    self.pool.get_database_backend(),
  59                    sql,
  60                    [
  61                        channel.id.to_proto().into(),
  62                        channel.id.to_proto().into(),
  63                        parent.to_proto().into(),
  64                    ],
  65                );
  66                tx.execute(channel_paths_stmt).await?;
  67            } else {
  68                channel_path::Entity::insert(channel_path::ActiveModel {
  69                    channel_id: ActiveValue::Set(channel.id),
  70                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  71                })
  72                .exec(&*tx)
  73                .await?;
  74            }
  75
  76            channel_member::ActiveModel {
  77                channel_id: ActiveValue::Set(channel.id),
  78                user_id: ActiveValue::Set(creator_id),
  79                accepted: ActiveValue::Set(true),
  80                admin: ActiveValue::Set(true),
  81                ..Default::default()
  82            }
  83            .insert(&*tx)
  84            .await?;
  85
  86            Ok(channel.id)
  87        })
  88        .await
  89    }
  90
  91    pub async fn delete_channel(
  92        &self,
  93        channel_id: ChannelId,
  94        user_id: UserId,
  95    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
  96        self.transaction(move |tx| async move {
  97            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
  98                .await?;
  99
 100            // Don't remove descendant channels that have additional parents.
 101            let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
 102            {
 103                let mut channels_to_keep = channel_path::Entity::find()
 104                    .filter(
 105                        channel_path::Column::ChannelId
 106                            .is_in(
 107                                channels_to_remove
 108                                    .keys()
 109                                    .copied()
 110                                    .filter(|&id| id != channel_id),
 111                            )
 112                            .and(
 113                                channel_path::Column::IdPath
 114                                    .not_like(&format!("%/{}/%", channel_id)),
 115                            ),
 116                    )
 117                    .stream(&*tx)
 118                    .await?;
 119                while let Some(row) = channels_to_keep.next().await {
 120                    let row = row?;
 121                    channels_to_remove.remove(&row.channel_id);
 122                }
 123            }
 124
 125            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 126            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 127                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 128                .select_only()
 129                .column(channel_member::Column::UserId)
 130                .distinct()
 131                .into_values::<_, QueryUserIds>()
 132                .all(&*tx)
 133                .await?;
 134
 135            channel::Entity::delete_many()
 136                .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
 137                .exec(&*tx)
 138                .await?;
 139
 140            // Delete any other paths that include this channel
 141            let sql = r#"
 142                    DELETE FROM channel_paths
 143                    WHERE
 144                        id_path LIKE '%' || $1 || '%'
 145                "#;
 146            let channel_paths_stmt = Statement::from_sql_and_values(
 147                self.pool.get_database_backend(),
 148                sql,
 149                [channel_id.to_proto().into()],
 150            );
 151            tx.execute(channel_paths_stmt).await?;
 152
 153            Ok((channels_to_remove.into_keys().collect(), members_to_notify))
 154        })
 155        .await
 156    }
 157
 158    pub async fn invite_channel_member(
 159        &self,
 160        channel_id: ChannelId,
 161        invitee_id: UserId,
 162        inviter_id: UserId,
 163        is_admin: bool,
 164    ) -> Result<Option<proto::Notification>> {
 165        self.transaction(move |tx| async move {
 166            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
 167                .await?;
 168
 169            channel_member::ActiveModel {
 170                channel_id: ActiveValue::Set(channel_id),
 171                user_id: ActiveValue::Set(invitee_id),
 172                accepted: ActiveValue::Set(false),
 173                admin: ActiveValue::Set(is_admin),
 174                ..Default::default()
 175            }
 176            .insert(&*tx)
 177            .await?;
 178
 179            self.create_notification(
 180                invitee_id,
 181                rpc::Notification::ChannelInvitation {
 182                    actor_id: inviter_id.to_proto(),
 183                    channel_id: channel_id.to_proto(),
 184                },
 185                true,
 186                &*tx,
 187            )
 188            .await
 189        })
 190        .await
 191    }
 192
 193    fn sanitize_channel_name(name: &str) -> Result<&str> {
 194        let new_name = name.trim().trim_start_matches('#');
 195        if new_name == "" {
 196            Err(anyhow!("channel name can't be blank"))?;
 197        }
 198        Ok(new_name)
 199    }
 200
 201    pub async fn rename_channel(
 202        &self,
 203        channel_id: ChannelId,
 204        user_id: UserId,
 205        new_name: &str,
 206    ) -> Result<String> {
 207        self.transaction(move |tx| async move {
 208            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 209
 210            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 211                .await?;
 212
 213            channel::ActiveModel {
 214                id: ActiveValue::Unchanged(channel_id),
 215                name: ActiveValue::Set(new_name.clone()),
 216                ..Default::default()
 217            }
 218            .update(&*tx)
 219            .await?;
 220
 221            Ok(new_name)
 222        })
 223        .await
 224    }
 225
 226    pub async fn respond_to_channel_invite(
 227        &self,
 228        channel_id: ChannelId,
 229        user_id: UserId,
 230        accept: bool,
 231    ) -> Result<()> {
 232        self.transaction(move |tx| async move {
 233            let rows_affected = if accept {
 234                channel_member::Entity::update_many()
 235                    .set(channel_member::ActiveModel {
 236                        accepted: ActiveValue::Set(accept),
 237                        ..Default::default()
 238                    })
 239                    .filter(
 240                        channel_member::Column::ChannelId
 241                            .eq(channel_id)
 242                            .and(channel_member::Column::UserId.eq(user_id))
 243                            .and(channel_member::Column::Accepted.eq(false)),
 244                    )
 245                    .exec(&*tx)
 246                    .await?
 247                    .rows_affected
 248            } else {
 249                channel_member::ActiveModel {
 250                    channel_id: ActiveValue::Unchanged(channel_id),
 251                    user_id: ActiveValue::Unchanged(user_id),
 252                    ..Default::default()
 253                }
 254                .delete(&*tx)
 255                .await?
 256                .rows_affected
 257            };
 258
 259            if rows_affected == 0 {
 260                Err(anyhow!("no such invitation"))?;
 261            }
 262
 263            Ok(())
 264        })
 265        .await
 266    }
 267
 268    pub async fn remove_channel_member(
 269        &self,
 270        channel_id: ChannelId,
 271        member_id: UserId,
 272        remover_id: UserId,
 273    ) -> Result<()> {
 274        self.transaction(|tx| async move {
 275            self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
 276                .await?;
 277
 278            let result = channel_member::Entity::delete_many()
 279                .filter(
 280                    channel_member::Column::ChannelId
 281                        .eq(channel_id)
 282                        .and(channel_member::Column::UserId.eq(member_id)),
 283                )
 284                .exec(&*tx)
 285                .await?;
 286
 287            if result.rows_affected == 0 {
 288                Err(anyhow!("no such member"))?;
 289            }
 290
 291            Ok(())
 292        })
 293        .await
 294    }
 295
 296    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 297        self.transaction(|tx| async move {
 298            let channel_invites = channel_member::Entity::find()
 299                .filter(
 300                    channel_member::Column::UserId
 301                        .eq(user_id)
 302                        .and(channel_member::Column::Accepted.eq(false)),
 303                )
 304                .all(&*tx)
 305                .await?;
 306
 307            let channels = channel::Entity::find()
 308                .filter(
 309                    channel::Column::Id.is_in(
 310                        channel_invites
 311                            .into_iter()
 312                            .map(|channel_member| channel_member.channel_id),
 313                    ),
 314                )
 315                .all(&*tx)
 316                .await?;
 317
 318            let channels = channels
 319                .into_iter()
 320                .map(|channel| Channel {
 321                    id: channel.id,
 322                    name: channel.name,
 323                })
 324                .collect();
 325
 326            Ok(channels)
 327        })
 328        .await
 329    }
 330
 331    async fn get_channel_graph(
 332        &self,
 333        parents_by_child_id: ChannelDescendants,
 334        trim_dangling_parents: bool,
 335        tx: &DatabaseTransaction,
 336    ) -> Result<ChannelGraph> {
 337        let mut channels = Vec::with_capacity(parents_by_child_id.len());
 338        {
 339            let mut rows = channel::Entity::find()
 340                .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
 341                .stream(&*tx)
 342                .await?;
 343            while let Some(row) = rows.next().await {
 344                let row = row?;
 345                channels.push(Channel {
 346                    id: row.id,
 347                    name: row.name,
 348                })
 349            }
 350        }
 351
 352        let mut edges = Vec::with_capacity(parents_by_child_id.len());
 353        for (channel, parents) in parents_by_child_id.iter() {
 354            for parent in parents.into_iter() {
 355                if trim_dangling_parents {
 356                    if parents_by_child_id.contains_key(parent) {
 357                        edges.push(ChannelEdge {
 358                            channel_id: channel.to_proto(),
 359                            parent_id: parent.to_proto(),
 360                        });
 361                    }
 362                } else {
 363                    edges.push(ChannelEdge {
 364                        channel_id: channel.to_proto(),
 365                        parent_id: parent.to_proto(),
 366                    });
 367                }
 368            }
 369        }
 370
 371        Ok(ChannelGraph { channels, edges })
 372    }
 373
 374    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 375        self.transaction(|tx| async move {
 376            let tx = tx;
 377
 378            let channel_memberships = channel_member::Entity::find()
 379                .filter(
 380                    channel_member::Column::UserId
 381                        .eq(user_id)
 382                        .and(channel_member::Column::Accepted.eq(true)),
 383                )
 384                .all(&*tx)
 385                .await?;
 386
 387            self.get_user_channels(user_id, channel_memberships, &tx)
 388                .await
 389        })
 390        .await
 391    }
 392
 393    pub async fn get_channel_for_user(
 394        &self,
 395        channel_id: ChannelId,
 396        user_id: UserId,
 397    ) -> Result<ChannelsForUser> {
 398        self.transaction(|tx| async move {
 399            let tx = tx;
 400
 401            let channel_membership = channel_member::Entity::find()
 402                .filter(
 403                    channel_member::Column::UserId
 404                        .eq(user_id)
 405                        .and(channel_member::Column::ChannelId.eq(channel_id))
 406                        .and(channel_member::Column::Accepted.eq(true)),
 407                )
 408                .all(&*tx)
 409                .await?;
 410
 411            self.get_user_channels(user_id, channel_membership, &tx)
 412                .await
 413        })
 414        .await
 415    }
 416
 417    pub async fn get_user_channels(
 418        &self,
 419        user_id: UserId,
 420        channel_memberships: Vec<channel_member::Model>,
 421        tx: &DatabaseTransaction,
 422    ) -> Result<ChannelsForUser> {
 423        let parents_by_child_id = self
 424            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 425            .await?;
 426
 427        let channels_with_admin_privileges = channel_memberships
 428            .iter()
 429            .filter_map(|membership| membership.admin.then_some(membership.channel_id))
 430            .collect();
 431
 432        let graph = self
 433            .get_channel_graph(parents_by_child_id, true, &tx)
 434            .await?;
 435
 436        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 437        enum QueryUserIdsAndChannelIds {
 438            ChannelId,
 439            UserId,
 440        }
 441
 442        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 443        {
 444            let mut rows = room_participant::Entity::find()
 445                .inner_join(room::Entity)
 446                .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
 447                .select_only()
 448                .column(room::Column::ChannelId)
 449                .column(room_participant::Column::UserId)
 450                .into_values::<_, QueryUserIdsAndChannelIds>()
 451                .stream(&*tx)
 452                .await?;
 453            while let Some(row) = rows.next().await {
 454                let row: (ChannelId, UserId) = row?;
 455                channel_participants.entry(row.0).or_default().push(row.1)
 456            }
 457        }
 458
 459        let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
 460        let channel_buffer_changes = self
 461            .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
 462            .await?;
 463
 464        let unseen_messages = self
 465            .unseen_channel_messages(user_id, &channel_ids, &*tx)
 466            .await?;
 467
 468        Ok(ChannelsForUser {
 469            channels: graph,
 470            channel_participants,
 471            channels_with_admin_privileges,
 472            unseen_buffer_changes: channel_buffer_changes,
 473            channel_messages: unseen_messages,
 474        })
 475    }
 476
 477    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 478        self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
 479            .await
 480    }
 481
 482    pub async fn set_channel_member_admin(
 483        &self,
 484        channel_id: ChannelId,
 485        from: UserId,
 486        for_user: UserId,
 487        admin: bool,
 488    ) -> Result<()> {
 489        self.transaction(|tx| async move {
 490            self.check_user_is_channel_admin(channel_id, from, &*tx)
 491                .await?;
 492
 493            let result = channel_member::Entity::update_many()
 494                .filter(
 495                    channel_member::Column::ChannelId
 496                        .eq(channel_id)
 497                        .and(channel_member::Column::UserId.eq(for_user)),
 498                )
 499                .set(channel_member::ActiveModel {
 500                    admin: ActiveValue::set(admin),
 501                    ..Default::default()
 502                })
 503                .exec(&*tx)
 504                .await?;
 505
 506            if result.rows_affected == 0 {
 507                Err(anyhow!("no such member"))?;
 508            }
 509
 510            Ok(())
 511        })
 512        .await
 513    }
 514
 515    pub async fn get_channel_member_details(
 516        &self,
 517        channel_id: ChannelId,
 518        user_id: UserId,
 519    ) -> Result<Vec<proto::ChannelMember>> {
 520        self.transaction(|tx| async move {
 521            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 522                .await?;
 523
 524            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 525            enum QueryMemberDetails {
 526                UserId,
 527                Admin,
 528                IsDirectMember,
 529                Accepted,
 530            }
 531
 532            let tx = tx;
 533            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 534            let mut stream = channel_member::Entity::find()
 535                .distinct()
 536                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 537                .select_only()
 538                .column(channel_member::Column::UserId)
 539                .column(channel_member::Column::Admin)
 540                .column_as(
 541                    channel_member::Column::ChannelId.eq(channel_id),
 542                    QueryMemberDetails::IsDirectMember,
 543                )
 544                .column(channel_member::Column::Accepted)
 545                .order_by_asc(channel_member::Column::UserId)
 546                .into_values::<_, QueryMemberDetails>()
 547                .stream(&*tx)
 548                .await?;
 549
 550            let mut rows = Vec::<proto::ChannelMember>::new();
 551            while let Some(row) = stream.next().await {
 552                let (user_id, is_admin, is_direct_member, is_invite_accepted): (
 553                    UserId,
 554                    bool,
 555                    bool,
 556                    bool,
 557                ) = row?;
 558                let kind = match (is_direct_member, is_invite_accepted) {
 559                    (true, true) => proto::channel_member::Kind::Member,
 560                    (true, false) => proto::channel_member::Kind::Invitee,
 561                    (false, true) => proto::channel_member::Kind::AncestorMember,
 562                    (false, false) => continue,
 563                };
 564                let user_id = user_id.to_proto();
 565                let kind = kind.into();
 566                if let Some(last_row) = rows.last_mut() {
 567                    if last_row.user_id == user_id {
 568                        if is_direct_member {
 569                            last_row.kind = kind;
 570                            last_row.admin = is_admin;
 571                        }
 572                        continue;
 573                    }
 574                }
 575                rows.push(proto::ChannelMember {
 576                    user_id,
 577                    kind,
 578                    admin: is_admin,
 579                });
 580            }
 581
 582            Ok(rows)
 583        })
 584        .await
 585    }
 586
 587    pub async fn get_channel_members_internal(
 588        &self,
 589        id: ChannelId,
 590        tx: &DatabaseTransaction,
 591    ) -> Result<Vec<UserId>> {
 592        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 593        let user_ids = channel_member::Entity::find()
 594            .distinct()
 595            .filter(
 596                channel_member::Column::ChannelId
 597                    .is_in(ancestor_ids.iter().copied())
 598                    .and(channel_member::Column::Accepted.eq(true)),
 599            )
 600            .select_only()
 601            .column(channel_member::Column::UserId)
 602            .into_values::<_, QueryUserIds>()
 603            .all(&*tx)
 604            .await?;
 605        Ok(user_ids)
 606    }
 607
 608    pub async fn check_user_is_channel_member(
 609        &self,
 610        channel_id: ChannelId,
 611        user_id: UserId,
 612        tx: &DatabaseTransaction,
 613    ) -> Result<()> {
 614        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 615        channel_member::Entity::find()
 616            .filter(
 617                channel_member::Column::ChannelId
 618                    .is_in(channel_ids)
 619                    .and(channel_member::Column::UserId.eq(user_id)),
 620            )
 621            .one(&*tx)
 622            .await?
 623            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
 624        Ok(())
 625    }
 626
 627    pub async fn check_user_is_channel_admin(
 628        &self,
 629        channel_id: ChannelId,
 630        user_id: UserId,
 631        tx: &DatabaseTransaction,
 632    ) -> Result<()> {
 633        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 634        channel_member::Entity::find()
 635            .filter(
 636                channel_member::Column::ChannelId
 637                    .is_in(channel_ids)
 638                    .and(channel_member::Column::UserId.eq(user_id))
 639                    .and(channel_member::Column::Admin.eq(true)),
 640            )
 641            .one(&*tx)
 642            .await?
 643            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
 644        Ok(())
 645    }
 646
 647    /// Returns the channel ancestors, deepest first
 648    pub async fn get_channel_ancestors(
 649        &self,
 650        channel_id: ChannelId,
 651        tx: &DatabaseTransaction,
 652    ) -> Result<Vec<ChannelId>> {
 653        let paths = channel_path::Entity::find()
 654            .filter(channel_path::Column::ChannelId.eq(channel_id))
 655            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 656            .all(tx)
 657            .await?;
 658        let mut channel_ids = Vec::new();
 659        for path in paths {
 660            for id in path.id_path.trim_matches('/').split('/') {
 661                if let Ok(id) = id.parse() {
 662                    let id = ChannelId::from_proto(id);
 663                    if let Err(ix) = channel_ids.binary_search(&id) {
 664                        channel_ids.insert(ix, id);
 665                    }
 666                }
 667            }
 668        }
 669        Ok(channel_ids)
 670    }
 671
 672    /// Returns the channel descendants,
 673    /// Structured as a map from child ids to their parent ids
 674    /// For example, the descendants of 'a' in this DAG:
 675    ///
 676    ///   /- b -\
 677    /// a -- c -- d
 678    ///
 679    /// would be:
 680    /// {
 681    ///     a: [],
 682    ///     b: [a],
 683    ///     c: [a],
 684    ///     d: [a, c],
 685    /// }
 686    async fn get_channel_descendants(
 687        &self,
 688        channel_ids: impl IntoIterator<Item = ChannelId>,
 689        tx: &DatabaseTransaction,
 690    ) -> Result<ChannelDescendants> {
 691        let mut values = String::new();
 692        for id in channel_ids {
 693            if !values.is_empty() {
 694                values.push_str(", ");
 695            }
 696            write!(&mut values, "({})", id).unwrap();
 697        }
 698
 699        if values.is_empty() {
 700            return Ok(HashMap::default());
 701        }
 702
 703        let sql = format!(
 704            r#"
 705            SELECT
 706                descendant_paths.*
 707            FROM
 708                channel_paths parent_paths, channel_paths descendant_paths
 709            WHERE
 710                parent_paths.channel_id IN ({values}) AND
 711                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
 712        "#
 713        );
 714
 715        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 716
 717        let mut parents_by_child_id: ChannelDescendants = HashMap::default();
 718        let mut paths = channel_path::Entity::find()
 719            .from_raw_sql(stmt)
 720            .stream(tx)
 721            .await?;
 722
 723        while let Some(path) = paths.next().await {
 724            let path = path?;
 725            let ids = path.id_path.trim_matches('/').split('/');
 726            let mut parent_id = None;
 727            for id in ids {
 728                if let Ok(id) = id.parse() {
 729                    let id = ChannelId::from_proto(id);
 730                    if id == path.channel_id {
 731                        break;
 732                    }
 733                    parent_id = Some(id);
 734                }
 735            }
 736            let entry = parents_by_child_id.entry(path.channel_id).or_default();
 737            if let Some(parent_id) = parent_id {
 738                entry.insert(parent_id);
 739            }
 740        }
 741
 742        Ok(parents_by_child_id)
 743    }
 744
 745    /// Returns the channel with the given ID and:
 746    /// - true if the user is a member
 747    /// - false if the user hasn't accepted the invitation yet
 748    pub async fn get_channel(
 749        &self,
 750        channel_id: ChannelId,
 751        user_id: UserId,
 752    ) -> Result<Option<(Channel, bool)>> {
 753        self.transaction(|tx| async move {
 754            let tx = tx;
 755
 756            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
 757
 758            if let Some(channel) = channel {
 759                if self
 760                    .check_user_is_channel_member(channel_id, user_id, &*tx)
 761                    .await
 762                    .is_err()
 763                {
 764                    return Ok(None);
 765                }
 766
 767                let channel_membership = channel_member::Entity::find()
 768                    .filter(
 769                        channel_member::Column::ChannelId
 770                            .eq(channel_id)
 771                            .and(channel_member::Column::UserId.eq(user_id)),
 772                    )
 773                    .one(&*tx)
 774                    .await?;
 775
 776                let is_accepted = channel_membership
 777                    .map(|membership| membership.accepted)
 778                    .unwrap_or(false);
 779
 780                Ok(Some((
 781                    Channel {
 782                        id: channel.id,
 783                        name: channel.name,
 784                    },
 785                    is_accepted,
 786                )))
 787            } else {
 788                Ok(None)
 789            }
 790        })
 791        .await
 792    }
 793
 794    pub async fn get_or_create_channel_room(
 795        &self,
 796        channel_id: ChannelId,
 797        live_kit_room: &str,
 798        enviroment: &str,
 799    ) -> Result<RoomId> {
 800        self.transaction(|tx| async move {
 801            let tx = tx;
 802
 803            let room = room::Entity::find()
 804                .filter(room::Column::ChannelId.eq(channel_id))
 805                .one(&*tx)
 806                .await?;
 807
 808            let room_id = if let Some(room) = room {
 809                room.id
 810            } else {
 811                let result = room::Entity::insert(room::ActiveModel {
 812                    channel_id: ActiveValue::Set(Some(channel_id)),
 813                    live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
 814                    enviroment: ActiveValue::Set(Some(enviroment.to_string())),
 815                    ..Default::default()
 816                })
 817                .exec(&*tx)
 818                .await?;
 819
 820                result.last_insert_id
 821            };
 822
 823            Ok(room_id)
 824        })
 825        .await
 826    }
 827
 828    // Insert an edge from the given channel to the given other channel.
 829    pub async fn link_channel(
 830        &self,
 831        user: UserId,
 832        channel: ChannelId,
 833        to: ChannelId,
 834    ) -> Result<ChannelGraph> {
 835        self.transaction(|tx| async move {
 836            // Note that even with these maxed permissions, this linking operation
 837            // is still insecure because you can't remove someone's permissions to a
 838            // channel if they've linked the channel to one where they're an admin.
 839            self.check_user_is_channel_admin(channel, user, &*tx)
 840                .await?;
 841
 842            self.link_channel_internal(user, channel, to, &*tx).await
 843        })
 844        .await
 845    }
 846
 847    pub async fn link_channel_internal(
 848        &self,
 849        user: UserId,
 850        channel: ChannelId,
 851        to: ChannelId,
 852        tx: &DatabaseTransaction,
 853    ) -> Result<ChannelGraph> {
 854        self.check_user_is_channel_admin(to, user, &*tx).await?;
 855
 856        let paths = channel_path::Entity::find()
 857            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
 858            .all(tx)
 859            .await?;
 860
 861        let mut new_path_suffixes = HashSet::default();
 862        for path in paths {
 863            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
 864                new_path_suffixes.insert((
 865                    path.channel_id,
 866                    path.id_path[(start_offset + 1)..].to_string(),
 867                ));
 868            }
 869        }
 870
 871        let paths_to_new_parent = channel_path::Entity::find()
 872            .filter(channel_path::Column::ChannelId.eq(to))
 873            .all(tx)
 874            .await?;
 875
 876        let mut new_paths = Vec::new();
 877        for path in paths_to_new_parent {
 878            if path.id_path.contains(&format!("/{}/", channel)) {
 879                Err(anyhow!("cycle"))?;
 880            }
 881
 882            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
 883                channel_path::ActiveModel {
 884                    channel_id: ActiveValue::Set(*channel_id),
 885                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
 886                }
 887            }));
 888        }
 889
 890        channel_path::Entity::insert_many(new_paths)
 891            .exec(&*tx)
 892            .await?;
 893
 894        // remove any root edges for the channel we just linked
 895        {
 896            channel_path::Entity::delete_many()
 897                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
 898                .exec(&*tx)
 899                .await?;
 900        }
 901
 902        let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
 903        if let Some(channel) = channel_descendants.get_mut(&channel) {
 904            // Remove the other parents
 905            channel.clear();
 906            channel.insert(to);
 907        }
 908
 909        let channels = self
 910            .get_channel_graph(channel_descendants, false, &*tx)
 911            .await?;
 912
 913        Ok(channels)
 914    }
 915
 916    /// Unlink a channel from a given parent. This will add in a root edge if
 917    /// the channel has no other parents after this operation.
 918    pub async fn unlink_channel(
 919        &self,
 920        user: UserId,
 921        channel: ChannelId,
 922        from: ChannelId,
 923    ) -> Result<()> {
 924        self.transaction(|tx| async move {
 925            // Note that even with these maxed permissions, this linking operation
 926            // is still insecure because you can't remove someone's permissions to a
 927            // channel if they've linked the channel to one where they're an admin.
 928            self.check_user_is_channel_admin(channel, user, &*tx)
 929                .await?;
 930
 931            self.unlink_channel_internal(user, channel, from, &*tx)
 932                .await?;
 933
 934            Ok(())
 935        })
 936        .await
 937    }
 938
 939    pub async fn unlink_channel_internal(
 940        &self,
 941        user: UserId,
 942        channel: ChannelId,
 943        from: ChannelId,
 944        tx: &DatabaseTransaction,
 945    ) -> Result<()> {
 946        self.check_user_is_channel_admin(from, user, &*tx).await?;
 947
 948        let sql = r#"
 949            DELETE FROM channel_paths
 950            WHERE
 951                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
 952            RETURNING id_path, channel_id
 953        "#;
 954
 955        let paths = channel_path::Entity::find()
 956            .from_raw_sql(Statement::from_sql_and_values(
 957                self.pool.get_database_backend(),
 958                sql,
 959                [from.to_proto().into(), channel.to_proto().into()],
 960            ))
 961            .all(&*tx)
 962            .await?;
 963
 964        let is_stranded = channel_path::Entity::find()
 965            .filter(channel_path::Column::ChannelId.eq(channel))
 966            .count(&*tx)
 967            .await?
 968            == 0;
 969
 970        // Make sure that there is always at least one path to the channel
 971        if is_stranded {
 972            let root_paths: Vec<_> = paths
 973                .iter()
 974                .map(|path| {
 975                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
 976                    channel_path::ActiveModel {
 977                        channel_id: ActiveValue::Set(path.channel_id),
 978                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
 979                    }
 980                })
 981                .collect();
 982            channel_path::Entity::insert_many(root_paths)
 983                .exec(&*tx)
 984                .await?;
 985        }
 986
 987        Ok(())
 988    }
 989
 990    /// Move a channel from one parent to another, returns the
 991    /// Channels that were moved for notifying clients
 992    pub async fn move_channel(
 993        &self,
 994        user: UserId,
 995        channel: ChannelId,
 996        from: ChannelId,
 997        to: ChannelId,
 998    ) -> Result<ChannelGraph> {
 999        if from == to {
1000            return Ok(ChannelGraph {
1001                channels: vec![],
1002                edges: vec![],
1003            });
1004        }
1005
1006        self.transaction(|tx| async move {
1007            self.check_user_is_channel_admin(channel, user, &*tx)
1008                .await?;
1009
1010            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1011
1012            self.unlink_channel_internal(user, channel, from, &*tx)
1013                .await?;
1014
1015            Ok(moved_channels)
1016        })
1017        .await
1018    }
1019}
1020
1021#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1022enum QueryUserIds {
1023    UserId,
1024}
1025
1026#[derive(Debug)]
1027pub struct ChannelGraph {
1028    pub channels: Vec<Channel>,
1029    pub edges: Vec<ChannelEdge>,
1030}
1031
1032impl ChannelGraph {
1033    pub fn is_empty(&self) -> bool {
1034        self.channels.is_empty() && self.edges.is_empty()
1035    }
1036}
1037
1038#[cfg(test)]
1039impl PartialEq for ChannelGraph {
1040    fn eq(&self, other: &Self) -> bool {
1041        // Order independent comparison for tests
1042        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1043        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1044        let edges_set = self
1045            .edges
1046            .iter()
1047            .map(|edge| (edge.channel_id, edge.parent_id))
1048            .collect::<HashSet<_>>();
1049        let other_edges_set = other
1050            .edges
1051            .iter()
1052            .map(|edge| (edge.channel_id, edge.parent_id))
1053            .collect::<HashSet<_>>();
1054
1055        channels_set == other_channels_set && edges_set == other_edges_set
1056    }
1057}
1058
1059#[cfg(not(test))]
1060impl PartialEq for ChannelGraph {
1061    fn eq(&self, other: &Self) -> bool {
1062        self.channels == other.channels && self.edges == other.edges
1063    }
1064}
1065
1066struct SmallSet<T>(SmallVec<[T; 1]>);
1067
1068impl<T> Deref for SmallSet<T> {
1069    type Target = [T];
1070
1071    fn deref(&self) -> &Self::Target {
1072        self.0.deref()
1073    }
1074}
1075
1076impl<T> Default for SmallSet<T> {
1077    fn default() -> Self {
1078        Self(SmallVec::new())
1079    }
1080}
1081
1082impl<T> SmallSet<T> {
1083    fn insert(&mut self, value: T) -> bool
1084    where
1085        T: Ord,
1086    {
1087        match self.binary_search(&value) {
1088            Ok(_) => false,
1089            Err(ix) => {
1090                self.0.insert(ix, value);
1091                true
1092            }
1093        }
1094    }
1095
1096    fn clear(&mut self) {
1097        self.0.clear();
1098    }
1099}