1use super::*;
2use rpc::proto::{channel_member::Kind, ChannelEdge};
3
4impl Database {
5 #[cfg(test)]
6 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
7 self.transaction(move |tx| async move {
8 let mut channels = Vec::new();
9 let mut rows = channel::Entity::find().stream(&*tx).await?;
10 while let Some(row) = rows.next().await {
11 let row = row?;
12 channels.push((row.id, row.name));
13 }
14 Ok(channels)
15 })
16 .await
17 }
18
19 pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
20 self.create_channel(name, None, creator_id).await
21 }
22
23 pub async fn create_channel(
24 &self,
25 name: &str,
26 parent: Option<ChannelId>,
27 creator_id: UserId,
28 ) -> Result<ChannelId> {
29 let name = Self::sanitize_channel_name(name)?;
30 self.transaction(move |tx| async move {
31 if let Some(parent) = parent {
32 self.check_user_is_channel_admin(parent, creator_id, &*tx)
33 .await?;
34 }
35
36 let channel = channel::ActiveModel {
37 id: ActiveValue::NotSet,
38 name: ActiveValue::Set(name.to_string()),
39 visibility: ActiveValue::Set(ChannelVisibility::Members),
40 }
41 .insert(&*tx)
42 .await?;
43
44 if let Some(parent) = parent {
45 let sql = r#"
46 INSERT INTO channel_paths
47 (id_path, channel_id)
48 SELECT
49 id_path || $1 || '/', $2
50 FROM
51 channel_paths
52 WHERE
53 channel_id = $3
54 "#;
55 let channel_paths_stmt = Statement::from_sql_and_values(
56 self.pool.get_database_backend(),
57 sql,
58 [
59 channel.id.to_proto().into(),
60 channel.id.to_proto().into(),
61 parent.to_proto().into(),
62 ],
63 );
64 tx.execute(channel_paths_stmt).await?;
65 } else {
66 channel_path::Entity::insert(channel_path::ActiveModel {
67 channel_id: ActiveValue::Set(channel.id),
68 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
69 })
70 .exec(&*tx)
71 .await?;
72 }
73
74 channel_member::ActiveModel {
75 id: ActiveValue::NotSet,
76 channel_id: ActiveValue::Set(channel.id),
77 user_id: ActiveValue::Set(creator_id),
78 accepted: ActiveValue::Set(true),
79 role: ActiveValue::Set(ChannelRole::Admin),
80 }
81 .insert(&*tx)
82 .await?;
83
84 Ok(channel.id)
85 })
86 .await
87 }
88
89 pub async fn join_channel(
90 &self,
91 channel_id: ChannelId,
92 user_id: UserId,
93 connection: ConnectionId,
94 environment: &str,
95 ) -> Result<(JoinRoom, Option<ChannelId>)> {
96 self.transaction(move |tx| async move {
97 let mut joined_channel_id = None;
98
99 let channel = channel::Entity::find()
100 .filter(channel::Column::Id.eq(channel_id))
101 .one(&*tx)
102 .await?;
103
104 let mut role = self
105 .channel_role_for_user(channel_id, user_id, &*tx)
106 .await?;
107
108 if role.is_none() && channel.is_some() {
109 if let Some(invitation) = self
110 .pending_invite_for_channel(channel_id, user_id, &*tx)
111 .await?
112 {
113 // note, this may be a parent channel
114 joined_channel_id = Some(invitation.channel_id);
115 role = Some(invitation.role);
116
117 channel_member::Entity::update(channel_member::ActiveModel {
118 accepted: ActiveValue::Set(true),
119 ..invitation.into_active_model()
120 })
121 .exec(&*tx)
122 .await?;
123
124 debug_assert!(
125 self.channel_role_for_user(channel_id, user_id, &*tx)
126 .await?
127 == role
128 );
129 }
130 }
131 if role.is_none()
132 && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
133 {
134 let channel_id_to_join = self
135 .most_public_ancestor_for_channel(channel_id, &*tx)
136 .await?
137 .unwrap_or(channel_id);
138 // TODO: change this back to Guest.
139 role = Some(ChannelRole::Member);
140 joined_channel_id = Some(channel_id_to_join);
141
142 channel_member::Entity::insert(channel_member::ActiveModel {
143 id: ActiveValue::NotSet,
144 channel_id: ActiveValue::Set(channel_id_to_join),
145 user_id: ActiveValue::Set(user_id),
146 accepted: ActiveValue::Set(true),
147 // TODO: change this back to Guest.
148 role: ActiveValue::Set(ChannelRole::Member),
149 })
150 .exec(&*tx)
151 .await?;
152
153 debug_assert!(
154 self.channel_role_for_user(channel_id, user_id, &*tx)
155 .await?
156 == role
157 );
158 }
159
160 if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
161 Err(anyhow!("no such channel, or not allowed"))?
162 }
163
164 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
165 let room_id = self
166 .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
167 .await?;
168
169 self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
170 .await
171 .map(|jr| (jr, joined_channel_id))
172 })
173 .await
174 }
175
176 pub async fn set_channel_visibility(
177 &self,
178 channel_id: ChannelId,
179 visibility: ChannelVisibility,
180 user_id: UserId,
181 ) -> Result<channel::Model> {
182 self.transaction(move |tx| async move {
183 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
184 .await?;
185
186 let channel = channel::ActiveModel {
187 id: ActiveValue::Unchanged(channel_id),
188 visibility: ActiveValue::Set(visibility),
189 ..Default::default()
190 }
191 .update(&*tx)
192 .await?;
193
194 Ok(channel)
195 })
196 .await
197 }
198
199 pub async fn delete_channel(
200 &self,
201 channel_id: ChannelId,
202 user_id: UserId,
203 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
204 self.transaction(move |tx| async move {
205 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
206 .await?;
207
208 // Don't remove descendant channels that have additional parents.
209 let mut channels_to_remove: HashSet<ChannelId> = HashSet::default();
210 channels_to_remove.insert(channel_id);
211
212 let graph = self.get_channel_descendants([channel_id], &*tx).await?;
213 for edge in graph.iter() {
214 channels_to_remove.insert(ChannelId::from_proto(edge.channel_id));
215 }
216
217 {
218 let mut channels_to_keep = channel_path::Entity::find()
219 .filter(
220 channel_path::Column::ChannelId
221 .is_in(channels_to_remove.iter().copied())
222 .and(
223 channel_path::Column::IdPath
224 .not_like(&format!("%/{}/%", channel_id)),
225 ),
226 )
227 .stream(&*tx)
228 .await?;
229 while let Some(row) = channels_to_keep.next().await {
230 let row = row?;
231 channels_to_remove.remove(&row.channel_id);
232 }
233 }
234
235 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
236 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
237 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
238 .select_only()
239 .column(channel_member::Column::UserId)
240 .distinct()
241 .into_values::<_, QueryUserIds>()
242 .all(&*tx)
243 .await?;
244
245 channel::Entity::delete_many()
246 .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied()))
247 .exec(&*tx)
248 .await?;
249
250 // Delete any other paths that include this channel
251 let sql = r#"
252 DELETE FROM channel_paths
253 WHERE
254 id_path LIKE '%' || $1 || '%'
255 "#;
256 let channel_paths_stmt = Statement::from_sql_and_values(
257 self.pool.get_database_backend(),
258 sql,
259 [channel_id.to_proto().into()],
260 );
261 tx.execute(channel_paths_stmt).await?;
262
263 Ok((channels_to_remove.into_iter().collect(), members_to_notify))
264 })
265 .await
266 }
267
268 pub async fn invite_channel_member(
269 &self,
270 channel_id: ChannelId,
271 invitee_id: UserId,
272 admin_id: UserId,
273 role: ChannelRole,
274 ) -> Result<()> {
275 self.transaction(move |tx| async move {
276 self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
277 .await?;
278
279 channel_member::ActiveModel {
280 id: ActiveValue::NotSet,
281 channel_id: ActiveValue::Set(channel_id),
282 user_id: ActiveValue::Set(invitee_id),
283 accepted: ActiveValue::Set(false),
284 role: ActiveValue::Set(role),
285 }
286 .insert(&*tx)
287 .await?;
288
289 Ok(())
290 })
291 .await
292 }
293
294 fn sanitize_channel_name(name: &str) -> Result<&str> {
295 let new_name = name.trim().trim_start_matches('#');
296 if new_name == "" {
297 Err(anyhow!("channel name can't be blank"))?;
298 }
299 Ok(new_name)
300 }
301
302 pub async fn rename_channel(
303 &self,
304 channel_id: ChannelId,
305 user_id: UserId,
306 new_name: &str,
307 ) -> Result<Channel> {
308 self.transaction(move |tx| async move {
309 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
310
311 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
312 .await?;
313
314 let channel = channel::ActiveModel {
315 id: ActiveValue::Unchanged(channel_id),
316 name: ActiveValue::Set(new_name.clone()),
317 ..Default::default()
318 }
319 .update(&*tx)
320 .await?;
321
322 Ok(Channel {
323 id: channel.id,
324 name: channel.name,
325 visibility: channel.visibility,
326 })
327 })
328 .await
329 }
330
331 pub async fn respond_to_channel_invite(
332 &self,
333 channel_id: ChannelId,
334 user_id: UserId,
335 accept: bool,
336 ) -> Result<()> {
337 self.transaction(move |tx| async move {
338 let rows_affected = if accept {
339 channel_member::Entity::update_many()
340 .set(channel_member::ActiveModel {
341 accepted: ActiveValue::Set(accept),
342 ..Default::default()
343 })
344 .filter(
345 channel_member::Column::ChannelId
346 .eq(channel_id)
347 .and(channel_member::Column::UserId.eq(user_id))
348 .and(channel_member::Column::Accepted.eq(false)),
349 )
350 .exec(&*tx)
351 .await?
352 .rows_affected
353 } else {
354 channel_member::ActiveModel {
355 channel_id: ActiveValue::Unchanged(channel_id),
356 user_id: ActiveValue::Unchanged(user_id),
357 ..Default::default()
358 }
359 .delete(&*tx)
360 .await?
361 .rows_affected
362 };
363
364 if rows_affected == 0 {
365 Err(anyhow!("no such invitation"))?;
366 }
367
368 Ok(())
369 })
370 .await
371 }
372
373 pub async fn remove_channel_member(
374 &self,
375 channel_id: ChannelId,
376 member_id: UserId,
377 admin_id: UserId,
378 ) -> Result<()> {
379 self.transaction(|tx| async move {
380 self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
381 .await?;
382
383 let result = channel_member::Entity::delete_many()
384 .filter(
385 channel_member::Column::ChannelId
386 .eq(channel_id)
387 .and(channel_member::Column::UserId.eq(member_id)),
388 )
389 .exec(&*tx)
390 .await?;
391
392 if result.rows_affected == 0 {
393 Err(anyhow!("no such member"))?;
394 }
395
396 Ok(())
397 })
398 .await
399 }
400
401 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
402 self.transaction(|tx| async move {
403 let channel_invites = channel_member::Entity::find()
404 .filter(
405 channel_member::Column::UserId
406 .eq(user_id)
407 .and(channel_member::Column::Accepted.eq(false)),
408 )
409 .all(&*tx)
410 .await?;
411
412 let channels = channel::Entity::find()
413 .filter(
414 channel::Column::Id.is_in(
415 channel_invites
416 .into_iter()
417 .map(|channel_member| channel_member.channel_id),
418 ),
419 )
420 .all(&*tx)
421 .await?;
422
423 let channels = channels
424 .into_iter()
425 .map(|channel| Channel {
426 id: channel.id,
427 name: channel.name,
428 visibility: channel.visibility,
429 })
430 .collect();
431
432 Ok(channels)
433 })
434 .await
435 }
436
437 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
438 self.transaction(|tx| async move {
439 let tx = tx;
440
441 let channel_memberships = channel_member::Entity::find()
442 .filter(
443 channel_member::Column::UserId
444 .eq(user_id)
445 .and(channel_member::Column::Accepted.eq(true)),
446 )
447 .all(&*tx)
448 .await?;
449
450 self.get_user_channels(user_id, channel_memberships, &tx)
451 .await
452 })
453 .await
454 }
455
456 pub async fn get_channel_for_user(
457 &self,
458 channel_id: ChannelId,
459 user_id: UserId,
460 ) -> Result<ChannelsForUser> {
461 self.transaction(|tx| async move {
462 let tx = tx;
463
464 let channel_membership = channel_member::Entity::find()
465 .filter(
466 channel_member::Column::UserId
467 .eq(user_id)
468 .and(channel_member::Column::ChannelId.eq(channel_id))
469 .and(channel_member::Column::Accepted.eq(true)),
470 )
471 .all(&*tx)
472 .await?;
473
474 self.get_user_channels(user_id, channel_membership, &tx)
475 .await
476 })
477 .await
478 }
479
480 pub async fn get_user_channels(
481 &self,
482 user_id: UserId,
483 channel_memberships: Vec<channel_member::Model>,
484 tx: &DatabaseTransaction,
485 ) -> Result<ChannelsForUser> {
486 let mut edges = self
487 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
488 .await?;
489
490 let mut role_for_channel: HashMap<ChannelId, ChannelRole> = HashMap::default();
491
492 for membership in channel_memberships.iter() {
493 role_for_channel.insert(membership.channel_id, membership.role);
494 }
495
496 for ChannelEdge {
497 parent_id,
498 channel_id,
499 } in edges.iter()
500 {
501 let parent_id = ChannelId::from_proto(*parent_id);
502 let channel_id = ChannelId::from_proto(*channel_id);
503 debug_assert!(role_for_channel.get(&parent_id).is_some());
504 let parent_role = role_for_channel[&parent_id];
505 if let Some(existing_role) = role_for_channel.get(&channel_id) {
506 if existing_role.should_override(parent_role) {
507 continue;
508 }
509 }
510 role_for_channel.insert(channel_id, parent_role);
511 }
512
513 let mut channels: Vec<Channel> = Vec::new();
514 let mut channels_with_admin_privileges: HashSet<ChannelId> = HashSet::default();
515 let mut channels_to_remove: HashSet<u64> = HashSet::default();
516
517 let mut rows = channel::Entity::find()
518 .filter(channel::Column::Id.is_in(role_for_channel.keys().copied()))
519 .stream(&*tx)
520 .await?;
521
522 while let Some(row) = rows.next().await {
523 let channel = row?;
524 let role = role_for_channel[&channel.id];
525
526 if role == ChannelRole::Banned
527 || role == ChannelRole::Guest && channel.visibility != ChannelVisibility::Public
528 {
529 channels_to_remove.insert(channel.id.0 as u64);
530 continue;
531 }
532
533 channels.push(Channel {
534 id: channel.id,
535 name: channel.name,
536 visibility: channel.visibility,
537 });
538
539 if role == ChannelRole::Admin {
540 channels_with_admin_privileges.insert(channel.id);
541 }
542 }
543 drop(rows);
544
545 if !channels_to_remove.is_empty() {
546 // Note: this code assumes each channel has one parent.
547 // If there are multiple valid public paths to a channel,
548 // e.g.
549 // If both of these paths are present (* indicating public):
550 // - zed* -> projects -> vim*
551 // - zed* -> conrad -> public-projects* -> vim*
552 // Users would only see one of them (based on edge sort order)
553 let mut replacement_parent: HashMap<u64, u64> = HashMap::default();
554 for ChannelEdge {
555 parent_id,
556 channel_id,
557 } in edges.iter()
558 {
559 if channels_to_remove.contains(channel_id) {
560 replacement_parent.insert(*channel_id, *parent_id);
561 }
562 }
563
564 let mut new_edges: Vec<ChannelEdge> = Vec::new();
565 'outer: for ChannelEdge {
566 mut parent_id,
567 channel_id,
568 } in edges.iter()
569 {
570 if channels_to_remove.contains(channel_id) {
571 continue;
572 }
573 while channels_to_remove.contains(&parent_id) {
574 if let Some(new_parent_id) = replacement_parent.get(&parent_id) {
575 parent_id = *new_parent_id;
576 } else {
577 continue 'outer;
578 }
579 }
580 new_edges.push(ChannelEdge {
581 parent_id,
582 channel_id: *channel_id,
583 })
584 }
585 edges = new_edges;
586 }
587
588 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
589 enum QueryUserIdsAndChannelIds {
590 ChannelId,
591 UserId,
592 }
593
594 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
595 {
596 let mut rows = room_participant::Entity::find()
597 .inner_join(room::Entity)
598 .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
599 .select_only()
600 .column(room::Column::ChannelId)
601 .column(room_participant::Column::UserId)
602 .into_values::<_, QueryUserIdsAndChannelIds>()
603 .stream(&*tx)
604 .await?;
605 while let Some(row) = rows.next().await {
606 let row: (ChannelId, UserId) = row?;
607 channel_participants.entry(row.0).or_default().push(row.1)
608 }
609 }
610
611 let channel_ids = channels.iter().map(|c| c.id).collect::<Vec<_>>();
612 let channel_buffer_changes = self
613 .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
614 .await?;
615
616 let unseen_messages = self
617 .unseen_channel_messages(user_id, &channel_ids, &*tx)
618 .await?;
619
620 Ok(ChannelsForUser {
621 channels: ChannelGraph { channels, edges },
622 channel_participants,
623 channels_with_admin_privileges,
624 unseen_buffer_changes: channel_buffer_changes,
625 channel_messages: unseen_messages,
626 })
627 }
628
629 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
630 self.transaction(|tx| async move { self.get_channel_participants_internal(id, &*tx).await })
631 .await
632 }
633
634 pub async fn set_channel_member_role(
635 &self,
636 channel_id: ChannelId,
637 admin_id: UserId,
638 for_user: UserId,
639 role: ChannelRole,
640 ) -> Result<channel_member::Model> {
641 self.transaction(|tx| async move {
642 self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
643 .await?;
644
645 let membership = channel_member::Entity::find()
646 .filter(
647 channel_member::Column::ChannelId
648 .eq(channel_id)
649 .and(channel_member::Column::UserId.eq(for_user)),
650 )
651 .one(&*tx)
652 .await?;
653
654 let Some(membership) = membership else {
655 Err(anyhow!("no such member"))?
656 };
657
658 let mut update = membership.into_active_model();
659 update.role = ActiveValue::Set(role);
660 let updated = channel_member::Entity::update(update).exec(&*tx).await?;
661
662 Ok(updated)
663 })
664 .await
665 }
666
667 pub async fn get_channel_participant_details(
668 &self,
669 channel_id: ChannelId,
670 admin_id: UserId,
671 ) -> Result<Vec<proto::ChannelMember>> {
672 self.transaction(|tx| async move {
673 self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
674 .await?;
675
676 let channel_visibility = channel::Entity::find()
677 .filter(channel::Column::Id.eq(channel_id))
678 .one(&*tx)
679 .await?
680 .map(|channel| channel.visibility)
681 .unwrap_or(ChannelVisibility::Members);
682
683 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
684 enum QueryMemberDetails {
685 UserId,
686 Role,
687 IsDirectMember,
688 Accepted,
689 Visibility,
690 }
691
692 let tx = tx;
693 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
694 let mut stream = channel_member::Entity::find()
695 .left_join(channel::Entity)
696 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
697 .select_only()
698 .column(channel_member::Column::UserId)
699 .column(channel_member::Column::Role)
700 .column_as(
701 channel_member::Column::ChannelId.eq(channel_id),
702 QueryMemberDetails::IsDirectMember,
703 )
704 .column(channel_member::Column::Accepted)
705 .column(channel::Column::Visibility)
706 .into_values::<_, QueryMemberDetails>()
707 .stream(&*tx)
708 .await?;
709
710 struct UserDetail {
711 kind: Kind,
712 channel_role: ChannelRole,
713 }
714 let mut user_details: HashMap<UserId, UserDetail> = HashMap::default();
715
716 while let Some(user_membership) = stream.next().await {
717 let (user_id, channel_role, is_direct_member, is_invite_accepted, visibility): (
718 UserId,
719 ChannelRole,
720 bool,
721 bool,
722 ChannelVisibility,
723 ) = user_membership?;
724 let kind = match (is_direct_member, is_invite_accepted) {
725 (true, true) => proto::channel_member::Kind::Member,
726 (true, false) => proto::channel_member::Kind::Invitee,
727 (false, true) => proto::channel_member::Kind::AncestorMember,
728 (false, false) => continue,
729 };
730
731 if channel_role == ChannelRole::Guest
732 && visibility != ChannelVisibility::Public
733 && channel_visibility != ChannelVisibility::Public
734 {
735 continue;
736 }
737
738 if let Some(details_mut) = user_details.get_mut(&user_id) {
739 if channel_role.should_override(details_mut.channel_role) {
740 details_mut.channel_role = channel_role;
741 }
742 if kind == Kind::Member {
743 details_mut.kind = kind;
744 // the UI is going to be a bit confusing if you already have permissions
745 // that are greater than or equal to the ones you're being invited to.
746 } else if kind == Kind::Invitee && details_mut.kind == Kind::AncestorMember {
747 details_mut.kind = kind;
748 }
749 } else {
750 user_details.insert(user_id, UserDetail { kind, channel_role });
751 }
752 }
753
754 Ok(user_details
755 .into_iter()
756 .map(|(user_id, details)| proto::ChannelMember {
757 user_id: user_id.to_proto(),
758 kind: details.kind.into(),
759 role: details.channel_role.into(),
760 })
761 .collect())
762 })
763 .await
764 }
765
766 pub async fn get_channel_participants_internal(
767 &self,
768 id: ChannelId,
769 tx: &DatabaseTransaction,
770 ) -> Result<Vec<UserId>> {
771 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
772 let user_ids = channel_member::Entity::find()
773 .distinct()
774 .filter(
775 channel_member::Column::ChannelId
776 .is_in(ancestor_ids.iter().copied())
777 .and(channel_member::Column::Accepted.eq(true)),
778 )
779 .select_only()
780 .column(channel_member::Column::UserId)
781 .into_values::<_, QueryUserIds>()
782 .all(&*tx)
783 .await?;
784 Ok(user_ids)
785 }
786
787 pub async fn check_user_is_channel_admin(
788 &self,
789 channel_id: ChannelId,
790 user_id: UserId,
791 tx: &DatabaseTransaction,
792 ) -> Result<()> {
793 match self.channel_role_for_user(channel_id, user_id, tx).await? {
794 Some(ChannelRole::Admin) => Ok(()),
795 Some(ChannelRole::Member)
796 | Some(ChannelRole::Banned)
797 | Some(ChannelRole::Guest)
798 | None => Err(anyhow!(
799 "user is not a channel admin or channel does not exist"
800 ))?,
801 }
802 }
803
804 pub async fn check_user_is_channel_member(
805 &self,
806 channel_id: ChannelId,
807 user_id: UserId,
808 tx: &DatabaseTransaction,
809 ) -> Result<()> {
810 match self.channel_role_for_user(channel_id, user_id, tx).await? {
811 Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()),
812 Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
813 "user is not a channel member or channel does not exist"
814 ))?,
815 }
816 }
817
818 pub async fn check_user_is_channel_participant(
819 &self,
820 channel_id: ChannelId,
821 user_id: UserId,
822 tx: &DatabaseTransaction,
823 ) -> Result<()> {
824 match self.channel_role_for_user(channel_id, user_id, tx).await? {
825 Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
826 Ok(())
827 }
828 Some(ChannelRole::Banned) | None => Err(anyhow!(
829 "user is not a channel participant or channel does not exist"
830 ))?,
831 }
832 }
833
834 pub async fn pending_invite_for_channel(
835 &self,
836 channel_id: ChannelId,
837 user_id: UserId,
838 tx: &DatabaseTransaction,
839 ) -> Result<Option<channel_member::Model>> {
840 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
841
842 let row = channel_member::Entity::find()
843 .filter(channel_member::Column::ChannelId.is_in(channel_ids))
844 .filter(channel_member::Column::UserId.eq(user_id))
845 .filter(channel_member::Column::Accepted.eq(false))
846 .one(&*tx)
847 .await?;
848
849 Ok(row)
850 }
851
852 pub async fn most_public_ancestor_for_channel(
853 &self,
854 channel_id: ChannelId,
855 tx: &DatabaseTransaction,
856 ) -> Result<Option<ChannelId>> {
857 // Note: if there are many paths to a channel, this will return just one
858 let arbitary_path = channel_path::Entity::find()
859 .filter(channel_path::Column::ChannelId.eq(channel_id))
860 .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
861 .one(tx)
862 .await?;
863
864 let Some(path) = arbitary_path else {
865 return Ok(None);
866 };
867
868 let ancestor_ids: Vec<ChannelId> = path
869 .id_path
870 .trim_matches('/')
871 .split('/')
872 .map(|id| ChannelId::from_proto(id.parse().unwrap()))
873 .collect();
874
875 let rows = channel::Entity::find()
876 .filter(channel::Column::Id.is_in(ancestor_ids.iter().copied()))
877 .filter(channel::Column::Visibility.eq(ChannelVisibility::Public))
878 .all(&*tx)
879 .await?;
880
881 let mut visible_channels: HashSet<ChannelId> = HashSet::default();
882
883 for row in rows {
884 visible_channels.insert(row.id);
885 }
886
887 for ancestor in ancestor_ids {
888 if visible_channels.contains(&ancestor) {
889 return Ok(Some(ancestor));
890 }
891 }
892
893 Ok(None)
894 }
895
896 pub async fn channel_role_for_user(
897 &self,
898 channel_id: ChannelId,
899 user_id: UserId,
900 tx: &DatabaseTransaction,
901 ) -> Result<Option<ChannelRole>> {
902 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
903
904 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
905 enum QueryChannelMembership {
906 ChannelId,
907 Role,
908 Visibility,
909 }
910
911 let mut rows = channel_member::Entity::find()
912 .left_join(channel::Entity)
913 .filter(
914 channel_member::Column::ChannelId
915 .is_in(channel_ids)
916 .and(channel_member::Column::UserId.eq(user_id))
917 .and(channel_member::Column::Accepted.eq(true)),
918 )
919 .select_only()
920 .column(channel_member::Column::ChannelId)
921 .column(channel_member::Column::Role)
922 .column(channel::Column::Visibility)
923 .into_values::<_, QueryChannelMembership>()
924 .stream(&*tx)
925 .await?;
926
927 let mut user_role: Option<ChannelRole> = None;
928
929 let mut is_participant = false;
930 let mut current_channel_visibility = None;
931
932 // note these channels are not iterated in any particular order,
933 // our current logic takes the highest permission available.
934 while let Some(row) = rows.next().await {
935 let (membership_channel, role, visibility): (
936 ChannelId,
937 ChannelRole,
938 ChannelVisibility,
939 ) = row?;
940
941 match role {
942 ChannelRole::Admin | ChannelRole::Member | ChannelRole::Banned => {
943 if let Some(users_role) = user_role {
944 user_role = Some(users_role.max(role));
945 } else {
946 user_role = Some(role)
947 }
948 }
949 ChannelRole::Guest if visibility == ChannelVisibility::Public => {
950 is_participant = true
951 }
952 ChannelRole::Guest => {}
953 }
954 if channel_id == membership_channel {
955 current_channel_visibility = Some(visibility);
956 }
957 }
958 // free up database connection
959 drop(rows);
960
961 if is_participant && user_role.is_none() {
962 if current_channel_visibility.is_none() {
963 current_channel_visibility = channel::Entity::find()
964 .filter(channel::Column::Id.eq(channel_id))
965 .one(&*tx)
966 .await?
967 .map(|channel| channel.visibility);
968 }
969 if current_channel_visibility == Some(ChannelVisibility::Public) {
970 user_role = Some(ChannelRole::Guest);
971 }
972 }
973
974 Ok(user_role)
975 }
976
977 /// Returns the channel ancestors in arbitrary order
978 pub async fn get_channel_ancestors(
979 &self,
980 channel_id: ChannelId,
981 tx: &DatabaseTransaction,
982 ) -> Result<Vec<ChannelId>> {
983 let paths = channel_path::Entity::find()
984 .filter(channel_path::Column::ChannelId.eq(channel_id))
985 .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
986 .all(tx)
987 .await?;
988 let mut channel_ids = Vec::new();
989 for path in paths {
990 for id in path.id_path.trim_matches('/').split('/') {
991 if let Ok(id) = id.parse() {
992 let id = ChannelId::from_proto(id);
993 if let Err(ix) = channel_ids.binary_search(&id) {
994 channel_ids.insert(ix, id);
995 }
996 }
997 }
998 }
999 Ok(channel_ids)
1000 }
1001
1002 // Returns the channel desendants as a sorted list of edges for further processing.
1003 // The edges are sorted such that you will see unknown channel ids as children
1004 // before you see them as parents.
1005 async fn get_channel_descendants(
1006 &self,
1007 channel_ids: impl IntoIterator<Item = ChannelId>,
1008 tx: &DatabaseTransaction,
1009 ) -> Result<Vec<ChannelEdge>> {
1010 let mut values = String::new();
1011 for id in channel_ids {
1012 if !values.is_empty() {
1013 values.push_str(", ");
1014 }
1015 write!(&mut values, "({})", id).unwrap();
1016 }
1017
1018 if values.is_empty() {
1019 return Ok(vec![]);
1020 }
1021
1022 let sql = format!(
1023 r#"
1024 SELECT
1025 descendant_paths.*
1026 FROM
1027 channel_paths parent_paths, channel_paths descendant_paths
1028 WHERE
1029 parent_paths.channel_id IN ({values}) AND
1030 descendant_paths.id_path != parent_paths.id_path AND
1031 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
1032 ORDER BY
1033 descendant_paths.id_path
1034 "#
1035 );
1036
1037 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
1038
1039 let mut paths = channel_path::Entity::find()
1040 .from_raw_sql(stmt)
1041 .stream(tx)
1042 .await?;
1043
1044 let mut results: Vec<ChannelEdge> = Vec::new();
1045 while let Some(path) = paths.next().await {
1046 let path = path?;
1047 let ids: Vec<&str> = path.id_path.trim_matches('/').split('/').collect();
1048
1049 debug_assert!(ids.len() >= 2);
1050 debug_assert!(ids[ids.len() - 1] == path.channel_id.to_string());
1051
1052 results.push(ChannelEdge {
1053 parent_id: ids[ids.len() - 2].parse().unwrap(),
1054 channel_id: ids[ids.len() - 1].parse().unwrap(),
1055 })
1056 }
1057
1058 Ok(results)
1059 }
1060
1061 /// Returns the channel with the given ID
1062 pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result<Channel> {
1063 self.transaction(|tx| async move {
1064 self.check_user_is_channel_participant(channel_id, user_id, &*tx)
1065 .await?;
1066
1067 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
1068 let Some(channel) = channel else {
1069 Err(anyhow!("no such channel"))?
1070 };
1071
1072 Ok(Channel {
1073 id: channel.id,
1074 visibility: channel.visibility,
1075 name: channel.name,
1076 })
1077 })
1078 .await
1079 }
1080
1081 pub(crate) async fn get_or_create_channel_room(
1082 &self,
1083 channel_id: ChannelId,
1084 live_kit_room: &str,
1085 environment: &str,
1086 tx: &DatabaseTransaction,
1087 ) -> Result<RoomId> {
1088 let room = room::Entity::find()
1089 .filter(room::Column::ChannelId.eq(channel_id))
1090 .one(&*tx)
1091 .await?;
1092
1093 let room_id = if let Some(room) = room {
1094 if let Some(env) = room.enviroment {
1095 if &env != environment {
1096 Err(anyhow!("must join using the {} release", env))?;
1097 }
1098 }
1099 room.id
1100 } else {
1101 let result = room::Entity::insert(room::ActiveModel {
1102 channel_id: ActiveValue::Set(Some(channel_id)),
1103 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
1104 enviroment: ActiveValue::Set(Some(environment.to_string())),
1105 ..Default::default()
1106 })
1107 .exec(&*tx)
1108 .await?;
1109
1110 result.last_insert_id
1111 };
1112
1113 Ok(room_id)
1114 }
1115
1116 // Insert an edge from the given channel to the given other channel.
1117 pub async fn link_channel(
1118 &self,
1119 user: UserId,
1120 channel: ChannelId,
1121 to: ChannelId,
1122 ) -> Result<ChannelGraph> {
1123 self.transaction(|tx| async move {
1124 // Note that even with these maxed permissions, this linking operation
1125 // is still insecure because you can't remove someone's permissions to a
1126 // channel if they've linked the channel to one where they're an admin.
1127 self.check_user_is_channel_admin(channel, user, &*tx)
1128 .await?;
1129
1130 self.link_channel_internal(user, channel, to, &*tx).await
1131 })
1132 .await
1133 }
1134
1135 pub async fn link_channel_internal(
1136 &self,
1137 user: UserId,
1138 channel: ChannelId,
1139 new_parent: ChannelId,
1140 tx: &DatabaseTransaction,
1141 ) -> Result<ChannelGraph> {
1142 self.check_user_is_channel_admin(new_parent, user, &*tx)
1143 .await?;
1144
1145 let paths = channel_path::Entity::find()
1146 .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel)))
1147 .all(tx)
1148 .await?;
1149
1150 let mut new_path_suffixes = HashSet::default();
1151 for path in paths {
1152 if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) {
1153 new_path_suffixes.insert((
1154 path.channel_id,
1155 path.id_path[(start_offset + 1)..].to_string(),
1156 ));
1157 }
1158 }
1159
1160 let paths_to_new_parent = channel_path::Entity::find()
1161 .filter(channel_path::Column::ChannelId.eq(new_parent))
1162 .all(tx)
1163 .await?;
1164
1165 let mut new_paths = Vec::new();
1166 for path in paths_to_new_parent {
1167 if path.id_path.contains(&format!("/{}/", channel)) {
1168 Err(anyhow!("cycle"))?;
1169 }
1170
1171 new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| {
1172 channel_path::ActiveModel {
1173 channel_id: ActiveValue::Set(*channel_id),
1174 id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)),
1175 }
1176 }));
1177 }
1178
1179 channel_path::Entity::insert_many(new_paths)
1180 .exec(&*tx)
1181 .await?;
1182
1183 // remove any root edges for the channel we just linked
1184 {
1185 channel_path::Entity::delete_many()
1186 .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel)))
1187 .exec(&*tx)
1188 .await?;
1189 }
1190
1191 let membership = channel_member::Entity::find()
1192 .filter(
1193 channel_member::Column::ChannelId
1194 .eq(channel)
1195 .and(channel_member::Column::UserId.eq(user)),
1196 )
1197 .all(tx)
1198 .await?;
1199
1200 let mut channel_info = self.get_user_channels(user, membership, &*tx).await?;
1201
1202 channel_info.channels.edges.push(ChannelEdge {
1203 channel_id: channel.to_proto(),
1204 parent_id: new_parent.to_proto(),
1205 });
1206
1207 Ok(channel_info.channels)
1208 }
1209
1210 /// Unlink a channel from a given parent. This will add in a root edge if
1211 /// the channel has no other parents after this operation.
1212 pub async fn unlink_channel(
1213 &self,
1214 user: UserId,
1215 channel: ChannelId,
1216 from: ChannelId,
1217 ) -> Result<()> {
1218 self.transaction(|tx| async move {
1219 // Note that even with these maxed permissions, this linking operation
1220 // is still insecure because you can't remove someone's permissions to a
1221 // channel if they've linked the channel to one where they're an admin.
1222 self.check_user_is_channel_admin(channel, user, &*tx)
1223 .await?;
1224
1225 self.unlink_channel_internal(user, channel, from, &*tx)
1226 .await?;
1227
1228 Ok(())
1229 })
1230 .await
1231 }
1232
1233 pub async fn unlink_channel_internal(
1234 &self,
1235 user: UserId,
1236 channel: ChannelId,
1237 from: ChannelId,
1238 tx: &DatabaseTransaction,
1239 ) -> Result<()> {
1240 self.check_user_is_channel_admin(from, user, &*tx).await?;
1241
1242 let sql = r#"
1243 DELETE FROM channel_paths
1244 WHERE
1245 id_path LIKE '%/' || $1 || '/' || $2 || '/%'
1246 RETURNING id_path, channel_id
1247 "#;
1248
1249 let paths = channel_path::Entity::find()
1250 .from_raw_sql(Statement::from_sql_and_values(
1251 self.pool.get_database_backend(),
1252 sql,
1253 [from.to_proto().into(), channel.to_proto().into()],
1254 ))
1255 .all(&*tx)
1256 .await?;
1257
1258 let is_stranded = channel_path::Entity::find()
1259 .filter(channel_path::Column::ChannelId.eq(channel))
1260 .count(&*tx)
1261 .await?
1262 == 0;
1263
1264 // Make sure that there is always at least one path to the channel
1265 if is_stranded {
1266 let root_paths: Vec<_> = paths
1267 .iter()
1268 .map(|path| {
1269 let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap();
1270 channel_path::ActiveModel {
1271 channel_id: ActiveValue::Set(path.channel_id),
1272 id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()),
1273 }
1274 })
1275 .collect();
1276 channel_path::Entity::insert_many(root_paths)
1277 .exec(&*tx)
1278 .await?;
1279 }
1280
1281 Ok(())
1282 }
1283
1284 /// Move a channel from one parent to another, returns the
1285 /// Channels that were moved for notifying clients
1286 pub async fn move_channel(
1287 &self,
1288 user: UserId,
1289 channel: ChannelId,
1290 from: ChannelId,
1291 to: ChannelId,
1292 ) -> Result<ChannelGraph> {
1293 if from == to {
1294 return Ok(ChannelGraph {
1295 channels: vec![],
1296 edges: vec![],
1297 });
1298 }
1299
1300 self.transaction(|tx| async move {
1301 self.check_user_is_channel_admin(channel, user, &*tx)
1302 .await?;
1303
1304 let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
1305
1306 self.unlink_channel_internal(user, channel, from, &*tx)
1307 .await?;
1308
1309 Ok(moved_channels)
1310 })
1311 .await
1312 }
1313}
1314
1315#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1316enum QueryUserIds {
1317 UserId,
1318}
1319
1320#[derive(Debug)]
1321pub struct ChannelGraph {
1322 pub channels: Vec<Channel>,
1323 pub edges: Vec<ChannelEdge>,
1324}
1325
1326impl ChannelGraph {
1327 pub fn is_empty(&self) -> bool {
1328 self.channels.is_empty() && self.edges.is_empty()
1329 }
1330}
1331
1332#[cfg(test)]
1333impl PartialEq for ChannelGraph {
1334 fn eq(&self, other: &Self) -> bool {
1335 // Order independent comparison for tests
1336 let channels_set = self.channels.iter().collect::<HashSet<_>>();
1337 let other_channels_set = other.channels.iter().collect::<HashSet<_>>();
1338 let edges_set = self
1339 .edges
1340 .iter()
1341 .map(|edge| (edge.channel_id, edge.parent_id))
1342 .collect::<HashSet<_>>();
1343 let other_edges_set = other
1344 .edges
1345 .iter()
1346 .map(|edge| (edge.channel_id, edge.parent_id))
1347 .collect::<HashSet<_>>();
1348
1349 channels_set == other_channels_set && edges_set == other_edges_set
1350 }
1351}
1352
1353#[cfg(not(test))]
1354impl PartialEq for ChannelGraph {
1355 fn eq(&self, other: &Self) -> bool {
1356 self.channels == other.channels && self.edges == other.edges
1357 }
1358}