1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use futures::StreamExt;
6use nanoid::nanoid;
7use serde::Serialize;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow};
10use time::OffsetDateTime;
11
12#[async_trait]
13pub trait Db: Send + Sync {
14 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
15 async fn get_all_users(&self) -> Result<Vec<User>>;
16 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
17 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
18 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
19 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
20 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
21 async fn destroy_user(&self, id: UserId) -> Result<()>;
22
23 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
24 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
25 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
26 async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
27
28 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
29 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
30 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
31 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
32 async fn dismiss_contact_notification(
33 &self,
34 responder_id: UserId,
35 requester_id: UserId,
36 ) -> Result<()>;
37 async fn respond_to_contact_request(
38 &self,
39 responder_id: UserId,
40 requester_id: UserId,
41 accept: bool,
42 ) -> Result<()>;
43
44 async fn create_access_token_hash(
45 &self,
46 user_id: UserId,
47 access_token_hash: &str,
48 max_access_token_count: usize,
49 ) -> Result<()>;
50 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
51 #[cfg(any(test, feature = "seed-support"))]
52
53 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
54 #[cfg(any(test, feature = "seed-support"))]
55 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
56 #[cfg(any(test, feature = "seed-support"))]
57 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
58 #[cfg(any(test, feature = "seed-support"))]
59 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
60 #[cfg(any(test, feature = "seed-support"))]
61
62 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
63 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
64 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
65 -> Result<bool>;
66 #[cfg(any(test, feature = "seed-support"))]
67 async fn add_channel_member(
68 &self,
69 channel_id: ChannelId,
70 user_id: UserId,
71 is_admin: bool,
72 ) -> Result<()>;
73 async fn create_channel_message(
74 &self,
75 channel_id: ChannelId,
76 sender_id: UserId,
77 body: &str,
78 timestamp: OffsetDateTime,
79 nonce: u128,
80 ) -> Result<MessageId>;
81 async fn get_channel_messages(
82 &self,
83 channel_id: ChannelId,
84 count: usize,
85 before_id: Option<MessageId>,
86 ) -> Result<Vec<ChannelMessage>>;
87 #[cfg(test)]
88 async fn teardown(&self, url: &str);
89 #[cfg(test)]
90 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
91}
92
93pub struct PostgresDb {
94 pool: sqlx::PgPool,
95}
96
97impl PostgresDb {
98 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
99 let pool = DbOptions::new()
100 .max_connections(max_connections)
101 .connect(&url)
102 .await
103 .context("failed to connect to postgres database")?;
104 Ok(Self { pool })
105 }
106}
107
108#[async_trait]
109impl Db for PostgresDb {
110 // users
111
112 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
113 let query = "
114 INSERT INTO users (github_login, admin)
115 VALUES ($1, $2)
116 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
117 RETURNING id
118 ";
119 Ok(sqlx::query_scalar(query)
120 .bind(github_login)
121 .bind(admin)
122 .fetch_one(&self.pool)
123 .await
124 .map(UserId)?)
125 }
126
127 async fn get_all_users(&self) -> Result<Vec<User>> {
128 let query = "SELECT * FROM users ORDER BY github_login ASC";
129 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
130 }
131
132 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
133 let like_string = fuzzy_like_string(name_query);
134 let query = "
135 SELECT users.*
136 FROM users
137 WHERE github_login ILIKE $1
138 ORDER BY github_login <-> $2
139 LIMIT $3
140 ";
141 Ok(sqlx::query_as(query)
142 .bind(like_string)
143 .bind(name_query)
144 .bind(limit)
145 .fetch_all(&self.pool)
146 .await?)
147 }
148
149 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
150 let users = self.get_users_by_ids(vec![id]).await?;
151 Ok(users.into_iter().next())
152 }
153
154 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
155 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
156 let query = "
157 SELECT users.*
158 FROM users
159 WHERE users.id = ANY ($1)
160 ";
161 Ok(sqlx::query_as(query)
162 .bind(&ids)
163 .fetch_all(&self.pool)
164 .await?)
165 }
166
167 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
168 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
169 Ok(sqlx::query_as(query)
170 .bind(github_login)
171 .fetch_optional(&self.pool)
172 .await?)
173 }
174
175 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
176 let query = "UPDATE users SET admin = $1 WHERE id = $2";
177 Ok(sqlx::query(query)
178 .bind(is_admin)
179 .bind(id.0)
180 .execute(&self.pool)
181 .await
182 .map(drop)?)
183 }
184
185 async fn destroy_user(&self, id: UserId) -> Result<()> {
186 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
187 sqlx::query(query)
188 .bind(id.0)
189 .execute(&self.pool)
190 .await
191 .map(drop)?;
192 let query = "DELETE FROM users WHERE id = $1;";
193 Ok(sqlx::query(query)
194 .bind(id.0)
195 .execute(&self.pool)
196 .await
197 .map(drop)?)
198 }
199
200 // invite codes
201
202 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
203 let mut tx = self.pool.begin().await?;
204 sqlx::query(
205 "
206 UPDATE users
207 SET invite_code = $1
208 WHERE id = $2 AND invite_code IS NULL
209 ",
210 )
211 .bind(nanoid!(16))
212 .bind(id)
213 .execute(&mut tx)
214 .await?;
215 sqlx::query(
216 "
217 UPDATE users
218 SET invite_count = $1
219 WHERE id = $2
220 ",
221 )
222 .bind(count)
223 .bind(id)
224 .execute(&mut tx)
225 .await?;
226 tx.commit().await?;
227 Ok(())
228 }
229
230 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
231 let result: Option<(String, i32)> = sqlx::query_as(
232 "
233 SELECT invite_code, invite_count
234 FROM users
235 WHERE id = $1 AND invite_code IS NOT NULL
236 ",
237 )
238 .bind(id)
239 .fetch_optional(&self.pool)
240 .await?;
241 if let Some((code, count)) = result {
242 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
243 } else {
244 Ok(None)
245 }
246 }
247
248 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
249 sqlx::query_as(
250 "
251 SELECT *
252 FROM users
253 WHERE invite_code = $1
254 ",
255 )
256 .bind(code)
257 .fetch_optional(&self.pool)
258 .await?
259 .ok_or_else(|| {
260 Error::Http(
261 StatusCode::NOT_FOUND,
262 "that invite code does not exist".to_string(),
263 )
264 })
265 }
266
267 async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
268 let mut tx = self.pool.begin().await?;
269
270 let inviter_id: Option<UserId> = sqlx::query_scalar(
271 "
272 UPDATE users
273 SET invite_count = invite_count - 1
274 WHERE
275 invite_code = $1 AND
276 invite_count > 0
277 RETURNING id
278 ",
279 )
280 .bind(code)
281 .fetch_optional(&mut tx)
282 .await?;
283
284 let inviter_id = match inviter_id {
285 Some(inviter_id) => inviter_id,
286 None => {
287 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
288 .bind(code)
289 .fetch_optional(&mut tx)
290 .await?
291 .is_some()
292 {
293 Err(Error::Http(
294 StatusCode::UNAUTHORIZED,
295 "no invites remaining".to_string(),
296 ))?
297 } else {
298 Err(Error::Http(
299 StatusCode::NOT_FOUND,
300 "invite code not found".to_string(),
301 ))?
302 }
303 }
304 };
305
306 let invitee_id = sqlx::query_scalar(
307 "
308 INSERT INTO users
309 (github_login, admin, inviter_id)
310 VALUES
311 ($1, 'f', $2)
312 RETURNING id
313 ",
314 )
315 .bind(login)
316 .bind(inviter_id)
317 .fetch_one(&mut tx)
318 .await
319 .map(UserId)?;
320
321 sqlx::query(
322 "
323 INSERT INTO contacts
324 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
325 VALUES
326 ($1, $2, 't', 't', 't')
327 ",
328 )
329 .bind(inviter_id)
330 .bind(invitee_id)
331 .execute(&mut tx)
332 .await?;
333
334 tx.commit().await?;
335 Ok(invitee_id)
336 }
337
338 // contacts
339
340 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
341 let query = "
342 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
343 FROM contacts
344 WHERE user_id_a = $1 OR user_id_b = $1;
345 ";
346
347 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
348 .bind(user_id)
349 .fetch(&self.pool);
350
351 let mut contacts = vec![Contact::Accepted {
352 user_id,
353 should_notify: false,
354 }];
355 while let Some(row) = rows.next().await {
356 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
357
358 if user_id_a == user_id {
359 if accepted {
360 contacts.push(Contact::Accepted {
361 user_id: user_id_b,
362 should_notify: should_notify && a_to_b,
363 });
364 } else if a_to_b {
365 contacts.push(Contact::Outgoing { user_id: user_id_b })
366 } else {
367 contacts.push(Contact::Incoming {
368 user_id: user_id_b,
369 should_notify,
370 });
371 }
372 } else {
373 if accepted {
374 contacts.push(Contact::Accepted {
375 user_id: user_id_a,
376 should_notify: should_notify && !a_to_b,
377 });
378 } else if a_to_b {
379 contacts.push(Contact::Incoming {
380 user_id: user_id_a,
381 should_notify,
382 });
383 } else {
384 contacts.push(Contact::Outgoing { user_id: user_id_a });
385 }
386 }
387 }
388
389 contacts.sort_unstable_by_key(|contact| contact.user_id());
390
391 Ok(contacts)
392 }
393
394 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
395 let (id_a, id_b) = if user_id_1 < user_id_2 {
396 (user_id_1, user_id_2)
397 } else {
398 (user_id_2, user_id_1)
399 };
400
401 let query = "
402 SELECT 1 FROM contacts
403 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
404 LIMIT 1
405 ";
406 Ok(sqlx::query_scalar::<_, i32>(query)
407 .bind(id_a.0)
408 .bind(id_b.0)
409 .fetch_optional(&self.pool)
410 .await?
411 .is_some())
412 }
413
414 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
415 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
416 (sender_id, receiver_id, true)
417 } else {
418 (receiver_id, sender_id, false)
419 };
420 let query = "
421 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
422 VALUES ($1, $2, $3, 'f', 't')
423 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
424 SET
425 accepted = 't',
426 should_notify = 'f'
427 WHERE
428 NOT contacts.accepted AND
429 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
430 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
431 ";
432 let result = sqlx::query(query)
433 .bind(id_a.0)
434 .bind(id_b.0)
435 .bind(a_to_b)
436 .execute(&self.pool)
437 .await?;
438
439 if result.rows_affected() == 1 {
440 Ok(())
441 } else {
442 Err(anyhow!("contact already requested"))?
443 }
444 }
445
446 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
447 let (id_a, id_b) = if responder_id < requester_id {
448 (responder_id, requester_id)
449 } else {
450 (requester_id, responder_id)
451 };
452 let query = "
453 DELETE FROM contacts
454 WHERE user_id_a = $1 AND user_id_b = $2;
455 ";
456 let result = sqlx::query(query)
457 .bind(id_a.0)
458 .bind(id_b.0)
459 .execute(&self.pool)
460 .await?;
461
462 if result.rows_affected() == 1 {
463 Ok(())
464 } else {
465 Err(anyhow!("no such contact"))?
466 }
467 }
468
469 async fn dismiss_contact_notification(
470 &self,
471 user_id: UserId,
472 contact_user_id: UserId,
473 ) -> Result<()> {
474 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
475 (user_id, contact_user_id, true)
476 } else {
477 (contact_user_id, user_id, false)
478 };
479
480 let query = "
481 UPDATE contacts
482 SET should_notify = 'f'
483 WHERE
484 user_id_a = $1 AND user_id_b = $2 AND
485 (
486 (a_to_b = $3 AND accepted) OR
487 (a_to_b != $3 AND NOT accepted)
488 );
489 ";
490
491 let result = sqlx::query(query)
492 .bind(id_a.0)
493 .bind(id_b.0)
494 .bind(a_to_b)
495 .execute(&self.pool)
496 .await?;
497
498 if result.rows_affected() == 0 {
499 Err(anyhow!("no such contact request"))?;
500 }
501
502 Ok(())
503 }
504
505 async fn respond_to_contact_request(
506 &self,
507 responder_id: UserId,
508 requester_id: UserId,
509 accept: bool,
510 ) -> Result<()> {
511 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
512 (responder_id, requester_id, false)
513 } else {
514 (requester_id, responder_id, true)
515 };
516 let result = if accept {
517 let query = "
518 UPDATE contacts
519 SET accepted = 't', should_notify = 't'
520 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
521 ";
522 sqlx::query(query)
523 .bind(id_a.0)
524 .bind(id_b.0)
525 .bind(a_to_b)
526 .execute(&self.pool)
527 .await?
528 } else {
529 let query = "
530 DELETE FROM contacts
531 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
532 ";
533 sqlx::query(query)
534 .bind(id_a.0)
535 .bind(id_b.0)
536 .bind(a_to_b)
537 .execute(&self.pool)
538 .await?
539 };
540 if result.rows_affected() == 1 {
541 Ok(())
542 } else {
543 Err(anyhow!("no such contact request"))?
544 }
545 }
546
547 // access tokens
548
549 async fn create_access_token_hash(
550 &self,
551 user_id: UserId,
552 access_token_hash: &str,
553 max_access_token_count: usize,
554 ) -> Result<()> {
555 let insert_query = "
556 INSERT INTO access_tokens (user_id, hash)
557 VALUES ($1, $2);
558 ";
559 let cleanup_query = "
560 DELETE FROM access_tokens
561 WHERE id IN (
562 SELECT id from access_tokens
563 WHERE user_id = $1
564 ORDER BY id DESC
565 OFFSET $3
566 )
567 ";
568
569 let mut tx = self.pool.begin().await?;
570 sqlx::query(insert_query)
571 .bind(user_id.0)
572 .bind(access_token_hash)
573 .execute(&mut tx)
574 .await?;
575 sqlx::query(cleanup_query)
576 .bind(user_id.0)
577 .bind(access_token_hash)
578 .bind(max_access_token_count as u32)
579 .execute(&mut tx)
580 .await?;
581 Ok(tx.commit().await?)
582 }
583
584 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
585 let query = "
586 SELECT hash
587 FROM access_tokens
588 WHERE user_id = $1
589 ORDER BY id DESC
590 ";
591 Ok(sqlx::query_scalar(query)
592 .bind(user_id.0)
593 .fetch_all(&self.pool)
594 .await?)
595 }
596
597 // orgs
598
599 #[allow(unused)] // Help rust-analyzer
600 #[cfg(any(test, feature = "seed-support"))]
601 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
602 let query = "
603 SELECT *
604 FROM orgs
605 WHERE slug = $1
606 ";
607 Ok(sqlx::query_as(query)
608 .bind(slug)
609 .fetch_optional(&self.pool)
610 .await?)
611 }
612
613 #[cfg(any(test, feature = "seed-support"))]
614 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
615 let query = "
616 INSERT INTO orgs (name, slug)
617 VALUES ($1, $2)
618 RETURNING id
619 ";
620 Ok(sqlx::query_scalar(query)
621 .bind(name)
622 .bind(slug)
623 .fetch_one(&self.pool)
624 .await
625 .map(OrgId)?)
626 }
627
628 #[cfg(any(test, feature = "seed-support"))]
629 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
630 let query = "
631 INSERT INTO org_memberships (org_id, user_id, admin)
632 VALUES ($1, $2, $3)
633 ON CONFLICT DO NOTHING
634 ";
635 Ok(sqlx::query(query)
636 .bind(org_id.0)
637 .bind(user_id.0)
638 .bind(is_admin)
639 .execute(&self.pool)
640 .await
641 .map(drop)?)
642 }
643
644 // channels
645
646 #[cfg(any(test, feature = "seed-support"))]
647 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
648 let query = "
649 INSERT INTO channels (owner_id, owner_is_user, name)
650 VALUES ($1, false, $2)
651 RETURNING id
652 ";
653 Ok(sqlx::query_scalar(query)
654 .bind(org_id.0)
655 .bind(name)
656 .fetch_one(&self.pool)
657 .await
658 .map(ChannelId)?)
659 }
660
661 #[allow(unused)] // Help rust-analyzer
662 #[cfg(any(test, feature = "seed-support"))]
663 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
664 let query = "
665 SELECT *
666 FROM channels
667 WHERE
668 channels.owner_is_user = false AND
669 channels.owner_id = $1
670 ";
671 Ok(sqlx::query_as(query)
672 .bind(org_id.0)
673 .fetch_all(&self.pool)
674 .await?)
675 }
676
677 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
678 let query = "
679 SELECT
680 channels.*
681 FROM
682 channel_memberships, channels
683 WHERE
684 channel_memberships.user_id = $1 AND
685 channel_memberships.channel_id = channels.id
686 ";
687 Ok(sqlx::query_as(query)
688 .bind(user_id.0)
689 .fetch_all(&self.pool)
690 .await?)
691 }
692
693 async fn can_user_access_channel(
694 &self,
695 user_id: UserId,
696 channel_id: ChannelId,
697 ) -> Result<bool> {
698 let query = "
699 SELECT id
700 FROM channel_memberships
701 WHERE user_id = $1 AND channel_id = $2
702 LIMIT 1
703 ";
704 Ok(sqlx::query_scalar::<_, i32>(query)
705 .bind(user_id.0)
706 .bind(channel_id.0)
707 .fetch_optional(&self.pool)
708 .await
709 .map(|e| e.is_some())?)
710 }
711
712 #[cfg(any(test, feature = "seed-support"))]
713 async fn add_channel_member(
714 &self,
715 channel_id: ChannelId,
716 user_id: UserId,
717 is_admin: bool,
718 ) -> Result<()> {
719 let query = "
720 INSERT INTO channel_memberships (channel_id, user_id, admin)
721 VALUES ($1, $2, $3)
722 ON CONFLICT DO NOTHING
723 ";
724 Ok(sqlx::query(query)
725 .bind(channel_id.0)
726 .bind(user_id.0)
727 .bind(is_admin)
728 .execute(&self.pool)
729 .await
730 .map(drop)?)
731 }
732
733 // messages
734
735 async fn create_channel_message(
736 &self,
737 channel_id: ChannelId,
738 sender_id: UserId,
739 body: &str,
740 timestamp: OffsetDateTime,
741 nonce: u128,
742 ) -> Result<MessageId> {
743 let query = "
744 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
745 VALUES ($1, $2, $3, $4, $5)
746 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
747 RETURNING id
748 ";
749 Ok(sqlx::query_scalar(query)
750 .bind(channel_id.0)
751 .bind(sender_id.0)
752 .bind(body)
753 .bind(timestamp)
754 .bind(Uuid::from_u128(nonce))
755 .fetch_one(&self.pool)
756 .await
757 .map(MessageId)?)
758 }
759
760 async fn get_channel_messages(
761 &self,
762 channel_id: ChannelId,
763 count: usize,
764 before_id: Option<MessageId>,
765 ) -> Result<Vec<ChannelMessage>> {
766 let query = r#"
767 SELECT * FROM (
768 SELECT
769 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
770 FROM
771 channel_messages
772 WHERE
773 channel_id = $1 AND
774 id < $2
775 ORDER BY id DESC
776 LIMIT $3
777 ) as recent_messages
778 ORDER BY id ASC
779 "#;
780 Ok(sqlx::query_as(query)
781 .bind(channel_id.0)
782 .bind(before_id.unwrap_or(MessageId::MAX))
783 .bind(count as i64)
784 .fetch_all(&self.pool)
785 .await?)
786 }
787
788 #[cfg(test)]
789 async fn teardown(&self, url: &str) {
790 use util::ResultExt;
791
792 let query = "
793 SELECT pg_terminate_backend(pg_stat_activity.pid)
794 FROM pg_stat_activity
795 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
796 ";
797 sqlx::query(query).execute(&self.pool).await.log_err();
798 self.pool.close().await;
799 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
800 .await
801 .log_err();
802 }
803
804 #[cfg(test)]
805 fn as_fake(&self) -> Option<&tests::FakeDb> {
806 None
807 }
808}
809
810macro_rules! id_type {
811 ($name:ident) => {
812 #[derive(
813 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
814 )]
815 #[sqlx(transparent)]
816 #[serde(transparent)]
817 pub struct $name(pub i32);
818
819 impl $name {
820 #[allow(unused)]
821 pub const MAX: Self = Self(i32::MAX);
822
823 #[allow(unused)]
824 pub fn from_proto(value: u64) -> Self {
825 Self(value as i32)
826 }
827
828 #[allow(unused)]
829 pub fn to_proto(&self) -> u64 {
830 self.0 as u64
831 }
832 }
833
834 impl std::fmt::Display for $name {
835 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
836 self.0.fmt(f)
837 }
838 }
839 };
840}
841
842id_type!(UserId);
843#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
844pub struct User {
845 pub id: UserId,
846 pub github_login: String,
847 pub admin: bool,
848 pub invite_code: Option<String>,
849 pub invite_count: i32,
850}
851
852id_type!(OrgId);
853#[derive(FromRow)]
854pub struct Org {
855 pub id: OrgId,
856 pub name: String,
857 pub slug: String,
858}
859
860id_type!(ChannelId);
861#[derive(Clone, Debug, FromRow, Serialize)]
862pub struct Channel {
863 pub id: ChannelId,
864 pub name: String,
865 pub owner_id: i32,
866 pub owner_is_user: bool,
867}
868
869id_type!(MessageId);
870#[derive(Clone, Debug, FromRow)]
871pub struct ChannelMessage {
872 pub id: MessageId,
873 pub channel_id: ChannelId,
874 pub sender_id: UserId,
875 pub body: String,
876 pub sent_at: OffsetDateTime,
877 pub nonce: Uuid,
878}
879
880#[derive(Clone, Debug, PartialEq, Eq)]
881pub enum Contact {
882 Accepted {
883 user_id: UserId,
884 should_notify: bool,
885 },
886 Outgoing {
887 user_id: UserId,
888 },
889 Incoming {
890 user_id: UserId,
891 should_notify: bool,
892 },
893}
894
895impl Contact {
896 pub fn user_id(&self) -> UserId {
897 match self {
898 Contact::Accepted { user_id, .. } => *user_id,
899 Contact::Outgoing { user_id } => *user_id,
900 Contact::Incoming { user_id, .. } => *user_id,
901 }
902 }
903}
904
905#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
906pub struct IncomingContactRequest {
907 pub requester_id: UserId,
908 pub should_notify: bool,
909}
910
911fn fuzzy_like_string(string: &str) -> String {
912 let mut result = String::with_capacity(string.len() * 2 + 1);
913 for c in string.chars() {
914 if c.is_alphanumeric() {
915 result.push('%');
916 result.push(c);
917 }
918 }
919 result.push('%');
920 result
921}
922
923#[cfg(test)]
924pub mod tests {
925 use super::*;
926 use anyhow::anyhow;
927 use collections::BTreeMap;
928 use gpui::executor::Background;
929 use lazy_static::lazy_static;
930 use parking_lot::Mutex;
931 use rand::prelude::*;
932 use sqlx::{
933 migrate::{MigrateDatabase, Migrator},
934 Postgres,
935 };
936 use std::{path::Path, sync::Arc};
937 use util::post_inc;
938
939 #[tokio::test(flavor = "multi_thread")]
940 async fn test_get_users_by_ids() {
941 for test_db in [
942 TestDb::postgres().await,
943 TestDb::fake(Arc::new(gpui::executor::Background::new())),
944 ] {
945 let db = test_db.db();
946
947 let user = db.create_user("user", false).await.unwrap();
948 let friend1 = db.create_user("friend-1", false).await.unwrap();
949 let friend2 = db.create_user("friend-2", false).await.unwrap();
950 let friend3 = db.create_user("friend-3", false).await.unwrap();
951
952 assert_eq!(
953 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
954 .await
955 .unwrap(),
956 vec![
957 User {
958 id: user,
959 github_login: "user".to_string(),
960 admin: false,
961 ..Default::default()
962 },
963 User {
964 id: friend1,
965 github_login: "friend-1".to_string(),
966 admin: false,
967 ..Default::default()
968 },
969 User {
970 id: friend2,
971 github_login: "friend-2".to_string(),
972 admin: false,
973 ..Default::default()
974 },
975 User {
976 id: friend3,
977 github_login: "friend-3".to_string(),
978 admin: false,
979 ..Default::default()
980 }
981 ]
982 );
983 }
984 }
985
986 #[tokio::test(flavor = "multi_thread")]
987 async fn test_recent_channel_messages() {
988 for test_db in [
989 TestDb::postgres().await,
990 TestDb::fake(Arc::new(gpui::executor::Background::new())),
991 ] {
992 let db = test_db.db();
993 let user = db.create_user("user", false).await.unwrap();
994 let org = db.create_org("org", "org").await.unwrap();
995 let channel = db.create_org_channel(org, "channel").await.unwrap();
996 for i in 0..10 {
997 db.create_channel_message(
998 channel,
999 user,
1000 &i.to_string(),
1001 OffsetDateTime::now_utc(),
1002 i,
1003 )
1004 .await
1005 .unwrap();
1006 }
1007
1008 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1009 assert_eq!(
1010 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1011 ["5", "6", "7", "8", "9"]
1012 );
1013
1014 let prev_messages = db
1015 .get_channel_messages(channel, 4, Some(messages[0].id))
1016 .await
1017 .unwrap();
1018 assert_eq!(
1019 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1020 ["1", "2", "3", "4"]
1021 );
1022 }
1023 }
1024
1025 #[tokio::test(flavor = "multi_thread")]
1026 async fn test_channel_message_nonces() {
1027 for test_db in [
1028 TestDb::postgres().await,
1029 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1030 ] {
1031 let db = test_db.db();
1032 let user = db.create_user("user", false).await.unwrap();
1033 let org = db.create_org("org", "org").await.unwrap();
1034 let channel = db.create_org_channel(org, "channel").await.unwrap();
1035
1036 let msg1_id = db
1037 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1038 .await
1039 .unwrap();
1040 let msg2_id = db
1041 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1042 .await
1043 .unwrap();
1044 let msg3_id = db
1045 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1046 .await
1047 .unwrap();
1048 let msg4_id = db
1049 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1050 .await
1051 .unwrap();
1052
1053 assert_ne!(msg1_id, msg2_id);
1054 assert_eq!(msg1_id, msg3_id);
1055 assert_eq!(msg2_id, msg4_id);
1056 }
1057 }
1058
1059 #[tokio::test(flavor = "multi_thread")]
1060 async fn test_create_access_tokens() {
1061 let test_db = TestDb::postgres().await;
1062 let db = test_db.db();
1063 let user = db.create_user("the-user", false).await.unwrap();
1064
1065 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1066 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1067 assert_eq!(
1068 db.get_access_token_hashes(user).await.unwrap(),
1069 &["h2".to_string(), "h1".to_string()]
1070 );
1071
1072 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1073 assert_eq!(
1074 db.get_access_token_hashes(user).await.unwrap(),
1075 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1076 );
1077
1078 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1079 assert_eq!(
1080 db.get_access_token_hashes(user).await.unwrap(),
1081 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1082 );
1083
1084 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1085 assert_eq!(
1086 db.get_access_token_hashes(user).await.unwrap(),
1087 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1088 );
1089 }
1090
1091 #[test]
1092 fn test_fuzzy_like_string() {
1093 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1094 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1095 assert_eq!(fuzzy_like_string(" z "), "%z%");
1096 }
1097
1098 #[tokio::test(flavor = "multi_thread")]
1099 async fn test_fuzzy_search_users() {
1100 let test_db = TestDb::postgres().await;
1101 let db = test_db.db();
1102 for github_login in [
1103 "California",
1104 "colorado",
1105 "oregon",
1106 "washington",
1107 "florida",
1108 "delaware",
1109 "rhode-island",
1110 ] {
1111 db.create_user(github_login, false).await.unwrap();
1112 }
1113
1114 assert_eq!(
1115 fuzzy_search_user_names(db, "clr").await,
1116 &["colorado", "California"]
1117 );
1118 assert_eq!(
1119 fuzzy_search_user_names(db, "ro").await,
1120 &["rhode-island", "colorado", "oregon"],
1121 );
1122
1123 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1124 db.fuzzy_search_users(query, 10)
1125 .await
1126 .unwrap()
1127 .into_iter()
1128 .map(|user| user.github_login)
1129 .collect::<Vec<_>>()
1130 }
1131 }
1132
1133 #[tokio::test(flavor = "multi_thread")]
1134 async fn test_add_contacts() {
1135 for test_db in [
1136 TestDb::postgres().await,
1137 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1138 ] {
1139 let db = test_db.db();
1140
1141 let user_1 = db.create_user("user1", false).await.unwrap();
1142 let user_2 = db.create_user("user2", false).await.unwrap();
1143 let user_3 = db.create_user("user3", false).await.unwrap();
1144
1145 // User starts with no contacts
1146 assert_eq!(
1147 db.get_contacts(user_1).await.unwrap(),
1148 vec![Contact::Accepted {
1149 user_id: user_1,
1150 should_notify: false
1151 }],
1152 );
1153
1154 // User requests a contact. Both users see the pending request.
1155 db.send_contact_request(user_1, user_2).await.unwrap();
1156 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1157 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1158 assert_eq!(
1159 db.get_contacts(user_1).await.unwrap(),
1160 &[
1161 Contact::Accepted {
1162 user_id: user_1,
1163 should_notify: false
1164 },
1165 Contact::Outgoing { user_id: user_2 }
1166 ],
1167 );
1168 assert_eq!(
1169 db.get_contacts(user_2).await.unwrap(),
1170 &[
1171 Contact::Incoming {
1172 user_id: user_1,
1173 should_notify: true
1174 },
1175 Contact::Accepted {
1176 user_id: user_2,
1177 should_notify: false
1178 },
1179 ]
1180 );
1181
1182 // User 2 dismisses the contact request notification without accepting or rejecting.
1183 // We shouldn't notify them again.
1184 db.dismiss_contact_notification(user_1, user_2)
1185 .await
1186 .unwrap_err();
1187 db.dismiss_contact_notification(user_2, user_1)
1188 .await
1189 .unwrap();
1190 assert_eq!(
1191 db.get_contacts(user_2).await.unwrap(),
1192 &[
1193 Contact::Incoming {
1194 user_id: user_1,
1195 should_notify: false
1196 },
1197 Contact::Accepted {
1198 user_id: user_2,
1199 should_notify: false
1200 },
1201 ]
1202 );
1203
1204 // User can't accept their own contact request
1205 db.respond_to_contact_request(user_1, user_2, true)
1206 .await
1207 .unwrap_err();
1208
1209 // User accepts a contact request. Both users see the contact.
1210 db.respond_to_contact_request(user_2, user_1, true)
1211 .await
1212 .unwrap();
1213 assert_eq!(
1214 db.get_contacts(user_1).await.unwrap(),
1215 &[
1216 Contact::Accepted {
1217 user_id: user_1,
1218 should_notify: false
1219 },
1220 Contact::Accepted {
1221 user_id: user_2,
1222 should_notify: true
1223 }
1224 ],
1225 );
1226 assert!(db.has_contact(user_1, user_2).await.unwrap());
1227 assert!(db.has_contact(user_2, user_1).await.unwrap());
1228 assert_eq!(
1229 db.get_contacts(user_2).await.unwrap(),
1230 &[
1231 Contact::Accepted {
1232 user_id: user_1,
1233 should_notify: false,
1234 },
1235 Contact::Accepted {
1236 user_id: user_2,
1237 should_notify: false,
1238 },
1239 ]
1240 );
1241
1242 // Users cannot re-request existing contacts.
1243 db.send_contact_request(user_1, user_2).await.unwrap_err();
1244 db.send_contact_request(user_2, user_1).await.unwrap_err();
1245
1246 // Users can't dismiss notifications of them accepting other users' requests.
1247 db.dismiss_contact_notification(user_2, user_1)
1248 .await
1249 .unwrap_err();
1250 assert_eq!(
1251 db.get_contacts(user_1).await.unwrap(),
1252 &[
1253 Contact::Accepted {
1254 user_id: user_1,
1255 should_notify: false
1256 },
1257 Contact::Accepted {
1258 user_id: user_2,
1259 should_notify: true,
1260 },
1261 ]
1262 );
1263
1264 // Users can dismiss notifications of other users accepting their requests.
1265 db.dismiss_contact_notification(user_1, user_2)
1266 .await
1267 .unwrap();
1268 assert_eq!(
1269 db.get_contacts(user_1).await.unwrap(),
1270 &[
1271 Contact::Accepted {
1272 user_id: user_1,
1273 should_notify: false
1274 },
1275 Contact::Accepted {
1276 user_id: user_2,
1277 should_notify: false,
1278 },
1279 ]
1280 );
1281
1282 // Users send each other concurrent contact requests and
1283 // see that they are immediately accepted.
1284 db.send_contact_request(user_1, user_3).await.unwrap();
1285 db.send_contact_request(user_3, user_1).await.unwrap();
1286 assert_eq!(
1287 db.get_contacts(user_1).await.unwrap(),
1288 &[
1289 Contact::Accepted {
1290 user_id: user_1,
1291 should_notify: false
1292 },
1293 Contact::Accepted {
1294 user_id: user_2,
1295 should_notify: false,
1296 },
1297 Contact::Accepted {
1298 user_id: user_3,
1299 should_notify: false
1300 },
1301 ]
1302 );
1303 assert_eq!(
1304 db.get_contacts(user_3).await.unwrap(),
1305 &[
1306 Contact::Accepted {
1307 user_id: user_1,
1308 should_notify: false
1309 },
1310 Contact::Accepted {
1311 user_id: user_3,
1312 should_notify: false
1313 }
1314 ],
1315 );
1316
1317 // User declines a contact request. Both users see that it is gone.
1318 db.send_contact_request(user_2, user_3).await.unwrap();
1319 db.respond_to_contact_request(user_3, user_2, false)
1320 .await
1321 .unwrap();
1322 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1323 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1324 assert_eq!(
1325 db.get_contacts(user_2).await.unwrap(),
1326 &[
1327 Contact::Accepted {
1328 user_id: user_1,
1329 should_notify: false
1330 },
1331 Contact::Accepted {
1332 user_id: user_2,
1333 should_notify: false
1334 }
1335 ]
1336 );
1337 assert_eq!(
1338 db.get_contacts(user_3).await.unwrap(),
1339 &[
1340 Contact::Accepted {
1341 user_id: user_1,
1342 should_notify: false
1343 },
1344 Contact::Accepted {
1345 user_id: user_3,
1346 should_notify: false
1347 }
1348 ],
1349 );
1350 }
1351 }
1352
1353 #[tokio::test(flavor = "multi_thread")]
1354 async fn test_invite_codes() {
1355 let postgres = TestDb::postgres().await;
1356 let db = postgres.db();
1357 let user1 = db.create_user("user-1", false).await.unwrap();
1358
1359 // Initially, user 1 has no invite code
1360 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1361
1362 // User 1 creates an invite code that can be used twice.
1363 db.set_invite_count(user1, 2).await.unwrap();
1364 let (invite_code, invite_count) =
1365 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1366 assert_eq!(invite_count, 2);
1367
1368 // User 2 redeems the invite code and becomes a contact of user 1.
1369 let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
1370 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1371 assert_eq!(invite_count, 1);
1372 assert_eq!(
1373 db.get_contacts(user1).await.unwrap(),
1374 [
1375 Contact::Accepted {
1376 user_id: user1,
1377 should_notify: false
1378 },
1379 Contact::Accepted {
1380 user_id: user2,
1381 should_notify: true
1382 }
1383 ]
1384 );
1385 assert_eq!(
1386 db.get_contacts(user2).await.unwrap(),
1387 [
1388 Contact::Accepted {
1389 user_id: user1,
1390 should_notify: false
1391 },
1392 Contact::Accepted {
1393 user_id: user2,
1394 should_notify: false
1395 }
1396 ]
1397 );
1398
1399 // User 3 redeems the invite code and becomes a contact of user 1.
1400 let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
1401 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1402 assert_eq!(invite_count, 0);
1403 assert_eq!(
1404 db.get_contacts(user1).await.unwrap(),
1405 [
1406 Contact::Accepted {
1407 user_id: user1,
1408 should_notify: false
1409 },
1410 Contact::Accepted {
1411 user_id: user2,
1412 should_notify: true
1413 },
1414 Contact::Accepted {
1415 user_id: user3,
1416 should_notify: true
1417 }
1418 ]
1419 );
1420 assert_eq!(
1421 db.get_contacts(user3).await.unwrap(),
1422 [
1423 Contact::Accepted {
1424 user_id: user1,
1425 should_notify: false
1426 },
1427 Contact::Accepted {
1428 user_id: user3,
1429 should_notify: false
1430 },
1431 ]
1432 );
1433
1434 // Trying to reedem the code for the third time results in an error.
1435 db.redeem_invite_code(&invite_code, "user-4")
1436 .await
1437 .unwrap_err();
1438
1439 // Invite count can be updated after the code has been created.
1440 db.set_invite_count(user1, 2).await.unwrap();
1441 let (latest_code, invite_count) =
1442 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1443 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1444 assert_eq!(invite_count, 2);
1445
1446 // User 4 can now redeem the invite code and becomes a contact of user 1.
1447 let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
1448 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1449 assert_eq!(invite_count, 1);
1450 assert_eq!(
1451 db.get_contacts(user1).await.unwrap(),
1452 [
1453 Contact::Accepted {
1454 user_id: user1,
1455 should_notify: false
1456 },
1457 Contact::Accepted {
1458 user_id: user2,
1459 should_notify: true
1460 },
1461 Contact::Accepted {
1462 user_id: user3,
1463 should_notify: true
1464 },
1465 Contact::Accepted {
1466 user_id: user4,
1467 should_notify: true
1468 }
1469 ]
1470 );
1471 assert_eq!(
1472 db.get_contacts(user4).await.unwrap(),
1473 [
1474 Contact::Accepted {
1475 user_id: user1,
1476 should_notify: false
1477 },
1478 Contact::Accepted {
1479 user_id: user4,
1480 should_notify: false
1481 },
1482 ]
1483 );
1484
1485 // An existing user cannot redeem invite codes.
1486 db.redeem_invite_code(&invite_code, "user-2")
1487 .await
1488 .unwrap_err();
1489 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1490 assert_eq!(invite_count, 1);
1491 }
1492
1493 pub struct TestDb {
1494 pub db: Option<Arc<dyn Db>>,
1495 pub url: String,
1496 }
1497
1498 impl TestDb {
1499 pub async fn postgres() -> Self {
1500 lazy_static! {
1501 static ref LOCK: Mutex<()> = Mutex::new(());
1502 }
1503
1504 let _guard = LOCK.lock();
1505 let mut rng = StdRng::from_entropy();
1506 let name = format!("zed-test-{}", rng.gen::<u128>());
1507 let url = format!("postgres://postgres@localhost/{}", name);
1508 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1509 Postgres::create_database(&url)
1510 .await
1511 .expect("failed to create test db");
1512 let db = PostgresDb::new(&url, 5).await.unwrap();
1513 let migrator = Migrator::new(migrations_path).await.unwrap();
1514 migrator.run(&db.pool).await.unwrap();
1515 Self {
1516 db: Some(Arc::new(db)),
1517 url,
1518 }
1519 }
1520
1521 pub fn fake(background: Arc<Background>) -> Self {
1522 Self {
1523 db: Some(Arc::new(FakeDb::new(background))),
1524 url: Default::default(),
1525 }
1526 }
1527
1528 pub fn db(&self) -> &Arc<dyn Db> {
1529 self.db.as_ref().unwrap()
1530 }
1531 }
1532
1533 impl Drop for TestDb {
1534 fn drop(&mut self) {
1535 if let Some(db) = self.db.take() {
1536 futures::executor::block_on(db.teardown(&self.url));
1537 }
1538 }
1539 }
1540
1541 pub struct FakeDb {
1542 background: Arc<Background>,
1543 pub users: Mutex<BTreeMap<UserId, User>>,
1544 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1545 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1546 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1547 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1548 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1549 pub contacts: Mutex<Vec<FakeContact>>,
1550 next_channel_message_id: Mutex<i32>,
1551 next_user_id: Mutex<i32>,
1552 next_org_id: Mutex<i32>,
1553 next_channel_id: Mutex<i32>,
1554 }
1555
1556 #[derive(Debug)]
1557 pub struct FakeContact {
1558 pub requester_id: UserId,
1559 pub responder_id: UserId,
1560 pub accepted: bool,
1561 pub should_notify: bool,
1562 }
1563
1564 impl FakeDb {
1565 pub fn new(background: Arc<Background>) -> Self {
1566 Self {
1567 background,
1568 users: Default::default(),
1569 next_user_id: Mutex::new(1),
1570 orgs: Default::default(),
1571 next_org_id: Mutex::new(1),
1572 org_memberships: Default::default(),
1573 channels: Default::default(),
1574 next_channel_id: Mutex::new(1),
1575 channel_memberships: Default::default(),
1576 channel_messages: Default::default(),
1577 next_channel_message_id: Mutex::new(1),
1578 contacts: Default::default(),
1579 }
1580 }
1581 }
1582
1583 #[async_trait]
1584 impl Db for FakeDb {
1585 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1586 self.background.simulate_random_delay().await;
1587
1588 let mut users = self.users.lock();
1589 if let Some(user) = users
1590 .values()
1591 .find(|user| user.github_login == github_login)
1592 {
1593 Ok(user.id)
1594 } else {
1595 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1596 users.insert(
1597 user_id,
1598 User {
1599 id: user_id,
1600 github_login: github_login.to_string(),
1601 admin,
1602 invite_code: None,
1603 invite_count: 0,
1604 },
1605 );
1606 Ok(user_id)
1607 }
1608 }
1609
1610 async fn get_all_users(&self) -> Result<Vec<User>> {
1611 unimplemented!()
1612 }
1613
1614 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1615 unimplemented!()
1616 }
1617
1618 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1619 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1620 }
1621
1622 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1623 self.background.simulate_random_delay().await;
1624 let users = self.users.lock();
1625 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1626 }
1627
1628 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1629 Ok(self
1630 .users
1631 .lock()
1632 .values()
1633 .find(|user| user.github_login == github_login)
1634 .cloned())
1635 }
1636
1637 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1638 unimplemented!()
1639 }
1640
1641 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1642 unimplemented!()
1643 }
1644
1645 // invite codes
1646
1647 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1648 unimplemented!()
1649 }
1650
1651 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1652 unimplemented!()
1653 }
1654
1655 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1656 unimplemented!()
1657 }
1658
1659 async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
1660 unimplemented!()
1661 }
1662
1663 // contacts
1664
1665 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1666 self.background.simulate_random_delay().await;
1667 let mut contacts = vec![Contact::Accepted {
1668 user_id: id,
1669 should_notify: false,
1670 }];
1671
1672 for contact in self.contacts.lock().iter() {
1673 if contact.requester_id == id {
1674 if contact.accepted {
1675 contacts.push(Contact::Accepted {
1676 user_id: contact.responder_id,
1677 should_notify: contact.should_notify,
1678 });
1679 } else {
1680 contacts.push(Contact::Outgoing {
1681 user_id: contact.responder_id,
1682 });
1683 }
1684 } else if contact.responder_id == id {
1685 if contact.accepted {
1686 contacts.push(Contact::Accepted {
1687 user_id: contact.requester_id,
1688 should_notify: false,
1689 });
1690 } else {
1691 contacts.push(Contact::Incoming {
1692 user_id: contact.requester_id,
1693 should_notify: contact.should_notify,
1694 });
1695 }
1696 }
1697 }
1698
1699 contacts.sort_unstable_by_key(|contact| contact.user_id());
1700 Ok(contacts)
1701 }
1702
1703 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1704 self.background.simulate_random_delay().await;
1705 Ok(self.contacts.lock().iter().any(|contact| {
1706 contact.accepted
1707 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1708 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1709 }))
1710 }
1711
1712 async fn send_contact_request(
1713 &self,
1714 requester_id: UserId,
1715 responder_id: UserId,
1716 ) -> Result<()> {
1717 let mut contacts = self.contacts.lock();
1718 for contact in contacts.iter_mut() {
1719 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1720 if contact.accepted {
1721 Err(anyhow!("contact already exists"))?;
1722 } else {
1723 Err(anyhow!("contact already requested"))?;
1724 }
1725 }
1726 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1727 if contact.accepted {
1728 Err(anyhow!("contact already exists"))?;
1729 } else {
1730 contact.accepted = true;
1731 contact.should_notify = false;
1732 return Ok(());
1733 }
1734 }
1735 }
1736 contacts.push(FakeContact {
1737 requester_id,
1738 responder_id,
1739 accepted: false,
1740 should_notify: true,
1741 });
1742 Ok(())
1743 }
1744
1745 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1746 self.contacts.lock().retain(|contact| {
1747 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1748 });
1749 Ok(())
1750 }
1751
1752 async fn dismiss_contact_notification(
1753 &self,
1754 user_id: UserId,
1755 contact_user_id: UserId,
1756 ) -> Result<()> {
1757 let mut contacts = self.contacts.lock();
1758 for contact in contacts.iter_mut() {
1759 if contact.requester_id == contact_user_id
1760 && contact.responder_id == user_id
1761 && !contact.accepted
1762 {
1763 contact.should_notify = false;
1764 return Ok(());
1765 }
1766 if contact.requester_id == user_id
1767 && contact.responder_id == contact_user_id
1768 && contact.accepted
1769 {
1770 contact.should_notify = false;
1771 return Ok(());
1772 }
1773 }
1774 Err(anyhow!("no such notification"))?
1775 }
1776
1777 async fn respond_to_contact_request(
1778 &self,
1779 responder_id: UserId,
1780 requester_id: UserId,
1781 accept: bool,
1782 ) -> Result<()> {
1783 let mut contacts = self.contacts.lock();
1784 for (ix, contact) in contacts.iter_mut().enumerate() {
1785 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1786 if contact.accepted {
1787 Err(anyhow!("contact already confirmed"))?;
1788 }
1789 if accept {
1790 contact.accepted = true;
1791 contact.should_notify = true;
1792 } else {
1793 contacts.remove(ix);
1794 }
1795 return Ok(());
1796 }
1797 }
1798 Err(anyhow!("no such contact request"))?
1799 }
1800
1801 async fn create_access_token_hash(
1802 &self,
1803 _user_id: UserId,
1804 _access_token_hash: &str,
1805 _max_access_token_count: usize,
1806 ) -> Result<()> {
1807 unimplemented!()
1808 }
1809
1810 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1811 unimplemented!()
1812 }
1813
1814 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1815 unimplemented!()
1816 }
1817
1818 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1819 self.background.simulate_random_delay().await;
1820 let mut orgs = self.orgs.lock();
1821 if orgs.values().any(|org| org.slug == slug) {
1822 Err(anyhow!("org already exists"))?
1823 } else {
1824 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1825 orgs.insert(
1826 org_id,
1827 Org {
1828 id: org_id,
1829 name: name.to_string(),
1830 slug: slug.to_string(),
1831 },
1832 );
1833 Ok(org_id)
1834 }
1835 }
1836
1837 async fn add_org_member(
1838 &self,
1839 org_id: OrgId,
1840 user_id: UserId,
1841 is_admin: bool,
1842 ) -> Result<()> {
1843 self.background.simulate_random_delay().await;
1844 if !self.orgs.lock().contains_key(&org_id) {
1845 Err(anyhow!("org does not exist"))?;
1846 }
1847 if !self.users.lock().contains_key(&user_id) {
1848 Err(anyhow!("user does not exist"))?;
1849 }
1850
1851 self.org_memberships
1852 .lock()
1853 .entry((org_id, user_id))
1854 .or_insert(is_admin);
1855 Ok(())
1856 }
1857
1858 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1859 self.background.simulate_random_delay().await;
1860 if !self.orgs.lock().contains_key(&org_id) {
1861 Err(anyhow!("org does not exist"))?;
1862 }
1863
1864 let mut channels = self.channels.lock();
1865 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1866 channels.insert(
1867 channel_id,
1868 Channel {
1869 id: channel_id,
1870 name: name.to_string(),
1871 owner_id: org_id.0,
1872 owner_is_user: false,
1873 },
1874 );
1875 Ok(channel_id)
1876 }
1877
1878 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1879 self.background.simulate_random_delay().await;
1880 Ok(self
1881 .channels
1882 .lock()
1883 .values()
1884 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1885 .cloned()
1886 .collect())
1887 }
1888
1889 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1890 self.background.simulate_random_delay().await;
1891 let channels = self.channels.lock();
1892 let memberships = self.channel_memberships.lock();
1893 Ok(channels
1894 .values()
1895 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1896 .cloned()
1897 .collect())
1898 }
1899
1900 async fn can_user_access_channel(
1901 &self,
1902 user_id: UserId,
1903 channel_id: ChannelId,
1904 ) -> Result<bool> {
1905 self.background.simulate_random_delay().await;
1906 Ok(self
1907 .channel_memberships
1908 .lock()
1909 .contains_key(&(channel_id, user_id)))
1910 }
1911
1912 async fn add_channel_member(
1913 &self,
1914 channel_id: ChannelId,
1915 user_id: UserId,
1916 is_admin: bool,
1917 ) -> Result<()> {
1918 self.background.simulate_random_delay().await;
1919 if !self.channels.lock().contains_key(&channel_id) {
1920 Err(anyhow!("channel does not exist"))?;
1921 }
1922 if !self.users.lock().contains_key(&user_id) {
1923 Err(anyhow!("user does not exist"))?;
1924 }
1925
1926 self.channel_memberships
1927 .lock()
1928 .entry((channel_id, user_id))
1929 .or_insert(is_admin);
1930 Ok(())
1931 }
1932
1933 async fn create_channel_message(
1934 &self,
1935 channel_id: ChannelId,
1936 sender_id: UserId,
1937 body: &str,
1938 timestamp: OffsetDateTime,
1939 nonce: u128,
1940 ) -> Result<MessageId> {
1941 self.background.simulate_random_delay().await;
1942 if !self.channels.lock().contains_key(&channel_id) {
1943 Err(anyhow!("channel does not exist"))?;
1944 }
1945 if !self.users.lock().contains_key(&sender_id) {
1946 Err(anyhow!("user does not exist"))?;
1947 }
1948
1949 let mut messages = self.channel_messages.lock();
1950 if let Some(message) = messages
1951 .values()
1952 .find(|message| message.nonce.as_u128() == nonce)
1953 {
1954 Ok(message.id)
1955 } else {
1956 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1957 messages.insert(
1958 message_id,
1959 ChannelMessage {
1960 id: message_id,
1961 channel_id,
1962 sender_id,
1963 body: body.to_string(),
1964 sent_at: timestamp,
1965 nonce: Uuid::from_u128(nonce),
1966 },
1967 );
1968 Ok(message_id)
1969 }
1970 }
1971
1972 async fn get_channel_messages(
1973 &self,
1974 channel_id: ChannelId,
1975 count: usize,
1976 before_id: Option<MessageId>,
1977 ) -> Result<Vec<ChannelMessage>> {
1978 let mut messages = self
1979 .channel_messages
1980 .lock()
1981 .values()
1982 .rev()
1983 .filter(|message| {
1984 message.channel_id == channel_id
1985 && message.id < before_id.unwrap_or(MessageId::MAX)
1986 })
1987 .take(count)
1988 .cloned()
1989 .collect::<Vec<_>>();
1990 messages.sort_unstable_by_key(|message| message.id);
1991 Ok(messages)
1992 }
1993
1994 async fn teardown(&self, _: &str) {}
1995
1996 #[cfg(test)]
1997 fn as_fake(&self) -> Option<&FakeDb> {
1998 Some(self)
1999 }
2000 }
2001}