1use super::*;
2use rpc::proto::ChannelEdge;
3use smallvec::SmallVec;
4
5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
6
7impl Database {
8 #[cfg(test)]
9 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
10 self.transaction(move |tx| async move {
11 let mut channels = Vec::new();
12 let mut rows = channel::Entity::find().stream(&*tx).await?;
13 while let Some(row) = rows.next().await {
14 let row = row?;
15 channels.push((row.id, row.name));
16 }
17 Ok(channels)
18 })
19 .await
20 }
21
22 pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
23 self.create_channel(name, None, creator_id).await
24 }
25
26 pub async fn create_channel(
27 &self,
28 name: &str,
29 parent: Option<ChannelId>,
30 creator_id: UserId,
31 ) -> Result<ChannelId> {
32 let name = Self::sanitize_channel_name(name)?;
33 self.transaction(move |tx| async move {
34 if let Some(parent) = parent {
35 self.check_user_is_channel_admin(parent, creator_id, &*tx)
36 .await?;
37 }
38
39 let channel = channel::ActiveModel {
40 name: ActiveValue::Set(name.to_string()),
41 ..Default::default()
42 }
43 .insert(&*tx)
44 .await?;
45
46 if let Some(parent) = parent {
47 let sql = r#"
48 INSERT INTO channel_paths
49 (id_path, channel_id)
50 SELECT
51 id_path || $1 || '/', $2
52 FROM
53 channel_paths
54 WHERE
55 channel_id = $3
56 "#;
57 let channel_paths_stmt = Statement::from_sql_and_values(
58 self.pool.get_database_backend(),
59 sql,
60 [
61 channel.id.to_proto().into(),
62 channel.id.to_proto().into(),
63 parent.to_proto().into(),
64 ],
65 );
66 tx.execute(channel_paths_stmt).await?;
67 } else {
68 channel_path::Entity::insert(channel_path::ActiveModel {
69 channel_id: ActiveValue::Set(channel.id),
70 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
71 })
72 .exec(&*tx)
73 .await?;
74 }
75
76 channel_member::ActiveModel {
77 channel_id: ActiveValue::Set(channel.id),
78 user_id: ActiveValue::Set(creator_id),
79 accepted: ActiveValue::Set(true),
80 admin: ActiveValue::Set(true),
81 ..Default::default()
82 }
83 .insert(&*tx)
84 .await?;
85
86 Ok(channel.id)
87 })
88 .await
89 }
90
91 pub async fn delete_channel(
92 &self,
93 channel_id: ChannelId,
94 user_id: UserId,
95 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
96 self.transaction(move |tx| async move {
97 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
98 .await?;
99
100 // Don't remove descendant channels that have additional parents.
101 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
102 {
103 let mut channels_to_keep = channel_path::Entity::find()
104 .filter(
105 channel_path::Column::ChannelId
106 .is_in(
107 channels_to_remove
108 .keys()
109 .copied()
110 .filter(|&id| id != channel_id),
111 )
112 .and(
113 channel_path::Column::IdPath
114 .not_like(&format!("%/{}/%", channel_id)),
115 ),
116 )
117 .stream(&*tx)
118 .await?;
119 while let Some(row) = channels_to_keep.next().await {
120 let row = row?;
121 channels_to_remove.remove(&row.channel_id);
122 }
123 }
124
125 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
126 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
127 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
128 .select_only()
129 .column(channel_member::Column::UserId)
130 .distinct()
131 .into_values::<_, QueryUserIds>()
132 .all(&*tx)
133 .await?;
134
135 channel::Entity::delete_many()
136 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
137 .exec(&*tx)
138 .await?;
139
140 // Delete any other paths that include this channel
141 let sql = r#"
142 DELETE FROM channel_paths
143 WHERE
144 id_path LIKE '%' || $1 || '%'
145 "#;
146 let channel_paths_stmt = Statement::from_sql_and_values(
147 self.pool.get_database_backend(),
148 sql,
149 [channel_id.to_proto().into()],
150 );
151 tx.execute(channel_paths_stmt).await?;
152
153 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
154 })
155 .await
156 }
157
158 pub async fn invite_channel_member(
159 &self,
160 channel_id: ChannelId,
161 invitee_id: UserId,
162 inviter_id: UserId,
163 is_admin: bool,
164 ) -> Result<Option<proto::Notification>> {
165 self.transaction(move |tx| async move {
166 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
167 .await?;
168
169 channel_member::ActiveModel {
170 channel_id: ActiveValue::Set(channel_id),
171 user_id: ActiveValue::Set(invitee_id),
172 accepted: ActiveValue::Set(false),
173 admin: ActiveValue::Set(is_admin),
174 ..Default::default()
175 }
176 .insert(&*tx)
177 .await?;
178
179 self.create_notification(
180 invitee_id,
181 rpc::Notification::ChannelInvitation {
182 actor_id: inviter_id.to_proto(),
183 channel_id: channel_id.to_proto(),
184 },
185 true,
186 &*tx,
187 )
188 .await
189 })
190 .await
191 }
192
193 fn sanitize_channel_name(name: &str) -> Result<&str> {
194 let new_name = name.trim().trim_start_matches('#');
195 if new_name == "" {
196 Err(anyhow!("channel name can't be blank"))?;
197 }
198 Ok(new_name)
199 }
200
201 pub async fn rename_channel(
202 &self,
203 channel_id: ChannelId,
204 user_id: UserId,
205 new_name: &str,
206 ) -> Result<String> {
207 self.transaction(move |tx| async move {
208 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
209
210 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
211 .await?;
212
213 channel::ActiveModel {
214 id: ActiveValue::Unchanged(channel_id),
215 name: ActiveValue::Set(new_name.clone()),
216 ..Default::default()
217 }
218 .update(&*tx)
219 .await?;
220
221 Ok(new_name)
222 })
223 .await
224 }
225
226 pub async fn respond_to_channel_invite(
227 &self,
228 channel_id: ChannelId,
229 user_id: UserId,
230 accept: bool,
231 ) -> Result<()> {
232 self.transaction(move |tx| async move {
233 let rows_affected = if accept {
234 channel_member::Entity::update_many()
235 .set(channel_member::ActiveModel {
236 accepted: ActiveValue::Set(accept),
237 ..Default::default()
238 })
239 .filter(
240 channel_member::Column::ChannelId
241 .eq(channel_id)
242 .and(channel_member::Column::UserId.eq(user_id))
243 .and(channel_member::Column::Accepted.eq(false)),
244 )
245 .exec(&*tx)
246 .await?
247 .rows_affected
248 } else {
249 channel_member::ActiveModel {
250 channel_id: ActiveValue::Unchanged(channel_id),
251 user_id: ActiveValue::Unchanged(user_id),
252 ..Default::default()
253 }
254 .delete(&*tx)
255 .await?
256 .rows_affected
257 };
258
259 if rows_affected == 0 {
260 Err(anyhow!("no such invitation"))?;
261 }
262
263 Ok(())
264 })
265 .await
266 }
267
268 pub async fn remove_channel_member(
269 &self,
270 channel_id: ChannelId,
271 member_id: UserId,
272 remover_id: UserId,
273 ) -> Result<()> {
274 self.transaction(|tx| async move {
275 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
276 .await?;
277
278 let result = channel_member::Entity::delete_many()
279 .filter(
280 channel_member::Column::ChannelId
281 .eq(channel_id)
282 .and(channel_member::Column::UserId.eq(member_id)),
283 )
284 .exec(&*tx)
285 .await?;
286
287 if result.rows_affected == 0 {
288 Err(anyhow!("no such member"))?;
289 }
290
291 Ok(())
292 })
293 .await
294 }
295
296 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
297 self.transaction(|tx| async move {
298 let channel_invites = channel_member::Entity::find()
299 .filter(
300 channel_member::Column::UserId
301 .eq(user_id)
302 .and(channel_member::Column::Accepted.eq(false)),
303 )
304 .all(&*tx)
305 .await?;
306
307 let channels = channel::Entity::find()
308 .filter(
309 channel::Column::Id.is_in(
310 channel_invites
311 .into_iter()
312 .map(|channel_member| channel_member.channel_id),
313 ),
314 )
315 .all(&*tx)
316 .await?;
317
318 let channels = channels
319 .into_iter()
320 .map(|channel| Channel {
321 id: channel.id,
322 name: channel.name,
323 })
324 .collect();
325
326 Ok(channels)
327 })
328 .await
329 }
330
331 async fn get_channel_graph(
332 &self,
333 parents_by_child_id: ChannelDescendants,
334 trim_dangling_parents: bool,
335 tx: &DatabaseTransaction,
336 ) -> Result<ChannelGraph> {
337 let mut channels = Vec::with_capacity(parents_by_child_id.len());
338 {
339 let mut rows = channel::Entity::find()
340 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
341 .stream(&*tx)
342 .await?;
343 while let Some(row) = rows.next().await {
344 let row = row?;
345 channels.push(Channel {
346 id: row.id,
347 name: row.name,
348 })
349 }
350 }
351
352 let mut edges = Vec::with_capacity(parents_by_child_id.len());
353 for (channel, parents) in parents_by_child_id.iter() {
354 for parent in parents.into_iter() {
355 if trim_dangling_parents {
356 if parents_by_child_id.contains_key(parent) {
357 edges.push(ChannelEdge {
358 channel_id: channel.to_proto(),
359 parent_id: parent.to_proto(),
360 });
361 }
362 } else {
363 edges.push(ChannelEdge {
364 channel_id: channel.to_proto(),
365 parent_id: parent.to_proto(),
366 });
367 }
368 }
369 }
370
371 Ok(ChannelGraph { channels, edges })
372 }
373
374 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
375 self.transaction(|tx| async move {
376 let tx = tx;
377
378 let channel_memberships = channel_member::Entity::find()
379 .filter(
380 channel_member::Column::UserId
381 .eq(user_id)
382 .and(channel_member::Column::Accepted.eq(true)),
383 )
384 .all(&*tx)
385 .await?;
386
387 self.get_user_channels(user_id, channel_memberships, &tx)
388 .await
389 })
390 .await
391 }
392
393 pub async fn get_channel_for_user(
394 &self,
395 channel_id: ChannelId,
396 user_id: UserId,
397 ) -> Result<ChannelsForUser> {
398 self.transaction(|tx| async move {
399 let tx = tx;
400
401 let channel_membership = channel_member::Entity::find()
402 .filter(
403 channel_member::Column::UserId
404 .eq(user_id)
405 .and(channel_member::Column::ChannelId.eq(channel_id))
406 .and(channel_member::Column::Accepted.eq(true)),
407 )
408 .all(&*tx)
409 .await?;
410
411 self.get_user_channels(user_id, channel_membership, &tx)
412 .await
413 })
414 .await
415 }
416
417 pub async fn get_user_channels(
418 &self,
419 user_id: UserId,
420 channel_memberships: Vec<channel_member::Model>,
421 tx: &DatabaseTransaction,
422 ) -> Result<ChannelsForUser> {
423 let parents_by_child_id = self
424 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
425 .await?;
426
427 let channels_with_admin_privileges = channel_memberships
428 .iter()
429 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
430 .collect();
431
432 let graph = self
433 .get_channel_graph(parents_by_child_id, true, &tx)
434 .await?;
435
436 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
437 enum QueryUserIdsAndChannelIds {
438 ChannelId,
439 UserId,
440 }
441
442 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
443 {
444 let mut rows = room_participant::Entity::find()
445 .inner_join(room::Entity)
446 .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
447 .select_only()
448 .column(room::Column::ChannelId)
449 .column(room_participant::Column::UserId)
450 .into_values::<_, QueryUserIdsAndChannelIds>()
451 .stream(&*tx)
452 .await?;
453 while let Some(row) = rows.next().await {
454 let row: (ChannelId, UserId) = row?;
455 channel_participants.entry(row.0).or_default().push(row.1)
456 }
457 }
458
459 let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
460 let channel_buffer_changes = self
461 .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
462 .await?;
463
464 let unseen_messages = self
465 .unseen_channel_messages(user_id, &channel_ids, &*tx)
466 .await?;
467
468 Ok(ChannelsForUser {
469 channels: graph,
470 channel_participants,
471 channels_with_admin_privileges,
472 unseen_buffer_changes: channel_buffer_changes,
473 channel_messages: unseen_messages,
474 })
475 }
476
477 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
478 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
479 .await
480 }
481
482 pub async fn set_channel_member_admin(
483 &self,
484 channel_id: ChannelId,
485 from: UserId,
486 for_user: UserId,
487 admin: bool,
488 ) -> Result<()> {
489 self.transaction(|tx| async move {
490 self.check_user_is_channel_admin(channel_id, from, &*tx)
491 .await?;
492
493 let result = channel_member::Entity::update_many()
494 .filter(
495 channel_member::Column::ChannelId
496 .eq(channel_id)
497 .and(channel_member::Column::UserId.eq(for_user)),
498 )
499 .set(channel_member::ActiveModel {
500 admin: ActiveValue::set(admin),
501 ..Default::default()
502 })
503 .exec(&*tx)
504 .await?;
505
506 if result.rows_affected == 0 {
507 Err(anyhow!("no such member"))?;
508 }
509
510 Ok(())
511 })
512 .await
513 }
514
515 pub async fn get_channel_member_details(
516 &self,
517 channel_id: ChannelId,
518 user_id: UserId,
519 ) -> Result<Vec<proto::ChannelMember>> {
520 self.transaction(|tx| async move {
521 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
522 .await?;
523
524 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
525 enum QueryMemberDetails {
526 UserId,
527 Admin,
528 IsDirectMember,
529 Accepted,
530 }
531
532 let tx = tx;
533 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
534 let mut stream = channel_member::Entity::find()
535 .distinct()
536 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
537 .select_only()
538 .column(channel_member::Column::UserId)
539 .column(channel_member::Column::Admin)
540 .column_as(
541 channel_member::Column::ChannelId.eq(channel_id),
542 QueryMemberDetails::IsDirectMember,
543 )
544 .column(channel_member::Column::Accepted)
545 .order_by_asc(channel_member::Column::UserId)
546 .into_values::<_, QueryMemberDetails>()
547 .stream(&*tx)
548 .await?;
549
550 let mut rows = Vec::<proto::ChannelMember>::new();
551 while let Some(row) = stream.next().await {
552 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
553 UserId,
554 bool,
555 bool,
556 bool,
557 ) = row?;
558 let kind = match (is_direct_member, is_invite_accepted) {
559 (true, true) => proto::channel_member::Kind::Member,
560 (true, false) => proto::channel_member::Kind::Invitee,
561 (false, true) => proto::channel_member::Kind::AncestorMember,
562 (false, false) => continue,
563 };
564 let user_id = user_id.to_proto();
565 let kind = kind.into();
566 if let Some(last_row) = rows.last_mut() {
567 if last_row.user_id == user_id {
568 if is_direct_member {
569 last_row.kind = kind;
570 last_row.admin = is_admin;
571 }
572 continue;
573 }
574 }
575 rows.push(proto::ChannelMember {
576 user_id,
577 kind,
578 admin: is_admin,
579 });
580 }
581
582 Ok(rows)
583 })
584 .await
585 }
586
587 pub async fn get_channel_members_internal(
588 &self,
589 id: ChannelId,
590 tx: &DatabaseTransaction,
591 ) -> Result<Vec<UserId>> {
592 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
593 let user_ids = channel_member::Entity::find()
594 .distinct()
595 .filter(
596 channel_member::Column::ChannelId
597 .is_in(ancestor_ids.iter().copied())
598 .and(channel_member::Column::Accepted.eq(true)),
599 )
600 .select_only()
601 .column(channel_member::Column::UserId)
602 .into_values::<_, QueryUserIds>()
603 .all(&*tx)
604 .await?;
605 Ok(user_ids)
606 }
607
608 pub async fn check_user_is_channel_member(
609 &self,
610 channel_id: ChannelId,
611 user_id: UserId,
612 tx: &DatabaseTransaction,
613 ) -> Result<()> {
614 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
615 channel_member::Entity::find()
616 .filter(
617 channel_member::Column::ChannelId
618 .is_in(channel_ids)
619 .and(channel_member::Column::UserId.eq(user_id)),
620 )
621 .one(&*tx)
622 .await?
623 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
624 Ok(())
625 }
626
627 pub async fn check_user_is_channel_admin(
628 &self,
629 channel_id: ChannelId,
630 user_id: UserId,
631 tx: &DatabaseTransaction,
632 ) -> Result<()> {
633 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
634 channel_member::Entity::find()
635 .filter(
636 channel_member::Column::ChannelId
637 .is_in(channel_ids)
638 .and(channel_member::Column::UserId.eq(user_id))
639 .and(channel_member::Column::Admin.eq(true)),
640 )
641 .one(&*tx)
642 .await?
643 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
644 Ok(())
645 }
646
647 /// Returns the channel ancestors, deepest first
648 pub async fn get_channel_ancestors(
649 &self,
650 channel_id: ChannelId,
651 tx: &DatabaseTransaction,
652 ) -> Result<Vec<ChannelId>> {
653 let paths = channel_path::Entity::find()
654 .filter(channel_path::Column::ChannelId.eq(channel_id))
655 .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
656 .all(tx)
657 .await?;
658 let mut channel_ids = Vec::new();
659 for path in paths {
660 for id in path.id_path.trim_matches('/').split('/') {
661 if let Ok(id) = id.parse() {
662 let id = ChannelId::from_proto(id);
663 if let Err(ix) = channel_ids.binary_search(&id) {
664 channel_ids.insert(ix, id);
665 }
666 }
667 }
668 }
669 Ok(channel_ids)
670 }
671
672 /// Returns the channel descendants,
673 /// Structured as a map from child ids to their parent ids
674 /// For example, the descendants of 'a' in this DAG:
675 ///
676 /// /- b -\
677 /// a -- c -- d
678 ///
679 /// would be:
680 /// {
681 /// a: [],
682 /// b: [a],
683 /// c: [a],
684 /// d: [a, c],
685 /// }
686 async fn get_channel_descendants(
687 &self,
688 channel_ids: impl IntoIterator<Item = ChannelId>,
689 tx: &DatabaseTransaction,
690 ) -> Result<ChannelDescendants> {
691 let mut values = String::new();
692 for id in channel_ids {
693 if !values.is_empty() {
694 values.push_str(", ");
695 }
696 write!(&mut values, "({})", id).unwrap();
697 }
698
699 if values.is_empty() {
700 return Ok(HashMap::default());
701 }
702
703 let sql = format!(
704 r#"
705 SELECT
706 descendant_paths.*
707 FROM
708 channel_paths parent_paths, channel_paths descendant_paths
709 WHERE
710 parent_paths.channel_id IN ({values}) AND
711 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
712 "#
713 );
714
715 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
716
717 let mut parents_by_child_id: ChannelDescendants = HashMap::default();
718 let mut paths = channel_path::Entity::find()
719 .from_raw_sql(stmt)
720 .stream(tx)
721 .await?;
722
723 while let Some(path) = paths.next().await {
724 let path = path?;
725 let ids = path.id_path.trim_matches('/').split('/');
726 let mut parent_id = None;
727 for id in ids {
728 if let Ok(id) = id.parse() {
729 let id = ChannelId::from_proto(id);
730 if id == path.channel_id {
731 break;
732 }
733 parent_id = Some(id);
734 }
735 }
736 let entry = parents_by_child_id.entry(path.channel_id).or_default();
737 if let Some(parent_id) = parent_id {
738 entry.insert(parent_id);
739 }
740 }
741
742 Ok(parents_by_child_id)
743 }
744
745 /// Returns the channel with the given ID and:
746 /// - true if the user is a member
747 /// - false if the user hasn't accepted the invitation yet
748 pub async fn get_channel(
749 &self,
750 channel_id: ChannelId,
751 user_id: UserId,
752 ) -> Result<Option<(Channel, bool)>> {
753 self.transaction(|tx| async move {
754 let tx = tx;
755
756 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
757
758 if let Some(channel) = channel {
759 if self
760 .check_user_is_channel_member(channel_id, user_id, &*tx)
761 .await
762 .is_err()
763 {
764 return Ok(None);
765 }
766
767 let channel_membership = channel_member::Entity::find()
768 .filter(
769 channel_member::Column::ChannelId
770 .eq(channel_id)
771 .and(channel_member::Column::UserId.eq(user_id)),
772 )
773 .one(&*tx)
774 .await?;
775
776 let is_accepted = channel_membership
777 .map(|membership| membership.accepted)
778 .unwrap_or(false);
779
780 Ok(Some((
781 Channel {
782 id: channel.id,
783 name: channel.name,
784 },
785 is_accepted,
786 )))
787 } else {
788 Ok(None)
789 }
790 })
791 .await
792 }
793
794 pub async fn get_or_create_channel_room(
795 &self,
796 channel_id: ChannelId,
797 live_kit_room: &str,
798 enviroment: &str,
799 ) -> Result<RoomId> {
800 self.transaction(|tx| async move {
801 let tx = tx;
802
803 let room = room::Entity::find()
804 .filter(room::Column::ChannelId.eq(channel_id))
805 .one(&*tx)
806 .await?;
807
808 let room_id = if let Some(room) = room {
809 room.id
810 } else {
811 let result = room::Entity::insert(room::ActiveModel {
812 channel_id: ActiveValue::Set(Some(channel_id)),
813 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
814 enviroment: ActiveValue::Set(Some(enviroment.to_string())),
815 ..Default::default()
816 })
817 .exec(&*tx)
818 .await?;
819
820 result.last_insert_id
821 };
822
823 Ok(room_id)
824 })
825 .await
826 }
827
828 // Insert an edge from the given channel to the given other channel.
829 pub async fn link_channel(
830 &self,
831 user: UserId,
832 channel: ChannelId,
833 to: ChannelId,
834 ) -> Result<ChannelGraph> {
835 self.transaction(|tx| async move {
836 // Note that even with these maxed permissions, this linking operation
837 // is still insecure because you can't remove someone's permissions to a
838 // channel if they've linked the channel to one where they're an admin.
839 self.check_user_is_channel_admin(channel, user, &*tx)
840 .await?;
841
842 self.link_channel_internal(user, channel, to, &*tx).await
843 })
844 .await
845 }
846
847 pub async fn link_channel_internal(
848 &self,
849 user: UserId,
850 channel: ChannelId,
851 to: ChannelId,
852 tx: &DatabaseTransaction,
853 ) -> Result<ChannelGraph> {
854 self.check_user_is_channel_admin(to, user, &*tx).await?;
855
856 let paths = channel_path::Entity::find()
857 .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
858 .all(tx)
859 .await?;
860
861 let mut new_path_suffixes = HashSet::default();
862 for path in paths {
863 if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
864 new_path_suffixes.insert((
865 path.channel_id,
866 path.id_path[(start_offset + 1)..].to_string(),
867 ));
868 }
869 }
870
871 let paths_to_new_parent = channel_path::Entity::find()
872 .filter(channel_path::Column::ChannelId.eq(to))
873 .all(tx)
874 .await?;
875
876 let mut new_paths = Vec::new();
877 for path in paths_to_new_parent {
878 if path.id_path.contains(&format!("/{}/", channel)) {
879 Err(anyhow!("cycle"))?;
880 }
881
882 new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
883 channel_path::ActiveModel {
884 channel_id: ActiveValue::Set(*channel_id),
885 id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
886 }
887 }));
888 }
889
890 channel_path::Entity::insert_many(new_paths)
891 .exec(&*tx)
892 .await?;
893
894 // remove any root edges for the channel we just linked
895 {
896 channel_path::Entity::delete_many()
897 .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
898 .exec(&*tx)
899 .await?;
900 }
901
902 let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
903 if let Some(channel) = channel_descendants.get_mut(&channel) {
904 // Remove the other parents
905 channel.clear();
906 channel.insert(to);
907 }
908
909 let channels = self
910 .get_channel_graph(channel_descendants, false, &*tx)
911 .await?;
912
913 Ok(channels)
914 }
915
916 /// Unlink a channel from a given parent. This will add in a root edge if
917 /// the channel has no other parents after this operation.
918 pub async fn unlink_channel(
919 &self,
920 user: UserId,
921 channel: ChannelId,
922 from: ChannelId,
923 ) -> Result<()> {
924 self.transaction(|tx| async move {
925 // Note that even with these maxed permissions, this linking operation
926 // is still insecure because you can't remove someone's permissions to a
927 // channel if they've linked the channel to one where they're an admin.
928 self.check_user_is_channel_admin(channel, user, &*tx)
929 .await?;
930
931 self.unlink_channel_internal(user, channel, from, &*tx)
932 .await?;
933
934 Ok(())
935 })
936 .await
937 }
938
939 pub async fn unlink_channel_internal(
940 &self,
941 user: UserId,
942 channel: ChannelId,
943 from: ChannelId,
944 tx: &DatabaseTransaction,
945 ) -> Result<()> {
946 self.check_user_is_channel_admin(from, user, &*tx).await?;
947
948 let sql = r#"
949 DELETE FROM channel_paths
950 WHERE
951 id_path LIKE '%/' || $1 || '/' || $2 || '/%'
952 RETURNING id_path, channel_id
953 "#;
954
955 let paths = channel_path::Entity::find()
956 .from_raw_sql(Statement::from_sql_and_values(
957 self.pool.get_database_backend(),
958 sql,
959 [from.to_proto().into(), channel.to_proto().into()],
960 ))
961 .all(&*tx)
962 .await?;
963
964 let is_stranded = channel_path::Entity::find()
965 .filter(channel_path::Column::ChannelId.eq(channel))
966 .count(&*tx)
967 .await?
968 == 0;
969
970 // Make sure that there is always at least one path to the channel
971 if is_stranded {
972 let root_paths: Vec<_> = paths
973 .iter()
974 .map(|path| {
975 let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
976 channel_path::ActiveModel {
977 channel_id: ActiveValue::Set(path.channel_id),
978 id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
979 }
980 })
981 .collect();
982 channel_path::Entity::insert_many(root_paths)
983 .exec(&*tx)
984 .await?;
985 }
986
987 Ok(())
988 }
989
990 /// Move a channel from one parent to another, returns the
991 /// Channels that were moved for notifying clients
992 pub async fn move_channel(
993 &self,
994 user: UserId,
995 channel: ChannelId,
996 from: ChannelId,
997 to: ChannelId,
998 ) -> Result<ChannelGraph> {
999 if from == to {
1000 return Ok(ChannelGraph {
1001 channels: vec![],
1002 edges: vec![],
1003 });
1004 }
1005
1006 self.transaction(|tx| async move {
1007 self.check_user_is_channel_admin(channel, user, &*tx)
1008 .await?;
1009
1010 let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1011
1012 self.unlink_channel_internal(user, channel, from, &*tx)
1013 .await?;
1014
1015 Ok(moved_channels)
1016 })
1017 .await
1018 }
1019}
1020
1021#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1022enum QueryUserIds {
1023 UserId,
1024}
1025
1026#[derive(Debug)]
1027pub struct ChannelGraph {
1028 pub channels: Vec<Channel>,
1029 pub edges: Vec<ChannelEdge>,
1030}
1031
1032impl ChannelGraph {
1033 pub fn is_empty(&self) -> bool {
1034 self.channels.is_empty() && self.edges.is_empty()
1035 }
1036}
1037
1038#[cfg(test)]
1039impl PartialEq for ChannelGraph {
1040 fn eq(&self, other: &Self) -> bool {
1041 // Order independent comparison for tests
1042 let channels_set = self.channels.iter().collect::<HashSet<_>>();
1043 let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1044 let edges_set = self
1045 .edges
1046 .iter()
1047 .map(|edge| (edge.channel_id, edge.parent_id))
1048 .collect::<HashSet<_>>();
1049 let other_edges_set = other
1050 .edges
1051 .iter()
1052 .map(|edge| (edge.channel_id, edge.parent_id))
1053 .collect::<HashSet<_>>();
1054
1055 channels_set == other_channels_set && edges_set == other_edges_set
1056 }
1057}
1058
1059#[cfg(not(test))]
1060impl PartialEq for ChannelGraph {
1061 fn eq(&self, other: &Self) -> bool {
1062 self.channels == other.channels && self.edges == other.edges
1063 }
1064}
1065
1066struct SmallSet<T>(SmallVec<[T; 1]>);
1067
1068impl<T> Deref for SmallSet<T> {
1069 type Target = [T];
1070
1071 fn deref(&self) -> &Self::Target {
1072 self.0.deref()
1073 }
1074}
1075
1076impl<T> Default for SmallSet<T> {
1077 fn default() -> Self {
1078 Self(SmallVec::new())
1079 }
1080}
1081
1082impl<T> SmallSet<T> {
1083 fn insert(&mut self, value: T) -> bool
1084 where
1085 T: Ord,
1086 {
1087 match self.binary_search(&value) {
1088 Ok(_) => false,
1089 Err(ix) => {
1090 self.0.insert(ix, value);
1091 true
1092 }
1093 }
1094 }
1095
1096 fn clear(&mut self) {
1097 self.0.clear();
1098 }
1099}