1use super::*;
2
3impl Database {
4 pub async fn create_root_channel(
5 &self,
6 name: &str,
7 live_kit_room: &str,
8 creator_id: UserId,
9 ) -> Result<ChannelId> {
10 self.create_channel(name, None, live_kit_room, creator_id)
11 .await
12 }
13
14 pub async fn create_channel(
15 &self,
16 name: &str,
17 parent: Option<ChannelId>,
18 live_kit_room: &str,
19 creator_id: UserId,
20 ) -> Result<ChannelId> {
21 let name = Self::sanitize_channel_name(name)?;
22 self.transaction(move |tx| async move {
23 if let Some(parent) = parent {
24 self.check_user_is_channel_admin(parent, creator_id, &*tx)
25 .await?;
26 }
27
28 let channel = channel::ActiveModel {
29 name: ActiveValue::Set(name.to_string()),
30 ..Default::default()
31 }
32 .insert(&*tx)
33 .await?;
34
35 let channel_paths_stmt;
36 if let Some(parent) = parent {
37 let sql = r#"
38 INSERT INTO channel_paths
39 (id_path, channel_id)
40 SELECT
41 id_path || $1 || '/', $2
42 FROM
43 channel_paths
44 WHERE
45 channel_id = $3
46 "#;
47 channel_paths_stmt = Statement::from_sql_and_values(
48 self.pool.get_database_backend(),
49 sql,
50 [
51 channel.id.to_proto().into(),
52 channel.id.to_proto().into(),
53 parent.to_proto().into(),
54 ],
55 );
56 tx.execute(channel_paths_stmt).await?;
57 } else {
58 channel_path::Entity::insert(channel_path::ActiveModel {
59 channel_id: ActiveValue::Set(channel.id),
60 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
61 })
62 .exec(&*tx)
63 .await?;
64 }
65
66 channel_member::ActiveModel {
67 channel_id: ActiveValue::Set(channel.id),
68 user_id: ActiveValue::Set(creator_id),
69 accepted: ActiveValue::Set(true),
70 admin: ActiveValue::Set(true),
71 ..Default::default()
72 }
73 .insert(&*tx)
74 .await?;
75
76 room::ActiveModel {
77 channel_id: ActiveValue::Set(Some(channel.id)),
78 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
79 ..Default::default()
80 }
81 .insert(&*tx)
82 .await?;
83
84 Ok(channel.id)
85 })
86 .await
87 }
88
89 pub async fn remove_channel(
90 &self,
91 channel_id: ChannelId,
92 user_id: UserId,
93 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
94 self.transaction(move |tx| async move {
95 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
96 .await?;
97
98 // Don't remove descendant channels that have additional parents.
99 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
100 {
101 let mut channels_to_keep = channel_path::Entity::find()
102 .filter(
103 channel_path::Column::ChannelId
104 .is_in(
105 channels_to_remove
106 .keys()
107 .copied()
108 .filter(|&id| id != channel_id),
109 )
110 .and(
111 channel_path::Column::IdPath
112 .not_like(&format!("%/{}/%", channel_id)),
113 ),
114 )
115 .stream(&*tx)
116 .await?;
117 while let Some(row) = channels_to_keep.next().await {
118 let row = row?;
119 channels_to_remove.remove(&row.channel_id);
120 }
121 }
122
123 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
124 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
125 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
126 .select_only()
127 .column(channel_member::Column::UserId)
128 .distinct()
129 .into_values::<_, QueryUserIds>()
130 .all(&*tx)
131 .await?;
132
133 channel::Entity::delete_many()
134 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
135 .exec(&*tx)
136 .await?;
137
138 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
139 })
140 .await
141 }
142
143 pub async fn invite_channel_member(
144 &self,
145 channel_id: ChannelId,
146 invitee_id: UserId,
147 inviter_id: UserId,
148 is_admin: bool,
149 ) -> Result<()> {
150 self.transaction(move |tx| async move {
151 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
152 .await?;
153
154 channel_member::ActiveModel {
155 channel_id: ActiveValue::Set(channel_id),
156 user_id: ActiveValue::Set(invitee_id),
157 accepted: ActiveValue::Set(false),
158 admin: ActiveValue::Set(is_admin),
159 ..Default::default()
160 }
161 .insert(&*tx)
162 .await?;
163
164 Ok(())
165 })
166 .await
167 }
168
169 fn sanitize_channel_name(name: &str) -> Result<&str> {
170 let new_name = name.trim().trim_start_matches('#');
171 if new_name == "" {
172 Err(anyhow!("channel name can't be blank"))?;
173 }
174 Ok(new_name)
175 }
176
177 pub async fn rename_channel(
178 &self,
179 channel_id: ChannelId,
180 user_id: UserId,
181 new_name: &str,
182 ) -> Result<String> {
183 self.transaction(move |tx| async move {
184 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
185
186 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
187 .await?;
188
189 channel::ActiveModel {
190 id: ActiveValue::Unchanged(channel_id),
191 name: ActiveValue::Set(new_name.clone()),
192 ..Default::default()
193 }
194 .update(&*tx)
195 .await?;
196
197 Ok(new_name)
198 })
199 .await
200 }
201
202 pub async fn respond_to_channel_invite(
203 &self,
204 channel_id: ChannelId,
205 user_id: UserId,
206 accept: bool,
207 ) -> Result<()> {
208 self.transaction(move |tx| async move {
209 let rows_affected = if accept {
210 channel_member::Entity::update_many()
211 .set(channel_member::ActiveModel {
212 accepted: ActiveValue::Set(accept),
213 ..Default::default()
214 })
215 .filter(
216 channel_member::Column::ChannelId
217 .eq(channel_id)
218 .and(channel_member::Column::UserId.eq(user_id))
219 .and(channel_member::Column::Accepted.eq(false)),
220 )
221 .exec(&*tx)
222 .await?
223 .rows_affected
224 } else {
225 channel_member::ActiveModel {
226 channel_id: ActiveValue::Unchanged(channel_id),
227 user_id: ActiveValue::Unchanged(user_id),
228 ..Default::default()
229 }
230 .delete(&*tx)
231 .await?
232 .rows_affected
233 };
234
235 if rows_affected == 0 {
236 Err(anyhow!("no such invitation"))?;
237 }
238
239 Ok(())
240 })
241 .await
242 }
243
244 pub async fn remove_channel_member(
245 &self,
246 channel_id: ChannelId,
247 member_id: UserId,
248 remover_id: UserId,
249 ) -> Result<()> {
250 self.transaction(|tx| async move {
251 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
252 .await?;
253
254 let result = channel_member::Entity::delete_many()
255 .filter(
256 channel_member::Column::ChannelId
257 .eq(channel_id)
258 .and(channel_member::Column::UserId.eq(member_id)),
259 )
260 .exec(&*tx)
261 .await?;
262
263 if result.rows_affected == 0 {
264 Err(anyhow!("no such member"))?;
265 }
266
267 Ok(())
268 })
269 .await
270 }
271
272 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
273 self.transaction(|tx| async move {
274 let channel_invites = channel_member::Entity::find()
275 .filter(
276 channel_member::Column::UserId
277 .eq(user_id)
278 .and(channel_member::Column::Accepted.eq(false)),
279 )
280 .all(&*tx)
281 .await?;
282
283 let channels = channel::Entity::find()
284 .filter(
285 channel::Column::Id.is_in(
286 channel_invites
287 .into_iter()
288 .map(|channel_member| channel_member.channel_id),
289 ),
290 )
291 .all(&*tx)
292 .await?;
293
294 let channels = channels
295 .into_iter()
296 .map(|channel| Channel {
297 id: channel.id,
298 name: channel.name,
299 parent_id: None,
300 })
301 .collect();
302
303 Ok(channels)
304 })
305 .await
306 }
307
308 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
309 self.transaction(|tx| async move {
310 let tx = tx;
311
312 let channel_memberships = channel_member::Entity::find()
313 .filter(
314 channel_member::Column::UserId
315 .eq(user_id)
316 .and(channel_member::Column::Accepted.eq(true)),
317 )
318 .all(&*tx)
319 .await?;
320
321 let parents_by_child_id = self
322 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
323 .await?;
324
325 let channels_with_admin_privileges = channel_memberships
326 .iter()
327 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
328 .collect();
329
330 let mut channels = Vec::with_capacity(parents_by_child_id.len());
331 {
332 let mut rows = channel::Entity::find()
333 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
334 .stream(&*tx)
335 .await?;
336 while let Some(row) = rows.next().await {
337 let row = row?;
338 channels.push(Channel {
339 id: row.id,
340 name: row.name,
341 parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
342 });
343 }
344 }
345
346 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
347 enum QueryUserIdsAndChannelIds {
348 ChannelId,
349 UserId,
350 }
351
352 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
353 {
354 let mut rows = room_participant::Entity::find()
355 .inner_join(room::Entity)
356 .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
357 .select_only()
358 .column(room::Column::ChannelId)
359 .column(room_participant::Column::UserId)
360 .into_values::<_, QueryUserIdsAndChannelIds>()
361 .stream(&*tx)
362 .await?;
363 while let Some(row) = rows.next().await {
364 let row: (ChannelId, UserId) = row?;
365 channel_participants.entry(row.0).or_default().push(row.1)
366 }
367 }
368
369 Ok(ChannelsForUser {
370 channels,
371 channel_participants,
372 channels_with_admin_privileges,
373 })
374 })
375 .await
376 }
377
378 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
379 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
380 .await
381 }
382
383 pub async fn set_channel_member_admin(
384 &self,
385 channel_id: ChannelId,
386 from: UserId,
387 for_user: UserId,
388 admin: bool,
389 ) -> Result<()> {
390 self.transaction(|tx| async move {
391 self.check_user_is_channel_admin(channel_id, from, &*tx)
392 .await?;
393
394 let result = channel_member::Entity::update_many()
395 .filter(
396 channel_member::Column::ChannelId
397 .eq(channel_id)
398 .and(channel_member::Column::UserId.eq(for_user)),
399 )
400 .set(channel_member::ActiveModel {
401 admin: ActiveValue::set(admin),
402 ..Default::default()
403 })
404 .exec(&*tx)
405 .await?;
406
407 if result.rows_affected == 0 {
408 Err(anyhow!("no such member"))?;
409 }
410
411 Ok(())
412 })
413 .await
414 }
415
416 pub async fn get_channel_member_details(
417 &self,
418 channel_id: ChannelId,
419 user_id: UserId,
420 ) -> Result<Vec<proto::ChannelMember>> {
421 self.transaction(|tx| async move {
422 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
423 .await?;
424
425 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
426 enum QueryMemberDetails {
427 UserId,
428 Admin,
429 IsDirectMember,
430 Accepted,
431 }
432
433 let tx = tx;
434 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
435 let mut stream = channel_member::Entity::find()
436 .distinct()
437 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
438 .select_only()
439 .column(channel_member::Column::UserId)
440 .column(channel_member::Column::Admin)
441 .column_as(
442 channel_member::Column::ChannelId.eq(channel_id),
443 QueryMemberDetails::IsDirectMember,
444 )
445 .column(channel_member::Column::Accepted)
446 .order_by_asc(channel_member::Column::UserId)
447 .into_values::<_, QueryMemberDetails>()
448 .stream(&*tx)
449 .await?;
450
451 let mut rows = Vec::<proto::ChannelMember>::new();
452 while let Some(row) = stream.next().await {
453 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
454 UserId,
455 bool,
456 bool,
457 bool,
458 ) = row?;
459 let kind = match (is_direct_member, is_invite_accepted) {
460 (true, true) => proto::channel_member::Kind::Member,
461 (true, false) => proto::channel_member::Kind::Invitee,
462 (false, true) => proto::channel_member::Kind::AncestorMember,
463 (false, false) => continue,
464 };
465 let user_id = user_id.to_proto();
466 let kind = kind.into();
467 if let Some(last_row) = rows.last_mut() {
468 if last_row.user_id == user_id {
469 if is_direct_member {
470 last_row.kind = kind;
471 last_row.admin = is_admin;
472 }
473 continue;
474 }
475 }
476 rows.push(proto::ChannelMember {
477 user_id,
478 kind,
479 admin: is_admin,
480 });
481 }
482
483 Ok(rows)
484 })
485 .await
486 }
487
488 pub async fn get_channel_members_internal(
489 &self,
490 id: ChannelId,
491 tx: &DatabaseTransaction,
492 ) -> Result<Vec<UserId>> {
493 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
494 let user_ids = channel_member::Entity::find()
495 .distinct()
496 .filter(
497 channel_member::Column::ChannelId
498 .is_in(ancestor_ids.iter().copied())
499 .and(channel_member::Column::Accepted.eq(true)),
500 )
501 .select_only()
502 .column(channel_member::Column::UserId)
503 .into_values::<_, QueryUserIds>()
504 .all(&*tx)
505 .await?;
506 Ok(user_ids)
507 }
508
509 pub async fn check_user_is_channel_member(
510 &self,
511 channel_id: ChannelId,
512 user_id: UserId,
513 tx: &DatabaseTransaction,
514 ) -> Result<()> {
515 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
516 channel_member::Entity::find()
517 .filter(
518 channel_member::Column::ChannelId
519 .is_in(channel_ids)
520 .and(channel_member::Column::UserId.eq(user_id)),
521 )
522 .one(&*tx)
523 .await?
524 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
525 Ok(())
526 }
527
528 pub async fn check_user_is_channel_admin(
529 &self,
530 channel_id: ChannelId,
531 user_id: UserId,
532 tx: &DatabaseTransaction,
533 ) -> Result<()> {
534 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
535 channel_member::Entity::find()
536 .filter(
537 channel_member::Column::ChannelId
538 .is_in(channel_ids)
539 .and(channel_member::Column::UserId.eq(user_id))
540 .and(channel_member::Column::Admin.eq(true)),
541 )
542 .one(&*tx)
543 .await?
544 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
545 Ok(())
546 }
547
548 pub async fn get_channel_ancestors(
549 &self,
550 channel_id: ChannelId,
551 tx: &DatabaseTransaction,
552 ) -> Result<Vec<ChannelId>> {
553 let paths = channel_path::Entity::find()
554 .filter(channel_path::Column::ChannelId.eq(channel_id))
555 .all(tx)
556 .await?;
557 let mut channel_ids = Vec::new();
558 for path in paths {
559 for id in path.id_path.trim_matches('/').split('/') {
560 if let Ok(id) = id.parse() {
561 let id = ChannelId::from_proto(id);
562 if let Err(ix) = channel_ids.binary_search(&id) {
563 channel_ids.insert(ix, id);
564 }
565 }
566 }
567 }
568 Ok(channel_ids)
569 }
570
571 async fn get_channel_descendants(
572 &self,
573 channel_ids: impl IntoIterator<Item = ChannelId>,
574 tx: &DatabaseTransaction,
575 ) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
576 let mut values = String::new();
577 for id in channel_ids {
578 if !values.is_empty() {
579 values.push_str(", ");
580 }
581 write!(&mut values, "({})", id).unwrap();
582 }
583
584 if values.is_empty() {
585 return Ok(HashMap::default());
586 }
587
588 let sql = format!(
589 r#"
590 SELECT
591 descendant_paths.*
592 FROM
593 channel_paths parent_paths, channel_paths descendant_paths
594 WHERE
595 parent_paths.channel_id IN ({values}) AND
596 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
597 "#
598 );
599
600 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
601
602 let mut parents_by_child_id = HashMap::default();
603 let mut paths = channel_path::Entity::find()
604 .from_raw_sql(stmt)
605 .stream(tx)
606 .await?;
607
608 while let Some(path) = paths.next().await {
609 let path = path?;
610 let ids = path.id_path.trim_matches('/').split('/');
611 let mut parent_id = None;
612 for id in ids {
613 if let Ok(id) = id.parse() {
614 let id = ChannelId::from_proto(id);
615 if id == path.channel_id {
616 break;
617 }
618 parent_id = Some(id);
619 }
620 }
621 parents_by_child_id.insert(path.channel_id, parent_id);
622 }
623
624 Ok(parents_by_child_id)
625 }
626
627 /// Returns the channel with the given ID and:
628 /// - true if the user is a member
629 /// - false if the user hasn't accepted the invitation yet
630 pub async fn get_channel(
631 &self,
632 channel_id: ChannelId,
633 user_id: UserId,
634 ) -> Result<Option<(Channel, bool)>> {
635 self.transaction(|tx| async move {
636 let tx = tx;
637
638 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
639
640 if let Some(channel) = channel {
641 if self
642 .check_user_is_channel_member(channel_id, user_id, &*tx)
643 .await
644 .is_err()
645 {
646 return Ok(None);
647 }
648
649 let channel_membership = channel_member::Entity::find()
650 .filter(
651 channel_member::Column::ChannelId
652 .eq(channel_id)
653 .and(channel_member::Column::UserId.eq(user_id)),
654 )
655 .one(&*tx)
656 .await?;
657
658 let is_accepted = channel_membership
659 .map(|membership| membership.accepted)
660 .unwrap_or(false);
661
662 Ok(Some((
663 Channel {
664 id: channel.id,
665 name: channel.name,
666 parent_id: None,
667 },
668 is_accepted,
669 )))
670 } else {
671 Ok(None)
672 }
673 })
674 .await
675 }
676
677 pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
678 self.transaction(|tx| async move {
679 let tx = tx;
680 let room = channel::Model {
681 id: channel_id,
682 ..Default::default()
683 }
684 .find_related(room::Entity)
685 .one(&*tx)
686 .await?
687 .ok_or_else(|| anyhow!("invalid channel"))?;
688 Ok(room.id)
689 })
690 .await
691 }
692}
693
694#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
695enum QueryUserIds {
696 UserId,
697}