1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use futures::StreamExt;
6use nanoid::nanoid;
7use serde::Serialize;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow};
10use time::OffsetDateTime;
11
12#[async_trait]
13pub trait Db: Send + Sync {
14 async fn create_user(
15 &self,
16 github_login: &str,
17 email_address: Option<&str>,
18 admin: bool,
19 ) -> Result<UserId>;
20 async fn get_all_users(&self) -> Result<Vec<User>>;
21 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
22 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
23 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
24 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
25 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
26 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
27 async fn destroy_user(&self, id: UserId) -> Result<()>;
28
29 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
30 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
31 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
32 async fn redeem_invite_code(
33 &self,
34 code: &str,
35 login: &str,
36 email_address: Option<&str>,
37 ) -> Result<UserId>;
38
39 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
40 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
41 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
42 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
43 async fn dismiss_contact_notification(
44 &self,
45 responder_id: UserId,
46 requester_id: UserId,
47 ) -> Result<()>;
48 async fn respond_to_contact_request(
49 &self,
50 responder_id: UserId,
51 requester_id: UserId,
52 accept: bool,
53 ) -> Result<()>;
54
55 async fn create_access_token_hash(
56 &self,
57 user_id: UserId,
58 access_token_hash: &str,
59 max_access_token_count: usize,
60 ) -> Result<()>;
61 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
62 #[cfg(any(test, feature = "seed-support"))]
63
64 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
65 #[cfg(any(test, feature = "seed-support"))]
66 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
67 #[cfg(any(test, feature = "seed-support"))]
68 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
69 #[cfg(any(test, feature = "seed-support"))]
70 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
71 #[cfg(any(test, feature = "seed-support"))]
72
73 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
74 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
75 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
76 -> Result<bool>;
77 #[cfg(any(test, feature = "seed-support"))]
78 async fn add_channel_member(
79 &self,
80 channel_id: ChannelId,
81 user_id: UserId,
82 is_admin: bool,
83 ) -> Result<()>;
84 async fn create_channel_message(
85 &self,
86 channel_id: ChannelId,
87 sender_id: UserId,
88 body: &str,
89 timestamp: OffsetDateTime,
90 nonce: u128,
91 ) -> Result<MessageId>;
92 async fn get_channel_messages(
93 &self,
94 channel_id: ChannelId,
95 count: usize,
96 before_id: Option<MessageId>,
97 ) -> Result<Vec<ChannelMessage>>;
98 #[cfg(test)]
99 async fn teardown(&self, url: &str);
100 #[cfg(test)]
101 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
102}
103
104pub struct PostgresDb {
105 pool: sqlx::PgPool,
106}
107
108impl PostgresDb {
109 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
110 let pool = DbOptions::new()
111 .max_connections(max_connections)
112 .connect(&url)
113 .await
114 .context("failed to connect to postgres database")?;
115 Ok(Self { pool })
116 }
117}
118
119#[async_trait]
120impl Db for PostgresDb {
121 // users
122
123 async fn create_user(
124 &self,
125 github_login: &str,
126 email_address: Option<&str>,
127 admin: bool,
128 ) -> Result<UserId> {
129 let query = "
130 INSERT INTO users (github_login, email_address, admin)
131 VALUES ($1, $2, $3)
132 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
133 RETURNING id
134 ";
135 Ok(sqlx::query_scalar(query)
136 .bind(github_login)
137 .bind(email_address)
138 .bind(admin)
139 .fetch_one(&self.pool)
140 .await
141 .map(UserId)?)
142 }
143
144 async fn get_all_users(&self) -> Result<Vec<User>> {
145 let query = "SELECT * FROM users ORDER BY github_login ASC";
146 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
147 }
148
149 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
150 let like_string = fuzzy_like_string(name_query);
151 let query = "
152 SELECT users.*
153 FROM users
154 WHERE github_login ILIKE $1
155 ORDER BY github_login <-> $2
156 LIMIT $3
157 ";
158 Ok(sqlx::query_as(query)
159 .bind(like_string)
160 .bind(name_query)
161 .bind(limit)
162 .fetch_all(&self.pool)
163 .await?)
164 }
165
166 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
167 let users = self.get_users_by_ids(vec![id]).await?;
168 Ok(users.into_iter().next())
169 }
170
171 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
172 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
173 let query = "
174 SELECT users.*
175 FROM users
176 WHERE users.id = ANY ($1)
177 ";
178 Ok(sqlx::query_as(query)
179 .bind(&ids)
180 .fetch_all(&self.pool)
181 .await?)
182 }
183
184 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
185 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
186 Ok(sqlx::query_as(query)
187 .bind(github_login)
188 .fetch_optional(&self.pool)
189 .await?)
190 }
191
192 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
193 let query = "UPDATE users SET admin = $1 WHERE id = $2";
194 Ok(sqlx::query(query)
195 .bind(is_admin)
196 .bind(id.0)
197 .execute(&self.pool)
198 .await
199 .map(drop)?)
200 }
201
202 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
203 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
204 Ok(sqlx::query(query)
205 .bind(connected_once)
206 .bind(id.0)
207 .execute(&self.pool)
208 .await
209 .map(drop)?)
210 }
211
212 async fn destroy_user(&self, id: UserId) -> Result<()> {
213 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
214 sqlx::query(query)
215 .bind(id.0)
216 .execute(&self.pool)
217 .await
218 .map(drop)?;
219 let query = "DELETE FROM users WHERE id = $1;";
220 Ok(sqlx::query(query)
221 .bind(id.0)
222 .execute(&self.pool)
223 .await
224 .map(drop)?)
225 }
226
227 // invite codes
228
229 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
230 let mut tx = self.pool.begin().await?;
231 if count > 0 {
232 sqlx::query(
233 "
234 UPDATE users
235 SET invite_code = $1
236 WHERE id = $2 AND invite_code IS NULL
237 ",
238 )
239 .bind(nanoid!(16))
240 .bind(id)
241 .execute(&mut tx)
242 .await?;
243 }
244
245 sqlx::query(
246 "
247 UPDATE users
248 SET invite_count = $1
249 WHERE id = $2
250 ",
251 )
252 .bind(count)
253 .bind(id)
254 .execute(&mut tx)
255 .await?;
256 tx.commit().await?;
257 Ok(())
258 }
259
260 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
261 let result: Option<(String, i32)> = sqlx::query_as(
262 "
263 SELECT invite_code, invite_count
264 FROM users
265 WHERE id = $1 AND invite_code IS NOT NULL
266 ",
267 )
268 .bind(id)
269 .fetch_optional(&self.pool)
270 .await?;
271 if let Some((code, count)) = result {
272 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
273 } else {
274 Ok(None)
275 }
276 }
277
278 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
279 sqlx::query_as(
280 "
281 SELECT *
282 FROM users
283 WHERE invite_code = $1
284 ",
285 )
286 .bind(code)
287 .fetch_optional(&self.pool)
288 .await?
289 .ok_or_else(|| {
290 Error::Http(
291 StatusCode::NOT_FOUND,
292 "that invite code does not exist".to_string(),
293 )
294 })
295 }
296
297 async fn redeem_invite_code(
298 &self,
299 code: &str,
300 login: &str,
301 email_address: Option<&str>,
302 ) -> Result<UserId> {
303 let mut tx = self.pool.begin().await?;
304
305 let inviter_id: Option<UserId> = sqlx::query_scalar(
306 "
307 UPDATE users
308 SET invite_count = invite_count - 1
309 WHERE
310 invite_code = $1 AND
311 invite_count > 0
312 RETURNING id
313 ",
314 )
315 .bind(code)
316 .fetch_optional(&mut tx)
317 .await?;
318
319 let inviter_id = match inviter_id {
320 Some(inviter_id) => inviter_id,
321 None => {
322 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
323 .bind(code)
324 .fetch_optional(&mut tx)
325 .await?
326 .is_some()
327 {
328 Err(Error::Http(
329 StatusCode::UNAUTHORIZED,
330 "no invites remaining".to_string(),
331 ))?
332 } else {
333 Err(Error::Http(
334 StatusCode::NOT_FOUND,
335 "invite code not found".to_string(),
336 ))?
337 }
338 }
339 };
340
341 let invitee_id = sqlx::query_scalar(
342 "
343 INSERT INTO users
344 (github_login, email_address, admin, inviter_id)
345 VALUES
346 ($1, $2, 'f', $3)
347 RETURNING id
348 ",
349 )
350 .bind(login)
351 .bind(email_address)
352 .bind(inviter_id)
353 .fetch_one(&mut tx)
354 .await
355 .map(UserId)?;
356
357 sqlx::query(
358 "
359 INSERT INTO contacts
360 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
361 VALUES
362 ($1, $2, 't', 't', 't')
363 ",
364 )
365 .bind(inviter_id)
366 .bind(invitee_id)
367 .execute(&mut tx)
368 .await?;
369
370 tx.commit().await?;
371 Ok(invitee_id)
372 }
373
374 // contacts
375
376 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
377 let query = "
378 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
379 FROM contacts
380 WHERE user_id_a = $1 OR user_id_b = $1;
381 ";
382
383 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
384 .bind(user_id)
385 .fetch(&self.pool);
386
387 let mut contacts = vec![Contact::Accepted {
388 user_id,
389 should_notify: false,
390 }];
391 while let Some(row) = rows.next().await {
392 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
393
394 if user_id_a == user_id {
395 if accepted {
396 contacts.push(Contact::Accepted {
397 user_id: user_id_b,
398 should_notify: should_notify && a_to_b,
399 });
400 } else if a_to_b {
401 contacts.push(Contact::Outgoing { user_id: user_id_b })
402 } else {
403 contacts.push(Contact::Incoming {
404 user_id: user_id_b,
405 should_notify,
406 });
407 }
408 } else {
409 if accepted {
410 contacts.push(Contact::Accepted {
411 user_id: user_id_a,
412 should_notify: should_notify && !a_to_b,
413 });
414 } else if a_to_b {
415 contacts.push(Contact::Incoming {
416 user_id: user_id_a,
417 should_notify,
418 });
419 } else {
420 contacts.push(Contact::Outgoing { user_id: user_id_a });
421 }
422 }
423 }
424
425 contacts.sort_unstable_by_key(|contact| contact.user_id());
426
427 Ok(contacts)
428 }
429
430 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
431 let (id_a, id_b) = if user_id_1 < user_id_2 {
432 (user_id_1, user_id_2)
433 } else {
434 (user_id_2, user_id_1)
435 };
436
437 let query = "
438 SELECT 1 FROM contacts
439 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
440 LIMIT 1
441 ";
442 Ok(sqlx::query_scalar::<_, i32>(query)
443 .bind(id_a.0)
444 .bind(id_b.0)
445 .fetch_optional(&self.pool)
446 .await?
447 .is_some())
448 }
449
450 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
451 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
452 (sender_id, receiver_id, true)
453 } else {
454 (receiver_id, sender_id, false)
455 };
456 let query = "
457 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
458 VALUES ($1, $2, $3, 'f', 't')
459 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
460 SET
461 accepted = 't',
462 should_notify = 'f'
463 WHERE
464 NOT contacts.accepted AND
465 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
466 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
467 ";
468 let result = sqlx::query(query)
469 .bind(id_a.0)
470 .bind(id_b.0)
471 .bind(a_to_b)
472 .execute(&self.pool)
473 .await?;
474
475 if result.rows_affected() == 1 {
476 Ok(())
477 } else {
478 Err(anyhow!("contact already requested"))?
479 }
480 }
481
482 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
483 let (id_a, id_b) = if responder_id < requester_id {
484 (responder_id, requester_id)
485 } else {
486 (requester_id, responder_id)
487 };
488 let query = "
489 DELETE FROM contacts
490 WHERE user_id_a = $1 AND user_id_b = $2;
491 ";
492 let result = sqlx::query(query)
493 .bind(id_a.0)
494 .bind(id_b.0)
495 .execute(&self.pool)
496 .await?;
497
498 if result.rows_affected() == 1 {
499 Ok(())
500 } else {
501 Err(anyhow!("no such contact"))?
502 }
503 }
504
505 async fn dismiss_contact_notification(
506 &self,
507 user_id: UserId,
508 contact_user_id: UserId,
509 ) -> Result<()> {
510 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
511 (user_id, contact_user_id, true)
512 } else {
513 (contact_user_id, user_id, false)
514 };
515
516 let query = "
517 UPDATE contacts
518 SET should_notify = 'f'
519 WHERE
520 user_id_a = $1 AND user_id_b = $2 AND
521 (
522 (a_to_b = $3 AND accepted) OR
523 (a_to_b != $3 AND NOT accepted)
524 );
525 ";
526
527 let result = sqlx::query(query)
528 .bind(id_a.0)
529 .bind(id_b.0)
530 .bind(a_to_b)
531 .execute(&self.pool)
532 .await?;
533
534 if result.rows_affected() == 0 {
535 Err(anyhow!("no such contact request"))?;
536 }
537
538 Ok(())
539 }
540
541 async fn respond_to_contact_request(
542 &self,
543 responder_id: UserId,
544 requester_id: UserId,
545 accept: bool,
546 ) -> Result<()> {
547 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
548 (responder_id, requester_id, false)
549 } else {
550 (requester_id, responder_id, true)
551 };
552 let result = if accept {
553 let query = "
554 UPDATE contacts
555 SET accepted = 't', should_notify = 't'
556 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
557 ";
558 sqlx::query(query)
559 .bind(id_a.0)
560 .bind(id_b.0)
561 .bind(a_to_b)
562 .execute(&self.pool)
563 .await?
564 } else {
565 let query = "
566 DELETE FROM contacts
567 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
568 ";
569 sqlx::query(query)
570 .bind(id_a.0)
571 .bind(id_b.0)
572 .bind(a_to_b)
573 .execute(&self.pool)
574 .await?
575 };
576 if result.rows_affected() == 1 {
577 Ok(())
578 } else {
579 Err(anyhow!("no such contact request"))?
580 }
581 }
582
583 // access tokens
584
585 async fn create_access_token_hash(
586 &self,
587 user_id: UserId,
588 access_token_hash: &str,
589 max_access_token_count: usize,
590 ) -> Result<()> {
591 let insert_query = "
592 INSERT INTO access_tokens (user_id, hash)
593 VALUES ($1, $2);
594 ";
595 let cleanup_query = "
596 DELETE FROM access_tokens
597 WHERE id IN (
598 SELECT id from access_tokens
599 WHERE user_id = $1
600 ORDER BY id DESC
601 OFFSET $3
602 )
603 ";
604
605 let mut tx = self.pool.begin().await?;
606 sqlx::query(insert_query)
607 .bind(user_id.0)
608 .bind(access_token_hash)
609 .execute(&mut tx)
610 .await?;
611 sqlx::query(cleanup_query)
612 .bind(user_id.0)
613 .bind(access_token_hash)
614 .bind(max_access_token_count as u32)
615 .execute(&mut tx)
616 .await?;
617 Ok(tx.commit().await?)
618 }
619
620 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
621 let query = "
622 SELECT hash
623 FROM access_tokens
624 WHERE user_id = $1
625 ORDER BY id DESC
626 ";
627 Ok(sqlx::query_scalar(query)
628 .bind(user_id.0)
629 .fetch_all(&self.pool)
630 .await?)
631 }
632
633 // orgs
634
635 #[allow(unused)] // Help rust-analyzer
636 #[cfg(any(test, feature = "seed-support"))]
637 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
638 let query = "
639 SELECT *
640 FROM orgs
641 WHERE slug = $1
642 ";
643 Ok(sqlx::query_as(query)
644 .bind(slug)
645 .fetch_optional(&self.pool)
646 .await?)
647 }
648
649 #[cfg(any(test, feature = "seed-support"))]
650 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
651 let query = "
652 INSERT INTO orgs (name, slug)
653 VALUES ($1, $2)
654 RETURNING id
655 ";
656 Ok(sqlx::query_scalar(query)
657 .bind(name)
658 .bind(slug)
659 .fetch_one(&self.pool)
660 .await
661 .map(OrgId)?)
662 }
663
664 #[cfg(any(test, feature = "seed-support"))]
665 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
666 let query = "
667 INSERT INTO org_memberships (org_id, user_id, admin)
668 VALUES ($1, $2, $3)
669 ON CONFLICT DO NOTHING
670 ";
671 Ok(sqlx::query(query)
672 .bind(org_id.0)
673 .bind(user_id.0)
674 .bind(is_admin)
675 .execute(&self.pool)
676 .await
677 .map(drop)?)
678 }
679
680 // channels
681
682 #[cfg(any(test, feature = "seed-support"))]
683 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
684 let query = "
685 INSERT INTO channels (owner_id, owner_is_user, name)
686 VALUES ($1, false, $2)
687 RETURNING id
688 ";
689 Ok(sqlx::query_scalar(query)
690 .bind(org_id.0)
691 .bind(name)
692 .fetch_one(&self.pool)
693 .await
694 .map(ChannelId)?)
695 }
696
697 #[allow(unused)] // Help rust-analyzer
698 #[cfg(any(test, feature = "seed-support"))]
699 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
700 let query = "
701 SELECT *
702 FROM channels
703 WHERE
704 channels.owner_is_user = false AND
705 channels.owner_id = $1
706 ";
707 Ok(sqlx::query_as(query)
708 .bind(org_id.0)
709 .fetch_all(&self.pool)
710 .await?)
711 }
712
713 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
714 let query = "
715 SELECT
716 channels.*
717 FROM
718 channel_memberships, channels
719 WHERE
720 channel_memberships.user_id = $1 AND
721 channel_memberships.channel_id = channels.id
722 ";
723 Ok(sqlx::query_as(query)
724 .bind(user_id.0)
725 .fetch_all(&self.pool)
726 .await?)
727 }
728
729 async fn can_user_access_channel(
730 &self,
731 user_id: UserId,
732 channel_id: ChannelId,
733 ) -> Result<bool> {
734 let query = "
735 SELECT id
736 FROM channel_memberships
737 WHERE user_id = $1 AND channel_id = $2
738 LIMIT 1
739 ";
740 Ok(sqlx::query_scalar::<_, i32>(query)
741 .bind(user_id.0)
742 .bind(channel_id.0)
743 .fetch_optional(&self.pool)
744 .await
745 .map(|e| e.is_some())?)
746 }
747
748 #[cfg(any(test, feature = "seed-support"))]
749 async fn add_channel_member(
750 &self,
751 channel_id: ChannelId,
752 user_id: UserId,
753 is_admin: bool,
754 ) -> Result<()> {
755 let query = "
756 INSERT INTO channel_memberships (channel_id, user_id, admin)
757 VALUES ($1, $2, $3)
758 ON CONFLICT DO NOTHING
759 ";
760 Ok(sqlx::query(query)
761 .bind(channel_id.0)
762 .bind(user_id.0)
763 .bind(is_admin)
764 .execute(&self.pool)
765 .await
766 .map(drop)?)
767 }
768
769 // messages
770
771 async fn create_channel_message(
772 &self,
773 channel_id: ChannelId,
774 sender_id: UserId,
775 body: &str,
776 timestamp: OffsetDateTime,
777 nonce: u128,
778 ) -> Result<MessageId> {
779 let query = "
780 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
781 VALUES ($1, $2, $3, $4, $5)
782 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
783 RETURNING id
784 ";
785 Ok(sqlx::query_scalar(query)
786 .bind(channel_id.0)
787 .bind(sender_id.0)
788 .bind(body)
789 .bind(timestamp)
790 .bind(Uuid::from_u128(nonce))
791 .fetch_one(&self.pool)
792 .await
793 .map(MessageId)?)
794 }
795
796 async fn get_channel_messages(
797 &self,
798 channel_id: ChannelId,
799 count: usize,
800 before_id: Option<MessageId>,
801 ) -> Result<Vec<ChannelMessage>> {
802 let query = r#"
803 SELECT * FROM (
804 SELECT
805 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
806 FROM
807 channel_messages
808 WHERE
809 channel_id = $1 AND
810 id < $2
811 ORDER BY id DESC
812 LIMIT $3
813 ) as recent_messages
814 ORDER BY id ASC
815 "#;
816 Ok(sqlx::query_as(query)
817 .bind(channel_id.0)
818 .bind(before_id.unwrap_or(MessageId::MAX))
819 .bind(count as i64)
820 .fetch_all(&self.pool)
821 .await?)
822 }
823
824 #[cfg(test)]
825 async fn teardown(&self, url: &str) {
826 use util::ResultExt;
827
828 let query = "
829 SELECT pg_terminate_backend(pg_stat_activity.pid)
830 FROM pg_stat_activity
831 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
832 ";
833 sqlx::query(query).execute(&self.pool).await.log_err();
834 self.pool.close().await;
835 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
836 .await
837 .log_err();
838 }
839
840 #[cfg(test)]
841 fn as_fake(&self) -> Option<&tests::FakeDb> {
842 None
843 }
844}
845
846macro_rules! id_type {
847 ($name:ident) => {
848 #[derive(
849 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
850 )]
851 #[sqlx(transparent)]
852 #[serde(transparent)]
853 pub struct $name(pub i32);
854
855 impl $name {
856 #[allow(unused)]
857 pub const MAX: Self = Self(i32::MAX);
858
859 #[allow(unused)]
860 pub fn from_proto(value: u64) -> Self {
861 Self(value as i32)
862 }
863
864 #[allow(unused)]
865 pub fn to_proto(&self) -> u64 {
866 self.0 as u64
867 }
868 }
869
870 impl std::fmt::Display for $name {
871 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
872 self.0.fmt(f)
873 }
874 }
875 };
876}
877
878id_type!(UserId);
879#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
880pub struct User {
881 pub id: UserId,
882 pub github_login: String,
883 pub email_address: Option<String>,
884 pub admin: bool,
885 pub invite_code: Option<String>,
886 pub invite_count: i32,
887 pub connected_once: bool,
888}
889
890id_type!(OrgId);
891#[derive(FromRow)]
892pub struct Org {
893 pub id: OrgId,
894 pub name: String,
895 pub slug: String,
896}
897
898id_type!(ChannelId);
899#[derive(Clone, Debug, FromRow, Serialize)]
900pub struct Channel {
901 pub id: ChannelId,
902 pub name: String,
903 pub owner_id: i32,
904 pub owner_is_user: bool,
905}
906
907id_type!(MessageId);
908#[derive(Clone, Debug, FromRow)]
909pub struct ChannelMessage {
910 pub id: MessageId,
911 pub channel_id: ChannelId,
912 pub sender_id: UserId,
913 pub body: String,
914 pub sent_at: OffsetDateTime,
915 pub nonce: Uuid,
916}
917
918#[derive(Clone, Debug, PartialEq, Eq)]
919pub enum Contact {
920 Accepted {
921 user_id: UserId,
922 should_notify: bool,
923 },
924 Outgoing {
925 user_id: UserId,
926 },
927 Incoming {
928 user_id: UserId,
929 should_notify: bool,
930 },
931}
932
933impl Contact {
934 pub fn user_id(&self) -> UserId {
935 match self {
936 Contact::Accepted { user_id, .. } => *user_id,
937 Contact::Outgoing { user_id } => *user_id,
938 Contact::Incoming { user_id, .. } => *user_id,
939 }
940 }
941}
942
943#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
944pub struct IncomingContactRequest {
945 pub requester_id: UserId,
946 pub should_notify: bool,
947}
948
949fn fuzzy_like_string(string: &str) -> String {
950 let mut result = String::with_capacity(string.len() * 2 + 1);
951 for c in string.chars() {
952 if c.is_alphanumeric() {
953 result.push('%');
954 result.push(c);
955 }
956 }
957 result.push('%');
958 result
959}
960
961#[cfg(test)]
962pub mod tests {
963 use super::*;
964 use anyhow::anyhow;
965 use collections::BTreeMap;
966 use gpui::executor::Background;
967 use lazy_static::lazy_static;
968 use parking_lot::Mutex;
969 use rand::prelude::*;
970 use sqlx::{
971 migrate::{MigrateDatabase, Migrator},
972 Postgres,
973 };
974 use std::{path::Path, sync::Arc};
975 use util::post_inc;
976
977 #[tokio::test(flavor = "multi_thread")]
978 async fn test_get_users_by_ids() {
979 for test_db in [
980 TestDb::postgres().await,
981 TestDb::fake(Arc::new(gpui::executor::Background::new())),
982 ] {
983 let db = test_db.db();
984
985 let user = db.create_user("user", None, false).await.unwrap();
986 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
987 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
988 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
989
990 assert_eq!(
991 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
992 .await
993 .unwrap(),
994 vec![
995 User {
996 id: user,
997 github_login: "user".to_string(),
998 admin: false,
999 ..Default::default()
1000 },
1001 User {
1002 id: friend1,
1003 github_login: "friend-1".to_string(),
1004 admin: false,
1005 ..Default::default()
1006 },
1007 User {
1008 id: friend2,
1009 github_login: "friend-2".to_string(),
1010 admin: false,
1011 ..Default::default()
1012 },
1013 User {
1014 id: friend3,
1015 github_login: "friend-3".to_string(),
1016 admin: false,
1017 ..Default::default()
1018 }
1019 ]
1020 );
1021 }
1022 }
1023
1024 #[tokio::test(flavor = "multi_thread")]
1025 async fn test_recent_channel_messages() {
1026 for test_db in [
1027 TestDb::postgres().await,
1028 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1029 ] {
1030 let db = test_db.db();
1031 let user = db.create_user("user", None, false).await.unwrap();
1032 let org = db.create_org("org", "org").await.unwrap();
1033 let channel = db.create_org_channel(org, "channel").await.unwrap();
1034 for i in 0..10 {
1035 db.create_channel_message(
1036 channel,
1037 user,
1038 &i.to_string(),
1039 OffsetDateTime::now_utc(),
1040 i,
1041 )
1042 .await
1043 .unwrap();
1044 }
1045
1046 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1047 assert_eq!(
1048 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1049 ["5", "6", "7", "8", "9"]
1050 );
1051
1052 let prev_messages = db
1053 .get_channel_messages(channel, 4, Some(messages[0].id))
1054 .await
1055 .unwrap();
1056 assert_eq!(
1057 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1058 ["1", "2", "3", "4"]
1059 );
1060 }
1061 }
1062
1063 #[tokio::test(flavor = "multi_thread")]
1064 async fn test_channel_message_nonces() {
1065 for test_db in [
1066 TestDb::postgres().await,
1067 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1068 ] {
1069 let db = test_db.db();
1070 let user = db.create_user("user", None, false).await.unwrap();
1071 let org = db.create_org("org", "org").await.unwrap();
1072 let channel = db.create_org_channel(org, "channel").await.unwrap();
1073
1074 let msg1_id = db
1075 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1076 .await
1077 .unwrap();
1078 let msg2_id = db
1079 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1080 .await
1081 .unwrap();
1082 let msg3_id = db
1083 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1084 .await
1085 .unwrap();
1086 let msg4_id = db
1087 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1088 .await
1089 .unwrap();
1090
1091 assert_ne!(msg1_id, msg2_id);
1092 assert_eq!(msg1_id, msg3_id);
1093 assert_eq!(msg2_id, msg4_id);
1094 }
1095 }
1096
1097 #[tokio::test(flavor = "multi_thread")]
1098 async fn test_create_access_tokens() {
1099 let test_db = TestDb::postgres().await;
1100 let db = test_db.db();
1101 let user = db.create_user("the-user", None, false).await.unwrap();
1102
1103 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1104 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1105 assert_eq!(
1106 db.get_access_token_hashes(user).await.unwrap(),
1107 &["h2".to_string(), "h1".to_string()]
1108 );
1109
1110 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1111 assert_eq!(
1112 db.get_access_token_hashes(user).await.unwrap(),
1113 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1114 );
1115
1116 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1117 assert_eq!(
1118 db.get_access_token_hashes(user).await.unwrap(),
1119 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1120 );
1121
1122 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1123 assert_eq!(
1124 db.get_access_token_hashes(user).await.unwrap(),
1125 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1126 );
1127 }
1128
1129 #[test]
1130 fn test_fuzzy_like_string() {
1131 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1132 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1133 assert_eq!(fuzzy_like_string(" z "), "%z%");
1134 }
1135
1136 #[tokio::test(flavor = "multi_thread")]
1137 async fn test_fuzzy_search_users() {
1138 let test_db = TestDb::postgres().await;
1139 let db = test_db.db();
1140 for github_login in [
1141 "California",
1142 "colorado",
1143 "oregon",
1144 "washington",
1145 "florida",
1146 "delaware",
1147 "rhode-island",
1148 ] {
1149 db.create_user(github_login, None, false).await.unwrap();
1150 }
1151
1152 assert_eq!(
1153 fuzzy_search_user_names(db, "clr").await,
1154 &["colorado", "California"]
1155 );
1156 assert_eq!(
1157 fuzzy_search_user_names(db, "ro").await,
1158 &["rhode-island", "colorado", "oregon"],
1159 );
1160
1161 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1162 db.fuzzy_search_users(query, 10)
1163 .await
1164 .unwrap()
1165 .into_iter()
1166 .map(|user| user.github_login)
1167 .collect::<Vec<_>>()
1168 }
1169 }
1170
1171 #[tokio::test(flavor = "multi_thread")]
1172 async fn test_add_contacts() {
1173 for test_db in [
1174 TestDb::postgres().await,
1175 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1176 ] {
1177 let db = test_db.db();
1178
1179 let user_1 = db.create_user("user1", None, false).await.unwrap();
1180 let user_2 = db.create_user("user2", None, false).await.unwrap();
1181 let user_3 = db.create_user("user3", None, false).await.unwrap();
1182
1183 // User starts with no contacts
1184 assert_eq!(
1185 db.get_contacts(user_1).await.unwrap(),
1186 vec![Contact::Accepted {
1187 user_id: user_1,
1188 should_notify: false
1189 }],
1190 );
1191
1192 // User requests a contact. Both users see the pending request.
1193 db.send_contact_request(user_1, user_2).await.unwrap();
1194 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1195 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1196 assert_eq!(
1197 db.get_contacts(user_1).await.unwrap(),
1198 &[
1199 Contact::Accepted {
1200 user_id: user_1,
1201 should_notify: false
1202 },
1203 Contact::Outgoing { user_id: user_2 }
1204 ],
1205 );
1206 assert_eq!(
1207 db.get_contacts(user_2).await.unwrap(),
1208 &[
1209 Contact::Incoming {
1210 user_id: user_1,
1211 should_notify: true
1212 },
1213 Contact::Accepted {
1214 user_id: user_2,
1215 should_notify: false
1216 },
1217 ]
1218 );
1219
1220 // User 2 dismisses the contact request notification without accepting or rejecting.
1221 // We shouldn't notify them again.
1222 db.dismiss_contact_notification(user_1, user_2)
1223 .await
1224 .unwrap_err();
1225 db.dismiss_contact_notification(user_2, user_1)
1226 .await
1227 .unwrap();
1228 assert_eq!(
1229 db.get_contacts(user_2).await.unwrap(),
1230 &[
1231 Contact::Incoming {
1232 user_id: user_1,
1233 should_notify: false
1234 },
1235 Contact::Accepted {
1236 user_id: user_2,
1237 should_notify: false
1238 },
1239 ]
1240 );
1241
1242 // User can't accept their own contact request
1243 db.respond_to_contact_request(user_1, user_2, true)
1244 .await
1245 .unwrap_err();
1246
1247 // User accepts a contact request. Both users see the contact.
1248 db.respond_to_contact_request(user_2, user_1, true)
1249 .await
1250 .unwrap();
1251 assert_eq!(
1252 db.get_contacts(user_1).await.unwrap(),
1253 &[
1254 Contact::Accepted {
1255 user_id: user_1,
1256 should_notify: false
1257 },
1258 Contact::Accepted {
1259 user_id: user_2,
1260 should_notify: true
1261 }
1262 ],
1263 );
1264 assert!(db.has_contact(user_1, user_2).await.unwrap());
1265 assert!(db.has_contact(user_2, user_1).await.unwrap());
1266 assert_eq!(
1267 db.get_contacts(user_2).await.unwrap(),
1268 &[
1269 Contact::Accepted {
1270 user_id: user_1,
1271 should_notify: false,
1272 },
1273 Contact::Accepted {
1274 user_id: user_2,
1275 should_notify: false,
1276 },
1277 ]
1278 );
1279
1280 // Users cannot re-request existing contacts.
1281 db.send_contact_request(user_1, user_2).await.unwrap_err();
1282 db.send_contact_request(user_2, user_1).await.unwrap_err();
1283
1284 // Users can't dismiss notifications of them accepting other users' requests.
1285 db.dismiss_contact_notification(user_2, user_1)
1286 .await
1287 .unwrap_err();
1288 assert_eq!(
1289 db.get_contacts(user_1).await.unwrap(),
1290 &[
1291 Contact::Accepted {
1292 user_id: user_1,
1293 should_notify: false
1294 },
1295 Contact::Accepted {
1296 user_id: user_2,
1297 should_notify: true,
1298 },
1299 ]
1300 );
1301
1302 // Users can dismiss notifications of other users accepting their requests.
1303 db.dismiss_contact_notification(user_1, user_2)
1304 .await
1305 .unwrap();
1306 assert_eq!(
1307 db.get_contacts(user_1).await.unwrap(),
1308 &[
1309 Contact::Accepted {
1310 user_id: user_1,
1311 should_notify: false
1312 },
1313 Contact::Accepted {
1314 user_id: user_2,
1315 should_notify: false,
1316 },
1317 ]
1318 );
1319
1320 // Users send each other concurrent contact requests and
1321 // see that they are immediately accepted.
1322 db.send_contact_request(user_1, user_3).await.unwrap();
1323 db.send_contact_request(user_3, user_1).await.unwrap();
1324 assert_eq!(
1325 db.get_contacts(user_1).await.unwrap(),
1326 &[
1327 Contact::Accepted {
1328 user_id: user_1,
1329 should_notify: false
1330 },
1331 Contact::Accepted {
1332 user_id: user_2,
1333 should_notify: false,
1334 },
1335 Contact::Accepted {
1336 user_id: user_3,
1337 should_notify: false
1338 },
1339 ]
1340 );
1341 assert_eq!(
1342 db.get_contacts(user_3).await.unwrap(),
1343 &[
1344 Contact::Accepted {
1345 user_id: user_1,
1346 should_notify: false
1347 },
1348 Contact::Accepted {
1349 user_id: user_3,
1350 should_notify: false
1351 }
1352 ],
1353 );
1354
1355 // User declines a contact request. Both users see that it is gone.
1356 db.send_contact_request(user_2, user_3).await.unwrap();
1357 db.respond_to_contact_request(user_3, user_2, false)
1358 .await
1359 .unwrap();
1360 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1361 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1362 assert_eq!(
1363 db.get_contacts(user_2).await.unwrap(),
1364 &[
1365 Contact::Accepted {
1366 user_id: user_1,
1367 should_notify: false
1368 },
1369 Contact::Accepted {
1370 user_id: user_2,
1371 should_notify: false
1372 }
1373 ]
1374 );
1375 assert_eq!(
1376 db.get_contacts(user_3).await.unwrap(),
1377 &[
1378 Contact::Accepted {
1379 user_id: user_1,
1380 should_notify: false
1381 },
1382 Contact::Accepted {
1383 user_id: user_3,
1384 should_notify: false
1385 }
1386 ],
1387 );
1388 }
1389 }
1390
1391 #[tokio::test(flavor = "multi_thread")]
1392 async fn test_invite_codes() {
1393 let postgres = TestDb::postgres().await;
1394 let db = postgres.db();
1395 let user1 = db.create_user("user-1", None, false).await.unwrap();
1396
1397 // Initially, user 1 has no invite code
1398 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1399
1400 // Setting invite count to 0 when no code is assigned does not assign a new code
1401 db.set_invite_count(user1, 0).await.unwrap();
1402 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1403
1404 // User 1 creates an invite code that can be used twice.
1405 db.set_invite_count(user1, 2).await.unwrap();
1406 let (invite_code, invite_count) =
1407 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1408 assert_eq!(invite_count, 2);
1409
1410 // User 2 redeems the invite code and becomes a contact of user 1.
1411 let user2 = db
1412 .redeem_invite_code(&invite_code, "user-2", None)
1413 .await
1414 .unwrap();
1415 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1416 assert_eq!(invite_count, 1);
1417 assert_eq!(
1418 db.get_contacts(user1).await.unwrap(),
1419 [
1420 Contact::Accepted {
1421 user_id: user1,
1422 should_notify: false
1423 },
1424 Contact::Accepted {
1425 user_id: user2,
1426 should_notify: true
1427 }
1428 ]
1429 );
1430 assert_eq!(
1431 db.get_contacts(user2).await.unwrap(),
1432 [
1433 Contact::Accepted {
1434 user_id: user1,
1435 should_notify: false
1436 },
1437 Contact::Accepted {
1438 user_id: user2,
1439 should_notify: false
1440 }
1441 ]
1442 );
1443
1444 // User 3 redeems the invite code and becomes a contact of user 1.
1445 let user3 = db
1446 .redeem_invite_code(&invite_code, "user-3", None)
1447 .await
1448 .unwrap();
1449 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1450 assert_eq!(invite_count, 0);
1451 assert_eq!(
1452 db.get_contacts(user1).await.unwrap(),
1453 [
1454 Contact::Accepted {
1455 user_id: user1,
1456 should_notify: false
1457 },
1458 Contact::Accepted {
1459 user_id: user2,
1460 should_notify: true
1461 },
1462 Contact::Accepted {
1463 user_id: user3,
1464 should_notify: true
1465 }
1466 ]
1467 );
1468 assert_eq!(
1469 db.get_contacts(user3).await.unwrap(),
1470 [
1471 Contact::Accepted {
1472 user_id: user1,
1473 should_notify: false
1474 },
1475 Contact::Accepted {
1476 user_id: user3,
1477 should_notify: false
1478 },
1479 ]
1480 );
1481
1482 // Trying to reedem the code for the third time results in an error.
1483 db.redeem_invite_code(&invite_code, "user-4", None)
1484 .await
1485 .unwrap_err();
1486
1487 // Invite count can be updated after the code has been created.
1488 db.set_invite_count(user1, 2).await.unwrap();
1489 let (latest_code, invite_count) =
1490 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1491 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1492 assert_eq!(invite_count, 2);
1493
1494 // User 4 can now redeem the invite code and becomes a contact of user 1.
1495 let user4 = db
1496 .redeem_invite_code(&invite_code, "user-4", None)
1497 .await
1498 .unwrap();
1499 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1500 assert_eq!(invite_count, 1);
1501 assert_eq!(
1502 db.get_contacts(user1).await.unwrap(),
1503 [
1504 Contact::Accepted {
1505 user_id: user1,
1506 should_notify: false
1507 },
1508 Contact::Accepted {
1509 user_id: user2,
1510 should_notify: true
1511 },
1512 Contact::Accepted {
1513 user_id: user3,
1514 should_notify: true
1515 },
1516 Contact::Accepted {
1517 user_id: user4,
1518 should_notify: true
1519 }
1520 ]
1521 );
1522 assert_eq!(
1523 db.get_contacts(user4).await.unwrap(),
1524 [
1525 Contact::Accepted {
1526 user_id: user1,
1527 should_notify: false
1528 },
1529 Contact::Accepted {
1530 user_id: user4,
1531 should_notify: false
1532 },
1533 ]
1534 );
1535
1536 // An existing user cannot redeem invite codes.
1537 db.redeem_invite_code(&invite_code, "user-2", None)
1538 .await
1539 .unwrap_err();
1540 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1541 assert_eq!(invite_count, 1);
1542 }
1543
1544 pub struct TestDb {
1545 pub db: Option<Arc<dyn Db>>,
1546 pub url: String,
1547 }
1548
1549 impl TestDb {
1550 pub async fn postgres() -> Self {
1551 lazy_static! {
1552 static ref LOCK: Mutex<()> = Mutex::new(());
1553 }
1554
1555 let _guard = LOCK.lock();
1556 let mut rng = StdRng::from_entropy();
1557 let name = format!("zed-test-{}", rng.gen::<u128>());
1558 let url = format!("postgres://postgres@localhost/{}", name);
1559 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1560 Postgres::create_database(&url)
1561 .await
1562 .expect("failed to create test db");
1563 let db = PostgresDb::new(&url, 5).await.unwrap();
1564 let migrator = Migrator::new(migrations_path).await.unwrap();
1565 migrator.run(&db.pool).await.unwrap();
1566 Self {
1567 db: Some(Arc::new(db)),
1568 url,
1569 }
1570 }
1571
1572 pub fn fake(background: Arc<Background>) -> Self {
1573 Self {
1574 db: Some(Arc::new(FakeDb::new(background))),
1575 url: Default::default(),
1576 }
1577 }
1578
1579 pub fn db(&self) -> &Arc<dyn Db> {
1580 self.db.as_ref().unwrap()
1581 }
1582 }
1583
1584 impl Drop for TestDb {
1585 fn drop(&mut self) {
1586 if let Some(db) = self.db.take() {
1587 futures::executor::block_on(db.teardown(&self.url));
1588 }
1589 }
1590 }
1591
1592 pub struct FakeDb {
1593 background: Arc<Background>,
1594 pub users: Mutex<BTreeMap<UserId, User>>,
1595 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1596 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1597 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1598 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1599 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1600 pub contacts: Mutex<Vec<FakeContact>>,
1601 next_channel_message_id: Mutex<i32>,
1602 next_user_id: Mutex<i32>,
1603 next_org_id: Mutex<i32>,
1604 next_channel_id: Mutex<i32>,
1605 }
1606
1607 #[derive(Debug)]
1608 pub struct FakeContact {
1609 pub requester_id: UserId,
1610 pub responder_id: UserId,
1611 pub accepted: bool,
1612 pub should_notify: bool,
1613 }
1614
1615 impl FakeDb {
1616 pub fn new(background: Arc<Background>) -> Self {
1617 Self {
1618 background,
1619 users: Default::default(),
1620 next_user_id: Mutex::new(1),
1621 orgs: Default::default(),
1622 next_org_id: Mutex::new(1),
1623 org_memberships: Default::default(),
1624 channels: Default::default(),
1625 next_channel_id: Mutex::new(1),
1626 channel_memberships: Default::default(),
1627 channel_messages: Default::default(),
1628 next_channel_message_id: Mutex::new(1),
1629 contacts: Default::default(),
1630 }
1631 }
1632 }
1633
1634 #[async_trait]
1635 impl Db for FakeDb {
1636 async fn create_user(
1637 &self,
1638 github_login: &str,
1639 email_address: Option<&str>,
1640 admin: bool,
1641 ) -> Result<UserId> {
1642 self.background.simulate_random_delay().await;
1643
1644 let mut users = self.users.lock();
1645 if let Some(user) = users
1646 .values()
1647 .find(|user| user.github_login == github_login)
1648 {
1649 Ok(user.id)
1650 } else {
1651 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1652 users.insert(
1653 user_id,
1654 User {
1655 id: user_id,
1656 github_login: github_login.to_string(),
1657 email_address: email_address.map(str::to_string),
1658 admin,
1659 invite_code: None,
1660 invite_count: 0,
1661 connected_once: false,
1662 },
1663 );
1664 Ok(user_id)
1665 }
1666 }
1667
1668 async fn get_all_users(&self) -> Result<Vec<User>> {
1669 unimplemented!()
1670 }
1671
1672 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1673 unimplemented!()
1674 }
1675
1676 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1677 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1678 }
1679
1680 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1681 self.background.simulate_random_delay().await;
1682 let users = self.users.lock();
1683 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1684 }
1685
1686 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1687 Ok(self
1688 .users
1689 .lock()
1690 .values()
1691 .find(|user| user.github_login == github_login)
1692 .cloned())
1693 }
1694
1695 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1696 unimplemented!()
1697 }
1698
1699 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1700 self.background.simulate_random_delay().await;
1701 let mut users = self.users.lock();
1702 let mut user = users
1703 .get_mut(&id)
1704 .ok_or_else(|| anyhow!("user not found"))?;
1705 user.connected_once = connected_once;
1706 Ok(())
1707 }
1708
1709 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1710 unimplemented!()
1711 }
1712
1713 // invite codes
1714
1715 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1716 unimplemented!()
1717 }
1718
1719 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1720 Ok(None)
1721 }
1722
1723 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1724 unimplemented!()
1725 }
1726
1727 async fn redeem_invite_code(
1728 &self,
1729 _code: &str,
1730 _login: &str,
1731 _email_address: Option<&str>,
1732 ) -> Result<UserId> {
1733 unimplemented!()
1734 }
1735
1736 // contacts
1737
1738 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1739 self.background.simulate_random_delay().await;
1740 let mut contacts = vec![Contact::Accepted {
1741 user_id: id,
1742 should_notify: false,
1743 }];
1744
1745 for contact in self.contacts.lock().iter() {
1746 if contact.requester_id == id {
1747 if contact.accepted {
1748 contacts.push(Contact::Accepted {
1749 user_id: contact.responder_id,
1750 should_notify: contact.should_notify,
1751 });
1752 } else {
1753 contacts.push(Contact::Outgoing {
1754 user_id: contact.responder_id,
1755 });
1756 }
1757 } else if contact.responder_id == id {
1758 if contact.accepted {
1759 contacts.push(Contact::Accepted {
1760 user_id: contact.requester_id,
1761 should_notify: false,
1762 });
1763 } else {
1764 contacts.push(Contact::Incoming {
1765 user_id: contact.requester_id,
1766 should_notify: contact.should_notify,
1767 });
1768 }
1769 }
1770 }
1771
1772 contacts.sort_unstable_by_key(|contact| contact.user_id());
1773 Ok(contacts)
1774 }
1775
1776 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1777 self.background.simulate_random_delay().await;
1778 Ok(self.contacts.lock().iter().any(|contact| {
1779 contact.accepted
1780 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1781 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1782 }))
1783 }
1784
1785 async fn send_contact_request(
1786 &self,
1787 requester_id: UserId,
1788 responder_id: UserId,
1789 ) -> Result<()> {
1790 let mut contacts = self.contacts.lock();
1791 for contact in contacts.iter_mut() {
1792 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1793 if contact.accepted {
1794 Err(anyhow!("contact already exists"))?;
1795 } else {
1796 Err(anyhow!("contact already requested"))?;
1797 }
1798 }
1799 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1800 if contact.accepted {
1801 Err(anyhow!("contact already exists"))?;
1802 } else {
1803 contact.accepted = true;
1804 contact.should_notify = false;
1805 return Ok(());
1806 }
1807 }
1808 }
1809 contacts.push(FakeContact {
1810 requester_id,
1811 responder_id,
1812 accepted: false,
1813 should_notify: true,
1814 });
1815 Ok(())
1816 }
1817
1818 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1819 self.contacts.lock().retain(|contact| {
1820 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1821 });
1822 Ok(())
1823 }
1824
1825 async fn dismiss_contact_notification(
1826 &self,
1827 user_id: UserId,
1828 contact_user_id: UserId,
1829 ) -> Result<()> {
1830 let mut contacts = self.contacts.lock();
1831 for contact in contacts.iter_mut() {
1832 if contact.requester_id == contact_user_id
1833 && contact.responder_id == user_id
1834 && !contact.accepted
1835 {
1836 contact.should_notify = false;
1837 return Ok(());
1838 }
1839 if contact.requester_id == user_id
1840 && contact.responder_id == contact_user_id
1841 && contact.accepted
1842 {
1843 contact.should_notify = false;
1844 return Ok(());
1845 }
1846 }
1847 Err(anyhow!("no such notification"))?
1848 }
1849
1850 async fn respond_to_contact_request(
1851 &self,
1852 responder_id: UserId,
1853 requester_id: UserId,
1854 accept: bool,
1855 ) -> Result<()> {
1856 let mut contacts = self.contacts.lock();
1857 for (ix, contact) in contacts.iter_mut().enumerate() {
1858 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1859 if contact.accepted {
1860 Err(anyhow!("contact already confirmed"))?;
1861 }
1862 if accept {
1863 contact.accepted = true;
1864 contact.should_notify = true;
1865 } else {
1866 contacts.remove(ix);
1867 }
1868 return Ok(());
1869 }
1870 }
1871 Err(anyhow!("no such contact request"))?
1872 }
1873
1874 async fn create_access_token_hash(
1875 &self,
1876 _user_id: UserId,
1877 _access_token_hash: &str,
1878 _max_access_token_count: usize,
1879 ) -> Result<()> {
1880 unimplemented!()
1881 }
1882
1883 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1884 unimplemented!()
1885 }
1886
1887 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1888 unimplemented!()
1889 }
1890
1891 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1892 self.background.simulate_random_delay().await;
1893 let mut orgs = self.orgs.lock();
1894 if orgs.values().any(|org| org.slug == slug) {
1895 Err(anyhow!("org already exists"))?
1896 } else {
1897 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1898 orgs.insert(
1899 org_id,
1900 Org {
1901 id: org_id,
1902 name: name.to_string(),
1903 slug: slug.to_string(),
1904 },
1905 );
1906 Ok(org_id)
1907 }
1908 }
1909
1910 async fn add_org_member(
1911 &self,
1912 org_id: OrgId,
1913 user_id: UserId,
1914 is_admin: bool,
1915 ) -> Result<()> {
1916 self.background.simulate_random_delay().await;
1917 if !self.orgs.lock().contains_key(&org_id) {
1918 Err(anyhow!("org does not exist"))?;
1919 }
1920 if !self.users.lock().contains_key(&user_id) {
1921 Err(anyhow!("user does not exist"))?;
1922 }
1923
1924 self.org_memberships
1925 .lock()
1926 .entry((org_id, user_id))
1927 .or_insert(is_admin);
1928 Ok(())
1929 }
1930
1931 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1932 self.background.simulate_random_delay().await;
1933 if !self.orgs.lock().contains_key(&org_id) {
1934 Err(anyhow!("org does not exist"))?;
1935 }
1936
1937 let mut channels = self.channels.lock();
1938 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1939 channels.insert(
1940 channel_id,
1941 Channel {
1942 id: channel_id,
1943 name: name.to_string(),
1944 owner_id: org_id.0,
1945 owner_is_user: false,
1946 },
1947 );
1948 Ok(channel_id)
1949 }
1950
1951 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1952 self.background.simulate_random_delay().await;
1953 Ok(self
1954 .channels
1955 .lock()
1956 .values()
1957 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1958 .cloned()
1959 .collect())
1960 }
1961
1962 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1963 self.background.simulate_random_delay().await;
1964 let channels = self.channels.lock();
1965 let memberships = self.channel_memberships.lock();
1966 Ok(channels
1967 .values()
1968 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1969 .cloned()
1970 .collect())
1971 }
1972
1973 async fn can_user_access_channel(
1974 &self,
1975 user_id: UserId,
1976 channel_id: ChannelId,
1977 ) -> Result<bool> {
1978 self.background.simulate_random_delay().await;
1979 Ok(self
1980 .channel_memberships
1981 .lock()
1982 .contains_key(&(channel_id, user_id)))
1983 }
1984
1985 async fn add_channel_member(
1986 &self,
1987 channel_id: ChannelId,
1988 user_id: UserId,
1989 is_admin: bool,
1990 ) -> Result<()> {
1991 self.background.simulate_random_delay().await;
1992 if !self.channels.lock().contains_key(&channel_id) {
1993 Err(anyhow!("channel does not exist"))?;
1994 }
1995 if !self.users.lock().contains_key(&user_id) {
1996 Err(anyhow!("user does not exist"))?;
1997 }
1998
1999 self.channel_memberships
2000 .lock()
2001 .entry((channel_id, user_id))
2002 .or_insert(is_admin);
2003 Ok(())
2004 }
2005
2006 async fn create_channel_message(
2007 &self,
2008 channel_id: ChannelId,
2009 sender_id: UserId,
2010 body: &str,
2011 timestamp: OffsetDateTime,
2012 nonce: u128,
2013 ) -> Result<MessageId> {
2014 self.background.simulate_random_delay().await;
2015 if !self.channels.lock().contains_key(&channel_id) {
2016 Err(anyhow!("channel does not exist"))?;
2017 }
2018 if !self.users.lock().contains_key(&sender_id) {
2019 Err(anyhow!("user does not exist"))?;
2020 }
2021
2022 let mut messages = self.channel_messages.lock();
2023 if let Some(message) = messages
2024 .values()
2025 .find(|message| message.nonce.as_u128() == nonce)
2026 {
2027 Ok(message.id)
2028 } else {
2029 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2030 messages.insert(
2031 message_id,
2032 ChannelMessage {
2033 id: message_id,
2034 channel_id,
2035 sender_id,
2036 body: body.to_string(),
2037 sent_at: timestamp,
2038 nonce: Uuid::from_u128(nonce),
2039 },
2040 );
2041 Ok(message_id)
2042 }
2043 }
2044
2045 async fn get_channel_messages(
2046 &self,
2047 channel_id: ChannelId,
2048 count: usize,
2049 before_id: Option<MessageId>,
2050 ) -> Result<Vec<ChannelMessage>> {
2051 let mut messages = self
2052 .channel_messages
2053 .lock()
2054 .values()
2055 .rev()
2056 .filter(|message| {
2057 message.channel_id == channel_id
2058 && message.id < before_id.unwrap_or(MessageId::MAX)
2059 })
2060 .take(count)
2061 .cloned()
2062 .collect::<Vec<_>>();
2063 messages.sort_unstable_by_key(|message| message.id);
2064 Ok(messages)
2065 }
2066
2067 async fn teardown(&self, _: &str) {}
2068
2069 #[cfg(test)]
2070 fn as_fake(&self) -> Option<&FakeDb> {
2071 Some(self)
2072 }
2073 }
2074}