1use super::*;
2use rpc::proto::ChannelEdge;
3use smallvec::SmallVec;
4
5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
6
7impl Database {
8 #[cfg(test)]
9 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
10 self.transaction(move |tx| async move {
11 let mut channels = Vec::new();
12 let mut rows = channel::Entity::find().stream(&*tx).await?;
13 while let Some(row) = rows.next().await {
14 let row = row?;
15 channels.push((row.id, row.name));
16 }
17 Ok(channels)
18 })
19 .await
20 }
21
22 pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
23 self.create_channel(name, None, creator_id).await
24 }
25
26 pub async fn create_channel(
27 &self,
28 name: &str,
29 parent: Option<ChannelId>,
30 creator_id: UserId,
31 ) -> Result<ChannelId> {
32 let name = Self::sanitize_channel_name(name)?;
33 self.transaction(move |tx| async move {
34 if let Some(parent) = parent {
35 self.check_user_is_channel_admin(parent, creator_id, &*tx)
36 .await?;
37 }
38
39 let channel = channel::ActiveModel {
40 id: ActiveValue::NotSet,
41 name: ActiveValue::Set(name.to_string()),
42 visibility: ActiveValue::Set(ChannelVisibility::ChannelMembers),
43 }
44 .insert(&*tx)
45 .await?;
46
47 if let Some(parent) = parent {
48 let sql = r#"
49 INSERT INTO channel_paths
50 (id_path, channel_id)
51 SELECT
52 id_path || $1 || '/', $2
53 FROM
54 channel_paths
55 WHERE
56 channel_id = $3
57 "#;
58 let channel_paths_stmt = Statement::from_sql_and_values(
59 self.pool.get_database_backend(),
60 sql,
61 [
62 channel.id.to_proto().into(),
63 channel.id.to_proto().into(),
64 parent.to_proto().into(),
65 ],
66 );
67 tx.execute(channel_paths_stmt).await?;
68 } else {
69 channel_path::Entity::insert(channel_path::ActiveModel {
70 channel_id: ActiveValue::Set(channel.id),
71 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
72 })
73 .exec(&*tx)
74 .await?;
75 }
76
77 channel_member::ActiveModel {
78 id: ActiveValue::NotSet,
79 channel_id: ActiveValue::Set(channel.id),
80 user_id: ActiveValue::Set(creator_id),
81 accepted: ActiveValue::Set(true),
82 admin: ActiveValue::Set(true),
83 role: ActiveValue::Set(Some(ChannelRole::Admin)),
84 }
85 .insert(&*tx)
86 .await?;
87
88 Ok(channel.id)
89 })
90 .await
91 }
92
93 pub async fn set_channel_visibility(
94 &self,
95 channel_id: ChannelId,
96 visibility: ChannelVisibility,
97 user_id: UserId,
98 ) -> Result<()> {
99 self.transaction(move |tx| async move {
100 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
101 .await?;
102
103 channel::ActiveModel {
104 id: ActiveValue::Unchanged(channel_id),
105 visibility: ActiveValue::Set(visibility),
106 ..Default::default()
107 }
108 .update(&*tx)
109 .await?;
110
111 Ok(())
112 })
113 .await
114 }
115
116 pub async fn delete_channel(
117 &self,
118 channel_id: ChannelId,
119 user_id: UserId,
120 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
121 self.transaction(move |tx| async move {
122 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
123 .await?;
124
125 // Don't remove descendant channels that have additional parents.
126 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
127 {
128 let mut channels_to_keep = channel_path::Entity::find()
129 .filter(
130 channel_path::Column::ChannelId
131 .is_in(
132 channels_to_remove
133 .keys()
134 .copied()
135 .filter(|&id| id != channel_id),
136 )
137 .and(
138 channel_path::Column::IdPath
139 .not_like(&format!("%/{}/%", channel_id)),
140 ),
141 )
142 .stream(&*tx)
143 .await?;
144 while let Some(row) = channels_to_keep.next().await {
145 let row = row?;
146 channels_to_remove.remove(&row.channel_id);
147 }
148 }
149
150 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
151 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
152 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
153 .select_only()
154 .column(channel_member::Column::UserId)
155 .distinct()
156 .into_values::<_, QueryUserIds>()
157 .all(&*tx)
158 .await?;
159
160 channel::Entity::delete_many()
161 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
162 .exec(&*tx)
163 .await?;
164
165 // Delete any other paths that include this channel
166 let sql = r#"
167 DELETE FROM channel_paths
168 WHERE
169 id_path LIKE '%' || $1 || '%'
170 "#;
171 let channel_paths_stmt = Statement::from_sql_and_values(
172 self.pool.get_database_backend(),
173 sql,
174 [channel_id.to_proto().into()],
175 );
176 tx.execute(channel_paths_stmt).await?;
177
178 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
179 })
180 .await
181 }
182
183 pub async fn invite_channel_member(
184 &self,
185 channel_id: ChannelId,
186 invitee_id: UserId,
187 admin_id: UserId,
188 role: ChannelRole,
189 ) -> Result<()> {
190 self.transaction(move |tx| async move {
191 self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
192 .await?;
193
194 channel_member::ActiveModel {
195 id: ActiveValue::NotSet,
196 channel_id: ActiveValue::Set(channel_id),
197 user_id: ActiveValue::Set(invitee_id),
198 accepted: ActiveValue::Set(false),
199 admin: ActiveValue::Set(role == ChannelRole::Admin),
200 role: ActiveValue::Set(Some(role)),
201 }
202 .insert(&*tx)
203 .await?;
204
205 Ok(())
206 })
207 .await
208 }
209
210 fn sanitize_channel_name(name: &str) -> Result<&str> {
211 let new_name = name.trim().trim_start_matches('#');
212 if new_name == "" {
213 Err(anyhow!("channel name can't be blank"))?;
214 }
215 Ok(new_name)
216 }
217
218 pub async fn rename_channel(
219 &self,
220 channel_id: ChannelId,
221 user_id: UserId,
222 new_name: &str,
223 ) -> Result<String> {
224 self.transaction(move |tx| async move {
225 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
226
227 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
228 .await?;
229
230 channel::ActiveModel {
231 id: ActiveValue::Unchanged(channel_id),
232 name: ActiveValue::Set(new_name.clone()),
233 ..Default::default()
234 }
235 .update(&*tx)
236 .await?;
237
238 Ok(new_name)
239 })
240 .await
241 }
242
243 pub async fn respond_to_channel_invite(
244 &self,
245 channel_id: ChannelId,
246 user_id: UserId,
247 accept: bool,
248 ) -> Result<()> {
249 self.transaction(move |tx| async move {
250 let rows_affected = if accept {
251 channel_member::Entity::update_many()
252 .set(channel_member::ActiveModel {
253 accepted: ActiveValue::Set(accept),
254 ..Default::default()
255 })
256 .filter(
257 channel_member::Column::ChannelId
258 .eq(channel_id)
259 .and(channel_member::Column::UserId.eq(user_id))
260 .and(channel_member::Column::Accepted.eq(false)),
261 )
262 .exec(&*tx)
263 .await?
264 .rows_affected
265 } else {
266 channel_member::ActiveModel {
267 channel_id: ActiveValue::Unchanged(channel_id),
268 user_id: ActiveValue::Unchanged(user_id),
269 ..Default::default()
270 }
271 .delete(&*tx)
272 .await?
273 .rows_affected
274 };
275
276 if rows_affected == 0 {
277 Err(anyhow!("no such invitation"))?;
278 }
279
280 Ok(())
281 })
282 .await
283 }
284
285 pub async fn remove_channel_member(
286 &self,
287 channel_id: ChannelId,
288 member_id: UserId,
289 admin_id: UserId,
290 ) -> Result<()> {
291 self.transaction(|tx| async move {
292 self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
293 .await?;
294
295 let result = channel_member::Entity::delete_many()
296 .filter(
297 channel_member::Column::ChannelId
298 .eq(channel_id)
299 .and(channel_member::Column::UserId.eq(member_id)),
300 )
301 .exec(&*tx)
302 .await?;
303
304 if result.rows_affected == 0 {
305 Err(anyhow!("no such member"))?;
306 }
307
308 Ok(())
309 })
310 .await
311 }
312
313 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
314 self.transaction(|tx| async move {
315 let channel_invites = channel_member::Entity::find()
316 .filter(
317 channel_member::Column::UserId
318 .eq(user_id)
319 .and(channel_member::Column::Accepted.eq(false)),
320 )
321 .all(&*tx)
322 .await?;
323
324 let channels = channel::Entity::find()
325 .filter(
326 channel::Column::Id.is_in(
327 channel_invites
328 .into_iter()
329 .map(|channel_member| channel_member.channel_id),
330 ),
331 )
332 .all(&*tx)
333 .await?;
334
335 let channels = channels
336 .into_iter()
337 .map(|channel| Channel {
338 id: channel.id,
339 name: channel.name,
340 })
341 .collect();
342
343 Ok(channels)
344 })
345 .await
346 }
347
348 async fn get_channel_graph(
349 &self,
350 parents_by_child_id: ChannelDescendants,
351 trim_dangling_parents: bool,
352 tx: &DatabaseTransaction,
353 ) -> Result<ChannelGraph> {
354 let mut channels = Vec::with_capacity(parents_by_child_id.len());
355 {
356 let mut rows = channel::Entity::find()
357 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
358 .stream(&*tx)
359 .await?;
360 while let Some(row) = rows.next().await {
361 let row = row?;
362 channels.push(Channel {
363 id: row.id,
364 name: row.name,
365 })
366 }
367 }
368
369 let mut edges = Vec::with_capacity(parents_by_child_id.len());
370 for (channel, parents) in parents_by_child_id.iter() {
371 for parent in parents.into_iter() {
372 if trim_dangling_parents {
373 if parents_by_child_id.contains_key(parent) {
374 edges.push(ChannelEdge {
375 channel_id: channel.to_proto(),
376 parent_id: parent.to_proto(),
377 });
378 }
379 } else {
380 edges.push(ChannelEdge {
381 channel_id: channel.to_proto(),
382 parent_id: parent.to_proto(),
383 });
384 }
385 }
386 }
387
388 Ok(ChannelGraph { channels, edges })
389 }
390
391 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
392 self.transaction(|tx| async move {
393 let tx = tx;
394
395 let channel_memberships = channel_member::Entity::find()
396 .filter(
397 channel_member::Column::UserId
398 .eq(user_id)
399 .and(channel_member::Column::Accepted.eq(true)),
400 )
401 .all(&*tx)
402 .await?;
403
404 self.get_user_channels(user_id, channel_memberships, &tx)
405 .await
406 })
407 .await
408 }
409
410 pub async fn get_channel_for_user(
411 &self,
412 channel_id: ChannelId,
413 user_id: UserId,
414 ) -> Result<ChannelsForUser> {
415 self.transaction(|tx| async move {
416 let tx = tx;
417
418 let channel_membership = channel_member::Entity::find()
419 .filter(
420 channel_member::Column::UserId
421 .eq(user_id)
422 .and(channel_member::Column::ChannelId.eq(channel_id))
423 .and(channel_member::Column::Accepted.eq(true)),
424 )
425 .all(&*tx)
426 .await?;
427
428 self.get_user_channels(user_id, channel_membership, &tx)
429 .await
430 })
431 .await
432 }
433
434 pub async fn get_user_channels(
435 &self,
436 user_id: UserId,
437 channel_memberships: Vec<channel_member::Model>,
438 tx: &DatabaseTransaction,
439 ) -> Result<ChannelsForUser> {
440 let parents_by_child_id = self
441 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
442 .await?;
443
444 let channels_with_admin_privileges = channel_memberships
445 .iter()
446 .filter_map(|membership| {
447 if membership.role == Some(ChannelRole::Admin) || membership.admin {
448 Some(membership.channel_id)
449 } else {
450 None
451 }
452 })
453 .collect();
454
455 let graph = self
456 .get_channel_graph(parents_by_child_id, true, &tx)
457 .await?;
458
459 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
460 enum QueryUserIdsAndChannelIds {
461 ChannelId,
462 UserId,
463 }
464
465 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
466 {
467 let mut rows = room_participant::Entity::find()
468 .inner_join(room::Entity)
469 .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
470 .select_only()
471 .column(room::Column::ChannelId)
472 .column(room_participant::Column::UserId)
473 .into_values::<_, QueryUserIdsAndChannelIds>()
474 .stream(&*tx)
475 .await?;
476 while let Some(row) = rows.next().await {
477 let row: (ChannelId, UserId) = row?;
478 channel_participants.entry(row.0).or_default().push(row.1)
479 }
480 }
481
482 let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
483 let channel_buffer_changes = self
484 .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
485 .await?;
486
487 let unseen_messages = self
488 .unseen_channel_messages(user_id, &channel_ids, &*tx)
489 .await?;
490
491 Ok(ChannelsForUser {
492 channels: graph,
493 channel_participants,
494 channels_with_admin_privileges,
495 unseen_buffer_changes: channel_buffer_changes,
496 channel_messages: unseen_messages,
497 })
498 }
499
500 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
501 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
502 .await
503 }
504
505 pub async fn set_channel_member_role(
506 &self,
507 channel_id: ChannelId,
508 admin_id: UserId,
509 for_user: UserId,
510 role: ChannelRole,
511 ) -> Result<()> {
512 self.transaction(|tx| async move {
513 self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
514 .await?;
515
516 let result = channel_member::Entity::update_many()
517 .filter(
518 channel_member::Column::ChannelId
519 .eq(channel_id)
520 .and(channel_member::Column::UserId.eq(for_user)),
521 )
522 .set(channel_member::ActiveModel {
523 admin: ActiveValue::set(role == ChannelRole::Admin),
524 role: ActiveValue::set(Some(role)),
525 ..Default::default()
526 })
527 .exec(&*tx)
528 .await?;
529
530 if result.rows_affected == 0 {
531 Err(anyhow!("no such member"))?;
532 }
533
534 Ok(())
535 })
536 .await
537 }
538
539 pub async fn get_channel_member_details(
540 &self,
541 channel_id: ChannelId,
542 user_id: UserId,
543 ) -> Result<Vec<proto::ChannelMember>> {
544 self.transaction(|tx| async move {
545 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
546 .await?;
547
548 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
549 enum QueryMemberDetails {
550 UserId,
551 Admin,
552 Role,
553 IsDirectMember,
554 Accepted,
555 }
556
557 let tx = tx;
558 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
559 let mut stream = channel_member::Entity::find()
560 .distinct()
561 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
562 .select_only()
563 .column(channel_member::Column::UserId)
564 .column(channel_member::Column::Admin)
565 .column(channel_member::Column::Role)
566 .column_as(
567 channel_member::Column::ChannelId.eq(channel_id),
568 QueryMemberDetails::IsDirectMember,
569 )
570 .column(channel_member::Column::Accepted)
571 .order_by_asc(channel_member::Column::UserId)
572 .into_values::<_, QueryMemberDetails>()
573 .stream(&*tx)
574 .await?;
575
576 let mut rows = Vec::<proto::ChannelMember>::new();
577 while let Some(row) = stream.next().await {
578 let (user_id, is_admin, channel_role, is_direct_member, is_invite_accepted): (
579 UserId,
580 bool,
581 Option<ChannelRole>,
582 bool,
583 bool,
584 ) = row?;
585 let kind = match (is_direct_member, is_invite_accepted) {
586 (true, true) => proto::channel_member::Kind::Member,
587 (true, false) => proto::channel_member::Kind::Invitee,
588 (false, true) => proto::channel_member::Kind::AncestorMember,
589 (false, false) => continue,
590 };
591 let channel_role = channel_role.unwrap_or(if is_admin {
592 ChannelRole::Admin
593 } else {
594 ChannelRole::Member
595 });
596 let user_id = user_id.to_proto();
597 let kind = kind.into();
598 if let Some(last_row) = rows.last_mut() {
599 if last_row.user_id == user_id {
600 if is_direct_member {
601 last_row.kind = kind;
602 last_row.role = channel_role.into()
603 }
604 continue;
605 }
606 }
607 rows.push(proto::ChannelMember {
608 user_id,
609 kind,
610 role: channel_role.into(),
611 });
612 }
613
614 Ok(rows)
615 })
616 .await
617 }
618
619 pub async fn get_channel_members_internal(
620 &self,
621 id: ChannelId,
622 tx: &DatabaseTransaction,
623 ) -> Result<Vec<UserId>> {
624 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
625 let user_ids = channel_member::Entity::find()
626 .distinct()
627 .filter(
628 channel_member::Column::ChannelId
629 .is_in(ancestor_ids.iter().copied())
630 .and(channel_member::Column::Accepted.eq(true)),
631 )
632 .select_only()
633 .column(channel_member::Column::UserId)
634 .into_values::<_, QueryUserIds>()
635 .all(&*tx)
636 .await?;
637 Ok(user_ids)
638 }
639
640 pub async fn check_user_is_channel_admin(
641 &self,
642 channel_id: ChannelId,
643 user_id: UserId,
644 tx: &DatabaseTransaction,
645 ) -> Result<()> {
646 match self.channel_role_for_user(channel_id, user_id, tx).await? {
647 Some(ChannelRole::Admin) => Ok(()),
648 Some(ChannelRole::Member)
649 | Some(ChannelRole::Banned)
650 | Some(ChannelRole::Guest)
651 | None => Err(anyhow!(
652 "user is not a channel admin or channel does not exist"
653 ))?,
654 }
655 }
656
657 pub async fn check_user_is_channel_member(
658 &self,
659 channel_id: ChannelId,
660 user_id: UserId,
661 tx: &DatabaseTransaction,
662 ) -> Result<()> {
663 match self.channel_role_for_user(channel_id, user_id, tx).await? {
664 Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()),
665 Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
666 "user is not a channel member or channel does not exist"
667 ))?,
668 }
669 }
670
671 pub async fn check_user_is_channel_participant(
672 &self,
673 channel_id: ChannelId,
674 user_id: UserId,
675 tx: &DatabaseTransaction,
676 ) -> Result<()> {
677 match self.channel_role_for_user(channel_id, user_id, tx).await? {
678 Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
679 Ok(())
680 }
681 Some(ChannelRole::Banned) | None => Err(anyhow!(
682 "user is not a channel participant or channel does not exist"
683 ))?,
684 }
685 }
686
687 pub async fn channel_role_for_user(
688 &self,
689 channel_id: ChannelId,
690 user_id: UserId,
691 tx: &DatabaseTransaction,
692 ) -> Result<Option<ChannelRole>> {
693 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
694
695 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
696 enum QueryChannelMembership {
697 ChannelId,
698 Role,
699 Admin,
700 Visibility,
701 }
702
703 let mut rows = channel_member::Entity::find()
704 .left_join(channel::Entity)
705 .filter(
706 channel_member::Column::ChannelId
707 .is_in(channel_ids)
708 .and(channel_member::Column::UserId.eq(user_id)),
709 )
710 .select_only()
711 .column(channel_member::Column::ChannelId)
712 .column(channel_member::Column::Role)
713 .column(channel_member::Column::Admin)
714 .column(channel::Column::Visibility)
715 .into_values::<_, QueryChannelMembership>()
716 .stream(&*tx)
717 .await?;
718
719 let mut is_admin = false;
720 let mut is_member = false;
721 let mut is_participant = false;
722 let mut is_banned = false;
723 let mut current_channel_visibility = None;
724
725 // note these channels are not iterated in any particular order,
726 // our current logic takes the highest permission available.
727 while let Some(row) = rows.next().await {
728 let (ch_id, role, admin, visibility): (
729 ChannelId,
730 Option<ChannelRole>,
731 bool,
732 ChannelVisibility,
733 ) = row?;
734 match role {
735 Some(ChannelRole::Admin) => is_admin = true,
736 Some(ChannelRole::Member) => is_member = true,
737 Some(ChannelRole::Guest) => {
738 if visibility == ChannelVisibility::Public {
739 is_participant = true
740 }
741 }
742 Some(ChannelRole::Banned) => is_banned = true,
743 None => {
744 // rows created from pre-role collab server.
745 if admin {
746 is_admin = true
747 } else {
748 is_member = true
749 }
750 }
751 }
752 if channel_id == ch_id {
753 current_channel_visibility = Some(visibility);
754 }
755 }
756 // free up database connection
757 drop(rows);
758
759 Ok(if is_admin {
760 Some(ChannelRole::Admin)
761 } else if is_member {
762 Some(ChannelRole::Member)
763 } else if is_banned {
764 Some(ChannelRole::Banned)
765 } else if is_participant {
766 if current_channel_visibility.is_none() {
767 current_channel_visibility = channel::Entity::find()
768 .filter(channel::Column::Id.eq(channel_id))
769 .one(&*tx)
770 .await?
771 .map(|channel| channel.visibility);
772 }
773 if current_channel_visibility == Some(ChannelVisibility::Public) {
774 Some(ChannelRole::Guest)
775 } else {
776 None
777 }
778 } else {
779 None
780 })
781 }
782
783 /// Returns the channel ancestors, deepest first
784 pub async fn get_channel_ancestors(
785 &self,
786 channel_id: ChannelId,
787 tx: &DatabaseTransaction,
788 ) -> Result<Vec<ChannelId>> {
789 let paths = channel_path::Entity::find()
790 .filter(channel_path::Column::ChannelId.eq(channel_id))
791 .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
792 .all(tx)
793 .await?;
794 let mut channel_ids = Vec::new();
795 for path in paths {
796 for id in path.id_path.trim_matches('/').split('/') {
797 if let Ok(id) = id.parse() {
798 let id = ChannelId::from_proto(id);
799 if let Err(ix) = channel_ids.binary_search(&id) {
800 channel_ids.insert(ix, id);
801 }
802 }
803 }
804 }
805 Ok(channel_ids)
806 }
807
808 /// Returns the channel descendants,
809 /// Structured as a map from child ids to their parent ids
810 /// For example, the descendants of 'a' in this DAG:
811 ///
812 /// /- b -\
813 /// a -- c -- d
814 ///
815 /// would be:
816 /// {
817 /// a: [],
818 /// b: [a],
819 /// c: [a],
820 /// d: [a, c],
821 /// }
822 async fn get_channel_descendants(
823 &self,
824 channel_ids: impl IntoIterator<Item = ChannelId>,
825 tx: &DatabaseTransaction,
826 ) -> Result<ChannelDescendants> {
827 let mut values = String::new();
828 for id in channel_ids {
829 if !values.is_empty() {
830 values.push_str(", ");
831 }
832 write!(&mut values, "({})", id).unwrap();
833 }
834
835 if values.is_empty() {
836 return Ok(HashMap::default());
837 }
838
839 let sql = format!(
840 r#"
841 SELECT
842 descendant_paths.*
843 FROM
844 channel_paths parent_paths, channel_paths descendant_paths
845 WHERE
846 parent_paths.channel_id IN ({values}) AND
847 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
848 "#
849 );
850
851 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
852
853 let mut parents_by_child_id: ChannelDescendants = HashMap::default();
854 let mut paths = channel_path::Entity::find()
855 .from_raw_sql(stmt)
856 .stream(tx)
857 .await?;
858
859 while let Some(path) = paths.next().await {
860 let path = path?;
861 let ids = path.id_path.trim_matches('/').split('/');
862 let mut parent_id = None;
863 for id in ids {
864 if let Ok(id) = id.parse() {
865 let id = ChannelId::from_proto(id);
866 if id == path.channel_id {
867 break;
868 }
869 parent_id = Some(id);
870 }
871 }
872 let entry = parents_by_child_id.entry(path.channel_id).or_default();
873 if let Some(parent_id) = parent_id {
874 entry.insert(parent_id);
875 }
876 }
877
878 Ok(parents_by_child_id)
879 }
880
881 /// Returns the channel with the given ID and:
882 /// - true if the user is a member
883 /// - false if the user hasn't accepted the invitation yet
884 pub async fn get_channel(
885 &self,
886 channel_id: ChannelId,
887 user_id: UserId,
888 ) -> Result<Option<(Channel, bool)>> {
889 self.transaction(|tx| async move {
890 let tx = tx;
891
892 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
893
894 if let Some(channel) = channel {
895 if self
896 .check_user_is_channel_member(channel_id, user_id, &*tx)
897 .await
898 .is_err()
899 {
900 return Ok(None);
901 }
902
903 let channel_membership = channel_member::Entity::find()
904 .filter(
905 channel_member::Column::ChannelId
906 .eq(channel_id)
907 .and(channel_member::Column::UserId.eq(user_id)),
908 )
909 .one(&*tx)
910 .await?;
911
912 let is_accepted = channel_membership
913 .map(|membership| membership.accepted)
914 .unwrap_or(false);
915
916 Ok(Some((
917 Channel {
918 id: channel.id,
919 name: channel.name,
920 },
921 is_accepted,
922 )))
923 } else {
924 Ok(None)
925 }
926 })
927 .await
928 }
929
930 pub async fn get_or_create_channel_room(
931 &self,
932 channel_id: ChannelId,
933 live_kit_room: &str,
934 enviroment: &str,
935 ) -> Result<RoomId> {
936 self.transaction(|tx| async move {
937 let tx = tx;
938
939 let room = room::Entity::find()
940 .filter(room::Column::ChannelId.eq(channel_id))
941 .one(&*tx)
942 .await?;
943
944 let room_id = if let Some(room) = room {
945 room.id
946 } else {
947 let result = room::Entity::insert(room::ActiveModel {
948 channel_id: ActiveValue::Set(Some(channel_id)),
949 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
950 enviroment: ActiveValue::Set(Some(enviroment.to_string())),
951 ..Default::default()
952 })
953 .exec(&*tx)
954 .await?;
955
956 result.last_insert_id
957 };
958
959 Ok(room_id)
960 })
961 .await
962 }
963
964 // Insert an edge from the given channel to the given other channel.
965 pub async fn link_channel(
966 &self,
967 user: UserId,
968 channel: ChannelId,
969 to: ChannelId,
970 ) -> Result<ChannelGraph> {
971 self.transaction(|tx| async move {
972 // Note that even with these maxed permissions, this linking operation
973 // is still insecure because you can't remove someone's permissions to a
974 // channel if they've linked the channel to one where they're an admin.
975 self.check_user_is_channel_admin(channel, user, &*tx)
976 .await?;
977
978 self.link_channel_internal(user, channel, to, &*tx).await
979 })
980 .await
981 }
982
983 pub async fn link_channel_internal(
984 &self,
985 user: UserId,
986 channel: ChannelId,
987 new_parent: ChannelId,
988 tx: &DatabaseTransaction,
989 ) -> Result<ChannelGraph> {
990 self.check_user_is_channel_admin(new_parent, user, &*tx)
991 .await?;
992
993 let paths = channel_path::Entity::find()
994 .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
995 .all(tx)
996 .await?;
997
998 let mut new_path_suffixes = HashSet::default();
999 for path in paths {
1000 if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
1001 new_path_suffixes.insert((
1002 path.channel_id,
1003 path.id_path[(start_offset + 1)..].to_string(),
1004 ));
1005 }
1006 }
1007
1008 let paths_to_new_parent = channel_path::Entity::find()
1009 .filter(channel_path::Column::ChannelId.eq(new_parent))
1010 .all(tx)
1011 .await?;
1012
1013 let mut new_paths = Vec::new();
1014 for path in paths_to_new_parent {
1015 if path.id_path.contains(&format!("/{}/", channel)) {
1016 Err(anyhow!("cycle"))?;
1017 }
1018
1019 new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
1020 channel_path::ActiveModel {
1021 channel_id: ActiveValue::Set(*channel_id),
1022 id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
1023 }
1024 }));
1025 }
1026
1027 channel_path::Entity::insert_many(new_paths)
1028 .exec(&*tx)
1029 .await?;
1030
1031 // remove any root edges for the channel we just linked
1032 {
1033 channel_path::Entity::delete_many()
1034 .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
1035 .exec(&*tx)
1036 .await?;
1037 }
1038
1039 let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
1040 if let Some(channel) = channel_descendants.get_mut(&channel) {
1041 // Remove the other parents
1042 channel.clear();
1043 channel.insert(new_parent);
1044 }
1045
1046 let channels = self
1047 .get_channel_graph(channel_descendants, false, &*tx)
1048 .await?;
1049
1050 Ok(channels)
1051 }
1052
1053 /// Unlink a channel from a given parent. This will add in a root edge if
1054 /// the channel has no other parents after this operation.
1055 pub async fn unlink_channel(
1056 &self,
1057 user: UserId,
1058 channel: ChannelId,
1059 from: ChannelId,
1060 ) -> Result<()> {
1061 self.transaction(|tx| async move {
1062 // Note that even with these maxed permissions, this linking operation
1063 // is still insecure because you can't remove someone's permissions to a
1064 // channel if they've linked the channel to one where they're an admin.
1065 self.check_user_is_channel_admin(channel, user, &*tx)
1066 .await?;
1067
1068 self.unlink_channel_internal(user, channel, from, &*tx)
1069 .await?;
1070
1071 Ok(())
1072 })
1073 .await
1074 }
1075
1076 pub async fn unlink_channel_internal(
1077 &self,
1078 user: UserId,
1079 channel: ChannelId,
1080 from: ChannelId,
1081 tx: &DatabaseTransaction,
1082 ) -> Result<()> {
1083 self.check_user_is_channel_admin(from, user, &*tx).await?;
1084
1085 let sql = r#"
1086 DELETE FROM channel_paths
1087 WHERE
1088 id_path LIKE '%/' || $1 || '/' || $2 || '/%'
1089 RETURNING id_path, channel_id
1090 "#;
1091
1092 let paths = channel_path::Entity::find()
1093 .from_raw_sql(Statement::from_sql_and_values(
1094 self.pool.get_database_backend(),
1095 sql,
1096 [from.to_proto().into(), channel.to_proto().into()],
1097 ))
1098 .all(&*tx)
1099 .await?;
1100
1101 let is_stranded = channel_path::Entity::find()
1102 .filter(channel_path::Column::ChannelId.eq(channel))
1103 .count(&*tx)
1104 .await?
1105 == 0;
1106
1107 // Make sure that there is always at least one path to the channel
1108 if is_stranded {
1109 let root_paths: Vec<_> = paths
1110 .iter()
1111 .map(|path| {
1112 let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
1113 channel_path::ActiveModel {
1114 channel_id: ActiveValue::Set(path.channel_id),
1115 id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
1116 }
1117 })
1118 .collect();
1119 channel_path::Entity::insert_many(root_paths)
1120 .exec(&*tx)
1121 .await?;
1122 }
1123
1124 Ok(())
1125 }
1126
1127 /// Move a channel from one parent to another, returns the
1128 /// Channels that were moved for notifying clients
1129 pub async fn move_channel(
1130 &self,
1131 user: UserId,
1132 channel: ChannelId,
1133 from: ChannelId,
1134 to: ChannelId,
1135 ) -> Result<ChannelGraph> {
1136 if from == to {
1137 return Ok(ChannelGraph {
1138 channels: vec![],
1139 edges: vec![],
1140 });
1141 }
1142
1143 self.transaction(|tx| async move {
1144 self.check_user_is_channel_admin(channel, user, &*tx)
1145 .await?;
1146
1147 let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1148
1149 self.unlink_channel_internal(user, channel, from, &*tx)
1150 .await?;
1151
1152 Ok(moved_channels)
1153 })
1154 .await
1155 }
1156}
1157
1158#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1159enum QueryUserIds {
1160 UserId,
1161}
1162
1163#[derive(Debug)]
1164pub struct ChannelGraph {
1165 pub channels: Vec<Channel>,
1166 pub edges: Vec<ChannelEdge>,
1167}
1168
1169impl ChannelGraph {
1170 pub fn is_empty(&self) -> bool {
1171 self.channels.is_empty() && self.edges.is_empty()
1172 }
1173}
1174
1175#[cfg(test)]
1176impl PartialEq for ChannelGraph {
1177 fn eq(&self, other: &Self) -> bool {
1178 // Order independent comparison for tests
1179 let channels_set = self.channels.iter().collect::<HashSet<_>>();
1180 let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1181 let edges_set = self
1182 .edges
1183 .iter()
1184 .map(|edge| (edge.channel_id, edge.parent_id))
1185 .collect::<HashSet<_>>();
1186 let other_edges_set = other
1187 .edges
1188 .iter()
1189 .map(|edge| (edge.channel_id, edge.parent_id))
1190 .collect::<HashSet<_>>();
1191
1192 channels_set == other_channels_set && edges_set == other_edges_set
1193 }
1194}
1195
1196#[cfg(not(test))]
1197impl PartialEq for ChannelGraph {
1198 fn eq(&self, other: &Self) -> bool {
1199 self.channels == other.channels && self.edges == other.edges
1200 }
1201}
1202
1203struct SmallSet<T>(SmallVec<[T; 1]>);
1204
1205impl<T> Deref for SmallSet<T> {
1206 type Target = [T];
1207
1208 fn deref(&self) -> &Self::Target {
1209 self.0.deref()
1210 }
1211}
1212
1213impl<T> Default for SmallSet<T> {
1214 fn default() -> Self {
1215 Self(SmallVec::new())
1216 }
1217}
1218
1219impl<T> SmallSet<T> {
1220 fn insert(&mut self, value: T) -> bool
1221 where
1222 T: Ord,
1223 {
1224 match self.binary_search(&value) {
1225 Ok(_) => false,
1226 Err(ix) => {
1227 self.0.insert(ix, value);
1228 true
1229 }
1230 }
1231 }
1232
1233 fn clear(&mut self) {
1234 self.0.clear();
1235 }
1236}