1use rpc::proto::ChannelEdge;
2use smallvec::SmallVec;
3
4use super::*;
5
6type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
7
8impl Database {
9 #[cfg(test)]
10 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
11 self.transaction(move |tx| async move {
12 let mut channels = Vec::new();
13 let mut rows = channel::Entity::find().stream(&*tx).await?;
14 while let Some(row) = rows.next().await {
15 let row = row?;
16 channels.push((row.id, row.name));
17 }
18 Ok(channels)
19 })
20 .await
21 }
22
23 pub async fn create_root_channel(
24 &self,
25 name: &str,
26 live_kit_room: &str,
27 creator_id: UserId,
28 ) -> Result<ChannelId> {
29 self.create_channel(name, None, live_kit_room, creator_id)
30 .await
31 }
32
33 pub async fn create_channel(
34 &self,
35 name: &str,
36 parent: Option<ChannelId>,
37 live_kit_room: &str,
38 creator_id: UserId,
39 ) -> Result<ChannelId> {
40 let name = Self::sanitize_channel_name(name)?;
41 self.transaction(move |tx| async move {
42 if let Some(parent) = parent {
43 self.check_user_is_channel_admin(parent, creator_id, &*tx)
44 .await?;
45 }
46
47 let channel = channel::ActiveModel {
48 name: ActiveValue::Set(name.to_string()),
49 ..Default::default()
50 }
51 .insert(&*tx)
52 .await?;
53
54 if let Some(parent) = parent {
55 let sql = r#"
56 INSERT INTO channel_paths
57 (id_path, channel_id)
58 SELECT
59 id_path || $1 || '/', $2
60 FROM
61 channel_paths
62 WHERE
63 channel_id = $3
64 "#;
65 let channel_paths_stmt = Statement::from_sql_and_values(
66 self.pool.get_database_backend(),
67 sql,
68 [
69 channel.id.to_proto().into(),
70 channel.id.to_proto().into(),
71 parent.to_proto().into(),
72 ],
73 );
74 tx.execute(channel_paths_stmt).await?;
75 } else {
76 channel_path::Entity::insert(channel_path::ActiveModel {
77 channel_id: ActiveValue::Set(channel.id),
78 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
79 })
80 .exec(&*tx)
81 .await?;
82 }
83
84 channel_member::ActiveModel {
85 channel_id: ActiveValue::Set(channel.id),
86 user_id: ActiveValue::Set(creator_id),
87 accepted: ActiveValue::Set(true),
88 admin: ActiveValue::Set(true),
89 ..Default::default()
90 }
91 .insert(&*tx)
92 .await?;
93
94 room::ActiveModel {
95 channel_id: ActiveValue::Set(Some(channel.id)),
96 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
97 ..Default::default()
98 }
99 .insert(&*tx)
100 .await?;
101
102 Ok(channel.id)
103 })
104 .await
105 }
106
107 pub async fn delete_channel(
108 &self,
109 channel_id: ChannelId,
110 user_id: UserId,
111 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
112 self.transaction(move |tx| async move {
113 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
114 .await?;
115
116 // Don't remove descendant channels that have additional parents.
117 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
118 {
119 let mut channels_to_keep = channel_path::Entity::find()
120 .filter(
121 channel_path::Column::ChannelId
122 .is_in(
123 channels_to_remove
124 .keys()
125 .copied()
126 .filter(|&id| id != channel_id),
127 )
128 .and(
129 channel_path::Column::IdPath
130 .not_like(&format!("%/{}/%", channel_id)),
131 ),
132 )
133 .stream(&*tx)
134 .await?;
135 while let Some(row) = channels_to_keep.next().await {
136 let row = row?;
137 channels_to_remove.remove(&row.channel_id);
138 }
139 }
140
141 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
142 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
143 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
144 .select_only()
145 .column(channel_member::Column::UserId)
146 .distinct()
147 .into_values::<_, QueryUserIds>()
148 .all(&*tx)
149 .await?;
150
151 channel::Entity::delete_many()
152 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
153 .exec(&*tx)
154 .await?;
155
156 // Delete any other paths that include this channel
157 let sql = r#"
158 DELETE FROM channel_paths
159 WHERE
160 id_path LIKE '%' || $1 || '%'
161 "#;
162 let channel_paths_stmt = Statement::from_sql_and_values(
163 self.pool.get_database_backend(),
164 sql,
165 [channel_id.to_proto().into()],
166 );
167 tx.execute(channel_paths_stmt).await?;
168
169 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
170 })
171 .await
172 }
173
174 pub async fn invite_channel_member(
175 &self,
176 channel_id: ChannelId,
177 invitee_id: UserId,
178 inviter_id: UserId,
179 is_admin: bool,
180 ) -> Result<()> {
181 self.transaction(move |tx| async move {
182 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
183 .await?;
184
185 channel_member::ActiveModel {
186 channel_id: ActiveValue::Set(channel_id),
187 user_id: ActiveValue::Set(invitee_id),
188 accepted: ActiveValue::Set(false),
189 admin: ActiveValue::Set(is_admin),
190 ..Default::default()
191 }
192 .insert(&*tx)
193 .await?;
194
195 Ok(())
196 })
197 .await
198 }
199
200 fn sanitize_channel_name(name: &str) -> Result<&str> {
201 let new_name = name.trim().trim_start_matches('#');
202 if new_name == "" {
203 Err(anyhow!("channel name can't be blank"))?;
204 }
205 Ok(new_name)
206 }
207
208 pub async fn rename_channel(
209 &self,
210 channel_id: ChannelId,
211 user_id: UserId,
212 new_name: &str,
213 ) -> Result<String> {
214 self.transaction(move |tx| async move {
215 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
216
217 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
218 .await?;
219
220 channel::ActiveModel {
221 id: ActiveValue::Unchanged(channel_id),
222 name: ActiveValue::Set(new_name.clone()),
223 ..Default::default()
224 }
225 .update(&*tx)
226 .await?;
227
228 Ok(new_name)
229 })
230 .await
231 }
232
233 pub async fn respond_to_channel_invite(
234 &self,
235 channel_id: ChannelId,
236 user_id: UserId,
237 accept: bool,
238 ) -> Result<()> {
239 self.transaction(move |tx| async move {
240 let rows_affected = if accept {
241 channel_member::Entity::update_many()
242 .set(channel_member::ActiveModel {
243 accepted: ActiveValue::Set(accept),
244 ..Default::default()
245 })
246 .filter(
247 channel_member::Column::ChannelId
248 .eq(channel_id)
249 .and(channel_member::Column::UserId.eq(user_id))
250 .and(channel_member::Column::Accepted.eq(false)),
251 )
252 .exec(&*tx)
253 .await?
254 .rows_affected
255 } else {
256 channel_member::ActiveModel {
257 channel_id: ActiveValue::Unchanged(channel_id),
258 user_id: ActiveValue::Unchanged(user_id),
259 ..Default::default()
260 }
261 .delete(&*tx)
262 .await?
263 .rows_affected
264 };
265
266 if rows_affected == 0 {
267 Err(anyhow!("no such invitation"))?;
268 }
269
270 Ok(())
271 })
272 .await
273 }
274
275 pub async fn remove_channel_member(
276 &self,
277 channel_id: ChannelId,
278 member_id: UserId,
279 remover_id: UserId,
280 ) -> Result<()> {
281 self.transaction(|tx| async move {
282 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
283 .await?;
284
285 let result = channel_member::Entity::delete_many()
286 .filter(
287 channel_member::Column::ChannelId
288 .eq(channel_id)
289 .and(channel_member::Column::UserId.eq(member_id)),
290 )
291 .exec(&*tx)
292 .await?;
293
294 if result.rows_affected == 0 {
295 Err(anyhow!("no such member"))?;
296 }
297
298 Ok(())
299 })
300 .await
301 }
302
303 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
304 self.transaction(|tx| async move {
305 let channel_invites = channel_member::Entity::find()
306 .filter(
307 channel_member::Column::UserId
308 .eq(user_id)
309 .and(channel_member::Column::Accepted.eq(false)),
310 )
311 .all(&*tx)
312 .await?;
313
314 let channels = channel::Entity::find()
315 .filter(
316 channel::Column::Id.is_in(
317 channel_invites
318 .into_iter()
319 .map(|channel_member| channel_member.channel_id),
320 ),
321 )
322 .all(&*tx)
323 .await?;
324
325 let channels = channels
326 .into_iter()
327 .map(|channel| Channel {
328 id: channel.id,
329 name: channel.name,
330 })
331 .collect();
332
333 Ok(channels)
334 })
335 .await
336 }
337
338 async fn get_channel_graph(
339 &self,
340 parents_by_child_id: ChannelDescendants,
341 trim_dangling_parents: bool,
342 tx: &DatabaseTransaction,
343 ) -> Result<ChannelGraph> {
344 let mut channels = Vec::with_capacity(parents_by_child_id.len());
345 {
346 let mut rows = channel::Entity::find()
347 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
348 .stream(&*tx)
349 .await?;
350 while let Some(row) = rows.next().await {
351 let row = row?;
352 channels.push(Channel {
353 id: row.id,
354 name: row.name,
355 })
356 }
357 }
358
359 let mut edges = Vec::with_capacity(parents_by_child_id.len());
360 for (channel, parents) in parents_by_child_id.iter() {
361 for parent in parents.into_iter() {
362 if trim_dangling_parents {
363 if parents_by_child_id.contains_key(parent) {
364 edges.push(ChannelEdge {
365 channel_id: channel.to_proto(),
366 parent_id: parent.to_proto(),
367 });
368 }
369 } else {
370 edges.push(ChannelEdge {
371 channel_id: channel.to_proto(),
372 parent_id: parent.to_proto(),
373 });
374 }
375 }
376 }
377
378 Ok(ChannelGraph { channels, edges })
379 }
380
381 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
382 self.transaction(|tx| async move {
383 let tx = tx;
384
385 let channel_memberships = channel_member::Entity::find()
386 .filter(
387 channel_member::Column::UserId
388 .eq(user_id)
389 .and(channel_member::Column::Accepted.eq(true)),
390 )
391 .all(&*tx)
392 .await?;
393
394 self.get_user_channels(channel_memberships, &tx).await
395 })
396 .await
397 }
398
399 pub async fn get_channel_for_user(
400 &self,
401 channel_id: ChannelId,
402 user_id: UserId,
403 ) -> Result<ChannelsForUser> {
404 self.transaction(|tx| async move {
405 let tx = tx;
406
407 let channel_membership = channel_member::Entity::find()
408 .filter(
409 channel_member::Column::UserId
410 .eq(user_id)
411 .and(channel_member::Column::ChannelId.eq(channel_id))
412 .and(channel_member::Column::Accepted.eq(true)),
413 )
414 .all(&*tx)
415 .await?;
416
417 self.get_user_channels(channel_membership, &tx).await
418 })
419 .await
420 }
421
422 pub async fn get_user_channels(
423 &self,
424 channel_memberships: Vec<channel_member::Model>,
425 tx: &DatabaseTransaction,
426 ) -> Result<ChannelsForUser> {
427 let parents_by_child_id = self
428 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
429 .await?;
430
431 let channels_with_admin_privileges = channel_memberships
432 .iter()
433 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
434 .collect();
435
436 let graph = self
437 .get_channel_graph(parents_by_child_id, true, &tx)
438 .await?;
439
440 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
441 enum QueryUserIdsAndChannelIds {
442 ChannelId,
443 UserId,
444 }
445
446 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
447 {
448 let mut rows = room_participant::Entity::find()
449 .inner_join(room::Entity)
450 .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
451 .select_only()
452 .column(room::Column::ChannelId)
453 .column(room_participant::Column::UserId)
454 .into_values::<_, QueryUserIdsAndChannelIds>()
455 .stream(&*tx)
456 .await?;
457 while let Some(row) = rows.next().await {
458 let row: (ChannelId, UserId) = row?;
459 channel_participants.entry(row.0).or_default().push(row.1)
460 }
461 }
462
463 Ok(ChannelsForUser {
464 channels: graph,
465 channel_participants,
466 channels_with_admin_privileges,
467 })
468 }
469
470 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
471 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
472 .await
473 }
474
475 pub async fn set_channel_member_admin(
476 &self,
477 channel_id: ChannelId,
478 from: UserId,
479 for_user: UserId,
480 admin: bool,
481 ) -> Result<()> {
482 self.transaction(|tx| async move {
483 self.check_user_is_channel_admin(channel_id, from, &*tx)
484 .await?;
485
486 let result = channel_member::Entity::update_many()
487 .filter(
488 channel_member::Column::ChannelId
489 .eq(channel_id)
490 .and(channel_member::Column::UserId.eq(for_user)),
491 )
492 .set(channel_member::ActiveModel {
493 admin: ActiveValue::set(admin),
494 ..Default::default()
495 })
496 .exec(&*tx)
497 .await?;
498
499 if result.rows_affected == 0 {
500 Err(anyhow!("no such member"))?;
501 }
502
503 Ok(())
504 })
505 .await
506 }
507
508 pub async fn get_channel_member_details(
509 &self,
510 channel_id: ChannelId,
511 user_id: UserId,
512 ) -> Result<Vec<proto::ChannelMember>> {
513 self.transaction(|tx| async move {
514 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
515 .await?;
516
517 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
518 enum QueryMemberDetails {
519 UserId,
520 Admin,
521 IsDirectMember,
522 Accepted,
523 }
524
525 let tx = tx;
526 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
527 let mut stream = channel_member::Entity::find()
528 .distinct()
529 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
530 .select_only()
531 .column(channel_member::Column::UserId)
532 .column(channel_member::Column::Admin)
533 .column_as(
534 channel_member::Column::ChannelId.eq(channel_id),
535 QueryMemberDetails::IsDirectMember,
536 )
537 .column(channel_member::Column::Accepted)
538 .order_by_asc(channel_member::Column::UserId)
539 .into_values::<_, QueryMemberDetails>()
540 .stream(&*tx)
541 .await?;
542
543 let mut rows = Vec::<proto::ChannelMember>::new();
544 while let Some(row) = stream.next().await {
545 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
546 UserId,
547 bool,
548 bool,
549 bool,
550 ) = row?;
551 let kind = match (is_direct_member, is_invite_accepted) {
552 (true, true) => proto::channel_member::Kind::Member,
553 (true, false) => proto::channel_member::Kind::Invitee,
554 (false, true) => proto::channel_member::Kind::AncestorMember,
555 (false, false) => continue,
556 };
557 let user_id = user_id.to_proto();
558 let kind = kind.into();
559 if let Some(last_row) = rows.last_mut() {
560 if last_row.user_id == user_id {
561 if is_direct_member {
562 last_row.kind = kind;
563 last_row.admin = is_admin;
564 }
565 continue;
566 }
567 }
568 rows.push(proto::ChannelMember {
569 user_id,
570 kind,
571 admin: is_admin,
572 });
573 }
574
575 Ok(rows)
576 })
577 .await
578 }
579
580 pub async fn get_channel_members_internal(
581 &self,
582 id: ChannelId,
583 tx: &DatabaseTransaction,
584 ) -> Result<Vec<UserId>> {
585 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
586 let user_ids = channel_member::Entity::find()
587 .distinct()
588 .filter(
589 channel_member::Column::ChannelId
590 .is_in(ancestor_ids.iter().copied())
591 .and(channel_member::Column::Accepted.eq(true)),
592 )
593 .select_only()
594 .column(channel_member::Column::UserId)
595 .into_values::<_, QueryUserIds>()
596 .all(&*tx)
597 .await?;
598 Ok(user_ids)
599 }
600
601 pub async fn check_user_is_channel_member(
602 &self,
603 channel_id: ChannelId,
604 user_id: UserId,
605 tx: &DatabaseTransaction,
606 ) -> Result<()> {
607 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
608 channel_member::Entity::find()
609 .filter(
610 channel_member::Column::ChannelId
611 .is_in(channel_ids)
612 .and(channel_member::Column::UserId.eq(user_id)),
613 )
614 .one(&*tx)
615 .await?
616 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
617 Ok(())
618 }
619
620 pub async fn check_user_is_channel_admin(
621 &self,
622 channel_id: ChannelId,
623 user_id: UserId,
624 tx: &DatabaseTransaction,
625 ) -> Result<()> {
626 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
627 channel_member::Entity::find()
628 .filter(
629 channel_member::Column::ChannelId
630 .is_in(channel_ids)
631 .and(channel_member::Column::UserId.eq(user_id))
632 .and(channel_member::Column::Admin.eq(true)),
633 )
634 .one(&*tx)
635 .await?
636 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
637 Ok(())
638 }
639
640 /// Returns the channel ancestors, deepest first
641 pub async fn get_channel_ancestors(
642 &self,
643 channel_id: ChannelId,
644 tx: &DatabaseTransaction,
645 ) -> Result<Vec<ChannelId>> {
646 let paths = channel_path::Entity::find()
647 .filter(channel_path::Column::ChannelId.eq(channel_id))
648 .order_by(channel_path::Column::IdPath, sea_query::Order::Desc)
649 .all(tx)
650 .await?;
651 let mut channel_ids = Vec::new();
652 for path in paths {
653 for id in path.id_path.trim_matches('/').split('/') {
654 if let Ok(id) = id.parse() {
655 let id = ChannelId::from_proto(id);
656 if let Err(ix) = channel_ids.binary_search(&id) {
657 channel_ids.insert(ix, id);
658 }
659 }
660 }
661 }
662 Ok(channel_ids)
663 }
664
665 /// Returns the channel descendants,
666 /// Structured as a map from child ids to their parent ids
667 /// For example, the descendants of 'a' in this DAG:
668 ///
669 /// /- b -\
670 /// a -- c -- d
671 ///
672 /// would be:
673 /// {
674 /// a: [],
675 /// b: [a],
676 /// c: [a],
677 /// d: [a, c],
678 /// }
679 async fn get_channel_descendants(
680 &self,
681 channel_ids: impl IntoIterator<Item = ChannelId>,
682 tx: &DatabaseTransaction,
683 ) -> Result<ChannelDescendants> {
684 let mut values = String::new();
685 for id in channel_ids {
686 if !values.is_empty() {
687 values.push_str(", ");
688 }
689 write!(&mut values, "({})", id).unwrap();
690 }
691
692 if values.is_empty() {
693 return Ok(HashMap::default());
694 }
695
696 let sql = format!(
697 r#"
698 SELECT
699 descendant_paths.*
700 FROM
701 channel_paths parent_paths, channel_paths descendant_paths
702 WHERE
703 parent_paths.channel_id IN ({values}) AND
704 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
705 "#
706 );
707
708 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
709
710 let mut parents_by_child_id: ChannelDescendants = HashMap::default();
711 let mut paths = channel_path::Entity::find()
712 .from_raw_sql(stmt)
713 .stream(tx)
714 .await?;
715
716 while let Some(path) = paths.next().await {
717 let path = path?;
718 let ids = path.id_path.trim_matches('/').split('/');
719 let mut parent_id = None;
720 for id in ids {
721 if let Ok(id) = id.parse() {
722 let id = ChannelId::from_proto(id);
723 if id == path.channel_id {
724 break;
725 }
726 parent_id = Some(id);
727 }
728 }
729 let entry = parents_by_child_id.entry(path.channel_id).or_default();
730 if let Some(parent_id) = parent_id {
731 entry.insert(parent_id);
732 }
733 }
734
735 Ok(parents_by_child_id)
736 }
737
738 /// Returns the channel with the given ID and:
739 /// - true if the user is a member
740 /// - false if the user hasn't accepted the invitation yet
741 pub async fn get_channel(
742 &self,
743 channel_id: ChannelId,
744 user_id: UserId,
745 ) -> Result<Option<(Channel, bool)>> {
746 self.transaction(|tx| async move {
747 let tx = tx;
748
749 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
750
751 if let Some(channel) = channel {
752 if self
753 .check_user_is_channel_member(channel_id, user_id, &*tx)
754 .await
755 .is_err()
756 {
757 return Ok(None);
758 }
759
760 let channel_membership = channel_member::Entity::find()
761 .filter(
762 channel_member::Column::ChannelId
763 .eq(channel_id)
764 .and(channel_member::Column::UserId.eq(user_id)),
765 )
766 .one(&*tx)
767 .await?;
768
769 let is_accepted = channel_membership
770 .map(|membership| membership.accepted)
771 .unwrap_or(false);
772
773 Ok(Some((
774 Channel {
775 id: channel.id,
776 name: channel.name,
777 },
778 is_accepted,
779 )))
780 } else {
781 Ok(None)
782 }
783 })
784 .await
785 }
786
787 pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
788 self.transaction(|tx| async move {
789 let tx = tx;
790 let room = channel::Model {
791 id: channel_id,
792 ..Default::default()
793 }
794 .find_related(room::Entity)
795 .one(&*tx)
796 .await?
797 .ok_or_else(|| anyhow!("invalid channel"))?;
798 Ok(room.id)
799 })
800 .await
801 }
802
803 // Insert an edge from the given channel to the given other channel.
804 pub async fn link_channel(
805 &self,
806 user: UserId,
807 channel: ChannelId,
808 to: ChannelId,
809 ) -> Result<ChannelGraph> {
810 self.transaction(|tx| async move {
811 // Note that even with these maxed permissions, this linking operation
812 // is still insecure because you can't remove someone's permissions to a
813 // channel if they've linked the channel to one where they're an admin.
814 self.check_user_is_channel_admin(channel, user, &*tx)
815 .await?;
816
817 self.link_channel_internal(user, channel, to, &*tx).await
818 })
819 .await
820 }
821
822 pub async fn link_channel_internal(
823 &self,
824 user: UserId,
825 channel: ChannelId,
826 to: ChannelId,
827 tx: &DatabaseTransaction,
828 ) -> Result<ChannelGraph> {
829 self.check_user_is_channel_admin(to, user, &*tx).await?;
830
831 let paths = channel_path::Entity::find()
832 .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
833 .all(tx)
834 .await?;
835
836 let mut new_path_suffixes = HashSet::default();
837 for path in paths {
838 if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
839 new_path_suffixes.insert((
840 path.channel_id,
841 path.id_path[(start_offset + 1)..].to_string(),
842 ));
843 }
844 }
845
846 let paths_to_new_parent = channel_path::Entity::find()
847 .filter(channel_path::Column::ChannelId.eq(to))
848 .all(tx)
849 .await?;
850
851 let mut new_paths = Vec::new();
852 for path in paths_to_new_parent {
853 if path.id_path.contains(&format!("/{}/", channel)) {
854 Err(anyhow!("cycle"))?;
855 }
856
857 new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
858 channel_path::ActiveModel {
859 channel_id: ActiveValue::Set(*channel_id),
860 id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
861 }
862 }));
863 }
864
865 channel_path::Entity::insert_many(new_paths)
866 .exec(&*tx)
867 .await?;
868
869 // remove any root edges for the channel we just linked
870 {
871 channel_path::Entity::delete_many()
872 .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
873 .exec(&*tx)
874 .await?;
875 }
876
877 let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
878 if let Some(channel) = channel_descendants.get_mut(&channel) {
879 // Remove the other parents
880 channel.clear();
881 channel.insert(to);
882 }
883
884 let channels = self
885 .get_channel_graph(channel_descendants, false, &*tx)
886 .await?;
887
888 Ok(channels)
889 }
890
891 /// Unlink a channel from a given parent. This will add in a root edge if
892 /// the channel has no other parents after this operation.
893 pub async fn unlink_channel(
894 &self,
895 user: UserId,
896 channel: ChannelId,
897 from: ChannelId,
898 ) -> Result<()> {
899 self.transaction(|tx| async move {
900 // Note that even with these maxed permissions, this linking operation
901 // is still insecure because you can't remove someone's permissions to a
902 // channel if they've linked the channel to one where they're an admin.
903 self.check_user_is_channel_admin(channel, user, &*tx)
904 .await?;
905
906 self.unlink_channel_internal(user, channel, from, &*tx)
907 .await?;
908
909 Ok(())
910 })
911 .await
912 }
913
914 pub async fn unlink_channel_internal(
915 &self,
916 user: UserId,
917 channel: ChannelId,
918 from: ChannelId,
919 tx: &DatabaseTransaction,
920 ) -> Result<()> {
921 self.check_user_is_channel_admin(from, user, &*tx).await?;
922
923 let sql = r#"
924 DELETE FROM channel_paths
925 WHERE
926 id_path LIKE '%/' || $1 || '/' || $2 || '/%'
927 RETURNING id_path, channel_id
928 "#;
929
930 let paths = channel_path::Entity::find()
931 .from_raw_sql(Statement::from_sql_and_values(
932 self.pool.get_database_backend(),
933 sql,
934 [from.to_proto().into(), channel.to_proto().into()],
935 ))
936 .all(&*tx)
937 .await?;
938
939 let is_stranded = channel_path::Entity::find()
940 .filter(channel_path::Column::ChannelId.eq(channel))
941 .count(&*tx)
942 .await?
943 == 0;
944
945 // Make sure that there is always at least one path to the channel
946 if is_stranded {
947 let root_paths: Vec<_> = paths
948 .iter()
949 .map(|path| {
950 let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
951 channel_path::ActiveModel {
952 channel_id: ActiveValue::Set(path.channel_id),
953 id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
954 }
955 })
956 .collect();
957 channel_path::Entity::insert_many(root_paths)
958 .exec(&*tx)
959 .await?;
960 }
961
962 Ok(())
963 }
964
965 /// Move a channel from one parent to another, returns the
966 /// Channels that were moved for notifying clients
967 pub async fn move_channel(
968 &self,
969 user: UserId,
970 channel: ChannelId,
971 from: ChannelId,
972 to: ChannelId,
973 ) -> Result<ChannelGraph> {
974 if from == to {
975 return Ok(ChannelGraph {
976 channels: vec![],
977 edges: vec![],
978 });
979 }
980
981 self.transaction(|tx| async move {
982 self.check_user_is_channel_admin(channel, user, &*tx)
983 .await?;
984
985 let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
986
987 self.unlink_channel_internal(user, channel, from, &*tx)
988 .await?;
989
990 Ok(moved_channels)
991 })
992 .await
993 }
994}
995
996#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
997enum QueryUserIds {
998 UserId,
999}
1000
1001#[derive(Debug)]
1002pub struct ChannelGraph {
1003 pub channels: Vec<Channel>,
1004 pub edges: Vec<ChannelEdge>,
1005}
1006
1007impl ChannelGraph {
1008 pub fn is_empty(&self) -> bool {
1009 self.channels.is_empty() && self.edges.is_empty()
1010 }
1011}
1012
1013#[cfg(test)]
1014impl PartialEq for ChannelGraph {
1015 fn eq(&self, other: &Self) -> bool {
1016 // Order independent comparison for tests
1017 let channels_set = self.channels.iter().collect::<HashSet<_>>();
1018 let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1019 let edges_set = self
1020 .edges
1021 .iter()
1022 .map(|edge| (edge.channel_id, edge.parent_id))
1023 .collect::<HashSet<_>>();
1024 let other_edges_set = other
1025 .edges
1026 .iter()
1027 .map(|edge| (edge.channel_id, edge.parent_id))
1028 .collect::<HashSet<_>>();
1029
1030 channels_set == other_channels_set && edges_set == other_edges_set
1031 }
1032}
1033
1034#[cfg(not(test))]
1035impl PartialEq for ChannelGraph {
1036 fn eq(&self, other: &Self) -> bool {
1037 self.channels == other.channels && self.edges == other.edges
1038 }
1039}
1040
1041struct SmallSet<T>(SmallVec<[T; 1]>);
1042
1043impl<T> Deref for SmallSet<T> {
1044 type Target = [T];
1045
1046 fn deref(&self) -> &Self::Target {
1047 self.0.deref()
1048 }
1049}
1050
1051impl<T> Default for SmallSet<T> {
1052 fn default() -> Self {
1053 Self(SmallVec::new())
1054 }
1055}
1056
1057impl<T> SmallSet<T> {
1058 fn insert(&mut self, value: T) -> bool
1059 where
1060 T: Ord,
1061 {
1062 match self.binary_search(&value) {
1063 Ok(_) => false,
1064 Err(ix) => {
1065 self.0.insert(ix, value);
1066 true
1067 }
1068 }
1069 }
1070
1071 fn clear(&mut self) {
1072 self.0.clear();
1073 }
1074}