1use super::*;
2use rpc::proto::ChannelEdge;
3use smallvec::SmallVec;
4
5type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
6
7impl Database {
8 #[cfg(test)]
9 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
10 self.transaction(move |tx| async move {
11 let mut channels = Vec::new();
12 let mut rows = channel::Entity::find().stream(&*tx).await?;
13 while let Some(row) = rows.next().await {
14 let row = row?;
15 channels.push((row.id, row.name));
16 }
17 Ok(channels)
18 })
19 .await
20 }
21
22 pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
23 self.create_channel(name, None, creator_id).await
24 }
25
26 pub async fn create_channel(
27 &self,
28 name: &str,
29 parent: Option<ChannelId>,
30 creator_id: UserId,
31 ) -> Result<ChannelId> {
32 let name = Self::sanitize_channel_name(name)?;
33 self.transaction(move |tx| async move {
34 if let Some(parent) = parent {
35 self.check_user_is_channel_admin(parent, creator_id, &*tx)
36 .await?;
37 }
38
39 let channel = channel::ActiveModel {
40 name: ActiveValue::Set(name.to_string()),
41 ..Default::default()
42 }
43 .insert(&*tx)
44 .await?;
45
46 if let Some(parent) = parent {
47 let sql = r#"
48 INSERT INTO channel_paths
49 (id_path, channel_id)
50 SELECT
51 id_path || $1 || '/', $2
52 FROM
53 channel_paths
54 WHERE
55 channel_id = $3
56 "#;
57 let channel_paths_stmt = Statement::from_sql_and_values(
58 self.pool.get_database_backend(),
59 sql,
60 [
61 channel.id.to_proto().into(),
62 channel.id.to_proto().into(),
63 parent.to_proto().into(),
64 ],
65 );
66 tx.execute(channel_paths_stmt).await?;
67 } else {
68 channel_path::Entity::insert(channel_path::ActiveModel {
69 channel_id: ActiveValue::Set(channel.id),
70 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
71 })
72 .exec(&*tx)
73 .await?;
74 }
75
76 channel_member::ActiveModel {
77 channel_id: ActiveValue::Set(channel.id),
78 user_id: ActiveValue::Set(creator_id),
79 accepted: ActiveValue::Set(true),
80 admin: ActiveValue::Set(true),
81 ..Default::default()
82 }
83 .insert(&*tx)
84 .await?;
85
86 Ok(channel.id)
87 })
88 .await
89 }
90
91 pub async fn delete_channel(
92 &self,
93 channel_id: ChannelId,
94 user_id: UserId,
95 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
96 self.transaction(move |tx| async move {
97 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
98 .await?;
99
100 // Don't remove descendant channels that have additional parents.
101 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
102 {
103 let mut channels_to_keep = channel_path::Entity::find()
104 .filter(
105 channel_path::Column::ChannelId
106 .is_in(
107 channels_to_remove
108 .keys()
109 .copied()
110 .filter(|&id| id != channel_id),
111 )
112 .and(
113 channel_path::Column::IdPath
114 .not_like(&format!("%/{}/%", channel_id)),
115 ),
116 )
117 .stream(&*tx)
118 .await?;
119 while let Some(row) = channels_to_keep.next().await {
120 let row = row?;
121 channels_to_remove.remove(&row.channel_id);
122 }
123 }
124
125 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
126 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
127 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
128 .select_only()
129 .column(channel_member::Column::UserId)
130 .distinct()
131 .into_values::<_, QueryUserIds>()
132 .all(&*tx)
133 .await?;
134
135 channel::Entity::delete_many()
136 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
137 .exec(&*tx)
138 .await?;
139
140 // Delete any other paths that include this channel
141 let sql = r#"
142 DELETE FROM channel_paths
143 WHERE
144 id_path LIKE '%' || $1 || '%'
145 "#;
146 let channel_paths_stmt = Statement::from_sql_and_values(
147 self.pool.get_database_backend(),
148 sql,
149 [channel_id.to_proto().into()],
150 );
151 tx.execute(channel_paths_stmt).await?;
152
153 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
154 })
155 .await
156 }
157
158 pub async fn invite_channel_member(
159 &self,
160 channel_id: ChannelId,
161 invitee_id: UserId,
162 inviter_id: UserId,
163 is_admin: bool,
164 ) -> Result<()> {
165 self.transaction(move |tx| async move {
166 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
167 .await?;
168
169 channel_member::ActiveModel {
170 channel_id: ActiveValue::Set(channel_id),
171 user_id: ActiveValue::Set(invitee_id),
172 accepted: ActiveValue::Set(false),
173 admin: ActiveValue::Set(is_admin),
174 ..Default::default()
175 }
176 .insert(&*tx)
177 .await?;
178
179 Ok(())
180 })
181 .await
182 }
183
184 fn sanitize_channel_name(name: &str) -> Result<&str> {
185 let new_name = name.trim().trim_start_matches('#');
186 if new_name == "" {
187 Err(anyhow!("channel name can't be blank"))?;
188 }
189 Ok(new_name)
190 }
191
192 pub async fn rename_channel(
193 &self,
194 channel_id: ChannelId,
195 user_id: UserId,
196 new_name: &str,
197 ) -> Result<String> {
198 self.transaction(move |tx| async move {
199 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
200
201 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
202 .await?;
203
204 channel::ActiveModel {
205 id: ActiveValue::Unchanged(channel_id),
206 name: ActiveValue::Set(new_name.clone()),
207 ..Default::default()
208 }
209 .update(&*tx)
210 .await?;
211
212 Ok(new_name)
213 })
214 .await
215 }
216
217 pub async fn respond_to_channel_invite(
218 &self,
219 channel_id: ChannelId,
220 user_id: UserId,
221 accept: bool,
222 ) -> Result<()> {
223 self.transaction(move |tx| async move {
224 let rows_affected = if accept {
225 channel_member::Entity::update_many()
226 .set(channel_member::ActiveModel {
227 accepted: ActiveValue::Set(accept),
228 ..Default::default()
229 })
230 .filter(
231 channel_member::Column::ChannelId
232 .eq(channel_id)
233 .and(channel_member::Column::UserId.eq(user_id))
234 .and(channel_member::Column::Accepted.eq(false)),
235 )
236 .exec(&*tx)
237 .await?
238 .rows_affected
239 } else {
240 channel_member::ActiveModel {
241 channel_id: ActiveValue::Unchanged(channel_id),
242 user_id: ActiveValue::Unchanged(user_id),
243 ..Default::default()
244 }
245 .delete(&*tx)
246 .await?
247 .rows_affected
248 };
249
250 if rows_affected == 0 {
251 Err(anyhow!("no such invitation"))?;
252 }
253
254 Ok(())
255 })
256 .await
257 }
258
259 pub async fn remove_channel_member(
260 &self,
261 channel_id: ChannelId,
262 member_id: UserId,
263 remover_id: UserId,
264 ) -> Result<()> {
265 self.transaction(|tx| async move {
266 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
267 .await?;
268
269 let result = channel_member::Entity::delete_many()
270 .filter(
271 channel_member::Column::ChannelId
272 .eq(channel_id)
273 .and(channel_member::Column::UserId.eq(member_id)),
274 )
275 .exec(&*tx)
276 .await?;
277
278 if result.rows_affected == 0 {
279 Err(anyhow!("no such member"))?;
280 }
281
282 Ok(())
283 })
284 .await
285 }
286
287 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
288 self.transaction(|tx| async move {
289 let channel_invites = channel_member::Entity::find()
290 .filter(
291 channel_member::Column::UserId
292 .eq(user_id)
293 .and(channel_member::Column::Accepted.eq(false)),
294 )
295 .all(&*tx)
296 .await?;
297
298 let channels = channel::Entity::find()
299 .filter(
300 channel::Column::Id.is_in(
301 channel_invites
302 .into_iter()
303 .map(|channel_member| channel_member.channel_id),
304 ),
305 )
306 .all(&*tx)
307 .await?;
308
309 let channels = channels
310 .into_iter()
311 .map(|channel| Channel {
312 id: channel.id,
313 name: channel.name,
314 })
315 .collect();
316
317 Ok(channels)
318 })
319 .await
320 }
321
322 async fn get_channel_graph(
323 &self,
324 parents_by_child_id: ChannelDescendants,
325 trim_dangling_parents: bool,
326 tx: &DatabaseTransaction,
327 ) -> Result<ChannelGraph> {
328 let mut channels = Vec::with_capacity(parents_by_child_id.len());
329 {
330 let mut rows = channel::Entity::find()
331 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
332 .stream(&*tx)
333 .await?;
334 while let Some(row) = rows.next().await {
335 let row = row?;
336 channels.push(Channel {
337 id: row.id,
338 name: row.name,
339 })
340 }
341 }
342
343 let mut edges = Vec::with_capacity(parents_by_child_id.len());
344 for (channel, parents) in parents_by_child_id.iter() {
345 for parent in parents.into_iter() {
346 if trim_dangling_parents {
347 if parents_by_child_id.contains_key(parent) {
348 edges.push(ChannelEdge {
349 channel_id: channel.to_proto(),
350 parent_id: parent.to_proto(),
351 });
352 }
353 } else {
354 edges.push(ChannelEdge {
355 channel_id: channel.to_proto(),
356 parent_id: parent.to_proto(),
357 });
358 }
359 }
360 }
361
362 Ok(ChannelGraph { channels, edges })
363 }
364
365 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
366 self.transaction(|tx| async move {
367 let tx = tx;
368
369 let channel_memberships = channel_member::Entity::find()
370 .filter(
371 channel_member::Column::UserId
372 .eq(user_id)
373 .and(channel_member::Column::Accepted.eq(true)),
374 )
375 .all(&*tx)
376 .await?;
377
378 self.get_user_channels(user_id, channel_memberships, &tx)
379 .await
380 })
381 .await
382 }
383
384 pub async fn get_channel_for_user(
385 &self,
386 channel_id: ChannelId,
387 user_id: UserId,
388 ) -> Result<ChannelsForUser> {
389 self.transaction(|tx| async move {
390 let tx = tx;
391
392 let channel_membership = channel_member::Entity::find()
393 .filter(
394 channel_member::Column::UserId
395 .eq(user_id)
396 .and(channel_member::Column::ChannelId.eq(channel_id))
397 .and(channel_member::Column::Accepted.eq(true)),
398 )
399 .all(&*tx)
400 .await?;
401
402 self.get_user_channels(user_id, channel_membership, &tx)
403 .await
404 })
405 .await
406 }
407
408 pub async fn get_user_channels(
409 &self,
410 user_id: UserId,
411 channel_memberships: Vec<channel_member::Model>,
412 tx: &DatabaseTransaction,
413 ) -> Result<ChannelsForUser> {
414 let parents_by_child_id = self
415 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
416 .await?;
417
418 let channels_with_admin_privileges = channel_memberships
419 .iter()
420 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
421 .collect();
422
423 let graph = self
424 .get_channel_graph(parents_by_child_id, true, &tx)
425 .await?;
426
427 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
428 enum QueryUserIdsAndChannelIds {
429 ChannelId,
430 UserId,
431 }
432
433 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
434 {
435 let mut rows = room_participant::Entity::find()
436 .inner_join(room::Entity)
437 .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id)))
438 .select_only()
439 .column(room::Column::ChannelId)
440 .column(room_participant::Column::UserId)
441 .into_values::<_, QueryUserIdsAndChannelIds>()
442 .stream(&*tx)
443 .await?;
444 while let Some(row) = rows.next().await {
445 let row: (ChannelId, UserId) = row?;
446 channel_participants.entry(row.0).or_default().push(row.1)
447 }
448 }
449
450 let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
451 let channel_buffer_changes = self
452 .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
453 .await?;
454
455 let unseen_messages = self
456 .unseen_channel_messages(user_id, &channel_ids, &*tx)
457 .await?;
458
459 Ok(ChannelsForUser {
460 channels: graph,
461 channel_participants,
462 channels_with_admin_privileges,
463 unseen_buffer_changes: channel_buffer_changes,
464 channel_messages: unseen_messages,
465 })
466 }
467
468 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
469 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
470 .await
471 }
472
473 pub async fn set_channel_member_admin(
474 &self,
475 channel_id: ChannelId,
476 from: UserId,
477 for_user: UserId,
478 admin: bool,
479 ) -> Result<()> {
480 self.transaction(|tx| async move {
481 self.check_user_is_channel_admin(channel_id, from, &*tx)
482 .await?;
483
484 let result = channel_member::Entity::update_many()
485 .filter(
486 channel_member::Column::ChannelId
487 .eq(channel_id)
488 .and(channel_member::Column::UserId.eq(for_user)),
489 )
490 .set(channel_member::ActiveModel {
491 admin: ActiveValue::set(admin),
492 ..Default::default()
493 })
494 .exec(&*tx)
495 .await?;
496
497 if result.rows_affected == 0 {
498 Err(anyhow!("no such member"))?;
499 }
500
501 Ok(())
502 })
503 .await
504 }
505
506 pub async fn get_channel_member_details(
507 &self,
508 channel_id: ChannelId,
509 user_id: UserId,
510 ) -> Result<Vec<proto::ChannelMember>> {
511 self.transaction(|tx| async move {
512 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
513 .await?;
514
515 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
516 enum QueryMemberDetails {
517 UserId,
518 Admin,
519 IsDirectMember,
520 Accepted,
521 }
522
523 let tx = tx;
524 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
525 let mut stream = channel_member::Entity::find()
526 .distinct()
527 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
528 .select_only()
529 .column(channel_member::Column::UserId)
530 .column(channel_member::Column::Admin)
531 .column_as(
532 channel_member::Column::ChannelId.eq(channel_id),
533 QueryMemberDetails::IsDirectMember,
534 )
535 .column(channel_member::Column::Accepted)
536 .order_by_asc(channel_member::Column::UserId)
537 .into_values::<_, QueryMemberDetails>()
538 .stream(&*tx)
539 .await?;
540
541 let mut rows = Vec::<proto::ChannelMember>::new();
542 while let Some(row) = stream.next().await {
543 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
544 UserId,
545 bool,
546 bool,
547 bool,
548 ) = row?;
549 let kind = match (is_direct_member, is_invite_accepted) {
550 (true, true) => proto::channel_member::Kind::Member,
551 (true, false) => proto::channel_member::Kind::Invitee,
552 (false, true) => proto::channel_member::Kind::AncestorMember,
553 (false, false) => continue,
554 };
555 let user_id = user_id.to_proto();
556 let kind = kind.into();
557 if let Some(last_row) = rows.last_mut() {
558 if last_row.user_id == user_id {
559 if is_direct_member {
560 last_row.kind = kind;
561 last_row.admin = is_admin;
562 }
563 continue;
564 }
565 }
566 rows.push(proto::ChannelMember {
567 user_id,
568 kind,
569 admin: is_admin,
570 });
571 }
572
573 Ok(rows)
574 })
575 .await
576 }
577
578 pub async fn get_channel_members_internal(
579 &self,
580 id: ChannelId,
581 tx: &DatabaseTransaction,
582 ) -> Result<Vec<UserId>> {
583 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
584 let user_ids = channel_member::Entity::find()
585 .distinct()
586 .filter(
587 channel_member::Column::ChannelId
588 .is_in(ancestor_ids.iter().copied())
589 .and(channel_member::Column::Accepted.eq(true)),
590 )
591 .select_only()
592 .column(channel_member::Column::UserId)
593 .into_values::<_, QueryUserIds>()
594 .all(&*tx)
595 .await?;
596 Ok(user_ids)
597 }
598
599 pub async fn check_user_is_channel_member(
600 &self,
601 channel_id: ChannelId,
602 user_id: UserId,
603 tx: &DatabaseTransaction,
604 ) -> Result<()> {
605 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
606 channel_member::Entity::find()
607 .filter(
608 channel_member::Column::ChannelId
609 .is_in(channel_ids)
610 .and(channel_member::Column::UserId.eq(user_id)),
611 )
612 .one(&*tx)
613 .await?
614 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
615 Ok(())
616 }
617
618 pub async fn check_user_is_channel_admin(
619 &self,
620 channel_id: ChannelId,
621 user_id: UserId,
622 tx: &DatabaseTransaction,
623 ) -> Result<()> {
624 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
625 channel_member::Entity::find()
626 .filter(
627 channel_member::Column::ChannelId
628 .is_in(channel_ids)
629 .and(channel_member::Column::UserId.eq(user_id))
630 .and(channel_member::Column::Admin.eq(true)),
631 )
632 .one(&*tx)
633 .await?
634 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
635 Ok(())
636 }
637
638 /// Returns the channel ancestors, deepest first
639 pub async fn get_channel_ancestors(
640 &self,
641 channel_id: ChannelId,
642 tx: &DatabaseTransaction,
643 ) -> Result<Vec<ChannelId>> {
644 let paths = channel_path::Entity::find()
645 .filter(channel_path::Column::ChannelId.eq(channel_id))
646 .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
647 .all(tx)
648 .await?;
649 let mut channel_ids = Vec::new();
650 for path in paths {
651 for id in path.id_path.trim_matches('/').split('/') {
652 if let Ok(id) = id.parse() {
653 let id = ChannelId::from_proto(id);
654 if let Err(ix) = channel_ids.binary_search(&id) {
655 channel_ids.insert(ix, id);
656 }
657 }
658 }
659 }
660 Ok(channel_ids)
661 }
662
663 /// Returns the channel descendants,
664 /// Structured as a map from child ids to their parent ids
665 /// For example, the descendants of 'a' in this DAG:
666 ///
667 /// /- b -\
668 /// a -- c -- d
669 ///
670 /// would be:
671 /// {
672 /// a: [],
673 /// b: [a],
674 /// c: [a],
675 /// d: [a, c],
676 /// }
677 async fn get_channel_descendants(
678 &self,
679 channel_ids: impl IntoIterator<Item = ChannelId>,
680 tx: &DatabaseTransaction,
681 ) -> Result<ChannelDescendants> {
682 let mut values = String::new();
683 for id in channel_ids {
684 if !values.is_empty() {
685 values.push_str(", ");
686 }
687 write!(&mut values, "({})", id).unwrap();
688 }
689
690 if values.is_empty() {
691 return Ok(HashMap::default());
692 }
693
694 let sql = format!(
695 r#"
696 SELECT
697 descendant_paths.*
698 FROM
699 channel_paths parent_paths, channel_paths descendant_paths
700 WHERE
701 parent_paths.channel_id IN ({values}) AND
702 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
703 "#
704 );
705
706 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
707
708 let mut parents_by_child_id: ChannelDescendants = HashMap::default();
709 let mut paths = channel_path::Entity::find()
710 .from_raw_sql(stmt)
711 .stream(tx)
712 .await?;
713
714 while let Some(path) = paths.next().await {
715 let path = path?;
716 let ids = path.id_path.trim_matches('/').split('/');
717 let mut parent_id = None;
718 for id in ids {
719 if let Ok(id) = id.parse() {
720 let id = ChannelId::from_proto(id);
721 if id == path.channel_id {
722 break;
723 }
724 parent_id = Some(id);
725 }
726 }
727 let entry = parents_by_child_id.entry(path.channel_id).or_default();
728 if let Some(parent_id) = parent_id {
729 entry.insert(parent_id);
730 }
731 }
732
733 Ok(parents_by_child_id)
734 }
735
736 /// Returns the channel with the given ID and:
737 /// - true if the user is a member
738 /// - false if the user hasn't accepted the invitation yet
739 pub async fn get_channel(
740 &self,
741 channel_id: ChannelId,
742 user_id: UserId,
743 ) -> Result<Option<(Channel, bool)>> {
744 self.transaction(|tx| async move {
745 let tx = tx;
746
747 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
748
749 if let Some(channel) = channel {
750 if self
751 .check_user_is_channel_member(channel_id, user_id, &*tx)
752 .await
753 .is_err()
754 {
755 return Ok(None);
756 }
757
758 let channel_membership = channel_member::Entity::find()
759 .filter(
760 channel_member::Column::ChannelId
761 .eq(channel_id)
762 .and(channel_member::Column::UserId.eq(user_id)),
763 )
764 .one(&*tx)
765 .await?;
766
767 let is_accepted = channel_membership
768 .map(|membership| membership.accepted)
769 .unwrap_or(false);
770
771 Ok(Some((
772 Channel {
773 id: channel.id,
774 name: channel.name,
775 },
776 is_accepted,
777 )))
778 } else {
779 Ok(None)
780 }
781 })
782 .await
783 }
784
785 pub async fn get_or_create_channel_room(
786 &self,
787 channel_id: ChannelId,
788 live_kit_room: &str,
789 enviroment: &str,
790 ) -> Result<RoomId> {
791 self.transaction(|tx| async move {
792 let tx = tx;
793
794 let room = room::Entity::find()
795 .filter(room::Column::ChannelId.eq(channel_id))
796 .one(&*tx)
797 .await?;
798
799 let room_id = if let Some(room) = room {
800 room.id
801 } else {
802 let result = room::Entity::insert(room::ActiveModel {
803 channel_id: ActiveValue::Set(Some(channel_id)),
804 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
805 enviroment: ActiveValue::Set(Some(enviroment.to_string())),
806 ..Default::default()
807 })
808 .exec(&*tx)
809 .await?;
810
811 result.last_insert_id
812 };
813
814 Ok(room_id)
815 })
816 .await
817 }
818
819 // Insert an edge from the given channel to the given other channel.
820 pub async fn link_channel(
821 &self,
822 user: UserId,
823 channel: ChannelId,
824 to: ChannelId,
825 ) -> Result<ChannelGraph> {
826 self.transaction(|tx| async move {
827 // Note that even with these maxed permissions, this linking operation
828 // is still insecure because you can't remove someone's permissions to a
829 // channel if they've linked the channel to one where they're an admin.
830 self.check_user_is_channel_admin(channel, user, &*tx)
831 .await?;
832
833 self.link_channel_internal(user, channel, to, &*tx).await
834 })
835 .await
836 }
837
838 pub async fn link_channel_internal(
839 &self,
840 user: UserId,
841 channel: ChannelId,
842 to: ChannelId,
843 tx: &DatabaseTransaction,
844 ) -> Result<ChannelGraph> {
845 self.check_user_is_channel_admin(to, user, &*tx).await?;
846
847 let paths = channel_path::Entity::find()
848 .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
849 .all(tx)
850 .await?;
851
852 let mut new_path_suffixes = HashSet::default();
853 for path in paths {
854 if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
855 new_path_suffixes.insert((
856 path.channel_id,
857 path.id_path[(start_offset + 1)..].to_string(),
858 ));
859 }
860 }
861
862 let paths_to_new_parent = channel_path::Entity::find()
863 .filter(channel_path::Column::ChannelId.eq(to))
864 .all(tx)
865 .await?;
866
867 let mut new_paths = Vec::new();
868 for path in paths_to_new_parent {
869 if path.id_path.contains(&format!("/{}/", channel)) {
870 Err(anyhow!("cycle"))?;
871 }
872
873 new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
874 channel_path::ActiveModel {
875 channel_id: ActiveValue::Set(*channel_id),
876 id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
877 }
878 }));
879 }
880
881 channel_path::Entity::insert_many(new_paths)
882 .exec(&*tx)
883 .await?;
884
885 // remove any root edges for the channel we just linked
886 {
887 channel_path::Entity::delete_many()
888 .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
889 .exec(&*tx)
890 .await?;
891 }
892
893 let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?;
894 if let Some(channel) = channel_descendants.get_mut(&channel) {
895 // Remove the other parents
896 channel.clear();
897 channel.insert(to);
898 }
899
900 let channels = self
901 .get_channel_graph(channel_descendants, false, &*tx)
902 .await?;
903
904 Ok(channels)
905 }
906
907 /// Unlink a channel from a given parent. This will add in a root edge if
908 /// the channel has no other parents after this operation.
909 pub async fn unlink_channel(
910 &self,
911 user: UserId,
912 channel: ChannelId,
913 from: ChannelId,
914 ) -> Result<()> {
915 self.transaction(|tx| async move {
916 // Note that even with these maxed permissions, this linking operation
917 // is still insecure because you can't remove someone's permissions to a
918 // channel if they've linked the channel to one where they're an admin.
919 self.check_user_is_channel_admin(channel, user, &*tx)
920 .await?;
921
922 self.unlink_channel_internal(user, channel, from, &*tx)
923 .await?;
924
925 Ok(())
926 })
927 .await
928 }
929
930 pub async fn unlink_channel_internal(
931 &self,
932 user: UserId,
933 channel: ChannelId,
934 from: ChannelId,
935 tx: &DatabaseTransaction,
936 ) -> Result<()> {
937 self.check_user_is_channel_admin(from, user, &*tx).await?;
938
939 let sql = r#"
940 DELETE FROM channel_paths
941 WHERE
942 id_path LIKE '%/' || $1 || '/' || $2 || '/%'
943 RETURNING id_path, channel_id
944 "#;
945
946 let paths = channel_path::Entity::find()
947 .from_raw_sql(Statement::from_sql_and_values(
948 self.pool.get_database_backend(),
949 sql,
950 [from.to_proto().into(), channel.to_proto().into()],
951 ))
952 .all(&*tx)
953 .await?;
954
955 let is_stranded = channel_path::Entity::find()
956 .filter(channel_path::Column::ChannelId.eq(channel))
957 .count(&*tx)
958 .await?
959 == 0;
960
961 // Make sure that there is always at least one path to the channel
962 if is_stranded {
963 let root_paths: Vec<_> = paths
964 .iter()
965 .map(|path| {
966 let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
967 channel_path::ActiveModel {
968 channel_id: ActiveValue::Set(path.channel_id),
969 id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
970 }
971 })
972 .collect();
973 channel_path::Entity::insert_many(root_paths)
974 .exec(&*tx)
975 .await?;
976 }
977
978 Ok(())
979 }
980
981 /// Move a channel from one parent to another, returns the
982 /// Channels that were moved for notifying clients
983 pub async fn move_channel(
984 &self,
985 user: UserId,
986 channel: ChannelId,
987 from: ChannelId,
988 to: ChannelId,
989 ) -> Result<ChannelGraph> {
990 if from == to {
991 return Ok(ChannelGraph {
992 channels: vec![],
993 edges: vec![],
994 });
995 }
996
997 self.transaction(|tx| async move {
998 self.check_user_is_channel_admin(channel, user, &*tx)
999 .await?;
1000
1001 let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1002
1003 self.unlink_channel_internal(user, channel, from, &*tx)
1004 .await?;
1005
1006 Ok(moved_channels)
1007 })
1008 .await
1009 }
1010}
1011
1012#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1013enum QueryUserIds {
1014 UserId,
1015}
1016
1017#[derive(Debug)]
1018pub struct ChannelGraph {
1019 pub channels: Vec<Channel>,
1020 pub edges: Vec<ChannelEdge>,
1021}
1022
1023impl ChannelGraph {
1024 pub fn is_empty(&self) -> bool {
1025 self.channels.is_empty() && self.edges.is_empty()
1026 }
1027}
1028
1029#[cfg(test)]
1030impl PartialEq for ChannelGraph {
1031 fn eq(&self, other: &Self) -> bool {
1032 // Order independent comparison for tests
1033 let channels_set = self.channels.iter().collect::<HashSet<_>>();
1034 let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1035 let edges_set = self
1036 .edges
1037 .iter()
1038 .map(|edge| (edge.channel_id, edge.parent_id))
1039 .collect::<HashSet<_>>();
1040 let other_edges_set = other
1041 .edges
1042 .iter()
1043 .map(|edge| (edge.channel_id, edge.parent_id))
1044 .collect::<HashSet<_>>();
1045
1046 channels_set == other_channels_set && edges_set == other_edges_set
1047 }
1048}
1049
1050#[cfg(not(test))]
1051impl PartialEq for ChannelGraph {
1052 fn eq(&self, other: &Self) -> bool {
1053 self.channels == other.channels && self.edges == other.edges
1054 }
1055}
1056
1057struct SmallSet<T>(SmallVec<[T; 1]>);
1058
1059impl<T> Deref for SmallSet<T> {
1060 type Target = [T];
1061
1062 fn deref(&self) -> &Self::Target {
1063 self.0.deref()
1064 }
1065}
1066
1067impl<T> Default for SmallSet<T> {
1068 fn default() -> Self {
1069 Self(SmallVec::new())
1070 }
1071}
1072
1073impl<T> SmallSet<T> {
1074 fn insert(&mut self, value: T) -> bool
1075 where
1076 T: Ord,
1077 {
1078 match self.binary_search(&value) {
1079 Ok(_) => false,
1080 Err(ix) => {
1081 self.0.insert(ix, value);
1082 true
1083 }
1084 }
1085 }
1086
1087 fn clear(&mut self) {
1088 self.0.clear();
1089 }
1090}