channels.rs

   1use super::*;
   2use rpc::proto::{channel_member::Kind, ChannelEdge};
   3
   4impl Database {
   5    #[cfg(test)]
   6    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
   7        self.transaction(move |tx| async move {
   8            let mut channels = Vec::new();
   9            let mut rows = channel::Entity::find().stream(&*tx).await?;
  10            while let Some(row) = rows.next().await {
  11                let row = row?;
  12                channels.push((row.id, row.name));
  13            }
  14            Ok(channels)
  15        })
  16        .await
  17    }
  18
  19    pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
  20        self.create_channel(name, None, creator_id).await
  21    }
  22
  23    pub async fn create_channel(
  24        &self,
  25        name: &str,
  26        parent: Option<ChannelId>,
  27        creator_id: UserId,
  28    ) -> Result<ChannelId> {
  29        let name = Self::sanitize_channel_name(name)?;
  30        self.transaction(move |tx| async move {
  31            if let Some(parent) = parent {
  32                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  33                    .await?;
  34            }
  35
  36            let channel = channel::ActiveModel {
  37                id: ActiveValue::NotSet,
  38                name: ActiveValue::Set(name.to_string()),
  39                visibility: ActiveValue::Set(ChannelVisibility::Members),
  40            }
  41            .insert(&*tx)
  42            .await?;
  43
  44            if let Some(parent) = parent {
  45                let sql = r#"
  46                    INSERT INTO channel_paths
  47                    (id_path, channel_id)
  48                    SELECT
  49                        id_path || $1 || '/', $2
  50                    FROM
  51                        channel_paths
  52                    WHERE
  53                        channel_id = $3
  54                "#;
  55                let channel_paths_stmt = Statement::from_sql_and_values(
  56                    self.pool.get_database_backend(),
  57                    sql,
  58                    [
  59                        channel.id.to_proto().into(),
  60                        channel.id.to_proto().into(),
  61                        parent.to_proto().into(),
  62                    ],
  63                );
  64                tx.execute(channel_paths_stmt).await?;
  65            } else {
  66                channel_path::Entity::insert(channel_path::ActiveModel {
  67                    channel_id: ActiveValue::Set(channel.id),
  68                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  69                })
  70                .exec(&*tx)
  71                .await?;
  72            }
  73
  74            channel_member::ActiveModel {
  75                id: ActiveValue::NotSet,
  76                channel_id: ActiveValue::Set(channel.id),
  77                user_id: ActiveValue::Set(creator_id),
  78                accepted: ActiveValue::Set(true),
  79                role: ActiveValue::Set(ChannelRole::Admin),
  80            }
  81            .insert(&*tx)
  82            .await?;
  83
  84            Ok(channel.id)
  85        })
  86        .await
  87    }
  88
  89    pub async fn join_channel(
  90        &self,
  91        channel_id: ChannelId,
  92        user_id: UserId,
  93        connection: ConnectionId,
  94        environment: &str,
  95    ) -> Result<(JoinRoom, Option<ChannelId>)> {
  96        self.transaction(move |tx| async move {
  97            let mut joined_channel_id = None;
  98
  99            let channel = channel::Entity::find()
 100                .filter(channel::Column::Id.eq(channel_id))
 101                .one(&*tx)
 102                .await?;
 103
 104            let mut role = self
 105                .channel_role_for_user(channel_id, user_id, &*tx)
 106                .await?;
 107
 108            if role.is_none() && channel.is_some() {
 109                if let Some(invitation) = self
 110                    .pending_invite_for_channel(channel_id, user_id, &*tx)
 111                    .await?
 112                {
 113                    // note, this may be a parent channel
 114                    joined_channel_id = Some(invitation.channel_id);
 115                    role = Some(invitation.role);
 116
 117                    channel_member::Entity::update(channel_member::ActiveModel {
 118                        accepted: ActiveValue::Set(true),
 119                        ..invitation.into_active_model()
 120                    })
 121                    .exec(&*tx)
 122                    .await?;
 123
 124                    debug_assert!(
 125                        self.channel_role_for_user(channel_id, user_id, &*tx)
 126                            .await?
 127                            == role
 128                    );
 129                }
 130            }
 131            if role.is_none()
 132                && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
 133            {
 134                let channel_id_to_join = self
 135                    .most_public_ancestor_for_channel(channel_id, &*tx)
 136                    .await?
 137                    .unwrap_or(channel_id);
 138                // TODO: change this back to Guest.
 139                role = Some(ChannelRole::Member);
 140                joined_channel_id = Some(channel_id_to_join);
 141
 142                channel_member::Entity::insert(channel_member::ActiveModel {
 143                    id: ActiveValue::NotSet,
 144                    channel_id: ActiveValue::Set(channel_id_to_join),
 145                    user_id: ActiveValue::Set(user_id),
 146                    accepted: ActiveValue::Set(true),
 147                    // TODO: change this back to Guest.
 148                    role: ActiveValue::Set(ChannelRole::Member),
 149                })
 150                .exec(&*tx)
 151                .await?;
 152
 153                debug_assert!(
 154                    self.channel_role_for_user(channel_id, user_id, &*tx)
 155                        .await?
 156                        == role
 157                );
 158            }
 159
 160            if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
 161                Err(anyhow!("no such channel, or not allowed"))?
 162            }
 163
 164            let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
 165            let room_id = self
 166                .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
 167                .await?;
 168
 169            self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
 170                .await
 171                .map(|jr| (jr, joined_channel_id))
 172        })
 173        .await
 174    }
 175
 176    pub async fn set_channel_visibility(
 177        &self,
 178        channel_id: ChannelId,
 179        visibility: ChannelVisibility,
 180        user_id: UserId,
 181    ) -> Result<channel::Model> {
 182        self.transaction(move |tx| async move {
 183            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 184                .await?;
 185
 186            let channel = channel::ActiveModel {
 187                id: ActiveValue::Unchanged(channel_id),
 188                visibility: ActiveValue::Set(visibility),
 189                ..Default::default()
 190            }
 191            .update(&*tx)
 192            .await?;
 193
 194            Ok(channel)
 195        })
 196        .await
 197    }
 198
 199    pub async fn delete_channel(
 200        &self,
 201        channel_id: ChannelId,
 202        user_id: UserId,
 203    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
 204        self.transaction(move |tx| async move {
 205            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 206                .await?;
 207
 208            // Don't remove descendant channels that have additional parents.
 209            let mut channels_to_remove: HashSet<ChannelId> = HashSet::default();
 210            channels_to_remove.insert(channel_id);
 211
 212            let graph = self.get_channel_descendants([channel_id], &*tx).await?;
 213            for edge in graph.iter() {
 214                channels_to_remove.insert(ChannelId::from_proto(edge.channel_id));
 215            }
 216
 217            {
 218                let mut channels_to_keep = channel_path::Entity::find()
 219                    .filter(
 220                        channel_path::Column::ChannelId
 221                            .is_in(channels_to_remove.iter().copied())
 222                            .and(
 223                                channel_path::Column::IdPath
 224                                    .not_like(&format!("%/{}/%", channel_id)),
 225                            ),
 226                    )
 227                    .stream(&*tx)
 228                    .await?;
 229                while let Some(row) = channels_to_keep.next().await {
 230                    let row = row?;
 231                    channels_to_remove.remove(&row.channel_id);
 232                }
 233            }
 234
 235            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 236            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 237                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 238                .select_only()
 239                .column(channel_member::Column::UserId)
 240                .distinct()
 241                .into_values::<_, QueryUserIds>()
 242                .all(&*tx)
 243                .await?;
 244
 245            channel::Entity::delete_many()
 246                .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied()))
 247                .exec(&*tx)
 248                .await?;
 249
 250            // Delete any other paths that include this channel
 251            let sql = r#"
 252                    DELETE FROM channel_paths
 253                    WHERE
 254                        id_path LIKE '%' || $1 || '%'
 255                "#;
 256            let channel_paths_stmt = Statement::from_sql_and_values(
 257                self.pool.get_database_backend(),
 258                sql,
 259                [channel_id.to_proto().into()],
 260            );
 261            tx.execute(channel_paths_stmt).await?;
 262
 263            Ok((channels_to_remove.into_iter().collect(), members_to_notify))
 264        })
 265        .await
 266    }
 267
 268    pub async fn invite_channel_member(
 269        &self,
 270        channel_id: ChannelId,
 271        invitee_id: UserId,
 272        inviter_id: UserId,
 273        role: ChannelRole,
 274    ) -> Result<NotificationBatch> {
 275        self.transaction(move |tx| async move {
 276            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
 277                .await?;
 278
 279            let channel = channel::Entity::find_by_id(channel_id)
 280                .one(&*tx)
 281                .await?
 282                .ok_or_else(|| anyhow!("no such channel"))?;
 283
 284            channel_member::ActiveModel {
 285                id: ActiveValue::NotSet,
 286                channel_id: ActiveValue::Set(channel_id),
 287                user_id: ActiveValue::Set(invitee_id),
 288                accepted: ActiveValue::Set(false),
 289                role: ActiveValue::Set(role),
 290            }
 291            .insert(&*tx)
 292            .await?;
 293
 294            Ok(self
 295                .create_notification(
 296                    invitee_id,
 297                    rpc::Notification::ChannelInvitation {
 298                        channel_id: channel_id.to_proto(),
 299                        channel_name: channel.name,
 300                        inviter_id: inviter_id.to_proto(),
 301                    },
 302                    true,
 303                    &*tx,
 304                )
 305                .await?
 306                .into_iter()
 307                .collect())
 308        })
 309        .await
 310    }
 311
 312    fn sanitize_channel_name(name: &str) -> Result<&str> {
 313        let new_name = name.trim().trim_start_matches('#');
 314        if new_name == "" {
 315            Err(anyhow!("channel name can't be blank"))?;
 316        }
 317        Ok(new_name)
 318    }
 319
 320    pub async fn rename_channel(
 321        &self,
 322        channel_id: ChannelId,
 323        user_id: UserId,
 324        new_name: &str,
 325    ) -> Result<Channel> {
 326        self.transaction(move |tx| async move {
 327            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 328
 329            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 330                .await?;
 331
 332            let channel = channel::ActiveModel {
 333                id: ActiveValue::Unchanged(channel_id),
 334                name: ActiveValue::Set(new_name.clone()),
 335                ..Default::default()
 336            }
 337            .update(&*tx)
 338            .await?;
 339
 340            Ok(Channel {
 341                id: channel.id,
 342                name: channel.name,
 343                visibility: channel.visibility,
 344            })
 345        })
 346        .await
 347    }
 348
 349    pub async fn respond_to_channel_invite(
 350        &self,
 351        channel_id: ChannelId,
 352        user_id: UserId,
 353        accept: bool,
 354    ) -> Result<NotificationBatch> {
 355        self.transaction(move |tx| async move {
 356            let rows_affected = if accept {
 357                channel_member::Entity::update_many()
 358                    .set(channel_member::ActiveModel {
 359                        accepted: ActiveValue::Set(accept),
 360                        ..Default::default()
 361                    })
 362                    .filter(
 363                        channel_member::Column::ChannelId
 364                            .eq(channel_id)
 365                            .and(channel_member::Column::UserId.eq(user_id))
 366                            .and(channel_member::Column::Accepted.eq(false)),
 367                    )
 368                    .exec(&*tx)
 369                    .await?
 370                    .rows_affected
 371            } else {
 372                channel_member::Entity::delete_many()
 373                    .filter(
 374                        channel_member::Column::ChannelId
 375                            .eq(channel_id)
 376                            .and(channel_member::Column::UserId.eq(user_id))
 377                            .and(channel_member::Column::Accepted.eq(false)),
 378                    )
 379                    .exec(&*tx)
 380                    .await?
 381                    .rows_affected
 382            };
 383
 384            if rows_affected == 0 {
 385                Err(anyhow!("no such invitation"))?;
 386            }
 387
 388            Ok(self
 389                .mark_notification_as_read_with_response(
 390                    user_id,
 391                    &rpc::Notification::ChannelInvitation {
 392                        channel_id: channel_id.to_proto(),
 393                        channel_name: Default::default(),
 394                        inviter_id: Default::default(),
 395                    },
 396                    accept,
 397                    &*tx,
 398                )
 399                .await?
 400                .into_iter()
 401                .collect())
 402        })
 403        .await
 404    }
 405
 406    pub async fn remove_channel_member(
 407        &self,
 408        channel_id: ChannelId,
 409        member_id: UserId,
 410        admin_id: UserId,
 411    ) -> Result<Option<NotificationId>> {
 412        self.transaction(|tx| async move {
 413            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 414                .await?;
 415
 416            let result = channel_member::Entity::delete_many()
 417                .filter(
 418                    channel_member::Column::ChannelId
 419                        .eq(channel_id)
 420                        .and(channel_member::Column::UserId.eq(member_id)),
 421                )
 422                .exec(&*tx)
 423                .await?;
 424
 425            if result.rows_affected == 0 {
 426                Err(anyhow!("no such member"))?;
 427            }
 428
 429            Ok(self
 430                .remove_notification(
 431                    member_id,
 432                    rpc::Notification::ChannelInvitation {
 433                        channel_id: channel_id.to_proto(),
 434                        channel_name: Default::default(),
 435                        inviter_id: Default::default(),
 436                    },
 437                    &*tx,
 438                )
 439                .await?)
 440        })
 441        .await
 442    }
 443
 444    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 445        self.transaction(|tx| async move {
 446            let channel_invites = channel_member::Entity::find()
 447                .filter(
 448                    channel_member::Column::UserId
 449                        .eq(user_id)
 450                        .and(channel_member::Column::Accepted.eq(false)),
 451                )
 452                .all(&*tx)
 453                .await?;
 454
 455            let channels = channel::Entity::find()
 456                .filter(
 457                    channel::Column::Id.is_in(
 458                        channel_invites
 459                            .into_iter()
 460                            .map(|channel_member| channel_member.channel_id),
 461                    ),
 462                )
 463                .all(&*tx)
 464                .await?;
 465
 466            let channels = channels
 467                .into_iter()
 468                .map(|channel| Channel {
 469                    id: channel.id,
 470                    name: channel.name,
 471                    visibility: channel.visibility,
 472                })
 473                .collect();
 474
 475            Ok(channels)
 476        })
 477        .await
 478    }
 479
 480    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 481        self.transaction(|tx| async move {
 482            let tx = tx;
 483
 484            let channel_memberships = channel_member::Entity::find()
 485                .filter(
 486                    channel_member::Column::UserId
 487                        .eq(user_id)
 488                        .and(channel_member::Column::Accepted.eq(true)),
 489                )
 490                .all(&*tx)
 491                .await?;
 492
 493            self.get_user_channels(user_id, channel_memberships, &tx)
 494                .await
 495        })
 496        .await
 497    }
 498
 499    pub async fn get_channel_for_user(
 500        &self,
 501        channel_id: ChannelId,
 502        user_id: UserId,
 503    ) -> Result<ChannelsForUser> {
 504        self.transaction(|tx| async move {
 505            let tx = tx;
 506
 507            let channel_membership = channel_member::Entity::find()
 508                .filter(
 509                    channel_member::Column::UserId
 510                        .eq(user_id)
 511                        .and(channel_member::Column::ChannelId.eq(channel_id))
 512                        .and(channel_member::Column::Accepted.eq(true)),
 513                )
 514                .all(&*tx)
 515                .await?;
 516
 517            self.get_user_channels(user_id, channel_membership, &tx)
 518                .await
 519        })
 520        .await
 521    }
 522
 523    pub async fn get_user_channels(
 524        &self,
 525        user_id: UserId,
 526        channel_memberships: Vec<channel_member::Model>,
 527        tx: &DatabaseTransaction,
 528    ) -> Result<ChannelsForUser> {
 529        let mut edges = self
 530            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 531            .await?;
 532
 533        let mut role_for_channel: HashMap<ChannelId, ChannelRole> = HashMap::default();
 534
 535        for membership in channel_memberships.iter() {
 536            role_for_channel.insert(membership.channel_id, membership.role);
 537        }
 538
 539        for ChannelEdge {
 540            parent_id,
 541            channel_id,
 542        } in edges.iter()
 543        {
 544            let parent_id = ChannelId::from_proto(*parent_id);
 545            let channel_id = ChannelId::from_proto(*channel_id);
 546            debug_assert!(role_for_channel.get(&parent_id).is_some());
 547            let parent_role = role_for_channel[&parent_id];
 548            if let Some(existing_role) = role_for_channel.get(&channel_id) {
 549                if existing_role.should_override(parent_role) {
 550                    continue;
 551                }
 552            }
 553            role_for_channel.insert(channel_id, parent_role);
 554        }
 555
 556        let mut channels: Vec<Channel> = Vec::new();
 557        let mut channels_with_admin_privileges: HashSet<ChannelId> = HashSet::default();
 558        let mut channels_to_remove: HashSet<u64> = HashSet::default();
 559
 560        let mut rows = channel::Entity::find()
 561            .filter(channel::Column::Id.is_in(role_for_channel.keys().copied()))
 562            .stream(&*tx)
 563            .await?;
 564
 565        while let Some(row) = rows.next().await {
 566            let channel = row?;
 567            let role = role_for_channel[&channel.id];
 568
 569            if role == ChannelRole::Banned
 570                || role == ChannelRole::Guest && channel.visibility != ChannelVisibility::Public
 571            {
 572                channels_to_remove.insert(channel.id.0 as u64);
 573                continue;
 574            }
 575
 576            channels.push(Channel {
 577                id: channel.id,
 578                name: channel.name,
 579                visibility: channel.visibility,
 580            });
 581
 582            if role == ChannelRole::Admin {
 583                channels_with_admin_privileges.insert(channel.id);
 584            }
 585        }
 586        drop(rows);
 587
 588        if !channels_to_remove.is_empty() {
 589            // Note: this code assumes each channel has one parent.
 590            // If there are multiple valid public paths to a channel,
 591            // e.g.
 592            // If both of these paths are present (* indicating public):
 593            // - zed* -> projects -> vim*
 594            // - zed* -> conrad -> public-projects* -> vim*
 595            // Users would only see one of them (based on edge sort order)
 596            let mut replacement_parent: HashMap<u64, u64> = HashMap::default();
 597            for ChannelEdge {
 598                parent_id,
 599                channel_id,
 600            } in edges.iter()
 601            {
 602                if channels_to_remove.contains(channel_id) {
 603                    replacement_parent.insert(*channel_id, *parent_id);
 604                }
 605            }
 606
 607            let mut new_edges: Vec<ChannelEdge> = Vec::new();
 608            'outer: for ChannelEdge {
 609                mut parent_id,
 610                channel_id,
 611            } in edges.iter()
 612            {
 613                if channels_to_remove.contains(channel_id) {
 614                    continue;
 615                }
 616                while channels_to_remove.contains(&parent_id) {
 617                    if let Some(new_parent_id) = replacement_parent.get(&parent_id) {
 618                        parent_id = *new_parent_id;
 619                    } else {
 620                        continue 'outer;
 621                    }
 622                }
 623                new_edges.push(ChannelEdge {
 624                    parent_id,
 625                    channel_id: *channel_id,
 626                })
 627            }
 628            edges = new_edges;
 629        }
 630
 631        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 632        enum QueryUserIdsAndChannelIds {
 633            ChannelId,
 634            UserId,
 635        }
 636
 637        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 638        {
 639            let mut rows = room_participant::Entity::find()
 640                .inner_join(room::Entity)
 641                .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
 642                .select_only()
 643                .column(room::Column::ChannelId)
 644                .column(room_participant::Column::UserId)
 645                .into_values::<_, QueryUserIdsAndChannelIds>()
 646                .stream(&*tx)
 647                .await?;
 648            while let Some(row) = rows.next().await {
 649                let row: (ChannelId, UserId) = row?;
 650                channel_participants.entry(row.0).or_default().push(row.1)
 651            }
 652        }
 653
 654        let channel_ids = channels.iter().map(|c| c.id).collect::<Vec<_>>();
 655        let channel_buffer_changes = self
 656            .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
 657            .await?;
 658
 659        let unseen_messages = self
 660            .unseen_channel_messages(user_id, &channel_ids, &*tx)
 661            .await?;
 662
 663        Ok(ChannelsForUser {
 664            channels: ChannelGraph { channels, edges },
 665            channel_participants,
 666            channels_with_admin_privileges,
 667            unseen_buffer_changes: channel_buffer_changes,
 668            channel_messages: unseen_messages,
 669        })
 670    }
 671
 672    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 673        self.transaction(|tx| async move { self.get_channel_participants_internal(id, &*tx).await })
 674            .await
 675    }
 676
 677    pub async fn set_channel_member_role(
 678        &self,
 679        channel_id: ChannelId,
 680        admin_id: UserId,
 681        for_user: UserId,
 682        role: ChannelRole,
 683    ) -> Result<channel_member::Model> {
 684        self.transaction(|tx| async move {
 685            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 686                .await?;
 687
 688            let membership = channel_member::Entity::find()
 689                .filter(
 690                    channel_member::Column::ChannelId
 691                        .eq(channel_id)
 692                        .and(channel_member::Column::UserId.eq(for_user)),
 693                )
 694                .one(&*tx)
 695                .await?;
 696
 697            let Some(membership) = membership else {
 698                Err(anyhow!("no such member"))?
 699            };
 700
 701            let mut update = membership.into_active_model();
 702            update.role = ActiveValue::Set(role);
 703            let updated = channel_member::Entity::update(update).exec(&*tx).await?;
 704
 705            Ok(updated)
 706        })
 707        .await
 708    }
 709
 710    pub async fn get_channel_participant_details(
 711        &self,
 712        channel_id: ChannelId,
 713        user_id: UserId,
 714    ) -> Result<Vec<proto::ChannelMember>> {
 715        self.transaction(|tx| async move {
 716            let user_role = self
 717                .check_user_is_channel_member(channel_id, user_id, &*tx)
 718                .await?;
 719
 720            let channel_visibility = channel::Entity::find()
 721                .filter(channel::Column::Id.eq(channel_id))
 722                .one(&*tx)
 723                .await?
 724                .map(|channel| channel.visibility)
 725                .unwrap_or(ChannelVisibility::Members);
 726
 727            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 728            enum QueryMemberDetails {
 729                UserId,
 730                Role,
 731                IsDirectMember,
 732                Accepted,
 733                Visibility,
 734            }
 735
 736            let tx = tx;
 737            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 738            let mut stream = channel_member::Entity::find()
 739                .left_join(channel::Entity)
 740                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 741                .select_only()
 742                .column(channel_member::Column::UserId)
 743                .column(channel_member::Column::Role)
 744                .column_as(
 745                    channel_member::Column::ChannelId.eq(channel_id),
 746                    QueryMemberDetails::IsDirectMember,
 747                )
 748                .column(channel_member::Column::Accepted)
 749                .column(channel::Column::Visibility)
 750                .into_values::<_, QueryMemberDetails>()
 751                .stream(&*tx)
 752                .await?;
 753
 754            struct UserDetail {
 755                kind: Kind,
 756                channel_role: ChannelRole,
 757            }
 758            let mut user_details: HashMap<UserId, UserDetail> = HashMap::default();
 759
 760            while let Some(user_membership) = stream.next().await {
 761                let (user_id, channel_role, is_direct_member, is_invite_accepted, visibility): (
 762                    UserId,
 763                    ChannelRole,
 764                    bool,
 765                    bool,
 766                    ChannelVisibility,
 767                ) = user_membership?;
 768                let kind = match (is_direct_member, is_invite_accepted) {
 769                    (true, true) => proto::channel_member::Kind::Member,
 770                    (true, false) => proto::channel_member::Kind::Invitee,
 771                    (false, true) => proto::channel_member::Kind::AncestorMember,
 772                    (false, false) => continue,
 773                };
 774
 775                if channel_role == ChannelRole::Guest
 776                    && visibility != ChannelVisibility::Public
 777                    && channel_visibility != ChannelVisibility::Public
 778                {
 779                    continue;
 780                }
 781
 782                if let Some(details_mut) = user_details.get_mut(&user_id) {
 783                    if channel_role.should_override(details_mut.channel_role) {
 784                        details_mut.channel_role = channel_role;
 785                    }
 786                    if kind == Kind::Member {
 787                        details_mut.kind = kind;
 788                    // the UI is going to be a bit confusing if you already have permissions
 789                    // that are greater than or equal to the ones you're being invited to.
 790                    } else if kind == Kind::Invitee && details_mut.kind == Kind::AncestorMember {
 791                        details_mut.kind = kind;
 792                    }
 793                } else {
 794                    user_details.insert(user_id, UserDetail { kind, channel_role });
 795                }
 796            }
 797
 798            Ok(user_details
 799                .into_iter()
 800                .filter_map(|(user_id, mut details)| {
 801                    // If the user is not an admin, don't give them as much
 802                    // information about the other members.
 803                    if user_role != ChannelRole::Admin {
 804                        if details.kind == Kind::Invitee
 805                            || details.channel_role == ChannelRole::Banned
 806                        {
 807                            return None;
 808                        }
 809
 810                        if details.channel_role == ChannelRole::Admin {
 811                            details.channel_role = ChannelRole::Member;
 812                        }
 813                    }
 814
 815                    Some(proto::ChannelMember {
 816                        user_id: user_id.to_proto(),
 817                        kind: details.kind.into(),
 818                        role: details.channel_role.into(),
 819                    })
 820                })
 821                .collect())
 822        })
 823        .await
 824    }
 825
 826    pub async fn get_channel_participants_internal(
 827        &self,
 828        id: ChannelId,
 829        tx: &DatabaseTransaction,
 830    ) -> Result<Vec<UserId>> {
 831        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 832        let user_ids = channel_member::Entity::find()
 833            .distinct()
 834            .filter(
 835                channel_member::Column::ChannelId
 836                    .is_in(ancestor_ids.iter().copied())
 837                    .and(channel_member::Column::Accepted.eq(true)),
 838            )
 839            .select_only()
 840            .column(channel_member::Column::UserId)
 841            .into_values::<_, QueryUserIds>()
 842            .all(&*tx)
 843            .await?;
 844        Ok(user_ids)
 845    }
 846
 847    pub async fn check_user_is_channel_admin(
 848        &self,
 849        channel_id: ChannelId,
 850        user_id: UserId,
 851        tx: &DatabaseTransaction,
 852    ) -> Result<()> {
 853        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 854            Some(ChannelRole::Admin) => Ok(()),
 855            Some(ChannelRole::Member)
 856            | Some(ChannelRole::Banned)
 857            | Some(ChannelRole::Guest)
 858            | None => Err(anyhow!(
 859                "user is not a channel admin or channel does not exist"
 860            ))?,
 861        }
 862    }
 863
 864    pub async fn check_user_is_channel_member(
 865        &self,
 866        channel_id: ChannelId,
 867        user_id: UserId,
 868        tx: &DatabaseTransaction,
 869    ) -> Result<ChannelRole> {
 870        let channel_role = self.channel_role_for_user(channel_id, user_id, tx).await?;
 871        match channel_role {
 872            Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()),
 873            Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
 874                "user is not a channel member or channel does not exist"
 875            ))?,
 876        }
 877    }
 878
 879    pub async fn check_user_is_channel_participant(
 880        &self,
 881        channel_id: ChannelId,
 882        user_id: UserId,
 883        tx: &DatabaseTransaction,
 884    ) -> Result<()> {
 885        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 886            Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
 887                Ok(())
 888            }
 889            Some(ChannelRole::Banned) | None => Err(anyhow!(
 890                "user is not a channel participant or channel does not exist"
 891            ))?,
 892        }
 893    }
 894
 895    pub async fn pending_invite_for_channel(
 896        &self,
 897        channel_id: ChannelId,
 898        user_id: UserId,
 899        tx: &DatabaseTransaction,
 900    ) -> Result<Option<channel_member::Model>> {
 901        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 902
 903        let row = channel_member::Entity::find()
 904            .filter(channel_member::Column::ChannelId.is_in(channel_ids))
 905            .filter(channel_member::Column::UserId.eq(user_id))
 906            .filter(channel_member::Column::Accepted.eq(false))
 907            .one(&*tx)
 908            .await?;
 909
 910        Ok(row)
 911    }
 912
 913    pub async fn most_public_ancestor_for_channel(
 914        &self,
 915        channel_id: ChannelId,
 916        tx: &DatabaseTransaction,
 917    ) -> Result<Option<ChannelId>> {
 918        // Note: if there are many paths to a channel, this will return just one
 919        let arbitary_path = channel_path::Entity::find()
 920            .filter(channel_path::Column::ChannelId.eq(channel_id))
 921            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 922            .one(tx)
 923            .await?;
 924
 925        let Some(path) = arbitary_path else {
 926            return Ok(None);
 927        };
 928
 929        let ancestor_ids: Vec<ChannelId> = path
 930            .id_path
 931            .trim_matches('/')
 932            .split('/')
 933            .map(|id| ChannelId::from_proto(id.parse().unwrap()))
 934            .collect();
 935
 936        let rows = channel::Entity::find()
 937            .filter(channel::Column::Id.is_in(ancestor_ids.iter().copied()))
 938            .filter(channel::Column::Visibility.eq(ChannelVisibility::Public))
 939            .all(&*tx)
 940            .await?;
 941
 942        let mut visible_channels: HashSet<ChannelId> = HashSet::default();
 943
 944        for row in rows {
 945            visible_channels.insert(row.id);
 946        }
 947
 948        for ancestor in ancestor_ids {
 949            if visible_channels.contains(&ancestor) {
 950                return Ok(Some(ancestor));
 951            }
 952        }
 953
 954        Ok(None)
 955    }
 956
 957    pub async fn channel_role_for_user(
 958        &self,
 959        channel_id: ChannelId,
 960        user_id: UserId,
 961        tx: &DatabaseTransaction,
 962    ) -> Result<Option<ChannelRole>> {
 963        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 964
 965        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 966        enum QueryChannelMembership {
 967            ChannelId,
 968            Role,
 969            Visibility,
 970        }
 971
 972        let mut rows = channel_member::Entity::find()
 973            .left_join(channel::Entity)
 974            .filter(
 975                channel_member::Column::ChannelId
 976                    .is_in(channel_ids)
 977                    .and(channel_member::Column::UserId.eq(user_id))
 978                    .and(channel_member::Column::Accepted.eq(true)),
 979            )
 980            .select_only()
 981            .column(channel_member::Column::ChannelId)
 982            .column(channel_member::Column::Role)
 983            .column(channel::Column::Visibility)
 984            .into_values::<_, QueryChannelMembership>()
 985            .stream(&*tx)
 986            .await?;
 987
 988        let mut user_role: Option<ChannelRole> = None;
 989
 990        let mut is_participant = false;
 991        let mut current_channel_visibility = None;
 992
 993        // note these channels are not iterated in any particular order,
 994        // our current logic takes the highest permission available.
 995        while let Some(row) = rows.next().await {
 996            let (membership_channel, role, visibility): (
 997                ChannelId,
 998                ChannelRole,
 999                ChannelVisibility,
1000            ) = row?;
1001
1002            match role {
1003                ChannelRole::Admin | ChannelRole::Member | ChannelRole::Banned => {
1004                    if let Some(users_role) = user_role {
1005                        user_role = Some(users_role.max(role));
1006                    } else {
1007                        user_role = Some(role)
1008                    }
1009                }
1010                ChannelRole::Guest if visibility == ChannelVisibility::Public => {
1011                    is_participant = true
1012                }
1013                ChannelRole::Guest => {}
1014            }
1015            if channel_id == membership_channel {
1016                current_channel_visibility = Some(visibility);
1017            }
1018        }
1019        // free up database connection
1020        drop(rows);
1021
1022        if is_participant && user_role.is_none() {
1023            if current_channel_visibility.is_none() {
1024                current_channel_visibility = channel::Entity::find()
1025                    .filter(channel::Column::Id.eq(channel_id))
1026                    .one(&*tx)
1027                    .await?
1028                    .map(|channel| channel.visibility);
1029            }
1030            if current_channel_visibility == Some(ChannelVisibility::Public) {
1031                user_role = Some(ChannelRole::Guest);
1032            }
1033        }
1034
1035        Ok(user_role)
1036    }
1037
1038    /// Returns the channel ancestors in arbitrary order
1039    pub async fn get_channel_ancestors(
1040        &self,
1041        channel_id: ChannelId,
1042        tx: &DatabaseTransaction,
1043    ) -> Result<Vec<ChannelId>> {
1044        let paths = channel_path::Entity::find()
1045            .filter(channel_path::Column::ChannelId.eq(channel_id))
1046            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
1047            .all(tx)
1048            .await?;
1049        let mut channel_ids = Vec::new();
1050        for path in paths {
1051            for id in path.id_path.trim_matches('/').split('/') {
1052                if let Ok(id) = id.parse() {
1053                    let id = ChannelId::from_proto(id);
1054                    if let Err(ix) = channel_ids.binary_search(&id) {
1055                        channel_ids.insert(ix, id);
1056                    }
1057                }
1058            }
1059        }
1060        Ok(channel_ids)
1061    }
1062
1063    // Returns the channel desendants as a sorted list of edges for further processing.
1064    // The edges are sorted such that you will see unknown channel ids as children
1065    // before you see them as parents.
1066    async fn get_channel_descendants(
1067        &self,
1068        channel_ids: impl IntoIterator<Item = ChannelId>,
1069        tx: &DatabaseTransaction,
1070    ) -> Result<Vec<ChannelEdge>> {
1071        let mut values = String::new();
1072        for id in channel_ids {
1073            if !values.is_empty() {
1074                values.push_str(", ");
1075            }
1076            write!(&mut values, "({})", id).unwrap();
1077        }
1078
1079        if values.is_empty() {
1080            return Ok(vec![]);
1081        }
1082
1083        let sql = format!(
1084            r#"
1085            SELECT
1086                descendant_paths.*
1087            FROM
1088                channel_paths parent_paths, channel_paths descendant_paths
1089            WHERE
1090                parent_paths.channel_id IN ({values}) AND
1091                descendant_paths.id_path != parent_paths.id_path AND
1092                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
1093            ORDER BY
1094                descendant_paths.id_path
1095        "#
1096        );
1097
1098        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
1099
1100        let mut paths = channel_path::Entity::find()
1101            .from_raw_sql(stmt)
1102            .stream(tx)
1103            .await?;
1104
1105        let mut results: Vec<ChannelEdge> = Vec::new();
1106        while let Some(path) = paths.next().await {
1107            let path = path?;
1108            let ids: Vec<&str> = path.id_path.trim_matches('/').split('/').collect();
1109
1110            debug_assert!(ids.len() >= 2);
1111            debug_assert!(ids[ids.len() - 1] == path.channel_id.to_string());
1112
1113            results.push(ChannelEdge {
1114                parent_id: ids[ids.len() - 2].parse().unwrap(),
1115                channel_id: ids[ids.len() - 1].parse().unwrap(),
1116            })
1117        }
1118
1119        Ok(results)
1120    }
1121
1122    /// Returns the channel with the given ID
1123    pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result<Channel> {
1124        self.transaction(|tx| async move {
1125            self.check_user_is_channel_participant(channel_id, user_id, &*tx)
1126                .await?;
1127
1128            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
1129            let Some(channel) = channel else {
1130                Err(anyhow!("no such channel"))?
1131            };
1132
1133            Ok(Channel {
1134                id: channel.id,
1135                visibility: channel.visibility,
1136                name: channel.name,
1137            })
1138        })
1139        .await
1140    }
1141
1142    pub(crate) async fn get_or_create_channel_room(
1143        &self,
1144        channel_id: ChannelId,
1145        live_kit_room: &str,
1146        environment: &str,
1147        tx: &DatabaseTransaction,
1148    ) -> Result<RoomId> {
1149        let room = room::Entity::find()
1150            .filter(room::Column::ChannelId.eq(channel_id))
1151            .one(&*tx)
1152            .await?;
1153
1154        let room_id = if let Some(room) = room {
1155            if let Some(env) = room.enviroment {
1156                if &env != environment {
1157                    Err(anyhow!("must join using the {} release", env))?;
1158                }
1159            }
1160            room.id
1161        } else {
1162            let result = room::Entity::insert(room::ActiveModel {
1163                channel_id: ActiveValue::Set(Some(channel_id)),
1164                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
1165                enviroment: ActiveValue::Set(Some(environment.to_string())),
1166                ..Default::default()
1167            })
1168            .exec(&*tx)
1169            .await?;
1170
1171            result.last_insert_id
1172        };
1173
1174        Ok(room_id)
1175    }
1176
1177    // Insert an edge from the given channel to the given other channel.
1178    pub async fn link_channel(
1179        &self,
1180        user: UserId,
1181        channel: ChannelId,
1182        to: ChannelId,
1183    ) -> Result<ChannelGraph> {
1184        self.transaction(|tx| async move {
1185            // Note that even with these maxed permissions, this linking operation
1186            // is still insecure because you can't remove someone's permissions to a
1187            // channel if they've linked the channel to one where they're an admin.
1188            self.check_user_is_channel_admin(channel, user, &*tx)
1189                .await?;
1190
1191            self.link_channel_internal(user, channel, to, &*tx).await
1192        })
1193        .await
1194    }
1195
1196    pub async fn link_channel_internal(
1197        &self,
1198        user: UserId,
1199        channel: ChannelId,
1200        new_parent: ChannelId,
1201        tx: &DatabaseTransaction,
1202    ) -> Result<ChannelGraph> {
1203        self.check_user_is_channel_admin(new_parent, user, &*tx)
1204            .await?;
1205
1206        let paths = channel_path::Entity::find()
1207            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
1208            .all(tx)
1209            .await?;
1210
1211        let mut new_path_suffixes = HashSet::default();
1212        for path in paths {
1213            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
1214                new_path_suffixes.insert((
1215                    path.channel_id,
1216                    path.id_path[(start_offset + 1)..].to_string(),
1217                ));
1218            }
1219        }
1220
1221        let paths_to_new_parent = channel_path::Entity::find()
1222            .filter(channel_path::Column::ChannelId.eq(new_parent))
1223            .all(tx)
1224            .await?;
1225
1226        let mut new_paths = Vec::new();
1227        for path in paths_to_new_parent {
1228            if path.id_path.contains(&format!("/{}/", channel)) {
1229                Err(anyhow!("cycle"))?;
1230            }
1231
1232            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
1233                channel_path::ActiveModel {
1234                    channel_id: ActiveValue::Set(*channel_id),
1235                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
1236                }
1237            }));
1238        }
1239
1240        channel_path::Entity::insert_many(new_paths)
1241            .exec(&*tx)
1242            .await?;
1243
1244        // remove any root edges for the channel we just linked
1245        {
1246            channel_path::Entity::delete_many()
1247                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
1248                .exec(&*tx)
1249                .await?;
1250        }
1251
1252        let membership = channel_member::Entity::find()
1253            .filter(
1254                channel_member::Column::ChannelId
1255                    .eq(channel)
1256                    .and(channel_member::Column::UserId.eq(user)),
1257            )
1258            .all(tx)
1259            .await?;
1260
1261        let mut channel_info = self.get_user_channels(user, membership, &*tx).await?;
1262
1263        channel_info.channels.edges.push(ChannelEdge {
1264            channel_id: channel.to_proto(),
1265            parent_id: new_parent.to_proto(),
1266        });
1267
1268        Ok(channel_info.channels)
1269    }
1270
1271    /// Unlink a channel from a given parent. This will add in a root edge if
1272    /// the channel has no other parents after this operation.
1273    pub async fn unlink_channel(
1274        &self,
1275        user: UserId,
1276        channel: ChannelId,
1277        from: ChannelId,
1278    ) -> Result<()> {
1279        self.transaction(|tx| async move {
1280            // Note that even with these maxed permissions, this linking operation
1281            // is still insecure because you can't remove someone's permissions to a
1282            // channel if they've linked the channel to one where they're an admin.
1283            self.check_user_is_channel_admin(channel, user, &*tx)
1284                .await?;
1285
1286            self.unlink_channel_internal(user, channel, from, &*tx)
1287                .await?;
1288
1289            Ok(())
1290        })
1291        .await
1292    }
1293
1294    pub async fn unlink_channel_internal(
1295        &self,
1296        user: UserId,
1297        channel: ChannelId,
1298        from: ChannelId,
1299        tx: &DatabaseTransaction,
1300    ) -> Result<()> {
1301        self.check_user_is_channel_admin(from, user, &*tx).await?;
1302
1303        let sql = r#"
1304            DELETE FROM channel_paths
1305            WHERE
1306                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
1307            RETURNING id_path, channel_id
1308        "#;
1309
1310        let paths = channel_path::Entity::find()
1311            .from_raw_sql(Statement::from_sql_and_values(
1312                self.pool.get_database_backend(),
1313                sql,
1314                [from.to_proto().into(), channel.to_proto().into()],
1315            ))
1316            .all(&*tx)
1317            .await?;
1318
1319        let is_stranded = channel_path::Entity::find()
1320            .filter(channel_path::Column::ChannelId.eq(channel))
1321            .count(&*tx)
1322            .await?
1323            == 0;
1324
1325        // Make sure that there is always at least one path to the channel
1326        if is_stranded {
1327            let root_paths: Vec<_> = paths
1328                .iter()
1329                .map(|path| {
1330                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
1331                    channel_path::ActiveModel {
1332                        channel_id: ActiveValue::Set(path.channel_id),
1333                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
1334                    }
1335                })
1336                .collect();
1337            channel_path::Entity::insert_many(root_paths)
1338                .exec(&*tx)
1339                .await?;
1340        }
1341
1342        Ok(())
1343    }
1344
1345    /// Move a channel from one parent to another, returns the
1346    /// Channels that were moved for notifying clients
1347    pub async fn move_channel(
1348        &self,
1349        user: UserId,
1350        channel: ChannelId,
1351        from: ChannelId,
1352        to: ChannelId,
1353    ) -> Result<ChannelGraph> {
1354        if from == to {
1355            return Ok(ChannelGraph {
1356                channels: vec![],
1357                edges: vec![],
1358            });
1359        }
1360
1361        self.transaction(|tx| async move {
1362            self.check_user_is_channel_admin(channel, user, &*tx)
1363                .await?;
1364
1365            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1366
1367            self.unlink_channel_internal(user, channel, from, &*tx)
1368                .await?;
1369
1370            Ok(moved_channels)
1371        })
1372        .await
1373    }
1374}
1375
1376#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1377enum QueryUserIds {
1378    UserId,
1379}
1380
1381#[derive(Debug)]
1382pub struct ChannelGraph {
1383    pub channels: Vec<Channel>,
1384    pub edges: Vec<ChannelEdge>,
1385}
1386
1387impl ChannelGraph {
1388    pub fn is_empty(&self) -> bool {
1389        self.channels.is_empty() && self.edges.is_empty()
1390    }
1391}
1392
1393#[cfg(test)]
1394impl PartialEq for ChannelGraph {
1395    fn eq(&self, other: &Self) -> bool {
1396        // Order independent comparison for tests
1397        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1398        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1399        let edges_set = self
1400            .edges
1401            .iter()
1402            .map(|edge| (edge.channel_id, edge.parent_id))
1403            .collect::<HashSet<_>>();
1404        let other_edges_set = other
1405            .edges
1406            .iter()
1407            .map(|edge| (edge.channel_id, edge.parent_id))
1408            .collect::<HashSet<_>>();
1409
1410        channels_set == other_channels_set && edges_set == other_edges_set
1411    }
1412}
1413
1414#[cfg(not(test))]
1415impl PartialEq for ChannelGraph {
1416    fn eq(&self, other: &Self) -> bool {
1417        self.channels == other.channels && self.edges == other.edges
1418    }
1419}