1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use futures::StreamExt;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
14 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
15 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
16 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
17 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
18 async fn destroy_user(&self, id: UserId) -> Result<()>;
19
20 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
21 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
22 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
23 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
24 async fn dismiss_contact_notification(
25 &self,
26 responder_id: UserId,
27 requester_id: UserId,
28 ) -> Result<()>;
29 async fn respond_to_contact_request(
30 &self,
31 responder_id: UserId,
32 requester_id: UserId,
33 accept: bool,
34 ) -> Result<()>;
35
36 async fn create_access_token_hash(
37 &self,
38 user_id: UserId,
39 access_token_hash: &str,
40 max_access_token_count: usize,
41 ) -> Result<()>;
42 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
43 #[cfg(any(test, feature = "seed-support"))]
44
45 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
46 #[cfg(any(test, feature = "seed-support"))]
47 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
48 #[cfg(any(test, feature = "seed-support"))]
49 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
50 #[cfg(any(test, feature = "seed-support"))]
51 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
52 #[cfg(any(test, feature = "seed-support"))]
53
54 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
55 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
56 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
57 -> Result<bool>;
58 #[cfg(any(test, feature = "seed-support"))]
59 async fn add_channel_member(
60 &self,
61 channel_id: ChannelId,
62 user_id: UserId,
63 is_admin: bool,
64 ) -> Result<()>;
65 async fn create_channel_message(
66 &self,
67 channel_id: ChannelId,
68 sender_id: UserId,
69 body: &str,
70 timestamp: OffsetDateTime,
71 nonce: u128,
72 ) -> Result<MessageId>;
73 async fn get_channel_messages(
74 &self,
75 channel_id: ChannelId,
76 count: usize,
77 before_id: Option<MessageId>,
78 ) -> Result<Vec<ChannelMessage>>;
79 #[cfg(test)]
80 async fn teardown(&self, url: &str);
81 #[cfg(test)]
82 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
83}
84
85pub struct PostgresDb {
86 pool: sqlx::PgPool,
87}
88
89impl PostgresDb {
90 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
91 let pool = DbOptions::new()
92 .max_connections(max_connections)
93 .connect(&url)
94 .await
95 .context("failed to connect to postgres database")?;
96 Ok(Self { pool })
97 }
98}
99
100#[async_trait]
101impl Db for PostgresDb {
102 // users
103
104 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
105 let query = "
106 INSERT INTO users (github_login, admin)
107 VALUES ($1, $2)
108 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
109 RETURNING id
110 ";
111 Ok(sqlx::query_scalar(query)
112 .bind(github_login)
113 .bind(admin)
114 .fetch_one(&self.pool)
115 .await
116 .map(UserId)?)
117 }
118
119 async fn get_all_users(&self) -> Result<Vec<User>> {
120 let query = "SELECT * FROM users ORDER BY github_login ASC";
121 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
122 }
123
124 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
125 let like_string = fuzzy_like_string(name_query);
126 let query = "
127 SELECT users.*
128 FROM users
129 WHERE github_login ILIKE $1
130 ORDER BY github_login <-> $2
131 LIMIT $3
132 ";
133 Ok(sqlx::query_as(query)
134 .bind(like_string)
135 .bind(name_query)
136 .bind(limit)
137 .fetch_all(&self.pool)
138 .await?)
139 }
140
141 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
142 let users = self.get_users_by_ids(vec![id]).await?;
143 Ok(users.into_iter().next())
144 }
145
146 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
147 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
148 let query = "
149 SELECT users.*
150 FROM users
151 WHERE users.id = ANY ($1)
152 ";
153 Ok(sqlx::query_as(query)
154 .bind(&ids)
155 .fetch_all(&self.pool)
156 .await?)
157 }
158
159 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
160 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
161 Ok(sqlx::query_as(query)
162 .bind(github_login)
163 .fetch_optional(&self.pool)
164 .await?)
165 }
166
167 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
168 let query = "UPDATE users SET admin = $1 WHERE id = $2";
169 Ok(sqlx::query(query)
170 .bind(is_admin)
171 .bind(id.0)
172 .execute(&self.pool)
173 .await
174 .map(drop)?)
175 }
176
177 async fn destroy_user(&self, id: UserId) -> Result<()> {
178 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
179 sqlx::query(query)
180 .bind(id.0)
181 .execute(&self.pool)
182 .await
183 .map(drop)?;
184 let query = "DELETE FROM users WHERE id = $1;";
185 Ok(sqlx::query(query)
186 .bind(id.0)
187 .execute(&self.pool)
188 .await
189 .map(drop)?)
190 }
191
192 // contacts
193
194 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
195 let query = "
196 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
197 FROM contacts
198 WHERE user_id_a = $1 OR user_id_b = $1;
199 ";
200
201 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
202 .bind(user_id)
203 .fetch(&self.pool);
204
205 let mut contacts = vec![Contact::Accepted {
206 user_id,
207 should_notify: false,
208 }];
209 while let Some(row) = rows.next().await {
210 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
211
212 if user_id_a == user_id {
213 if accepted {
214 contacts.push(Contact::Accepted {
215 user_id: user_id_b,
216 should_notify: should_notify && a_to_b,
217 });
218 } else if a_to_b {
219 contacts.push(Contact::Outgoing { user_id: user_id_b })
220 } else {
221 contacts.push(Contact::Incoming {
222 user_id: user_id_b,
223 should_notify,
224 });
225 }
226 } else {
227 if accepted {
228 contacts.push(Contact::Accepted {
229 user_id: user_id_a,
230 should_notify: should_notify && !a_to_b,
231 });
232 } else if a_to_b {
233 contacts.push(Contact::Incoming {
234 user_id: user_id_a,
235 should_notify,
236 });
237 } else {
238 contacts.push(Contact::Outgoing { user_id: user_id_a });
239 }
240 }
241 }
242
243 contacts.sort_unstable_by_key(|contact| contact.user_id());
244
245 Ok(contacts)
246 }
247
248 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
249 let (id_a, id_b) = if user_id_1 < user_id_2 {
250 (user_id_1, user_id_2)
251 } else {
252 (user_id_2, user_id_1)
253 };
254
255 let query = "
256 SELECT 1 FROM contacts
257 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
258 LIMIT 1
259 ";
260 Ok(sqlx::query_scalar::<_, i32>(query)
261 .bind(id_a.0)
262 .bind(id_b.0)
263 .fetch_optional(&self.pool)
264 .await?
265 .is_some())
266 }
267
268 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
269 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
270 (sender_id, receiver_id, true)
271 } else {
272 (receiver_id, sender_id, false)
273 };
274 let query = "
275 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
276 VALUES ($1, $2, $3, 'f', 't')
277 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
278 SET
279 accepted = 't',
280 should_notify = 'f'
281 WHERE
282 NOT contacts.accepted AND
283 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
284 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
285 ";
286 let result = sqlx::query(query)
287 .bind(id_a.0)
288 .bind(id_b.0)
289 .bind(a_to_b)
290 .execute(&self.pool)
291 .await?;
292
293 if result.rows_affected() == 1 {
294 Ok(())
295 } else {
296 Err(anyhow!("contact already requested"))
297 }
298 }
299
300 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
301 let (id_a, id_b) = if responder_id < requester_id {
302 (responder_id, requester_id)
303 } else {
304 (requester_id, responder_id)
305 };
306 let query = "
307 DELETE FROM contacts
308 WHERE user_id_a = $1 AND user_id_b = $2;
309 ";
310 let result = sqlx::query(query)
311 .bind(id_a.0)
312 .bind(id_b.0)
313 .execute(&self.pool)
314 .await?;
315
316 if result.rows_affected() == 1 {
317 Ok(())
318 } else {
319 Err(anyhow!("no such contact"))
320 }
321 }
322
323 async fn dismiss_contact_notification(
324 &self,
325 user_id: UserId,
326 contact_user_id: UserId,
327 ) -> Result<()> {
328 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
329 (user_id, contact_user_id, true)
330 } else {
331 (contact_user_id, user_id, false)
332 };
333
334 let query = "
335 UPDATE contacts
336 SET should_notify = 'f'
337 WHERE
338 user_id_a = $1 AND user_id_b = $2 AND
339 (
340 (a_to_b = $3 AND accepted) OR
341 (a_to_b != $3 AND NOT accepted)
342 );
343 ";
344
345 let result = sqlx::query(query)
346 .bind(id_a.0)
347 .bind(id_b.0)
348 .bind(a_to_b)
349 .execute(&self.pool)
350 .await?;
351
352 if result.rows_affected() == 0 {
353 Err(anyhow!("no such contact request"))?;
354 }
355
356 Ok(())
357 }
358
359 async fn respond_to_contact_request(
360 &self,
361 responder_id: UserId,
362 requester_id: UserId,
363 accept: bool,
364 ) -> Result<()> {
365 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
366 (responder_id, requester_id, false)
367 } else {
368 (requester_id, responder_id, true)
369 };
370 let result = if accept {
371 let query = "
372 UPDATE contacts
373 SET accepted = 't', should_notify = 't'
374 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
375 ";
376 sqlx::query(query)
377 .bind(id_a.0)
378 .bind(id_b.0)
379 .bind(a_to_b)
380 .execute(&self.pool)
381 .await?
382 } else {
383 let query = "
384 DELETE FROM contacts
385 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
386 ";
387 sqlx::query(query)
388 .bind(id_a.0)
389 .bind(id_b.0)
390 .bind(a_to_b)
391 .execute(&self.pool)
392 .await?
393 };
394 if result.rows_affected() == 1 {
395 Ok(())
396 } else {
397 Err(anyhow!("no such contact request"))
398 }
399 }
400
401 // access tokens
402
403 async fn create_access_token_hash(
404 &self,
405 user_id: UserId,
406 access_token_hash: &str,
407 max_access_token_count: usize,
408 ) -> Result<()> {
409 let insert_query = "
410 INSERT INTO access_tokens (user_id, hash)
411 VALUES ($1, $2);
412 ";
413 let cleanup_query = "
414 DELETE FROM access_tokens
415 WHERE id IN (
416 SELECT id from access_tokens
417 WHERE user_id = $1
418 ORDER BY id DESC
419 OFFSET $3
420 )
421 ";
422
423 let mut tx = self.pool.begin().await?;
424 sqlx::query(insert_query)
425 .bind(user_id.0)
426 .bind(access_token_hash)
427 .execute(&mut tx)
428 .await?;
429 sqlx::query(cleanup_query)
430 .bind(user_id.0)
431 .bind(access_token_hash)
432 .bind(max_access_token_count as u32)
433 .execute(&mut tx)
434 .await?;
435 Ok(tx.commit().await?)
436 }
437
438 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
439 let query = "
440 SELECT hash
441 FROM access_tokens
442 WHERE user_id = $1
443 ORDER BY id DESC
444 ";
445 Ok(sqlx::query_scalar(query)
446 .bind(user_id.0)
447 .fetch_all(&self.pool)
448 .await?)
449 }
450
451 // orgs
452
453 #[allow(unused)] // Help rust-analyzer
454 #[cfg(any(test, feature = "seed-support"))]
455 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
456 let query = "
457 SELECT *
458 FROM orgs
459 WHERE slug = $1
460 ";
461 Ok(sqlx::query_as(query)
462 .bind(slug)
463 .fetch_optional(&self.pool)
464 .await?)
465 }
466
467 #[cfg(any(test, feature = "seed-support"))]
468 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
469 let query = "
470 INSERT INTO orgs (name, slug)
471 VALUES ($1, $2)
472 RETURNING id
473 ";
474 Ok(sqlx::query_scalar(query)
475 .bind(name)
476 .bind(slug)
477 .fetch_one(&self.pool)
478 .await
479 .map(OrgId)?)
480 }
481
482 #[cfg(any(test, feature = "seed-support"))]
483 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
484 let query = "
485 INSERT INTO org_memberships (org_id, user_id, admin)
486 VALUES ($1, $2, $3)
487 ON CONFLICT DO NOTHING
488 ";
489 Ok(sqlx::query(query)
490 .bind(org_id.0)
491 .bind(user_id.0)
492 .bind(is_admin)
493 .execute(&self.pool)
494 .await
495 .map(drop)?)
496 }
497
498 // channels
499
500 #[cfg(any(test, feature = "seed-support"))]
501 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
502 let query = "
503 INSERT INTO channels (owner_id, owner_is_user, name)
504 VALUES ($1, false, $2)
505 RETURNING id
506 ";
507 Ok(sqlx::query_scalar(query)
508 .bind(org_id.0)
509 .bind(name)
510 .fetch_one(&self.pool)
511 .await
512 .map(ChannelId)?)
513 }
514
515 #[allow(unused)] // Help rust-analyzer
516 #[cfg(any(test, feature = "seed-support"))]
517 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
518 let query = "
519 SELECT *
520 FROM channels
521 WHERE
522 channels.owner_is_user = false AND
523 channels.owner_id = $1
524 ";
525 Ok(sqlx::query_as(query)
526 .bind(org_id.0)
527 .fetch_all(&self.pool)
528 .await?)
529 }
530
531 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
532 let query = "
533 SELECT
534 channels.*
535 FROM
536 channel_memberships, channels
537 WHERE
538 channel_memberships.user_id = $1 AND
539 channel_memberships.channel_id = channels.id
540 ";
541 Ok(sqlx::query_as(query)
542 .bind(user_id.0)
543 .fetch_all(&self.pool)
544 .await?)
545 }
546
547 async fn can_user_access_channel(
548 &self,
549 user_id: UserId,
550 channel_id: ChannelId,
551 ) -> Result<bool> {
552 let query = "
553 SELECT id
554 FROM channel_memberships
555 WHERE user_id = $1 AND channel_id = $2
556 LIMIT 1
557 ";
558 Ok(sqlx::query_scalar::<_, i32>(query)
559 .bind(user_id.0)
560 .bind(channel_id.0)
561 .fetch_optional(&self.pool)
562 .await
563 .map(|e| e.is_some())?)
564 }
565
566 #[cfg(any(test, feature = "seed-support"))]
567 async fn add_channel_member(
568 &self,
569 channel_id: ChannelId,
570 user_id: UserId,
571 is_admin: bool,
572 ) -> Result<()> {
573 let query = "
574 INSERT INTO channel_memberships (channel_id, user_id, admin)
575 VALUES ($1, $2, $3)
576 ON CONFLICT DO NOTHING
577 ";
578 Ok(sqlx::query(query)
579 .bind(channel_id.0)
580 .bind(user_id.0)
581 .bind(is_admin)
582 .execute(&self.pool)
583 .await
584 .map(drop)?)
585 }
586
587 // messages
588
589 async fn create_channel_message(
590 &self,
591 channel_id: ChannelId,
592 sender_id: UserId,
593 body: &str,
594 timestamp: OffsetDateTime,
595 nonce: u128,
596 ) -> Result<MessageId> {
597 let query = "
598 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
599 VALUES ($1, $2, $3, $4, $5)
600 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
601 RETURNING id
602 ";
603 Ok(sqlx::query_scalar(query)
604 .bind(channel_id.0)
605 .bind(sender_id.0)
606 .bind(body)
607 .bind(timestamp)
608 .bind(Uuid::from_u128(nonce))
609 .fetch_one(&self.pool)
610 .await
611 .map(MessageId)?)
612 }
613
614 async fn get_channel_messages(
615 &self,
616 channel_id: ChannelId,
617 count: usize,
618 before_id: Option<MessageId>,
619 ) -> Result<Vec<ChannelMessage>> {
620 let query = r#"
621 SELECT * FROM (
622 SELECT
623 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
624 FROM
625 channel_messages
626 WHERE
627 channel_id = $1 AND
628 id < $2
629 ORDER BY id DESC
630 LIMIT $3
631 ) as recent_messages
632 ORDER BY id ASC
633 "#;
634 Ok(sqlx::query_as(query)
635 .bind(channel_id.0)
636 .bind(before_id.unwrap_or(MessageId::MAX))
637 .bind(count as i64)
638 .fetch_all(&self.pool)
639 .await?)
640 }
641
642 #[cfg(test)]
643 async fn teardown(&self, url: &str) {
644 use util::ResultExt;
645
646 let query = "
647 SELECT pg_terminate_backend(pg_stat_activity.pid)
648 FROM pg_stat_activity
649 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
650 ";
651 sqlx::query(query).execute(&self.pool).await.log_err();
652 self.pool.close().await;
653 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
654 .await
655 .log_err();
656 }
657
658 #[cfg(test)]
659 fn as_fake(&self) -> Option<&tests::FakeDb> {
660 None
661 }
662}
663
664macro_rules! id_type {
665 ($name:ident) => {
666 #[derive(
667 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
668 )]
669 #[sqlx(transparent)]
670 #[serde(transparent)]
671 pub struct $name(pub i32);
672
673 impl $name {
674 #[allow(unused)]
675 pub const MAX: Self = Self(i32::MAX);
676
677 #[allow(unused)]
678 pub fn from_proto(value: u64) -> Self {
679 Self(value as i32)
680 }
681
682 #[allow(unused)]
683 pub fn to_proto(&self) -> u64 {
684 self.0 as u64
685 }
686 }
687
688 impl std::fmt::Display for $name {
689 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
690 self.0.fmt(f)
691 }
692 }
693 };
694}
695
696id_type!(UserId);
697#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
698pub struct User {
699 pub id: UserId,
700 pub github_login: String,
701 pub admin: bool,
702}
703
704id_type!(OrgId);
705#[derive(FromRow)]
706pub struct Org {
707 pub id: OrgId,
708 pub name: String,
709 pub slug: String,
710}
711
712id_type!(ChannelId);
713#[derive(Clone, Debug, FromRow, Serialize)]
714pub struct Channel {
715 pub id: ChannelId,
716 pub name: String,
717 pub owner_id: i32,
718 pub owner_is_user: bool,
719}
720
721id_type!(MessageId);
722#[derive(Clone, Debug, FromRow)]
723pub struct ChannelMessage {
724 pub id: MessageId,
725 pub channel_id: ChannelId,
726 pub sender_id: UserId,
727 pub body: String,
728 pub sent_at: OffsetDateTime,
729 pub nonce: Uuid,
730}
731
732#[derive(Clone, Debug, PartialEq, Eq)]
733pub enum Contact {
734 Accepted {
735 user_id: UserId,
736 should_notify: bool,
737 },
738 Outgoing {
739 user_id: UserId,
740 },
741 Incoming {
742 user_id: UserId,
743 should_notify: bool,
744 },
745}
746
747impl Contact {
748 pub fn user_id(&self) -> UserId {
749 match self {
750 Contact::Accepted { user_id, .. } => *user_id,
751 Contact::Outgoing { user_id } => *user_id,
752 Contact::Incoming { user_id, .. } => *user_id,
753 }
754 }
755}
756
757#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
758pub struct IncomingContactRequest {
759 pub requester_id: UserId,
760 pub should_notify: bool,
761}
762
763fn fuzzy_like_string(string: &str) -> String {
764 let mut result = String::with_capacity(string.len() * 2 + 1);
765 for c in string.chars() {
766 if c.is_alphanumeric() {
767 result.push('%');
768 result.push(c);
769 }
770 }
771 result.push('%');
772 result
773}
774
775#[cfg(test)]
776pub mod tests {
777 use super::*;
778 use anyhow::anyhow;
779 use collections::BTreeMap;
780 use gpui::executor::Background;
781 use lazy_static::lazy_static;
782 use parking_lot::Mutex;
783 use rand::prelude::*;
784 use sqlx::{
785 migrate::{MigrateDatabase, Migrator},
786 Postgres,
787 };
788 use std::{path::Path, sync::Arc};
789 use util::post_inc;
790
791 #[tokio::test(flavor = "multi_thread")]
792 async fn test_get_users_by_ids() {
793 for test_db in [
794 TestDb::postgres().await,
795 TestDb::fake(Arc::new(gpui::executor::Background::new())),
796 ] {
797 let db = test_db.db();
798
799 let user = db.create_user("user", false).await.unwrap();
800 let friend1 = db.create_user("friend-1", false).await.unwrap();
801 let friend2 = db.create_user("friend-2", false).await.unwrap();
802 let friend3 = db.create_user("friend-3", false).await.unwrap();
803
804 assert_eq!(
805 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
806 .await
807 .unwrap(),
808 vec![
809 User {
810 id: user,
811 github_login: "user".to_string(),
812 admin: false,
813 },
814 User {
815 id: friend1,
816 github_login: "friend-1".to_string(),
817 admin: false,
818 },
819 User {
820 id: friend2,
821 github_login: "friend-2".to_string(),
822 admin: false,
823 },
824 User {
825 id: friend3,
826 github_login: "friend-3".to_string(),
827 admin: false,
828 }
829 ]
830 );
831 }
832 }
833
834 #[tokio::test(flavor = "multi_thread")]
835 async fn test_recent_channel_messages() {
836 for test_db in [
837 TestDb::postgres().await,
838 TestDb::fake(Arc::new(gpui::executor::Background::new())),
839 ] {
840 let db = test_db.db();
841 let user = db.create_user("user", false).await.unwrap();
842 let org = db.create_org("org", "org").await.unwrap();
843 let channel = db.create_org_channel(org, "channel").await.unwrap();
844 for i in 0..10 {
845 db.create_channel_message(
846 channel,
847 user,
848 &i.to_string(),
849 OffsetDateTime::now_utc(),
850 i,
851 )
852 .await
853 .unwrap();
854 }
855
856 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
857 assert_eq!(
858 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
859 ["5", "6", "7", "8", "9"]
860 );
861
862 let prev_messages = db
863 .get_channel_messages(channel, 4, Some(messages[0].id))
864 .await
865 .unwrap();
866 assert_eq!(
867 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
868 ["1", "2", "3", "4"]
869 );
870 }
871 }
872
873 #[tokio::test(flavor = "multi_thread")]
874 async fn test_channel_message_nonces() {
875 for test_db in [
876 TestDb::postgres().await,
877 TestDb::fake(Arc::new(gpui::executor::Background::new())),
878 ] {
879 let db = test_db.db();
880 let user = db.create_user("user", false).await.unwrap();
881 let org = db.create_org("org", "org").await.unwrap();
882 let channel = db.create_org_channel(org, "channel").await.unwrap();
883
884 let msg1_id = db
885 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
886 .await
887 .unwrap();
888 let msg2_id = db
889 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
890 .await
891 .unwrap();
892 let msg3_id = db
893 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
894 .await
895 .unwrap();
896 let msg4_id = db
897 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
898 .await
899 .unwrap();
900
901 assert_ne!(msg1_id, msg2_id);
902 assert_eq!(msg1_id, msg3_id);
903 assert_eq!(msg2_id, msg4_id);
904 }
905 }
906
907 #[tokio::test(flavor = "multi_thread")]
908 async fn test_create_access_tokens() {
909 let test_db = TestDb::postgres().await;
910 let db = test_db.db();
911 let user = db.create_user("the-user", false).await.unwrap();
912
913 db.create_access_token_hash(user, "h1", 3).await.unwrap();
914 db.create_access_token_hash(user, "h2", 3).await.unwrap();
915 assert_eq!(
916 db.get_access_token_hashes(user).await.unwrap(),
917 &["h2".to_string(), "h1".to_string()]
918 );
919
920 db.create_access_token_hash(user, "h3", 3).await.unwrap();
921 assert_eq!(
922 db.get_access_token_hashes(user).await.unwrap(),
923 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
924 );
925
926 db.create_access_token_hash(user, "h4", 3).await.unwrap();
927 assert_eq!(
928 db.get_access_token_hashes(user).await.unwrap(),
929 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
930 );
931
932 db.create_access_token_hash(user, "h5", 3).await.unwrap();
933 assert_eq!(
934 db.get_access_token_hashes(user).await.unwrap(),
935 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
936 );
937 }
938
939 #[test]
940 fn test_fuzzy_like_string() {
941 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
942 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
943 assert_eq!(fuzzy_like_string(" z "), "%z%");
944 }
945
946 #[tokio::test(flavor = "multi_thread")]
947 async fn test_fuzzy_search_users() {
948 let test_db = TestDb::postgres().await;
949 let db = test_db.db();
950 for github_login in [
951 "California",
952 "colorado",
953 "oregon",
954 "washington",
955 "florida",
956 "delaware",
957 "rhode-island",
958 ] {
959 db.create_user(github_login, false).await.unwrap();
960 }
961
962 assert_eq!(
963 fuzzy_search_user_names(db, "clr").await,
964 &["colorado", "California"]
965 );
966 assert_eq!(
967 fuzzy_search_user_names(db, "ro").await,
968 &["rhode-island", "colorado", "oregon"],
969 );
970
971 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
972 db.fuzzy_search_users(query, 10)
973 .await
974 .unwrap()
975 .into_iter()
976 .map(|user| user.github_login)
977 .collect::<Vec<_>>()
978 }
979 }
980
981 #[tokio::test(flavor = "multi_thread")]
982 async fn test_add_contacts() {
983 for test_db in [
984 TestDb::postgres().await,
985 TestDb::fake(Arc::new(gpui::executor::Background::new())),
986 ] {
987 let db = test_db.db();
988
989 let user_1 = db.create_user("user1", false).await.unwrap();
990 let user_2 = db.create_user("user2", false).await.unwrap();
991 let user_3 = db.create_user("user3", false).await.unwrap();
992
993 // User starts with no contacts
994 assert_eq!(
995 db.get_contacts(user_1).await.unwrap(),
996 vec![Contact::Accepted {
997 user_id: user_1,
998 should_notify: false
999 }],
1000 );
1001
1002 // User requests a contact. Both users see the pending request.
1003 db.send_contact_request(user_1, user_2).await.unwrap();
1004 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1005 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1006 assert_eq!(
1007 db.get_contacts(user_1).await.unwrap(),
1008 &[
1009 Contact::Accepted {
1010 user_id: user_1,
1011 should_notify: false
1012 },
1013 Contact::Outgoing { user_id: user_2 }
1014 ],
1015 );
1016 assert_eq!(
1017 db.get_contacts(user_2).await.unwrap(),
1018 &[
1019 Contact::Incoming {
1020 user_id: user_1,
1021 should_notify: true
1022 },
1023 Contact::Accepted {
1024 user_id: user_2,
1025 should_notify: false
1026 },
1027 ]
1028 );
1029
1030 // User 2 dismisses the contact request notification without accepting or rejecting.
1031 // We shouldn't notify them again.
1032 db.dismiss_contact_notification(user_1, user_2)
1033 .await
1034 .unwrap_err();
1035 db.dismiss_contact_notification(user_2, user_1)
1036 .await
1037 .unwrap();
1038 assert_eq!(
1039 db.get_contacts(user_2).await.unwrap(),
1040 &[
1041 Contact::Incoming {
1042 user_id: user_1,
1043 should_notify: false
1044 },
1045 Contact::Accepted {
1046 user_id: user_2,
1047 should_notify: false
1048 },
1049 ]
1050 );
1051
1052 // User can't accept their own contact request
1053 db.respond_to_contact_request(user_1, user_2, true)
1054 .await
1055 .unwrap_err();
1056
1057 // User accepts a contact request. Both users see the contact.
1058 db.respond_to_contact_request(user_2, user_1, true)
1059 .await
1060 .unwrap();
1061 assert_eq!(
1062 db.get_contacts(user_1).await.unwrap(),
1063 &[
1064 Contact::Accepted {
1065 user_id: user_1,
1066 should_notify: false
1067 },
1068 Contact::Accepted {
1069 user_id: user_2,
1070 should_notify: true
1071 }
1072 ],
1073 );
1074 assert!(db.has_contact(user_1, user_2).await.unwrap());
1075 assert!(db.has_contact(user_2, user_1).await.unwrap());
1076 assert_eq!(
1077 db.get_contacts(user_2).await.unwrap(),
1078 &[
1079 Contact::Accepted {
1080 user_id: user_1,
1081 should_notify: false,
1082 },
1083 Contact::Accepted {
1084 user_id: user_2,
1085 should_notify: false,
1086 },
1087 ]
1088 );
1089
1090 // Users cannot re-request existing contacts.
1091 db.send_contact_request(user_1, user_2).await.unwrap_err();
1092 db.send_contact_request(user_2, user_1).await.unwrap_err();
1093
1094 // Users can't dismiss notifications of them accepting other users' requests.
1095 db.dismiss_contact_notification(user_2, user_1)
1096 .await
1097 .unwrap_err();
1098 assert_eq!(
1099 db.get_contacts(user_1).await.unwrap(),
1100 &[
1101 Contact::Accepted {
1102 user_id: user_1,
1103 should_notify: false
1104 },
1105 Contact::Accepted {
1106 user_id: user_2,
1107 should_notify: true,
1108 },
1109 ]
1110 );
1111
1112 // Users can dismiss notifications of other users accepting their requests.
1113 db.dismiss_contact_notification(user_1, user_2)
1114 .await
1115 .unwrap();
1116 assert_eq!(
1117 db.get_contacts(user_1).await.unwrap(),
1118 &[
1119 Contact::Accepted {
1120 user_id: user_1,
1121 should_notify: false
1122 },
1123 Contact::Accepted {
1124 user_id: user_2,
1125 should_notify: false,
1126 },
1127 ]
1128 );
1129
1130 // Users send each other concurrent contact requests and
1131 // see that they are immediately accepted.
1132 db.send_contact_request(user_1, user_3).await.unwrap();
1133 db.send_contact_request(user_3, user_1).await.unwrap();
1134 assert_eq!(
1135 db.get_contacts(user_1).await.unwrap(),
1136 &[
1137 Contact::Accepted {
1138 user_id: user_1,
1139 should_notify: false
1140 },
1141 Contact::Accepted {
1142 user_id: user_2,
1143 should_notify: false,
1144 },
1145 Contact::Accepted {
1146 user_id: user_3,
1147 should_notify: false
1148 },
1149 ]
1150 );
1151 assert_eq!(
1152 db.get_contacts(user_3).await.unwrap(),
1153 &[
1154 Contact::Accepted {
1155 user_id: user_1,
1156 should_notify: false
1157 },
1158 Contact::Accepted {
1159 user_id: user_3,
1160 should_notify: false
1161 }
1162 ],
1163 );
1164
1165 // User declines a contact request. Both users see that it is gone.
1166 db.send_contact_request(user_2, user_3).await.unwrap();
1167 db.respond_to_contact_request(user_3, user_2, false)
1168 .await
1169 .unwrap();
1170 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1171 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1172 assert_eq!(
1173 db.get_contacts(user_2).await.unwrap(),
1174 &[
1175 Contact::Accepted {
1176 user_id: user_1,
1177 should_notify: false
1178 },
1179 Contact::Accepted {
1180 user_id: user_2,
1181 should_notify: false
1182 }
1183 ]
1184 );
1185 assert_eq!(
1186 db.get_contacts(user_3).await.unwrap(),
1187 &[
1188 Contact::Accepted {
1189 user_id: user_1,
1190 should_notify: false
1191 },
1192 Contact::Accepted {
1193 user_id: user_3,
1194 should_notify: false
1195 }
1196 ],
1197 );
1198 }
1199 }
1200
1201 pub struct TestDb {
1202 pub db: Option<Arc<dyn Db>>,
1203 pub url: String,
1204 }
1205
1206 impl TestDb {
1207 pub async fn postgres() -> Self {
1208 lazy_static! {
1209 static ref LOCK: Mutex<()> = Mutex::new(());
1210 }
1211
1212 let _guard = LOCK.lock();
1213 let mut rng = StdRng::from_entropy();
1214 let name = format!("zed-test-{}", rng.gen::<u128>());
1215 let url = format!("postgres://postgres@localhost/{}", name);
1216 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1217 Postgres::create_database(&url)
1218 .await
1219 .expect("failed to create test db");
1220 let db = PostgresDb::new(&url, 5).await.unwrap();
1221 let migrator = Migrator::new(migrations_path).await.unwrap();
1222 migrator.run(&db.pool).await.unwrap();
1223 Self {
1224 db: Some(Arc::new(db)),
1225 url,
1226 }
1227 }
1228
1229 pub fn fake(background: Arc<Background>) -> Self {
1230 Self {
1231 db: Some(Arc::new(FakeDb::new(background))),
1232 url: Default::default(),
1233 }
1234 }
1235
1236 pub fn db(&self) -> &Arc<dyn Db> {
1237 self.db.as_ref().unwrap()
1238 }
1239 }
1240
1241 impl Drop for TestDb {
1242 fn drop(&mut self) {
1243 if let Some(db) = self.db.take() {
1244 futures::executor::block_on(db.teardown(&self.url));
1245 }
1246 }
1247 }
1248
1249 pub struct FakeDb {
1250 background: Arc<Background>,
1251 pub users: Mutex<BTreeMap<UserId, User>>,
1252 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1253 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1254 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1255 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1256 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1257 pub contacts: Mutex<Vec<FakeContact>>,
1258 next_channel_message_id: Mutex<i32>,
1259 next_user_id: Mutex<i32>,
1260 next_org_id: Mutex<i32>,
1261 next_channel_id: Mutex<i32>,
1262 }
1263
1264 #[derive(Debug)]
1265 pub struct FakeContact {
1266 pub requester_id: UserId,
1267 pub responder_id: UserId,
1268 pub accepted: bool,
1269 pub should_notify: bool,
1270 }
1271
1272 impl FakeDb {
1273 pub fn new(background: Arc<Background>) -> Self {
1274 Self {
1275 background,
1276 users: Default::default(),
1277 next_user_id: Mutex::new(1),
1278 orgs: Default::default(),
1279 next_org_id: Mutex::new(1),
1280 org_memberships: Default::default(),
1281 channels: Default::default(),
1282 next_channel_id: Mutex::new(1),
1283 channel_memberships: Default::default(),
1284 channel_messages: Default::default(),
1285 next_channel_message_id: Mutex::new(1),
1286 contacts: Default::default(),
1287 }
1288 }
1289 }
1290
1291 #[async_trait]
1292 impl Db for FakeDb {
1293 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1294 self.background.simulate_random_delay().await;
1295
1296 let mut users = self.users.lock();
1297 if let Some(user) = users
1298 .values()
1299 .find(|user| user.github_login == github_login)
1300 {
1301 Ok(user.id)
1302 } else {
1303 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1304 users.insert(
1305 user_id,
1306 User {
1307 id: user_id,
1308 github_login: github_login.to_string(),
1309 admin,
1310 },
1311 );
1312 Ok(user_id)
1313 }
1314 }
1315
1316 async fn get_all_users(&self) -> Result<Vec<User>> {
1317 unimplemented!()
1318 }
1319
1320 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1321 unimplemented!()
1322 }
1323
1324 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1325 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1326 }
1327
1328 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1329 self.background.simulate_random_delay().await;
1330 let users = self.users.lock();
1331 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1332 }
1333
1334 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1335 Ok(self
1336 .users
1337 .lock()
1338 .values()
1339 .find(|user| user.github_login == github_login)
1340 .cloned())
1341 }
1342
1343 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1344 unimplemented!()
1345 }
1346
1347 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1348 unimplemented!()
1349 }
1350
1351 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1352 self.background.simulate_random_delay().await;
1353 let mut contacts = vec![Contact::Accepted {
1354 user_id: id,
1355 should_notify: false,
1356 }];
1357
1358 for contact in self.contacts.lock().iter() {
1359 if contact.requester_id == id {
1360 if contact.accepted {
1361 contacts.push(Contact::Accepted {
1362 user_id: contact.responder_id,
1363 should_notify: contact.should_notify,
1364 });
1365 } else {
1366 contacts.push(Contact::Outgoing {
1367 user_id: contact.responder_id,
1368 });
1369 }
1370 } else if contact.responder_id == id {
1371 if contact.accepted {
1372 contacts.push(Contact::Accepted {
1373 user_id: contact.requester_id,
1374 should_notify: false,
1375 });
1376 } else {
1377 contacts.push(Contact::Incoming {
1378 user_id: contact.requester_id,
1379 should_notify: contact.should_notify,
1380 });
1381 }
1382 }
1383 }
1384
1385 contacts.sort_unstable_by_key(|contact| contact.user_id());
1386 Ok(contacts)
1387 }
1388
1389 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1390 self.background.simulate_random_delay().await;
1391 Ok(self.contacts.lock().iter().any(|contact| {
1392 contact.accepted
1393 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1394 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1395 }))
1396 }
1397
1398 async fn send_contact_request(
1399 &self,
1400 requester_id: UserId,
1401 responder_id: UserId,
1402 ) -> Result<()> {
1403 let mut contacts = self.contacts.lock();
1404 for contact in contacts.iter_mut() {
1405 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1406 if contact.accepted {
1407 Err(anyhow!("contact already exists"))?;
1408 } else {
1409 Err(anyhow!("contact already requested"))?;
1410 }
1411 }
1412 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1413 if contact.accepted {
1414 Err(anyhow!("contact already exists"))?;
1415 } else {
1416 contact.accepted = true;
1417 contact.should_notify = false;
1418 return Ok(());
1419 }
1420 }
1421 }
1422 contacts.push(FakeContact {
1423 requester_id,
1424 responder_id,
1425 accepted: false,
1426 should_notify: true,
1427 });
1428 Ok(())
1429 }
1430
1431 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1432 self.contacts.lock().retain(|contact| {
1433 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1434 });
1435 Ok(())
1436 }
1437
1438 async fn dismiss_contact_notification(
1439 &self,
1440 user_id: UserId,
1441 contact_user_id: UserId,
1442 ) -> Result<()> {
1443 let mut contacts = self.contacts.lock();
1444 for contact in contacts.iter_mut() {
1445 if contact.requester_id == contact_user_id
1446 && contact.responder_id == user_id
1447 && !contact.accepted
1448 {
1449 contact.should_notify = false;
1450 return Ok(());
1451 }
1452 if contact.requester_id == user_id
1453 && contact.responder_id == contact_user_id
1454 && contact.accepted
1455 {
1456 contact.should_notify = false;
1457 return Ok(());
1458 }
1459 }
1460 Err(anyhow!("no such notification"))
1461 }
1462
1463 async fn respond_to_contact_request(
1464 &self,
1465 responder_id: UserId,
1466 requester_id: UserId,
1467 accept: bool,
1468 ) -> Result<()> {
1469 let mut contacts = self.contacts.lock();
1470 for (ix, contact) in contacts.iter_mut().enumerate() {
1471 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1472 if contact.accepted {
1473 return Err(anyhow!("contact already confirmed"));
1474 }
1475 if accept {
1476 contact.accepted = true;
1477 contact.should_notify = true;
1478 } else {
1479 contacts.remove(ix);
1480 }
1481 return Ok(());
1482 }
1483 }
1484 Err(anyhow!("no such contact request"))
1485 }
1486
1487 async fn create_access_token_hash(
1488 &self,
1489 _user_id: UserId,
1490 _access_token_hash: &str,
1491 _max_access_token_count: usize,
1492 ) -> Result<()> {
1493 unimplemented!()
1494 }
1495
1496 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1497 unimplemented!()
1498 }
1499
1500 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1501 unimplemented!()
1502 }
1503
1504 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1505 self.background.simulate_random_delay().await;
1506 let mut orgs = self.orgs.lock();
1507 if orgs.values().any(|org| org.slug == slug) {
1508 Err(anyhow!("org already exists"))
1509 } else {
1510 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1511 orgs.insert(
1512 org_id,
1513 Org {
1514 id: org_id,
1515 name: name.to_string(),
1516 slug: slug.to_string(),
1517 },
1518 );
1519 Ok(org_id)
1520 }
1521 }
1522
1523 async fn add_org_member(
1524 &self,
1525 org_id: OrgId,
1526 user_id: UserId,
1527 is_admin: bool,
1528 ) -> Result<()> {
1529 self.background.simulate_random_delay().await;
1530 if !self.orgs.lock().contains_key(&org_id) {
1531 return Err(anyhow!("org does not exist"));
1532 }
1533 if !self.users.lock().contains_key(&user_id) {
1534 return Err(anyhow!("user does not exist"));
1535 }
1536
1537 self.org_memberships
1538 .lock()
1539 .entry((org_id, user_id))
1540 .or_insert(is_admin);
1541 Ok(())
1542 }
1543
1544 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1545 self.background.simulate_random_delay().await;
1546 if !self.orgs.lock().contains_key(&org_id) {
1547 return Err(anyhow!("org does not exist"));
1548 }
1549
1550 let mut channels = self.channels.lock();
1551 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1552 channels.insert(
1553 channel_id,
1554 Channel {
1555 id: channel_id,
1556 name: name.to_string(),
1557 owner_id: org_id.0,
1558 owner_is_user: false,
1559 },
1560 );
1561 Ok(channel_id)
1562 }
1563
1564 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1565 self.background.simulate_random_delay().await;
1566 Ok(self
1567 .channels
1568 .lock()
1569 .values()
1570 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1571 .cloned()
1572 .collect())
1573 }
1574
1575 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1576 self.background.simulate_random_delay().await;
1577 let channels = self.channels.lock();
1578 let memberships = self.channel_memberships.lock();
1579 Ok(channels
1580 .values()
1581 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1582 .cloned()
1583 .collect())
1584 }
1585
1586 async fn can_user_access_channel(
1587 &self,
1588 user_id: UserId,
1589 channel_id: ChannelId,
1590 ) -> Result<bool> {
1591 self.background.simulate_random_delay().await;
1592 Ok(self
1593 .channel_memberships
1594 .lock()
1595 .contains_key(&(channel_id, user_id)))
1596 }
1597
1598 async fn add_channel_member(
1599 &self,
1600 channel_id: ChannelId,
1601 user_id: UserId,
1602 is_admin: bool,
1603 ) -> Result<()> {
1604 self.background.simulate_random_delay().await;
1605 if !self.channels.lock().contains_key(&channel_id) {
1606 return Err(anyhow!("channel does not exist"));
1607 }
1608 if !self.users.lock().contains_key(&user_id) {
1609 return Err(anyhow!("user does not exist"));
1610 }
1611
1612 self.channel_memberships
1613 .lock()
1614 .entry((channel_id, user_id))
1615 .or_insert(is_admin);
1616 Ok(())
1617 }
1618
1619 async fn create_channel_message(
1620 &self,
1621 channel_id: ChannelId,
1622 sender_id: UserId,
1623 body: &str,
1624 timestamp: OffsetDateTime,
1625 nonce: u128,
1626 ) -> Result<MessageId> {
1627 self.background.simulate_random_delay().await;
1628 if !self.channels.lock().contains_key(&channel_id) {
1629 return Err(anyhow!("channel does not exist"));
1630 }
1631 if !self.users.lock().contains_key(&sender_id) {
1632 return Err(anyhow!("user does not exist"));
1633 }
1634
1635 let mut messages = self.channel_messages.lock();
1636 if let Some(message) = messages
1637 .values()
1638 .find(|message| message.nonce.as_u128() == nonce)
1639 {
1640 Ok(message.id)
1641 } else {
1642 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1643 messages.insert(
1644 message_id,
1645 ChannelMessage {
1646 id: message_id,
1647 channel_id,
1648 sender_id,
1649 body: body.to_string(),
1650 sent_at: timestamp,
1651 nonce: Uuid::from_u128(nonce),
1652 },
1653 );
1654 Ok(message_id)
1655 }
1656 }
1657
1658 async fn get_channel_messages(
1659 &self,
1660 channel_id: ChannelId,
1661 count: usize,
1662 before_id: Option<MessageId>,
1663 ) -> Result<Vec<ChannelMessage>> {
1664 let mut messages = self
1665 .channel_messages
1666 .lock()
1667 .values()
1668 .rev()
1669 .filter(|message| {
1670 message.channel_id == channel_id
1671 && message.id < before_id.unwrap_or(MessageId::MAX)
1672 })
1673 .take(count)
1674 .cloned()
1675 .collect::<Vec<_>>();
1676 messages.sort_unstable_by_key(|message| message.id);
1677 Ok(messages)
1678 }
1679
1680 async fn teardown(&self, _: &str) {}
1681
1682 #[cfg(test)]
1683 fn as_fake(&self) -> Option<&FakeDb> {
1684 Some(self)
1685 }
1686 }
1687}