channels.rs

   1use super::*;
   2use rpc::proto::{channel_member::Kind, ChannelEdge};
   3
   4impl Database {
   5    #[cfg(test)]
   6    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
   7        self.transaction(move |tx| async move {
   8            let mut channels = Vec::new();
   9            let mut rows = channel::Entity::find().stream(&*tx).await?;
  10            while let Some(row) = rows.next().await {
  11                let row = row?;
  12                channels.push((row.id, row.name));
  13            }
  14            Ok(channels)
  15        })
  16        .await
  17    }
  18
  19    pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
  20        self.create_channel(name, None, creator_id).await
  21    }
  22
  23    pub async fn create_channel(
  24        &self,
  25        name: &str,
  26        parent: Option<ChannelId>,
  27        creator_id: UserId,
  28    ) -> Result<ChannelId> {
  29        let name = Self::sanitize_channel_name(name)?;
  30        self.transaction(move |tx| async move {
  31            if let Some(parent) = parent {
  32                self.check_user_is_channel_admin(parent, creator_id, &*tx)
  33                    .await?;
  34            }
  35
  36            let channel = channel::ActiveModel {
  37                id: ActiveValue::NotSet,
  38                name: ActiveValue::Set(name.to_string()),
  39                visibility: ActiveValue::Set(ChannelVisibility::Members),
  40            }
  41            .insert(&*tx)
  42            .await?;
  43
  44            if let Some(parent) = parent {
  45                let sql = r#"
  46                    INSERT INTO channel_paths
  47                    (id_path, channel_id)
  48                    SELECT
  49                        id_path || $1 || '/', $2
  50                    FROM
  51                        channel_paths
  52                    WHERE
  53                        channel_id = $3
  54                "#;
  55                let channel_paths_stmt = Statement::from_sql_and_values(
  56                    self.pool.get_database_backend(),
  57                    sql,
  58                    [
  59                        channel.id.to_proto().into(),
  60                        channel.id.to_proto().into(),
  61                        parent.to_proto().into(),
  62                    ],
  63                );
  64                tx.execute(channel_paths_stmt).await?;
  65            } else {
  66                channel_path::Entity::insert(channel_path::ActiveModel {
  67                    channel_id: ActiveValue::Set(channel.id),
  68                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
  69                })
  70                .exec(&*tx)
  71                .await?;
  72            }
  73
  74            channel_member::ActiveModel {
  75                id: ActiveValue::NotSet,
  76                channel_id: ActiveValue::Set(channel.id),
  77                user_id: ActiveValue::Set(creator_id),
  78                accepted: ActiveValue::Set(true),
  79                role: ActiveValue::Set(ChannelRole::Admin),
  80            }
  81            .insert(&*tx)
  82            .await?;
  83
  84            Ok(channel.id)
  85        })
  86        .await
  87    }
  88
  89    pub async fn join_channel(
  90        &self,
  91        channel_id: ChannelId,
  92        user_id: UserId,
  93        connection: ConnectionId,
  94        environment: &str,
  95    ) -> Result<(JoinRoom, Option<ChannelId>)> {
  96        self.transaction(move |tx| async move {
  97            let mut joined_channel_id = None;
  98
  99            let channel = channel::Entity::find()
 100                .filter(channel::Column::Id.eq(channel_id))
 101                .one(&*tx)
 102                .await?;
 103
 104            let mut role = self
 105                .channel_role_for_user(channel_id, user_id, &*tx)
 106                .await?;
 107
 108            if role.is_none() && channel.is_some() {
 109                if let Some(invitation) = self
 110                    .pending_invite_for_channel(channel_id, user_id, &*tx)
 111                    .await?
 112                {
 113                    // note, this may be a parent channel
 114                    joined_channel_id = Some(invitation.channel_id);
 115                    role = Some(invitation.role);
 116
 117                    channel_member::Entity::update(channel_member::ActiveModel {
 118                        accepted: ActiveValue::Set(true),
 119                        ..invitation.into_active_model()
 120                    })
 121                    .exec(&*tx)
 122                    .await?;
 123
 124                    debug_assert!(
 125                        self.channel_role_for_user(channel_id, user_id, &*tx)
 126                            .await?
 127                            == role
 128                    );
 129                }
 130            }
 131            if role.is_none()
 132                && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
 133            {
 134                let channel_id_to_join = self
 135                    .most_public_ancestor_for_channel(channel_id, &*tx)
 136                    .await?
 137                    .unwrap_or(channel_id);
 138                // TODO: change this back to Guest.
 139                role = Some(ChannelRole::Member);
 140                joined_channel_id = Some(channel_id_to_join);
 141
 142                channel_member::Entity::insert(channel_member::ActiveModel {
 143                    id: ActiveValue::NotSet,
 144                    channel_id: ActiveValue::Set(channel_id_to_join),
 145                    user_id: ActiveValue::Set(user_id),
 146                    accepted: ActiveValue::Set(true),
 147                    // TODO: change this back to Guest.
 148                    role: ActiveValue::Set(ChannelRole::Member),
 149                })
 150                .exec(&*tx)
 151                .await?;
 152
 153                debug_assert!(
 154                    self.channel_role_for_user(channel_id, user_id, &*tx)
 155                        .await?
 156                        == role
 157                );
 158            }
 159
 160            if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
 161                Err(anyhow!("no such channel, or not allowed"))?
 162            }
 163
 164            let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
 165            let room_id = self
 166                .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
 167                .await?;
 168
 169            self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
 170                .await
 171                .map(|jr| (jr, joined_channel_id))
 172        })
 173        .await
 174    }
 175
 176    pub async fn set_channel_visibility(
 177        &self,
 178        channel_id: ChannelId,
 179        visibility: ChannelVisibility,
 180        user_id: UserId,
 181    ) -> Result<channel::Model> {
 182        self.transaction(move |tx| async move {
 183            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 184                .await?;
 185
 186            let channel = channel::ActiveModel {
 187                id: ActiveValue::Unchanged(channel_id),
 188                visibility: ActiveValue::Set(visibility),
 189                ..Default::default()
 190            }
 191            .update(&*tx)
 192            .await?;
 193
 194            Ok(channel)
 195        })
 196        .await
 197    }
 198
 199    pub async fn delete_channel(
 200        &self,
 201        channel_id: ChannelId,
 202        user_id: UserId,
 203    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
 204        self.transaction(move |tx| async move {
 205            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 206                .await?;
 207
 208            // Don't remove descendant channels that have additional parents.
 209            let mut channels_to_remove: HashSet<ChannelId> = HashSet::default();
 210            channels_to_remove.insert(channel_id);
 211
 212            let graph = self.get_channel_descendants([channel_id], &*tx).await?;
 213            for edge in graph.iter() {
 214                channels_to_remove.insert(ChannelId::from_proto(edge.channel_id));
 215            }
 216
 217            {
 218                let mut channels_to_keep = channel_path::Entity::find()
 219                    .filter(
 220                        channel_path::Column::ChannelId
 221                            .is_in(channels_to_remove.iter().copied())
 222                            .and(
 223                                channel_path::Column::IdPath
 224                                    .not_like(&format!("%/{}/%", channel_id)),
 225                            ),
 226                    )
 227                    .stream(&*tx)
 228                    .await?;
 229                while let Some(row) = channels_to_keep.next().await {
 230                    let row = row?;
 231                    channels_to_remove.remove(&row.channel_id);
 232                }
 233            }
 234
 235            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
 236            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
 237                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
 238                .select_only()
 239                .column(channel_member::Column::UserId)
 240                .distinct()
 241                .into_values::<_, QueryUserIds>()
 242                .all(&*tx)
 243                .await?;
 244
 245            channel::Entity::delete_many()
 246                .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied()))
 247                .exec(&*tx)
 248                .await?;
 249
 250            // Delete any other paths that include this channel
 251            let sql = r#"
 252                    DELETE FROM channel_paths
 253                    WHERE
 254                        id_path LIKE '%' || $1 || '%'
 255                "#;
 256            let channel_paths_stmt = Statement::from_sql_and_values(
 257                self.pool.get_database_backend(),
 258                sql,
 259                [channel_id.to_proto().into()],
 260            );
 261            tx.execute(channel_paths_stmt).await?;
 262
 263            Ok((channels_to_remove.into_iter().collect(), members_to_notify))
 264        })
 265        .await
 266    }
 267
 268    pub async fn invite_channel_member(
 269        &self,
 270        channel_id: ChannelId,
 271        invitee_id: UserId,
 272        inviter_id: UserId,
 273        role: ChannelRole,
 274    ) -> Result<NotificationBatch> {
 275        self.transaction(move |tx| async move {
 276            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
 277                .await?;
 278
 279            let channel = channel::Entity::find_by_id(channel_id)
 280                .one(&*tx)
 281                .await?
 282                .ok_or_else(|| anyhow!("no such channel"))?;
 283
 284            channel_member::ActiveModel {
 285                id: ActiveValue::NotSet,
 286                channel_id: ActiveValue::Set(channel_id),
 287                user_id: ActiveValue::Set(invitee_id),
 288                accepted: ActiveValue::Set(false),
 289                role: ActiveValue::Set(role),
 290            }
 291            .insert(&*tx)
 292            .await?;
 293
 294            Ok(self
 295                .create_notification(
 296                    invitee_id,
 297                    rpc::Notification::ChannelInvitation {
 298                        channel_id: channel_id.to_proto(),
 299                        channel_name: channel.name,
 300                        inviter_id: inviter_id.to_proto(),
 301                    },
 302                    true,
 303                    &*tx,
 304                )
 305                .await?
 306                .into_iter()
 307                .collect())
 308        })
 309        .await
 310    }
 311
 312    fn sanitize_channel_name(name: &str) -> Result<&str> {
 313        let new_name = name.trim().trim_start_matches('#');
 314        if new_name == "" {
 315            Err(anyhow!("channel name can't be blank"))?;
 316        }
 317        Ok(new_name)
 318    }
 319
 320    pub async fn rename_channel(
 321        &self,
 322        channel_id: ChannelId,
 323        user_id: UserId,
 324        new_name: &str,
 325    ) -> Result<Channel> {
 326        self.transaction(move |tx| async move {
 327            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
 328
 329            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 330                .await?;
 331
 332            let channel = channel::ActiveModel {
 333                id: ActiveValue::Unchanged(channel_id),
 334                name: ActiveValue::Set(new_name.clone()),
 335                ..Default::default()
 336            }
 337            .update(&*tx)
 338            .await?;
 339
 340            Ok(Channel {
 341                id: channel.id,
 342                name: channel.name,
 343                visibility: channel.visibility,
 344            })
 345        })
 346        .await
 347    }
 348
 349    pub async fn respond_to_channel_invite(
 350        &self,
 351        channel_id: ChannelId,
 352        user_id: UserId,
 353        accept: bool,
 354    ) -> Result<NotificationBatch> {
 355        self.transaction(move |tx| async move {
 356            let rows_affected = if accept {
 357                channel_member::Entity::update_many()
 358                    .set(channel_member::ActiveModel {
 359                        accepted: ActiveValue::Set(accept),
 360                        ..Default::default()
 361                    })
 362                    .filter(
 363                        channel_member::Column::ChannelId
 364                            .eq(channel_id)
 365                            .and(channel_member::Column::UserId.eq(user_id))
 366                            .and(channel_member::Column::Accepted.eq(false)),
 367                    )
 368                    .exec(&*tx)
 369                    .await?
 370                    .rows_affected
 371            } else {
 372                channel_member::Entity::delete_many()
 373                    .filter(
 374                        channel_member::Column::ChannelId
 375                            .eq(channel_id)
 376                            .and(channel_member::Column::UserId.eq(user_id))
 377                            .and(channel_member::Column::Accepted.eq(false)),
 378                    )
 379                    .exec(&*tx)
 380                    .await?
 381                    .rows_affected
 382            };
 383
 384            if rows_affected == 0 {
 385                Err(anyhow!("no such invitation"))?;
 386            }
 387
 388            Ok(self
 389                .respond_to_notification(
 390                    user_id,
 391                    &rpc::Notification::ChannelInvitation {
 392                        channel_id: channel_id.to_proto(),
 393                        channel_name: Default::default(),
 394                        inviter_id: Default::default(),
 395                    },
 396                    accept,
 397                    &*tx,
 398                )
 399                .await?
 400                .into_iter()
 401                .collect())
 402        })
 403        .await
 404    }
 405
 406    pub async fn remove_channel_member(
 407        &self,
 408        channel_id: ChannelId,
 409        member_id: UserId,
 410        admin_id: UserId,
 411    ) -> Result<Option<NotificationId>> {
 412        self.transaction(|tx| async move {
 413            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 414                .await?;
 415
 416            let result = channel_member::Entity::delete_many()
 417                .filter(
 418                    channel_member::Column::ChannelId
 419                        .eq(channel_id)
 420                        .and(channel_member::Column::UserId.eq(member_id)),
 421                )
 422                .exec(&*tx)
 423                .await?;
 424
 425            if result.rows_affected == 0 {
 426                Err(anyhow!("no such member"))?;
 427            }
 428
 429            Ok(self
 430                .remove_notification(
 431                    member_id,
 432                    rpc::Notification::ChannelInvitation {
 433                        channel_id: channel_id.to_proto(),
 434                        channel_name: Default::default(),
 435                        inviter_id: Default::default(),
 436                    },
 437                    &*tx,
 438                )
 439                .await?)
 440        })
 441        .await
 442    }
 443
 444    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
 445        self.transaction(|tx| async move {
 446            let channel_invites = channel_member::Entity::find()
 447                .filter(
 448                    channel_member::Column::UserId
 449                        .eq(user_id)
 450                        .and(channel_member::Column::Accepted.eq(false)),
 451                )
 452                .all(&*tx)
 453                .await?;
 454
 455            let channels = channel::Entity::find()
 456                .filter(
 457                    channel::Column::Id.is_in(
 458                        channel_invites
 459                            .into_iter()
 460                            .map(|channel_member| channel_member.channel_id),
 461                    ),
 462                )
 463                .all(&*tx)
 464                .await?;
 465
 466            let channels = channels
 467                .into_iter()
 468                .map(|channel| Channel {
 469                    id: channel.id,
 470                    name: channel.name,
 471                    visibility: channel.visibility,
 472                })
 473                .collect();
 474
 475            Ok(channels)
 476        })
 477        .await
 478    }
 479
 480    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
 481        self.transaction(|tx| async move {
 482            let tx = tx;
 483
 484            let channel_memberships = channel_member::Entity::find()
 485                .filter(
 486                    channel_member::Column::UserId
 487                        .eq(user_id)
 488                        .and(channel_member::Column::Accepted.eq(true)),
 489                )
 490                .all(&*tx)
 491                .await?;
 492
 493            self.get_user_channels(user_id, channel_memberships, &tx)
 494                .await
 495        })
 496        .await
 497    }
 498
 499    pub async fn get_channel_for_user(
 500        &self,
 501        channel_id: ChannelId,
 502        user_id: UserId,
 503    ) -> Result<ChannelsForUser> {
 504        self.transaction(|tx| async move {
 505            let tx = tx;
 506
 507            let channel_membership = channel_member::Entity::find()
 508                .filter(
 509                    channel_member::Column::UserId
 510                        .eq(user_id)
 511                        .and(channel_member::Column::ChannelId.eq(channel_id))
 512                        .and(channel_member::Column::Accepted.eq(true)),
 513                )
 514                .all(&*tx)
 515                .await?;
 516
 517            self.get_user_channels(user_id, channel_membership, &tx)
 518                .await
 519        })
 520        .await
 521    }
 522
 523    pub async fn get_user_channels(
 524        &self,
 525        user_id: UserId,
 526        channel_memberships: Vec<channel_member::Model>,
 527        tx: &DatabaseTransaction,
 528    ) -> Result<ChannelsForUser> {
 529        let mut edges = self
 530            .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
 531            .await?;
 532
 533        let mut role_for_channel: HashMap<ChannelId, ChannelRole> = HashMap::default();
 534
 535        for membership in channel_memberships.iter() {
 536            role_for_channel.insert(membership.channel_id, membership.role);
 537        }
 538
 539        for ChannelEdge {
 540            parent_id,
 541            channel_id,
 542        } in edges.iter()
 543        {
 544            let parent_id = ChannelId::from_proto(*parent_id);
 545            let channel_id = ChannelId::from_proto(*channel_id);
 546            debug_assert!(role_for_channel.get(&parent_id).is_some());
 547            let parent_role = role_for_channel[&parent_id];
 548            if let Some(existing_role) = role_for_channel.get(&channel_id) {
 549                if existing_role.should_override(parent_role) {
 550                    continue;
 551                }
 552            }
 553            role_for_channel.insert(channel_id, parent_role);
 554        }
 555
 556        let mut channels: Vec<Channel> = Vec::new();
 557        let mut channels_with_admin_privileges: HashSet<ChannelId> = HashSet::default();
 558        let mut channels_to_remove: HashSet<u64> = HashSet::default();
 559
 560        let mut rows = channel::Entity::find()
 561            .filter(channel::Column::Id.is_in(role_for_channel.keys().copied()))
 562            .stream(&*tx)
 563            .await?;
 564
 565        while let Some(row) = rows.next().await {
 566            let channel = row?;
 567            let role = role_for_channel[&channel.id];
 568
 569            if role == ChannelRole::Banned
 570                || role == ChannelRole::Guest && channel.visibility != ChannelVisibility::Public
 571            {
 572                channels_to_remove.insert(channel.id.0 as u64);
 573                continue;
 574            }
 575
 576            channels.push(Channel {
 577                id: channel.id,
 578                name: channel.name,
 579                visibility: channel.visibility,
 580            });
 581
 582            if role == ChannelRole::Admin {
 583                channels_with_admin_privileges.insert(channel.id);
 584            }
 585        }
 586        drop(rows);
 587
 588        if !channels_to_remove.is_empty() {
 589            // Note: this code assumes each channel has one parent.
 590            // If there are multiple valid public paths to a channel,
 591            // e.g.
 592            // If both of these paths are present (* indicating public):
 593            // - zed* -> projects -> vim*
 594            // - zed* -> conrad -> public-projects* -> vim*
 595            // Users would only see one of them (based on edge sort order)
 596            let mut replacement_parent: HashMap<u64, u64> = HashMap::default();
 597            for ChannelEdge {
 598                parent_id,
 599                channel_id,
 600            } in edges.iter()
 601            {
 602                if channels_to_remove.contains(channel_id) {
 603                    replacement_parent.insert(*channel_id, *parent_id);
 604                }
 605            }
 606
 607            let mut new_edges: Vec<ChannelEdge> = Vec::new();
 608            'outer: for ChannelEdge {
 609                mut parent_id,
 610                channel_id,
 611            } in edges.iter()
 612            {
 613                if channels_to_remove.contains(channel_id) {
 614                    continue;
 615                }
 616                while channels_to_remove.contains(&parent_id) {
 617                    if let Some(new_parent_id) = replacement_parent.get(&parent_id) {
 618                        parent_id = *new_parent_id;
 619                    } else {
 620                        continue 'outer;
 621                    }
 622                }
 623                new_edges.push(ChannelEdge {
 624                    parent_id,
 625                    channel_id: *channel_id,
 626                })
 627            }
 628            edges = new_edges;
 629        }
 630
 631        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 632        enum QueryUserIdsAndChannelIds {
 633            ChannelId,
 634            UserId,
 635        }
 636
 637        let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
 638        {
 639            let mut rows = room_participant::Entity::find()
 640                .inner_join(room::Entity)
 641                .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
 642                .select_only()
 643                .column(room::Column::ChannelId)
 644                .column(room_participant::Column::UserId)
 645                .into_values::<_, QueryUserIdsAndChannelIds>()
 646                .stream(&*tx)
 647                .await?;
 648            while let Some(row) = rows.next().await {
 649                let row: (ChannelId, UserId) = row?;
 650                channel_participants.entry(row.0).or_default().push(row.1)
 651            }
 652        }
 653
 654        let channel_ids = channels.iter().map(|c| c.id).collect::<Vec<_>>();
 655        let channel_buffer_changes = self
 656            .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
 657            .await?;
 658
 659        let unseen_messages = self
 660            .unseen_channel_messages(user_id, &channel_ids, &*tx)
 661            .await?;
 662
 663        Ok(ChannelsForUser {
 664            channels: ChannelGraph { channels, edges },
 665            channel_participants,
 666            channels_with_admin_privileges,
 667            unseen_buffer_changes: channel_buffer_changes,
 668            channel_messages: unseen_messages,
 669        })
 670    }
 671
 672    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
 673        self.transaction(|tx| async move { self.get_channel_participants_internal(id, &*tx).await })
 674            .await
 675    }
 676
 677    pub async fn set_channel_member_role(
 678        &self,
 679        channel_id: ChannelId,
 680        admin_id: UserId,
 681        for_user: UserId,
 682        role: ChannelRole,
 683    ) -> Result<channel_member::Model> {
 684        self.transaction(|tx| async move {
 685            self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
 686                .await?;
 687
 688            let membership = channel_member::Entity::find()
 689                .filter(
 690                    channel_member::Column::ChannelId
 691                        .eq(channel_id)
 692                        .and(channel_member::Column::UserId.eq(for_user)),
 693                )
 694                .one(&*tx)
 695                .await?;
 696
 697            let Some(membership) = membership else {
 698                Err(anyhow!("no such member"))?
 699            };
 700
 701            let mut update = membership.into_active_model();
 702            update.role = ActiveValue::Set(role);
 703            let updated = channel_member::Entity::update(update).exec(&*tx).await?;
 704
 705            Ok(updated)
 706        })
 707        .await
 708    }
 709
 710    pub async fn get_channel_participant_details(
 711        &self,
 712        channel_id: ChannelId,
 713        user_id: UserId,
 714    ) -> Result<Vec<proto::ChannelMember>> {
 715        self.transaction(|tx| async move {
 716            let role = self
 717                .check_user_is_channel_member(channel_id, user_id, &*tx)
 718                .await?;
 719
 720            let channel_visibility = channel::Entity::find()
 721                .filter(channel::Column::Id.eq(channel_id))
 722                .one(&*tx)
 723                .await?
 724                .map(|channel| channel.visibility)
 725                .unwrap_or(ChannelVisibility::Members);
 726
 727            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 728            enum QueryMemberDetails {
 729                UserId,
 730                Role,
 731                IsDirectMember,
 732                Accepted,
 733                Visibility,
 734            }
 735
 736            let tx = tx;
 737            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
 738            let mut stream = channel_member::Entity::find()
 739                .left_join(channel::Entity)
 740                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
 741                .select_only()
 742                .column(channel_member::Column::UserId)
 743                .column(channel_member::Column::Role)
 744                .column_as(
 745                    channel_member::Column::ChannelId.eq(channel_id),
 746                    QueryMemberDetails::IsDirectMember,
 747                )
 748                .column(channel_member::Column::Accepted)
 749                .column(channel::Column::Visibility)
 750                .into_values::<_, QueryMemberDetails>()
 751                .stream(&*tx)
 752                .await?;
 753
 754            struct UserDetail {
 755                kind: Kind,
 756                channel_role: ChannelRole,
 757            }
 758            let mut user_details: HashMap<UserId, UserDetail> = HashMap::default();
 759
 760            while let Some(user_membership) = stream.next().await {
 761                let (user_id, channel_role, is_direct_member, is_invite_accepted, visibility): (
 762                    UserId,
 763                    ChannelRole,
 764                    bool,
 765                    bool,
 766                    ChannelVisibility,
 767                ) = user_membership?;
 768                let kind = match (is_direct_member, is_invite_accepted) {
 769                    (true, true) => proto::channel_member::Kind::Member,
 770                    (true, false) => proto::channel_member::Kind::Invitee,
 771                    (false, true) => proto::channel_member::Kind::AncestorMember,
 772                    (false, false) => continue,
 773                };
 774
 775                if channel_role == ChannelRole::Guest
 776                    && visibility != ChannelVisibility::Public
 777                    && channel_visibility != ChannelVisibility::Public
 778                {
 779                    continue;
 780                }
 781
 782                if let Some(details_mut) = user_details.get_mut(&user_id) {
 783                    if channel_role.should_override(details_mut.channel_role) {
 784                        details_mut.channel_role = channel_role;
 785                    }
 786                    if kind == Kind::Member {
 787                        details_mut.kind = kind;
 788                    // the UI is going to be a bit confusing if you already have permissions
 789                    // that are greater than or equal to the ones you're being invited to.
 790                    } else if kind == Kind::Invitee && details_mut.kind == Kind::AncestorMember {
 791                        details_mut.kind = kind;
 792                    }
 793                } else {
 794                    user_details.insert(user_id, UserDetail { kind, channel_role });
 795                }
 796            }
 797
 798            Ok(user_details
 799                .into_iter()
 800                .filter_map(|(user_id, details)| {
 801                    // If the user is not an admin, don't give them all of the details
 802                    if role != ChannelRole::Admin {
 803                        if details.kind == Kind::AncestorMember {
 804                            return None;
 805                        }
 806                        return None;
 807                    }
 808
 809                    Some(proto::ChannelMember {
 810                        user_id: user_id.to_proto(),
 811                        kind: details.kind.into(),
 812                        role: details.channel_role.into(),
 813                    })
 814                })
 815                .collect())
 816        })
 817        .await
 818    }
 819
 820    pub async fn get_channel_participants_internal(
 821        &self,
 822        id: ChannelId,
 823        tx: &DatabaseTransaction,
 824    ) -> Result<Vec<UserId>> {
 825        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
 826        let user_ids = channel_member::Entity::find()
 827            .distinct()
 828            .filter(
 829                channel_member::Column::ChannelId
 830                    .is_in(ancestor_ids.iter().copied())
 831                    .and(channel_member::Column::Accepted.eq(true)),
 832            )
 833            .select_only()
 834            .column(channel_member::Column::UserId)
 835            .into_values::<_, QueryUserIds>()
 836            .all(&*tx)
 837            .await?;
 838        Ok(user_ids)
 839    }
 840
 841    pub async fn check_user_is_channel_admin(
 842        &self,
 843        channel_id: ChannelId,
 844        user_id: UserId,
 845        tx: &DatabaseTransaction,
 846    ) -> Result<()> {
 847        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 848            Some(ChannelRole::Admin) => Ok(()),
 849            Some(ChannelRole::Member)
 850            | Some(ChannelRole::Banned)
 851            | Some(ChannelRole::Guest)
 852            | None => Err(anyhow!(
 853                "user is not a channel admin or channel does not exist"
 854            ))?,
 855        }
 856    }
 857
 858    pub async fn check_user_is_channel_member(
 859        &self,
 860        channel_id: ChannelId,
 861        user_id: UserId,
 862        tx: &DatabaseTransaction,
 863    ) -> Result<ChannelRole> {
 864        let channel_role = self.channel_role_for_user(channel_id, user_id, tx).await?;
 865        match channel_role {
 866            Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()),
 867            Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
 868                "user is not a channel member or channel does not exist"
 869            ))?,
 870        }
 871    }
 872
 873    pub async fn check_user_is_channel_participant(
 874        &self,
 875        channel_id: ChannelId,
 876        user_id: UserId,
 877        tx: &DatabaseTransaction,
 878    ) -> Result<()> {
 879        match self.channel_role_for_user(channel_id, user_id, tx).await? {
 880            Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
 881                Ok(())
 882            }
 883            Some(ChannelRole::Banned) | None => Err(anyhow!(
 884                "user is not a channel participant or channel does not exist"
 885            ))?,
 886        }
 887    }
 888
 889    pub async fn pending_invite_for_channel(
 890        &self,
 891        channel_id: ChannelId,
 892        user_id: UserId,
 893        tx: &DatabaseTransaction,
 894    ) -> Result<Option<channel_member::Model>> {
 895        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 896
 897        let row = channel_member::Entity::find()
 898            .filter(channel_member::Column::ChannelId.is_in(channel_ids))
 899            .filter(channel_member::Column::UserId.eq(user_id))
 900            .filter(channel_member::Column::Accepted.eq(false))
 901            .one(&*tx)
 902            .await?;
 903
 904        Ok(row)
 905    }
 906
 907    pub async fn most_public_ancestor_for_channel(
 908        &self,
 909        channel_id: ChannelId,
 910        tx: &DatabaseTransaction,
 911    ) -> Result<Option<ChannelId>> {
 912        // Note: if there are many paths to a channel, this will return just one
 913        let arbitary_path = channel_path::Entity::find()
 914            .filter(channel_path::Column::ChannelId.eq(channel_id))
 915            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
 916            .one(tx)
 917            .await?;
 918
 919        let Some(path) = arbitary_path else {
 920            return Ok(None);
 921        };
 922
 923        let ancestor_ids: Vec<ChannelId> = path
 924            .id_path
 925            .trim_matches('/')
 926            .split('/')
 927            .map(|id| ChannelId::from_proto(id.parse().unwrap()))
 928            .collect();
 929
 930        let rows = channel::Entity::find()
 931            .filter(channel::Column::Id.is_in(ancestor_ids.iter().copied()))
 932            .filter(channel::Column::Visibility.eq(ChannelVisibility::Public))
 933            .all(&*tx)
 934            .await?;
 935
 936        let mut visible_channels: HashSet<ChannelId> = HashSet::default();
 937
 938        for row in rows {
 939            visible_channels.insert(row.id);
 940        }
 941
 942        for ancestor in ancestor_ids {
 943            if visible_channels.contains(&ancestor) {
 944                return Ok(Some(ancestor));
 945            }
 946        }
 947
 948        Ok(None)
 949    }
 950
 951    pub async fn channel_role_for_user(
 952        &self,
 953        channel_id: ChannelId,
 954        user_id: UserId,
 955        tx: &DatabaseTransaction,
 956    ) -> Result<Option<ChannelRole>> {
 957        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
 958
 959        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 960        enum QueryChannelMembership {
 961            ChannelId,
 962            Role,
 963            Visibility,
 964        }
 965
 966        let mut rows = channel_member::Entity::find()
 967            .left_join(channel::Entity)
 968            .filter(
 969                channel_member::Column::ChannelId
 970                    .is_in(channel_ids)
 971                    .and(channel_member::Column::UserId.eq(user_id))
 972                    .and(channel_member::Column::Accepted.eq(true)),
 973            )
 974            .select_only()
 975            .column(channel_member::Column::ChannelId)
 976            .column(channel_member::Column::Role)
 977            .column(channel::Column::Visibility)
 978            .into_values::<_, QueryChannelMembership>()
 979            .stream(&*tx)
 980            .await?;
 981
 982        let mut user_role: Option<ChannelRole> = None;
 983
 984        let mut is_participant = false;
 985        let mut current_channel_visibility = None;
 986
 987        // note these channels are not iterated in any particular order,
 988        // our current logic takes the highest permission available.
 989        while let Some(row) = rows.next().await {
 990            let (membership_channel, role, visibility): (
 991                ChannelId,
 992                ChannelRole,
 993                ChannelVisibility,
 994            ) = row?;
 995
 996            match role {
 997                ChannelRole::Admin | ChannelRole::Member | ChannelRole::Banned => {
 998                    if let Some(users_role) = user_role {
 999                        user_role = Some(users_role.max(role));
1000                    } else {
1001                        user_role = Some(role)
1002                    }
1003                }
1004                ChannelRole::Guest if visibility == ChannelVisibility::Public => {
1005                    is_participant = true
1006                }
1007                ChannelRole::Guest => {}
1008            }
1009            if channel_id == membership_channel {
1010                current_channel_visibility = Some(visibility);
1011            }
1012        }
1013        // free up database connection
1014        drop(rows);
1015
1016        if is_participant && user_role.is_none() {
1017            if current_channel_visibility.is_none() {
1018                current_channel_visibility = channel::Entity::find()
1019                    .filter(channel::Column::Id.eq(channel_id))
1020                    .one(&*tx)
1021                    .await?
1022                    .map(|channel| channel.visibility);
1023            }
1024            if current_channel_visibility == Some(ChannelVisibility::Public) {
1025                user_role = Some(ChannelRole::Guest);
1026            }
1027        }
1028
1029        Ok(user_role)
1030    }
1031
1032    /// Returns the channel ancestors in arbitrary order
1033    pub async fn get_channel_ancestors(
1034        &self,
1035        channel_id: ChannelId,
1036        tx: &DatabaseTransaction,
1037    ) -> Result<Vec<ChannelId>> {
1038        let paths = channel_path::Entity::find()
1039            .filter(channel_path::Column::ChannelId.eq(channel_id))
1040            .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
1041            .all(tx)
1042            .await?;
1043        let mut channel_ids = Vec::new();
1044        for path in paths {
1045            for id in path.id_path.trim_matches('/').split('/') {
1046                if let Ok(id) = id.parse() {
1047                    let id = ChannelId::from_proto(id);
1048                    if let Err(ix) = channel_ids.binary_search(&id) {
1049                        channel_ids.insert(ix, id);
1050                    }
1051                }
1052            }
1053        }
1054        Ok(channel_ids)
1055    }
1056
1057    // Returns the channel desendants as a sorted list of edges for further processing.
1058    // The edges are sorted such that you will see unknown channel ids as children
1059    // before you see them as parents.
1060    async fn get_channel_descendants(
1061        &self,
1062        channel_ids: impl IntoIterator<Item = ChannelId>,
1063        tx: &DatabaseTransaction,
1064    ) -> Result<Vec<ChannelEdge>> {
1065        let mut values = String::new();
1066        for id in channel_ids {
1067            if !values.is_empty() {
1068                values.push_str(", ");
1069            }
1070            write!(&mut values, "({})", id).unwrap();
1071        }
1072
1073        if values.is_empty() {
1074            return Ok(vec![]);
1075        }
1076
1077        let sql = format!(
1078            r#"
1079            SELECT
1080                descendant_paths.*
1081            FROM
1082                channel_paths parent_paths, channel_paths descendant_paths
1083            WHERE
1084                parent_paths.channel_id IN ({values}) AND
1085                descendant_paths.id_path != parent_paths.id_path AND
1086                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
1087            ORDER BY
1088                descendant_paths.id_path
1089        "#
1090        );
1091
1092        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
1093
1094        let mut paths = channel_path::Entity::find()
1095            .from_raw_sql(stmt)
1096            .stream(tx)
1097            .await?;
1098
1099        let mut results: Vec<ChannelEdge> = Vec::new();
1100        while let Some(path) = paths.next().await {
1101            let path = path?;
1102            let ids: Vec<&str> = path.id_path.trim_matches('/').split('/').collect();
1103
1104            debug_assert!(ids.len() >= 2);
1105            debug_assert!(ids[ids.len() - 1] == path.channel_id.to_string());
1106
1107            results.push(ChannelEdge {
1108                parent_id: ids[ids.len() - 2].parse().unwrap(),
1109                channel_id: ids[ids.len() - 1].parse().unwrap(),
1110            })
1111        }
1112
1113        Ok(results)
1114    }
1115
1116    /// Returns the channel with the given ID
1117    pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result<Channel> {
1118        self.transaction(|tx| async move {
1119            self.check_user_is_channel_participant(channel_id, user_id, &*tx)
1120                .await?;
1121
1122            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
1123            let Some(channel) = channel else {
1124                Err(anyhow!("no such channel"))?
1125            };
1126
1127            Ok(Channel {
1128                id: channel.id,
1129                visibility: channel.visibility,
1130                name: channel.name,
1131            })
1132        })
1133        .await
1134    }
1135
1136    pub(crate) async fn get_or_create_channel_room(
1137        &self,
1138        channel_id: ChannelId,
1139        live_kit_room: &str,
1140        environment: &str,
1141        tx: &DatabaseTransaction,
1142    ) -> Result<RoomId> {
1143        let room = room::Entity::find()
1144            .filter(room::Column::ChannelId.eq(channel_id))
1145            .one(&*tx)
1146            .await?;
1147
1148        let room_id = if let Some(room) = room {
1149            if let Some(env) = room.enviroment {
1150                if &env != environment {
1151                    Err(anyhow!("must join using the {} release", env))?;
1152                }
1153            }
1154            room.id
1155        } else {
1156            let result = room::Entity::insert(room::ActiveModel {
1157                channel_id: ActiveValue::Set(Some(channel_id)),
1158                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
1159                enviroment: ActiveValue::Set(Some(environment.to_string())),
1160                ..Default::default()
1161            })
1162            .exec(&*tx)
1163            .await?;
1164
1165            result.last_insert_id
1166        };
1167
1168        Ok(room_id)
1169    }
1170
1171    // Insert an edge from the given channel to the given other channel.
1172    pub async fn link_channel(
1173        &self,
1174        user: UserId,
1175        channel: ChannelId,
1176        to: ChannelId,
1177    ) -> Result<ChannelGraph> {
1178        self.transaction(|tx| async move {
1179            // Note that even with these maxed permissions, this linking operation
1180            // is still insecure because you can't remove someone's permissions to a
1181            // channel if they've linked the channel to one where they're an admin.
1182            self.check_user_is_channel_admin(channel, user, &*tx)
1183                .await?;
1184
1185            self.link_channel_internal(user, channel, to, &*tx).await
1186        })
1187        .await
1188    }
1189
1190    pub async fn link_channel_internal(
1191        &self,
1192        user: UserId,
1193        channel: ChannelId,
1194        new_parent: ChannelId,
1195        tx: &DatabaseTransaction,
1196    ) -> Result<ChannelGraph> {
1197        self.check_user_is_channel_admin(new_parent, user, &*tx)
1198            .await?;
1199
1200        let paths = channel_path::Entity::find()
1201            .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
1202            .all(tx)
1203            .await?;
1204
1205        let mut new_path_suffixes = HashSet::default();
1206        for path in paths {
1207            if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
1208                new_path_suffixes.insert((
1209                    path.channel_id,
1210                    path.id_path[(start_offset + 1)..].to_string(),
1211                ));
1212            }
1213        }
1214
1215        let paths_to_new_parent = channel_path::Entity::find()
1216            .filter(channel_path::Column::ChannelId.eq(new_parent))
1217            .all(tx)
1218            .await?;
1219
1220        let mut new_paths = Vec::new();
1221        for path in paths_to_new_parent {
1222            if path.id_path.contains(&format!("/{}/", channel)) {
1223                Err(anyhow!("cycle"))?;
1224            }
1225
1226            new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
1227                channel_path::ActiveModel {
1228                    channel_id: ActiveValue::Set(*channel_id),
1229                    id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
1230                }
1231            }));
1232        }
1233
1234        channel_path::Entity::insert_many(new_paths)
1235            .exec(&*tx)
1236            .await?;
1237
1238        // remove any root edges for the channel we just linked
1239        {
1240            channel_path::Entity::delete_many()
1241                .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
1242                .exec(&*tx)
1243                .await?;
1244        }
1245
1246        let membership = channel_member::Entity::find()
1247            .filter(
1248                channel_member::Column::ChannelId
1249                    .eq(channel)
1250                    .and(channel_member::Column::UserId.eq(user)),
1251            )
1252            .all(tx)
1253            .await?;
1254
1255        let mut channel_info = self.get_user_channels(user, membership, &*tx).await?;
1256
1257        channel_info.channels.edges.push(ChannelEdge {
1258            channel_id: channel.to_proto(),
1259            parent_id: new_parent.to_proto(),
1260        });
1261
1262        Ok(channel_info.channels)
1263    }
1264
1265    /// Unlink a channel from a given parent. This will add in a root edge if
1266    /// the channel has no other parents after this operation.
1267    pub async fn unlink_channel(
1268        &self,
1269        user: UserId,
1270        channel: ChannelId,
1271        from: ChannelId,
1272    ) -> Result<()> {
1273        self.transaction(|tx| async move {
1274            // Note that even with these maxed permissions, this linking operation
1275            // is still insecure because you can't remove someone's permissions to a
1276            // channel if they've linked the channel to one where they're an admin.
1277            self.check_user_is_channel_admin(channel, user, &*tx)
1278                .await?;
1279
1280            self.unlink_channel_internal(user, channel, from, &*tx)
1281                .await?;
1282
1283            Ok(())
1284        })
1285        .await
1286    }
1287
1288    pub async fn unlink_channel_internal(
1289        &self,
1290        user: UserId,
1291        channel: ChannelId,
1292        from: ChannelId,
1293        tx: &DatabaseTransaction,
1294    ) -> Result<()> {
1295        self.check_user_is_channel_admin(from, user, &*tx).await?;
1296
1297        let sql = r#"
1298            DELETE FROM channel_paths
1299            WHERE
1300                id_path LIKE '%/' || $1 || '/' || $2 || '/%'
1301            RETURNING id_path, channel_id
1302        "#;
1303
1304        let paths = channel_path::Entity::find()
1305            .from_raw_sql(Statement::from_sql_and_values(
1306                self.pool.get_database_backend(),
1307                sql,
1308                [from.to_proto().into(), channel.to_proto().into()],
1309            ))
1310            .all(&*tx)
1311            .await?;
1312
1313        let is_stranded = channel_path::Entity::find()
1314            .filter(channel_path::Column::ChannelId.eq(channel))
1315            .count(&*tx)
1316            .await?
1317            == 0;
1318
1319        // Make sure that there is always at least one path to the channel
1320        if is_stranded {
1321            let root_paths: Vec<_> = paths
1322                .iter()
1323                .map(|path| {
1324                    let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
1325                    channel_path::ActiveModel {
1326                        channel_id: ActiveValue::Set(path.channel_id),
1327                        id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
1328                    }
1329                })
1330                .collect();
1331            channel_path::Entity::insert_many(root_paths)
1332                .exec(&*tx)
1333                .await?;
1334        }
1335
1336        Ok(())
1337    }
1338
1339    /// Move a channel from one parent to another, returns the
1340    /// Channels that were moved for notifying clients
1341    pub async fn move_channel(
1342        &self,
1343        user: UserId,
1344        channel: ChannelId,
1345        from: ChannelId,
1346        to: ChannelId,
1347    ) -> Result<ChannelGraph> {
1348        if from == to {
1349            return Ok(ChannelGraph {
1350                channels: vec![],
1351                edges: vec![],
1352            });
1353        }
1354
1355        self.transaction(|tx| async move {
1356            self.check_user_is_channel_admin(channel, user, &*tx)
1357                .await?;
1358
1359            let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1360
1361            self.unlink_channel_internal(user, channel, from, &*tx)
1362                .await?;
1363
1364            Ok(moved_channels)
1365        })
1366        .await
1367    }
1368}
1369
1370#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1371enum QueryUserIds {
1372    UserId,
1373}
1374
1375#[derive(Debug)]
1376pub struct ChannelGraph {
1377    pub channels: Vec<Channel>,
1378    pub edges: Vec<ChannelEdge>,
1379}
1380
1381impl ChannelGraph {
1382    pub fn is_empty(&self) -> bool {
1383        self.channels.is_empty() && self.edges.is_empty()
1384    }
1385}
1386
1387#[cfg(test)]
1388impl PartialEq for ChannelGraph {
1389    fn eq(&self, other: &Self) -> bool {
1390        // Order independent comparison for tests
1391        let channels_set = self.channels.iter().collect::<HashSet<_>>();
1392        let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1393        let edges_set = self
1394            .edges
1395            .iter()
1396            .map(|edge| (edge.channel_id, edge.parent_id))
1397            .collect::<HashSet<_>>();
1398        let other_edges_set = other
1399            .edges
1400            .iter()
1401            .map(|edge| (edge.channel_id, edge.parent_id))
1402            .collect::<HashSet<_>>();
1403
1404        channels_set == other_channels_set && edges_set == other_edges_set
1405    }
1406}
1407
1408#[cfg(not(test))]
1409impl PartialEq for ChannelGraph {
1410    fn eq(&self, other: &Self) -> bool {
1411        self.channels == other.channels && self.edges == other.edges
1412    }
1413}