1use crate::Result;
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use futures::StreamExt;
5use nanoid::nanoid;
6use serde::Serialize;
7pub use sqlx::postgres::PgPoolOptions as DbOptions;
8use sqlx::{types::Uuid, FromRow};
9use time::OffsetDateTime;
10
11#[async_trait]
12pub trait Db: Send + Sync {
13 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
14 async fn get_all_users(&self) -> Result<Vec<User>>;
15 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
16 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
17 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
18 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
19 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
20 async fn destroy_user(&self, id: UserId) -> Result<()>;
21
22 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
23 async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>>;
24 async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
25
26 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
27 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
28 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
29 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
30 async fn dismiss_contact_notification(
31 &self,
32 responder_id: UserId,
33 requester_id: UserId,
34 ) -> Result<()>;
35 async fn respond_to_contact_request(
36 &self,
37 responder_id: UserId,
38 requester_id: UserId,
39 accept: bool,
40 ) -> Result<()>;
41
42 async fn create_access_token_hash(
43 &self,
44 user_id: UserId,
45 access_token_hash: &str,
46 max_access_token_count: usize,
47 ) -> Result<()>;
48 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
49 #[cfg(any(test, feature = "seed-support"))]
50
51 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
52 #[cfg(any(test, feature = "seed-support"))]
53 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
54 #[cfg(any(test, feature = "seed-support"))]
55 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
56 #[cfg(any(test, feature = "seed-support"))]
57 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
58 #[cfg(any(test, feature = "seed-support"))]
59
60 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
61 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
62 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
63 -> Result<bool>;
64 #[cfg(any(test, feature = "seed-support"))]
65 async fn add_channel_member(
66 &self,
67 channel_id: ChannelId,
68 user_id: UserId,
69 is_admin: bool,
70 ) -> Result<()>;
71 async fn create_channel_message(
72 &self,
73 channel_id: ChannelId,
74 sender_id: UserId,
75 body: &str,
76 timestamp: OffsetDateTime,
77 nonce: u128,
78 ) -> Result<MessageId>;
79 async fn get_channel_messages(
80 &self,
81 channel_id: ChannelId,
82 count: usize,
83 before_id: Option<MessageId>,
84 ) -> Result<Vec<ChannelMessage>>;
85 #[cfg(test)]
86 async fn teardown(&self, url: &str);
87 #[cfg(test)]
88 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
89}
90
91pub struct PostgresDb {
92 pool: sqlx::PgPool,
93}
94
95impl PostgresDb {
96 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
97 let pool = DbOptions::new()
98 .max_connections(max_connections)
99 .connect(&url)
100 .await
101 .context("failed to connect to postgres database")?;
102 Ok(Self { pool })
103 }
104}
105
106#[async_trait]
107impl Db for PostgresDb {
108 // users
109
110 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
111 let query = "
112 INSERT INTO users (github_login, admin)
113 VALUES ($1, $2)
114 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
115 RETURNING id
116 ";
117 Ok(sqlx::query_scalar(query)
118 .bind(github_login)
119 .bind(admin)
120 .fetch_one(&self.pool)
121 .await
122 .map(UserId)?)
123 }
124
125 async fn get_all_users(&self) -> Result<Vec<User>> {
126 let query = "SELECT * FROM users ORDER BY github_login ASC";
127 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
128 }
129
130 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
131 let like_string = fuzzy_like_string(name_query);
132 let query = "
133 SELECT users.*
134 FROM users
135 WHERE github_login ILIKE $1
136 ORDER BY github_login <-> $2
137 LIMIT $3
138 ";
139 Ok(sqlx::query_as(query)
140 .bind(like_string)
141 .bind(name_query)
142 .bind(limit)
143 .fetch_all(&self.pool)
144 .await?)
145 }
146
147 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
148 let users = self.get_users_by_ids(vec![id]).await?;
149 Ok(users.into_iter().next())
150 }
151
152 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
153 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
154 let query = "
155 SELECT users.*
156 FROM users
157 WHERE users.id = ANY ($1)
158 ";
159 Ok(sqlx::query_as(query)
160 .bind(&ids)
161 .fetch_all(&self.pool)
162 .await?)
163 }
164
165 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
166 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
167 Ok(sqlx::query_as(query)
168 .bind(github_login)
169 .fetch_optional(&self.pool)
170 .await?)
171 }
172
173 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
174 let query = "UPDATE users SET admin = $1 WHERE id = $2";
175 Ok(sqlx::query(query)
176 .bind(is_admin)
177 .bind(id.0)
178 .execute(&self.pool)
179 .await
180 .map(drop)?)
181 }
182
183 async fn destroy_user(&self, id: UserId) -> Result<()> {
184 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
185 sqlx::query(query)
186 .bind(id.0)
187 .execute(&self.pool)
188 .await
189 .map(drop)?;
190 let query = "DELETE FROM users WHERE id = $1;";
191 Ok(sqlx::query(query)
192 .bind(id.0)
193 .execute(&self.pool)
194 .await
195 .map(drop)?)
196 }
197
198 // invite codes
199
200 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
201 let mut tx = self.pool.begin().await?;
202 sqlx::query(
203 "
204 UPDATE users
205 SET invite_code = $1
206 WHERE id = $2 AND invite_code IS NULL
207 ",
208 )
209 .bind(nanoid!(16))
210 .bind(id)
211 .execute(&mut tx)
212 .await?;
213 sqlx::query(
214 "
215 UPDATE users
216 SET invite_count = $1
217 WHERE id = $2
218 ",
219 )
220 .bind(count)
221 .bind(id)
222 .execute(&mut tx)
223 .await?;
224 tx.commit().await?;
225 Ok(())
226 }
227
228 async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>> {
229 let result: Option<(String, i32)> = sqlx::query_as(
230 "
231 SELECT invite_code, invite_count
232 FROM users
233 WHERE id = $1 AND invite_code IS NOT NULL
234 ",
235 )
236 .bind(id)
237 .fetch_optional(&self.pool)
238 .await?;
239 if let Some((code, count)) = result {
240 Ok(Some((code, count.try_into()?)))
241 } else {
242 Ok(None)
243 }
244 }
245
246 async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
247 let mut tx = self.pool.begin().await?;
248
249 let inviter_id: UserId = sqlx::query_scalar(
250 "
251 UPDATE users
252 SET invite_count = invite_count - 1
253 WHERE
254 invite_code = $1 AND
255 invite_count > 0
256 RETURNING id
257 ",
258 )
259 .bind(code)
260 .fetch_optional(&mut tx)
261 .await?
262 .ok_or_else(|| anyhow!("invite code not found"))?;
263 let invitee_id = sqlx::query_scalar(
264 "
265 INSERT INTO users
266 (github_login, admin, inviter_id)
267 VALUES
268 ($1, 'f', $2)
269 RETURNING id
270 ",
271 )
272 .bind(login)
273 .bind(inviter_id)
274 .fetch_one(&mut tx)
275 .await
276 .map(UserId)?;
277
278 sqlx::query(
279 "
280 INSERT INTO contacts
281 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
282 VALUES
283 ($1, $2, 't', 't', 't')
284 ",
285 )
286 .bind(inviter_id)
287 .bind(invitee_id)
288 .execute(&mut tx)
289 .await?;
290
291 tx.commit().await?;
292 Ok(invitee_id)
293 }
294
295 // contacts
296
297 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
298 let query = "
299 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
300 FROM contacts
301 WHERE user_id_a = $1 OR user_id_b = $1;
302 ";
303
304 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
305 .bind(user_id)
306 .fetch(&self.pool);
307
308 let mut contacts = vec![Contact::Accepted {
309 user_id,
310 should_notify: false,
311 }];
312 while let Some(row) = rows.next().await {
313 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
314
315 if user_id_a == user_id {
316 if accepted {
317 contacts.push(Contact::Accepted {
318 user_id: user_id_b,
319 should_notify: should_notify && a_to_b,
320 });
321 } else if a_to_b {
322 contacts.push(Contact::Outgoing { user_id: user_id_b })
323 } else {
324 contacts.push(Contact::Incoming {
325 user_id: user_id_b,
326 should_notify,
327 });
328 }
329 } else {
330 if accepted {
331 contacts.push(Contact::Accepted {
332 user_id: user_id_a,
333 should_notify: should_notify && !a_to_b,
334 });
335 } else if a_to_b {
336 contacts.push(Contact::Incoming {
337 user_id: user_id_a,
338 should_notify,
339 });
340 } else {
341 contacts.push(Contact::Outgoing { user_id: user_id_a });
342 }
343 }
344 }
345
346 contacts.sort_unstable_by_key(|contact| contact.user_id());
347
348 Ok(contacts)
349 }
350
351 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
352 let (id_a, id_b) = if user_id_1 < user_id_2 {
353 (user_id_1, user_id_2)
354 } else {
355 (user_id_2, user_id_1)
356 };
357
358 let query = "
359 SELECT 1 FROM contacts
360 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
361 LIMIT 1
362 ";
363 Ok(sqlx::query_scalar::<_, i32>(query)
364 .bind(id_a.0)
365 .bind(id_b.0)
366 .fetch_optional(&self.pool)
367 .await?
368 .is_some())
369 }
370
371 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
372 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
373 (sender_id, receiver_id, true)
374 } else {
375 (receiver_id, sender_id, false)
376 };
377 let query = "
378 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
379 VALUES ($1, $2, $3, 'f', 't')
380 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
381 SET
382 accepted = 't',
383 should_notify = 'f'
384 WHERE
385 NOT contacts.accepted AND
386 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
387 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
388 ";
389 let result = sqlx::query(query)
390 .bind(id_a.0)
391 .bind(id_b.0)
392 .bind(a_to_b)
393 .execute(&self.pool)
394 .await?;
395
396 if result.rows_affected() == 1 {
397 Ok(())
398 } else {
399 Err(anyhow!("contact already requested"))?
400 }
401 }
402
403 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
404 let (id_a, id_b) = if responder_id < requester_id {
405 (responder_id, requester_id)
406 } else {
407 (requester_id, responder_id)
408 };
409 let query = "
410 DELETE FROM contacts
411 WHERE user_id_a = $1 AND user_id_b = $2;
412 ";
413 let result = sqlx::query(query)
414 .bind(id_a.0)
415 .bind(id_b.0)
416 .execute(&self.pool)
417 .await?;
418
419 if result.rows_affected() == 1 {
420 Ok(())
421 } else {
422 Err(anyhow!("no such contact"))?
423 }
424 }
425
426 async fn dismiss_contact_notification(
427 &self,
428 user_id: UserId,
429 contact_user_id: UserId,
430 ) -> Result<()> {
431 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
432 (user_id, contact_user_id, true)
433 } else {
434 (contact_user_id, user_id, false)
435 };
436
437 let query = "
438 UPDATE contacts
439 SET should_notify = 'f'
440 WHERE
441 user_id_a = $1 AND user_id_b = $2 AND
442 (
443 (a_to_b = $3 AND accepted) OR
444 (a_to_b != $3 AND NOT accepted)
445 );
446 ";
447
448 let result = sqlx::query(query)
449 .bind(id_a.0)
450 .bind(id_b.0)
451 .bind(a_to_b)
452 .execute(&self.pool)
453 .await?;
454
455 if result.rows_affected() == 0 {
456 Err(anyhow!("no such contact request"))?;
457 }
458
459 Ok(())
460 }
461
462 async fn respond_to_contact_request(
463 &self,
464 responder_id: UserId,
465 requester_id: UserId,
466 accept: bool,
467 ) -> Result<()> {
468 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
469 (responder_id, requester_id, false)
470 } else {
471 (requester_id, responder_id, true)
472 };
473 let result = if accept {
474 let query = "
475 UPDATE contacts
476 SET accepted = 't', should_notify = 't'
477 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
478 ";
479 sqlx::query(query)
480 .bind(id_a.0)
481 .bind(id_b.0)
482 .bind(a_to_b)
483 .execute(&self.pool)
484 .await?
485 } else {
486 let query = "
487 DELETE FROM contacts
488 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
489 ";
490 sqlx::query(query)
491 .bind(id_a.0)
492 .bind(id_b.0)
493 .bind(a_to_b)
494 .execute(&self.pool)
495 .await?
496 };
497 if result.rows_affected() == 1 {
498 Ok(())
499 } else {
500 Err(anyhow!("no such contact request"))?
501 }
502 }
503
504 // access tokens
505
506 async fn create_access_token_hash(
507 &self,
508 user_id: UserId,
509 access_token_hash: &str,
510 max_access_token_count: usize,
511 ) -> Result<()> {
512 let insert_query = "
513 INSERT INTO access_tokens (user_id, hash)
514 VALUES ($1, $2);
515 ";
516 let cleanup_query = "
517 DELETE FROM access_tokens
518 WHERE id IN (
519 SELECT id from access_tokens
520 WHERE user_id = $1
521 ORDER BY id DESC
522 OFFSET $3
523 )
524 ";
525
526 let mut tx = self.pool.begin().await?;
527 sqlx::query(insert_query)
528 .bind(user_id.0)
529 .bind(access_token_hash)
530 .execute(&mut tx)
531 .await?;
532 sqlx::query(cleanup_query)
533 .bind(user_id.0)
534 .bind(access_token_hash)
535 .bind(max_access_token_count as u32)
536 .execute(&mut tx)
537 .await?;
538 Ok(tx.commit().await?)
539 }
540
541 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
542 let query = "
543 SELECT hash
544 FROM access_tokens
545 WHERE user_id = $1
546 ORDER BY id DESC
547 ";
548 Ok(sqlx::query_scalar(query)
549 .bind(user_id.0)
550 .fetch_all(&self.pool)
551 .await?)
552 }
553
554 // orgs
555
556 #[allow(unused)] // Help rust-analyzer
557 #[cfg(any(test, feature = "seed-support"))]
558 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
559 let query = "
560 SELECT *
561 FROM orgs
562 WHERE slug = $1
563 ";
564 Ok(sqlx::query_as(query)
565 .bind(slug)
566 .fetch_optional(&self.pool)
567 .await?)
568 }
569
570 #[cfg(any(test, feature = "seed-support"))]
571 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
572 let query = "
573 INSERT INTO orgs (name, slug)
574 VALUES ($1, $2)
575 RETURNING id
576 ";
577 Ok(sqlx::query_scalar(query)
578 .bind(name)
579 .bind(slug)
580 .fetch_one(&self.pool)
581 .await
582 .map(OrgId)?)
583 }
584
585 #[cfg(any(test, feature = "seed-support"))]
586 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
587 let query = "
588 INSERT INTO org_memberships (org_id, user_id, admin)
589 VALUES ($1, $2, $3)
590 ON CONFLICT DO NOTHING
591 ";
592 Ok(sqlx::query(query)
593 .bind(org_id.0)
594 .bind(user_id.0)
595 .bind(is_admin)
596 .execute(&self.pool)
597 .await
598 .map(drop)?)
599 }
600
601 // channels
602
603 #[cfg(any(test, feature = "seed-support"))]
604 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
605 let query = "
606 INSERT INTO channels (owner_id, owner_is_user, name)
607 VALUES ($1, false, $2)
608 RETURNING id
609 ";
610 Ok(sqlx::query_scalar(query)
611 .bind(org_id.0)
612 .bind(name)
613 .fetch_one(&self.pool)
614 .await
615 .map(ChannelId)?)
616 }
617
618 #[allow(unused)] // Help rust-analyzer
619 #[cfg(any(test, feature = "seed-support"))]
620 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
621 let query = "
622 SELECT *
623 FROM channels
624 WHERE
625 channels.owner_is_user = false AND
626 channels.owner_id = $1
627 ";
628 Ok(sqlx::query_as(query)
629 .bind(org_id.0)
630 .fetch_all(&self.pool)
631 .await?)
632 }
633
634 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
635 let query = "
636 SELECT
637 channels.*
638 FROM
639 channel_memberships, channels
640 WHERE
641 channel_memberships.user_id = $1 AND
642 channel_memberships.channel_id = channels.id
643 ";
644 Ok(sqlx::query_as(query)
645 .bind(user_id.0)
646 .fetch_all(&self.pool)
647 .await?)
648 }
649
650 async fn can_user_access_channel(
651 &self,
652 user_id: UserId,
653 channel_id: ChannelId,
654 ) -> Result<bool> {
655 let query = "
656 SELECT id
657 FROM channel_memberships
658 WHERE user_id = $1 AND channel_id = $2
659 LIMIT 1
660 ";
661 Ok(sqlx::query_scalar::<_, i32>(query)
662 .bind(user_id.0)
663 .bind(channel_id.0)
664 .fetch_optional(&self.pool)
665 .await
666 .map(|e| e.is_some())?)
667 }
668
669 #[cfg(any(test, feature = "seed-support"))]
670 async fn add_channel_member(
671 &self,
672 channel_id: ChannelId,
673 user_id: UserId,
674 is_admin: bool,
675 ) -> Result<()> {
676 let query = "
677 INSERT INTO channel_memberships (channel_id, user_id, admin)
678 VALUES ($1, $2, $3)
679 ON CONFLICT DO NOTHING
680 ";
681 Ok(sqlx::query(query)
682 .bind(channel_id.0)
683 .bind(user_id.0)
684 .bind(is_admin)
685 .execute(&self.pool)
686 .await
687 .map(drop)?)
688 }
689
690 // messages
691
692 async fn create_channel_message(
693 &self,
694 channel_id: ChannelId,
695 sender_id: UserId,
696 body: &str,
697 timestamp: OffsetDateTime,
698 nonce: u128,
699 ) -> Result<MessageId> {
700 let query = "
701 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
702 VALUES ($1, $2, $3, $4, $5)
703 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
704 RETURNING id
705 ";
706 Ok(sqlx::query_scalar(query)
707 .bind(channel_id.0)
708 .bind(sender_id.0)
709 .bind(body)
710 .bind(timestamp)
711 .bind(Uuid::from_u128(nonce))
712 .fetch_one(&self.pool)
713 .await
714 .map(MessageId)?)
715 }
716
717 async fn get_channel_messages(
718 &self,
719 channel_id: ChannelId,
720 count: usize,
721 before_id: Option<MessageId>,
722 ) -> Result<Vec<ChannelMessage>> {
723 let query = r#"
724 SELECT * FROM (
725 SELECT
726 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
727 FROM
728 channel_messages
729 WHERE
730 channel_id = $1 AND
731 id < $2
732 ORDER BY id DESC
733 LIMIT $3
734 ) as recent_messages
735 ORDER BY id ASC
736 "#;
737 Ok(sqlx::query_as(query)
738 .bind(channel_id.0)
739 .bind(before_id.unwrap_or(MessageId::MAX))
740 .bind(count as i64)
741 .fetch_all(&self.pool)
742 .await?)
743 }
744
745 #[cfg(test)]
746 async fn teardown(&self, url: &str) {
747 use util::ResultExt;
748
749 let query = "
750 SELECT pg_terminate_backend(pg_stat_activity.pid)
751 FROM pg_stat_activity
752 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
753 ";
754 sqlx::query(query).execute(&self.pool).await.log_err();
755 self.pool.close().await;
756 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
757 .await
758 .log_err();
759 }
760
761 #[cfg(test)]
762 fn as_fake(&self) -> Option<&tests::FakeDb> {
763 None
764 }
765}
766
767macro_rules! id_type {
768 ($name:ident) => {
769 #[derive(
770 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
771 )]
772 #[sqlx(transparent)]
773 #[serde(transparent)]
774 pub struct $name(pub i32);
775
776 impl $name {
777 #[allow(unused)]
778 pub const MAX: Self = Self(i32::MAX);
779
780 #[allow(unused)]
781 pub fn from_proto(value: u64) -> Self {
782 Self(value as i32)
783 }
784
785 #[allow(unused)]
786 pub fn to_proto(&self) -> u64 {
787 self.0 as u64
788 }
789 }
790
791 impl std::fmt::Display for $name {
792 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
793 self.0.fmt(f)
794 }
795 }
796 };
797}
798
799id_type!(UserId);
800#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
801pub struct User {
802 pub id: UserId,
803 pub github_login: String,
804 pub admin: bool,
805 pub invite_code: Option<String>,
806 pub invite_count: i32,
807}
808
809id_type!(OrgId);
810#[derive(FromRow)]
811pub struct Org {
812 pub id: OrgId,
813 pub name: String,
814 pub slug: String,
815}
816
817id_type!(ChannelId);
818#[derive(Clone, Debug, FromRow, Serialize)]
819pub struct Channel {
820 pub id: ChannelId,
821 pub name: String,
822 pub owner_id: i32,
823 pub owner_is_user: bool,
824}
825
826id_type!(MessageId);
827#[derive(Clone, Debug, FromRow)]
828pub struct ChannelMessage {
829 pub id: MessageId,
830 pub channel_id: ChannelId,
831 pub sender_id: UserId,
832 pub body: String,
833 pub sent_at: OffsetDateTime,
834 pub nonce: Uuid,
835}
836
837#[derive(Clone, Debug, PartialEq, Eq)]
838pub enum Contact {
839 Accepted {
840 user_id: UserId,
841 should_notify: bool,
842 },
843 Outgoing {
844 user_id: UserId,
845 },
846 Incoming {
847 user_id: UserId,
848 should_notify: bool,
849 },
850}
851
852impl Contact {
853 pub fn user_id(&self) -> UserId {
854 match self {
855 Contact::Accepted { user_id, .. } => *user_id,
856 Contact::Outgoing { user_id } => *user_id,
857 Contact::Incoming { user_id, .. } => *user_id,
858 }
859 }
860}
861
862#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
863pub struct IncomingContactRequest {
864 pub requester_id: UserId,
865 pub should_notify: bool,
866}
867
868fn fuzzy_like_string(string: &str) -> String {
869 let mut result = String::with_capacity(string.len() * 2 + 1);
870 for c in string.chars() {
871 if c.is_alphanumeric() {
872 result.push('%');
873 result.push(c);
874 }
875 }
876 result.push('%');
877 result
878}
879
880#[cfg(test)]
881pub mod tests {
882 use super::*;
883 use anyhow::anyhow;
884 use collections::BTreeMap;
885 use gpui::executor::Background;
886 use lazy_static::lazy_static;
887 use parking_lot::Mutex;
888 use rand::prelude::*;
889 use sqlx::{
890 migrate::{MigrateDatabase, Migrator},
891 Postgres,
892 };
893 use std::{path::Path, sync::Arc};
894 use util::post_inc;
895
896 #[tokio::test(flavor = "multi_thread")]
897 async fn test_get_users_by_ids() {
898 for test_db in [
899 TestDb::postgres().await,
900 TestDb::fake(Arc::new(gpui::executor::Background::new())),
901 ] {
902 let db = test_db.db();
903
904 let user = db.create_user("user", false).await.unwrap();
905 let friend1 = db.create_user("friend-1", false).await.unwrap();
906 let friend2 = db.create_user("friend-2", false).await.unwrap();
907 let friend3 = db.create_user("friend-3", false).await.unwrap();
908
909 assert_eq!(
910 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
911 .await
912 .unwrap(),
913 vec![
914 User {
915 id: user,
916 github_login: "user".to_string(),
917 admin: false,
918 ..Default::default()
919 },
920 User {
921 id: friend1,
922 github_login: "friend-1".to_string(),
923 admin: false,
924 ..Default::default()
925 },
926 User {
927 id: friend2,
928 github_login: "friend-2".to_string(),
929 admin: false,
930 ..Default::default()
931 },
932 User {
933 id: friend3,
934 github_login: "friend-3".to_string(),
935 admin: false,
936 ..Default::default()
937 }
938 ]
939 );
940 }
941 }
942
943 #[tokio::test(flavor = "multi_thread")]
944 async fn test_recent_channel_messages() {
945 for test_db in [
946 TestDb::postgres().await,
947 TestDb::fake(Arc::new(gpui::executor::Background::new())),
948 ] {
949 let db = test_db.db();
950 let user = db.create_user("user", false).await.unwrap();
951 let org = db.create_org("org", "org").await.unwrap();
952 let channel = db.create_org_channel(org, "channel").await.unwrap();
953 for i in 0..10 {
954 db.create_channel_message(
955 channel,
956 user,
957 &i.to_string(),
958 OffsetDateTime::now_utc(),
959 i,
960 )
961 .await
962 .unwrap();
963 }
964
965 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
966 assert_eq!(
967 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
968 ["5", "6", "7", "8", "9"]
969 );
970
971 let prev_messages = db
972 .get_channel_messages(channel, 4, Some(messages[0].id))
973 .await
974 .unwrap();
975 assert_eq!(
976 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
977 ["1", "2", "3", "4"]
978 );
979 }
980 }
981
982 #[tokio::test(flavor = "multi_thread")]
983 async fn test_channel_message_nonces() {
984 for test_db in [
985 TestDb::postgres().await,
986 TestDb::fake(Arc::new(gpui::executor::Background::new())),
987 ] {
988 let db = test_db.db();
989 let user = db.create_user("user", false).await.unwrap();
990 let org = db.create_org("org", "org").await.unwrap();
991 let channel = db.create_org_channel(org, "channel").await.unwrap();
992
993 let msg1_id = db
994 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
995 .await
996 .unwrap();
997 let msg2_id = db
998 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
999 .await
1000 .unwrap();
1001 let msg3_id = db
1002 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1003 .await
1004 .unwrap();
1005 let msg4_id = db
1006 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1007 .await
1008 .unwrap();
1009
1010 assert_ne!(msg1_id, msg2_id);
1011 assert_eq!(msg1_id, msg3_id);
1012 assert_eq!(msg2_id, msg4_id);
1013 }
1014 }
1015
1016 #[tokio::test(flavor = "multi_thread")]
1017 async fn test_create_access_tokens() {
1018 let test_db = TestDb::postgres().await;
1019 let db = test_db.db();
1020 let user = db.create_user("the-user", false).await.unwrap();
1021
1022 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1023 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1024 assert_eq!(
1025 db.get_access_token_hashes(user).await.unwrap(),
1026 &["h2".to_string(), "h1".to_string()]
1027 );
1028
1029 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1030 assert_eq!(
1031 db.get_access_token_hashes(user).await.unwrap(),
1032 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1033 );
1034
1035 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1036 assert_eq!(
1037 db.get_access_token_hashes(user).await.unwrap(),
1038 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1039 );
1040
1041 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1042 assert_eq!(
1043 db.get_access_token_hashes(user).await.unwrap(),
1044 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1045 );
1046 }
1047
1048 #[test]
1049 fn test_fuzzy_like_string() {
1050 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1051 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1052 assert_eq!(fuzzy_like_string(" z "), "%z%");
1053 }
1054
1055 #[tokio::test(flavor = "multi_thread")]
1056 async fn test_fuzzy_search_users() {
1057 let test_db = TestDb::postgres().await;
1058 let db = test_db.db();
1059 for github_login in [
1060 "California",
1061 "colorado",
1062 "oregon",
1063 "washington",
1064 "florida",
1065 "delaware",
1066 "rhode-island",
1067 ] {
1068 db.create_user(github_login, false).await.unwrap();
1069 }
1070
1071 assert_eq!(
1072 fuzzy_search_user_names(db, "clr").await,
1073 &["colorado", "California"]
1074 );
1075 assert_eq!(
1076 fuzzy_search_user_names(db, "ro").await,
1077 &["rhode-island", "colorado", "oregon"],
1078 );
1079
1080 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1081 db.fuzzy_search_users(query, 10)
1082 .await
1083 .unwrap()
1084 .into_iter()
1085 .map(|user| user.github_login)
1086 .collect::<Vec<_>>()
1087 }
1088 }
1089
1090 #[tokio::test(flavor = "multi_thread")]
1091 async fn test_add_contacts() {
1092 for test_db in [
1093 TestDb::postgres().await,
1094 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1095 ] {
1096 let db = test_db.db();
1097
1098 let user_1 = db.create_user("user1", false).await.unwrap();
1099 let user_2 = db.create_user("user2", false).await.unwrap();
1100 let user_3 = db.create_user("user3", false).await.unwrap();
1101
1102 // User starts with no contacts
1103 assert_eq!(
1104 db.get_contacts(user_1).await.unwrap(),
1105 vec![Contact::Accepted {
1106 user_id: user_1,
1107 should_notify: false
1108 }],
1109 );
1110
1111 // User requests a contact. Both users see the pending request.
1112 db.send_contact_request(user_1, user_2).await.unwrap();
1113 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1114 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1115 assert_eq!(
1116 db.get_contacts(user_1).await.unwrap(),
1117 &[
1118 Contact::Accepted {
1119 user_id: user_1,
1120 should_notify: false
1121 },
1122 Contact::Outgoing { user_id: user_2 }
1123 ],
1124 );
1125 assert_eq!(
1126 db.get_contacts(user_2).await.unwrap(),
1127 &[
1128 Contact::Incoming {
1129 user_id: user_1,
1130 should_notify: true
1131 },
1132 Contact::Accepted {
1133 user_id: user_2,
1134 should_notify: false
1135 },
1136 ]
1137 );
1138
1139 // User 2 dismisses the contact request notification without accepting or rejecting.
1140 // We shouldn't notify them again.
1141 db.dismiss_contact_notification(user_1, user_2)
1142 .await
1143 .unwrap_err();
1144 db.dismiss_contact_notification(user_2, user_1)
1145 .await
1146 .unwrap();
1147 assert_eq!(
1148 db.get_contacts(user_2).await.unwrap(),
1149 &[
1150 Contact::Incoming {
1151 user_id: user_1,
1152 should_notify: false
1153 },
1154 Contact::Accepted {
1155 user_id: user_2,
1156 should_notify: false
1157 },
1158 ]
1159 );
1160
1161 // User can't accept their own contact request
1162 db.respond_to_contact_request(user_1, user_2, true)
1163 .await
1164 .unwrap_err();
1165
1166 // User accepts a contact request. Both users see the contact.
1167 db.respond_to_contact_request(user_2, user_1, true)
1168 .await
1169 .unwrap();
1170 assert_eq!(
1171 db.get_contacts(user_1).await.unwrap(),
1172 &[
1173 Contact::Accepted {
1174 user_id: user_1,
1175 should_notify: false
1176 },
1177 Contact::Accepted {
1178 user_id: user_2,
1179 should_notify: true
1180 }
1181 ],
1182 );
1183 assert!(db.has_contact(user_1, user_2).await.unwrap());
1184 assert!(db.has_contact(user_2, user_1).await.unwrap());
1185 assert_eq!(
1186 db.get_contacts(user_2).await.unwrap(),
1187 &[
1188 Contact::Accepted {
1189 user_id: user_1,
1190 should_notify: false,
1191 },
1192 Contact::Accepted {
1193 user_id: user_2,
1194 should_notify: false,
1195 },
1196 ]
1197 );
1198
1199 // Users cannot re-request existing contacts.
1200 db.send_contact_request(user_1, user_2).await.unwrap_err();
1201 db.send_contact_request(user_2, user_1).await.unwrap_err();
1202
1203 // Users can't dismiss notifications of them accepting other users' requests.
1204 db.dismiss_contact_notification(user_2, user_1)
1205 .await
1206 .unwrap_err();
1207 assert_eq!(
1208 db.get_contacts(user_1).await.unwrap(),
1209 &[
1210 Contact::Accepted {
1211 user_id: user_1,
1212 should_notify: false
1213 },
1214 Contact::Accepted {
1215 user_id: user_2,
1216 should_notify: true,
1217 },
1218 ]
1219 );
1220
1221 // Users can dismiss notifications of other users accepting their requests.
1222 db.dismiss_contact_notification(user_1, user_2)
1223 .await
1224 .unwrap();
1225 assert_eq!(
1226 db.get_contacts(user_1).await.unwrap(),
1227 &[
1228 Contact::Accepted {
1229 user_id: user_1,
1230 should_notify: false
1231 },
1232 Contact::Accepted {
1233 user_id: user_2,
1234 should_notify: false,
1235 },
1236 ]
1237 );
1238
1239 // Users send each other concurrent contact requests and
1240 // see that they are immediately accepted.
1241 db.send_contact_request(user_1, user_3).await.unwrap();
1242 db.send_contact_request(user_3, user_1).await.unwrap();
1243 assert_eq!(
1244 db.get_contacts(user_1).await.unwrap(),
1245 &[
1246 Contact::Accepted {
1247 user_id: user_1,
1248 should_notify: false
1249 },
1250 Contact::Accepted {
1251 user_id: user_2,
1252 should_notify: false,
1253 },
1254 Contact::Accepted {
1255 user_id: user_3,
1256 should_notify: false
1257 },
1258 ]
1259 );
1260 assert_eq!(
1261 db.get_contacts(user_3).await.unwrap(),
1262 &[
1263 Contact::Accepted {
1264 user_id: user_1,
1265 should_notify: false
1266 },
1267 Contact::Accepted {
1268 user_id: user_3,
1269 should_notify: false
1270 }
1271 ],
1272 );
1273
1274 // User declines a contact request. Both users see that it is gone.
1275 db.send_contact_request(user_2, user_3).await.unwrap();
1276 db.respond_to_contact_request(user_3, user_2, false)
1277 .await
1278 .unwrap();
1279 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1280 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1281 assert_eq!(
1282 db.get_contacts(user_2).await.unwrap(),
1283 &[
1284 Contact::Accepted {
1285 user_id: user_1,
1286 should_notify: false
1287 },
1288 Contact::Accepted {
1289 user_id: user_2,
1290 should_notify: false
1291 }
1292 ]
1293 );
1294 assert_eq!(
1295 db.get_contacts(user_3).await.unwrap(),
1296 &[
1297 Contact::Accepted {
1298 user_id: user_1,
1299 should_notify: false
1300 },
1301 Contact::Accepted {
1302 user_id: user_3,
1303 should_notify: false
1304 }
1305 ],
1306 );
1307 }
1308 }
1309
1310 #[tokio::test(flavor = "multi_thread")]
1311 async fn test_invite_codes() {
1312 let postgres = TestDb::postgres().await;
1313 let db = postgres.db();
1314 let user1 = db.create_user("user-1", false).await.unwrap();
1315
1316 // Initially, user 1 has no invite code
1317 assert_eq!(db.get_invite_code(user1).await.unwrap(), None);
1318
1319 // User 1 creates an invite code that can be used twice.
1320 db.set_invite_count(user1, 2).await.unwrap();
1321 let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1322 assert_eq!(invite_count, 2);
1323
1324 // User 2 redeems the invite code and becomes a contact of user 1.
1325 let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
1326 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1327 assert_eq!(invite_count, 1);
1328 assert_eq!(
1329 db.get_contacts(user1).await.unwrap(),
1330 [
1331 Contact::Accepted {
1332 user_id: user1,
1333 should_notify: false
1334 },
1335 Contact::Accepted {
1336 user_id: user2,
1337 should_notify: true
1338 }
1339 ]
1340 );
1341 assert_eq!(
1342 db.get_contacts(user2).await.unwrap(),
1343 [
1344 Contact::Accepted {
1345 user_id: user1,
1346 should_notify: false
1347 },
1348 Contact::Accepted {
1349 user_id: user2,
1350 should_notify: false
1351 }
1352 ]
1353 );
1354
1355 // User 3 redeems the invite code and becomes a contact of user 1.
1356 let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
1357 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1358 assert_eq!(invite_count, 0);
1359 assert_eq!(
1360 db.get_contacts(user1).await.unwrap(),
1361 [
1362 Contact::Accepted {
1363 user_id: user1,
1364 should_notify: false
1365 },
1366 Contact::Accepted {
1367 user_id: user2,
1368 should_notify: true
1369 },
1370 Contact::Accepted {
1371 user_id: user3,
1372 should_notify: true
1373 }
1374 ]
1375 );
1376 assert_eq!(
1377 db.get_contacts(user3).await.unwrap(),
1378 [
1379 Contact::Accepted {
1380 user_id: user1,
1381 should_notify: false
1382 },
1383 Contact::Accepted {
1384 user_id: user3,
1385 should_notify: false
1386 },
1387 ]
1388 );
1389
1390 // Trying to reedem the code for the third time results in an error.
1391 db.redeem_invite_code(&invite_code, "user-4")
1392 .await
1393 .unwrap_err();
1394
1395 // Invite count can be updated after the code has been created.
1396 db.set_invite_count(user1, 2).await.unwrap();
1397 let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1398 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1399 assert_eq!(invite_count, 2);
1400
1401 // User 4 can now redeem the invite code and becomes a contact of user 1.
1402 let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
1403 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1404 assert_eq!(invite_count, 1);
1405 assert_eq!(
1406 db.get_contacts(user1).await.unwrap(),
1407 [
1408 Contact::Accepted {
1409 user_id: user1,
1410 should_notify: false
1411 },
1412 Contact::Accepted {
1413 user_id: user2,
1414 should_notify: true
1415 },
1416 Contact::Accepted {
1417 user_id: user3,
1418 should_notify: true
1419 },
1420 Contact::Accepted {
1421 user_id: user4,
1422 should_notify: true
1423 }
1424 ]
1425 );
1426 assert_eq!(
1427 db.get_contacts(user4).await.unwrap(),
1428 [
1429 Contact::Accepted {
1430 user_id: user1,
1431 should_notify: false
1432 },
1433 Contact::Accepted {
1434 user_id: user4,
1435 should_notify: false
1436 },
1437 ]
1438 );
1439
1440 // An existing user cannot redeem invite codes.
1441 db.redeem_invite_code(&invite_code, "user-2")
1442 .await
1443 .unwrap_err();
1444 let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1445 assert_eq!(invite_count, 1);
1446 }
1447
1448 pub struct TestDb {
1449 pub db: Option<Arc<dyn Db>>,
1450 pub url: String,
1451 }
1452
1453 impl TestDb {
1454 pub async fn postgres() -> Self {
1455 lazy_static! {
1456 static ref LOCK: Mutex<()> = Mutex::new(());
1457 }
1458
1459 let _guard = LOCK.lock();
1460 let mut rng = StdRng::from_entropy();
1461 let name = format!("zed-test-{}", rng.gen::<u128>());
1462 let url = format!("postgres://postgres@localhost/{}", name);
1463 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1464 Postgres::create_database(&url)
1465 .await
1466 .expect("failed to create test db");
1467 let db = PostgresDb::new(&url, 5).await.unwrap();
1468 let migrator = Migrator::new(migrations_path).await.unwrap();
1469 migrator.run(&db.pool).await.unwrap();
1470 Self {
1471 db: Some(Arc::new(db)),
1472 url,
1473 }
1474 }
1475
1476 pub fn fake(background: Arc<Background>) -> Self {
1477 Self {
1478 db: Some(Arc::new(FakeDb::new(background))),
1479 url: Default::default(),
1480 }
1481 }
1482
1483 pub fn db(&self) -> &Arc<dyn Db> {
1484 self.db.as_ref().unwrap()
1485 }
1486 }
1487
1488 impl Drop for TestDb {
1489 fn drop(&mut self) {
1490 if let Some(db) = self.db.take() {
1491 futures::executor::block_on(db.teardown(&self.url));
1492 }
1493 }
1494 }
1495
1496 pub struct FakeDb {
1497 background: Arc<Background>,
1498 pub users: Mutex<BTreeMap<UserId, User>>,
1499 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1500 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1501 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1502 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1503 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1504 pub contacts: Mutex<Vec<FakeContact>>,
1505 next_channel_message_id: Mutex<i32>,
1506 next_user_id: Mutex<i32>,
1507 next_org_id: Mutex<i32>,
1508 next_channel_id: Mutex<i32>,
1509 }
1510
1511 #[derive(Debug)]
1512 pub struct FakeContact {
1513 pub requester_id: UserId,
1514 pub responder_id: UserId,
1515 pub accepted: bool,
1516 pub should_notify: bool,
1517 }
1518
1519 impl FakeDb {
1520 pub fn new(background: Arc<Background>) -> Self {
1521 Self {
1522 background,
1523 users: Default::default(),
1524 next_user_id: Mutex::new(1),
1525 orgs: Default::default(),
1526 next_org_id: Mutex::new(1),
1527 org_memberships: Default::default(),
1528 channels: Default::default(),
1529 next_channel_id: Mutex::new(1),
1530 channel_memberships: Default::default(),
1531 channel_messages: Default::default(),
1532 next_channel_message_id: Mutex::new(1),
1533 contacts: Default::default(),
1534 }
1535 }
1536 }
1537
1538 #[async_trait]
1539 impl Db for FakeDb {
1540 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1541 self.background.simulate_random_delay().await;
1542
1543 let mut users = self.users.lock();
1544 if let Some(user) = users
1545 .values()
1546 .find(|user| user.github_login == github_login)
1547 {
1548 Ok(user.id)
1549 } else {
1550 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1551 users.insert(
1552 user_id,
1553 User {
1554 id: user_id,
1555 github_login: github_login.to_string(),
1556 admin,
1557 invite_code: None,
1558 invite_count: 0,
1559 },
1560 );
1561 Ok(user_id)
1562 }
1563 }
1564
1565 async fn get_all_users(&self) -> Result<Vec<User>> {
1566 unimplemented!()
1567 }
1568
1569 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1570 unimplemented!()
1571 }
1572
1573 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1574 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1575 }
1576
1577 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1578 self.background.simulate_random_delay().await;
1579 let users = self.users.lock();
1580 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1581 }
1582
1583 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1584 Ok(self
1585 .users
1586 .lock()
1587 .values()
1588 .find(|user| user.github_login == github_login)
1589 .cloned())
1590 }
1591
1592 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1593 unimplemented!()
1594 }
1595
1596 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1597 unimplemented!()
1598 }
1599
1600 // invite codes
1601
1602 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1603 unimplemented!()
1604 }
1605
1606 async fn get_invite_code(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1607 unimplemented!()
1608 }
1609
1610 async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
1611 unimplemented!()
1612 }
1613
1614 // contacts
1615
1616 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1617 self.background.simulate_random_delay().await;
1618 let mut contacts = vec![Contact::Accepted {
1619 user_id: id,
1620 should_notify: false,
1621 }];
1622
1623 for contact in self.contacts.lock().iter() {
1624 if contact.requester_id == id {
1625 if contact.accepted {
1626 contacts.push(Contact::Accepted {
1627 user_id: contact.responder_id,
1628 should_notify: contact.should_notify,
1629 });
1630 } else {
1631 contacts.push(Contact::Outgoing {
1632 user_id: contact.responder_id,
1633 });
1634 }
1635 } else if contact.responder_id == id {
1636 if contact.accepted {
1637 contacts.push(Contact::Accepted {
1638 user_id: contact.requester_id,
1639 should_notify: false,
1640 });
1641 } else {
1642 contacts.push(Contact::Incoming {
1643 user_id: contact.requester_id,
1644 should_notify: contact.should_notify,
1645 });
1646 }
1647 }
1648 }
1649
1650 contacts.sort_unstable_by_key(|contact| contact.user_id());
1651 Ok(contacts)
1652 }
1653
1654 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1655 self.background.simulate_random_delay().await;
1656 Ok(self.contacts.lock().iter().any(|contact| {
1657 contact.accepted
1658 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1659 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1660 }))
1661 }
1662
1663 async fn send_contact_request(
1664 &self,
1665 requester_id: UserId,
1666 responder_id: UserId,
1667 ) -> Result<()> {
1668 let mut contacts = self.contacts.lock();
1669 for contact in contacts.iter_mut() {
1670 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1671 if contact.accepted {
1672 Err(anyhow!("contact already exists"))?;
1673 } else {
1674 Err(anyhow!("contact already requested"))?;
1675 }
1676 }
1677 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1678 if contact.accepted {
1679 Err(anyhow!("contact already exists"))?;
1680 } else {
1681 contact.accepted = true;
1682 contact.should_notify = false;
1683 return Ok(());
1684 }
1685 }
1686 }
1687 contacts.push(FakeContact {
1688 requester_id,
1689 responder_id,
1690 accepted: false,
1691 should_notify: true,
1692 });
1693 Ok(())
1694 }
1695
1696 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1697 self.contacts.lock().retain(|contact| {
1698 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1699 });
1700 Ok(())
1701 }
1702
1703 async fn dismiss_contact_notification(
1704 &self,
1705 user_id: UserId,
1706 contact_user_id: UserId,
1707 ) -> Result<()> {
1708 let mut contacts = self.contacts.lock();
1709 for contact in contacts.iter_mut() {
1710 if contact.requester_id == contact_user_id
1711 && contact.responder_id == user_id
1712 && !contact.accepted
1713 {
1714 contact.should_notify = false;
1715 return Ok(());
1716 }
1717 if contact.requester_id == user_id
1718 && contact.responder_id == contact_user_id
1719 && contact.accepted
1720 {
1721 contact.should_notify = false;
1722 return Ok(());
1723 }
1724 }
1725 Err(anyhow!("no such notification"))?
1726 }
1727
1728 async fn respond_to_contact_request(
1729 &self,
1730 responder_id: UserId,
1731 requester_id: UserId,
1732 accept: bool,
1733 ) -> Result<()> {
1734 let mut contacts = self.contacts.lock();
1735 for (ix, contact) in contacts.iter_mut().enumerate() {
1736 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1737 if contact.accepted {
1738 Err(anyhow!("contact already confirmed"))?;
1739 }
1740 if accept {
1741 contact.accepted = true;
1742 contact.should_notify = true;
1743 } else {
1744 contacts.remove(ix);
1745 }
1746 return Ok(());
1747 }
1748 }
1749 Err(anyhow!("no such contact request"))?
1750 }
1751
1752 async fn create_access_token_hash(
1753 &self,
1754 _user_id: UserId,
1755 _access_token_hash: &str,
1756 _max_access_token_count: usize,
1757 ) -> Result<()> {
1758 unimplemented!()
1759 }
1760
1761 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1762 unimplemented!()
1763 }
1764
1765 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1766 unimplemented!()
1767 }
1768
1769 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1770 self.background.simulate_random_delay().await;
1771 let mut orgs = self.orgs.lock();
1772 if orgs.values().any(|org| org.slug == slug) {
1773 Err(anyhow!("org already exists"))?
1774 } else {
1775 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1776 orgs.insert(
1777 org_id,
1778 Org {
1779 id: org_id,
1780 name: name.to_string(),
1781 slug: slug.to_string(),
1782 },
1783 );
1784 Ok(org_id)
1785 }
1786 }
1787
1788 async fn add_org_member(
1789 &self,
1790 org_id: OrgId,
1791 user_id: UserId,
1792 is_admin: bool,
1793 ) -> Result<()> {
1794 self.background.simulate_random_delay().await;
1795 if !self.orgs.lock().contains_key(&org_id) {
1796 Err(anyhow!("org does not exist"))?;
1797 }
1798 if !self.users.lock().contains_key(&user_id) {
1799 Err(anyhow!("user does not exist"))?;
1800 }
1801
1802 self.org_memberships
1803 .lock()
1804 .entry((org_id, user_id))
1805 .or_insert(is_admin);
1806 Ok(())
1807 }
1808
1809 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1810 self.background.simulate_random_delay().await;
1811 if !self.orgs.lock().contains_key(&org_id) {
1812 Err(anyhow!("org does not exist"))?;
1813 }
1814
1815 let mut channels = self.channels.lock();
1816 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1817 channels.insert(
1818 channel_id,
1819 Channel {
1820 id: channel_id,
1821 name: name.to_string(),
1822 owner_id: org_id.0,
1823 owner_is_user: false,
1824 },
1825 );
1826 Ok(channel_id)
1827 }
1828
1829 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1830 self.background.simulate_random_delay().await;
1831 Ok(self
1832 .channels
1833 .lock()
1834 .values()
1835 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1836 .cloned()
1837 .collect())
1838 }
1839
1840 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1841 self.background.simulate_random_delay().await;
1842 let channels = self.channels.lock();
1843 let memberships = self.channel_memberships.lock();
1844 Ok(channels
1845 .values()
1846 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1847 .cloned()
1848 .collect())
1849 }
1850
1851 async fn can_user_access_channel(
1852 &self,
1853 user_id: UserId,
1854 channel_id: ChannelId,
1855 ) -> Result<bool> {
1856 self.background.simulate_random_delay().await;
1857 Ok(self
1858 .channel_memberships
1859 .lock()
1860 .contains_key(&(channel_id, user_id)))
1861 }
1862
1863 async fn add_channel_member(
1864 &self,
1865 channel_id: ChannelId,
1866 user_id: UserId,
1867 is_admin: bool,
1868 ) -> Result<()> {
1869 self.background.simulate_random_delay().await;
1870 if !self.channels.lock().contains_key(&channel_id) {
1871 Err(anyhow!("channel does not exist"))?;
1872 }
1873 if !self.users.lock().contains_key(&user_id) {
1874 Err(anyhow!("user does not exist"))?;
1875 }
1876
1877 self.channel_memberships
1878 .lock()
1879 .entry((channel_id, user_id))
1880 .or_insert(is_admin);
1881 Ok(())
1882 }
1883
1884 async fn create_channel_message(
1885 &self,
1886 channel_id: ChannelId,
1887 sender_id: UserId,
1888 body: &str,
1889 timestamp: OffsetDateTime,
1890 nonce: u128,
1891 ) -> Result<MessageId> {
1892 self.background.simulate_random_delay().await;
1893 if !self.channels.lock().contains_key(&channel_id) {
1894 Err(anyhow!("channel does not exist"))?;
1895 }
1896 if !self.users.lock().contains_key(&sender_id) {
1897 Err(anyhow!("user does not exist"))?;
1898 }
1899
1900 let mut messages = self.channel_messages.lock();
1901 if let Some(message) = messages
1902 .values()
1903 .find(|message| message.nonce.as_u128() == nonce)
1904 {
1905 Ok(message.id)
1906 } else {
1907 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1908 messages.insert(
1909 message_id,
1910 ChannelMessage {
1911 id: message_id,
1912 channel_id,
1913 sender_id,
1914 body: body.to_string(),
1915 sent_at: timestamp,
1916 nonce: Uuid::from_u128(nonce),
1917 },
1918 );
1919 Ok(message_id)
1920 }
1921 }
1922
1923 async fn get_channel_messages(
1924 &self,
1925 channel_id: ChannelId,
1926 count: usize,
1927 before_id: Option<MessageId>,
1928 ) -> Result<Vec<ChannelMessage>> {
1929 let mut messages = self
1930 .channel_messages
1931 .lock()
1932 .values()
1933 .rev()
1934 .filter(|message| {
1935 message.channel_id == channel_id
1936 && message.id < before_id.unwrap_or(MessageId::MAX)
1937 })
1938 .take(count)
1939 .cloned()
1940 .collect::<Vec<_>>();
1941 messages.sort_unstable_by_key(|message| message.id);
1942 Ok(messages)
1943 }
1944
1945 async fn teardown(&self, _: &str) {}
1946
1947 #[cfg(test)]
1948 fn as_fake(&self) -> Option<&FakeDb> {
1949 Some(self)
1950 }
1951 }
1952}