channels.rs

   1use super::*;
   2use rpc::proto::{channel_member::Kind, ChannelEdge};
   3
   4impl Database {
   5    #[cfg(test)]
   6    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
   7        self.transaction(move |tx| async move {
   8            let mut channels = Vec::new();
   9            let mut rows = channel::Entity::find().stream(&*tx).await?;
  10            while let Some(row) = rows.next().await {
  11                let row = row?;
  12                channels.push((row.id, row.name));
  13            }
  14            Ok(channels)
  15        })
  16        .await
  17    }
  18
  19    pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
  20        self.create_channel(name, None, creator_id).await
  21    }
  22
  23    pub async fn create_channel(
  24        &self,
  25        name: &str,
  26        parent: Option<ChannelId>,
  27        creator_id: UserId,
  28    ) -> Result<ChannelId> {
  29        let name = Self::sanitize_channel_name(name)?;
  30        self.transaction(move |tx| async move {
  31            if let Some(parent) = parent {
  32                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  33                    .await?;
  34            }
  35
  36            let channel = channel::ActiveModel {
  37                id: ActiveValue::NotSet,
  38                name: ActiveValue::Set(name.to_string()),
  39                visibility: ActiveValue::Set(ChannelVisibility::Members),
  40            }
  41            .insert(&*tx)
  42            .await?;
  43
  44            if let Some(parent) = parent {
  45                let sql = r#"
  46                    INSERT INTO channel_paths
  47                    (id_path, channel_id)
  48                    SELECT
  49                        id_path || $1 || '/', $2
  50                    FROM
  51                        channel_paths
  52                    WHERE
  53                        channel_id = $3
  54                "#;
  55                let channel_paths_stmt = Statement::from_sql_and_values(
  56                    self.pool.get_database_backend(),
  57                    sql,
  58                    [
  59                        channel.id.to_proto().into(),
  60                        channel.id.to_proto().into(),
  61                        parent.to_proto().into(),
  62                    ],
  63                );
  64                tx.execute(channel_paths_stmt).await?;
  65            } else {
  66                channel_path::Entity::insert(channel_path::ActiveModel {
  67                    channel_id: ActiveValue::Set(channel.id),
  68                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  69                })
  70                .exec(&*tx)
  71                .await?;
  72            }
  73
  74            channel_member::ActiveModel {
  75                id: ActiveValue::NotSet,
  76                channel_id: ActiveValue::Set(channel.id),
  77                user_id: ActiveValue::Set(creator_id),
  78                accepted: ActiveValue::Set(true),
  79                role: ActiveValue::Set(ChannelRole::Admin),
  80            }
  81            .insert(&*tx)
  82            .await?;
  83
  84            Ok(channel.id)
  85        })
  86        .await
  87    }
  88
  89    pub async fn join_channel(
  90        &self,
  91        channel_id: ChannelId,
  92        user_id: UserId,
  93        connection: ConnectionId,
  94        environment: &str,
  95    ) -> Result<(JoinRoom, Option<ChannelId>)> {
  96        self.transaction(move |tx| async move {
  97            let mut joined_channel_id = None;
  98
  99            let channel = channel::Entity::find()
 100                .filter(channel::Column::Id.eq(channel_id))
 101                .one(&*tx)
 102                .await?;
 103
 104            let mut role = self
 105                .channel_role_for_user(channel_id, user_id, &*tx)
 106                .await?;
 107
 108            if role.is_none() && channel.is_some() {
 109                if let Some(invitation) = self
 110                    .pending_invite_for_channel(channel_id, user_id, &*tx)
 111                    .await?
 112                {
 113                    // note, this may be a parent channel
 114                    joined_channel_id = Some(invitation.channel_id);
 115                    role = Some(invitation.role);
 116
 117                    channel_member::Entity::update(channel_member::ActiveModel {
 118                        accepted: ActiveValue::Set(true),
 119                        ..invitation.into_active_model()
 120                    })
 121                    .exec(&*tx)
 122                    .await?;
 123
 124                    debug_assert!(
 125                        self.channel_role_for_user(channel_id, user_id, &*tx)
 126                            .await?
 127                            == role
 128                    );
 129                }
 130            }
 131            if role.is_none()
 132                && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
 133            {
 134                let channel_id_to_join = self
 135                    .most_public_ancestor_for_channel(channel_id, &*tx)
 136                    .await?
 137                    .unwrap_or(channel_id);
 138                // TODO: change this back to Guest.
 139                role = Some(ChannelRole::Member);
 140                joined_channel_id = Some(channel_id_to_join);
 141
 142                channel_member::Entity::insert(channel_member::ActiveModel {
 143                    id: ActiveValue::NotSet,
 144                    channel_id: ActiveValue::Set(channel_id_to_join),
 145                    user_id: ActiveValue::Set(user_id),
 146                    accepted: ActiveValue::Set(true),
 147                    // TODO: change this back to Guest.
 148                    role: ActiveValue::Set(ChannelRole::Member),
 149                })
 150                .exec(&*tx)
 151                .await?;
 152
 153                debug_assert!(
 154                    self.channel_role_for_user(channel_id, user_id, &*tx)
 155                        .await?
 156                        == role
 157                );
 158            }
 159
 160            if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
 161                Err(anyhow!("no such channel, or not allowed"))?
 162            }
 163
 164            let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
 165            let room_id = self
 166                .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
 167                .await?;
 168
 169            self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
 170                .await
 171                .map(|jr| (jr, joined_channel_id))
 172        })
 173        .await
 174    }
 175
 176    pub async fn set_channel_visibility(
 177        &self,
 178        channel_id: ChannelId,
 179        visibility: ChannelVisibility,
 180        user_id: UserId,
 181    ) -> Result<channel::Model> {
 182        self.transaction(move |tx| async move {
 183            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 184                .await?;
 185
 186            let channel = channel::ActiveModel {
 187                id: ActiveValue::Unchanged(channel_id),
 188                visibility: ActiveValue::Set(visibility),
 189                ..Default::default()
 190            }
 191            .update(&*tx)
 192            .await?;
 193
 194            Ok(channel)
 195        })
 196        .await
 197    }
 198
 199    pub async fn delete_channel(
 200        &self,
 201        channel_id: ChannelId,
 202        user_id: UserId,
 203    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
 204        self.transaction(move |tx| async move {
 205            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 206                .await?;
 207
 208            // Don't remove descendant channels that have additional parents.
 209            let mut channels_to_remove: HashSet<ChannelId> = HashSet::default();
 210            channels_to_remove.insert(channel_id);
 211
 212            let graph = self.get_channel_descendants([channel_id], &*tx).await?;
 213            for edge in graph.iter() {
 214                channels_to_remove.insert(ChannelId::from_proto(edge.channel_id));
 215            }
 216
 217            {
 218                let mut channels_to_keep = channel_path::Entity::find()
 219                    .filter(
 220                        channel_path::Column::ChannelId
 221                            .is_in(channels_to_remove.iter().copied())
 222                            .and(
 223                                channel_path::Column::IdPath
 224                                    .not_like(&format!("%/{}/%", channel_id)),
 225                            ),
 226                    )
 227                    .stream(&*tx)
 228                    .await?;
 229                while let Some(row) = channels_to_keep.next().await {
 230                    let row = row?;
 231                    channels_to_remove.remove(&row.channel_id);
 232                }
 233            }
 234
 235            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 236            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 237                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 238                .select_only()
 239                .column(channel_member::Column::UserId)
 240                .distinct()
 241                .into_values::<_, QueryUserIds>()
 242                .all(&*tx)
 243                .await?;
 244
 245            channel::Entity::delete_many()
 246                .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied()))
 247                .exec(&*tx)
 248                .await?;
 249
 250            // Delete any other paths that include this channel
 251            let sql = r#"
 252                    DELETE FROM channel_paths
 253                    WHERE
 254                        id_path LIKE '%' || $1 || '%'
 255                "#;
 256            let channel_paths_stmt = Statement::from_sql_and_values(
 257                self.pool.get_database_backend(),
 258                sql,
 259                [channel_id.to_proto().into()],
 260            );
 261            tx.execute(channel_paths_stmt).await?;
 262
 263            Ok((channels_to_remove.into_iter().collect(), members_to_notify))
 264        })
 265        .await
 266    }
 267
 268    pub async fn invite_channel_member(
 269        &self,
 270        channel_id: ChannelId,
 271        invitee_id: UserId,
 272        admin_id: UserId,
 273        role: ChannelRole,
 274    ) -> Result<()> {
 275        self.transaction(move |tx| async move {
 276            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 277                .await?;
 278
 279            channel_member::ActiveModel {
 280                id: ActiveValue::NotSet,
 281                channel_id: ActiveValue::Set(channel_id),
 282                user_id: ActiveValue::Set(invitee_id),
 283                accepted: ActiveValue::Set(false),
 284                role: ActiveValue::Set(role),
 285            }
 286            .insert(&*tx)
 287            .await?;
 288
 289            Ok(())
 290        })
 291        .await
 292    }
 293
 294    fn sanitize_channel_name(name: &str) -> Result<&str> {
 295        let new_name = name.trim().trim_start_matches('#');
 296        if new_name == "" {
 297            Err(anyhow!("channel name can't be blank"))?;
 298        }
 299        Ok(new_name)
 300    }
 301
 302    pub async fn rename_channel(
 303        &self,
 304        channel_id: ChannelId,
 305        user_id: UserId,
 306        new_name: &str,
 307    ) -> Result<Channel> {
 308        self.transaction(move |tx| async move {
 309            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 310
 311            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 312                .await?;
 313
 314            let channel = channel::ActiveModel {
 315                id: ActiveValue::Unchanged(channel_id),
 316                name: ActiveValue::Set(new_name.clone()),
 317                ..Default::default()
 318            }
 319            .update(&*tx)
 320            .await?;
 321
 322            Ok(Channel {
 323                id: channel.id,
 324                name: channel.name,
 325                visibility: channel.visibility,
 326            })
 327        })
 328        .await
 329    }
 330
 331    pub async fn respond_to_channel_invite(
 332        &self,
 333        channel_id: ChannelId,
 334        user_id: UserId,
 335        accept: bool,
 336    ) -> Result<()> {
 337        self.transaction(move |tx| async move {
 338            let rows_affected = if accept {
 339                channel_member::Entity::update_many()
 340                    .set(channel_member::ActiveModel {
 341                        accepted: ActiveValue::Set(accept),
 342                        ..Default::default()
 343                    })
 344                    .filter(
 345                        channel_member::Column::ChannelId
 346                            .eq(channel_id)
 347                            .and(channel_member::Column::UserId.eq(user_id))
 348                            .and(channel_member::Column::Accepted.eq(false)),
 349                    )
 350                    .exec(&*tx)
 351                    .await?
 352                    .rows_affected
 353            } else {
 354                channel_member::ActiveModel {
 355                    channel_id: ActiveValue::Unchanged(channel_id),
 356                    user_id: ActiveValue::Unchanged(user_id),
 357                    ..Default::default()
 358                }
 359                .delete(&*tx)
 360                .await?
 361                .rows_affected
 362            };
 363
 364            if rows_affected == 0 {
 365                Err(anyhow!("no such invitation"))?;
 366            }
 367
 368            Ok(())
 369        })
 370        .await
 371    }
 372
 373    pub async fn remove_channel_member(
 374        &self,
 375        channel_id: ChannelId,
 376        member_id: UserId,
 377        admin_id: UserId,
 378    ) -> Result<()> {
 379        self.transaction(|tx| async move {
 380            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 381                .await?;
 382
 383            let result = channel_member::Entity::delete_many()
 384                .filter(
 385                    channel_member::Column::ChannelId
 386                        .eq(channel_id)
 387                        .and(channel_member::Column::UserId.eq(member_id)),
 388                )
 389                .exec(&*tx)
 390                .await?;
 391
 392            if result.rows_affected == 0 {
 393                Err(anyhow!("no such member"))?;
 394            }
 395
 396            Ok(())
 397        })
 398        .await
 399    }
 400
 401    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 402        self.transaction(|tx| async move {
 403            let channel_invites = channel_member::Entity::find()
 404                .filter(
 405                    channel_member::Column::UserId
 406                        .eq(user_id)
 407                        .and(channel_member::Column::Accepted.eq(false)),
 408                )
 409                .all(&*tx)
 410                .await?;
 411
 412            let channels = channel::Entity::find()
 413                .filter(
 414                    channel::Column::Id.is_in(
 415                        channel_invites
 416                            .into_iter()
 417                            .map(|channel_member| channel_member.channel_id),
 418                    ),
 419                )
 420                .all(&*tx)
 421                .await?;
 422
 423            let channels = channels
 424                .into_iter()
 425                .map(|channel| Channel {
 426                    id: channel.id,
 427                    name: channel.name,
 428                    visibility: channel.visibility,
 429                })
 430                .collect();
 431
 432            Ok(channels)
 433        })
 434        .await
 435    }
 436
 437    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 438        self.transaction(|tx| async move {
 439            let tx = tx;
 440
 441            let channel_memberships = channel_member::Entity::find()
 442                .filter(
 443                    channel_member::Column::UserId
 444                        .eq(user_id)
 445                        .and(channel_member::Column::Accepted.eq(true)),
 446                )
 447                .all(&*tx)
 448                .await?;
 449
 450            self.get_user_channels(user_id, channel_memberships, &tx)
 451                .await
 452        })
 453        .await
 454    }
 455
 456    pub async fn get_channel_for_user(
 457        &self,
 458        channel_id: ChannelId,
 459        user_id: UserId,
 460    ) -> Result<ChannelsForUser> {
 461        self.transaction(|tx| async move {
 462            let tx = tx;
 463
 464            let channel_membership = channel_member::Entity::find()
 465                .filter(
 466                    channel_member::Column::UserId
 467                        .eq(user_id)
 468                        .and(channel_member::Column::ChannelId.eq(channel_id))
 469                        .and(channel_member::Column::Accepted.eq(true)),
 470                )
 471                .all(&*tx)
 472                .await?;
 473
 474            self.get_user_channels(user_id, channel_membership, &tx)
 475                .await
 476        })
 477        .await
 478    }
 479
 480    pub async fn get_user_channels(
 481        &self,
 482        user_id: UserId,
 483        channel_memberships: Vec<channel_member::Model>,
 484        tx: &DatabaseTransaction,
 485    ) -> Result<ChannelsForUser> {
 486        let mut edges = self
 487            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 488            .await?;
 489
 490        let mut role_for_channel: HashMap<ChannelId, ChannelRole> = HashMap::default();
 491
 492        for membership in channel_memberships.iter() {
 493            role_for_channel.insert(membership.channel_id, membership.role);
 494        }
 495
 496        for ChannelEdge {
 497            parent_id,
 498            channel_id,
 499        } in edges.iter()
 500        {
 501            let parent_id = ChannelId::from_proto(*parent_id);
 502            let channel_id = ChannelId::from_proto(*channel_id);
 503            debug_assert!(role_for_channel.get(&parent_id).is_some());
 504            let parent_role = role_for_channel[&parent_id];
 505            if let Some(existing_role) = role_for_channel.get(&channel_id) {
 506                if existing_role.should_override(parent_role) {
 507                    continue;
 508                }
 509            }
 510            role_for_channel.insert(channel_id, parent_role);
 511        }
 512
 513        let mut channels: Vec<Channel> = Vec::new();
 514        let mut channels_with_admin_privileges: HashSet<ChannelId> = HashSet::default();
 515        let mut channels_to_remove: HashSet<u64> = HashSet::default();
 516
 517        let mut rows = channel::Entity::find()
 518            .filter(channel::Column::Id.is_in(role_for_channel.keys().copied()))
 519            .stream(&*tx)
 520            .await?;
 521
 522        while let Some(row) = rows.next().await {
 523            let channel = row?;
 524            let role = role_for_channel[&channel.id];
 525
 526            if role == ChannelRole::Banned
 527                || role == ChannelRole::Guest && channel.visibility != ChannelVisibility::Public
 528            {
 529                channels_to_remove.insert(channel.id.0 as u64);
 530                continue;
 531            }
 532
 533            channels.push(Channel {
 534                id: channel.id,
 535                name: channel.name,
 536                visibility: channel.visibility,
 537            });
 538
 539            if role == ChannelRole::Admin {
 540                channels_with_admin_privileges.insert(channel.id);
 541            }
 542        }
 543        drop(rows);
 544
 545        if !channels_to_remove.is_empty() {
 546            // Note: this code assumes each channel has one parent.
 547            // If there are multiple valid public paths to a channel,
 548            // e.g.
 549            // If both of these paths are present (* indicating public):
 550            // - zed* -> projects -> vim*
 551            // - zed* -> conrad -> public-projects* -> vim*
 552            // Users would only see one of them (based on edge sort order)
 553            let mut replacement_parent: HashMap<u64, u64> = HashMap::default();
 554            for ChannelEdge {
 555                parent_id,
 556                channel_id,
 557            } in edges.iter()
 558            {
 559                if channels_to_remove.contains(channel_id) {
 560                    replacement_parent.insert(*channel_id, *parent_id);
 561                }
 562            }
 563
 564            let mut new_edges: Vec<ChannelEdge> = Vec::new();
 565            'outer: for ChannelEdge {
 566                mut parent_id,
 567                channel_id,
 568            } in edges.iter()
 569            {
 570                if channels_to_remove.contains(channel_id) {
 571                    continue;
 572                }
 573                while channels_to_remove.contains(&parent_id) {
 574                    if let Some(new_parent_id) = replacement_parent.get(&parent_id) {
 575                        parent_id = *new_parent_id;
 576                    } else {
 577                        continue 'outer;
 578                    }
 579                }
 580                new_edges.push(ChannelEdge {
 581                    parent_id,
 582                    channel_id: *channel_id,
 583                })
 584            }
 585            edges = new_edges;
 586        }
 587
 588        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 589        enum QueryUserIdsAndChannelIds {
 590            ChannelId,
 591            UserId,
 592        }
 593
 594        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 595        {
 596            let mut rows = room_participant::Entity::find()
 597                .inner_join(room::Entity)
 598                .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
 599                .select_only()
 600                .column(room::Column::ChannelId)
 601                .column(room_participant::Column::UserId)
 602                .into_values::<_, QueryUserIdsAndChannelIds>()
 603                .stream(&*tx)
 604                .await?;
 605            while let Some(row) = rows.next().await {
 606                let row: (ChannelId, UserId) = row?;
 607                channel_participants.entry(row.0).or_default().push(row.1)
 608            }
 609        }
 610
 611        let channel_ids = channels.iter().map(|c| c.id).collect::<Vec<_>>();
 612        let channel_buffer_changes = self
 613            .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
 614            .await?;
 615
 616        let unseen_messages = self
 617            .unseen_channel_messages(user_id, &channel_ids, &*tx)
 618            .await?;
 619
 620        Ok(ChannelsForUser {
 621            channels: ChannelGraph { channels, edges },
 622            channel_participants,
 623            channels_with_admin_privileges,
 624            unseen_buffer_changes: channel_buffer_changes,
 625            channel_messages: unseen_messages,
 626        })
 627    }
 628
 629    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 630        self.transaction(|tx| async move { self.get_channel_participants_internal(id, &*tx).await })
 631            .await
 632    }
 633
 634    pub async fn set_channel_member_role(
 635        &self,
 636        channel_id: ChannelId,
 637        admin_id: UserId,
 638        for_user: UserId,
 639        role: ChannelRole,
 640    ) -> Result<channel_member::Model> {
 641        self.transaction(|tx| async move {
 642            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 643                .await?;
 644
 645            let membership = channel_member::Entity::find()
 646                .filter(
 647                    channel_member::Column::ChannelId
 648                        .eq(channel_id)
 649                        .and(channel_member::Column::UserId.eq(for_user)),
 650                )
 651                .one(&*tx)
 652                .await?;
 653
 654            let Some(membership) = membership else {
 655                Err(anyhow!("no such member"))?
 656            };
 657
 658            let mut update = membership.into_active_model();
 659            update.role = ActiveValue::Set(role);
 660            let updated = channel_member::Entity::update(update).exec(&*tx).await?;
 661
 662            Ok(updated)
 663        })
 664        .await
 665    }
 666
 667    pub async fn get_channel_participant_details(
 668        &self,
 669        channel_id: ChannelId,
 670        admin_id: UserId,
 671    ) -> Result<Vec<proto::ChannelMember>> {
 672        self.transaction(|tx| async move {
 673            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 674                .await?;
 675
 676            let channel_visibility = channel::Entity::find()
 677                .filter(channel::Column::Id.eq(channel_id))
 678                .one(&*tx)
 679                .await?
 680                .map(|channel| channel.visibility)
 681                .unwrap_or(ChannelVisibility::Members);
 682
 683            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 684            enum QueryMemberDetails {
 685                UserId,
 686                Role,
 687                IsDirectMember,
 688                Accepted,
 689                Visibility,
 690            }
 691
 692            let tx = tx;
 693            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 694            let mut stream = channel_member::Entity::find()
 695                .left_join(channel::Entity)
 696                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 697                .select_only()
 698                .column(channel_member::Column::UserId)
 699                .column(channel_member::Column::Role)
 700                .column_as(
 701                    channel_member::Column::ChannelId.eq(channel_id),
 702                    QueryMemberDetails::IsDirectMember,
 703                )
 704                .column(channel_member::Column::Accepted)
 705                .column(channel::Column::Visibility)
 706                .into_values::<_, QueryMemberDetails>()
 707                .stream(&*tx)
 708                .await?;
 709
 710            struct UserDetail {
 711                kind: Kind,
 712                channel_role: ChannelRole,
 713            }
 714            let mut user_details: HashMap<UserId, UserDetail> = HashMap::default();
 715
 716            while let Some(user_membership) = stream.next().await {
 717                let (user_id, channel_role, is_direct_member, is_invite_accepted, visibility): (
 718                    UserId,
 719                    ChannelRole,
 720                    bool,
 721                    bool,
 722                    ChannelVisibility,
 723                ) = user_membership?;
 724                let kind = match (is_direct_member, is_invite_accepted) {
 725                    (true, true) => proto::channel_member::Kind::Member,
 726                    (true, false) => proto::channel_member::Kind::Invitee,
 727                    (false, true) => proto::channel_member::Kind::AncestorMember,
 728                    (false, false) => continue,
 729                };
 730
 731                if channel_role == ChannelRole::Guest
 732                    && visibility != ChannelVisibility::Public
 733                    && channel_visibility != ChannelVisibility::Public
 734                {
 735                    continue;
 736                }
 737
 738                if let Some(details_mut) = user_details.get_mut(&user_id) {
 739                    if channel_role.should_override(details_mut.channel_role) {
 740                        details_mut.channel_role = channel_role;
 741                    }
 742                    if kind == Kind::Member {
 743                        details_mut.kind = kind;
 744                    // the UI is going to be a bit confusing if you already have permissions
 745                    // that are greater than or equal to the ones you're being invited to.
 746                    } else if kind == Kind::Invitee && details_mut.kind == Kind::AncestorMember {
 747                        details_mut.kind = kind;
 748                    }
 749                } else {
 750                    user_details.insert(user_id, UserDetail { kind, channel_role });
 751                }
 752            }
 753
 754            Ok(user_details
 755                .into_iter()
 756                .map(|(user_id, details)| proto::ChannelMember {
 757                    user_id: user_id.to_proto(),
 758                    kind: details.kind.into(),
 759                    role: details.channel_role.into(),
 760                })
 761                .collect())
 762        })
 763        .await
 764    }
 765
 766    pub async fn get_channel_participants_internal(
 767        &self,
 768        id: ChannelId,
 769        tx: &DatabaseTransaction,
 770    ) -> Result<Vec<UserId>> {
 771        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 772        let user_ids = channel_member::Entity::find()
 773            .distinct()
 774            .filter(
 775                channel_member::Column::ChannelId
 776                    .is_in(ancestor_ids.iter().copied())
 777                    .and(channel_member::Column::Accepted.eq(true)),
 778            )
 779            .select_only()
 780            .column(channel_member::Column::UserId)
 781            .into_values::<_, QueryUserIds>()
 782            .all(&*tx)
 783            .await?;
 784        Ok(user_ids)
 785    }
 786
 787    pub async fn check_user_is_channel_admin(
 788        &self,
 789        channel_id: ChannelId,
 790        user_id: UserId,
 791        tx: &DatabaseTransaction,
 792    ) -> Result<()> {
 793        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 794            Some(ChannelRole::Admin) => Ok(()),
 795            Some(ChannelRole::Member)
 796            | Some(ChannelRole::Banned)
 797            | Some(ChannelRole::Guest)
 798            | None => Err(anyhow!(
 799                "user is not a channel admin or channel does not exist"
 800            ))?,
 801        }
 802    }
 803
 804    pub async fn check_user_is_channel_member(
 805        &self,
 806        channel_id: ChannelId,
 807        user_id: UserId,
 808        tx: &DatabaseTransaction,
 809    ) -> Result<()> {
 810        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 811            Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()),
 812            Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
 813                "user is not a channel member or channel does not exist"
 814            ))?,
 815        }
 816    }
 817
 818    pub async fn check_user_is_channel_participant(
 819        &self,
 820        channel_id: ChannelId,
 821        user_id: UserId,
 822        tx: &DatabaseTransaction,
 823    ) -> Result<()> {
 824        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 825            Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
 826                Ok(())
 827            }
 828            Some(ChannelRole::Banned) | None => Err(anyhow!(
 829                "user is not a channel participant or channel does not exist"
 830            ))?,
 831        }
 832    }
 833
 834    pub async fn pending_invite_for_channel(
 835        &self,
 836        channel_id: ChannelId,
 837        user_id: UserId,
 838        tx: &DatabaseTransaction,
 839    ) -> Result<Option<channel_member::Model>> {
 840        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 841
 842        let row = channel_member::Entity::find()
 843            .filter(channel_member::Column::ChannelId.is_in(channel_ids))
 844            .filter(channel_member::Column::UserId.eq(user_id))
 845            .filter(channel_member::Column::Accepted.eq(false))
 846            .one(&*tx)
 847            .await?;
 848
 849        Ok(row)
 850    }
 851
 852    pub async fn most_public_ancestor_for_channel(
 853        &self,
 854        channel_id: ChannelId,
 855        tx: &DatabaseTransaction,
 856    ) -> Result<Option<ChannelId>> {
 857        // Note: if there are many paths to a channel, this will return just one
 858        let arbitary_path = channel_path::Entity::find()
 859            .filter(channel_path::Column::ChannelId.eq(channel_id))
 860            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 861            .one(tx)
 862            .await?;
 863
 864        let Some(path) = arbitary_path else {
 865            return Ok(None);
 866        };
 867
 868        let ancestor_ids: Vec<ChannelId> = path
 869            .id_path
 870            .trim_matches('/')
 871            .split('/')
 872            .map(|id| ChannelId::from_proto(id.parse().unwrap()))
 873            .collect();
 874
 875        let rows = channel::Entity::find()
 876            .filter(channel::Column::Id.is_in(ancestor_ids.iter().copied()))
 877            .filter(channel::Column::Visibility.eq(ChannelVisibility::Public))
 878            .all(&*tx)
 879            .await?;
 880
 881        let mut visible_channels: HashSet<ChannelId> = HashSet::default();
 882
 883        for row in rows {
 884            visible_channels.insert(row.id);
 885        }
 886
 887        for ancestor in ancestor_ids {
 888            if visible_channels.contains(&ancestor) {
 889                return Ok(Some(ancestor));
 890            }
 891        }
 892
 893        Ok(None)
 894    }
 895
 896    pub async fn channel_role_for_user(
 897        &self,
 898        channel_id: ChannelId,
 899        user_id: UserId,
 900        tx: &DatabaseTransaction,
 901    ) -> Result<Option<ChannelRole>> {
 902        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 903
 904        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 905        enum QueryChannelMembership {
 906            ChannelId,
 907            Role,
 908            Visibility,
 909        }
 910
 911        let mut rows = channel_member::Entity::find()
 912            .left_join(channel::Entity)
 913            .filter(
 914                channel_member::Column::ChannelId
 915                    .is_in(channel_ids)
 916                    .and(channel_member::Column::UserId.eq(user_id))
 917                    .and(channel_member::Column::Accepted.eq(true)),
 918            )
 919            .select_only()
 920            .column(channel_member::Column::ChannelId)
 921            .column(channel_member::Column::Role)
 922            .column(channel::Column::Visibility)
 923            .into_values::<_, QueryChannelMembership>()
 924            .stream(&*tx)
 925            .await?;
 926
 927        let mut user_role: Option<ChannelRole> = None;
 928
 929        let mut is_participant = false;
 930        let mut current_channel_visibility = None;
 931
 932        // note these channels are not iterated in any particular order,
 933        // our current logic takes the highest permission available.
 934        while let Some(row) = rows.next().await {
 935            let (membership_channel, role, visibility): (
 936                ChannelId,
 937                ChannelRole,
 938                ChannelVisibility,
 939            ) = row?;
 940
 941            match role {
 942                ChannelRole::Admin | ChannelRole::Member | ChannelRole::Banned => {
 943                    if let Some(users_role) = user_role {
 944                        user_role = Some(users_role.max(role));
 945                    } else {
 946                        user_role = Some(role)
 947                    }
 948                }
 949                ChannelRole::Guest if visibility == ChannelVisibility::Public => {
 950                    is_participant = true
 951                }
 952                ChannelRole::Guest => {}
 953            }
 954            if channel_id == membership_channel {
 955                current_channel_visibility = Some(visibility);
 956            }
 957        }
 958        // free up database connection
 959        drop(rows);
 960
 961        if is_participant && user_role.is_none() {
 962            if current_channel_visibility.is_none() {
 963                current_channel_visibility = channel::Entity::find()
 964                    .filter(channel::Column::Id.eq(channel_id))
 965                    .one(&*tx)
 966                    .await?
 967                    .map(|channel| channel.visibility);
 968            }
 969            if current_channel_visibility == Some(ChannelVisibility::Public) {
 970                user_role = Some(ChannelRole::Guest);
 971            }
 972        }
 973
 974        Ok(user_role)
 975    }
 976
 977    /// Returns the channel ancestors in arbitrary order
 978    pub async fn get_channel_ancestors(
 979        &self,
 980        channel_id: ChannelId,
 981        tx: &DatabaseTransaction,
 982    ) -> Result<Vec<ChannelId>> {
 983        let paths = channel_path::Entity::find()
 984            .filter(channel_path::Column::ChannelId.eq(channel_id))
 985            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 986            .all(tx)
 987            .await?;
 988        let mut channel_ids = Vec::new();
 989        for path in paths {
 990            for id in path.id_path.trim_matches('/').split('/') {
 991                if let Ok(id) = id.parse() {
 992                    let id = ChannelId::from_proto(id);
 993                    if let Err(ix) = channel_ids.binary_search(&id) {
 994                        channel_ids.insert(ix, id);
 995                    }
 996                }
 997            }
 998        }
 999        Ok(channel_ids)
1000    }
1001
1002    // Returns the channel desendants as a sorted list of edges for further processing.
1003    // The edges are sorted such that you will see unknown channel ids as children
1004    // before you see them as parents.
1005    async fn get_channel_descendants(
1006        &self,
1007        channel_ids: impl IntoIterator<Item = ChannelId>,
1008        tx: &DatabaseTransaction,
1009    ) -> Result<Vec<ChannelEdge>> {
1010        let mut values = String::new();
1011        for id in channel_ids {
1012            if !values.is_empty() {
1013                values.push_str(", ");
1014            }
1015            write!(&mut values, "({})", id).unwrap();
1016        }
1017
1018        if values.is_empty() {
1019            return Ok(vec![]);
1020        }
1021
1022        let sql = format!(
1023            r#"
1024            SELECT
1025                descendant_paths.*
1026            FROM
1027                channel_paths parent_paths, channel_paths descendant_paths
1028            WHERE
1029                parent_paths.channel_id IN ({values}) AND
1030                descendant_paths.id_path != parent_paths.id_path AND
1031                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
1032            ORDER BY
1033                descendant_paths.id_path
1034        "#
1035        );
1036
1037        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
1038
1039        let mut paths = channel_path::Entity::find()
1040            .from_raw_sql(stmt)
1041            .stream(tx)
1042            .await?;
1043
1044        let mut results: Vec<ChannelEdge> = Vec::new();
1045        while let Some(path) = paths.next().await {
1046            let path = path?;
1047            let ids: Vec<&str> = path.id_path.trim_matches('/').split('/').collect();
1048
1049            debug_assert!(ids.len() >= 2);
1050            debug_assert!(ids[ids.len() - 1] == path.channel_id.to_string());
1051
1052            results.push(ChannelEdge {
1053                parent_id: ids[ids.len() - 2].parse().unwrap(),
1054                channel_id: ids[ids.len() - 1].parse().unwrap(),
1055            })
1056        }
1057
1058        Ok(results)
1059    }
1060
1061    /// Returns the channel with the given ID
1062    pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result<Channel> {
1063        self.transaction(|tx| async move {
1064            self.check_user_is_channel_participant(channel_id, user_id, &*tx)
1065                .await?;
1066
1067            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
1068            let Some(channel) = channel else {
1069                Err(anyhow!("no such channel"))?
1070            };
1071
1072            Ok(Channel {
1073                id: channel.id,
1074                visibility: channel.visibility,
1075                name: channel.name,
1076            })
1077        })
1078        .await
1079    }
1080
1081    pub(crate) async fn get_or_create_channel_room(
1082        &self,
1083        channel_id: ChannelId,
1084        live_kit_room: &str,
1085        environment: &str,
1086        tx: &DatabaseTransaction,
1087    ) -> Result<RoomId> {
1088        let room = room::Entity::find()
1089            .filter(room::Column::ChannelId.eq(channel_id))
1090            .one(&*tx)
1091            .await?;
1092
1093        let room_id = if let Some(room) = room {
1094            if let Some(env) = room.enviroment {
1095                if &env != environment {
1096                    Err(anyhow!("must join using the {} release", env))?;
1097                }
1098            }
1099            room.id
1100        } else {
1101            let result = room::Entity::insert(room::ActiveModel {
1102                channel_id: ActiveValue::Set(Some(channel_id)),
1103                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
1104                enviroment: ActiveValue::Set(Some(environment.to_string())),
1105                ..Default::default()
1106            })
1107            .exec(&*tx)
1108            .await?;
1109
1110            result.last_insert_id
1111        };
1112
1113        Ok(room_id)
1114    }
1115
1116    // Insert an edge from the given channel to the given other channel.
1117    pub async fn link_channel(
1118        &self,
1119        user: UserId,
1120        channel: ChannelId,
1121        to: ChannelId,
1122    ) -> Result<ChannelGraph> {
1123        self.transaction(|tx| async move {
1124            // Note that even with these maxed permissions, this linking operation
1125            // is still insecure because you can't remove someone's permissions to a
1126            // channel if they've linked the channel to one where they're an admin.
1127            self.check_user_is_channel_admin(channel, user, &*tx)
1128                .await?;
1129
1130            self.link_channel_internal(user, channel, to, &*tx).await
1131        })
1132        .await
1133    }
1134
1135    pub async fn link_channel_internal(
1136        &self,
1137        user: UserId,
1138        channel: ChannelId,
1139        new_parent: ChannelId,
1140        tx: &DatabaseTransaction,
1141    ) -> Result<ChannelGraph> {
1142        self.check_user_is_channel_admin(new_parent, user, &*tx)
1143            .await?;
1144
1145        let paths = channel_path::Entity::find()
1146            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
1147            .all(tx)
1148            .await?;
1149
1150        let mut new_path_suffixes = HashSet::default();
1151        for path in paths {
1152            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
1153                new_path_suffixes.insert((
1154                    path.channel_id,
1155                    path.id_path[(start_offset + 1)..].to_string(),
1156                ));
1157            }
1158        }
1159
1160        let paths_to_new_parent = channel_path::Entity::find()
1161            .filter(channel_path::Column::ChannelId.eq(new_parent))
1162            .all(tx)
1163            .await?;
1164
1165        let mut new_paths = Vec::new();
1166        for path in paths_to_new_parent {
1167            if path.id_path.contains(&format!("/{}/", channel)) {
1168                Err(anyhow!("cycle"))?;
1169            }
1170
1171            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
1172                channel_path::ActiveModel {
1173                    channel_id: ActiveValue::Set(*channel_id),
1174                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
1175                }
1176            }));
1177        }
1178
1179        channel_path::Entity::insert_many(new_paths)
1180            .exec(&*tx)
1181            .await?;
1182
1183        // remove any root edges for the channel we just linked
1184        {
1185            channel_path::Entity::delete_many()
1186                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
1187                .exec(&*tx)
1188                .await?;
1189        }
1190
1191        let membership = channel_member::Entity::find()
1192            .filter(
1193                channel_member::Column::ChannelId
1194                    .eq(channel)
1195                    .and(channel_member::Column::UserId.eq(user)),
1196            )
1197            .all(tx)
1198            .await?;
1199
1200        let mut channel_info = self.get_user_channels(user, membership, &*tx).await?;
1201
1202        channel_info.channels.edges.push(ChannelEdge {
1203            channel_id: channel.to_proto(),
1204            parent_id: new_parent.to_proto(),
1205        });
1206
1207        Ok(channel_info.channels)
1208    }
1209
1210    /// Unlink a channel from a given parent. This will add in a root edge if
1211    /// the channel has no other parents after this operation.
1212    pub async fn unlink_channel(
1213        &self,
1214        user: UserId,
1215        channel: ChannelId,
1216        from: ChannelId,
1217    ) -> Result<()> {
1218        self.transaction(|tx| async move {
1219            // Note that even with these maxed permissions, this linking operation
1220            // is still insecure because you can't remove someone's permissions to a
1221            // channel if they've linked the channel to one where they're an admin.
1222            self.check_user_is_channel_admin(channel, user, &*tx)
1223                .await?;
1224
1225            self.unlink_channel_internal(user, channel, from, &*tx)
1226                .await?;
1227
1228            Ok(())
1229        })
1230        .await
1231    }
1232
1233    pub async fn unlink_channel_internal(
1234        &self,
1235        user: UserId,
1236        channel: ChannelId,
1237        from: ChannelId,
1238        tx: &DatabaseTransaction,
1239    ) -> Result<()> {
1240        self.check_user_is_channel_admin(from, user, &*tx).await?;
1241
1242        let sql = r#"
1243            DELETE FROM channel_paths
1244            WHERE
1245                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
1246            RETURNING id_path, channel_id
1247        "#;
1248
1249        let paths = channel_path::Entity::find()
1250            .from_raw_sql(Statement::from_sql_and_values(
1251                self.pool.get_database_backend(),
1252                sql,
1253                [from.to_proto().into(), channel.to_proto().into()],
1254            ))
1255            .all(&*tx)
1256            .await?;
1257
1258        let is_stranded = channel_path::Entity::find()
1259            .filter(channel_path::Column::ChannelId.eq(channel))
1260            .count(&*tx)
1261            .await?
1262            == 0;
1263
1264        // Make sure that there is always at least one path to the channel
1265        if is_stranded {
1266            let root_paths: Vec<_> = paths
1267                .iter()
1268                .map(|path| {
1269                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
1270                    channel_path::ActiveModel {
1271                        channel_id: ActiveValue::Set(path.channel_id),
1272                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
1273                    }
1274                })
1275                .collect();
1276            channel_path::Entity::insert_many(root_paths)
1277                .exec(&*tx)
1278                .await?;
1279        }
1280
1281        Ok(())
1282    }
1283
1284    /// Move a channel from one parent to another, returns the
1285    /// Channels that were moved for notifying clients
1286    pub async fn move_channel(
1287        &self,
1288        user: UserId,
1289        channel: ChannelId,
1290        from: ChannelId,
1291        to: ChannelId,
1292    ) -> Result<ChannelGraph> {
1293        if from == to {
1294            return Ok(ChannelGraph {
1295                channels: vec![],
1296                edges: vec![],
1297            });
1298        }
1299
1300        self.transaction(|tx| async move {
1301            self.check_user_is_channel_admin(channel, user, &*tx)
1302                .await?;
1303
1304            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1305
1306            self.unlink_channel_internal(user, channel, from, &*tx)
1307                .await?;
1308
1309            Ok(moved_channels)
1310        })
1311        .await
1312    }
1313}
1314
1315#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1316enum QueryUserIds {
1317    UserId,
1318}
1319
1320#[derive(Debug)]
1321pub struct ChannelGraph {
1322    pub channels: Vec<Channel>,
1323    pub edges: Vec<ChannelEdge>,
1324}
1325
1326impl ChannelGraph {
1327    pub fn is_empty(&self) -> bool {
1328        self.channels.is_empty() && self.edges.is_empty()
1329    }
1330}
1331
1332#[cfg(test)]
1333impl PartialEq for ChannelGraph {
1334    fn eq(&self, other: &Self) -> bool {
1335        // Order independent comparison for tests
1336        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1337        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1338        let edges_set = self
1339            .edges
1340            .iter()
1341            .map(|edge| (edge.channel_id, edge.parent_id))
1342            .collect::<HashSet<_>>();
1343        let other_edges_set = other
1344            .edges
1345            .iter()
1346            .map(|edge| (edge.channel_id, edge.parent_id))
1347            .collect::<HashSet<_>>();
1348
1349        channels_set == other_channels_set && edges_set == other_edges_set
1350    }
1351}
1352
1353#[cfg(not(test))]
1354impl PartialEq for ChannelGraph {
1355    fn eq(&self, other: &Self) -> bool {
1356        self.channels == other.channels && self.edges == other.edges
1357    }
1358}