1use super::*;
2
3impl Database {
4 #[cfg(test)]
5 pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
6 self.transaction(move |tx| async move {
7 let mut channels = Vec::new();
8 let mut rows = channel::Entity::find().stream(&*tx).await?;
9 while let Some(row) = rows.next().await {
10 let row = row?;
11 channels.push((row.id, row.name));
12 }
13 Ok(channels)
14 })
15 .await
16 }
17
18 pub async fn create_root_channel(
19 &self,
20 name: &str,
21 live_kit_room: &str,
22 creator_id: UserId,
23 ) -> Result<ChannelId> {
24 self.create_channel(name, None, live_kit_room, creator_id)
25 .await
26 }
27
28 pub async fn create_channel(
29 &self,
30 name: &str,
31 parent: Option<ChannelId>,
32 live_kit_room: &str,
33 creator_id: UserId,
34 ) -> Result<ChannelId> {
35 let name = Self::sanitize_channel_name(name)?;
36 self.transaction(move |tx| async move {
37 if let Some(parent) = parent {
38 self.check_user_is_channel_admin(parent, creator_id, &*tx)
39 .await?;
40 }
41
42 let channel = channel::ActiveModel {
43 name: ActiveValue::Set(name.to_string()),
44 ..Default::default()
45 }
46 .insert(&*tx)
47 .await?;
48
49 let channel_paths_stmt;
50 if let Some(parent) = parent {
51 let sql = r#"
52 INSERT INTO channel_paths
53 (id_path, channel_id)
54 SELECT
55 id_path || $1 || '/', $2
56 FROM
57 channel_paths
58 WHERE
59 channel_id = $3
60 "#;
61 channel_paths_stmt = Statement::from_sql_and_values(
62 self.pool.get_database_backend(),
63 sql,
64 [
65 channel.id.to_proto().into(),
66 channel.id.to_proto().into(),
67 parent.to_proto().into(),
68 ],
69 );
70 tx.execute(channel_paths_stmt).await?;
71 } else {
72 channel_path::Entity::insert(channel_path::ActiveModel {
73 channel_id: ActiveValue::Set(channel.id),
74 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
75 })
76 .exec(&*tx)
77 .await?;
78 }
79
80 channel_member::ActiveModel {
81 channel_id: ActiveValue::Set(channel.id),
82 user_id: ActiveValue::Set(creator_id),
83 accepted: ActiveValue::Set(true),
84 admin: ActiveValue::Set(true),
85 ..Default::default()
86 }
87 .insert(&*tx)
88 .await?;
89
90 room::ActiveModel {
91 channel_id: ActiveValue::Set(Some(channel.id)),
92 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
93 ..Default::default()
94 }
95 .insert(&*tx)
96 .await?;
97
98 Ok(channel.id)
99 })
100 .await
101 }
102
103 pub async fn remove_channel(
104 &self,
105 channel_id: ChannelId,
106 user_id: UserId,
107 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
108 self.transaction(move |tx| async move {
109 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
110 .await?;
111
112 // Don't remove descendant channels that have additional parents.
113 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
114 {
115 let mut channels_to_keep = channel_path::Entity::find()
116 .filter(
117 channel_path::Column::ChannelId
118 .is_in(
119 channels_to_remove
120 .keys()
121 .copied()
122 .filter(|&id| id != channel_id),
123 )
124 .and(
125 channel_path::Column::IdPath
126 .not_like(&format!("%/{}/%", channel_id)),
127 ),
128 )
129 .stream(&*tx)
130 .await?;
131 while let Some(row) = channels_to_keep.next().await {
132 let row = row?;
133 channels_to_remove.remove(&row.channel_id);
134 }
135 }
136
137 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
138 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
139 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
140 .select_only()
141 .column(channel_member::Column::UserId)
142 .distinct()
143 .into_values::<_, QueryUserIds>()
144 .all(&*tx)
145 .await?;
146
147 channel::Entity::delete_many()
148 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
149 .exec(&*tx)
150 .await?;
151
152 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
153 })
154 .await
155 }
156
157 pub async fn invite_channel_member(
158 &self,
159 channel_id: ChannelId,
160 invitee_id: UserId,
161 inviter_id: UserId,
162 is_admin: bool,
163 ) -> Result<()> {
164 self.transaction(move |tx| async move {
165 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
166 .await?;
167
168 channel_member::ActiveModel {
169 channel_id: ActiveValue::Set(channel_id),
170 user_id: ActiveValue::Set(invitee_id),
171 accepted: ActiveValue::Set(false),
172 admin: ActiveValue::Set(is_admin),
173 ..Default::default()
174 }
175 .insert(&*tx)
176 .await?;
177
178 Ok(())
179 })
180 .await
181 }
182
183 fn sanitize_channel_name(name: &str) -> Result<&str> {
184 let new_name = name.trim().trim_start_matches('#');
185 if new_name == "" {
186 Err(anyhow!("channel name can't be blank"))?;
187 }
188 Ok(new_name)
189 }
190
191 pub async fn rename_channel(
192 &self,
193 channel_id: ChannelId,
194 user_id: UserId,
195 new_name: &str,
196 ) -> Result<String> {
197 self.transaction(move |tx| async move {
198 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
199
200 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
201 .await?;
202
203 channel::ActiveModel {
204 id: ActiveValue::Unchanged(channel_id),
205 name: ActiveValue::Set(new_name.clone()),
206 ..Default::default()
207 }
208 .update(&*tx)
209 .await?;
210
211 Ok(new_name)
212 })
213 .await
214 }
215
216 pub async fn respond_to_channel_invite(
217 &self,
218 channel_id: ChannelId,
219 user_id: UserId,
220 accept: bool,
221 ) -> Result<()> {
222 self.transaction(move |tx| async move {
223 let rows_affected = if accept {
224 channel_member::Entity::update_many()
225 .set(channel_member::ActiveModel {
226 accepted: ActiveValue::Set(accept),
227 ..Default::default()
228 })
229 .filter(
230 channel_member::Column::ChannelId
231 .eq(channel_id)
232 .and(channel_member::Column::UserId.eq(user_id))
233 .and(channel_member::Column::Accepted.eq(false)),
234 )
235 .exec(&*tx)
236 .await?
237 .rows_affected
238 } else {
239 channel_member::ActiveModel {
240 channel_id: ActiveValue::Unchanged(channel_id),
241 user_id: ActiveValue::Unchanged(user_id),
242 ..Default::default()
243 }
244 .delete(&*tx)
245 .await?
246 .rows_affected
247 };
248
249 if rows_affected == 0 {
250 Err(anyhow!("no such invitation"))?;
251 }
252
253 Ok(())
254 })
255 .await
256 }
257
258 pub async fn remove_channel_member(
259 &self,
260 channel_id: ChannelId,
261 member_id: UserId,
262 remover_id: UserId,
263 ) -> Result<()> {
264 self.transaction(|tx| async move {
265 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
266 .await?;
267
268 let result = channel_member::Entity::delete_many()
269 .filter(
270 channel_member::Column::ChannelId
271 .eq(channel_id)
272 .and(channel_member::Column::UserId.eq(member_id)),
273 )
274 .exec(&*tx)
275 .await?;
276
277 if result.rows_affected == 0 {
278 Err(anyhow!("no such member"))?;
279 }
280
281 Ok(())
282 })
283 .await
284 }
285
286 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
287 self.transaction(|tx| async move {
288 let channel_invites = channel_member::Entity::find()
289 .filter(
290 channel_member::Column::UserId
291 .eq(user_id)
292 .and(channel_member::Column::Accepted.eq(false)),
293 )
294 .all(&*tx)
295 .await?;
296
297 let channels = channel::Entity::find()
298 .filter(
299 channel::Column::Id.is_in(
300 channel_invites
301 .into_iter()
302 .map(|channel_member| channel_member.channel_id),
303 ),
304 )
305 .all(&*tx)
306 .await?;
307
308 let channels = channels
309 .into_iter()
310 .map(|channel| Channel {
311 id: channel.id,
312 name: channel.name,
313 parent_id: None,
314 })
315 .collect();
316
317 Ok(channels)
318 })
319 .await
320 }
321
322 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
323 self.transaction(|tx| async move {
324 let tx = tx;
325
326 let channel_memberships = channel_member::Entity::find()
327 .filter(
328 channel_member::Column::UserId
329 .eq(user_id)
330 .and(channel_member::Column::Accepted.eq(true)),
331 )
332 .all(&*tx)
333 .await?;
334
335 let parents_by_child_id = self
336 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
337 .await?;
338
339 let channels_with_admin_privileges = channel_memberships
340 .iter()
341 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
342 .collect();
343
344 let mut channels = Vec::with_capacity(parents_by_child_id.len());
345 {
346 let mut rows = channel::Entity::find()
347 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
348 .stream(&*tx)
349 .await?;
350 while let Some(row) = rows.next().await {
351 let row = row?;
352 channels.push(Channel {
353 id: row.id,
354 name: row.name,
355 parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
356 });
357 }
358 }
359
360 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
361 enum QueryUserIdsAndChannelIds {
362 ChannelId,
363 UserId,
364 }
365
366 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
367 {
368 let mut rows = room_participant::Entity::find()
369 .inner_join(room::Entity)
370 .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
371 .select_only()
372 .column(room::Column::ChannelId)
373 .column(room_participant::Column::UserId)
374 .into_values::<_, QueryUserIdsAndChannelIds>()
375 .stream(&*tx)
376 .await?;
377 while let Some(row) = rows.next().await {
378 let row: (ChannelId, UserId) = row?;
379 channel_participants.entry(row.0).or_default().push(row.1)
380 }
381 }
382
383 Ok(ChannelsForUser {
384 channels,
385 channel_participants,
386 channels_with_admin_privileges,
387 })
388 })
389 .await
390 }
391
392 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
393 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
394 .await
395 }
396
397 pub async fn set_channel_member_admin(
398 &self,
399 channel_id: ChannelId,
400 from: UserId,
401 for_user: UserId,
402 admin: bool,
403 ) -> Result<()> {
404 self.transaction(|tx| async move {
405 self.check_user_is_channel_admin(channel_id, from, &*tx)
406 .await?;
407
408 let result = channel_member::Entity::update_many()
409 .filter(
410 channel_member::Column::ChannelId
411 .eq(channel_id)
412 .and(channel_member::Column::UserId.eq(for_user)),
413 )
414 .set(channel_member::ActiveModel {
415 admin: ActiveValue::set(admin),
416 ..Default::default()
417 })
418 .exec(&*tx)
419 .await?;
420
421 if result.rows_affected == 0 {
422 Err(anyhow!("no such member"))?;
423 }
424
425 Ok(())
426 })
427 .await
428 }
429
430 pub async fn get_channel_member_details(
431 &self,
432 channel_id: ChannelId,
433 user_id: UserId,
434 ) -> Result<Vec<proto::ChannelMember>> {
435 self.transaction(|tx| async move {
436 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
437 .await?;
438
439 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
440 enum QueryMemberDetails {
441 UserId,
442 Admin,
443 IsDirectMember,
444 Accepted,
445 }
446
447 let tx = tx;
448 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
449 let mut stream = channel_member::Entity::find()
450 .distinct()
451 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
452 .select_only()
453 .column(channel_member::Column::UserId)
454 .column(channel_member::Column::Admin)
455 .column_as(
456 channel_member::Column::ChannelId.eq(channel_id),
457 QueryMemberDetails::IsDirectMember,
458 )
459 .column(channel_member::Column::Accepted)
460 .order_by_asc(channel_member::Column::UserId)
461 .into_values::<_, QueryMemberDetails>()
462 .stream(&*tx)
463 .await?;
464
465 let mut rows = Vec::<proto::ChannelMember>::new();
466 while let Some(row) = stream.next().await {
467 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
468 UserId,
469 bool,
470 bool,
471 bool,
472 ) = row?;
473 let kind = match (is_direct_member, is_invite_accepted) {
474 (true, true) => proto::channel_member::Kind::Member,
475 (true, false) => proto::channel_member::Kind::Invitee,
476 (false, true) => proto::channel_member::Kind::AncestorMember,
477 (false, false) => continue,
478 };
479 let user_id = user_id.to_proto();
480 let kind = kind.into();
481 if let Some(last_row) = rows.last_mut() {
482 if last_row.user_id == user_id {
483 if is_direct_member {
484 last_row.kind = kind;
485 last_row.admin = is_admin;
486 }
487 continue;
488 }
489 }
490 rows.push(proto::ChannelMember {
491 user_id,
492 kind,
493 admin: is_admin,
494 });
495 }
496
497 Ok(rows)
498 })
499 .await
500 }
501
502 pub async fn get_channel_members_internal(
503 &self,
504 id: ChannelId,
505 tx: &DatabaseTransaction,
506 ) -> Result<Vec<UserId>> {
507 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
508 let user_ids = channel_member::Entity::find()
509 .distinct()
510 .filter(
511 channel_member::Column::ChannelId
512 .is_in(ancestor_ids.iter().copied())
513 .and(channel_member::Column::Accepted.eq(true)),
514 )
515 .select_only()
516 .column(channel_member::Column::UserId)
517 .into_values::<_, QueryUserIds>()
518 .all(&*tx)
519 .await?;
520 Ok(user_ids)
521 }
522
523 pub async fn check_user_is_channel_member(
524 &self,
525 channel_id: ChannelId,
526 user_id: UserId,
527 tx: &DatabaseTransaction,
528 ) -> Result<()> {
529 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
530 channel_member::Entity::find()
531 .filter(
532 channel_member::Column::ChannelId
533 .is_in(channel_ids)
534 .and(channel_member::Column::UserId.eq(user_id)),
535 )
536 .one(&*tx)
537 .await?
538 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
539 Ok(())
540 }
541
542 pub async fn check_user_is_channel_admin(
543 &self,
544 channel_id: ChannelId,
545 user_id: UserId,
546 tx: &DatabaseTransaction,
547 ) -> Result<()> {
548 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
549 channel_member::Entity::find()
550 .filter(
551 channel_member::Column::ChannelId
552 .is_in(channel_ids)
553 .and(channel_member::Column::UserId.eq(user_id))
554 .and(channel_member::Column::Admin.eq(true)),
555 )
556 .one(&*tx)
557 .await?
558 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
559 Ok(())
560 }
561
562 pub async fn get_channel_ancestors(
563 &self,
564 channel_id: ChannelId,
565 tx: &DatabaseTransaction,
566 ) -> Result<Vec<ChannelId>> {
567 let paths = channel_path::Entity::find()
568 .filter(channel_path::Column::ChannelId.eq(channel_id))
569 .all(tx)
570 .await?;
571 let mut channel_ids = Vec::new();
572 for path in paths {
573 for id in path.id_path.trim_matches('/').split('/') {
574 if let Ok(id) = id.parse() {
575 let id = ChannelId::from_proto(id);
576 if let Err(ix) = channel_ids.binary_search(&id) {
577 channel_ids.insert(ix, id);
578 }
579 }
580 }
581 }
582 Ok(channel_ids)
583 }
584
585 async fn get_channel_descendants(
586 &self,
587 channel_ids: impl IntoIterator<Item = ChannelId>,
588 tx: &DatabaseTransaction,
589 ) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
590 let mut values = String::new();
591 for id in channel_ids {
592 if !values.is_empty() {
593 values.push_str(", ");
594 }
595 write!(&mut values, "({})", id).unwrap();
596 }
597
598 if values.is_empty() {
599 return Ok(HashMap::default());
600 }
601
602 let sql = format!(
603 r#"
604 SELECT
605 descendant_paths.*
606 FROM
607 channel_paths parent_paths, channel_paths descendant_paths
608 WHERE
609 parent_paths.channel_id IN ({values}) AND
610 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
611 "#
612 );
613
614 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
615
616 let mut parents_by_child_id = HashMap::default();
617 let mut paths = channel_path::Entity::find()
618 .from_raw_sql(stmt)
619 .stream(tx)
620 .await?;
621
622 while let Some(path) = paths.next().await {
623 let path = path?;
624 let ids = path.id_path.trim_matches('/').split('/');
625 let mut parent_id = None;
626 for id in ids {
627 if let Ok(id) = id.parse() {
628 let id = ChannelId::from_proto(id);
629 if id == path.channel_id {
630 break;
631 }
632 parent_id = Some(id);
633 }
634 }
635 parents_by_child_id.insert(path.channel_id, parent_id);
636 }
637
638 Ok(parents_by_child_id)
639 }
640
641 /// Returns the channel with the given ID and:
642 /// - true if the user is a member
643 /// - false if the user hasn't accepted the invitation yet
644 pub async fn get_channel(
645 &self,
646 channel_id: ChannelId,
647 user_id: UserId,
648 ) -> Result<Option<(Channel, bool)>> {
649 self.transaction(|tx| async move {
650 let tx = tx;
651
652 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
653
654 if let Some(channel) = channel {
655 if self
656 .check_user_is_channel_member(channel_id, user_id, &*tx)
657 .await
658 .is_err()
659 {
660 return Ok(None);
661 }
662
663 let channel_membership = channel_member::Entity::find()
664 .filter(
665 channel_member::Column::ChannelId
666 .eq(channel_id)
667 .and(channel_member::Column::UserId.eq(user_id)),
668 )
669 .one(&*tx)
670 .await?;
671
672 let is_accepted = channel_membership
673 .map(|membership| membership.accepted)
674 .unwrap_or(false);
675
676 Ok(Some((
677 Channel {
678 id: channel.id,
679 name: channel.name,
680 parent_id: None,
681 },
682 is_accepted,
683 )))
684 } else {
685 Ok(None)
686 }
687 })
688 .await
689 }
690
691 pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
692 self.transaction(|tx| async move {
693 let tx = tx;
694 let room = channel::Model {
695 id: channel_id,
696 ..Default::default()
697 }
698 .find_related(room::Entity)
699 .one(&*tx)
700 .await?
701 .ok_or_else(|| anyhow!("invalid channel"))?;
702 Ok(room.id)
703 })
704 .await
705 }
706}
707
708#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
709enum QueryUserIds {
710 UserId,
711}