1use super::*;
2use rpc::proto::ChannelEdge;
3use smallvec::SmallVec;
4
5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
6
7impl Database {
8 #[cfg(test)]
9 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
10 self.transaction(move |tx| async move {
11 let mut channels = Vec::new();
12 let mut rows = channel::Entity::find().stream(&*tx).await?;
13 while let Some(row) = rows.next().await {
14 let row = row?;
15 channels.push((row.id, row.name));
16 }
17 Ok(channels)
18 })
19 .await
20 }
21
22 pub async fn create_root_channel(
23 &self,
24 name: &str,
25 live_kit_room: &str,
26 creator_id: UserId,
27 ) -> Result<ChannelId> {
28 self.create_channel(name, None, live_kit_room, creator_id)
29 .await
30 }
31
32 pub async fn create_channel(
33 &self,
34 name: &str,
35 parent: Option<ChannelId>,
36 live_kit_room: &str,
37 creator_id: UserId,
38 ) -> Result<ChannelId> {
39 let name = Self::sanitize_channel_name(name)?;
40 self.transaction(move |tx| async move {
41 if let Some(parent) = parent {
42 self.check_user_is_channel_admin(parent, creator_id, &*tx)
43 .await?;
44 }
45
46 let channel = channel::ActiveModel {
47 name: ActiveValue::Set(name.to_string()),
48 ..Default::default()
49 }
50 .insert(&*tx)
51 .await?;
52
53 if let Some(parent) = parent {
54 let sql = r#"
55 INSERT INTO channel_paths
56 (id_path, channel_id)
57 SELECT
58 id_path || $1 || '/', $2
59 FROM
60 channel_paths
61 WHERE
62 channel_id = $3
63 "#;
64 let channel_paths_stmt = Statement::from_sql_and_values(
65 self.pool.get_database_backend(),
66 sql,
67 [
68 channel.id.to_proto().into(),
69 channel.id.to_proto().into(),
70 parent.to_proto().into(),
71 ],
72 );
73 tx.execute(channel_paths_stmt).await?;
74 } else {
75 channel_path::Entity::insert(channel_path::ActiveModel {
76 channel_id: ActiveValue::Set(channel.id),
77 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
78 })
79 .exec(&*tx)
80 .await?;
81 }
82
83 channel_member::ActiveModel {
84 channel_id: ActiveValue::Set(channel.id),
85 user_id: ActiveValue::Set(creator_id),
86 accepted: ActiveValue::Set(true),
87 admin: ActiveValue::Set(true),
88 ..Default::default()
89 }
90 .insert(&*tx)
91 .await?;
92
93 room::ActiveModel {
94 channel_id: ActiveValue::Set(Some(channel.id)),
95 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
96 ..Default::default()
97 }
98 .insert(&*tx)
99 .await?;
100
101 Ok(channel.id)
102 })
103 .await
104 }
105
106 pub async fn delete_channel(
107 &self,
108 channel_id: ChannelId,
109 user_id: UserId,
110 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
111 self.transaction(move |tx| async move {
112 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
113 .await?;
114
115 // Don't remove descendant channels that have additional parents.
116 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
117 {
118 let mut channels_to_keep = channel_path::Entity::find()
119 .filter(
120 channel_path::Column::ChannelId
121 .is_in(
122 channels_to_remove
123 .keys()
124 .copied()
125 .filter(|&id| id != channel_id),
126 )
127 .and(
128 channel_path::Column::IdPath
129 .not_like(&format!("%/{}/%", channel_id)),
130 ),
131 )
132 .stream(&*tx)
133 .await?;
134 while let Some(row) = channels_to_keep.next().await {
135 let row = row?;
136 channels_to_remove.remove(&row.channel_id);
137 }
138 }
139
140 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
141 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
142 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
143 .select_only()
144 .column(channel_member::Column::UserId)
145 .distinct()
146 .into_values::<_, QueryUserIds>()
147 .all(&*tx)
148 .await?;
149
150 channel::Entity::delete_many()
151 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
152 .exec(&*tx)
153 .await?;
154
155 // Delete any other paths that include this channel
156 let sql = r#"
157 DELETE FROM channel_paths
158 WHERE
159 id_path LIKE '%' || $1 || '%'
160 "#;
161 let channel_paths_stmt = Statement::from_sql_and_values(
162 self.pool.get_database_backend(),
163 sql,
164 [channel_id.to_proto().into()],
165 );
166 tx.execute(channel_paths_stmt).await?;
167
168 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
169 })
170 .await
171 }
172
173 pub async fn invite_channel_member(
174 &self,
175 channel_id: ChannelId,
176 invitee_id: UserId,
177 inviter_id: UserId,
178 is_admin: bool,
179 ) -> Result<()> {
180 self.transaction(move |tx| async move {
181 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
182 .await?;
183
184 channel_member::ActiveModel {
185 channel_id: ActiveValue::Set(channel_id),
186 user_id: ActiveValue::Set(invitee_id),
187 accepted: ActiveValue::Set(false),
188 admin: ActiveValue::Set(is_admin),
189 ..Default::default()
190 }
191 .insert(&*tx)
192 .await?;
193
194 Ok(())
195 })
196 .await
197 }
198
199 fn sanitize_channel_name(name: &str) -> Result<&str> {
200 let new_name = name.trim().trim_start_matches('#');
201 if new_name == "" {
202 Err(anyhow!("channel name can't be blank"))?;
203 }
204 Ok(new_name)
205 }
206
207 pub async fn rename_channel(
208 &self,
209 channel_id: ChannelId,
210 user_id: UserId,
211 new_name: &str,
212 ) -> Result<String> {
213 self.transaction(move |tx| async move {
214 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
215
216 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
217 .await?;
218
219 channel::ActiveModel {
220 id: ActiveValue::Unchanged(channel_id),
221 name: ActiveValue::Set(new_name.clone()),
222 ..Default::default()
223 }
224 .update(&*tx)
225 .await?;
226
227 Ok(new_name)
228 })
229 .await
230 }
231
232 pub async fn respond_to_channel_invite(
233 &self,
234 channel_id: ChannelId,
235 user_id: UserId,
236 accept: bool,
237 ) -> Result<()> {
238 self.transaction(move |tx| async move {
239 let rows_affected = if accept {
240 channel_member::Entity::update_many()
241 .set(channel_member::ActiveModel {
242 accepted: ActiveValue::Set(accept),
243 ..Default::default()
244 })
245 .filter(
246 channel_member::Column::ChannelId
247 .eq(channel_id)
248 .and(channel_member::Column::UserId.eq(user_id))
249 .and(channel_member::Column::Accepted.eq(false)),
250 )
251 .exec(&*tx)
252 .await?
253 .rows_affected
254 } else {
255 channel_member::ActiveModel {
256 channel_id: ActiveValue::Unchanged(channel_id),
257 user_id: ActiveValue::Unchanged(user_id),
258 ..Default::default()
259 }
260 .delete(&*tx)
261 .await?
262 .rows_affected
263 };
264
265 if rows_affected == 0 {
266 Err(anyhow!("no such invitation"))?;
267 }
268
269 Ok(())
270 })
271 .await
272 }
273
274 pub async fn remove_channel_member(
275 &self,
276 channel_id: ChannelId,
277 member_id: UserId,
278 remover_id: UserId,
279 ) -> Result<()> {
280 self.transaction(|tx| async move {
281 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
282 .await?;
283
284 let result = channel_member::Entity::delete_many()
285 .filter(
286 channel_member::Column::ChannelId
287 .eq(channel_id)
288 .and(channel_member::Column::UserId.eq(member_id)),
289 )
290 .exec(&*tx)
291 .await?;
292
293 if result.rows_affected == 0 {
294 Err(anyhow!("no such member"))?;
295 }
296
297 Ok(())
298 })
299 .await
300 }
301
302 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
303 self.transaction(|tx| async move {
304 let channel_invites = channel_member::Entity::find()
305 .filter(
306 channel_member::Column::UserId
307 .eq(user_id)
308 .and(channel_member::Column::Accepted.eq(false)),
309 )
310 .all(&*tx)
311 .await?;
312
313 let channels = channel::Entity::find()
314 .filter(
315 channel::Column::Id.is_in(
316 channel_invites
317 .into_iter()
318 .map(|channel_member| channel_member.channel_id),
319 ),
320 )
321 .all(&*tx)
322 .await?;
323
324 let channels = channels
325 .into_iter()
326 .map(|channel| Channel {
327 id: channel.id,
328 name: channel.name,
329 })
330 .collect();
331
332 Ok(channels)
333 })
334 .await
335 }
336
337 async fn get_channel_graph(
338 &self,
339 parents_by_child_id: ChannelDescendants,
340 trim_dangling_parents: bool,
341 tx: &DatabaseTransaction,
342 ) -> Result<ChannelGraph> {
343 let mut channels = Vec::with_capacity(parents_by_child_id.len());
344 {
345 let mut rows = channel::Entity::find()
346 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
347 .stream(&*tx)
348 .await?;
349 while let Some(row) = rows.next().await {
350 let row = row?;
351 channels.push(Channel {
352 id: row.id,
353 name: row.name,
354 })
355 }
356 }
357
358 let mut edges = Vec::with_capacity(parents_by_child_id.len());
359 for (channel, parents) in parents_by_child_id.iter() {
360 for parent in parents.into_iter() {
361 if trim_dangling_parents {
362 if parents_by_child_id.contains_key(parent) {
363 edges.push(ChannelEdge {
364 channel_id: channel.to_proto(),
365 parent_id: parent.to_proto(),
366 });
367 }
368 } else {
369 edges.push(ChannelEdge {
370 channel_id: channel.to_proto(),
371 parent_id: parent.to_proto(),
372 });
373 }
374 }
375 }
376
377 Ok(ChannelGraph { channels, edges })
378 }
379
380 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
381 self.transaction(|tx| async move {
382 let tx = tx;
383
384 let channel_memberships = channel_member::Entity::find()
385 .filter(
386 channel_member::Column::UserId
387 .eq(user_id)
388 .and(channel_member::Column::Accepted.eq(true)),
389 )
390 .all(&*tx)
391 .await?;
392
393 self.get_user_channels(user_id, channel_memberships, &tx)
394 .await
395 })
396 .await
397 }
398
399 pub async fn get_channel_for_user(
400 &self,
401 channel_id: ChannelId,
402 user_id: UserId,
403 ) -> Result<ChannelsForUser> {
404 self.transaction(|tx| async move {
405 let tx = tx;
406
407 let channel_membership = channel_member::Entity::find()
408 .filter(
409 channel_member::Column::UserId
410 .eq(user_id)
411 .and(channel_member::Column::ChannelId.eq(channel_id))
412 .and(channel_member::Column::Accepted.eq(true)),
413 )
414 .all(&*tx)
415 .await?;
416
417 self.get_user_channels(user_id, channel_membership, &tx)
418 .await
419 })
420 .await
421 }
422
423 pub async fn get_user_channels(
424 &self,
425 user_id: UserId,
426 channel_memberships: Vec<channel_member::Model>,
427 tx: &DatabaseTransaction,
428 ) -> Result<ChannelsForUser> {
429 let parents_by_child_id = self
430 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
431 .await?;
432
433 let channels_with_admin_privileges = channel_memberships
434 .iter()
435 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
436 .collect();
437
438 let graph = self
439 .get_channel_graph(parents_by_child_id, true, &tx)
440 .await?;
441
442 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
443 enum QueryUserIdsAndChannelIds {
444 ChannelId,
445 UserId,
446 }
447
448 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
449 {
450 let mut rows = room_participant::Entity::find()
451 .inner_join(room::Entity)
452 .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
453 .select_only()
454 .column(room::Column::ChannelId)
455 .column(room_participant::Column::UserId)
456 .into_values::<_, QueryUserIdsAndChannelIds>()
457 .stream(&*tx)
458 .await?;
459 while let Some(row) = rows.next().await {
460 let row: (ChannelId, UserId) = row?;
461 channel_participants.entry(row.0).or_default().push(row.1)
462 }
463 }
464
465 let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
466 let channel_buffer_changes = self
467 .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
468 .await?;
469
470 let unseen_messages = self
471 .unseen_channel_messages(user_id, &channel_ids, &*tx)
472 .await?;
473
474 Ok(ChannelsForUser {
475 channels: graph,
476 channel_participants,
477 channels_with_admin_privileges,
478 unseen_buffer_changes: channel_buffer_changes,
479 channel_messages: unseen_messages,
480 })
481 }
482
483 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
484 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
485 .await
486 }
487
488 pub async fn set_channel_member_admin(
489 &self,
490 channel_id: ChannelId,
491 from: UserId,
492 for_user: UserId,
493 admin: bool,
494 ) -> Result<()> {
495 self.transaction(|tx| async move {
496 self.check_user_is_channel_admin(channel_id, from, &*tx)
497 .await?;
498
499 let result = channel_member::Entity::update_many()
500 .filter(
501 channel_member::Column::ChannelId
502 .eq(channel_id)
503 .and(channel_member::Column::UserId.eq(for_user)),
504 )
505 .set(channel_member::ActiveModel {
506 admin: ActiveValue::set(admin),
507 ..Default::default()
508 })
509 .exec(&*tx)
510 .await?;
511
512 if result.rows_affected == 0 {
513 Err(anyhow!("no such member"))?;
514 }
515
516 Ok(())
517 })
518 .await
519 }
520
521 pub async fn get_channel_member_details(
522 &self,
523 channel_id: ChannelId,
524 user_id: UserId,
525 ) -> Result<Vec<proto::ChannelMember>> {
526 self.transaction(|tx| async move {
527 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
528 .await?;
529
530 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
531 enum QueryMemberDetails {
532 UserId,
533 Admin,
534 IsDirectMember,
535 Accepted,
536 }
537
538 let tx = tx;
539 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
540 let mut stream = channel_member::Entity::find()
541 .distinct()
542 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
543 .select_only()
544 .column(channel_member::Column::UserId)
545 .column(channel_member::Column::Admin)
546 .column_as(
547 channel_member::Column::ChannelId.eq(channel_id),
548 QueryMemberDetails::IsDirectMember,
549 )
550 .column(channel_member::Column::Accepted)
551 .order_by_asc(channel_member::Column::UserId)
552 .into_values::<_, QueryMemberDetails>()
553 .stream(&*tx)
554 .await?;
555
556 let mut rows = Vec::<proto::ChannelMember>::new();
557 while let Some(row) = stream.next().await {
558 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
559 UserId,
560 bool,
561 bool,
562 bool,
563 ) = row?;
564 let kind = match (is_direct_member, is_invite_accepted) {
565 (true, true) => proto::channel_member::Kind::Member,
566 (true, false) => proto::channel_member::Kind::Invitee,
567 (false, true) => proto::channel_member::Kind::AncestorMember,
568 (false, false) => continue,
569 };
570 let user_id = user_id.to_proto();
571 let kind = kind.into();
572 if let Some(last_row) = rows.last_mut() {
573 if last_row.user_id == user_id {
574 if is_direct_member {
575 last_row.kind = kind;
576 last_row.admin = is_admin;
577 }
578 continue;
579 }
580 }
581 rows.push(proto::ChannelMember {
582 user_id,
583 kind,
584 admin: is_admin,
585 });
586 }
587
588 Ok(rows)
589 })
590 .await
591 }
592
593 pub async fn get_channel_members_internal(
594 &self,
595 id: ChannelId,
596 tx: &DatabaseTransaction,
597 ) -> Result<Vec<UserId>> {
598 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
599 let user_ids = channel_member::Entity::find()
600 .distinct()
601 .filter(
602 channel_member::Column::ChannelId
603 .is_in(ancestor_ids.iter().copied())
604 .and(channel_member::Column::Accepted.eq(true)),
605 )
606 .select_only()
607 .column(channel_member::Column::UserId)
608 .into_values::<_, QueryUserIds>()
609 .all(&*tx)
610 .await?;
611 Ok(user_ids)
612 }
613
614 pub async fn check_user_is_channel_member(
615 &self,
616 channel_id: ChannelId,
617 user_id: UserId,
618 tx: &DatabaseTransaction,
619 ) -> Result<()> {
620 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
621 channel_member::Entity::find()
622 .filter(
623 channel_member::Column::ChannelId
624 .is_in(channel_ids)
625 .and(channel_member::Column::UserId.eq(user_id)),
626 )
627 .one(&*tx)
628 .await?
629 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
630 Ok(())
631 }
632
633 pub async fn check_user_is_channel_admin(
634 &self,
635 channel_id: ChannelId,
636 user_id: UserId,
637 tx: &DatabaseTransaction,
638 ) -> Result<()> {
639 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
640 channel_member::Entity::find()
641 .filter(
642 channel_member::Column::ChannelId
643 .is_in(channel_ids)
644 .and(channel_member::Column::UserId.eq(user_id))
645 .and(channel_member::Column::Admin.eq(true)),
646 )
647 .one(&*tx)
648 .await?
649 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
650 Ok(())
651 }
652
653 /// Returns the channel ancestors, deepest first
654 pub async fn get_channel_ancestors(
655 &self,
656 channel_id: ChannelId,
657 tx: &DatabaseTransaction,
658 ) -> Result<Vec<ChannelId>> {
659 let paths = channel_path::Entity::find()
660 .filter(channel_path::Column::ChannelId.eq(channel_id))
661 .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
662 .all(tx)
663 .await?;
664 let mut channel_ids = Vec::new();
665 for path in paths {
666 for id in path.id_path.trim_matches('/').split('/') {
667 if let Ok(id) = id.parse() {
668 let id = ChannelId::from_proto(id);
669 if let Err(ix) = channel_ids.binary_search(&id) {
670 channel_ids.insert(ix, id);
671 }
672 }
673 }
674 }
675 Ok(channel_ids)
676 }
677
678 /// Returns the channel descendants,
679 /// Structured as a map from child ids to their parent ids
680 /// For example, the descendants of 'a' in this DAG:
681 ///
682 /// /- b -\
683 /// a -- c -- d
684 ///
685 /// would be:
686 /// {
687 /// a: [],
688 /// b: [a],
689 /// c: [a],
690 /// d: [a, c],
691 /// }
692 async fn get_channel_descendants(
693 &self,
694 channel_ids: impl IntoIterator<Item = ChannelId>,
695 tx: &DatabaseTransaction,
696 ) -> Result<ChannelDescendants> {
697 let mut values = String::new();
698 for id in channel_ids {
699 if !values.is_empty() {
700 values.push_str(", ");
701 }
702 write!(&mut values, "({})", id).unwrap();
703 }
704
705 if values.is_empty() {
706 return Ok(HashMap::default());
707 }
708
709 let sql = format!(
710 r#"
711 SELECT
712 descendant_paths.*
713 FROM
714 channel_paths parent_paths, channel_paths descendant_paths
715 WHERE
716 parent_paths.channel_id IN ({values}) AND
717 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
718 "#
719 );
720
721 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
722
723 let mut parents_by_child_id: ChannelDescendants = HashMap::default();
724 let mut paths = channel_path::Entity::find()
725 .from_raw_sql(stmt)
726 .stream(tx)
727 .await?;
728
729 while let Some(path) = paths.next().await {
730 let path = path?;
731 let ids = path.id_path.trim_matches('/').split('/');
732 let mut parent_id = None;
733 for id in ids {
734 if let Ok(id) = id.parse() {
735 let id = ChannelId::from_proto(id);
736 if id == path.channel_id {
737 break;
738 }
739 parent_id = Some(id);
740 }
741 }
742 let entry = parents_by_child_id.entry(path.channel_id).or_default();
743 if let Some(parent_id) = parent_id {
744 entry.insert(parent_id);
745 }
746 }
747
748 Ok(parents_by_child_id)
749 }
750
751 /// Returns the channel with the given ID and:
752 /// - true if the user is a member
753 /// - false if the user hasn't accepted the invitation yet
754 pub async fn get_channel(
755 &self,
756 channel_id: ChannelId,
757 user_id: UserId,
758 ) -> Result<Option<(Channel, bool)>> {
759 self.transaction(|tx| async move {
760 let tx = tx;
761
762 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
763
764 if let Some(channel) = channel {
765 if self
766 .check_user_is_channel_member(channel_id, user_id, &*tx)
767 .await
768 .is_err()
769 {
770 return Ok(None);
771 }
772
773 let channel_membership = channel_member::Entity::find()
774 .filter(
775 channel_member::Column::ChannelId
776 .eq(channel_id)
777 .and(channel_member::Column::UserId.eq(user_id)),
778 )
779 .one(&*tx)
780 .await?;
781
782 let is_accepted = channel_membership
783 .map(|membership| membership.accepted)
784 .unwrap_or(false);
785
786 Ok(Some((
787 Channel {
788 id: channel.id,
789 name: channel.name,
790 },
791 is_accepted,
792 )))
793 } else {
794 Ok(None)
795 }
796 })
797 .await
798 }
799
800 pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
801 self.transaction(|tx| async move {
802 let tx = tx;
803 let room = channel::Model {
804 id: channel_id,
805 ..Default::default()
806 }
807 .find_related(room::Entity)
808 .one(&*tx)
809 .await?
810 .ok_or_else(|| anyhow!("invalid channel"))?;
811 Ok(room.id)
812 })
813 .await
814 }
815
816 // Insert an edge from the given channel to the given other channel.
817 pub async fn link_channel(
818 &self,
819 user: UserId,
820 channel: ChannelId,
821 to: ChannelId,
822 ) -> Result<ChannelGraph> {
823 self.transaction(|tx| async move {
824 // Note that even with these maxed permissions, this linking operation
825 // is still insecure because you can't remove someone's permissions to a
826 // channel if they've linked the channel to one where they're an admin.
827 self.check_user_is_channel_admin(channel, user, &*tx)
828 .await?;
829
830 self.link_channel_internal(user, channel, to, &*tx).await
831 })
832 .await
833 }
834
835 pub async fn link_channel_internal(
836 &self,
837 user: UserId,
838 channel: ChannelId,
839 to: ChannelId,
840 tx: &DatabaseTransaction,
841 ) -> Result<ChannelGraph> {
842 self.check_user_is_channel_admin(to, user, &*tx).await?;
843
844 let paths = channel_path::Entity::find()
845 .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
846 .all(tx)
847 .await?;
848
849 let mut new_path_suffixes = HashSet::default();
850 for path in paths {
851 if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
852 new_path_suffixes.insert((
853 path.channel_id,
854 path.id_path[(start_offset + 1)..].to_string(),
855 ));
856 }
857 }
858
859 let paths_to_new_parent = channel_path::Entity::find()
860 .filter(channel_path::Column::ChannelId.eq(to))
861 .all(tx)
862 .await?;
863
864 let mut new_paths = Vec::new();
865 for path in paths_to_new_parent {
866 if path.id_path.contains(&format!("/{}/", channel)) {
867 Err(anyhow!("cycle"))?;
868 }
869
870 new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
871 channel_path::ActiveModel {
872 channel_id: ActiveValue::Set(*channel_id),
873 id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
874 }
875 }));
876 }
877
878 channel_path::Entity::insert_many(new_paths)
879 .exec(&*tx)
880 .await?;
881
882 // remove any root edges for the channel we just linked
883 {
884 channel_path::Entity::delete_many()
885 .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
886 .exec(&*tx)
887 .await?;
888 }
889
890 let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
891 if let Some(channel) = channel_descendants.get_mut(&channel) {
892 // Remove the other parents
893 channel.clear();
894 channel.insert(to);
895 }
896
897 let channels = self
898 .get_channel_graph(channel_descendants, false, &*tx)
899 .await?;
900
901 Ok(channels)
902 }
903
904 /// Unlink a channel from a given parent. This will add in a root edge if
905 /// the channel has no other parents after this operation.
906 pub async fn unlink_channel(
907 &self,
908 user: UserId,
909 channel: ChannelId,
910 from: ChannelId,
911 ) -> Result<()> {
912 self.transaction(|tx| async move {
913 // Note that even with these maxed permissions, this linking operation
914 // is still insecure because you can't remove someone's permissions to a
915 // channel if they've linked the channel to one where they're an admin.
916 self.check_user_is_channel_admin(channel, user, &*tx)
917 .await?;
918
919 self.unlink_channel_internal(user, channel, from, &*tx)
920 .await?;
921
922 Ok(())
923 })
924 .await
925 }
926
927 pub async fn unlink_channel_internal(
928 &self,
929 user: UserId,
930 channel: ChannelId,
931 from: ChannelId,
932 tx: &DatabaseTransaction,
933 ) -> Result<()> {
934 self.check_user_is_channel_admin(from, user, &*tx).await?;
935
936 let sql = r#"
937 DELETE FROM channel_paths
938 WHERE
939 id_path LIKE '%/' || $1 || '/' || $2 || '/%'
940 RETURNING id_path, channel_id
941 "#;
942
943 let paths = channel_path::Entity::find()
944 .from_raw_sql(Statement::from_sql_and_values(
945 self.pool.get_database_backend(),
946 sql,
947 [from.to_proto().into(), channel.to_proto().into()],
948 ))
949 .all(&*tx)
950 .await?;
951
952 let is_stranded = channel_path::Entity::find()
953 .filter(channel_path::Column::ChannelId.eq(channel))
954 .count(&*tx)
955 .await?
956 == 0;
957
958 // Make sure that there is always at least one path to the channel
959 if is_stranded {
960 let root_paths: Vec<_> = paths
961 .iter()
962 .map(|path| {
963 let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
964 channel_path::ActiveModel {
965 channel_id: ActiveValue::Set(path.channel_id),
966 id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
967 }
968 })
969 .collect();
970 channel_path::Entity::insert_many(root_paths)
971 .exec(&*tx)
972 .await?;
973 }
974
975 Ok(())
976 }
977
978 /// Move a channel from one parent to another, returns the
979 /// Channels that were moved for notifying clients
980 pub async fn move_channel(
981 &self,
982 user: UserId,
983 channel: ChannelId,
984 from: ChannelId,
985 to: ChannelId,
986 ) -> Result<ChannelGraph> {
987 if from == to {
988 return Ok(ChannelGraph {
989 channels: vec![],
990 edges: vec![],
991 });
992 }
993
994 self.transaction(|tx| async move {
995 self.check_user_is_channel_admin(channel, user, &*tx)
996 .await?;
997
998 let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
999
1000 self.unlink_channel_internal(user, channel, from, &*tx)
1001 .await?;
1002
1003 Ok(moved_channels)
1004 })
1005 .await
1006 }
1007}
1008
1009#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1010enum QueryUserIds {
1011 UserId,
1012}
1013
1014#[derive(Debug)]
1015pub struct ChannelGraph {
1016 pub channels: Vec<Channel>,
1017 pub edges: Vec<ChannelEdge>,
1018}
1019
1020impl ChannelGraph {
1021 pub fn is_empty(&self) -> bool {
1022 self.channels.is_empty() && self.edges.is_empty()
1023 }
1024}
1025
1026#[cfg(test)]
1027impl PartialEq for ChannelGraph {
1028 fn eq(&self, other: &Self) -> bool {
1029 // Order independent comparison for tests
1030 let channels_set = self.channels.iter().collect::<HashSet<_>>();
1031 let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1032 let edges_set = self
1033 .edges
1034 .iter()
1035 .map(|edge| (edge.channel_id, edge.parent_id))
1036 .collect::<HashSet<_>>();
1037 let other_edges_set = other
1038 .edges
1039 .iter()
1040 .map(|edge| (edge.channel_id, edge.parent_id))
1041 .collect::<HashSet<_>>();
1042
1043 channels_set == other_channels_set && edges_set == other_edges_set
1044 }
1045}
1046
1047#[cfg(not(test))]
1048impl PartialEq for ChannelGraph {
1049 fn eq(&self, other: &Self) -> bool {
1050 self.channels == other.channels && self.edges == other.edges
1051 }
1052}
1053
1054struct SmallSet<T>(SmallVec<[T; 1]>);
1055
1056impl<T> Deref for SmallSet<T> {
1057 type Target = [T];
1058
1059 fn deref(&self) -> &Self::Target {
1060 self.0.deref()
1061 }
1062}
1063
1064impl<T> Default for SmallSet<T> {
1065 fn default() -> Self {
1066 Self(SmallVec::new())
1067 }
1068}
1069
1070impl<T> SmallSet<T> {
1071 fn insert(&mut self, value: T) -> bool
1072 where
1073 T: Ord,
1074 {
1075 match self.binary_search(&value) {
1076 Ok(_) => false,
1077 Err(ix) => {
1078 self.0.insert(ix, value);
1079 true
1080 }
1081 }
1082 }
1083
1084 fn clear(&mut self) {
1085 self.0.clear();
1086 }
1087}