1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use futures::StreamExt;
6use nanoid::nanoid;
7use serde::Serialize;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow};
10use time::OffsetDateTime;
11
12#[async_trait]
13pub trait Db: Send + Sync {
14 async fn create_user(
15 &self,
16 github_login: &str,
17 email_address: Option<&str>,
18 admin: bool,
19 ) -> Result<UserId>;
20 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
21 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
22 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
23 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
24 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
25 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
26 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
27 async fn destroy_user(&self, id: UserId) -> Result<()>;
28
29 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
30 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
31 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
32 async fn redeem_invite_code(
33 &self,
34 code: &str,
35 login: &str,
36 email_address: Option<&str>,
37 ) -> Result<UserId>;
38
39 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
40 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
41 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
42 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
43 async fn dismiss_contact_notification(
44 &self,
45 responder_id: UserId,
46 requester_id: UserId,
47 ) -> Result<()>;
48 async fn respond_to_contact_request(
49 &self,
50 responder_id: UserId,
51 requester_id: UserId,
52 accept: bool,
53 ) -> Result<()>;
54
55 async fn create_access_token_hash(
56 &self,
57 user_id: UserId,
58 access_token_hash: &str,
59 max_access_token_count: usize,
60 ) -> Result<()>;
61 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
62 #[cfg(any(test, feature = "seed-support"))]
63
64 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
65 #[cfg(any(test, feature = "seed-support"))]
66 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
67 #[cfg(any(test, feature = "seed-support"))]
68 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
69 #[cfg(any(test, feature = "seed-support"))]
70 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
71 #[cfg(any(test, feature = "seed-support"))]
72
73 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
74 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
75 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
76 -> Result<bool>;
77 #[cfg(any(test, feature = "seed-support"))]
78 async fn add_channel_member(
79 &self,
80 channel_id: ChannelId,
81 user_id: UserId,
82 is_admin: bool,
83 ) -> Result<()>;
84 async fn create_channel_message(
85 &self,
86 channel_id: ChannelId,
87 sender_id: UserId,
88 body: &str,
89 timestamp: OffsetDateTime,
90 nonce: u128,
91 ) -> Result<MessageId>;
92 async fn get_channel_messages(
93 &self,
94 channel_id: ChannelId,
95 count: usize,
96 before_id: Option<MessageId>,
97 ) -> Result<Vec<ChannelMessage>>;
98 #[cfg(test)]
99 async fn teardown(&self, url: &str);
100 #[cfg(test)]
101 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
102}
103
104pub struct PostgresDb {
105 pool: sqlx::PgPool,
106}
107
108impl PostgresDb {
109 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
110 let pool = DbOptions::new()
111 .max_connections(max_connections)
112 .connect(&url)
113 .await
114 .context("failed to connect to postgres database")?;
115 Ok(Self { pool })
116 }
117}
118
119#[async_trait]
120impl Db for PostgresDb {
121 // users
122
123 async fn create_user(
124 &self,
125 github_login: &str,
126 email_address: Option<&str>,
127 admin: bool,
128 ) -> Result<UserId> {
129 let query = "
130 INSERT INTO users (github_login, email_address, admin)
131 VALUES ($1, $2, $3)
132 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
133 RETURNING id
134 ";
135 Ok(sqlx::query_scalar(query)
136 .bind(github_login)
137 .bind(email_address)
138 .bind(admin)
139 .fetch_one(&self.pool)
140 .await
141 .map(UserId)?)
142 }
143
144 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
145 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
146 Ok(sqlx::query_as(query)
147 .bind(limit)
148 .bind(page * limit)
149 .fetch_all(&self.pool)
150 .await?)
151 }
152
153 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
154 let like_string = fuzzy_like_string(name_query);
155 let query = "
156 SELECT users.*
157 FROM users
158 WHERE github_login ILIKE $1
159 ORDER BY github_login <-> $2
160 LIMIT $3
161 ";
162 Ok(sqlx::query_as(query)
163 .bind(like_string)
164 .bind(name_query)
165 .bind(limit)
166 .fetch_all(&self.pool)
167 .await?)
168 }
169
170 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
171 let users = self.get_users_by_ids(vec![id]).await?;
172 Ok(users.into_iter().next())
173 }
174
175 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
176 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
177 let query = "
178 SELECT users.*
179 FROM users
180 WHERE users.id = ANY ($1)
181 ";
182 Ok(sqlx::query_as(query)
183 .bind(&ids)
184 .fetch_all(&self.pool)
185 .await?)
186 }
187
188 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
189 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
190 Ok(sqlx::query_as(query)
191 .bind(github_login)
192 .fetch_optional(&self.pool)
193 .await?)
194 }
195
196 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
197 let query = "UPDATE users SET admin = $1 WHERE id = $2";
198 Ok(sqlx::query(query)
199 .bind(is_admin)
200 .bind(id.0)
201 .execute(&self.pool)
202 .await
203 .map(drop)?)
204 }
205
206 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
207 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
208 Ok(sqlx::query(query)
209 .bind(connected_once)
210 .bind(id.0)
211 .execute(&self.pool)
212 .await
213 .map(drop)?)
214 }
215
216 async fn destroy_user(&self, id: UserId) -> Result<()> {
217 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
218 sqlx::query(query)
219 .bind(id.0)
220 .execute(&self.pool)
221 .await
222 .map(drop)?;
223 let query = "DELETE FROM users WHERE id = $1;";
224 Ok(sqlx::query(query)
225 .bind(id.0)
226 .execute(&self.pool)
227 .await
228 .map(drop)?)
229 }
230
231 // invite codes
232
233 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
234 let mut tx = self.pool.begin().await?;
235 if count > 0 {
236 sqlx::query(
237 "
238 UPDATE users
239 SET invite_code = $1
240 WHERE id = $2 AND invite_code IS NULL
241 ",
242 )
243 .bind(nanoid!(16))
244 .bind(id)
245 .execute(&mut tx)
246 .await?;
247 }
248
249 sqlx::query(
250 "
251 UPDATE users
252 SET invite_count = $1
253 WHERE id = $2
254 ",
255 )
256 .bind(count)
257 .bind(id)
258 .execute(&mut tx)
259 .await?;
260 tx.commit().await?;
261 Ok(())
262 }
263
264 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
265 let result: Option<(String, i32)> = sqlx::query_as(
266 "
267 SELECT invite_code, invite_count
268 FROM users
269 WHERE id = $1 AND invite_code IS NOT NULL
270 ",
271 )
272 .bind(id)
273 .fetch_optional(&self.pool)
274 .await?;
275 if let Some((code, count)) = result {
276 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
277 } else {
278 Ok(None)
279 }
280 }
281
282 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
283 sqlx::query_as(
284 "
285 SELECT *
286 FROM users
287 WHERE invite_code = $1
288 ",
289 )
290 .bind(code)
291 .fetch_optional(&self.pool)
292 .await?
293 .ok_or_else(|| {
294 Error::Http(
295 StatusCode::NOT_FOUND,
296 "that invite code does not exist".to_string(),
297 )
298 })
299 }
300
301 async fn redeem_invite_code(
302 &self,
303 code: &str,
304 login: &str,
305 email_address: Option<&str>,
306 ) -> Result<UserId> {
307 let mut tx = self.pool.begin().await?;
308
309 let inviter_id: Option<UserId> = sqlx::query_scalar(
310 "
311 UPDATE users
312 SET invite_count = invite_count - 1
313 WHERE
314 invite_code = $1 AND
315 invite_count > 0
316 RETURNING id
317 ",
318 )
319 .bind(code)
320 .fetch_optional(&mut tx)
321 .await?;
322
323 let inviter_id = match inviter_id {
324 Some(inviter_id) => inviter_id,
325 None => {
326 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
327 .bind(code)
328 .fetch_optional(&mut tx)
329 .await?
330 .is_some()
331 {
332 Err(Error::Http(
333 StatusCode::UNAUTHORIZED,
334 "no invites remaining".to_string(),
335 ))?
336 } else {
337 Err(Error::Http(
338 StatusCode::NOT_FOUND,
339 "invite code not found".to_string(),
340 ))?
341 }
342 }
343 };
344
345 let invitee_id = sqlx::query_scalar(
346 "
347 INSERT INTO users
348 (github_login, email_address, admin, inviter_id)
349 VALUES
350 ($1, $2, 'f', $3)
351 RETURNING id
352 ",
353 )
354 .bind(login)
355 .bind(email_address)
356 .bind(inviter_id)
357 .fetch_one(&mut tx)
358 .await
359 .map(UserId)?;
360
361 sqlx::query(
362 "
363 INSERT INTO contacts
364 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
365 VALUES
366 ($1, $2, 't', 't', 't')
367 ",
368 )
369 .bind(inviter_id)
370 .bind(invitee_id)
371 .execute(&mut tx)
372 .await?;
373
374 tx.commit().await?;
375 Ok(invitee_id)
376 }
377
378 // contacts
379
380 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
381 let query = "
382 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
383 FROM contacts
384 WHERE user_id_a = $1 OR user_id_b = $1;
385 ";
386
387 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
388 .bind(user_id)
389 .fetch(&self.pool);
390
391 let mut contacts = vec![Contact::Accepted {
392 user_id,
393 should_notify: false,
394 }];
395 while let Some(row) = rows.next().await {
396 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
397
398 if user_id_a == user_id {
399 if accepted {
400 contacts.push(Contact::Accepted {
401 user_id: user_id_b,
402 should_notify: should_notify && a_to_b,
403 });
404 } else if a_to_b {
405 contacts.push(Contact::Outgoing { user_id: user_id_b })
406 } else {
407 contacts.push(Contact::Incoming {
408 user_id: user_id_b,
409 should_notify,
410 });
411 }
412 } else {
413 if accepted {
414 contacts.push(Contact::Accepted {
415 user_id: user_id_a,
416 should_notify: should_notify && !a_to_b,
417 });
418 } else if a_to_b {
419 contacts.push(Contact::Incoming {
420 user_id: user_id_a,
421 should_notify,
422 });
423 } else {
424 contacts.push(Contact::Outgoing { user_id: user_id_a });
425 }
426 }
427 }
428
429 contacts.sort_unstable_by_key(|contact| contact.user_id());
430
431 Ok(contacts)
432 }
433
434 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
435 let (id_a, id_b) = if user_id_1 < user_id_2 {
436 (user_id_1, user_id_2)
437 } else {
438 (user_id_2, user_id_1)
439 };
440
441 let query = "
442 SELECT 1 FROM contacts
443 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
444 LIMIT 1
445 ";
446 Ok(sqlx::query_scalar::<_, i32>(query)
447 .bind(id_a.0)
448 .bind(id_b.0)
449 .fetch_optional(&self.pool)
450 .await?
451 .is_some())
452 }
453
454 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
455 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
456 (sender_id, receiver_id, true)
457 } else {
458 (receiver_id, sender_id, false)
459 };
460 let query = "
461 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
462 VALUES ($1, $2, $3, 'f', 't')
463 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
464 SET
465 accepted = 't',
466 should_notify = 'f'
467 WHERE
468 NOT contacts.accepted AND
469 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
470 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
471 ";
472 let result = sqlx::query(query)
473 .bind(id_a.0)
474 .bind(id_b.0)
475 .bind(a_to_b)
476 .execute(&self.pool)
477 .await?;
478
479 if result.rows_affected() == 1 {
480 Ok(())
481 } else {
482 Err(anyhow!("contact already requested"))?
483 }
484 }
485
486 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
487 let (id_a, id_b) = if responder_id < requester_id {
488 (responder_id, requester_id)
489 } else {
490 (requester_id, responder_id)
491 };
492 let query = "
493 DELETE FROM contacts
494 WHERE user_id_a = $1 AND user_id_b = $2;
495 ";
496 let result = sqlx::query(query)
497 .bind(id_a.0)
498 .bind(id_b.0)
499 .execute(&self.pool)
500 .await?;
501
502 if result.rows_affected() == 1 {
503 Ok(())
504 } else {
505 Err(anyhow!("no such contact"))?
506 }
507 }
508
509 async fn dismiss_contact_notification(
510 &self,
511 user_id: UserId,
512 contact_user_id: UserId,
513 ) -> Result<()> {
514 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
515 (user_id, contact_user_id, true)
516 } else {
517 (contact_user_id, user_id, false)
518 };
519
520 let query = "
521 UPDATE contacts
522 SET should_notify = 'f'
523 WHERE
524 user_id_a = $1 AND user_id_b = $2 AND
525 (
526 (a_to_b = $3 AND accepted) OR
527 (a_to_b != $3 AND NOT accepted)
528 );
529 ";
530
531 let result = sqlx::query(query)
532 .bind(id_a.0)
533 .bind(id_b.0)
534 .bind(a_to_b)
535 .execute(&self.pool)
536 .await?;
537
538 if result.rows_affected() == 0 {
539 Err(anyhow!("no such contact request"))?;
540 }
541
542 Ok(())
543 }
544
545 async fn respond_to_contact_request(
546 &self,
547 responder_id: UserId,
548 requester_id: UserId,
549 accept: bool,
550 ) -> Result<()> {
551 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
552 (responder_id, requester_id, false)
553 } else {
554 (requester_id, responder_id, true)
555 };
556 let result = if accept {
557 let query = "
558 UPDATE contacts
559 SET accepted = 't', should_notify = 't'
560 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
561 ";
562 sqlx::query(query)
563 .bind(id_a.0)
564 .bind(id_b.0)
565 .bind(a_to_b)
566 .execute(&self.pool)
567 .await?
568 } else {
569 let query = "
570 DELETE FROM contacts
571 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
572 ";
573 sqlx::query(query)
574 .bind(id_a.0)
575 .bind(id_b.0)
576 .bind(a_to_b)
577 .execute(&self.pool)
578 .await?
579 };
580 if result.rows_affected() == 1 {
581 Ok(())
582 } else {
583 Err(anyhow!("no such contact request"))?
584 }
585 }
586
587 // access tokens
588
589 async fn create_access_token_hash(
590 &self,
591 user_id: UserId,
592 access_token_hash: &str,
593 max_access_token_count: usize,
594 ) -> Result<()> {
595 let insert_query = "
596 INSERT INTO access_tokens (user_id, hash)
597 VALUES ($1, $2);
598 ";
599 let cleanup_query = "
600 DELETE FROM access_tokens
601 WHERE id IN (
602 SELECT id from access_tokens
603 WHERE user_id = $1
604 ORDER BY id DESC
605 OFFSET $3
606 )
607 ";
608
609 let mut tx = self.pool.begin().await?;
610 sqlx::query(insert_query)
611 .bind(user_id.0)
612 .bind(access_token_hash)
613 .execute(&mut tx)
614 .await?;
615 sqlx::query(cleanup_query)
616 .bind(user_id.0)
617 .bind(access_token_hash)
618 .bind(max_access_token_count as u32)
619 .execute(&mut tx)
620 .await?;
621 Ok(tx.commit().await?)
622 }
623
624 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
625 let query = "
626 SELECT hash
627 FROM access_tokens
628 WHERE user_id = $1
629 ORDER BY id DESC
630 ";
631 Ok(sqlx::query_scalar(query)
632 .bind(user_id.0)
633 .fetch_all(&self.pool)
634 .await?)
635 }
636
637 // orgs
638
639 #[allow(unused)] // Help rust-analyzer
640 #[cfg(any(test, feature = "seed-support"))]
641 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
642 let query = "
643 SELECT *
644 FROM orgs
645 WHERE slug = $1
646 ";
647 Ok(sqlx::query_as(query)
648 .bind(slug)
649 .fetch_optional(&self.pool)
650 .await?)
651 }
652
653 #[cfg(any(test, feature = "seed-support"))]
654 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
655 let query = "
656 INSERT INTO orgs (name, slug)
657 VALUES ($1, $2)
658 RETURNING id
659 ";
660 Ok(sqlx::query_scalar(query)
661 .bind(name)
662 .bind(slug)
663 .fetch_one(&self.pool)
664 .await
665 .map(OrgId)?)
666 }
667
668 #[cfg(any(test, feature = "seed-support"))]
669 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
670 let query = "
671 INSERT INTO org_memberships (org_id, user_id, admin)
672 VALUES ($1, $2, $3)
673 ON CONFLICT DO NOTHING
674 ";
675 Ok(sqlx::query(query)
676 .bind(org_id.0)
677 .bind(user_id.0)
678 .bind(is_admin)
679 .execute(&self.pool)
680 .await
681 .map(drop)?)
682 }
683
684 // channels
685
686 #[cfg(any(test, feature = "seed-support"))]
687 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
688 let query = "
689 INSERT INTO channels (owner_id, owner_is_user, name)
690 VALUES ($1, false, $2)
691 RETURNING id
692 ";
693 Ok(sqlx::query_scalar(query)
694 .bind(org_id.0)
695 .bind(name)
696 .fetch_one(&self.pool)
697 .await
698 .map(ChannelId)?)
699 }
700
701 #[allow(unused)] // Help rust-analyzer
702 #[cfg(any(test, feature = "seed-support"))]
703 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
704 let query = "
705 SELECT *
706 FROM channels
707 WHERE
708 channels.owner_is_user = false AND
709 channels.owner_id = $1
710 ";
711 Ok(sqlx::query_as(query)
712 .bind(org_id.0)
713 .fetch_all(&self.pool)
714 .await?)
715 }
716
717 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
718 let query = "
719 SELECT
720 channels.*
721 FROM
722 channel_memberships, channels
723 WHERE
724 channel_memberships.user_id = $1 AND
725 channel_memberships.channel_id = channels.id
726 ";
727 Ok(sqlx::query_as(query)
728 .bind(user_id.0)
729 .fetch_all(&self.pool)
730 .await?)
731 }
732
733 async fn can_user_access_channel(
734 &self,
735 user_id: UserId,
736 channel_id: ChannelId,
737 ) -> Result<bool> {
738 let query = "
739 SELECT id
740 FROM channel_memberships
741 WHERE user_id = $1 AND channel_id = $2
742 LIMIT 1
743 ";
744 Ok(sqlx::query_scalar::<_, i32>(query)
745 .bind(user_id.0)
746 .bind(channel_id.0)
747 .fetch_optional(&self.pool)
748 .await
749 .map(|e| e.is_some())?)
750 }
751
752 #[cfg(any(test, feature = "seed-support"))]
753 async fn add_channel_member(
754 &self,
755 channel_id: ChannelId,
756 user_id: UserId,
757 is_admin: bool,
758 ) -> Result<()> {
759 let query = "
760 INSERT INTO channel_memberships (channel_id, user_id, admin)
761 VALUES ($1, $2, $3)
762 ON CONFLICT DO NOTHING
763 ";
764 Ok(sqlx::query(query)
765 .bind(channel_id.0)
766 .bind(user_id.0)
767 .bind(is_admin)
768 .execute(&self.pool)
769 .await
770 .map(drop)?)
771 }
772
773 // messages
774
775 async fn create_channel_message(
776 &self,
777 channel_id: ChannelId,
778 sender_id: UserId,
779 body: &str,
780 timestamp: OffsetDateTime,
781 nonce: u128,
782 ) -> Result<MessageId> {
783 let query = "
784 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
785 VALUES ($1, $2, $3, $4, $5)
786 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
787 RETURNING id
788 ";
789 Ok(sqlx::query_scalar(query)
790 .bind(channel_id.0)
791 .bind(sender_id.0)
792 .bind(body)
793 .bind(timestamp)
794 .bind(Uuid::from_u128(nonce))
795 .fetch_one(&self.pool)
796 .await
797 .map(MessageId)?)
798 }
799
800 async fn get_channel_messages(
801 &self,
802 channel_id: ChannelId,
803 count: usize,
804 before_id: Option<MessageId>,
805 ) -> Result<Vec<ChannelMessage>> {
806 let query = r#"
807 SELECT * FROM (
808 SELECT
809 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
810 FROM
811 channel_messages
812 WHERE
813 channel_id = $1 AND
814 id < $2
815 ORDER BY id DESC
816 LIMIT $3
817 ) as recent_messages
818 ORDER BY id ASC
819 "#;
820 Ok(sqlx::query_as(query)
821 .bind(channel_id.0)
822 .bind(before_id.unwrap_or(MessageId::MAX))
823 .bind(count as i64)
824 .fetch_all(&self.pool)
825 .await?)
826 }
827
828 #[cfg(test)]
829 async fn teardown(&self, url: &str) {
830 use util::ResultExt;
831
832 let query = "
833 SELECT pg_terminate_backend(pg_stat_activity.pid)
834 FROM pg_stat_activity
835 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
836 ";
837 sqlx::query(query).execute(&self.pool).await.log_err();
838 self.pool.close().await;
839 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
840 .await
841 .log_err();
842 }
843
844 #[cfg(test)]
845 fn as_fake(&self) -> Option<&tests::FakeDb> {
846 None
847 }
848}
849
850macro_rules! id_type {
851 ($name:ident) => {
852 #[derive(
853 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
854 )]
855 #[sqlx(transparent)]
856 #[serde(transparent)]
857 pub struct $name(pub i32);
858
859 impl $name {
860 #[allow(unused)]
861 pub const MAX: Self = Self(i32::MAX);
862
863 #[allow(unused)]
864 pub fn from_proto(value: u64) -> Self {
865 Self(value as i32)
866 }
867
868 #[allow(unused)]
869 pub fn to_proto(&self) -> u64 {
870 self.0 as u64
871 }
872 }
873
874 impl std::fmt::Display for $name {
875 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
876 self.0.fmt(f)
877 }
878 }
879 };
880}
881
882id_type!(UserId);
883#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
884pub struct User {
885 pub id: UserId,
886 pub github_login: String,
887 pub email_address: Option<String>,
888 pub admin: bool,
889 pub invite_code: Option<String>,
890 pub invite_count: i32,
891 pub connected_once: bool,
892}
893
894id_type!(OrgId);
895#[derive(FromRow)]
896pub struct Org {
897 pub id: OrgId,
898 pub name: String,
899 pub slug: String,
900}
901
902id_type!(ChannelId);
903#[derive(Clone, Debug, FromRow, Serialize)]
904pub struct Channel {
905 pub id: ChannelId,
906 pub name: String,
907 pub owner_id: i32,
908 pub owner_is_user: bool,
909}
910
911id_type!(MessageId);
912#[derive(Clone, Debug, FromRow)]
913pub struct ChannelMessage {
914 pub id: MessageId,
915 pub channel_id: ChannelId,
916 pub sender_id: UserId,
917 pub body: String,
918 pub sent_at: OffsetDateTime,
919 pub nonce: Uuid,
920}
921
922#[derive(Clone, Debug, PartialEq, Eq)]
923pub enum Contact {
924 Accepted {
925 user_id: UserId,
926 should_notify: bool,
927 },
928 Outgoing {
929 user_id: UserId,
930 },
931 Incoming {
932 user_id: UserId,
933 should_notify: bool,
934 },
935}
936
937impl Contact {
938 pub fn user_id(&self) -> UserId {
939 match self {
940 Contact::Accepted { user_id, .. } => *user_id,
941 Contact::Outgoing { user_id } => *user_id,
942 Contact::Incoming { user_id, .. } => *user_id,
943 }
944 }
945}
946
947#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
948pub struct IncomingContactRequest {
949 pub requester_id: UserId,
950 pub should_notify: bool,
951}
952
953fn fuzzy_like_string(string: &str) -> String {
954 let mut result = String::with_capacity(string.len() * 2 + 1);
955 for c in string.chars() {
956 if c.is_alphanumeric() {
957 result.push('%');
958 result.push(c);
959 }
960 }
961 result.push('%');
962 result
963}
964
965#[cfg(test)]
966pub mod tests {
967 use super::*;
968 use anyhow::anyhow;
969 use collections::BTreeMap;
970 use gpui::executor::Background;
971 use lazy_static::lazy_static;
972 use parking_lot::Mutex;
973 use rand::prelude::*;
974 use sqlx::{
975 migrate::{MigrateDatabase, Migrator},
976 Postgres,
977 };
978 use std::{path::Path, sync::Arc};
979 use util::post_inc;
980
981 #[tokio::test(flavor = "multi_thread")]
982 async fn test_get_users_by_ids() {
983 for test_db in [
984 TestDb::postgres().await,
985 TestDb::fake(Arc::new(gpui::executor::Background::new())),
986 ] {
987 let db = test_db.db();
988
989 let user = db.create_user("user", None, false).await.unwrap();
990 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
991 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
992 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
993
994 assert_eq!(
995 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
996 .await
997 .unwrap(),
998 vec![
999 User {
1000 id: user,
1001 github_login: "user".to_string(),
1002 admin: false,
1003 ..Default::default()
1004 },
1005 User {
1006 id: friend1,
1007 github_login: "friend-1".to_string(),
1008 admin: false,
1009 ..Default::default()
1010 },
1011 User {
1012 id: friend2,
1013 github_login: "friend-2".to_string(),
1014 admin: false,
1015 ..Default::default()
1016 },
1017 User {
1018 id: friend3,
1019 github_login: "friend-3".to_string(),
1020 admin: false,
1021 ..Default::default()
1022 }
1023 ]
1024 );
1025 }
1026 }
1027
1028 #[tokio::test(flavor = "multi_thread")]
1029 async fn test_recent_channel_messages() {
1030 for test_db in [
1031 TestDb::postgres().await,
1032 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1033 ] {
1034 let db = test_db.db();
1035 let user = db.create_user("user", None, false).await.unwrap();
1036 let org = db.create_org("org", "org").await.unwrap();
1037 let channel = db.create_org_channel(org, "channel").await.unwrap();
1038 for i in 0..10 {
1039 db.create_channel_message(
1040 channel,
1041 user,
1042 &i.to_string(),
1043 OffsetDateTime::now_utc(),
1044 i,
1045 )
1046 .await
1047 .unwrap();
1048 }
1049
1050 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1051 assert_eq!(
1052 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1053 ["5", "6", "7", "8", "9"]
1054 );
1055
1056 let prev_messages = db
1057 .get_channel_messages(channel, 4, Some(messages[0].id))
1058 .await
1059 .unwrap();
1060 assert_eq!(
1061 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1062 ["1", "2", "3", "4"]
1063 );
1064 }
1065 }
1066
1067 #[tokio::test(flavor = "multi_thread")]
1068 async fn test_channel_message_nonces() {
1069 for test_db in [
1070 TestDb::postgres().await,
1071 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1072 ] {
1073 let db = test_db.db();
1074 let user = db.create_user("user", None, false).await.unwrap();
1075 let org = db.create_org("org", "org").await.unwrap();
1076 let channel = db.create_org_channel(org, "channel").await.unwrap();
1077
1078 let msg1_id = db
1079 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1080 .await
1081 .unwrap();
1082 let msg2_id = db
1083 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1084 .await
1085 .unwrap();
1086 let msg3_id = db
1087 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1088 .await
1089 .unwrap();
1090 let msg4_id = db
1091 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1092 .await
1093 .unwrap();
1094
1095 assert_ne!(msg1_id, msg2_id);
1096 assert_eq!(msg1_id, msg3_id);
1097 assert_eq!(msg2_id, msg4_id);
1098 }
1099 }
1100
1101 #[tokio::test(flavor = "multi_thread")]
1102 async fn test_create_access_tokens() {
1103 let test_db = TestDb::postgres().await;
1104 let db = test_db.db();
1105 let user = db.create_user("the-user", None, false).await.unwrap();
1106
1107 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1108 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1109 assert_eq!(
1110 db.get_access_token_hashes(user).await.unwrap(),
1111 &["h2".to_string(), "h1".to_string()]
1112 );
1113
1114 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1115 assert_eq!(
1116 db.get_access_token_hashes(user).await.unwrap(),
1117 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1118 );
1119
1120 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1121 assert_eq!(
1122 db.get_access_token_hashes(user).await.unwrap(),
1123 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1124 );
1125
1126 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1127 assert_eq!(
1128 db.get_access_token_hashes(user).await.unwrap(),
1129 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1130 );
1131 }
1132
1133 #[test]
1134 fn test_fuzzy_like_string() {
1135 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1136 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1137 assert_eq!(fuzzy_like_string(" z "), "%z%");
1138 }
1139
1140 #[tokio::test(flavor = "multi_thread")]
1141 async fn test_fuzzy_search_users() {
1142 let test_db = TestDb::postgres().await;
1143 let db = test_db.db();
1144 for github_login in [
1145 "California",
1146 "colorado",
1147 "oregon",
1148 "washington",
1149 "florida",
1150 "delaware",
1151 "rhode-island",
1152 ] {
1153 db.create_user(github_login, None, false).await.unwrap();
1154 }
1155
1156 assert_eq!(
1157 fuzzy_search_user_names(db, "clr").await,
1158 &["colorado", "California"]
1159 );
1160 assert_eq!(
1161 fuzzy_search_user_names(db, "ro").await,
1162 &["rhode-island", "colorado", "oregon"],
1163 );
1164
1165 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1166 db.fuzzy_search_users(query, 10)
1167 .await
1168 .unwrap()
1169 .into_iter()
1170 .map(|user| user.github_login)
1171 .collect::<Vec<_>>()
1172 }
1173 }
1174
1175 #[tokio::test(flavor = "multi_thread")]
1176 async fn test_add_contacts() {
1177 for test_db in [
1178 TestDb::postgres().await,
1179 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1180 ] {
1181 let db = test_db.db();
1182
1183 let user_1 = db.create_user("user1", None, false).await.unwrap();
1184 let user_2 = db.create_user("user2", None, false).await.unwrap();
1185 let user_3 = db.create_user("user3", None, false).await.unwrap();
1186
1187 // User starts with no contacts
1188 assert_eq!(
1189 db.get_contacts(user_1).await.unwrap(),
1190 vec![Contact::Accepted {
1191 user_id: user_1,
1192 should_notify: false
1193 }],
1194 );
1195
1196 // User requests a contact. Both users see the pending request.
1197 db.send_contact_request(user_1, user_2).await.unwrap();
1198 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1199 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1200 assert_eq!(
1201 db.get_contacts(user_1).await.unwrap(),
1202 &[
1203 Contact::Accepted {
1204 user_id: user_1,
1205 should_notify: false
1206 },
1207 Contact::Outgoing { user_id: user_2 }
1208 ],
1209 );
1210 assert_eq!(
1211 db.get_contacts(user_2).await.unwrap(),
1212 &[
1213 Contact::Incoming {
1214 user_id: user_1,
1215 should_notify: true
1216 },
1217 Contact::Accepted {
1218 user_id: user_2,
1219 should_notify: false
1220 },
1221 ]
1222 );
1223
1224 // User 2 dismisses the contact request notification without accepting or rejecting.
1225 // We shouldn't notify them again.
1226 db.dismiss_contact_notification(user_1, user_2)
1227 .await
1228 .unwrap_err();
1229 db.dismiss_contact_notification(user_2, user_1)
1230 .await
1231 .unwrap();
1232 assert_eq!(
1233 db.get_contacts(user_2).await.unwrap(),
1234 &[
1235 Contact::Incoming {
1236 user_id: user_1,
1237 should_notify: false
1238 },
1239 Contact::Accepted {
1240 user_id: user_2,
1241 should_notify: false
1242 },
1243 ]
1244 );
1245
1246 // User can't accept their own contact request
1247 db.respond_to_contact_request(user_1, user_2, true)
1248 .await
1249 .unwrap_err();
1250
1251 // User accepts a contact request. Both users see the contact.
1252 db.respond_to_contact_request(user_2, user_1, true)
1253 .await
1254 .unwrap();
1255 assert_eq!(
1256 db.get_contacts(user_1).await.unwrap(),
1257 &[
1258 Contact::Accepted {
1259 user_id: user_1,
1260 should_notify: false
1261 },
1262 Contact::Accepted {
1263 user_id: user_2,
1264 should_notify: true
1265 }
1266 ],
1267 );
1268 assert!(db.has_contact(user_1, user_2).await.unwrap());
1269 assert!(db.has_contact(user_2, user_1).await.unwrap());
1270 assert_eq!(
1271 db.get_contacts(user_2).await.unwrap(),
1272 &[
1273 Contact::Accepted {
1274 user_id: user_1,
1275 should_notify: false,
1276 },
1277 Contact::Accepted {
1278 user_id: user_2,
1279 should_notify: false,
1280 },
1281 ]
1282 );
1283
1284 // Users cannot re-request existing contacts.
1285 db.send_contact_request(user_1, user_2).await.unwrap_err();
1286 db.send_contact_request(user_2, user_1).await.unwrap_err();
1287
1288 // Users can't dismiss notifications of them accepting other users' requests.
1289 db.dismiss_contact_notification(user_2, user_1)
1290 .await
1291 .unwrap_err();
1292 assert_eq!(
1293 db.get_contacts(user_1).await.unwrap(),
1294 &[
1295 Contact::Accepted {
1296 user_id: user_1,
1297 should_notify: false
1298 },
1299 Contact::Accepted {
1300 user_id: user_2,
1301 should_notify: true,
1302 },
1303 ]
1304 );
1305
1306 // Users can dismiss notifications of other users accepting their requests.
1307 db.dismiss_contact_notification(user_1, user_2)
1308 .await
1309 .unwrap();
1310 assert_eq!(
1311 db.get_contacts(user_1).await.unwrap(),
1312 &[
1313 Contact::Accepted {
1314 user_id: user_1,
1315 should_notify: false
1316 },
1317 Contact::Accepted {
1318 user_id: user_2,
1319 should_notify: false,
1320 },
1321 ]
1322 );
1323
1324 // Users send each other concurrent contact requests and
1325 // see that they are immediately accepted.
1326 db.send_contact_request(user_1, user_3).await.unwrap();
1327 db.send_contact_request(user_3, user_1).await.unwrap();
1328 assert_eq!(
1329 db.get_contacts(user_1).await.unwrap(),
1330 &[
1331 Contact::Accepted {
1332 user_id: user_1,
1333 should_notify: false
1334 },
1335 Contact::Accepted {
1336 user_id: user_2,
1337 should_notify: false,
1338 },
1339 Contact::Accepted {
1340 user_id: user_3,
1341 should_notify: false
1342 },
1343 ]
1344 );
1345 assert_eq!(
1346 db.get_contacts(user_3).await.unwrap(),
1347 &[
1348 Contact::Accepted {
1349 user_id: user_1,
1350 should_notify: false
1351 },
1352 Contact::Accepted {
1353 user_id: user_3,
1354 should_notify: false
1355 }
1356 ],
1357 );
1358
1359 // User declines a contact request. Both users see that it is gone.
1360 db.send_contact_request(user_2, user_3).await.unwrap();
1361 db.respond_to_contact_request(user_3, user_2, false)
1362 .await
1363 .unwrap();
1364 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1365 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1366 assert_eq!(
1367 db.get_contacts(user_2).await.unwrap(),
1368 &[
1369 Contact::Accepted {
1370 user_id: user_1,
1371 should_notify: false
1372 },
1373 Contact::Accepted {
1374 user_id: user_2,
1375 should_notify: false
1376 }
1377 ]
1378 );
1379 assert_eq!(
1380 db.get_contacts(user_3).await.unwrap(),
1381 &[
1382 Contact::Accepted {
1383 user_id: user_1,
1384 should_notify: false
1385 },
1386 Contact::Accepted {
1387 user_id: user_3,
1388 should_notify: false
1389 }
1390 ],
1391 );
1392 }
1393 }
1394
1395 #[tokio::test(flavor = "multi_thread")]
1396 async fn test_invite_codes() {
1397 let postgres = TestDb::postgres().await;
1398 let db = postgres.db();
1399 let user1 = db.create_user("user-1", None, false).await.unwrap();
1400
1401 // Initially, user 1 has no invite code
1402 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1403
1404 // Setting invite count to 0 when no code is assigned does not assign a new code
1405 db.set_invite_count(user1, 0).await.unwrap();
1406 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1407
1408 // User 1 creates an invite code that can be used twice.
1409 db.set_invite_count(user1, 2).await.unwrap();
1410 let (invite_code, invite_count) =
1411 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1412 assert_eq!(invite_count, 2);
1413
1414 // User 2 redeems the invite code and becomes a contact of user 1.
1415 let user2 = db
1416 .redeem_invite_code(&invite_code, "user-2", None)
1417 .await
1418 .unwrap();
1419 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1420 assert_eq!(invite_count, 1);
1421 assert_eq!(
1422 db.get_contacts(user1).await.unwrap(),
1423 [
1424 Contact::Accepted {
1425 user_id: user1,
1426 should_notify: false
1427 },
1428 Contact::Accepted {
1429 user_id: user2,
1430 should_notify: true
1431 }
1432 ]
1433 );
1434 assert_eq!(
1435 db.get_contacts(user2).await.unwrap(),
1436 [
1437 Contact::Accepted {
1438 user_id: user1,
1439 should_notify: false
1440 },
1441 Contact::Accepted {
1442 user_id: user2,
1443 should_notify: false
1444 }
1445 ]
1446 );
1447
1448 // User 3 redeems the invite code and becomes a contact of user 1.
1449 let user3 = db
1450 .redeem_invite_code(&invite_code, "user-3", None)
1451 .await
1452 .unwrap();
1453 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1454 assert_eq!(invite_count, 0);
1455 assert_eq!(
1456 db.get_contacts(user1).await.unwrap(),
1457 [
1458 Contact::Accepted {
1459 user_id: user1,
1460 should_notify: false
1461 },
1462 Contact::Accepted {
1463 user_id: user2,
1464 should_notify: true
1465 },
1466 Contact::Accepted {
1467 user_id: user3,
1468 should_notify: true
1469 }
1470 ]
1471 );
1472 assert_eq!(
1473 db.get_contacts(user3).await.unwrap(),
1474 [
1475 Contact::Accepted {
1476 user_id: user1,
1477 should_notify: false
1478 },
1479 Contact::Accepted {
1480 user_id: user3,
1481 should_notify: false
1482 },
1483 ]
1484 );
1485
1486 // Trying to reedem the code for the third time results in an error.
1487 db.redeem_invite_code(&invite_code, "user-4", None)
1488 .await
1489 .unwrap_err();
1490
1491 // Invite count can be updated after the code has been created.
1492 db.set_invite_count(user1, 2).await.unwrap();
1493 let (latest_code, invite_count) =
1494 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1495 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1496 assert_eq!(invite_count, 2);
1497
1498 // User 4 can now redeem the invite code and becomes a contact of user 1.
1499 let user4 = db
1500 .redeem_invite_code(&invite_code, "user-4", None)
1501 .await
1502 .unwrap();
1503 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1504 assert_eq!(invite_count, 1);
1505 assert_eq!(
1506 db.get_contacts(user1).await.unwrap(),
1507 [
1508 Contact::Accepted {
1509 user_id: user1,
1510 should_notify: false
1511 },
1512 Contact::Accepted {
1513 user_id: user2,
1514 should_notify: true
1515 },
1516 Contact::Accepted {
1517 user_id: user3,
1518 should_notify: true
1519 },
1520 Contact::Accepted {
1521 user_id: user4,
1522 should_notify: true
1523 }
1524 ]
1525 );
1526 assert_eq!(
1527 db.get_contacts(user4).await.unwrap(),
1528 [
1529 Contact::Accepted {
1530 user_id: user1,
1531 should_notify: false
1532 },
1533 Contact::Accepted {
1534 user_id: user4,
1535 should_notify: false
1536 },
1537 ]
1538 );
1539
1540 // An existing user cannot redeem invite codes.
1541 db.redeem_invite_code(&invite_code, "user-2", None)
1542 .await
1543 .unwrap_err();
1544 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1545 assert_eq!(invite_count, 1);
1546 }
1547
1548 pub struct TestDb {
1549 pub db: Option<Arc<dyn Db>>,
1550 pub url: String,
1551 }
1552
1553 impl TestDb {
1554 pub async fn postgres() -> Self {
1555 lazy_static! {
1556 static ref LOCK: Mutex<()> = Mutex::new(());
1557 }
1558
1559 let _guard = LOCK.lock();
1560 let mut rng = StdRng::from_entropy();
1561 let name = format!("zed-test-{}", rng.gen::<u128>());
1562 let url = format!("postgres://postgres@localhost/{}", name);
1563 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1564 Postgres::create_database(&url)
1565 .await
1566 .expect("failed to create test db");
1567 let db = PostgresDb::new(&url, 5).await.unwrap();
1568 let migrator = Migrator::new(migrations_path).await.unwrap();
1569 migrator.run(&db.pool).await.unwrap();
1570 Self {
1571 db: Some(Arc::new(db)),
1572 url,
1573 }
1574 }
1575
1576 pub fn fake(background: Arc<Background>) -> Self {
1577 Self {
1578 db: Some(Arc::new(FakeDb::new(background))),
1579 url: Default::default(),
1580 }
1581 }
1582
1583 pub fn db(&self) -> &Arc<dyn Db> {
1584 self.db.as_ref().unwrap()
1585 }
1586 }
1587
1588 impl Drop for TestDb {
1589 fn drop(&mut self) {
1590 if let Some(db) = self.db.take() {
1591 futures::executor::block_on(db.teardown(&self.url));
1592 }
1593 }
1594 }
1595
1596 pub struct FakeDb {
1597 background: Arc<Background>,
1598 pub users: Mutex<BTreeMap<UserId, User>>,
1599 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1600 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1601 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1602 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1603 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1604 pub contacts: Mutex<Vec<FakeContact>>,
1605 next_channel_message_id: Mutex<i32>,
1606 next_user_id: Mutex<i32>,
1607 next_org_id: Mutex<i32>,
1608 next_channel_id: Mutex<i32>,
1609 }
1610
1611 #[derive(Debug)]
1612 pub struct FakeContact {
1613 pub requester_id: UserId,
1614 pub responder_id: UserId,
1615 pub accepted: bool,
1616 pub should_notify: bool,
1617 }
1618
1619 impl FakeDb {
1620 pub fn new(background: Arc<Background>) -> Self {
1621 Self {
1622 background,
1623 users: Default::default(),
1624 next_user_id: Mutex::new(1),
1625 orgs: Default::default(),
1626 next_org_id: Mutex::new(1),
1627 org_memberships: Default::default(),
1628 channels: Default::default(),
1629 next_channel_id: Mutex::new(1),
1630 channel_memberships: Default::default(),
1631 channel_messages: Default::default(),
1632 next_channel_message_id: Mutex::new(1),
1633 contacts: Default::default(),
1634 }
1635 }
1636 }
1637
1638 #[async_trait]
1639 impl Db for FakeDb {
1640 async fn create_user(
1641 &self,
1642 github_login: &str,
1643 email_address: Option<&str>,
1644 admin: bool,
1645 ) -> Result<UserId> {
1646 self.background.simulate_random_delay().await;
1647
1648 let mut users = self.users.lock();
1649 if let Some(user) = users
1650 .values()
1651 .find(|user| user.github_login == github_login)
1652 {
1653 Ok(user.id)
1654 } else {
1655 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1656 users.insert(
1657 user_id,
1658 User {
1659 id: user_id,
1660 github_login: github_login.to_string(),
1661 email_address: email_address.map(str::to_string),
1662 admin,
1663 invite_code: None,
1664 invite_count: 0,
1665 connected_once: false,
1666 },
1667 );
1668 Ok(user_id)
1669 }
1670 }
1671
1672 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1673 unimplemented!()
1674 }
1675
1676 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1677 unimplemented!()
1678 }
1679
1680 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1681 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1682 }
1683
1684 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1685 self.background.simulate_random_delay().await;
1686 let users = self.users.lock();
1687 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1688 }
1689
1690 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1691 Ok(self
1692 .users
1693 .lock()
1694 .values()
1695 .find(|user| user.github_login == github_login)
1696 .cloned())
1697 }
1698
1699 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1700 unimplemented!()
1701 }
1702
1703 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1704 self.background.simulate_random_delay().await;
1705 let mut users = self.users.lock();
1706 let mut user = users
1707 .get_mut(&id)
1708 .ok_or_else(|| anyhow!("user not found"))?;
1709 user.connected_once = connected_once;
1710 Ok(())
1711 }
1712
1713 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1714 unimplemented!()
1715 }
1716
1717 // invite codes
1718
1719 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1720 unimplemented!()
1721 }
1722
1723 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1724 Ok(None)
1725 }
1726
1727 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1728 unimplemented!()
1729 }
1730
1731 async fn redeem_invite_code(
1732 &self,
1733 _code: &str,
1734 _login: &str,
1735 _email_address: Option<&str>,
1736 ) -> Result<UserId> {
1737 unimplemented!()
1738 }
1739
1740 // contacts
1741
1742 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1743 self.background.simulate_random_delay().await;
1744 let mut contacts = vec![Contact::Accepted {
1745 user_id: id,
1746 should_notify: false,
1747 }];
1748
1749 for contact in self.contacts.lock().iter() {
1750 if contact.requester_id == id {
1751 if contact.accepted {
1752 contacts.push(Contact::Accepted {
1753 user_id: contact.responder_id,
1754 should_notify: contact.should_notify,
1755 });
1756 } else {
1757 contacts.push(Contact::Outgoing {
1758 user_id: contact.responder_id,
1759 });
1760 }
1761 } else if contact.responder_id == id {
1762 if contact.accepted {
1763 contacts.push(Contact::Accepted {
1764 user_id: contact.requester_id,
1765 should_notify: false,
1766 });
1767 } else {
1768 contacts.push(Contact::Incoming {
1769 user_id: contact.requester_id,
1770 should_notify: contact.should_notify,
1771 });
1772 }
1773 }
1774 }
1775
1776 contacts.sort_unstable_by_key(|contact| contact.user_id());
1777 Ok(contacts)
1778 }
1779
1780 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1781 self.background.simulate_random_delay().await;
1782 Ok(self.contacts.lock().iter().any(|contact| {
1783 contact.accepted
1784 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1785 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1786 }))
1787 }
1788
1789 async fn send_contact_request(
1790 &self,
1791 requester_id: UserId,
1792 responder_id: UserId,
1793 ) -> Result<()> {
1794 let mut contacts = self.contacts.lock();
1795 for contact in contacts.iter_mut() {
1796 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1797 if contact.accepted {
1798 Err(anyhow!("contact already exists"))?;
1799 } else {
1800 Err(anyhow!("contact already requested"))?;
1801 }
1802 }
1803 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1804 if contact.accepted {
1805 Err(anyhow!("contact already exists"))?;
1806 } else {
1807 contact.accepted = true;
1808 contact.should_notify = false;
1809 return Ok(());
1810 }
1811 }
1812 }
1813 contacts.push(FakeContact {
1814 requester_id,
1815 responder_id,
1816 accepted: false,
1817 should_notify: true,
1818 });
1819 Ok(())
1820 }
1821
1822 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1823 self.contacts.lock().retain(|contact| {
1824 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1825 });
1826 Ok(())
1827 }
1828
1829 async fn dismiss_contact_notification(
1830 &self,
1831 user_id: UserId,
1832 contact_user_id: UserId,
1833 ) -> Result<()> {
1834 let mut contacts = self.contacts.lock();
1835 for contact in contacts.iter_mut() {
1836 if contact.requester_id == contact_user_id
1837 && contact.responder_id == user_id
1838 && !contact.accepted
1839 {
1840 contact.should_notify = false;
1841 return Ok(());
1842 }
1843 if contact.requester_id == user_id
1844 && contact.responder_id == contact_user_id
1845 && contact.accepted
1846 {
1847 contact.should_notify = false;
1848 return Ok(());
1849 }
1850 }
1851 Err(anyhow!("no such notification"))?
1852 }
1853
1854 async fn respond_to_contact_request(
1855 &self,
1856 responder_id: UserId,
1857 requester_id: UserId,
1858 accept: bool,
1859 ) -> Result<()> {
1860 let mut contacts = self.contacts.lock();
1861 for (ix, contact) in contacts.iter_mut().enumerate() {
1862 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1863 if contact.accepted {
1864 Err(anyhow!("contact already confirmed"))?;
1865 }
1866 if accept {
1867 contact.accepted = true;
1868 contact.should_notify = true;
1869 } else {
1870 contacts.remove(ix);
1871 }
1872 return Ok(());
1873 }
1874 }
1875 Err(anyhow!("no such contact request"))?
1876 }
1877
1878 async fn create_access_token_hash(
1879 &self,
1880 _user_id: UserId,
1881 _access_token_hash: &str,
1882 _max_access_token_count: usize,
1883 ) -> Result<()> {
1884 unimplemented!()
1885 }
1886
1887 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1888 unimplemented!()
1889 }
1890
1891 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1892 unimplemented!()
1893 }
1894
1895 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1896 self.background.simulate_random_delay().await;
1897 let mut orgs = self.orgs.lock();
1898 if orgs.values().any(|org| org.slug == slug) {
1899 Err(anyhow!("org already exists"))?
1900 } else {
1901 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1902 orgs.insert(
1903 org_id,
1904 Org {
1905 id: org_id,
1906 name: name.to_string(),
1907 slug: slug.to_string(),
1908 },
1909 );
1910 Ok(org_id)
1911 }
1912 }
1913
1914 async fn add_org_member(
1915 &self,
1916 org_id: OrgId,
1917 user_id: UserId,
1918 is_admin: bool,
1919 ) -> Result<()> {
1920 self.background.simulate_random_delay().await;
1921 if !self.orgs.lock().contains_key(&org_id) {
1922 Err(anyhow!("org does not exist"))?;
1923 }
1924 if !self.users.lock().contains_key(&user_id) {
1925 Err(anyhow!("user does not exist"))?;
1926 }
1927
1928 self.org_memberships
1929 .lock()
1930 .entry((org_id, user_id))
1931 .or_insert(is_admin);
1932 Ok(())
1933 }
1934
1935 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1936 self.background.simulate_random_delay().await;
1937 if !self.orgs.lock().contains_key(&org_id) {
1938 Err(anyhow!("org does not exist"))?;
1939 }
1940
1941 let mut channels = self.channels.lock();
1942 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1943 channels.insert(
1944 channel_id,
1945 Channel {
1946 id: channel_id,
1947 name: name.to_string(),
1948 owner_id: org_id.0,
1949 owner_is_user: false,
1950 },
1951 );
1952 Ok(channel_id)
1953 }
1954
1955 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1956 self.background.simulate_random_delay().await;
1957 Ok(self
1958 .channels
1959 .lock()
1960 .values()
1961 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1962 .cloned()
1963 .collect())
1964 }
1965
1966 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1967 self.background.simulate_random_delay().await;
1968 let channels = self.channels.lock();
1969 let memberships = self.channel_memberships.lock();
1970 Ok(channels
1971 .values()
1972 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1973 .cloned()
1974 .collect())
1975 }
1976
1977 async fn can_user_access_channel(
1978 &self,
1979 user_id: UserId,
1980 channel_id: ChannelId,
1981 ) -> Result<bool> {
1982 self.background.simulate_random_delay().await;
1983 Ok(self
1984 .channel_memberships
1985 .lock()
1986 .contains_key(&(channel_id, user_id)))
1987 }
1988
1989 async fn add_channel_member(
1990 &self,
1991 channel_id: ChannelId,
1992 user_id: UserId,
1993 is_admin: bool,
1994 ) -> Result<()> {
1995 self.background.simulate_random_delay().await;
1996 if !self.channels.lock().contains_key(&channel_id) {
1997 Err(anyhow!("channel does not exist"))?;
1998 }
1999 if !self.users.lock().contains_key(&user_id) {
2000 Err(anyhow!("user does not exist"))?;
2001 }
2002
2003 self.channel_memberships
2004 .lock()
2005 .entry((channel_id, user_id))
2006 .or_insert(is_admin);
2007 Ok(())
2008 }
2009
2010 async fn create_channel_message(
2011 &self,
2012 channel_id: ChannelId,
2013 sender_id: UserId,
2014 body: &str,
2015 timestamp: OffsetDateTime,
2016 nonce: u128,
2017 ) -> Result<MessageId> {
2018 self.background.simulate_random_delay().await;
2019 if !self.channels.lock().contains_key(&channel_id) {
2020 Err(anyhow!("channel does not exist"))?;
2021 }
2022 if !self.users.lock().contains_key(&sender_id) {
2023 Err(anyhow!("user does not exist"))?;
2024 }
2025
2026 let mut messages = self.channel_messages.lock();
2027 if let Some(message) = messages
2028 .values()
2029 .find(|message| message.nonce.as_u128() == nonce)
2030 {
2031 Ok(message.id)
2032 } else {
2033 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2034 messages.insert(
2035 message_id,
2036 ChannelMessage {
2037 id: message_id,
2038 channel_id,
2039 sender_id,
2040 body: body.to_string(),
2041 sent_at: timestamp,
2042 nonce: Uuid::from_u128(nonce),
2043 },
2044 );
2045 Ok(message_id)
2046 }
2047 }
2048
2049 async fn get_channel_messages(
2050 &self,
2051 channel_id: ChannelId,
2052 count: usize,
2053 before_id: Option<MessageId>,
2054 ) -> Result<Vec<ChannelMessage>> {
2055 let mut messages = self
2056 .channel_messages
2057 .lock()
2058 .values()
2059 .rev()
2060 .filter(|message| {
2061 message.channel_id == channel_id
2062 && message.id < before_id.unwrap_or(MessageId::MAX)
2063 })
2064 .take(count)
2065 .cloned()
2066 .collect::<Vec<_>>();
2067 messages.sort_unstable_by_key(|message| message.id);
2068 Ok(messages)
2069 }
2070
2071 async fn teardown(&self, _: &str) {}
2072
2073 #[cfg(test)]
2074 fn as_fake(&self) -> Option<&FakeDb> {
2075 Some(self)
2076 }
2077 }
2078}