1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use futures::StreamExt;
6use nanoid::nanoid;
7use serde::Serialize;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow};
10use time::OffsetDateTime;
11
12#[async_trait]
13pub trait Db: Send + Sync {
14 async fn create_user(
15 &self,
16 github_login: &str,
17 email_address: Option<&str>,
18 admin: bool,
19 ) -> Result<UserId>;
20 async fn get_all_users(&self) -> Result<Vec<User>>;
21 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
22 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
23 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
24 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
25 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
26 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
27 async fn destroy_user(&self, id: UserId) -> Result<()>;
28
29 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
30 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
31 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
32 async fn redeem_invite_code(
33 &self,
34 code: &str,
35 login: &str,
36 email_address: Option<&str>,
37 ) -> Result<UserId>;
38
39 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
40 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
41 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
42 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
43 async fn dismiss_contact_notification(
44 &self,
45 responder_id: UserId,
46 requester_id: UserId,
47 ) -> Result<()>;
48 async fn respond_to_contact_request(
49 &self,
50 responder_id: UserId,
51 requester_id: UserId,
52 accept: bool,
53 ) -> Result<()>;
54
55 async fn create_access_token_hash(
56 &self,
57 user_id: UserId,
58 access_token_hash: &str,
59 max_access_token_count: usize,
60 ) -> Result<()>;
61 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
62 #[cfg(any(test, feature = "seed-support"))]
63
64 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
65 #[cfg(any(test, feature = "seed-support"))]
66 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
67 #[cfg(any(test, feature = "seed-support"))]
68 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
69 #[cfg(any(test, feature = "seed-support"))]
70 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
71 #[cfg(any(test, feature = "seed-support"))]
72
73 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
74 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
75 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
76 -> Result<bool>;
77 #[cfg(any(test, feature = "seed-support"))]
78 async fn add_channel_member(
79 &self,
80 channel_id: ChannelId,
81 user_id: UserId,
82 is_admin: bool,
83 ) -> Result<()>;
84 async fn create_channel_message(
85 &self,
86 channel_id: ChannelId,
87 sender_id: UserId,
88 body: &str,
89 timestamp: OffsetDateTime,
90 nonce: u128,
91 ) -> Result<MessageId>;
92 async fn get_channel_messages(
93 &self,
94 channel_id: ChannelId,
95 count: usize,
96 before_id: Option<MessageId>,
97 ) -> Result<Vec<ChannelMessage>>;
98 #[cfg(test)]
99 async fn teardown(&self, url: &str);
100 #[cfg(test)]
101 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
102}
103
104pub struct PostgresDb {
105 pool: sqlx::PgPool,
106}
107
108impl PostgresDb {
109 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
110 let pool = DbOptions::new()
111 .max_connections(max_connections)
112 .connect(&url)
113 .await
114 .context("failed to connect to postgres database")?;
115 Ok(Self { pool })
116 }
117}
118
119#[async_trait]
120impl Db for PostgresDb {
121 // users
122
123 async fn create_user(
124 &self,
125 github_login: &str,
126 email_address: Option<&str>,
127 admin: bool,
128 ) -> Result<UserId> {
129 let query = "
130 INSERT INTO users (github_login, email_address, admin)
131 VALUES ($1, $2, $3)
132 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
133 RETURNING id
134 ";
135 Ok(sqlx::query_scalar(query)
136 .bind(github_login)
137 .bind(email_address)
138 .bind(admin)
139 .fetch_one(&self.pool)
140 .await
141 .map(UserId)?)
142 }
143
144 async fn get_all_users(&self) -> Result<Vec<User>> {
145 let query = "SELECT * FROM users ORDER BY github_login ASC";
146 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
147 }
148
149 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
150 let like_string = fuzzy_like_string(name_query);
151 let query = "
152 SELECT users.*
153 FROM users
154 WHERE github_login ILIKE $1
155 ORDER BY github_login <-> $2
156 LIMIT $3
157 ";
158 Ok(sqlx::query_as(query)
159 .bind(like_string)
160 .bind(name_query)
161 .bind(limit)
162 .fetch_all(&self.pool)
163 .await?)
164 }
165
166 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
167 let users = self.get_users_by_ids(vec![id]).await?;
168 Ok(users.into_iter().next())
169 }
170
171 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
172 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
173 let query = "
174 SELECT users.*
175 FROM users
176 WHERE users.id = ANY ($1)
177 ";
178 Ok(sqlx::query_as(query)
179 .bind(&ids)
180 .fetch_all(&self.pool)
181 .await?)
182 }
183
184 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
185 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
186 Ok(sqlx::query_as(query)
187 .bind(github_login)
188 .fetch_optional(&self.pool)
189 .await?)
190 }
191
192 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
193 let query = "UPDATE users SET admin = $1 WHERE id = $2";
194 Ok(sqlx::query(query)
195 .bind(is_admin)
196 .bind(id.0)
197 .execute(&self.pool)
198 .await
199 .map(drop)?)
200 }
201
202 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
203 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
204 Ok(sqlx::query(query)
205 .bind(connected_once)
206 .bind(id.0)
207 .execute(&self.pool)
208 .await
209 .map(drop)?)
210 }
211
212 async fn destroy_user(&self, id: UserId) -> Result<()> {
213 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
214 sqlx::query(query)
215 .bind(id.0)
216 .execute(&self.pool)
217 .await
218 .map(drop)?;
219 let query = "DELETE FROM users WHERE id = $1;";
220 Ok(sqlx::query(query)
221 .bind(id.0)
222 .execute(&self.pool)
223 .await
224 .map(drop)?)
225 }
226
227 // invite codes
228
229 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
230 let mut tx = self.pool.begin().await?;
231 sqlx::query(
232 "
233 UPDATE users
234 SET invite_code = $1
235 WHERE id = $2 AND invite_code IS NULL
236 ",
237 )
238 .bind(nanoid!(16))
239 .bind(id)
240 .execute(&mut tx)
241 .await?;
242 sqlx::query(
243 "
244 UPDATE users
245 SET invite_count = $1
246 WHERE id = $2
247 ",
248 )
249 .bind(count)
250 .bind(id)
251 .execute(&mut tx)
252 .await?;
253 tx.commit().await?;
254 Ok(())
255 }
256
257 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
258 let result: Option<(String, i32)> = sqlx::query_as(
259 "
260 SELECT invite_code, invite_count
261 FROM users
262 WHERE id = $1 AND invite_code IS NOT NULL
263 ",
264 )
265 .bind(id)
266 .fetch_optional(&self.pool)
267 .await?;
268 if let Some((code, count)) = result {
269 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
270 } else {
271 Ok(None)
272 }
273 }
274
275 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
276 sqlx::query_as(
277 "
278 SELECT *
279 FROM users
280 WHERE invite_code = $1
281 ",
282 )
283 .bind(code)
284 .fetch_optional(&self.pool)
285 .await?
286 .ok_or_else(|| {
287 Error::Http(
288 StatusCode::NOT_FOUND,
289 "that invite code does not exist".to_string(),
290 )
291 })
292 }
293
294 async fn redeem_invite_code(
295 &self,
296 code: &str,
297 login: &str,
298 email_address: Option<&str>,
299 ) -> Result<UserId> {
300 let mut tx = self.pool.begin().await?;
301
302 let inviter_id: Option<UserId> = sqlx::query_scalar(
303 "
304 UPDATE users
305 SET invite_count = invite_count - 1
306 WHERE
307 invite_code = $1 AND
308 invite_count > 0
309 RETURNING id
310 ",
311 )
312 .bind(code)
313 .fetch_optional(&mut tx)
314 .await?;
315
316 let inviter_id = match inviter_id {
317 Some(inviter_id) => inviter_id,
318 None => {
319 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
320 .bind(code)
321 .fetch_optional(&mut tx)
322 .await?
323 .is_some()
324 {
325 Err(Error::Http(
326 StatusCode::UNAUTHORIZED,
327 "no invites remaining".to_string(),
328 ))?
329 } else {
330 Err(Error::Http(
331 StatusCode::NOT_FOUND,
332 "invite code not found".to_string(),
333 ))?
334 }
335 }
336 };
337
338 let invitee_id = sqlx::query_scalar(
339 "
340 INSERT INTO users
341 (github_login, email_address, admin, inviter_id)
342 VALUES
343 ($1, $2, 'f', $3)
344 RETURNING id
345 ",
346 )
347 .bind(login)
348 .bind(email_address)
349 .bind(inviter_id)
350 .fetch_one(&mut tx)
351 .await
352 .map(UserId)?;
353
354 sqlx::query(
355 "
356 INSERT INTO contacts
357 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
358 VALUES
359 ($1, $2, 't', 't', 't')
360 ",
361 )
362 .bind(inviter_id)
363 .bind(invitee_id)
364 .execute(&mut tx)
365 .await?;
366
367 tx.commit().await?;
368 Ok(invitee_id)
369 }
370
371 // contacts
372
373 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
374 let query = "
375 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
376 FROM contacts
377 WHERE user_id_a = $1 OR user_id_b = $1;
378 ";
379
380 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
381 .bind(user_id)
382 .fetch(&self.pool);
383
384 let mut contacts = vec![Contact::Accepted {
385 user_id,
386 should_notify: false,
387 }];
388 while let Some(row) = rows.next().await {
389 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
390
391 if user_id_a == user_id {
392 if accepted {
393 contacts.push(Contact::Accepted {
394 user_id: user_id_b,
395 should_notify: should_notify && a_to_b,
396 });
397 } else if a_to_b {
398 contacts.push(Contact::Outgoing { user_id: user_id_b })
399 } else {
400 contacts.push(Contact::Incoming {
401 user_id: user_id_b,
402 should_notify,
403 });
404 }
405 } else {
406 if accepted {
407 contacts.push(Contact::Accepted {
408 user_id: user_id_a,
409 should_notify: should_notify && !a_to_b,
410 });
411 } else if a_to_b {
412 contacts.push(Contact::Incoming {
413 user_id: user_id_a,
414 should_notify,
415 });
416 } else {
417 contacts.push(Contact::Outgoing { user_id: user_id_a });
418 }
419 }
420 }
421
422 contacts.sort_unstable_by_key(|contact| contact.user_id());
423
424 Ok(contacts)
425 }
426
427 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
428 let (id_a, id_b) = if user_id_1 < user_id_2 {
429 (user_id_1, user_id_2)
430 } else {
431 (user_id_2, user_id_1)
432 };
433
434 let query = "
435 SELECT 1 FROM contacts
436 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
437 LIMIT 1
438 ";
439 Ok(sqlx::query_scalar::<_, i32>(query)
440 .bind(id_a.0)
441 .bind(id_b.0)
442 .fetch_optional(&self.pool)
443 .await?
444 .is_some())
445 }
446
447 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
448 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
449 (sender_id, receiver_id, true)
450 } else {
451 (receiver_id, sender_id, false)
452 };
453 let query = "
454 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
455 VALUES ($1, $2, $3, 'f', 't')
456 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
457 SET
458 accepted = 't',
459 should_notify = 'f'
460 WHERE
461 NOT contacts.accepted AND
462 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
463 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
464 ";
465 let result = sqlx::query(query)
466 .bind(id_a.0)
467 .bind(id_b.0)
468 .bind(a_to_b)
469 .execute(&self.pool)
470 .await?;
471
472 if result.rows_affected() == 1 {
473 Ok(())
474 } else {
475 Err(anyhow!("contact already requested"))?
476 }
477 }
478
479 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
480 let (id_a, id_b) = if responder_id < requester_id {
481 (responder_id, requester_id)
482 } else {
483 (requester_id, responder_id)
484 };
485 let query = "
486 DELETE FROM contacts
487 WHERE user_id_a = $1 AND user_id_b = $2;
488 ";
489 let result = sqlx::query(query)
490 .bind(id_a.0)
491 .bind(id_b.0)
492 .execute(&self.pool)
493 .await?;
494
495 if result.rows_affected() == 1 {
496 Ok(())
497 } else {
498 Err(anyhow!("no such contact"))?
499 }
500 }
501
502 async fn dismiss_contact_notification(
503 &self,
504 user_id: UserId,
505 contact_user_id: UserId,
506 ) -> Result<()> {
507 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
508 (user_id, contact_user_id, true)
509 } else {
510 (contact_user_id, user_id, false)
511 };
512
513 let query = "
514 UPDATE contacts
515 SET should_notify = 'f'
516 WHERE
517 user_id_a = $1 AND user_id_b = $2 AND
518 (
519 (a_to_b = $3 AND accepted) OR
520 (a_to_b != $3 AND NOT accepted)
521 );
522 ";
523
524 let result = sqlx::query(query)
525 .bind(id_a.0)
526 .bind(id_b.0)
527 .bind(a_to_b)
528 .execute(&self.pool)
529 .await?;
530
531 if result.rows_affected() == 0 {
532 Err(anyhow!("no such contact request"))?;
533 }
534
535 Ok(())
536 }
537
538 async fn respond_to_contact_request(
539 &self,
540 responder_id: UserId,
541 requester_id: UserId,
542 accept: bool,
543 ) -> Result<()> {
544 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
545 (responder_id, requester_id, false)
546 } else {
547 (requester_id, responder_id, true)
548 };
549 let result = if accept {
550 let query = "
551 UPDATE contacts
552 SET accepted = 't', should_notify = 't'
553 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
554 ";
555 sqlx::query(query)
556 .bind(id_a.0)
557 .bind(id_b.0)
558 .bind(a_to_b)
559 .execute(&self.pool)
560 .await?
561 } else {
562 let query = "
563 DELETE FROM contacts
564 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
565 ";
566 sqlx::query(query)
567 .bind(id_a.0)
568 .bind(id_b.0)
569 .bind(a_to_b)
570 .execute(&self.pool)
571 .await?
572 };
573 if result.rows_affected() == 1 {
574 Ok(())
575 } else {
576 Err(anyhow!("no such contact request"))?
577 }
578 }
579
580 // access tokens
581
582 async fn create_access_token_hash(
583 &self,
584 user_id: UserId,
585 access_token_hash: &str,
586 max_access_token_count: usize,
587 ) -> Result<()> {
588 let insert_query = "
589 INSERT INTO access_tokens (user_id, hash)
590 VALUES ($1, $2);
591 ";
592 let cleanup_query = "
593 DELETE FROM access_tokens
594 WHERE id IN (
595 SELECT id from access_tokens
596 WHERE user_id = $1
597 ORDER BY id DESC
598 OFFSET $3
599 )
600 ";
601
602 let mut tx = self.pool.begin().await?;
603 sqlx::query(insert_query)
604 .bind(user_id.0)
605 .bind(access_token_hash)
606 .execute(&mut tx)
607 .await?;
608 sqlx::query(cleanup_query)
609 .bind(user_id.0)
610 .bind(access_token_hash)
611 .bind(max_access_token_count as u32)
612 .execute(&mut tx)
613 .await?;
614 Ok(tx.commit().await?)
615 }
616
617 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
618 let query = "
619 SELECT hash
620 FROM access_tokens
621 WHERE user_id = $1
622 ORDER BY id DESC
623 ";
624 Ok(sqlx::query_scalar(query)
625 .bind(user_id.0)
626 .fetch_all(&self.pool)
627 .await?)
628 }
629
630 // orgs
631
632 #[allow(unused)] // Help rust-analyzer
633 #[cfg(any(test, feature = "seed-support"))]
634 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
635 let query = "
636 SELECT *
637 FROM orgs
638 WHERE slug = $1
639 ";
640 Ok(sqlx::query_as(query)
641 .bind(slug)
642 .fetch_optional(&self.pool)
643 .await?)
644 }
645
646 #[cfg(any(test, feature = "seed-support"))]
647 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
648 let query = "
649 INSERT INTO orgs (name, slug)
650 VALUES ($1, $2)
651 RETURNING id
652 ";
653 Ok(sqlx::query_scalar(query)
654 .bind(name)
655 .bind(slug)
656 .fetch_one(&self.pool)
657 .await
658 .map(OrgId)?)
659 }
660
661 #[cfg(any(test, feature = "seed-support"))]
662 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
663 let query = "
664 INSERT INTO org_memberships (org_id, user_id, admin)
665 VALUES ($1, $2, $3)
666 ON CONFLICT DO NOTHING
667 ";
668 Ok(sqlx::query(query)
669 .bind(org_id.0)
670 .bind(user_id.0)
671 .bind(is_admin)
672 .execute(&self.pool)
673 .await
674 .map(drop)?)
675 }
676
677 // channels
678
679 #[cfg(any(test, feature = "seed-support"))]
680 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
681 let query = "
682 INSERT INTO channels (owner_id, owner_is_user, name)
683 VALUES ($1, false, $2)
684 RETURNING id
685 ";
686 Ok(sqlx::query_scalar(query)
687 .bind(org_id.0)
688 .bind(name)
689 .fetch_one(&self.pool)
690 .await
691 .map(ChannelId)?)
692 }
693
694 #[allow(unused)] // Help rust-analyzer
695 #[cfg(any(test, feature = "seed-support"))]
696 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
697 let query = "
698 SELECT *
699 FROM channels
700 WHERE
701 channels.owner_is_user = false AND
702 channels.owner_id = $1
703 ";
704 Ok(sqlx::query_as(query)
705 .bind(org_id.0)
706 .fetch_all(&self.pool)
707 .await?)
708 }
709
710 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
711 let query = "
712 SELECT
713 channels.*
714 FROM
715 channel_memberships, channels
716 WHERE
717 channel_memberships.user_id = $1 AND
718 channel_memberships.channel_id = channels.id
719 ";
720 Ok(sqlx::query_as(query)
721 .bind(user_id.0)
722 .fetch_all(&self.pool)
723 .await?)
724 }
725
726 async fn can_user_access_channel(
727 &self,
728 user_id: UserId,
729 channel_id: ChannelId,
730 ) -> Result<bool> {
731 let query = "
732 SELECT id
733 FROM channel_memberships
734 WHERE user_id = $1 AND channel_id = $2
735 LIMIT 1
736 ";
737 Ok(sqlx::query_scalar::<_, i32>(query)
738 .bind(user_id.0)
739 .bind(channel_id.0)
740 .fetch_optional(&self.pool)
741 .await
742 .map(|e| e.is_some())?)
743 }
744
745 #[cfg(any(test, feature = "seed-support"))]
746 async fn add_channel_member(
747 &self,
748 channel_id: ChannelId,
749 user_id: UserId,
750 is_admin: bool,
751 ) -> Result<()> {
752 let query = "
753 INSERT INTO channel_memberships (channel_id, user_id, admin)
754 VALUES ($1, $2, $3)
755 ON CONFLICT DO NOTHING
756 ";
757 Ok(sqlx::query(query)
758 .bind(channel_id.0)
759 .bind(user_id.0)
760 .bind(is_admin)
761 .execute(&self.pool)
762 .await
763 .map(drop)?)
764 }
765
766 // messages
767
768 async fn create_channel_message(
769 &self,
770 channel_id: ChannelId,
771 sender_id: UserId,
772 body: &str,
773 timestamp: OffsetDateTime,
774 nonce: u128,
775 ) -> Result<MessageId> {
776 let query = "
777 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
778 VALUES ($1, $2, $3, $4, $5)
779 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
780 RETURNING id
781 ";
782 Ok(sqlx::query_scalar(query)
783 .bind(channel_id.0)
784 .bind(sender_id.0)
785 .bind(body)
786 .bind(timestamp)
787 .bind(Uuid::from_u128(nonce))
788 .fetch_one(&self.pool)
789 .await
790 .map(MessageId)?)
791 }
792
793 async fn get_channel_messages(
794 &self,
795 channel_id: ChannelId,
796 count: usize,
797 before_id: Option<MessageId>,
798 ) -> Result<Vec<ChannelMessage>> {
799 let query = r#"
800 SELECT * FROM (
801 SELECT
802 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
803 FROM
804 channel_messages
805 WHERE
806 channel_id = $1 AND
807 id < $2
808 ORDER BY id DESC
809 LIMIT $3
810 ) as recent_messages
811 ORDER BY id ASC
812 "#;
813 Ok(sqlx::query_as(query)
814 .bind(channel_id.0)
815 .bind(before_id.unwrap_or(MessageId::MAX))
816 .bind(count as i64)
817 .fetch_all(&self.pool)
818 .await?)
819 }
820
821 #[cfg(test)]
822 async fn teardown(&self, url: &str) {
823 use util::ResultExt;
824
825 let query = "
826 SELECT pg_terminate_backend(pg_stat_activity.pid)
827 FROM pg_stat_activity
828 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
829 ";
830 sqlx::query(query).execute(&self.pool).await.log_err();
831 self.pool.close().await;
832 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
833 .await
834 .log_err();
835 }
836
837 #[cfg(test)]
838 fn as_fake(&self) -> Option<&tests::FakeDb> {
839 None
840 }
841}
842
843macro_rules! id_type {
844 ($name:ident) => {
845 #[derive(
846 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
847 )]
848 #[sqlx(transparent)]
849 #[serde(transparent)]
850 pub struct $name(pub i32);
851
852 impl $name {
853 #[allow(unused)]
854 pub const MAX: Self = Self(i32::MAX);
855
856 #[allow(unused)]
857 pub fn from_proto(value: u64) -> Self {
858 Self(value as i32)
859 }
860
861 #[allow(unused)]
862 pub fn to_proto(&self) -> u64 {
863 self.0 as u64
864 }
865 }
866
867 impl std::fmt::Display for $name {
868 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
869 self.0.fmt(f)
870 }
871 }
872 };
873}
874
875id_type!(UserId);
876#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
877pub struct User {
878 pub id: UserId,
879 pub github_login: String,
880 pub email_address: Option<String>,
881 pub admin: bool,
882 pub invite_code: Option<String>,
883 pub invite_count: i32,
884 pub connected_once: bool,
885}
886
887id_type!(OrgId);
888#[derive(FromRow)]
889pub struct Org {
890 pub id: OrgId,
891 pub name: String,
892 pub slug: String,
893}
894
895id_type!(ChannelId);
896#[derive(Clone, Debug, FromRow, Serialize)]
897pub struct Channel {
898 pub id: ChannelId,
899 pub name: String,
900 pub owner_id: i32,
901 pub owner_is_user: bool,
902}
903
904id_type!(MessageId);
905#[derive(Clone, Debug, FromRow)]
906pub struct ChannelMessage {
907 pub id: MessageId,
908 pub channel_id: ChannelId,
909 pub sender_id: UserId,
910 pub body: String,
911 pub sent_at: OffsetDateTime,
912 pub nonce: Uuid,
913}
914
915#[derive(Clone, Debug, PartialEq, Eq)]
916pub enum Contact {
917 Accepted {
918 user_id: UserId,
919 should_notify: bool,
920 },
921 Outgoing {
922 user_id: UserId,
923 },
924 Incoming {
925 user_id: UserId,
926 should_notify: bool,
927 },
928}
929
930impl Contact {
931 pub fn user_id(&self) -> UserId {
932 match self {
933 Contact::Accepted { user_id, .. } => *user_id,
934 Contact::Outgoing { user_id } => *user_id,
935 Contact::Incoming { user_id, .. } => *user_id,
936 }
937 }
938}
939
940#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
941pub struct IncomingContactRequest {
942 pub requester_id: UserId,
943 pub should_notify: bool,
944}
945
946fn fuzzy_like_string(string: &str) -> String {
947 let mut result = String::with_capacity(string.len() * 2 + 1);
948 for c in string.chars() {
949 if c.is_alphanumeric() {
950 result.push('%');
951 result.push(c);
952 }
953 }
954 result.push('%');
955 result
956}
957
958#[cfg(test)]
959pub mod tests {
960 use super::*;
961 use anyhow::anyhow;
962 use collections::BTreeMap;
963 use gpui::executor::Background;
964 use lazy_static::lazy_static;
965 use parking_lot::Mutex;
966 use rand::prelude::*;
967 use sqlx::{
968 migrate::{MigrateDatabase, Migrator},
969 Postgres,
970 };
971 use std::{path::Path, sync::Arc};
972 use util::post_inc;
973
974 #[tokio::test(flavor = "multi_thread")]
975 async fn test_get_users_by_ids() {
976 for test_db in [
977 TestDb::postgres().await,
978 TestDb::fake(Arc::new(gpui::executor::Background::new())),
979 ] {
980 let db = test_db.db();
981
982 let user = db.create_user("user", None, false).await.unwrap();
983 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
984 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
985 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
986
987 assert_eq!(
988 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
989 .await
990 .unwrap(),
991 vec![
992 User {
993 id: user,
994 github_login: "user".to_string(),
995 admin: false,
996 ..Default::default()
997 },
998 User {
999 id: friend1,
1000 github_login: "friend-1".to_string(),
1001 admin: false,
1002 ..Default::default()
1003 },
1004 User {
1005 id: friend2,
1006 github_login: "friend-2".to_string(),
1007 admin: false,
1008 ..Default::default()
1009 },
1010 User {
1011 id: friend3,
1012 github_login: "friend-3".to_string(),
1013 admin: false,
1014 ..Default::default()
1015 }
1016 ]
1017 );
1018 }
1019 }
1020
1021 #[tokio::test(flavor = "multi_thread")]
1022 async fn test_recent_channel_messages() {
1023 for test_db in [
1024 TestDb::postgres().await,
1025 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1026 ] {
1027 let db = test_db.db();
1028 let user = db.create_user("user", None, false).await.unwrap();
1029 let org = db.create_org("org", "org").await.unwrap();
1030 let channel = db.create_org_channel(org, "channel").await.unwrap();
1031 for i in 0..10 {
1032 db.create_channel_message(
1033 channel,
1034 user,
1035 &i.to_string(),
1036 OffsetDateTime::now_utc(),
1037 i,
1038 )
1039 .await
1040 .unwrap();
1041 }
1042
1043 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1044 assert_eq!(
1045 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1046 ["5", "6", "7", "8", "9"]
1047 );
1048
1049 let prev_messages = db
1050 .get_channel_messages(channel, 4, Some(messages[0].id))
1051 .await
1052 .unwrap();
1053 assert_eq!(
1054 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1055 ["1", "2", "3", "4"]
1056 );
1057 }
1058 }
1059
1060 #[tokio::test(flavor = "multi_thread")]
1061 async fn test_channel_message_nonces() {
1062 for test_db in [
1063 TestDb::postgres().await,
1064 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1065 ] {
1066 let db = test_db.db();
1067 let user = db.create_user("user", None, false).await.unwrap();
1068 let org = db.create_org("org", "org").await.unwrap();
1069 let channel = db.create_org_channel(org, "channel").await.unwrap();
1070
1071 let msg1_id = db
1072 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1073 .await
1074 .unwrap();
1075 let msg2_id = db
1076 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1077 .await
1078 .unwrap();
1079 let msg3_id = db
1080 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1081 .await
1082 .unwrap();
1083 let msg4_id = db
1084 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1085 .await
1086 .unwrap();
1087
1088 assert_ne!(msg1_id, msg2_id);
1089 assert_eq!(msg1_id, msg3_id);
1090 assert_eq!(msg2_id, msg4_id);
1091 }
1092 }
1093
1094 #[tokio::test(flavor = "multi_thread")]
1095 async fn test_create_access_tokens() {
1096 let test_db = TestDb::postgres().await;
1097 let db = test_db.db();
1098 let user = db.create_user("the-user", None, false).await.unwrap();
1099
1100 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1101 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1102 assert_eq!(
1103 db.get_access_token_hashes(user).await.unwrap(),
1104 &["h2".to_string(), "h1".to_string()]
1105 );
1106
1107 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1108 assert_eq!(
1109 db.get_access_token_hashes(user).await.unwrap(),
1110 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1111 );
1112
1113 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1114 assert_eq!(
1115 db.get_access_token_hashes(user).await.unwrap(),
1116 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1117 );
1118
1119 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1120 assert_eq!(
1121 db.get_access_token_hashes(user).await.unwrap(),
1122 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1123 );
1124 }
1125
1126 #[test]
1127 fn test_fuzzy_like_string() {
1128 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1129 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1130 assert_eq!(fuzzy_like_string(" z "), "%z%");
1131 }
1132
1133 #[tokio::test(flavor = "multi_thread")]
1134 async fn test_fuzzy_search_users() {
1135 let test_db = TestDb::postgres().await;
1136 let db = test_db.db();
1137 for github_login in [
1138 "California",
1139 "colorado",
1140 "oregon",
1141 "washington",
1142 "florida",
1143 "delaware",
1144 "rhode-island",
1145 ] {
1146 db.create_user(github_login, None, false).await.unwrap();
1147 }
1148
1149 assert_eq!(
1150 fuzzy_search_user_names(db, "clr").await,
1151 &["colorado", "California"]
1152 );
1153 assert_eq!(
1154 fuzzy_search_user_names(db, "ro").await,
1155 &["rhode-island", "colorado", "oregon"],
1156 );
1157
1158 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1159 db.fuzzy_search_users(query, 10)
1160 .await
1161 .unwrap()
1162 .into_iter()
1163 .map(|user| user.github_login)
1164 .collect::<Vec<_>>()
1165 }
1166 }
1167
1168 #[tokio::test(flavor = "multi_thread")]
1169 async fn test_add_contacts() {
1170 for test_db in [
1171 TestDb::postgres().await,
1172 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1173 ] {
1174 let db = test_db.db();
1175
1176 let user_1 = db.create_user("user1", None, false).await.unwrap();
1177 let user_2 = db.create_user("user2", None, false).await.unwrap();
1178 let user_3 = db.create_user("user3", None, false).await.unwrap();
1179
1180 // User starts with no contacts
1181 assert_eq!(
1182 db.get_contacts(user_1).await.unwrap(),
1183 vec![Contact::Accepted {
1184 user_id: user_1,
1185 should_notify: false
1186 }],
1187 );
1188
1189 // User requests a contact. Both users see the pending request.
1190 db.send_contact_request(user_1, user_2).await.unwrap();
1191 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1192 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1193 assert_eq!(
1194 db.get_contacts(user_1).await.unwrap(),
1195 &[
1196 Contact::Accepted {
1197 user_id: user_1,
1198 should_notify: false
1199 },
1200 Contact::Outgoing { user_id: user_2 }
1201 ],
1202 );
1203 assert_eq!(
1204 db.get_contacts(user_2).await.unwrap(),
1205 &[
1206 Contact::Incoming {
1207 user_id: user_1,
1208 should_notify: true
1209 },
1210 Contact::Accepted {
1211 user_id: user_2,
1212 should_notify: false
1213 },
1214 ]
1215 );
1216
1217 // User 2 dismisses the contact request notification without accepting or rejecting.
1218 // We shouldn't notify them again.
1219 db.dismiss_contact_notification(user_1, user_2)
1220 .await
1221 .unwrap_err();
1222 db.dismiss_contact_notification(user_2, user_1)
1223 .await
1224 .unwrap();
1225 assert_eq!(
1226 db.get_contacts(user_2).await.unwrap(),
1227 &[
1228 Contact::Incoming {
1229 user_id: user_1,
1230 should_notify: false
1231 },
1232 Contact::Accepted {
1233 user_id: user_2,
1234 should_notify: false
1235 },
1236 ]
1237 );
1238
1239 // User can't accept their own contact request
1240 db.respond_to_contact_request(user_1, user_2, true)
1241 .await
1242 .unwrap_err();
1243
1244 // User accepts a contact request. Both users see the contact.
1245 db.respond_to_contact_request(user_2, user_1, true)
1246 .await
1247 .unwrap();
1248 assert_eq!(
1249 db.get_contacts(user_1).await.unwrap(),
1250 &[
1251 Contact::Accepted {
1252 user_id: user_1,
1253 should_notify: false
1254 },
1255 Contact::Accepted {
1256 user_id: user_2,
1257 should_notify: true
1258 }
1259 ],
1260 );
1261 assert!(db.has_contact(user_1, user_2).await.unwrap());
1262 assert!(db.has_contact(user_2, user_1).await.unwrap());
1263 assert_eq!(
1264 db.get_contacts(user_2).await.unwrap(),
1265 &[
1266 Contact::Accepted {
1267 user_id: user_1,
1268 should_notify: false,
1269 },
1270 Contact::Accepted {
1271 user_id: user_2,
1272 should_notify: false,
1273 },
1274 ]
1275 );
1276
1277 // Users cannot re-request existing contacts.
1278 db.send_contact_request(user_1, user_2).await.unwrap_err();
1279 db.send_contact_request(user_2, user_1).await.unwrap_err();
1280
1281 // Users can't dismiss notifications of them accepting other users' requests.
1282 db.dismiss_contact_notification(user_2, user_1)
1283 .await
1284 .unwrap_err();
1285 assert_eq!(
1286 db.get_contacts(user_1).await.unwrap(),
1287 &[
1288 Contact::Accepted {
1289 user_id: user_1,
1290 should_notify: false
1291 },
1292 Contact::Accepted {
1293 user_id: user_2,
1294 should_notify: true,
1295 },
1296 ]
1297 );
1298
1299 // Users can dismiss notifications of other users accepting their requests.
1300 db.dismiss_contact_notification(user_1, user_2)
1301 .await
1302 .unwrap();
1303 assert_eq!(
1304 db.get_contacts(user_1).await.unwrap(),
1305 &[
1306 Contact::Accepted {
1307 user_id: user_1,
1308 should_notify: false
1309 },
1310 Contact::Accepted {
1311 user_id: user_2,
1312 should_notify: false,
1313 },
1314 ]
1315 );
1316
1317 // Users send each other concurrent contact requests and
1318 // see that they are immediately accepted.
1319 db.send_contact_request(user_1, user_3).await.unwrap();
1320 db.send_contact_request(user_3, user_1).await.unwrap();
1321 assert_eq!(
1322 db.get_contacts(user_1).await.unwrap(),
1323 &[
1324 Contact::Accepted {
1325 user_id: user_1,
1326 should_notify: false
1327 },
1328 Contact::Accepted {
1329 user_id: user_2,
1330 should_notify: false,
1331 },
1332 Contact::Accepted {
1333 user_id: user_3,
1334 should_notify: false
1335 },
1336 ]
1337 );
1338 assert_eq!(
1339 db.get_contacts(user_3).await.unwrap(),
1340 &[
1341 Contact::Accepted {
1342 user_id: user_1,
1343 should_notify: false
1344 },
1345 Contact::Accepted {
1346 user_id: user_3,
1347 should_notify: false
1348 }
1349 ],
1350 );
1351
1352 // User declines a contact request. Both users see that it is gone.
1353 db.send_contact_request(user_2, user_3).await.unwrap();
1354 db.respond_to_contact_request(user_3, user_2, false)
1355 .await
1356 .unwrap();
1357 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1358 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1359 assert_eq!(
1360 db.get_contacts(user_2).await.unwrap(),
1361 &[
1362 Contact::Accepted {
1363 user_id: user_1,
1364 should_notify: false
1365 },
1366 Contact::Accepted {
1367 user_id: user_2,
1368 should_notify: false
1369 }
1370 ]
1371 );
1372 assert_eq!(
1373 db.get_contacts(user_3).await.unwrap(),
1374 &[
1375 Contact::Accepted {
1376 user_id: user_1,
1377 should_notify: false
1378 },
1379 Contact::Accepted {
1380 user_id: user_3,
1381 should_notify: false
1382 }
1383 ],
1384 );
1385 }
1386 }
1387
1388 #[tokio::test(flavor = "multi_thread")]
1389 async fn test_invite_codes() {
1390 let postgres = TestDb::postgres().await;
1391 let db = postgres.db();
1392 let user1 = db.create_user("user-1", None, false).await.unwrap();
1393
1394 // Initially, user 1 has no invite code
1395 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1396
1397 // User 1 creates an invite code that can be used twice.
1398 db.set_invite_count(user1, 2).await.unwrap();
1399 let (invite_code, invite_count) =
1400 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1401 assert_eq!(invite_count, 2);
1402
1403 // User 2 redeems the invite code and becomes a contact of user 1.
1404 let user2 = db
1405 .redeem_invite_code(&invite_code, "user-2", None)
1406 .await
1407 .unwrap();
1408 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1409 assert_eq!(invite_count, 1);
1410 assert_eq!(
1411 db.get_contacts(user1).await.unwrap(),
1412 [
1413 Contact::Accepted {
1414 user_id: user1,
1415 should_notify: false
1416 },
1417 Contact::Accepted {
1418 user_id: user2,
1419 should_notify: true
1420 }
1421 ]
1422 );
1423 assert_eq!(
1424 db.get_contacts(user2).await.unwrap(),
1425 [
1426 Contact::Accepted {
1427 user_id: user1,
1428 should_notify: false
1429 },
1430 Contact::Accepted {
1431 user_id: user2,
1432 should_notify: false
1433 }
1434 ]
1435 );
1436
1437 // User 3 redeems the invite code and becomes a contact of user 1.
1438 let user3 = db
1439 .redeem_invite_code(&invite_code, "user-3", None)
1440 .await
1441 .unwrap();
1442 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1443 assert_eq!(invite_count, 0);
1444 assert_eq!(
1445 db.get_contacts(user1).await.unwrap(),
1446 [
1447 Contact::Accepted {
1448 user_id: user1,
1449 should_notify: false
1450 },
1451 Contact::Accepted {
1452 user_id: user2,
1453 should_notify: true
1454 },
1455 Contact::Accepted {
1456 user_id: user3,
1457 should_notify: true
1458 }
1459 ]
1460 );
1461 assert_eq!(
1462 db.get_contacts(user3).await.unwrap(),
1463 [
1464 Contact::Accepted {
1465 user_id: user1,
1466 should_notify: false
1467 },
1468 Contact::Accepted {
1469 user_id: user3,
1470 should_notify: false
1471 },
1472 ]
1473 );
1474
1475 // Trying to reedem the code for the third time results in an error.
1476 db.redeem_invite_code(&invite_code, "user-4", None)
1477 .await
1478 .unwrap_err();
1479
1480 // Invite count can be updated after the code has been created.
1481 db.set_invite_count(user1, 2).await.unwrap();
1482 let (latest_code, invite_count) =
1483 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1484 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1485 assert_eq!(invite_count, 2);
1486
1487 // User 4 can now redeem the invite code and becomes a contact of user 1.
1488 let user4 = db
1489 .redeem_invite_code(&invite_code, "user-4", None)
1490 .await
1491 .unwrap();
1492 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1493 assert_eq!(invite_count, 1);
1494 assert_eq!(
1495 db.get_contacts(user1).await.unwrap(),
1496 [
1497 Contact::Accepted {
1498 user_id: user1,
1499 should_notify: false
1500 },
1501 Contact::Accepted {
1502 user_id: user2,
1503 should_notify: true
1504 },
1505 Contact::Accepted {
1506 user_id: user3,
1507 should_notify: true
1508 },
1509 Contact::Accepted {
1510 user_id: user4,
1511 should_notify: true
1512 }
1513 ]
1514 );
1515 assert_eq!(
1516 db.get_contacts(user4).await.unwrap(),
1517 [
1518 Contact::Accepted {
1519 user_id: user1,
1520 should_notify: false
1521 },
1522 Contact::Accepted {
1523 user_id: user4,
1524 should_notify: false
1525 },
1526 ]
1527 );
1528
1529 // An existing user cannot redeem invite codes.
1530 db.redeem_invite_code(&invite_code, "user-2", None)
1531 .await
1532 .unwrap_err();
1533 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1534 assert_eq!(invite_count, 1);
1535 }
1536
1537 pub struct TestDb {
1538 pub db: Option<Arc<dyn Db>>,
1539 pub url: String,
1540 }
1541
1542 impl TestDb {
1543 pub async fn postgres() -> Self {
1544 lazy_static! {
1545 static ref LOCK: Mutex<()> = Mutex::new(());
1546 }
1547
1548 let _guard = LOCK.lock();
1549 let mut rng = StdRng::from_entropy();
1550 let name = format!("zed-test-{}", rng.gen::<u128>());
1551 let url = format!("postgres://postgres@localhost/{}", name);
1552 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1553 Postgres::create_database(&url)
1554 .await
1555 .expect("failed to create test db");
1556 let db = PostgresDb::new(&url, 5).await.unwrap();
1557 let migrator = Migrator::new(migrations_path).await.unwrap();
1558 migrator.run(&db.pool).await.unwrap();
1559 Self {
1560 db: Some(Arc::new(db)),
1561 url,
1562 }
1563 }
1564
1565 pub fn fake(background: Arc<Background>) -> Self {
1566 Self {
1567 db: Some(Arc::new(FakeDb::new(background))),
1568 url: Default::default(),
1569 }
1570 }
1571
1572 pub fn db(&self) -> &Arc<dyn Db> {
1573 self.db.as_ref().unwrap()
1574 }
1575 }
1576
1577 impl Drop for TestDb {
1578 fn drop(&mut self) {
1579 if let Some(db) = self.db.take() {
1580 futures::executor::block_on(db.teardown(&self.url));
1581 }
1582 }
1583 }
1584
1585 pub struct FakeDb {
1586 background: Arc<Background>,
1587 pub users: Mutex<BTreeMap<UserId, User>>,
1588 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1589 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1590 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1591 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1592 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1593 pub contacts: Mutex<Vec<FakeContact>>,
1594 next_channel_message_id: Mutex<i32>,
1595 next_user_id: Mutex<i32>,
1596 next_org_id: Mutex<i32>,
1597 next_channel_id: Mutex<i32>,
1598 }
1599
1600 #[derive(Debug)]
1601 pub struct FakeContact {
1602 pub requester_id: UserId,
1603 pub responder_id: UserId,
1604 pub accepted: bool,
1605 pub should_notify: bool,
1606 }
1607
1608 impl FakeDb {
1609 pub fn new(background: Arc<Background>) -> Self {
1610 Self {
1611 background,
1612 users: Default::default(),
1613 next_user_id: Mutex::new(1),
1614 orgs: Default::default(),
1615 next_org_id: Mutex::new(1),
1616 org_memberships: Default::default(),
1617 channels: Default::default(),
1618 next_channel_id: Mutex::new(1),
1619 channel_memberships: Default::default(),
1620 channel_messages: Default::default(),
1621 next_channel_message_id: Mutex::new(1),
1622 contacts: Default::default(),
1623 }
1624 }
1625 }
1626
1627 #[async_trait]
1628 impl Db for FakeDb {
1629 async fn create_user(
1630 &self,
1631 github_login: &str,
1632 email_address: Option<&str>,
1633 admin: bool,
1634 ) -> Result<UserId> {
1635 self.background.simulate_random_delay().await;
1636
1637 let mut users = self.users.lock();
1638 if let Some(user) = users
1639 .values()
1640 .find(|user| user.github_login == github_login)
1641 {
1642 Ok(user.id)
1643 } else {
1644 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1645 users.insert(
1646 user_id,
1647 User {
1648 id: user_id,
1649 github_login: github_login.to_string(),
1650 email_address: email_address.map(str::to_string),
1651 admin,
1652 invite_code: None,
1653 invite_count: 0,
1654 connected_once: false,
1655 },
1656 );
1657 Ok(user_id)
1658 }
1659 }
1660
1661 async fn get_all_users(&self) -> Result<Vec<User>> {
1662 unimplemented!()
1663 }
1664
1665 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1666 unimplemented!()
1667 }
1668
1669 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1670 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1671 }
1672
1673 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1674 self.background.simulate_random_delay().await;
1675 let users = self.users.lock();
1676 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1677 }
1678
1679 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1680 Ok(self
1681 .users
1682 .lock()
1683 .values()
1684 .find(|user| user.github_login == github_login)
1685 .cloned())
1686 }
1687
1688 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1689 unimplemented!()
1690 }
1691
1692 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1693 self.background.simulate_random_delay().await;
1694 let mut users = self.users.lock();
1695 let mut user = users
1696 .get_mut(&id)
1697 .ok_or_else(|| anyhow!("user not found"))?;
1698 user.connected_once = connected_once;
1699 Ok(())
1700 }
1701
1702 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1703 unimplemented!()
1704 }
1705
1706 // invite codes
1707
1708 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1709 unimplemented!()
1710 }
1711
1712 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1713 unimplemented!()
1714 }
1715
1716 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1717 unimplemented!()
1718 }
1719
1720 async fn redeem_invite_code(
1721 &self,
1722 _code: &str,
1723 _login: &str,
1724 _email_address: Option<&str>,
1725 ) -> Result<UserId> {
1726 unimplemented!()
1727 }
1728
1729 // contacts
1730
1731 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1732 self.background.simulate_random_delay().await;
1733 let mut contacts = vec![Contact::Accepted {
1734 user_id: id,
1735 should_notify: false,
1736 }];
1737
1738 for contact in self.contacts.lock().iter() {
1739 if contact.requester_id == id {
1740 if contact.accepted {
1741 contacts.push(Contact::Accepted {
1742 user_id: contact.responder_id,
1743 should_notify: contact.should_notify,
1744 });
1745 } else {
1746 contacts.push(Contact::Outgoing {
1747 user_id: contact.responder_id,
1748 });
1749 }
1750 } else if contact.responder_id == id {
1751 if contact.accepted {
1752 contacts.push(Contact::Accepted {
1753 user_id: contact.requester_id,
1754 should_notify: false,
1755 });
1756 } else {
1757 contacts.push(Contact::Incoming {
1758 user_id: contact.requester_id,
1759 should_notify: contact.should_notify,
1760 });
1761 }
1762 }
1763 }
1764
1765 contacts.sort_unstable_by_key(|contact| contact.user_id());
1766 Ok(contacts)
1767 }
1768
1769 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1770 self.background.simulate_random_delay().await;
1771 Ok(self.contacts.lock().iter().any(|contact| {
1772 contact.accepted
1773 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1774 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1775 }))
1776 }
1777
1778 async fn send_contact_request(
1779 &self,
1780 requester_id: UserId,
1781 responder_id: UserId,
1782 ) -> Result<()> {
1783 let mut contacts = self.contacts.lock();
1784 for contact in contacts.iter_mut() {
1785 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1786 if contact.accepted {
1787 Err(anyhow!("contact already exists"))?;
1788 } else {
1789 Err(anyhow!("contact already requested"))?;
1790 }
1791 }
1792 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1793 if contact.accepted {
1794 Err(anyhow!("contact already exists"))?;
1795 } else {
1796 contact.accepted = true;
1797 contact.should_notify = false;
1798 return Ok(());
1799 }
1800 }
1801 }
1802 contacts.push(FakeContact {
1803 requester_id,
1804 responder_id,
1805 accepted: false,
1806 should_notify: true,
1807 });
1808 Ok(())
1809 }
1810
1811 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1812 self.contacts.lock().retain(|contact| {
1813 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1814 });
1815 Ok(())
1816 }
1817
1818 async fn dismiss_contact_notification(
1819 &self,
1820 user_id: UserId,
1821 contact_user_id: UserId,
1822 ) -> Result<()> {
1823 let mut contacts = self.contacts.lock();
1824 for contact in contacts.iter_mut() {
1825 if contact.requester_id == contact_user_id
1826 && contact.responder_id == user_id
1827 && !contact.accepted
1828 {
1829 contact.should_notify = false;
1830 return Ok(());
1831 }
1832 if contact.requester_id == user_id
1833 && contact.responder_id == contact_user_id
1834 && contact.accepted
1835 {
1836 contact.should_notify = false;
1837 return Ok(());
1838 }
1839 }
1840 Err(anyhow!("no such notification"))?
1841 }
1842
1843 async fn respond_to_contact_request(
1844 &self,
1845 responder_id: UserId,
1846 requester_id: UserId,
1847 accept: bool,
1848 ) -> Result<()> {
1849 let mut contacts = self.contacts.lock();
1850 for (ix, contact) in contacts.iter_mut().enumerate() {
1851 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1852 if contact.accepted {
1853 Err(anyhow!("contact already confirmed"))?;
1854 }
1855 if accept {
1856 contact.accepted = true;
1857 contact.should_notify = true;
1858 } else {
1859 contacts.remove(ix);
1860 }
1861 return Ok(());
1862 }
1863 }
1864 Err(anyhow!("no such contact request"))?
1865 }
1866
1867 async fn create_access_token_hash(
1868 &self,
1869 _user_id: UserId,
1870 _access_token_hash: &str,
1871 _max_access_token_count: usize,
1872 ) -> Result<()> {
1873 unimplemented!()
1874 }
1875
1876 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1877 unimplemented!()
1878 }
1879
1880 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1881 unimplemented!()
1882 }
1883
1884 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1885 self.background.simulate_random_delay().await;
1886 let mut orgs = self.orgs.lock();
1887 if orgs.values().any(|org| org.slug == slug) {
1888 Err(anyhow!("org already exists"))?
1889 } else {
1890 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1891 orgs.insert(
1892 org_id,
1893 Org {
1894 id: org_id,
1895 name: name.to_string(),
1896 slug: slug.to_string(),
1897 },
1898 );
1899 Ok(org_id)
1900 }
1901 }
1902
1903 async fn add_org_member(
1904 &self,
1905 org_id: OrgId,
1906 user_id: UserId,
1907 is_admin: bool,
1908 ) -> Result<()> {
1909 self.background.simulate_random_delay().await;
1910 if !self.orgs.lock().contains_key(&org_id) {
1911 Err(anyhow!("org does not exist"))?;
1912 }
1913 if !self.users.lock().contains_key(&user_id) {
1914 Err(anyhow!("user does not exist"))?;
1915 }
1916
1917 self.org_memberships
1918 .lock()
1919 .entry((org_id, user_id))
1920 .or_insert(is_admin);
1921 Ok(())
1922 }
1923
1924 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1925 self.background.simulate_random_delay().await;
1926 if !self.orgs.lock().contains_key(&org_id) {
1927 Err(anyhow!("org does not exist"))?;
1928 }
1929
1930 let mut channels = self.channels.lock();
1931 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1932 channels.insert(
1933 channel_id,
1934 Channel {
1935 id: channel_id,
1936 name: name.to_string(),
1937 owner_id: org_id.0,
1938 owner_is_user: false,
1939 },
1940 );
1941 Ok(channel_id)
1942 }
1943
1944 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1945 self.background.simulate_random_delay().await;
1946 Ok(self
1947 .channels
1948 .lock()
1949 .values()
1950 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1951 .cloned()
1952 .collect())
1953 }
1954
1955 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1956 self.background.simulate_random_delay().await;
1957 let channels = self.channels.lock();
1958 let memberships = self.channel_memberships.lock();
1959 Ok(channels
1960 .values()
1961 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1962 .cloned()
1963 .collect())
1964 }
1965
1966 async fn can_user_access_channel(
1967 &self,
1968 user_id: UserId,
1969 channel_id: ChannelId,
1970 ) -> Result<bool> {
1971 self.background.simulate_random_delay().await;
1972 Ok(self
1973 .channel_memberships
1974 .lock()
1975 .contains_key(&(channel_id, user_id)))
1976 }
1977
1978 async fn add_channel_member(
1979 &self,
1980 channel_id: ChannelId,
1981 user_id: UserId,
1982 is_admin: bool,
1983 ) -> Result<()> {
1984 self.background.simulate_random_delay().await;
1985 if !self.channels.lock().contains_key(&channel_id) {
1986 Err(anyhow!("channel does not exist"))?;
1987 }
1988 if !self.users.lock().contains_key(&user_id) {
1989 Err(anyhow!("user does not exist"))?;
1990 }
1991
1992 self.channel_memberships
1993 .lock()
1994 .entry((channel_id, user_id))
1995 .or_insert(is_admin);
1996 Ok(())
1997 }
1998
1999 async fn create_channel_message(
2000 &self,
2001 channel_id: ChannelId,
2002 sender_id: UserId,
2003 body: &str,
2004 timestamp: OffsetDateTime,
2005 nonce: u128,
2006 ) -> Result<MessageId> {
2007 self.background.simulate_random_delay().await;
2008 if !self.channels.lock().contains_key(&channel_id) {
2009 Err(anyhow!("channel does not exist"))?;
2010 }
2011 if !self.users.lock().contains_key(&sender_id) {
2012 Err(anyhow!("user does not exist"))?;
2013 }
2014
2015 let mut messages = self.channel_messages.lock();
2016 if let Some(message) = messages
2017 .values()
2018 .find(|message| message.nonce.as_u128() == nonce)
2019 {
2020 Ok(message.id)
2021 } else {
2022 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2023 messages.insert(
2024 message_id,
2025 ChannelMessage {
2026 id: message_id,
2027 channel_id,
2028 sender_id,
2029 body: body.to_string(),
2030 sent_at: timestamp,
2031 nonce: Uuid::from_u128(nonce),
2032 },
2033 );
2034 Ok(message_id)
2035 }
2036 }
2037
2038 async fn get_channel_messages(
2039 &self,
2040 channel_id: ChannelId,
2041 count: usize,
2042 before_id: Option<MessageId>,
2043 ) -> Result<Vec<ChannelMessage>> {
2044 let mut messages = self
2045 .channel_messages
2046 .lock()
2047 .values()
2048 .rev()
2049 .filter(|message| {
2050 message.channel_id == channel_id
2051 && message.id < before_id.unwrap_or(MessageId::MAX)
2052 })
2053 .take(count)
2054 .cloned()
2055 .collect::<Vec<_>>();
2056 messages.sort_unstable_by_key(|message| message.id);
2057 Ok(messages)
2058 }
2059
2060 async fn teardown(&self, _: &str) {}
2061
2062 #[cfg(test)]
2063 fn as_fake(&self) -> Option<&FakeDb> {
2064 Some(self)
2065 }
2066 }
2067}