channels.rs

  1use super::*;
  2
  3impl Database {
  4    pub async fn create_root_channel(
  5        &self,
  6        name: &str,
  7        live_kit_room: &str,
  8        creator_id: UserId,
  9    ) -> Result<ChannelId> {
 10        self.create_channel(name, None, live_kit_room, creator_id)
 11            .await
 12    }
 13
 14    pub async fn create_channel(
 15        &self,
 16        name: &str,
 17        parent: Option<ChannelId>,
 18        live_kit_room: &str,
 19        creator_id: UserId,
 20    ) -> Result<ChannelId> {
 21        let name = Self::sanitize_channel_name(name)?;
 22        self.transaction(move |tx| async move {
 23            if let Some(parent) = parent {
 24                self.check_user_is_channel_admin(parent, creator_id, &*tx)
 25                    .await?;
 26            }
 27
 28            let channel = channel::ActiveModel {
 29                name: ActiveValue::Set(name.to_string()),
 30                ..Default::default()
 31            }
 32            .insert(&*tx)
 33            .await?;
 34
 35            let channel_paths_stmt;
 36            if let Some(parent) = parent {
 37                let sql = r#"
 38                    INSERT INTO channel_paths
 39                    (id_path, channel_id)
 40                    SELECT
 41                        id_path || $1 || '/', $2
 42                    FROM
 43                        channel_paths
 44                    WHERE
 45                        channel_id = $3
 46                "#;
 47                channel_paths_stmt = Statement::from_sql_and_values(
 48                    self.pool.get_database_backend(),
 49                    sql,
 50                    [
 51                        channel.id.to_proto().into(),
 52                        channel.id.to_proto().into(),
 53                        parent.to_proto().into(),
 54                    ],
 55                );
 56                tx.execute(channel_paths_stmt).await?;
 57            } else {
 58                channel_path::Entity::insert(channel_path::ActiveModel {
 59                    channel_id: ActiveValue::Set(channel.id),
 60                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
 61                })
 62                .exec(&*tx)
 63                .await?;
 64            }
 65
 66            channel_member::ActiveModel {
 67                channel_id: ActiveValue::Set(channel.id),
 68                user_id: ActiveValue::Set(creator_id),
 69                accepted: ActiveValue::Set(true),
 70                admin: ActiveValue::Set(true),
 71                ..Default::default()
 72            }
 73            .insert(&*tx)
 74            .await?;
 75
 76            room::ActiveModel {
 77                channel_id: ActiveValue::Set(Some(channel.id)),
 78                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
 79                ..Default::default()
 80            }
 81            .insert(&*tx)
 82            .await?;
 83
 84            Ok(channel.id)
 85        })
 86        .await
 87    }
 88
 89    pub async fn remove_channel(
 90        &self,
 91        channel_id: ChannelId,
 92        user_id: UserId,
 93    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
 94        self.transaction(move |tx| async move {
 95            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
 96                .await?;
 97
 98            // Don't remove descendant channels that have additional parents.
 99            let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
100            {
101                let mut channels_to_keep = channel_path::Entity::find()
102                    .filter(
103                        channel_path::Column::ChannelId
104                            .is_in(
105                                channels_to_remove
106                                    .keys()
107                                    .copied()
108                                    .filter(|&id| id != channel_id),
109                            )
110                            .and(
111                                channel_path::Column::IdPath
112                                    .not_like(&format!("%/{}/%", channel_id)),
113                            ),
114                    )
115                    .stream(&*tx)
116                    .await?;
117                while let Some(row) = channels_to_keep.next().await {
118                    let row = row?;
119                    channels_to_remove.remove(&row.channel_id);
120                }
121            }
122
123            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
124            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
125                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
126                .select_only()
127                .column(channel_member::Column::UserId)
128                .distinct()
129                .into_values::<_, QueryUserIds>()
130                .all(&*tx)
131                .await?;
132
133            channel::Entity::delete_many()
134                .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
135                .exec(&*tx)
136                .await?;
137
138            Ok((channels_to_remove.into_keys().collect(), members_to_notify))
139        })
140        .await
141    }
142
143    pub async fn invite_channel_member(
144        &self,
145        channel_id: ChannelId,
146        invitee_id: UserId,
147        inviter_id: UserId,
148        is_admin: bool,
149    ) -> Result<()> {
150        self.transaction(move |tx| async move {
151            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
152                .await?;
153
154            channel_member::ActiveModel {
155                channel_id: ActiveValue::Set(channel_id),
156                user_id: ActiveValue::Set(invitee_id),
157                accepted: ActiveValue::Set(false),
158                admin: ActiveValue::Set(is_admin),
159                ..Default::default()
160            }
161            .insert(&*tx)
162            .await?;
163
164            Ok(())
165        })
166        .await
167    }
168
169    fn sanitize_channel_name(name: &str) -> Result<&str> {
170        let new_name = name.trim().trim_start_matches('#');
171        if new_name == "" {
172            Err(anyhow!("channel name can't be blank"))?;
173        }
174        Ok(new_name)
175    }
176
177    pub async fn rename_channel(
178        &self,
179        channel_id: ChannelId,
180        user_id: UserId,
181        new_name: &str,
182    ) -> Result<String> {
183        self.transaction(move |tx| async move {
184            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
185
186            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
187                .await?;
188
189            channel::ActiveModel {
190                id: ActiveValue::Unchanged(channel_id),
191                name: ActiveValue::Set(new_name.clone()),
192                ..Default::default()
193            }
194            .update(&*tx)
195            .await?;
196
197            Ok(new_name)
198        })
199        .await
200    }
201
202    pub async fn respond_to_channel_invite(
203        &self,
204        channel_id: ChannelId,
205        user_id: UserId,
206        accept: bool,
207    ) -> Result<()> {
208        self.transaction(move |tx| async move {
209            let rows_affected = if accept {
210                channel_member::Entity::update_many()
211                    .set(channel_member::ActiveModel {
212                        accepted: ActiveValue::Set(accept),
213                        ..Default::default()
214                    })
215                    .filter(
216                        channel_member::Column::ChannelId
217                            .eq(channel_id)
218                            .and(channel_member::Column::UserId.eq(user_id))
219                            .and(channel_member::Column::Accepted.eq(false)),
220                    )
221                    .exec(&*tx)
222                    .await?
223                    .rows_affected
224            } else {
225                channel_member::ActiveModel {
226                    channel_id: ActiveValue::Unchanged(channel_id),
227                    user_id: ActiveValue::Unchanged(user_id),
228                    ..Default::default()
229                }
230                .delete(&*tx)
231                .await?
232                .rows_affected
233            };
234
235            if rows_affected == 0 {
236                Err(anyhow!("no such invitation"))?;
237            }
238
239            Ok(())
240        })
241        .await
242    }
243
244    pub async fn remove_channel_member(
245        &self,
246        channel_id: ChannelId,
247        member_id: UserId,
248        remover_id: UserId,
249    ) -> Result<()> {
250        self.transaction(|tx| async move {
251            self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
252                .await?;
253
254            let result = channel_member::Entity::delete_many()
255                .filter(
256                    channel_member::Column::ChannelId
257                        .eq(channel_id)
258                        .and(channel_member::Column::UserId.eq(member_id)),
259                )
260                .exec(&*tx)
261                .await?;
262
263            if result.rows_affected == 0 {
264                Err(anyhow!("no such member"))?;
265            }
266
267            Ok(())
268        })
269        .await
270    }
271
272    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
273        self.transaction(|tx| async move {
274            let channel_invites = channel_member::Entity::find()
275                .filter(
276                    channel_member::Column::UserId
277                        .eq(user_id)
278                        .and(channel_member::Column::Accepted.eq(false)),
279                )
280                .all(&*tx)
281                .await?;
282
283            let channels = channel::Entity::find()
284                .filter(
285                    channel::Column::Id.is_in(
286                        channel_invites
287                            .into_iter()
288                            .map(|channel_member| channel_member.channel_id),
289                    ),
290                )
291                .all(&*tx)
292                .await?;
293
294            let channels = channels
295                .into_iter()
296                .map(|channel| Channel {
297                    id: channel.id,
298                    name: channel.name,
299                    parent_id: None,
300                })
301                .collect();
302
303            Ok(channels)
304        })
305        .await
306    }
307
308    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
309        self.transaction(|tx| async move {
310            let tx = tx;
311
312            let channel_memberships = channel_member::Entity::find()
313                .filter(
314                    channel_member::Column::UserId
315                        .eq(user_id)
316                        .and(channel_member::Column::Accepted.eq(true)),
317                )
318                .all(&*tx)
319                .await?;
320
321            let parents_by_child_id = self
322                .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
323                .await?;
324
325            let channels_with_admin_privileges = channel_memberships
326                .iter()
327                .filter_map(|membership| membership.admin.then_some(membership.channel_id))
328                .collect();
329
330            let mut channels = Vec::with_capacity(parents_by_child_id.len());
331            {
332                let mut rows = channel::Entity::find()
333                    .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
334                    .stream(&*tx)
335                    .await?;
336                while let Some(row) = rows.next().await {
337                    let row = row?;
338                    channels.push(Channel {
339                        id: row.id,
340                        name: row.name,
341                        parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
342                    });
343                }
344            }
345
346            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
347            enum QueryUserIdsAndChannelIds {
348                ChannelId,
349                UserId,
350            }
351
352            let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
353            {
354                let mut rows = room_participant::Entity::find()
355                    .inner_join(room::Entity)
356                    .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
357                    .select_only()
358                    .column(room::Column::ChannelId)
359                    .column(room_participant::Column::UserId)
360                    .into_values::<_, QueryUserIdsAndChannelIds>()
361                    .stream(&*tx)
362                    .await?;
363                while let Some(row) = rows.next().await {
364                    let row: (ChannelId, UserId) = row?;
365                    channel_participants.entry(row.0).or_default().push(row.1)
366                }
367            }
368
369            Ok(ChannelsForUser {
370                channels,
371                channel_participants,
372                channels_with_admin_privileges,
373            })
374        })
375        .await
376    }
377
378    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
379        self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
380            .await
381    }
382
383    pub async fn set_channel_member_admin(
384        &self,
385        channel_id: ChannelId,
386        from: UserId,
387        for_user: UserId,
388        admin: bool,
389    ) -> Result<()> {
390        self.transaction(|tx| async move {
391            self.check_user_is_channel_admin(channel_id, from, &*tx)
392                .await?;
393
394            let result = channel_member::Entity::update_many()
395                .filter(
396                    channel_member::Column::ChannelId
397                        .eq(channel_id)
398                        .and(channel_member::Column::UserId.eq(for_user)),
399                )
400                .set(channel_member::ActiveModel {
401                    admin: ActiveValue::set(admin),
402                    ..Default::default()
403                })
404                .exec(&*tx)
405                .await?;
406
407            if result.rows_affected == 0 {
408                Err(anyhow!("no such member"))?;
409            }
410
411            Ok(())
412        })
413        .await
414    }
415
416    pub async fn get_channel_member_details(
417        &self,
418        channel_id: ChannelId,
419        user_id: UserId,
420    ) -> Result<Vec<proto::ChannelMember>> {
421        self.transaction(|tx| async move {
422            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
423                .await?;
424
425            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
426            enum QueryMemberDetails {
427                UserId,
428                Admin,
429                IsDirectMember,
430                Accepted,
431            }
432
433            let tx = tx;
434            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
435            let mut stream = channel_member::Entity::find()
436                .distinct()
437                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
438                .select_only()
439                .column(channel_member::Column::UserId)
440                .column(channel_member::Column::Admin)
441                .column_as(
442                    channel_member::Column::ChannelId.eq(channel_id),
443                    QueryMemberDetails::IsDirectMember,
444                )
445                .column(channel_member::Column::Accepted)
446                .order_by_asc(channel_member::Column::UserId)
447                .into_values::<_, QueryMemberDetails>()
448                .stream(&*tx)
449                .await?;
450
451            let mut rows = Vec::<proto::ChannelMember>::new();
452            while let Some(row) = stream.next().await {
453                let (user_id, is_admin, is_direct_member, is_invite_accepted): (
454                    UserId,
455                    bool,
456                    bool,
457                    bool,
458                ) = row?;
459                let kind = match (is_direct_member, is_invite_accepted) {
460                    (true, true) => proto::channel_member::Kind::Member,
461                    (true, false) => proto::channel_member::Kind::Invitee,
462                    (false, true) => proto::channel_member::Kind::AncestorMember,
463                    (false, false) => continue,
464                };
465                let user_id = user_id.to_proto();
466                let kind = kind.into();
467                if let Some(last_row) = rows.last_mut() {
468                    if last_row.user_id == user_id {
469                        if is_direct_member {
470                            last_row.kind = kind;
471                            last_row.admin = is_admin;
472                        }
473                        continue;
474                    }
475                }
476                rows.push(proto::ChannelMember {
477                    user_id,
478                    kind,
479                    admin: is_admin,
480                });
481            }
482
483            Ok(rows)
484        })
485        .await
486    }
487
488    pub async fn get_channel_members_internal(
489        &self,
490        id: ChannelId,
491        tx: &DatabaseTransaction,
492    ) -> Result<Vec<UserId>> {
493        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
494        let user_ids = channel_member::Entity::find()
495            .distinct()
496            .filter(
497                channel_member::Column::ChannelId
498                    .is_in(ancestor_ids.iter().copied())
499                    .and(channel_member::Column::Accepted.eq(true)),
500            )
501            .select_only()
502            .column(channel_member::Column::UserId)
503            .into_values::<_, QueryUserIds>()
504            .all(&*tx)
505            .await?;
506        Ok(user_ids)
507    }
508
509    pub async fn check_user_is_channel_member(
510        &self,
511        channel_id: ChannelId,
512        user_id: UserId,
513        tx: &DatabaseTransaction,
514    ) -> Result<()> {
515        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
516        channel_member::Entity::find()
517            .filter(
518                channel_member::Column::ChannelId
519                    .is_in(channel_ids)
520                    .and(channel_member::Column::UserId.eq(user_id)),
521            )
522            .one(&*tx)
523            .await?
524            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
525        Ok(())
526    }
527
528    pub async fn check_user_is_channel_admin(
529        &self,
530        channel_id: ChannelId,
531        user_id: UserId,
532        tx: &DatabaseTransaction,
533    ) -> Result<()> {
534        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
535        channel_member::Entity::find()
536            .filter(
537                channel_member::Column::ChannelId
538                    .is_in(channel_ids)
539                    .and(channel_member::Column::UserId.eq(user_id))
540                    .and(channel_member::Column::Admin.eq(true)),
541            )
542            .one(&*tx)
543            .await?
544            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
545        Ok(())
546    }
547
548    pub async fn get_channel_ancestors(
549        &self,
550        channel_id: ChannelId,
551        tx: &DatabaseTransaction,
552    ) -> Result<Vec<ChannelId>> {
553        let paths = channel_path::Entity::find()
554            .filter(channel_path::Column::ChannelId.eq(channel_id))
555            .all(tx)
556            .await?;
557        let mut channel_ids = Vec::new();
558        for path in paths {
559            for id in path.id_path.trim_matches('/').split('/') {
560                if let Ok(id) = id.parse() {
561                    let id = ChannelId::from_proto(id);
562                    if let Err(ix) = channel_ids.binary_search(&id) {
563                        channel_ids.insert(ix, id);
564                    }
565                }
566            }
567        }
568        Ok(channel_ids)
569    }
570
571    async fn get_channel_descendants(
572        &self,
573        channel_ids: impl IntoIterator<Item = ChannelId>,
574        tx: &DatabaseTransaction,
575    ) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
576        let mut values = String::new();
577        for id in channel_ids {
578            if !values.is_empty() {
579                values.push_str(", ");
580            }
581            write!(&mut values, "({})", id).unwrap();
582        }
583
584        if values.is_empty() {
585            return Ok(HashMap::default());
586        }
587
588        let sql = format!(
589            r#"
590            SELECT
591                descendant_paths.*
592            FROM
593                channel_paths parent_paths, channel_paths descendant_paths
594            WHERE
595                parent_paths.channel_id IN ({values}) AND
596                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
597        "#
598        );
599
600        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
601
602        let mut parents_by_child_id = HashMap::default();
603        let mut paths = channel_path::Entity::find()
604            .from_raw_sql(stmt)
605            .stream(tx)
606            .await?;
607
608        while let Some(path) = paths.next().await {
609            let path = path?;
610            let ids = path.id_path.trim_matches('/').split('/');
611            let mut parent_id = None;
612            for id in ids {
613                if let Ok(id) = id.parse() {
614                    let id = ChannelId::from_proto(id);
615                    if id == path.channel_id {
616                        break;
617                    }
618                    parent_id = Some(id);
619                }
620            }
621            parents_by_child_id.insert(path.channel_id, parent_id);
622        }
623
624        Ok(parents_by_child_id)
625    }
626
627    /// Returns the channel with the given ID and:
628    /// - true if the user is a member
629    /// - false if the user hasn't accepted the invitation yet
630    pub async fn get_channel(
631        &self,
632        channel_id: ChannelId,
633        user_id: UserId,
634    ) -> Result<Option<(Channel, bool)>> {
635        self.transaction(|tx| async move {
636            let tx = tx;
637
638            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
639
640            if let Some(channel) = channel {
641                if self
642                    .check_user_is_channel_member(channel_id, user_id, &*tx)
643                    .await
644                    .is_err()
645                {
646                    return Ok(None);
647                }
648
649                let channel_membership = channel_member::Entity::find()
650                    .filter(
651                        channel_member::Column::ChannelId
652                            .eq(channel_id)
653                            .and(channel_member::Column::UserId.eq(user_id)),
654                    )
655                    .one(&*tx)
656                    .await?;
657
658                let is_accepted = channel_membership
659                    .map(|membership| membership.accepted)
660                    .unwrap_or(false);
661
662                Ok(Some((
663                    Channel {
664                        id: channel.id,
665                        name: channel.name,
666                        parent_id: None,
667                    },
668                    is_accepted,
669                )))
670            } else {
671                Ok(None)
672            }
673        })
674        .await
675    }
676
677    pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
678        self.transaction(|tx| async move {
679            let tx = tx;
680            let room = channel::Model {
681                id: channel_id,
682                ..Default::default()
683            }
684            .find_related(room::Entity)
685            .one(&*tx)
686            .await?
687            .ok_or_else(|| anyhow!("invalid channel"))?;
688            Ok(room.id)
689        })
690        .await
691    }
692}
693
694#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
695enum QueryUserIds {
696    UserId,
697}