1use rpc::proto::ChannelEdge;
2use smallvec::SmallVec;
3
4use super::*;
5
6type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
7
8impl Database {
9 #[cfg(test)]
10 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
11 self.transaction(move |tx| async move {
12 let mut channels = Vec::new();
13 let mut rows = channel::Entity::find().stream(&*tx).await?;
14 while let Some(row) = rows.next().await {
15 let row = row?;
16 channels.push((row.id, row.name));
17 }
18 Ok(channels)
19 })
20 .await
21 }
22
23 pub async fn create_root_channel(
24 &self,
25 name: &str,
26 live_kit_room: &str,
27 creator_id: UserId,
28 ) -> Result<ChannelId> {
29 self.create_channel(name, None, live_kit_room, creator_id)
30 .await
31 }
32
33 pub async fn create_channel(
34 &self,
35 name: &str,
36 parent: Option<ChannelId>,
37 live_kit_room: &str,
38 creator_id: UserId,
39 ) -> Result<ChannelId> {
40 let name = Self::sanitize_channel_name(name)?;
41 self.transaction(move |tx| async move {
42 if let Some(parent) = parent {
43 self.check_user_is_channel_admin(parent, creator_id, &*tx)
44 .await?;
45 }
46
47 let channel = channel::ActiveModel {
48 name: ActiveValue::Set(name.to_string()),
49 ..Default::default()
50 }
51 .insert(&*tx)
52 .await?;
53
54 if let Some(parent) = parent {
55 let sql = r#"
56 INSERT INTO channel_paths
57 (id_path, channel_id)
58 SELECT
59 id_path || $1 || '/', $2
60 FROM
61 channel_paths
62 WHERE
63 channel_id = $3
64 "#;
65 let channel_paths_stmt = Statement::from_sql_and_values(
66 self.pool.get_database_backend(),
67 sql,
68 [
69 channel.id.to_proto().into(),
70 channel.id.to_proto().into(),
71 parent.to_proto().into(),
72 ],
73 );
74 tx.execute(channel_paths_stmt).await?;
75 } else {
76 channel_path::Entity::insert(channel_path::ActiveModel {
77 channel_id: ActiveValue::Set(channel.id),
78 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
79 })
80 .exec(&*tx)
81 .await?;
82 }
83
84 channel_member::ActiveModel {
85 channel_id: ActiveValue::Set(channel.id),
86 user_id: ActiveValue::Set(creator_id),
87 accepted: ActiveValue::Set(true),
88 admin: ActiveValue::Set(true),
89 ..Default::default()
90 }
91 .insert(&*tx)
92 .await?;
93
94 room::ActiveModel {
95 channel_id: ActiveValue::Set(Some(channel.id)),
96 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
97 ..Default::default()
98 }
99 .insert(&*tx)
100 .await?;
101
102 Ok(channel.id)
103 })
104 .await
105 }
106
107 pub async fn delete_channel(
108 &self,
109 channel_id: ChannelId,
110 user_id: UserId,
111 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
112 self.transaction(move |tx| async move {
113 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
114 .await?;
115
116 // Don't remove descendant channels that have additional parents.
117 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
118 {
119 let mut channels_to_keep = channel_path::Entity::find()
120 .filter(
121 channel_path::Column::ChannelId
122 .is_in(
123 channels_to_remove
124 .keys()
125 .copied()
126 .filter(|&id| id != channel_id),
127 )
128 .and(
129 channel_path::Column::IdPath
130 .not_like(&format!("%/{}/%", channel_id)),
131 ),
132 )
133 .stream(&*tx)
134 .await?;
135 while let Some(row) = channels_to_keep.next().await {
136 let row = row?;
137 channels_to_remove.remove(&row.channel_id);
138 }
139 }
140
141 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
142 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
143 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
144 .select_only()
145 .column(channel_member::Column::UserId)
146 .distinct()
147 .into_values::<_, QueryUserIds>()
148 .all(&*tx)
149 .await?;
150
151 channel::Entity::delete_many()
152 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
153 .exec(&*tx)
154 .await?;
155
156 // Delete any other paths that include this channel
157 let sql = r#"
158 DELETE FROM channel_paths
159 WHERE
160 id_path LIKE '%' || $1 || '%'
161 "#;
162 let channel_paths_stmt = Statement::from_sql_and_values(
163 self.pool.get_database_backend(),
164 sql,
165 [channel_id.to_proto().into()],
166 );
167 tx.execute(channel_paths_stmt).await?;
168
169 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
170 })
171 .await
172 }
173
174 pub async fn invite_channel_member(
175 &self,
176 channel_id: ChannelId,
177 invitee_id: UserId,
178 inviter_id: UserId,
179 is_admin: bool,
180 ) -> Result<()> {
181 self.transaction(move |tx| async move {
182 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
183 .await?;
184
185 channel_member::ActiveModel {
186 channel_id: ActiveValue::Set(channel_id),
187 user_id: ActiveValue::Set(invitee_id),
188 accepted: ActiveValue::Set(false),
189 admin: ActiveValue::Set(is_admin),
190 ..Default::default()
191 }
192 .insert(&*tx)
193 .await?;
194
195 Ok(())
196 })
197 .await
198 }
199
200 fn sanitize_channel_name(name: &str) -> Result<&str> {
201 let new_name = name.trim().trim_start_matches('#');
202 if new_name == "" {
203 Err(anyhow!("channel name can't be blank"))?;
204 }
205 Ok(new_name)
206 }
207
208 pub async fn rename_channel(
209 &self,
210 channel_id: ChannelId,
211 user_id: UserId,
212 new_name: &str,
213 ) -> Result<String> {
214 self.transaction(move |tx| async move {
215 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
216
217 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
218 .await?;
219
220 channel::ActiveModel {
221 id: ActiveValue::Unchanged(channel_id),
222 name: ActiveValue::Set(new_name.clone()),
223 ..Default::default()
224 }
225 .update(&*tx)
226 .await?;
227
228 Ok(new_name)
229 })
230 .await
231 }
232
233 pub async fn respond_to_channel_invite(
234 &self,
235 channel_id: ChannelId,
236 user_id: UserId,
237 accept: bool,
238 ) -> Result<()> {
239 self.transaction(move |tx| async move {
240 let rows_affected = if accept {
241 channel_member::Entity::update_many()
242 .set(channel_member::ActiveModel {
243 accepted: ActiveValue::Set(accept),
244 ..Default::default()
245 })
246 .filter(
247 channel_member::Column::ChannelId
248 .eq(channel_id)
249 .and(channel_member::Column::UserId.eq(user_id))
250 .and(channel_member::Column::Accepted.eq(false)),
251 )
252 .exec(&*tx)
253 .await?
254 .rows_affected
255 } else {
256 channel_member::ActiveModel {
257 channel_id: ActiveValue::Unchanged(channel_id),
258 user_id: ActiveValue::Unchanged(user_id),
259 ..Default::default()
260 }
261 .delete(&*tx)
262 .await?
263 .rows_affected
264 };
265
266 if rows_affected == 0 {
267 Err(anyhow!("no such invitation"))?;
268 }
269
270 Ok(())
271 })
272 .await
273 }
274
275 pub async fn remove_channel_member(
276 &self,
277 channel_id: ChannelId,
278 member_id: UserId,
279 remover_id: UserId,
280 ) -> Result<()> {
281 self.transaction(|tx| async move {
282 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
283 .await?;
284
285 let result = channel_member::Entity::delete_many()
286 .filter(
287 channel_member::Column::ChannelId
288 .eq(channel_id)
289 .and(channel_member::Column::UserId.eq(member_id)),
290 )
291 .exec(&*tx)
292 .await?;
293
294 if result.rows_affected == 0 {
295 Err(anyhow!("no such member"))?;
296 }
297
298 Ok(())
299 })
300 .await
301 }
302
303 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
304 self.transaction(|tx| async move {
305 let channel_invites = channel_member::Entity::find()
306 .filter(
307 channel_member::Column::UserId
308 .eq(user_id)
309 .and(channel_member::Column::Accepted.eq(false)),
310 )
311 .all(&*tx)
312 .await?;
313
314 let channels = channel::Entity::find()
315 .filter(
316 channel::Column::Id.is_in(
317 channel_invites
318 .into_iter()
319 .map(|channel_member| channel_member.channel_id),
320 ),
321 )
322 .all(&*tx)
323 .await?;
324
325 let channels = channels
326 .into_iter()
327 .map(|channel| Channel {
328 id: channel.id,
329 name: channel.name,
330 })
331 .collect();
332
333 Ok(channels)
334 })
335 .await
336 }
337
338 async fn get_channel_graph(
339 &self,
340 parents_by_child_id: ChannelDescendants,
341 trim_dangling_parents: bool,
342 tx: &DatabaseTransaction,
343 ) -> Result<ChannelGraph> {
344 let mut channels = Vec::with_capacity(parents_by_child_id.len());
345 {
346 let mut rows = channel::Entity::find()
347 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
348 .stream(&*tx)
349 .await?;
350 while let Some(row) = rows.next().await {
351 let row = row?;
352 channels.push(Channel {
353 id: row.id,
354 name: row.name,
355 })
356 }
357 }
358
359 let mut edges = Vec::with_capacity(parents_by_child_id.len());
360 for (channel, parents) in parents_by_child_id.iter() {
361 for parent in parents.into_iter() {
362 if trim_dangling_parents {
363 if parents_by_child_id.contains_key(parent) {
364 edges.push(ChannelEdge {
365 channel_id: channel.to_proto(),
366 parent_id: parent.to_proto(),
367 });
368 }
369 } else {
370 edges.push(ChannelEdge {
371 channel_id: channel.to_proto(),
372 parent_id: parent.to_proto(),
373 });
374 }
375 }
376 }
377
378 Ok(ChannelGraph { channels, edges })
379 }
380
381 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
382 self.transaction(|tx| async move {
383 let tx = tx;
384
385 let channel_memberships = channel_member::Entity::find()
386 .filter(
387 channel_member::Column::UserId
388 .eq(user_id)
389 .and(channel_member::Column::Accepted.eq(true)),
390 )
391 .all(&*tx)
392 .await?;
393
394 self.get_user_channels(channel_memberships, &tx).await
395 })
396 .await
397 }
398
399 pub async fn get_channel_for_user(
400 &self,
401 channel_id: ChannelId,
402 user_id: UserId,
403 ) -> Result<ChannelsForUser> {
404 self.transaction(|tx| async move {
405 let tx = tx;
406
407 let channel_membership = channel_member::Entity::find()
408 .filter(
409 channel_member::Column::UserId
410 .eq(user_id)
411 .and(channel_member::Column::ChannelId.eq(channel_id))
412 .and(channel_member::Column::Accepted.eq(true)),
413 )
414 .all(&*tx)
415 .await?;
416
417 self.get_user_channels(channel_membership, &tx).await
418 })
419 .await
420 }
421
422 pub async fn get_user_channels(
423 &self,
424 channel_memberships: Vec<channel_member::Model>,
425 tx: &DatabaseTransaction,
426 ) -> Result<ChannelsForUser> {
427 let parents_by_child_id = self
428 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
429 .await?;
430
431 let channels_with_admin_privileges = channel_memberships
432 .iter()
433 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
434 .collect();
435
436 let graph = self
437 .get_channel_graph(parents_by_child_id, true, &tx)
438 .await?;
439
440 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
441 enum QueryUserIdsAndChannelIds {
442 ChannelId,
443 UserId,
444 }
445
446 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
447 {
448 let mut rows = room_participant::Entity::find()
449 .inner_join(room::Entity)
450 .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
451 .select_only()
452 .column(room::Column::ChannelId)
453 .column(room_participant::Column::UserId)
454 .into_values::<_, QueryUserIdsAndChannelIds>()
455 .stream(&*tx)
456 .await?;
457 while let Some(row) = rows.next().await {
458 let row: (ChannelId, UserId) = row?;
459 channel_participants.entry(row.0).or_default().push(row.1)
460 }
461 }
462
463 Ok(ChannelsForUser {
464 channels: graph,
465 channel_participants,
466 channels_with_admin_privileges,
467 })
468 }
469
470 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
471 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
472 .await
473 }
474
475 pub async fn set_channel_member_admin(
476 &self,
477 channel_id: ChannelId,
478 from: UserId,
479 for_user: UserId,
480 admin: bool,
481 ) -> Result<()> {
482 self.transaction(|tx| async move {
483 self.check_user_is_channel_admin(channel_id, from, &*tx)
484 .await?;
485
486 let result = channel_member::Entity::update_many()
487 .filter(
488 channel_member::Column::ChannelId
489 .eq(channel_id)
490 .and(channel_member::Column::UserId.eq(for_user)),
491 )
492 .set(channel_member::ActiveModel {
493 admin: ActiveValue::set(admin),
494 ..Default::default()
495 })
496 .exec(&*tx)
497 .await?;
498
499 if result.rows_affected == 0 {
500 Err(anyhow!("no such member"))?;
501 }
502
503 Ok(())
504 })
505 .await
506 }
507
508 pub async fn get_channel_member_details(
509 &self,
510 channel_id: ChannelId,
511 user_id: UserId,
512 ) -> Result<Vec<proto::ChannelMember>> {
513 self.transaction(|tx| async move {
514 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
515 .await?;
516
517 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
518 enum QueryMemberDetails {
519 UserId,
520 Admin,
521 IsDirectMember,
522 Accepted,
523 }
524
525 let tx = tx;
526 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
527 let mut stream = channel_member::Entity::find()
528 .distinct()
529 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
530 .select_only()
531 .column(channel_member::Column::UserId)
532 .column(channel_member::Column::Admin)
533 .column_as(
534 channel_member::Column::ChannelId.eq(channel_id),
535 QueryMemberDetails::IsDirectMember,
536 )
537 .column(channel_member::Column::Accepted)
538 .order_by_asc(channel_member::Column::UserId)
539 .into_values::<_, QueryMemberDetails>()
540 .stream(&*tx)
541 .await?;
542
543 let mut rows = Vec::<proto::ChannelMember>::new();
544 while let Some(row) = stream.next().await {
545 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
546 UserId,
547 bool,
548 bool,
549 bool,
550 ) = row?;
551 let kind = match (is_direct_member, is_invite_accepted) {
552 (true, true) => proto::channel_member::Kind::Member,
553 (true, false) => proto::channel_member::Kind::Invitee,
554 (false, true) => proto::channel_member::Kind::AncestorMember,
555 (false, false) => continue,
556 };
557 let user_id = user_id.to_proto();
558 let kind = kind.into();
559 if let Some(last_row) = rows.last_mut() {
560 if last_row.user_id == user_id {
561 if is_direct_member {
562 last_row.kind = kind;
563 last_row.admin = is_admin;
564 }
565 continue;
566 }
567 }
568 rows.push(proto::ChannelMember {
569 user_id,
570 kind,
571 admin: is_admin,
572 });
573 }
574
575 Ok(rows)
576 })
577 .await
578 }
579
580 pub async fn get_channel_members_internal(
581 &self,
582 id: ChannelId,
583 tx: &DatabaseTransaction,
584 ) -> Result<Vec<UserId>> {
585 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
586 let user_ids = channel_member::Entity::find()
587 .distinct()
588 .filter(
589 channel_member::Column::ChannelId
590 .is_in(ancestor_ids.iter().copied())
591 .and(channel_member::Column::Accepted.eq(true)),
592 )
593 .select_only()
594 .column(channel_member::Column::UserId)
595 .into_values::<_, QueryUserIds>()
596 .all(&*tx)
597 .await?;
598 Ok(user_ids)
599 }
600
601 pub async fn check_user_is_channel_member(
602 &self,
603 channel_id: ChannelId,
604 user_id: UserId,
605 tx: &DatabaseTransaction,
606 ) -> Result<()> {
607 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
608 channel_member::Entity::find()
609 .filter(
610 channel_member::Column::ChannelId
611 .is_in(channel_ids)
612 .and(channel_member::Column::UserId.eq(user_id)),
613 )
614 .one(&*tx)
615 .await?
616 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
617 Ok(())
618 }
619
620 pub async fn check_user_is_channel_admin(
621 &self,
622 channel_id: ChannelId,
623 user_id: UserId,
624 tx: &DatabaseTransaction,
625 ) -> Result<()> {
626 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
627 channel_member::Entity::find()
628 .filter(
629 channel_member::Column::ChannelId
630 .is_in(channel_ids)
631 .and(channel_member::Column::UserId.eq(user_id))
632 .and(channel_member::Column::Admin.eq(true)),
633 )
634 .one(&*tx)
635 .await?
636 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
637 Ok(())
638 }
639
640 /// Returns the channel ancestors, deepest first
641 pub async fn get_channel_ancestors(
642 &self,
643 channel_id: ChannelId,
644 tx: &DatabaseTransaction,
645 ) -> Result<Vec<ChannelId>> {
646 let paths = channel_path::Entity::find()
647 .filter(channel_path::Column::ChannelId.eq(channel_id))
648 .order_by(channel_path::Column::IdPath, sea_query::Order::Desc)
649 .all(tx)
650 .await?;
651 let mut channel_ids = Vec::new();
652 for path in paths {
653 for id in path.id_path.trim_matches('/').split('/') {
654 if let Ok(id) = id.parse() {
655 let id = ChannelId::from_proto(id);
656 if let Err(ix) = channel_ids.binary_search(&id) {
657 channel_ids.insert(ix, id);
658 }
659 }
660 }
661 }
662 Ok(channel_ids)
663 }
664
665 /// Returns the channel descendants,
666 /// Structured as a map from child ids to their parent ids
667 /// For example, the descendants of 'a' in this DAG:
668 ///
669 /// /- b -\
670 /// a -- c -- d
671 ///
672 /// would be:
673 /// {
674 /// a: [],
675 /// b: [a],
676 /// c: [a],
677 /// d: [a, c],
678 /// }
679 async fn get_channel_descendants(
680 &self,
681 channel_ids: impl IntoIterator<Item = ChannelId>,
682 tx: &DatabaseTransaction,
683 ) -> Result<ChannelDescendants> {
684 let mut values = String::new();
685 for id in channel_ids {
686 if !values.is_empty() {
687 values.push_str(", ");
688 }
689 write!(&mut values, "({})", id).unwrap();
690 }
691
692 if values.is_empty() {
693 return Ok(HashMap::default());
694 }
695
696 let sql = format!(
697 r#"
698 SELECT
699 descendant_paths.*
700 FROM
701 channel_paths parent_paths, channel_paths descendant_paths
702 WHERE
703 parent_paths.channel_id IN ({values}) AND
704 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
705 "#
706 );
707
708 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
709
710 let mut parents_by_child_id: ChannelDescendants = HashMap::default();
711 let mut paths = channel_path::Entity::find()
712 .from_raw_sql(stmt)
713 .stream(tx)
714 .await?;
715
716 while let Some(path) = paths.next().await {
717 let path = path?;
718 let ids = path.id_path.trim_matches('/').split('/');
719 let mut parent_id = None;
720 for id in ids {
721 if let Ok(id) = id.parse() {
722 let id = ChannelId::from_proto(id);
723 if id == path.channel_id {
724 break;
725 }
726 parent_id = Some(id);
727 }
728 }
729 let entry = parents_by_child_id.entry(path.channel_id).or_default();
730 if let Some(parent_id) = parent_id {
731 entry.insert(parent_id);
732 }
733 }
734
735 Ok(parents_by_child_id)
736 }
737
738 /// Returns the channel with the given ID and:
739 /// - true if the user is a member
740 /// - false if the user hasn't accepted the invitation yet
741 pub async fn get_channel(
742 &self,
743 channel_id: ChannelId,
744 user_id: UserId,
745 ) -> Result<Option<(Channel, bool)>> {
746 self.transaction(|tx| async move {
747 let tx = tx;
748
749 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
750
751 if let Some(channel) = channel {
752 if self
753 .check_user_is_channel_member(channel_id, user_id, &*tx)
754 .await
755 .is_err()
756 {
757 return Ok(None);
758 }
759
760 let channel_membership = channel_member::Entity::find()
761 .filter(
762 channel_member::Column::ChannelId
763 .eq(channel_id)
764 .and(channel_member::Column::UserId.eq(user_id)),
765 )
766 .one(&*tx)
767 .await?;
768
769 let is_accepted = channel_membership
770 .map(|membership| membership.accepted)
771 .unwrap_or(false);
772
773 Ok(Some((
774 Channel {
775 id: channel.id,
776 name: channel.name,
777 },
778 is_accepted,
779 )))
780 } else {
781 Ok(None)
782 }
783 })
784 .await
785 }
786
787 pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
788 self.transaction(|tx| async move {
789 let tx = tx;
790 let room = channel::Model {
791 id: channel_id,
792 ..Default::default()
793 }
794 .find_related(room::Entity)
795 .one(&*tx)
796 .await?
797 .ok_or_else(|| anyhow!("invalid channel"))?;
798 Ok(room.id)
799 })
800 .await
801 }
802
803 // Insert an edge from the given channel to the given other channel.
804 pub async fn link_channel(
805 &self,
806 user: UserId,
807 channel: ChannelId,
808 to: ChannelId,
809 ) -> Result<ChannelGraph> {
810 self.transaction(|tx| async move {
811 // Note that even with these maxed permissions, this linking operation
812 // is still insecure because you can't remove someone's permissions to a
813 // channel if they've linked the channel to one where they're an admin.
814 self.check_user_is_channel_admin(channel, user, &*tx)
815 .await?;
816
817 self.link_channel_internal(user, channel, to, &*tx).await
818 })
819 .await
820 }
821
822 pub async fn link_channel_internal(
823 &self,
824 user: UserId,
825 channel: ChannelId,
826 to: ChannelId,
827 tx: &DatabaseTransaction,
828 ) -> Result<ChannelGraph> {
829 self.check_user_is_channel_admin(to, user, &*tx).await?;
830
831 let to_ancestors = self.get_channel_ancestors(to, &*tx).await?;
832 let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
833 for ancestor in to_ancestors {
834 if channel_descendants.contains_key(&ancestor) {
835 return Err(anyhow!("Cannot create a channel cycle").into());
836 }
837 }
838
839 // Now insert all of the new paths
840 let sql = r#"
841 INSERT INTO channel_paths
842 (id_path, channel_id)
843 SELECT
844 id_path || $1 || '/', $2
845 FROM
846 channel_paths
847 WHERE
848 channel_id = $3
849 ON CONFLICT (id_path) DO NOTHING;
850 "#;
851 let channel_paths_stmt = Statement::from_sql_and_values(
852 self.pool.get_database_backend(),
853 sql,
854 [
855 channel.to_proto().into(),
856 channel.to_proto().into(),
857 to.to_proto().into(),
858 ],
859 );
860 tx.execute(channel_paths_stmt).await?;
861 for (descdenant_id, descendant_parent_ids) in
862 channel_descendants.iter().filter(|(id, _)| id != &&channel)
863 {
864 for descendant_parent_id in descendant_parent_ids.iter() {
865 let channel_paths_stmt = Statement::from_sql_and_values(
866 self.pool.get_database_backend(),
867 sql,
868 [
869 descdenant_id.to_proto().into(),
870 descdenant_id.to_proto().into(),
871 descendant_parent_id.to_proto().into(),
872 ],
873 );
874 tx.execute(channel_paths_stmt).await?;
875 }
876 }
877
878 // If we're linking a channel, remove any root edges for the channel
879 {
880 let sql = r#"
881 DELETE FROM channel_paths
882 WHERE
883 id_path = '/' || $1 || '/'
884 "#;
885 let channel_paths_stmt = Statement::from_sql_and_values(
886 self.pool.get_database_backend(),
887 sql,
888 [channel.to_proto().into()],
889 );
890 tx.execute(channel_paths_stmt).await?;
891 }
892
893 if let Some(channel) = channel_descendants.get_mut(&channel) {
894 // Remove the other parents
895 channel.clear();
896 channel.insert(to);
897 }
898
899 let channels = self
900 .get_channel_graph(channel_descendants, false, &*tx)
901 .await?;
902
903 Ok(channels)
904 }
905
906 /// Unlink a channel from a given parent. This will add in a root edge if
907 /// the channel has no other parents after this operation.
908 pub async fn unlink_channel(
909 &self,
910 user: UserId,
911 channel: ChannelId,
912 from: ChannelId,
913 ) -> Result<()> {
914 self.transaction(|tx| async move {
915 // Note that even with these maxed permissions, this linking operation
916 // is still insecure because you can't remove someone's permissions to a
917 // channel if they've linked the channel to one where they're an admin.
918 self.check_user_is_channel_admin(channel, user, &*tx)
919 .await?;
920
921 self.unlink_channel_internal(user, channel, from, &*tx)
922 .await?;
923
924 Ok(())
925 })
926 .await
927 }
928
929 pub async fn unlink_channel_internal(
930 &self,
931 user: UserId,
932 channel: ChannelId,
933 from: ChannelId,
934 tx: &DatabaseTransaction,
935 ) -> Result<()> {
936 self.check_user_is_channel_admin(from, user, &*tx).await?;
937
938 let sql = r#"
939 DELETE FROM channel_paths
940 WHERE
941 id_path LIKE '%' || $1 || '/' || $2 || '%'
942 "#;
943 let channel_paths_stmt = Statement::from_sql_and_values(
944 self.pool.get_database_backend(),
945 sql,
946 [from.to_proto().into(), channel.to_proto().into()],
947 );
948 tx.execute(channel_paths_stmt).await?;
949
950 // Make sure that there is always at least one path to the channel
951 let sql = r#"
952 INSERT INTO channel_paths
953 (id_path, channel_id)
954 SELECT
955 '/' || $1 || '/', $2
956 WHERE NOT EXISTS
957 (SELECT *
958 FROM channel_paths
959 WHERE channel_id = $2)
960 "#;
961
962 let channel_paths_stmt = Statement::from_sql_and_values(
963 self.pool.get_database_backend(),
964 sql,
965 [channel.to_proto().into(), channel.to_proto().into()],
966 );
967 tx.execute(channel_paths_stmt).await?;
968
969 Ok(())
970 }
971
972 /// Move a channel from one parent to another, returns the
973 /// Channels that were moved for notifying clients
974 pub async fn move_channel(
975 &self,
976 user: UserId,
977 channel: ChannelId,
978 from: ChannelId,
979 to: ChannelId,
980 ) -> Result<ChannelGraph> {
981 self.transaction(|tx| async move {
982 self.check_user_is_channel_admin(channel, user, &*tx)
983 .await?;
984
985 let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
986
987 self.unlink_channel_internal(user, channel, from, &*tx)
988 .await?;
989
990 Ok(moved_channels)
991 })
992 .await
993 }
994}
995
996#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
997enum QueryUserIds {
998 UserId,
999}
1000
1001#[derive(Debug)]
1002pub struct ChannelGraph {
1003 pub channels: Vec<Channel>,
1004 pub edges: Vec<ChannelEdge>,
1005}
1006
1007impl ChannelGraph {
1008 pub fn is_empty(&self) -> bool {
1009 self.channels.is_empty() && self.edges.is_empty()
1010 }
1011}
1012
1013#[cfg(test)]
1014impl PartialEq for ChannelGraph {
1015 fn eq(&self, other: &Self) -> bool {
1016 // Order independent comparison for tests
1017 let channels_set = self.channels.iter().collect::<HashSet<_>>();
1018 let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1019 let edges_set = self
1020 .edges
1021 .iter()
1022 .map(|edge| (edge.channel_id, edge.parent_id))
1023 .collect::<HashSet<_>>();
1024 let other_edges_set = other
1025 .edges
1026 .iter()
1027 .map(|edge| (edge.channel_id, edge.parent_id))
1028 .collect::<HashSet<_>>();
1029
1030 channels_set == other_channels_set && edges_set == other_edges_set
1031 }
1032}
1033
1034#[cfg(not(test))]
1035impl PartialEq for ChannelGraph {
1036 fn eq(&self, other: &Self) -> bool {
1037 self.channels == other.channels && self.edges == other.edges
1038 }
1039}
1040
1041struct SmallSet<T>(SmallVec<[T; 1]>);
1042
1043impl<T> Deref for SmallSet<T> {
1044 type Target = [T];
1045
1046 fn deref(&self) -> &Self::Target {
1047 self.0.deref()
1048 }
1049}
1050
1051impl<T> Default for SmallSet<T> {
1052 fn default() -> Self {
1053 Self(SmallVec::new())
1054 }
1055}
1056
1057impl<T> SmallSet<T> {
1058 fn insert(&mut self, value: T) -> bool
1059 where
1060 T: Ord,
1061 {
1062 match self.binary_search(&value) {
1063 Ok(_) => false,
1064 Err(ix) => {
1065 self.0.insert(ix, value);
1066 true
1067 }
1068 }
1069 }
1070
1071 fn clear(&mut self) {
1072 self.0.clear();
1073 }
1074}