1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use futures::StreamExt;
4use nanoid::nanoid;
5use serde::Serialize;
6pub use sqlx::postgres::PgPoolOptions as DbOptions;
7use sqlx::{types::Uuid, FromRow};
8use time::OffsetDateTime;
9
10#[async_trait]
11pub trait Db: Send + Sync {
12 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
13 async fn get_all_users(&self) -> Result<Vec<User>>;
14 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
15 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
16 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
17 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
18 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
19 async fn destroy_user(&self, id: UserId) -> Result<()>;
20
21 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
22 async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>>;
23 async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
24
25 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
26 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
27 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
28 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
29 async fn dismiss_contact_notification(
30 &self,
31 responder_id: UserId,
32 requester_id: UserId,
33 ) -> Result<()>;
34 async fn respond_to_contact_request(
35 &self,
36 responder_id: UserId,
37 requester_id: UserId,
38 accept: bool,
39 ) -> Result<()>;
40
41 async fn create_access_token_hash(
42 &self,
43 user_id: UserId,
44 access_token_hash: &str,
45 max_access_token_count: usize,
46 ) -> Result<()>;
47 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
48 #[cfg(any(test, feature = "seed-support"))]
49
50 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
51 #[cfg(any(test, feature = "seed-support"))]
52 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
53 #[cfg(any(test, feature = "seed-support"))]
54 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
55 #[cfg(any(test, feature = "seed-support"))]
56 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
57 #[cfg(any(test, feature = "seed-support"))]
58
59 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
60 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
61 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
62 -> Result<bool>;
63 #[cfg(any(test, feature = "seed-support"))]
64 async fn add_channel_member(
65 &self,
66 channel_id: ChannelId,
67 user_id: UserId,
68 is_admin: bool,
69 ) -> Result<()>;
70 async fn create_channel_message(
71 &self,
72 channel_id: ChannelId,
73 sender_id: UserId,
74 body: &str,
75 timestamp: OffsetDateTime,
76 nonce: u128,
77 ) -> Result<MessageId>;
78 async fn get_channel_messages(
79 &self,
80 channel_id: ChannelId,
81 count: usize,
82 before_id: Option<MessageId>,
83 ) -> Result<Vec<ChannelMessage>>;
84 #[cfg(test)]
85 async fn teardown(&self, url: &str);
86 #[cfg(test)]
87 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
88}
89
90pub struct PostgresDb {
91 pool: sqlx::PgPool,
92}
93
94impl PostgresDb {
95 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
96 let pool = DbOptions::new()
97 .max_connections(max_connections)
98 .connect(&url)
99 .await
100 .context("failed to connect to postgres database")?;
101 Ok(Self { pool })
102 }
103}
104
105#[async_trait]
106impl Db for PostgresDb {
107 // users
108
109 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
110 let query = "
111 INSERT INTO users (github_login, admin)
112 VALUES ($1, $2)
113 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
114 RETURNING id
115 ";
116 Ok(sqlx::query_scalar(query)
117 .bind(github_login)
118 .bind(admin)
119 .fetch_one(&self.pool)
120 .await
121 .map(UserId)?)
122 }
123
124 async fn get_all_users(&self) -> Result<Vec<User>> {
125 let query = "SELECT * FROM users ORDER BY github_login ASC";
126 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
127 }
128
129 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
130 let like_string = fuzzy_like_string(name_query);
131 let query = "
132 SELECT users.*
133 FROM users
134 WHERE github_login ILIKE $1
135 ORDER BY github_login <-> $2
136 LIMIT $3
137 ";
138 Ok(sqlx::query_as(query)
139 .bind(like_string)
140 .bind(name_query)
141 .bind(limit)
142 .fetch_all(&self.pool)
143 .await?)
144 }
145
146 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
147 let users = self.get_users_by_ids(vec![id]).await?;
148 Ok(users.into_iter().next())
149 }
150
151 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
152 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
153 let query = "
154 SELECT users.*
155 FROM users
156 WHERE users.id = ANY ($1)
157 ";
158 Ok(sqlx::query_as(query)
159 .bind(&ids)
160 .fetch_all(&self.pool)
161 .await?)
162 }
163
164 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
165 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
166 Ok(sqlx::query_as(query)
167 .bind(github_login)
168 .fetch_optional(&self.pool)
169 .await?)
170 }
171
172 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
173 let query = "UPDATE users SET admin = $1 WHERE id = $2";
174 Ok(sqlx::query(query)
175 .bind(is_admin)
176 .bind(id.0)
177 .execute(&self.pool)
178 .await
179 .map(drop)?)
180 }
181
182 async fn destroy_user(&self, id: UserId) -> Result<()> {
183 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
184 sqlx::query(query)
185 .bind(id.0)
186 .execute(&self.pool)
187 .await
188 .map(drop)?;
189 let query = "DELETE FROM users WHERE id = $1;";
190 Ok(sqlx::query(query)
191 .bind(id.0)
192 .execute(&self.pool)
193 .await
194 .map(drop)?)
195 }
196
197 // invite codes
198
199 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
200 let mut tx = self.pool.begin().await?;
201 sqlx::query(
202 "
203 UPDATE users
204 SET invite_code = $1
205 WHERE id = $2 AND invite_code IS NULL
206 ",
207 )
208 .bind(nanoid!(16))
209 .bind(id)
210 .execute(&mut tx)
211 .await?;
212 sqlx::query(
213 "
214 UPDATE users
215 SET invite_count = $1
216 WHERE id = $2
217 ",
218 )
219 .bind(count)
220 .bind(id)
221 .execute(&mut tx)
222 .await?;
223 tx.commit().await?;
224 Ok(())
225 }
226
227 async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>> {
228 let result: Option<(String, i32)> = sqlx::query_as(
229 "
230 SELECT invite_code, invite_count
231 FROM users
232 WHERE id = $1 AND invite_code IS NOT NULL
233 ",
234 )
235 .bind(id)
236 .fetch_optional(&self.pool)
237 .await?;
238 if let Some((code, count)) = result {
239 Ok(Some((code, count.try_into()?)))
240 } else {
241 Ok(None)
242 }
243 }
244
245 async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
246 let mut tx = self.pool.begin().await?;
247
248 let inviter_id: UserId = sqlx::query_scalar(
249 "
250 UPDATE users
251 SET invite_count = invite_count - 1
252 WHERE
253 invite_code = $1 AND
254 invite_count > 0
255 RETURNING id
256 ",
257 )
258 .bind(code)
259 .fetch_optional(&mut tx)
260 .await?
261 .ok_or_else(|| anyhow!("invite code not found"))?;
262 let invitee_id = sqlx::query_scalar(
263 "
264 INSERT INTO users
265 (github_login, admin, inviter_id)
266 VALUES
267 ($1, 'f', $2)
268 RETURNING id
269 ",
270 )
271 .bind(login)
272 .bind(inviter_id)
273 .fetch_one(&mut tx)
274 .await
275 .map(UserId)?;
276
277 sqlx::query(
278 "
279 INSERT INTO contacts
280 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
281 VALUES
282 ($1, $2, 't', 't', 't')
283 ",
284 )
285 .bind(inviter_id)
286 .bind(invitee_id)
287 .execute(&mut tx)
288 .await?;
289
290 tx.commit().await?;
291 Ok(invitee_id)
292 }
293
294 // contacts
295
296 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
297 let query = "
298 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
299 FROM contacts
300 WHERE user_id_a = $1 OR user_id_b = $1;
301 ";
302
303 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
304 .bind(user_id)
305 .fetch(&self.pool);
306
307 let mut contacts = vec![Contact::Accepted {
308 user_id,
309 should_notify: false,
310 }];
311 while let Some(row) = rows.next().await {
312 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
313
314 if user_id_a == user_id {
315 if accepted {
316 contacts.push(Contact::Accepted {
317 user_id: user_id_b,
318 should_notify: should_notify && a_to_b,
319 });
320 } else if a_to_b {
321 contacts.push(Contact::Outgoing { user_id: user_id_b })
322 } else {
323 contacts.push(Contact::Incoming {
324 user_id: user_id_b,
325 should_notify,
326 });
327 }
328 } else {
329 if accepted {
330 contacts.push(Contact::Accepted {
331 user_id: user_id_a,
332 should_notify: should_notify && !a_to_b,
333 });
334 } else if a_to_b {
335 contacts.push(Contact::Incoming {
336 user_id: user_id_a,
337 should_notify,
338 });
339 } else {
340 contacts.push(Contact::Outgoing { user_id: user_id_a });
341 }
342 }
343 }
344
345 contacts.sort_unstable_by_key(|contact| contact.user_id());
346
347 Ok(contacts)
348 }
349
350 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
351 let (id_a, id_b) = if user_id_1 < user_id_2 {
352 (user_id_1, user_id_2)
353 } else {
354 (user_id_2, user_id_1)
355 };
356
357 let query = "
358 SELECT 1 FROM contacts
359 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
360 LIMIT 1
361 ";
362 Ok(sqlx::query_scalar::<_, i32>(query)
363 .bind(id_a.0)
364 .bind(id_b.0)
365 .fetch_optional(&self.pool)
366 .await?
367 .is_some())
368 }
369
370 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
371 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
372 (sender_id, receiver_id, true)
373 } else {
374 (receiver_id, sender_id, false)
375 };
376 let query = "
377 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
378 VALUES ($1, $2, $3, 'f', 't')
379 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
380 SET
381 accepted = 't',
382 should_notify = 'f'
383 WHERE
384 NOT contacts.accepted AND
385 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
386 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
387 ";
388 let result = sqlx::query(query)
389 .bind(id_a.0)
390 .bind(id_b.0)
391 .bind(a_to_b)
392 .execute(&self.pool)
393 .await?;
394
395 if result.rows_affected() == 1 {
396 Ok(())
397 } else {
398 Err(anyhow!("contact already requested"))
399 }
400 }
401
402 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
403 let (id_a, id_b) = if responder_id < requester_id {
404 (responder_id, requester_id)
405 } else {
406 (requester_id, responder_id)
407 };
408 let query = "
409 DELETE FROM contacts
410 WHERE user_id_a = $1 AND user_id_b = $2;
411 ";
412 let result = sqlx::query(query)
413 .bind(id_a.0)
414 .bind(id_b.0)
415 .execute(&self.pool)
416 .await?;
417
418 if result.rows_affected() == 1 {
419 Ok(())
420 } else {
421 Err(anyhow!("no such contact"))
422 }
423 }
424
425 async fn dismiss_contact_notification(
426 &self,
427 user_id: UserId,
428 contact_user_id: UserId,
429 ) -> Result<()> {
430 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
431 (user_id, contact_user_id, true)
432 } else {
433 (contact_user_id, user_id, false)
434 };
435
436 let query = "
437 UPDATE contacts
438 SET should_notify = 'f'
439 WHERE
440 user_id_a = $1 AND user_id_b = $2 AND
441 (
442 (a_to_b = $3 AND accepted) OR
443 (a_to_b != $3 AND NOT accepted)
444 );
445 ";
446
447 let result = sqlx::query(query)
448 .bind(id_a.0)
449 .bind(id_b.0)
450 .bind(a_to_b)
451 .execute(&self.pool)
452 .await?;
453
454 if result.rows_affected() == 0 {
455 Err(anyhow!("no such contact request"))?;
456 }
457
458 Ok(())
459 }
460
461 async fn respond_to_contact_request(
462 &self,
463 responder_id: UserId,
464 requester_id: UserId,
465 accept: bool,
466 ) -> Result<()> {
467 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
468 (responder_id, requester_id, false)
469 } else {
470 (requester_id, responder_id, true)
471 };
472 let result = if accept {
473 let query = "
474 UPDATE contacts
475 SET accepted = 't', should_notify = 't'
476 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
477 ";
478 sqlx::query(query)
479 .bind(id_a.0)
480 .bind(id_b.0)
481 .bind(a_to_b)
482 .execute(&self.pool)
483 .await?
484 } else {
485 let query = "
486 DELETE FROM contacts
487 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
488 ";
489 sqlx::query(query)
490 .bind(id_a.0)
491 .bind(id_b.0)
492 .bind(a_to_b)
493 .execute(&self.pool)
494 .await?
495 };
496 if result.rows_affected() == 1 {
497 Ok(())
498 } else {
499 Err(anyhow!("no such contact request"))
500 }
501 }
502
503 // access tokens
504
505 async fn create_access_token_hash(
506 &self,
507 user_id: UserId,
508 access_token_hash: &str,
509 max_access_token_count: usize,
510 ) -> Result<()> {
511 let insert_query = "
512 INSERT INTO access_tokens (user_id, hash)
513 VALUES ($1, $2);
514 ";
515 let cleanup_query = "
516 DELETE FROM access_tokens
517 WHERE id IN (
518 SELECT id from access_tokens
519 WHERE user_id = $1
520 ORDER BY id DESC
521 OFFSET $3
522 )
523 ";
524
525 let mut tx = self.pool.begin().await?;
526 sqlx::query(insert_query)
527 .bind(user_id.0)
528 .bind(access_token_hash)
529 .execute(&mut tx)
530 .await?;
531 sqlx::query(cleanup_query)
532 .bind(user_id.0)
533 .bind(access_token_hash)
534 .bind(max_access_token_count as u32)
535 .execute(&mut tx)
536 .await?;
537 Ok(tx.commit().await?)
538 }
539
540 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
541 let query = "
542 SELECT hash
543 FROM access_tokens
544 WHERE user_id = $1
545 ORDER BY id DESC
546 ";
547 Ok(sqlx::query_scalar(query)
548 .bind(user_id.0)
549 .fetch_all(&self.pool)
550 .await?)
551 }
552
553 // orgs
554
555 #[allow(unused)] // Help rust-analyzer
556 #[cfg(any(test, feature = "seed-support"))]
557 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
558 let query = "
559 SELECT *
560 FROM orgs
561 WHERE slug = $1
562 ";
563 Ok(sqlx::query_as(query)
564 .bind(slug)
565 .fetch_optional(&self.pool)
566 .await?)
567 }
568
569 #[cfg(any(test, feature = "seed-support"))]
570 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
571 let query = "
572 INSERT INTO orgs (name, slug)
573 VALUES ($1, $2)
574 RETURNING id
575 ";
576 Ok(sqlx::query_scalar(query)
577 .bind(name)
578 .bind(slug)
579 .fetch_one(&self.pool)
580 .await
581 .map(OrgId)?)
582 }
583
584 #[cfg(any(test, feature = "seed-support"))]
585 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
586 let query = "
587 INSERT INTO org_memberships (org_id, user_id, admin)
588 VALUES ($1, $2, $3)
589 ON CONFLICT DO NOTHING
590 ";
591 Ok(sqlx::query(query)
592 .bind(org_id.0)
593 .bind(user_id.0)
594 .bind(is_admin)
595 .execute(&self.pool)
596 .await
597 .map(drop)?)
598 }
599
600 // channels
601
602 #[cfg(any(test, feature = "seed-support"))]
603 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
604 let query = "
605 INSERT INTO channels (owner_id, owner_is_user, name)
606 VALUES ($1, false, $2)
607 RETURNING id
608 ";
609 Ok(sqlx::query_scalar(query)
610 .bind(org_id.0)
611 .bind(name)
612 .fetch_one(&self.pool)
613 .await
614 .map(ChannelId)?)
615 }
616
617 #[allow(unused)] // Help rust-analyzer
618 #[cfg(any(test, feature = "seed-support"))]
619 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
620 let query = "
621 SELECT *
622 FROM channels
623 WHERE
624 channels.owner_is_user = false AND
625 channels.owner_id = $1
626 ";
627 Ok(sqlx::query_as(query)
628 .bind(org_id.0)
629 .fetch_all(&self.pool)
630 .await?)
631 }
632
633 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
634 let query = "
635 SELECT
636 channels.*
637 FROM
638 channel_memberships, channels
639 WHERE
640 channel_memberships.user_id = $1 AND
641 channel_memberships.channel_id = channels.id
642 ";
643 Ok(sqlx::query_as(query)
644 .bind(user_id.0)
645 .fetch_all(&self.pool)
646 .await?)
647 }
648
649 async fn can_user_access_channel(
650 &self,
651 user_id: UserId,
652 channel_id: ChannelId,
653 ) -> Result<bool> {
654 let query = "
655 SELECT id
656 FROM channel_memberships
657 WHERE user_id = $1 AND channel_id = $2
658 LIMIT 1
659 ";
660 Ok(sqlx::query_scalar::<_, i32>(query)
661 .bind(user_id.0)
662 .bind(channel_id.0)
663 .fetch_optional(&self.pool)
664 .await
665 .map(|e| e.is_some())?)
666 }
667
668 #[cfg(any(test, feature = "seed-support"))]
669 async fn add_channel_member(
670 &self,
671 channel_id: ChannelId,
672 user_id: UserId,
673 is_admin: bool,
674 ) -> Result<()> {
675 let query = "
676 INSERT INTO channel_memberships (channel_id, user_id, admin)
677 VALUES ($1, $2, $3)
678 ON CONFLICT DO NOTHING
679 ";
680 Ok(sqlx::query(query)
681 .bind(channel_id.0)
682 .bind(user_id.0)
683 .bind(is_admin)
684 .execute(&self.pool)
685 .await
686 .map(drop)?)
687 }
688
689 // messages
690
691 async fn create_channel_message(
692 &self,
693 channel_id: ChannelId,
694 sender_id: UserId,
695 body: &str,
696 timestamp: OffsetDateTime,
697 nonce: u128,
698 ) -> Result<MessageId> {
699 let query = "
700 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
701 VALUES ($1, $2, $3, $4, $5)
702 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
703 RETURNING id
704 ";
705 Ok(sqlx::query_scalar(query)
706 .bind(channel_id.0)
707 .bind(sender_id.0)
708 .bind(body)
709 .bind(timestamp)
710 .bind(Uuid::from_u128(nonce))
711 .fetch_one(&self.pool)
712 .await
713 .map(MessageId)?)
714 }
715
716 async fn get_channel_messages(
717 &self,
718 channel_id: ChannelId,
719 count: usize,
720 before_id: Option<MessageId>,
721 ) -> Result<Vec<ChannelMessage>> {
722 let query = r#"
723 SELECT * FROM (
724 SELECT
725 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
726 FROM
727 channel_messages
728 WHERE
729 channel_id = $1 AND
730 id < $2
731 ORDER BY id DESC
732 LIMIT $3
733 ) as recent_messages
734 ORDER BY id ASC
735 "#;
736 Ok(sqlx::query_as(query)
737 .bind(channel_id.0)
738 .bind(before_id.unwrap_or(MessageId::MAX))
739 .bind(count as i64)
740 .fetch_all(&self.pool)
741 .await?)
742 }
743
744 #[cfg(test)]
745 async fn teardown(&self, url: &str) {
746 use util::ResultExt;
747
748 let query = "
749 SELECT pg_terminate_backend(pg_stat_activity.pid)
750 FROM pg_stat_activity
751 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
752 ";
753 sqlx::query(query).execute(&self.pool).await.log_err();
754 self.pool.close().await;
755 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
756 .await
757 .log_err();
758 }
759
760 #[cfg(test)]
761 fn as_fake(&self) -> Option<&tests::FakeDb> {
762 None
763 }
764}
765
766macro_rules! id_type {
767 ($name:ident) => {
768 #[derive(
769 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
770 )]
771 #[sqlx(transparent)]
772 #[serde(transparent)]
773 pub struct $name(pub i32);
774
775 impl $name {
776 #[allow(unused)]
777 pub const MAX: Self = Self(i32::MAX);
778
779 #[allow(unused)]
780 pub fn from_proto(value: u64) -> Self {
781 Self(value as i32)
782 }
783
784 #[allow(unused)]
785 pub fn to_proto(&self) -> u64 {
786 self.0 as u64
787 }
788 }
789
790 impl std::fmt::Display for $name {
791 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
792 self.0.fmt(f)
793 }
794 }
795 };
796}
797
798id_type!(UserId);
799#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
800pub struct User {
801 pub id: UserId,
802 pub github_login: String,
803 pub admin: bool,
804 pub invite_code: Option<String>,
805 pub invite_count: i32,
806}
807
808id_type!(OrgId);
809#[derive(FromRow)]
810pub struct Org {
811 pub id: OrgId,
812 pub name: String,
813 pub slug: String,
814}
815
816id_type!(ChannelId);
817#[derive(Clone, Debug, FromRow, Serialize)]
818pub struct Channel {
819 pub id: ChannelId,
820 pub name: String,
821 pub owner_id: i32,
822 pub owner_is_user: bool,
823}
824
825id_type!(MessageId);
826#[derive(Clone, Debug, FromRow)]
827pub struct ChannelMessage {
828 pub id: MessageId,
829 pub channel_id: ChannelId,
830 pub sender_id: UserId,
831 pub body: String,
832 pub sent_at: OffsetDateTime,
833 pub nonce: Uuid,
834}
835
836#[derive(Clone, Debug, PartialEq, Eq)]
837pub enum Contact {
838 Accepted {
839 user_id: UserId,
840 should_notify: bool,
841 },
842 Outgoing {
843 user_id: UserId,
844 },
845 Incoming {
846 user_id: UserId,
847 should_notify: bool,
848 },
849}
850
851impl Contact {
852 pub fn user_id(&self) -> UserId {
853 match self {
854 Contact::Accepted { user_id, .. } => *user_id,
855 Contact::Outgoing { user_id } => *user_id,
856 Contact::Incoming { user_id, .. } => *user_id,
857 }
858 }
859}
860
861#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
862pub struct IncomingContactRequest {
863 pub requester_id: UserId,
864 pub should_notify: bool,
865}
866
867fn fuzzy_like_string(string: &str) -> String {
868 let mut result = String::with_capacity(string.len() * 2 + 1);
869 for c in string.chars() {
870 if c.is_alphanumeric() {
871 result.push('%');
872 result.push(c);
873 }
874 }
875 result.push('%');
876 result
877}
878
879#[cfg(test)]
880pub mod tests {
881 use super::*;
882 use anyhow::anyhow;
883 use collections::BTreeMap;
884 use gpui::executor::Background;
885 use lazy_static::lazy_static;
886 use parking_lot::Mutex;
887 use rand::prelude::*;
888 use sqlx::{
889 migrate::{MigrateDatabase, Migrator},
890 Postgres,
891 };
892 use std::{path::Path, sync::Arc};
893 use util::post_inc;
894
895 #[tokio::test(flavor = "multi_thread")]
896 async fn test_get_users_by_ids() {
897 for test_db in [
898 TestDb::postgres().await,
899 TestDb::fake(Arc::new(gpui::executor::Background::new())),
900 ] {
901 let db = test_db.db();
902
903 let user = db.create_user("user", false).await.unwrap();
904 let friend1 = db.create_user("friend-1", false).await.unwrap();
905 let friend2 = db.create_user("friend-2", false).await.unwrap();
906 let friend3 = db.create_user("friend-3", false).await.unwrap();
907
908 assert_eq!(
909 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
910 .await
911 .unwrap(),
912 vec![
913 User {
914 id: user,
915 github_login: "user".to_string(),
916 admin: false,
917 ..Default::default()
918 },
919 User {
920 id: friend1,
921 github_login: "friend-1".to_string(),
922 admin: false,
923 ..Default::default()
924 },
925 User {
926 id: friend2,
927 github_login: "friend-2".to_string(),
928 admin: false,
929 ..Default::default()
930 },
931 User {
932 id: friend3,
933 github_login: "friend-3".to_string(),
934 admin: false,
935 ..Default::default()
936 }
937 ]
938 );
939 }
940 }
941
942 #[tokio::test(flavor = "multi_thread")]
943 async fn test_recent_channel_messages() {
944 for test_db in [
945 TestDb::postgres().await,
946 TestDb::fake(Arc::new(gpui::executor::Background::new())),
947 ] {
948 let db = test_db.db();
949 let user = db.create_user("user", false).await.unwrap();
950 let org = db.create_org("org", "org").await.unwrap();
951 let channel = db.create_org_channel(org, "channel").await.unwrap();
952 for i in 0..10 {
953 db.create_channel_message(
954 channel,
955 user,
956 &i.to_string(),
957 OffsetDateTime::now_utc(),
958 i,
959 )
960 .await
961 .unwrap();
962 }
963
964 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
965 assert_eq!(
966 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
967 ["5", "6", "7", "8", "9"]
968 );
969
970 let prev_messages = db
971 .get_channel_messages(channel, 4, Some(messages[0].id))
972 .await
973 .unwrap();
974 assert_eq!(
975 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
976 ["1", "2", "3", "4"]
977 );
978 }
979 }
980
981 #[tokio::test(flavor = "multi_thread")]
982 async fn test_channel_message_nonces() {
983 for test_db in [
984 TestDb::postgres().await,
985 TestDb::fake(Arc::new(gpui::executor::Background::new())),
986 ] {
987 let db = test_db.db();
988 let user = db.create_user("user", false).await.unwrap();
989 let org = db.create_org("org", "org").await.unwrap();
990 let channel = db.create_org_channel(org, "channel").await.unwrap();
991
992 let msg1_id = db
993 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
994 .await
995 .unwrap();
996 let msg2_id = db
997 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
998 .await
999 .unwrap();
1000 let msg3_id = db
1001 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1002 .await
1003 .unwrap();
1004 let msg4_id = db
1005 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1006 .await
1007 .unwrap();
1008
1009 assert_ne!(msg1_id, msg2_id);
1010 assert_eq!(msg1_id, msg3_id);
1011 assert_eq!(msg2_id, msg4_id);
1012 }
1013 }
1014
1015 #[tokio::test(flavor = "multi_thread")]
1016 async fn test_create_access_tokens() {
1017 let test_db = TestDb::postgres().await;
1018 let db = test_db.db();
1019 let user = db.create_user("the-user", false).await.unwrap();
1020
1021 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1022 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1023 assert_eq!(
1024 db.get_access_token_hashes(user).await.unwrap(),
1025 &["h2".to_string(), "h1".to_string()]
1026 );
1027
1028 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1029 assert_eq!(
1030 db.get_access_token_hashes(user).await.unwrap(),
1031 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1032 );
1033
1034 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1035 assert_eq!(
1036 db.get_access_token_hashes(user).await.unwrap(),
1037 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1038 );
1039
1040 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1041 assert_eq!(
1042 db.get_access_token_hashes(user).await.unwrap(),
1043 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1044 );
1045 }
1046
1047 #[test]
1048 fn test_fuzzy_like_string() {
1049 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1050 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1051 assert_eq!(fuzzy_like_string(" z "), "%z%");
1052 }
1053
1054 #[tokio::test(flavor = "multi_thread")]
1055 async fn test_fuzzy_search_users() {
1056 let test_db = TestDb::postgres().await;
1057 let db = test_db.db();
1058 for github_login in [
1059 "California",
1060 "colorado",
1061 "oregon",
1062 "washington",
1063 "florida",
1064 "delaware",
1065 "rhode-island",
1066 ] {
1067 db.create_user(github_login, false).await.unwrap();
1068 }
1069
1070 assert_eq!(
1071 fuzzy_search_user_names(db, "clr").await,
1072 &["colorado", "California"]
1073 );
1074 assert_eq!(
1075 fuzzy_search_user_names(db, "ro").await,
1076 &["rhode-island", "colorado", "oregon"],
1077 );
1078
1079 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1080 db.fuzzy_search_users(query, 10)
1081 .await
1082 .unwrap()
1083 .into_iter()
1084 .map(|user| user.github_login)
1085 .collect::<Vec<_>>()
1086 }
1087 }
1088
1089 #[tokio::test(flavor = "multi_thread")]
1090 async fn test_add_contacts() {
1091 for test_db in [
1092 TestDb::postgres().await,
1093 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1094 ] {
1095 let db = test_db.db();
1096
1097 let user_1 = db.create_user("user1", false).await.unwrap();
1098 let user_2 = db.create_user("user2", false).await.unwrap();
1099 let user_3 = db.create_user("user3", false).await.unwrap();
1100
1101 // User starts with no contacts
1102 assert_eq!(
1103 db.get_contacts(user_1).await.unwrap(),
1104 vec![Contact::Accepted {
1105 user_id: user_1,
1106 should_notify: false
1107 }],
1108 );
1109
1110 // User requests a contact. Both users see the pending request.
1111 db.send_contact_request(user_1, user_2).await.unwrap();
1112 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1113 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1114 assert_eq!(
1115 db.get_contacts(user_1).await.unwrap(),
1116 &[
1117 Contact::Accepted {
1118 user_id: user_1,
1119 should_notify: false
1120 },
1121 Contact::Outgoing { user_id: user_2 }
1122 ],
1123 );
1124 assert_eq!(
1125 db.get_contacts(user_2).await.unwrap(),
1126 &[
1127 Contact::Incoming {
1128 user_id: user_1,
1129 should_notify: true
1130 },
1131 Contact::Accepted {
1132 user_id: user_2,
1133 should_notify: false
1134 },
1135 ]
1136 );
1137
1138 // User 2 dismisses the contact request notification without accepting or rejecting.
1139 // We shouldn't notify them again.
1140 db.dismiss_contact_notification(user_1, user_2)
1141 .await
1142 .unwrap_err();
1143 db.dismiss_contact_notification(user_2, user_1)
1144 .await
1145 .unwrap();
1146 assert_eq!(
1147 db.get_contacts(user_2).await.unwrap(),
1148 &[
1149 Contact::Incoming {
1150 user_id: user_1,
1151 should_notify: false
1152 },
1153 Contact::Accepted {
1154 user_id: user_2,
1155 should_notify: false
1156 },
1157 ]
1158 );
1159
1160 // User can't accept their own contact request
1161 db.respond_to_contact_request(user_1, user_2, true)
1162 .await
1163 .unwrap_err();
1164
1165 // User accepts a contact request. Both users see the contact.
1166 db.respond_to_contact_request(user_2, user_1, true)
1167 .await
1168 .unwrap();
1169 assert_eq!(
1170 db.get_contacts(user_1).await.unwrap(),
1171 &[
1172 Contact::Accepted {
1173 user_id: user_1,
1174 should_notify: false
1175 },
1176 Contact::Accepted {
1177 user_id: user_2,
1178 should_notify: true
1179 }
1180 ],
1181 );
1182 assert!(db.has_contact(user_1, user_2).await.unwrap());
1183 assert!(db.has_contact(user_2, user_1).await.unwrap());
1184 assert_eq!(
1185 db.get_contacts(user_2).await.unwrap(),
1186 &[
1187 Contact::Accepted {
1188 user_id: user_1,
1189 should_notify: false,
1190 },
1191 Contact::Accepted {
1192 user_id: user_2,
1193 should_notify: false,
1194 },
1195 ]
1196 );
1197
1198 // Users cannot re-request existing contacts.
1199 db.send_contact_request(user_1, user_2).await.unwrap_err();
1200 db.send_contact_request(user_2, user_1).await.unwrap_err();
1201
1202 // Users can't dismiss notifications of them accepting other users' requests.
1203 db.dismiss_contact_notification(user_2, user_1)
1204 .await
1205 .unwrap_err();
1206 assert_eq!(
1207 db.get_contacts(user_1).await.unwrap(),
1208 &[
1209 Contact::Accepted {
1210 user_id: user_1,
1211 should_notify: false
1212 },
1213 Contact::Accepted {
1214 user_id: user_2,
1215 should_notify: true,
1216 },
1217 ]
1218 );
1219
1220 // Users can dismiss notifications of other users accepting their requests.
1221 db.dismiss_contact_notification(user_1, user_2)
1222 .await
1223 .unwrap();
1224 assert_eq!(
1225 db.get_contacts(user_1).await.unwrap(),
1226 &[
1227 Contact::Accepted {
1228 user_id: user_1,
1229 should_notify: false
1230 },
1231 Contact::Accepted {
1232 user_id: user_2,
1233 should_notify: false,
1234 },
1235 ]
1236 );
1237
1238 // Users send each other concurrent contact requests and
1239 // see that they are immediately accepted.
1240 db.send_contact_request(user_1, user_3).await.unwrap();
1241 db.send_contact_request(user_3, user_1).await.unwrap();
1242 assert_eq!(
1243 db.get_contacts(user_1).await.unwrap(),
1244 &[
1245 Contact::Accepted {
1246 user_id: user_1,
1247 should_notify: false
1248 },
1249 Contact::Accepted {
1250 user_id: user_2,
1251 should_notify: false,
1252 },
1253 Contact::Accepted {
1254 user_id: user_3,
1255 should_notify: false
1256 },
1257 ]
1258 );
1259 assert_eq!(
1260 db.get_contacts(user_3).await.unwrap(),
1261 &[
1262 Contact::Accepted {
1263 user_id: user_1,
1264 should_notify: false
1265 },
1266 Contact::Accepted {
1267 user_id: user_3,
1268 should_notify: false
1269 }
1270 ],
1271 );
1272
1273 // User declines a contact request. Both users see that it is gone.
1274 db.send_contact_request(user_2, user_3).await.unwrap();
1275 db.respond_to_contact_request(user_3, user_2, false)
1276 .await
1277 .unwrap();
1278 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1279 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1280 assert_eq!(
1281 db.get_contacts(user_2).await.unwrap(),
1282 &[
1283 Contact::Accepted {
1284 user_id: user_1,
1285 should_notify: false
1286 },
1287 Contact::Accepted {
1288 user_id: user_2,
1289 should_notify: false
1290 }
1291 ]
1292 );
1293 assert_eq!(
1294 db.get_contacts(user_3).await.unwrap(),
1295 &[
1296 Contact::Accepted {
1297 user_id: user_1,
1298 should_notify: false
1299 },
1300 Contact::Accepted {
1301 user_id: user_3,
1302 should_notify: false
1303 }
1304 ],
1305 );
1306 }
1307 }
1308
1309 #[tokio::test(flavor = "multi_thread")]
1310 async fn test_invite_codes() {
1311 let postgres = TestDb::postgres().await;
1312 let db = postgres.db();
1313 let user1 = db.create_user("user-1", false).await.unwrap();
1314
1315 // Initially, user 1 has no invite code
1316 assert_eq!(db.get_invite_code(user1).await.unwrap(), None);
1317
1318 // User 1 creates an invite code that can be used twice.
1319 db.set_invite_count(user1, 2).await.unwrap();
1320 let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1321 assert_eq!(invite_count, 2);
1322
1323 // User 2 redeems the invite code and becomes a contact of user 1.
1324 let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
1325 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1326 assert_eq!(invite_count, 1);
1327 assert_eq!(
1328 db.get_contacts(user1).await.unwrap(),
1329 [
1330 Contact::Accepted {
1331 user_id: user1,
1332 should_notify: false
1333 },
1334 Contact::Accepted {
1335 user_id: user2,
1336 should_notify: true
1337 }
1338 ]
1339 );
1340 assert_eq!(
1341 db.get_contacts(user2).await.unwrap(),
1342 [
1343 Contact::Accepted {
1344 user_id: user1,
1345 should_notify: false
1346 },
1347 Contact::Accepted {
1348 user_id: user2,
1349 should_notify: false
1350 }
1351 ]
1352 );
1353
1354 // User 3 redeems the invite code and becomes a contact of user 1.
1355 let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
1356 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1357 assert_eq!(invite_count, 0);
1358 assert_eq!(
1359 db.get_contacts(user1).await.unwrap(),
1360 [
1361 Contact::Accepted {
1362 user_id: user1,
1363 should_notify: false
1364 },
1365 Contact::Accepted {
1366 user_id: user2,
1367 should_notify: true
1368 },
1369 Contact::Accepted {
1370 user_id: user3,
1371 should_notify: true
1372 }
1373 ]
1374 );
1375 assert_eq!(
1376 db.get_contacts(user3).await.unwrap(),
1377 [
1378 Contact::Accepted {
1379 user_id: user1,
1380 should_notify: false
1381 },
1382 Contact::Accepted {
1383 user_id: user3,
1384 should_notify: false
1385 },
1386 ]
1387 );
1388
1389 // Trying to reedem the code for the third time results in an error.
1390 db.redeem_invite_code(&invite_code, "user-4")
1391 .await
1392 .unwrap_err();
1393
1394 // Invite count can be updated after the code has been created.
1395 db.set_invite_count(user1, 2).await.unwrap();
1396 let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1397 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1398 assert_eq!(invite_count, 2);
1399
1400 // User 4 can now redeem the invite code and becomes a contact of user 1.
1401 let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
1402 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1403 assert_eq!(invite_count, 1);
1404 assert_eq!(
1405 db.get_contacts(user1).await.unwrap(),
1406 [
1407 Contact::Accepted {
1408 user_id: user1,
1409 should_notify: false
1410 },
1411 Contact::Accepted {
1412 user_id: user2,
1413 should_notify: true
1414 },
1415 Contact::Accepted {
1416 user_id: user3,
1417 should_notify: true
1418 },
1419 Contact::Accepted {
1420 user_id: user4,
1421 should_notify: true
1422 }
1423 ]
1424 );
1425 assert_eq!(
1426 db.get_contacts(user4).await.unwrap(),
1427 [
1428 Contact::Accepted {
1429 user_id: user1,
1430 should_notify: false
1431 },
1432 Contact::Accepted {
1433 user_id: user4,
1434 should_notify: false
1435 },
1436 ]
1437 );
1438
1439 // An existing user cannot redeem invite codes.
1440 db.redeem_invite_code(&invite_code, "user-2")
1441 .await
1442 .unwrap_err();
1443 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1444 assert_eq!(invite_count, 1);
1445 }
1446
1447 pub struct TestDb {
1448 pub db: Option<Arc<dyn Db>>,
1449 pub url: String,
1450 }
1451
1452 impl TestDb {
1453 pub async fn postgres() -> Self {
1454 lazy_static! {
1455 static ref LOCK: Mutex<()> = Mutex::new(());
1456 }
1457
1458 let _guard = LOCK.lock();
1459 let mut rng = StdRng::from_entropy();
1460 let name = format!("zed-test-{}", rng.gen::<u128>());
1461 let url = format!("postgres://postgres@localhost/{}", name);
1462 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1463 Postgres::create_database(&url)
1464 .await
1465 .expect("failed to create test db");
1466 let db = PostgresDb::new(&url, 5).await.unwrap();
1467 let migrator = Migrator::new(migrations_path).await.unwrap();
1468 migrator.run(&db.pool).await.unwrap();
1469 Self {
1470 db: Some(Arc::new(db)),
1471 url,
1472 }
1473 }
1474
1475 pub fn fake(background: Arc<Background>) -> Self {
1476 Self {
1477 db: Some(Arc::new(FakeDb::new(background))),
1478 url: Default::default(),
1479 }
1480 }
1481
1482 pub fn db(&self) -> &Arc<dyn Db> {
1483 self.db.as_ref().unwrap()
1484 }
1485 }
1486
1487 impl Drop for TestDb {
1488 fn drop(&mut self) {
1489 if let Some(db) = self.db.take() {
1490 futures::executor::block_on(db.teardown(&self.url));
1491 }
1492 }
1493 }
1494
1495 pub struct FakeDb {
1496 background: Arc<Background>,
1497 pub users: Mutex<BTreeMap<UserId, User>>,
1498 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1499 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1500 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1501 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1502 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1503 pub contacts: Mutex<Vec<FakeContact>>,
1504 next_channel_message_id: Mutex<i32>,
1505 next_user_id: Mutex<i32>,
1506 next_org_id: Mutex<i32>,
1507 next_channel_id: Mutex<i32>,
1508 }
1509
1510 #[derive(Debug)]
1511 pub struct FakeContact {
1512 pub requester_id: UserId,
1513 pub responder_id: UserId,
1514 pub accepted: bool,
1515 pub should_notify: bool,
1516 }
1517
1518 impl FakeDb {
1519 pub fn new(background: Arc<Background>) -> Self {
1520 Self {
1521 background,
1522 users: Default::default(),
1523 next_user_id: Mutex::new(1),
1524 orgs: Default::default(),
1525 next_org_id: Mutex::new(1),
1526 org_memberships: Default::default(),
1527 channels: Default::default(),
1528 next_channel_id: Mutex::new(1),
1529 channel_memberships: Default::default(),
1530 channel_messages: Default::default(),
1531 next_channel_message_id: Mutex::new(1),
1532 contacts: Default::default(),
1533 }
1534 }
1535 }
1536
1537 #[async_trait]
1538 impl Db for FakeDb {
1539 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1540 self.background.simulate_random_delay().await;
1541
1542 let mut users = self.users.lock();
1543 if let Some(user) = users
1544 .values()
1545 .find(|user| user.github_login == github_login)
1546 {
1547 Ok(user.id)
1548 } else {
1549 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1550 users.insert(
1551 user_id,
1552 User {
1553 id: user_id,
1554 github_login: github_login.to_string(),
1555 admin,
1556 invite_code: None,
1557 invite_count: 0,
1558 },
1559 );
1560 Ok(user_id)
1561 }
1562 }
1563
1564 async fn get_all_users(&self) -> Result<Vec<User>> {
1565 unimplemented!()
1566 }
1567
1568 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1569 unimplemented!()
1570 }
1571
1572 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1573 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1574 }
1575
1576 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1577 self.background.simulate_random_delay().await;
1578 let users = self.users.lock();
1579 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1580 }
1581
1582 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1583 Ok(self
1584 .users
1585 .lock()
1586 .values()
1587 .find(|user| user.github_login == github_login)
1588 .cloned())
1589 }
1590
1591 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1592 unimplemented!()
1593 }
1594
1595 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1596 unimplemented!()
1597 }
1598
1599 // invite codes
1600
1601 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1602 unimplemented!()
1603 }
1604
1605 async fn get_invite_code(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1606 unimplemented!()
1607 }
1608
1609 async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
1610 unimplemented!()
1611 }
1612
1613 // contacts
1614
1615 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1616 self.background.simulate_random_delay().await;
1617 let mut contacts = vec![Contact::Accepted {
1618 user_id: id,
1619 should_notify: false,
1620 }];
1621
1622 for contact in self.contacts.lock().iter() {
1623 if contact.requester_id == id {
1624 if contact.accepted {
1625 contacts.push(Contact::Accepted {
1626 user_id: contact.responder_id,
1627 should_notify: contact.should_notify,
1628 });
1629 } else {
1630 contacts.push(Contact::Outgoing {
1631 user_id: contact.responder_id,
1632 });
1633 }
1634 } else if contact.responder_id == id {
1635 if contact.accepted {
1636 contacts.push(Contact::Accepted {
1637 user_id: contact.requester_id,
1638 should_notify: false,
1639 });
1640 } else {
1641 contacts.push(Contact::Incoming {
1642 user_id: contact.requester_id,
1643 should_notify: contact.should_notify,
1644 });
1645 }
1646 }
1647 }
1648
1649 contacts.sort_unstable_by_key(|contact| contact.user_id());
1650 Ok(contacts)
1651 }
1652
1653 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1654 self.background.simulate_random_delay().await;
1655 Ok(self.contacts.lock().iter().any(|contact| {
1656 contact.accepted
1657 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1658 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1659 }))
1660 }
1661
1662 async fn send_contact_request(
1663 &self,
1664 requester_id: UserId,
1665 responder_id: UserId,
1666 ) -> Result<()> {
1667 let mut contacts = self.contacts.lock();
1668 for contact in contacts.iter_mut() {
1669 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1670 if contact.accepted {
1671 Err(anyhow!("contact already exists"))?;
1672 } else {
1673 Err(anyhow!("contact already requested"))?;
1674 }
1675 }
1676 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1677 if contact.accepted {
1678 Err(anyhow!("contact already exists"))?;
1679 } else {
1680 contact.accepted = true;
1681 contact.should_notify = false;
1682 return Ok(());
1683 }
1684 }
1685 }
1686 contacts.push(FakeContact {
1687 requester_id,
1688 responder_id,
1689 accepted: false,
1690 should_notify: true,
1691 });
1692 Ok(())
1693 }
1694
1695 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1696 self.contacts.lock().retain(|contact| {
1697 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1698 });
1699 Ok(())
1700 }
1701
1702 async fn dismiss_contact_notification(
1703 &self,
1704 user_id: UserId,
1705 contact_user_id: UserId,
1706 ) -> Result<()> {
1707 let mut contacts = self.contacts.lock();
1708 for contact in contacts.iter_mut() {
1709 if contact.requester_id == contact_user_id
1710 && contact.responder_id == user_id
1711 && !contact.accepted
1712 {
1713 contact.should_notify = false;
1714 return Ok(());
1715 }
1716 if contact.requester_id == user_id
1717 && contact.responder_id == contact_user_id
1718 && contact.accepted
1719 {
1720 contact.should_notify = false;
1721 return Ok(());
1722 }
1723 }
1724 Err(anyhow!("no such notification"))
1725 }
1726
1727 async fn respond_to_contact_request(
1728 &self,
1729 responder_id: UserId,
1730 requester_id: UserId,
1731 accept: bool,
1732 ) -> Result<()> {
1733 let mut contacts = self.contacts.lock();
1734 for (ix, contact) in contacts.iter_mut().enumerate() {
1735 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1736 if contact.accepted {
1737 return Err(anyhow!("contact already confirmed"));
1738 }
1739 if accept {
1740 contact.accepted = true;
1741 contact.should_notify = true;
1742 } else {
1743 contacts.remove(ix);
1744 }
1745 return Ok(());
1746 }
1747 }
1748 Err(anyhow!("no such contact request"))
1749 }
1750
1751 async fn create_access_token_hash(
1752 &self,
1753 _user_id: UserId,
1754 _access_token_hash: &str,
1755 _max_access_token_count: usize,
1756 ) -> Result<()> {
1757 unimplemented!()
1758 }
1759
1760 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1761 unimplemented!()
1762 }
1763
1764 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1765 unimplemented!()
1766 }
1767
1768 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1769 self.background.simulate_random_delay().await;
1770 let mut orgs = self.orgs.lock();
1771 if orgs.values().any(|org| org.slug == slug) {
1772 Err(anyhow!("org already exists"))
1773 } else {
1774 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1775 orgs.insert(
1776 org_id,
1777 Org {
1778 id: org_id,
1779 name: name.to_string(),
1780 slug: slug.to_string(),
1781 },
1782 );
1783 Ok(org_id)
1784 }
1785 }
1786
1787 async fn add_org_member(
1788 &self,
1789 org_id: OrgId,
1790 user_id: UserId,
1791 is_admin: bool,
1792 ) -> Result<()> {
1793 self.background.simulate_random_delay().await;
1794 if !self.orgs.lock().contains_key(&org_id) {
1795 return Err(anyhow!("org does not exist"));
1796 }
1797 if !self.users.lock().contains_key(&user_id) {
1798 return Err(anyhow!("user does not exist"));
1799 }
1800
1801 self.org_memberships
1802 .lock()
1803 .entry((org_id, user_id))
1804 .or_insert(is_admin);
1805 Ok(())
1806 }
1807
1808 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1809 self.background.simulate_random_delay().await;
1810 if !self.orgs.lock().contains_key(&org_id) {
1811 return Err(anyhow!("org does not exist"));
1812 }
1813
1814 let mut channels = self.channels.lock();
1815 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1816 channels.insert(
1817 channel_id,
1818 Channel {
1819 id: channel_id,
1820 name: name.to_string(),
1821 owner_id: org_id.0,
1822 owner_is_user: false,
1823 },
1824 );
1825 Ok(channel_id)
1826 }
1827
1828 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1829 self.background.simulate_random_delay().await;
1830 Ok(self
1831 .channels
1832 .lock()
1833 .values()
1834 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1835 .cloned()
1836 .collect())
1837 }
1838
1839 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1840 self.background.simulate_random_delay().await;
1841 let channels = self.channels.lock();
1842 let memberships = self.channel_memberships.lock();
1843 Ok(channels
1844 .values()
1845 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1846 .cloned()
1847 .collect())
1848 }
1849
1850 async fn can_user_access_channel(
1851 &self,
1852 user_id: UserId,
1853 channel_id: ChannelId,
1854 ) -> Result<bool> {
1855 self.background.simulate_random_delay().await;
1856 Ok(self
1857 .channel_memberships
1858 .lock()
1859 .contains_key(&(channel_id, user_id)))
1860 }
1861
1862 async fn add_channel_member(
1863 &self,
1864 channel_id: ChannelId,
1865 user_id: UserId,
1866 is_admin: bool,
1867 ) -> Result<()> {
1868 self.background.simulate_random_delay().await;
1869 if !self.channels.lock().contains_key(&channel_id) {
1870 return Err(anyhow!("channel does not exist"));
1871 }
1872 if !self.users.lock().contains_key(&user_id) {
1873 return Err(anyhow!("user does not exist"));
1874 }
1875
1876 self.channel_memberships
1877 .lock()
1878 .entry((channel_id, user_id))
1879 .or_insert(is_admin);
1880 Ok(())
1881 }
1882
1883 async fn create_channel_message(
1884 &self,
1885 channel_id: ChannelId,
1886 sender_id: UserId,
1887 body: &str,
1888 timestamp: OffsetDateTime,
1889 nonce: u128,
1890 ) -> Result<MessageId> {
1891 self.background.simulate_random_delay().await;
1892 if !self.channels.lock().contains_key(&channel_id) {
1893 return Err(anyhow!("channel does not exist"));
1894 }
1895 if !self.users.lock().contains_key(&sender_id) {
1896 return Err(anyhow!("user does not exist"));
1897 }
1898
1899 let mut messages = self.channel_messages.lock();
1900 if let Some(message) = messages
1901 .values()
1902 .find(|message| message.nonce.as_u128() == nonce)
1903 {
1904 Ok(message.id)
1905 } else {
1906 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1907 messages.insert(
1908 message_id,
1909 ChannelMessage {
1910 id: message_id,
1911 channel_id,
1912 sender_id,
1913 body: body.to_string(),
1914 sent_at: timestamp,
1915 nonce: Uuid::from_u128(nonce),
1916 },
1917 );
1918 Ok(message_id)
1919 }
1920 }
1921
1922 async fn get_channel_messages(
1923 &self,
1924 channel_id: ChannelId,
1925 count: usize,
1926 before_id: Option<MessageId>,
1927 ) -> Result<Vec<ChannelMessage>> {
1928 let mut messages = self
1929 .channel_messages
1930 .lock()
1931 .values()
1932 .rev()
1933 .filter(|message| {
1934 message.channel_id == channel_id
1935 && message.id < before_id.unwrap_or(MessageId::MAX)
1936 })
1937 .take(count)
1938 .cloned()
1939 .collect::<Vec<_>>();
1940 messages.sort_unstable_by_key(|message| message.id);
1941 Ok(messages)
1942 }
1943
1944 async fn teardown(&self, _: &str) {}
1945
1946 #[cfg(test)]
1947 fn as_fake(&self) -> Option<&FakeDb> {
1948 Some(self)
1949 }
1950 }
1951}