channels.rs

  1use super::*;
  2
  3impl Database {
  4    #[cfg(test)]
  5    pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
  6        self.transaction(move |tx| async move {
  7            let mut channels = Vec::new();
  8            let mut rows = channel::Entity::find().stream(&*tx).await?;
  9            while let Some(row) = rows.next().await {
 10                let row = row?;
 11                channels.push((row.id, row.name));
 12            }
 13            Ok(channels)
 14        })
 15        .await
 16    }
 17
 18    pub async fn create_root_channel(
 19        &self,
 20        name: &str,
 21        live_kit_room: &str,
 22        creator_id: UserId,
 23    ) -> Result<ChannelId> {
 24        self.create_channel(name, None, live_kit_room, creator_id)
 25            .await
 26    }
 27
 28    pub async fn create_channel(
 29        &self,
 30        name: &str,
 31        parent: Option<ChannelId>,
 32        live_kit_room: &str,
 33        creator_id: UserId,
 34    ) -> Result<ChannelId> {
 35        let name = Self::sanitize_channel_name(name)?;
 36        self.transaction(move |tx| async move {
 37            if let Some(parent) = parent {
 38                self.check_user_is_channel_admin(parent, creator_id, &*tx)
 39                    .await?;
 40            }
 41
 42            let channel = channel::ActiveModel {
 43                name: ActiveValue::Set(name.to_string()),
 44                ..Default::default()
 45            }
 46            .insert(&*tx)
 47            .await?;
 48
 49            let channel_paths_stmt;
 50            if let Some(parent) = parent {
 51                let sql = r#"
 52                    INSERT INTO channel_paths
 53                    (id_path, channel_id)
 54                    SELECT
 55                        id_path || $1 || '/', $2
 56                    FROM
 57                        channel_paths
 58                    WHERE
 59                        channel_id = $3
 60                "#;
 61                channel_paths_stmt = Statement::from_sql_and_values(
 62                    self.pool.get_database_backend(),
 63                    sql,
 64                    [
 65                        channel.id.to_proto().into(),
 66                        channel.id.to_proto().into(),
 67                        parent.to_proto().into(),
 68                    ],
 69                );
 70                tx.execute(channel_paths_stmt).await?;
 71            } else {
 72                channel_path::Entity::insert(channel_path::ActiveModel {
 73                    channel_id: ActiveValue::Set(channel.id),
 74                    id_path: ActiveValue::Set(format!("/{}/", channel.id)),
 75                })
 76                .exec(&*tx)
 77                .await?;
 78            }
 79
 80            channel_member::ActiveModel {
 81                channel_id: ActiveValue::Set(channel.id),
 82                user_id: ActiveValue::Set(creator_id),
 83                accepted: ActiveValue::Set(true),
 84                admin: ActiveValue::Set(true),
 85                ..Default::default()
 86            }
 87            .insert(&*tx)
 88            .await?;
 89
 90            room::ActiveModel {
 91                channel_id: ActiveValue::Set(Some(channel.id)),
 92                live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
 93                ..Default::default()
 94            }
 95            .insert(&*tx)
 96            .await?;
 97
 98            Ok(channel.id)
 99        })
100        .await
101    }
102
103    pub async fn remove_channel(
104        &self,
105        channel_id: ChannelId,
106        user_id: UserId,
107    ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
108        self.transaction(move |tx| async move {
109            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
110                .await?;
111
112            // Don't remove descendant channels that have additional parents.
113            let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
114            {
115                let mut channels_to_keep = channel_path::Entity::find()
116                    .filter(
117                        channel_path::Column::ChannelId
118                            .is_in(
119                                channels_to_remove
120                                    .keys()
121                                    .copied()
122                                    .filter(|&id| id != channel_id),
123                            )
124                            .and(
125                                channel_path::Column::IdPath
126                                    .not_like(&format!("%/{}/%", channel_id)),
127                            ),
128                    )
129                    .stream(&*tx)
130                    .await?;
131                while let Some(row) = channels_to_keep.next().await {
132                    let row = row?;
133                    channels_to_remove.remove(&row.channel_id);
134                }
135            }
136
137            let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
138            let members_to_notify: Vec<UserId> = channel_member::Entity::find()
139                .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
140                .select_only()
141                .column(channel_member::Column::UserId)
142                .distinct()
143                .into_values::<_, QueryUserIds>()
144                .all(&*tx)
145                .await?;
146
147            channel::Entity::delete_many()
148                .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
149                .exec(&*tx)
150                .await?;
151
152            Ok((channels_to_remove.into_keys().collect(), members_to_notify))
153        })
154        .await
155    }
156
157    pub async fn invite_channel_member(
158        &self,
159        channel_id: ChannelId,
160        invitee_id: UserId,
161        inviter_id: UserId,
162        is_admin: bool,
163    ) -> Result<()> {
164        self.transaction(move |tx| async move {
165            self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
166                .await?;
167
168            channel_member::ActiveModel {
169                channel_id: ActiveValue::Set(channel_id),
170                user_id: ActiveValue::Set(invitee_id),
171                accepted: ActiveValue::Set(false),
172                admin: ActiveValue::Set(is_admin),
173                ..Default::default()
174            }
175            .insert(&*tx)
176            .await?;
177
178            Ok(())
179        })
180        .await
181    }
182
183    fn sanitize_channel_name(name: &str) -> Result<&str> {
184        let new_name = name.trim().trim_start_matches('#');
185        if new_name == "" {
186            Err(anyhow!("channel name can't be blank"))?;
187        }
188        Ok(new_name)
189    }
190
191    pub async fn rename_channel(
192        &self,
193        channel_id: ChannelId,
194        user_id: UserId,
195        new_name: &str,
196    ) -> Result<String> {
197        self.transaction(move |tx| async move {
198            let new_name = Self::sanitize_channel_name(new_name)?.to_string();
199
200            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
201                .await?;
202
203            channel::ActiveModel {
204                id: ActiveValue::Unchanged(channel_id),
205                name: ActiveValue::Set(new_name.clone()),
206                ..Default::default()
207            }
208            .update(&*tx)
209            .await?;
210
211            Ok(new_name)
212        })
213        .await
214    }
215
216    pub async fn respond_to_channel_invite(
217        &self,
218        channel_id: ChannelId,
219        user_id: UserId,
220        accept: bool,
221    ) -> Result<()> {
222        self.transaction(move |tx| async move {
223            let rows_affected = if accept {
224                channel_member::Entity::update_many()
225                    .set(channel_member::ActiveModel {
226                        accepted: ActiveValue::Set(accept),
227                        ..Default::default()
228                    })
229                    .filter(
230                        channel_member::Column::ChannelId
231                            .eq(channel_id)
232                            .and(channel_member::Column::UserId.eq(user_id))
233                            .and(channel_member::Column::Accepted.eq(false)),
234                    )
235                    .exec(&*tx)
236                    .await?
237                    .rows_affected
238            } else {
239                channel_member::ActiveModel {
240                    channel_id: ActiveValue::Unchanged(channel_id),
241                    user_id: ActiveValue::Unchanged(user_id),
242                    ..Default::default()
243                }
244                .delete(&*tx)
245                .await?
246                .rows_affected
247            };
248
249            if rows_affected == 0 {
250                Err(anyhow!("no such invitation"))?;
251            }
252
253            Ok(())
254        })
255        .await
256    }
257
258    pub async fn remove_channel_member(
259        &self,
260        channel_id: ChannelId,
261        member_id: UserId,
262        remover_id: UserId,
263    ) -> Result<()> {
264        self.transaction(|tx| async move {
265            self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
266                .await?;
267
268            let result = channel_member::Entity::delete_many()
269                .filter(
270                    channel_member::Column::ChannelId
271                        .eq(channel_id)
272                        .and(channel_member::Column::UserId.eq(member_id)),
273                )
274                .exec(&*tx)
275                .await?;
276
277            if result.rows_affected == 0 {
278                Err(anyhow!("no such member"))?;
279            }
280
281            Ok(())
282        })
283        .await
284    }
285
286    pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
287        self.transaction(|tx| async move {
288            let channel_invites = channel_member::Entity::find()
289                .filter(
290                    channel_member::Column::UserId
291                        .eq(user_id)
292                        .and(channel_member::Column::Accepted.eq(false)),
293                )
294                .all(&*tx)
295                .await?;
296
297            let channels = channel::Entity::find()
298                .filter(
299                    channel::Column::Id.is_in(
300                        channel_invites
301                            .into_iter()
302                            .map(|channel_member| channel_member.channel_id),
303                    ),
304                )
305                .all(&*tx)
306                .await?;
307
308            let channels = channels
309                .into_iter()
310                .map(|channel| Channel {
311                    id: channel.id,
312                    name: channel.name,
313                    parent_id: None,
314                })
315                .collect();
316
317            Ok(channels)
318        })
319        .await
320    }
321
322    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
323        self.transaction(|tx| async move {
324            let tx = tx;
325
326            let channel_memberships = channel_member::Entity::find()
327                .filter(
328                    channel_member::Column::UserId
329                        .eq(user_id)
330                        .and(channel_member::Column::Accepted.eq(true)),
331                )
332                .all(&*tx)
333                .await?;
334
335            let parents_by_child_id = self
336                .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
337                .await?;
338
339            let channels_with_admin_privileges = channel_memberships
340                .iter()
341                .filter_map(|membership| membership.admin.then_some(membership.channel_id))
342                .collect();
343
344            let mut channels = Vec::with_capacity(parents_by_child_id.len());
345            {
346                let mut rows = channel::Entity::find()
347                    .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
348                    .stream(&*tx)
349                    .await?;
350                while let Some(row) = rows.next().await {
351                    let row = row?;
352                    channels.push(Channel {
353                        id: row.id,
354                        name: row.name,
355                        parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
356                    });
357                }
358            }
359
360            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
361            enum QueryUserIdsAndChannelIds {
362                ChannelId,
363                UserId,
364            }
365
366            let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
367            {
368                let mut rows = room_participant::Entity::find()
369                    .inner_join(room::Entity)
370                    .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
371                    .select_only()
372                    .column(room::Column::ChannelId)
373                    .column(room_participant::Column::UserId)
374                    .into_values::<_, QueryUserIdsAndChannelIds>()
375                    .stream(&*tx)
376                    .await?;
377                while let Some(row) = rows.next().await {
378                    let row: (ChannelId, UserId) = row?;
379                    channel_participants.entry(row.0).or_default().push(row.1)
380                }
381            }
382
383            Ok(ChannelsForUser {
384                channels,
385                channel_participants,
386                channels_with_admin_privileges,
387            })
388        })
389        .await
390    }
391
392    pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
393        self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
394            .await
395    }
396
397    pub async fn set_channel_member_admin(
398        &self,
399        channel_id: ChannelId,
400        from: UserId,
401        for_user: UserId,
402        admin: bool,
403    ) -> Result<()> {
404        self.transaction(|tx| async move {
405            self.check_user_is_channel_admin(channel_id, from, &*tx)
406                .await?;
407
408            let result = channel_member::Entity::update_many()
409                .filter(
410                    channel_member::Column::ChannelId
411                        .eq(channel_id)
412                        .and(channel_member::Column::UserId.eq(for_user)),
413                )
414                .set(channel_member::ActiveModel {
415                    admin: ActiveValue::set(admin),
416                    ..Default::default()
417                })
418                .exec(&*tx)
419                .await?;
420
421            if result.rows_affected == 0 {
422                Err(anyhow!("no such member"))?;
423            }
424
425            Ok(())
426        })
427        .await
428    }
429
430    pub async fn get_channel_member_details(
431        &self,
432        channel_id: ChannelId,
433        user_id: UserId,
434    ) -> Result<Vec<proto::ChannelMember>> {
435        self.transaction(|tx| async move {
436            self.check_user_is_channel_admin(channel_id, user_id, &*tx)
437                .await?;
438
439            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
440            enum QueryMemberDetails {
441                UserId,
442                Admin,
443                IsDirectMember,
444                Accepted,
445            }
446
447            let tx = tx;
448            let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
449            let mut stream = channel_member::Entity::find()
450                .distinct()
451                .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
452                .select_only()
453                .column(channel_member::Column::UserId)
454                .column(channel_member::Column::Admin)
455                .column_as(
456                    channel_member::Column::ChannelId.eq(channel_id),
457                    QueryMemberDetails::IsDirectMember,
458                )
459                .column(channel_member::Column::Accepted)
460                .order_by_asc(channel_member::Column::UserId)
461                .into_values::<_, QueryMemberDetails>()
462                .stream(&*tx)
463                .await?;
464
465            let mut rows = Vec::<proto::ChannelMember>::new();
466            while let Some(row) = stream.next().await {
467                let (user_id, is_admin, is_direct_member, is_invite_accepted): (
468                    UserId,
469                    bool,
470                    bool,
471                    bool,
472                ) = row?;
473                let kind = match (is_direct_member, is_invite_accepted) {
474                    (true, true) => proto::channel_member::Kind::Member,
475                    (true, false) => proto::channel_member::Kind::Invitee,
476                    (false, true) => proto::channel_member::Kind::AncestorMember,
477                    (false, false) => continue,
478                };
479                let user_id = user_id.to_proto();
480                let kind = kind.into();
481                if let Some(last_row) = rows.last_mut() {
482                    if last_row.user_id == user_id {
483                        if is_direct_member {
484                            last_row.kind = kind;
485                            last_row.admin = is_admin;
486                        }
487                        continue;
488                    }
489                }
490                rows.push(proto::ChannelMember {
491                    user_id,
492                    kind,
493                    admin: is_admin,
494                });
495            }
496
497            Ok(rows)
498        })
499        .await
500    }
501
502    pub async fn get_channel_members_internal(
503        &self,
504        id: ChannelId,
505        tx: &DatabaseTransaction,
506    ) -> Result<Vec<UserId>> {
507        let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
508        let user_ids = channel_member::Entity::find()
509            .distinct()
510            .filter(
511                channel_member::Column::ChannelId
512                    .is_in(ancestor_ids.iter().copied())
513                    .and(channel_member::Column::Accepted.eq(true)),
514            )
515            .select_only()
516            .column(channel_member::Column::UserId)
517            .into_values::<_, QueryUserIds>()
518            .all(&*tx)
519            .await?;
520        Ok(user_ids)
521    }
522
523    pub async fn check_user_is_channel_member(
524        &self,
525        channel_id: ChannelId,
526        user_id: UserId,
527        tx: &DatabaseTransaction,
528    ) -> Result<()> {
529        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
530        channel_member::Entity::find()
531            .filter(
532                channel_member::Column::ChannelId
533                    .is_in(channel_ids)
534                    .and(channel_member::Column::UserId.eq(user_id)),
535            )
536            .one(&*tx)
537            .await?
538            .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
539        Ok(())
540    }
541
542    pub async fn check_user_is_channel_admin(
543        &self,
544        channel_id: ChannelId,
545        user_id: UserId,
546        tx: &DatabaseTransaction,
547    ) -> Result<()> {
548        let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
549        channel_member::Entity::find()
550            .filter(
551                channel_member::Column::ChannelId
552                    .is_in(channel_ids)
553                    .and(channel_member::Column::UserId.eq(user_id))
554                    .and(channel_member::Column::Admin.eq(true)),
555            )
556            .one(&*tx)
557            .await?
558            .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
559        Ok(())
560    }
561
562    pub async fn get_channel_ancestors(
563        &self,
564        channel_id: ChannelId,
565        tx: &DatabaseTransaction,
566    ) -> Result<Vec<ChannelId>> {
567        let paths = channel_path::Entity::find()
568            .filter(channel_path::Column::ChannelId.eq(channel_id))
569            .all(tx)
570            .await?;
571        let mut channel_ids = Vec::new();
572        for path in paths {
573            for id in path.id_path.trim_matches('/').split('/') {
574                if let Ok(id) = id.parse() {
575                    let id = ChannelId::from_proto(id);
576                    if let Err(ix) = channel_ids.binary_search(&id) {
577                        channel_ids.insert(ix, id);
578                    }
579                }
580            }
581        }
582        Ok(channel_ids)
583    }
584
585    async fn get_channel_descendants(
586        &self,
587        channel_ids: impl IntoIterator<Item = ChannelId>,
588        tx: &DatabaseTransaction,
589    ) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
590        let mut values = String::new();
591        for id in channel_ids {
592            if !values.is_empty() {
593                values.push_str(", ");
594            }
595            write!(&mut values, "({})", id).unwrap();
596        }
597
598        if values.is_empty() {
599            return Ok(HashMap::default());
600        }
601
602        let sql = format!(
603            r#"
604            SELECT
605                descendant_paths.*
606            FROM
607                channel_paths parent_paths, channel_paths descendant_paths
608            WHERE
609                parent_paths.channel_id IN ({values}) AND
610                descendant_paths.id_path LIKE (parent_paths.id_path || '%')
611        "#
612        );
613
614        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
615
616        let mut parents_by_child_id = HashMap::default();
617        let mut paths = channel_path::Entity::find()
618            .from_raw_sql(stmt)
619            .stream(tx)
620            .await?;
621
622        while let Some(path) = paths.next().await {
623            let path = path?;
624            let ids = path.id_path.trim_matches('/').split('/');
625            let mut parent_id = None;
626            for id in ids {
627                if let Ok(id) = id.parse() {
628                    let id = ChannelId::from_proto(id);
629                    if id == path.channel_id {
630                        break;
631                    }
632                    parent_id = Some(id);
633                }
634            }
635            parents_by_child_id.insert(path.channel_id, parent_id);
636        }
637
638        Ok(parents_by_child_id)
639    }
640
641    /// Returns the channel with the given ID and:
642    /// - true if the user is a member
643    /// - false if the user hasn't accepted the invitation yet
644    pub async fn get_channel(
645        &self,
646        channel_id: ChannelId,
647        user_id: UserId,
648    ) -> Result<Option<(Channel, bool)>> {
649        self.transaction(|tx| async move {
650            let tx = tx;
651
652            let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
653
654            if let Some(channel) = channel {
655                if self
656                    .check_user_is_channel_member(channel_id, user_id, &*tx)
657                    .await
658                    .is_err()
659                {
660                    return Ok(None);
661                }
662
663                let channel_membership = channel_member::Entity::find()
664                    .filter(
665                        channel_member::Column::ChannelId
666                            .eq(channel_id)
667                            .and(channel_member::Column::UserId.eq(user_id)),
668                    )
669                    .one(&*tx)
670                    .await?;
671
672                let is_accepted = channel_membership
673                    .map(|membership| membership.accepted)
674                    .unwrap_or(false);
675
676                Ok(Some((
677                    Channel {
678                        id: channel.id,
679                        name: channel.name,
680                        parent_id: None,
681                    },
682                    is_accepted,
683                )))
684            } else {
685                Ok(None)
686            }
687        })
688        .await
689    }
690
691    pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
692        self.transaction(|tx| async move {
693            let tx = tx;
694            let room = channel::Model {
695                id: channel_id,
696                ..Default::default()
697            }
698            .find_related(room::Entity)
699            .one(&*tx)
700            .await?
701            .ok_or_else(|| anyhow!("invalid channel"))?;
702            Ok(room.id)
703        })
704        .await
705    }
706}
707
708#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
709enum QueryUserIds {
710    UserId,
711}