1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use futures::StreamExt;
6use nanoid::nanoid;
7use serde::Serialize;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
10use time::OffsetDateTime;
11
12#[async_trait]
13pub trait Db: Send + Sync {
14 async fn create_user(
15 &self,
16 github_login: &str,
17 email_address: Option<&str>,
18 admin: bool,
19 ) -> Result<UserId>;
20 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
21 async fn get_all_users(&self) -> Result<Vec<User>>;
22 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
23 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
24 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
25 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
26 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
27 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
28 async fn destroy_user(&self, id: UserId) -> Result<()>;
29
30 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
31 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
32 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
33 async fn redeem_invite_code(
34 &self,
35 code: &str,
36 login: &str,
37 email_address: Option<&str>,
38 ) -> Result<UserId>;
39
40 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
41 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
42 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
43 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
44 async fn dismiss_contact_notification(
45 &self,
46 responder_id: UserId,
47 requester_id: UserId,
48 ) -> Result<()>;
49 async fn respond_to_contact_request(
50 &self,
51 responder_id: UserId,
52 requester_id: UserId,
53 accept: bool,
54 ) -> Result<()>;
55
56 async fn create_access_token_hash(
57 &self,
58 user_id: UserId,
59 access_token_hash: &str,
60 max_access_token_count: usize,
61 ) -> Result<()>;
62 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
63 #[cfg(any(test, feature = "seed-support"))]
64
65 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
66 #[cfg(any(test, feature = "seed-support"))]
67 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
68 #[cfg(any(test, feature = "seed-support"))]
69 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
70 #[cfg(any(test, feature = "seed-support"))]
71 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
72 #[cfg(any(test, feature = "seed-support"))]
73
74 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
75 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
76 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
77 -> Result<bool>;
78 #[cfg(any(test, feature = "seed-support"))]
79 async fn add_channel_member(
80 &self,
81 channel_id: ChannelId,
82 user_id: UserId,
83 is_admin: bool,
84 ) -> Result<()>;
85 async fn create_channel_message(
86 &self,
87 channel_id: ChannelId,
88 sender_id: UserId,
89 body: &str,
90 timestamp: OffsetDateTime,
91 nonce: u128,
92 ) -> Result<MessageId>;
93 async fn get_channel_messages(
94 &self,
95 channel_id: ChannelId,
96 count: usize,
97 before_id: Option<MessageId>,
98 ) -> Result<Vec<ChannelMessage>>;
99 #[cfg(test)]
100 async fn teardown(&self, url: &str);
101 #[cfg(test)]
102 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
103}
104
105pub struct PostgresDb {
106 pool: sqlx::PgPool,
107}
108
109impl PostgresDb {
110 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
111 let pool = DbOptions::new()
112 .max_connections(max_connections)
113 .connect(&url)
114 .await
115 .context("failed to connect to postgres database")?;
116 Ok(Self { pool })
117 }
118}
119
120#[async_trait]
121impl Db for PostgresDb {
122 // users
123
124 async fn create_user(
125 &self,
126 github_login: &str,
127 email_address: Option<&str>,
128 admin: bool,
129 ) -> Result<UserId> {
130 let query = "
131 INSERT INTO users (github_login, email_address, admin)
132 VALUES ($1, $2, $3)
133 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
134 RETURNING id
135 ";
136 Ok(sqlx::query_scalar(query)
137 .bind(github_login)
138 .bind(email_address)
139 .bind(admin)
140 .fetch_one(&self.pool)
141 .await
142 .map(UserId)?)
143 }
144
145 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
146 let mut query = QueryBuilder::new(
147 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
148 );
149 query.push_values(
150 users,
151 |mut query, (github_login, email_address, invite_count)| {
152 query
153 .push_bind(github_login)
154 .push_bind(email_address)
155 .push_bind(false)
156 .push_bind(nanoid!(16))
157 .push_bind(invite_count as u32);
158 },
159 );
160 query.push(
161 "
162 ON CONFLICT (github_login) DO UPDATE SET
163 github_login = excluded.github_login,
164 invite_count = excluded.invite_count,
165 invite_code = CASE WHEN users.invite_code IS NULL
166 THEN excluded.invite_code
167 ELSE users.invite_code
168 END
169 RETURNING id
170 ",
171 );
172
173 let rows = query.build().fetch_all(&self.pool).await?;
174 Ok(rows
175 .into_iter()
176 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
177 .collect())
178 }
179
180 async fn get_all_users(&self) -> Result<Vec<User>> {
181 let query = "SELECT * FROM users ORDER BY github_login ASC";
182 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
183 }
184
185 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
186 let like_string = fuzzy_like_string(name_query);
187 let query = "
188 SELECT users.*
189 FROM users
190 WHERE github_login ILIKE $1
191 ORDER BY github_login <-> $2
192 LIMIT $3
193 ";
194 Ok(sqlx::query_as(query)
195 .bind(like_string)
196 .bind(name_query)
197 .bind(limit)
198 .fetch_all(&self.pool)
199 .await?)
200 }
201
202 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
203 let users = self.get_users_by_ids(vec![id]).await?;
204 Ok(users.into_iter().next())
205 }
206
207 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
208 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
209 let query = "
210 SELECT users.*
211 FROM users
212 WHERE users.id = ANY ($1)
213 ";
214 Ok(sqlx::query_as(query)
215 .bind(&ids)
216 .fetch_all(&self.pool)
217 .await?)
218 }
219
220 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
221 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
222 Ok(sqlx::query_as(query)
223 .bind(github_login)
224 .fetch_optional(&self.pool)
225 .await?)
226 }
227
228 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
229 let query = "UPDATE users SET admin = $1 WHERE id = $2";
230 Ok(sqlx::query(query)
231 .bind(is_admin)
232 .bind(id.0)
233 .execute(&self.pool)
234 .await
235 .map(drop)?)
236 }
237
238 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
239 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
240 Ok(sqlx::query(query)
241 .bind(connected_once)
242 .bind(id.0)
243 .execute(&self.pool)
244 .await
245 .map(drop)?)
246 }
247
248 async fn destroy_user(&self, id: UserId) -> Result<()> {
249 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
250 sqlx::query(query)
251 .bind(id.0)
252 .execute(&self.pool)
253 .await
254 .map(drop)?;
255 let query = "DELETE FROM users WHERE id = $1;";
256 Ok(sqlx::query(query)
257 .bind(id.0)
258 .execute(&self.pool)
259 .await
260 .map(drop)?)
261 }
262
263 // invite codes
264
265 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
266 let mut tx = self.pool.begin().await?;
267 if count > 0 {
268 sqlx::query(
269 "
270 UPDATE users
271 SET invite_code = $1
272 WHERE id = $2 AND invite_code IS NULL
273 ",
274 )
275 .bind(nanoid!(16))
276 .bind(id)
277 .execute(&mut tx)
278 .await?;
279 }
280
281 sqlx::query(
282 "
283 UPDATE users
284 SET invite_count = $1
285 WHERE id = $2
286 ",
287 )
288 .bind(count)
289 .bind(id)
290 .execute(&mut tx)
291 .await?;
292 tx.commit().await?;
293 Ok(())
294 }
295
296 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
297 let result: Option<(String, i32)> = sqlx::query_as(
298 "
299 SELECT invite_code, invite_count
300 FROM users
301 WHERE id = $1 AND invite_code IS NOT NULL
302 ",
303 )
304 .bind(id)
305 .fetch_optional(&self.pool)
306 .await?;
307 if let Some((code, count)) = result {
308 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
309 } else {
310 Ok(None)
311 }
312 }
313
314 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
315 sqlx::query_as(
316 "
317 SELECT *
318 FROM users
319 WHERE invite_code = $1
320 ",
321 )
322 .bind(code)
323 .fetch_optional(&self.pool)
324 .await?
325 .ok_or_else(|| {
326 Error::Http(
327 StatusCode::NOT_FOUND,
328 "that invite code does not exist".to_string(),
329 )
330 })
331 }
332
333 async fn redeem_invite_code(
334 &self,
335 code: &str,
336 login: &str,
337 email_address: Option<&str>,
338 ) -> Result<UserId> {
339 let mut tx = self.pool.begin().await?;
340
341 let inviter_id: Option<UserId> = sqlx::query_scalar(
342 "
343 UPDATE users
344 SET invite_count = invite_count - 1
345 WHERE
346 invite_code = $1 AND
347 invite_count > 0
348 RETURNING id
349 ",
350 )
351 .bind(code)
352 .fetch_optional(&mut tx)
353 .await?;
354
355 let inviter_id = match inviter_id {
356 Some(inviter_id) => inviter_id,
357 None => {
358 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
359 .bind(code)
360 .fetch_optional(&mut tx)
361 .await?
362 .is_some()
363 {
364 Err(Error::Http(
365 StatusCode::UNAUTHORIZED,
366 "no invites remaining".to_string(),
367 ))?
368 } else {
369 Err(Error::Http(
370 StatusCode::NOT_FOUND,
371 "invite code not found".to_string(),
372 ))?
373 }
374 }
375 };
376
377 let invitee_id = sqlx::query_scalar(
378 "
379 INSERT INTO users
380 (github_login, email_address, admin, inviter_id)
381 VALUES
382 ($1, $2, 'f', $3)
383 RETURNING id
384 ",
385 )
386 .bind(login)
387 .bind(email_address)
388 .bind(inviter_id)
389 .fetch_one(&mut tx)
390 .await
391 .map(UserId)?;
392
393 sqlx::query(
394 "
395 INSERT INTO contacts
396 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
397 VALUES
398 ($1, $2, 't', 't', 't')
399 ",
400 )
401 .bind(inviter_id)
402 .bind(invitee_id)
403 .execute(&mut tx)
404 .await?;
405
406 tx.commit().await?;
407 Ok(invitee_id)
408 }
409
410 // contacts
411
412 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
413 let query = "
414 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
415 FROM contacts
416 WHERE user_id_a = $1 OR user_id_b = $1;
417 ";
418
419 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
420 .bind(user_id)
421 .fetch(&self.pool);
422
423 let mut contacts = vec![Contact::Accepted {
424 user_id,
425 should_notify: false,
426 }];
427 while let Some(row) = rows.next().await {
428 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
429
430 if user_id_a == user_id {
431 if accepted {
432 contacts.push(Contact::Accepted {
433 user_id: user_id_b,
434 should_notify: should_notify && a_to_b,
435 });
436 } else if a_to_b {
437 contacts.push(Contact::Outgoing { user_id: user_id_b })
438 } else {
439 contacts.push(Contact::Incoming {
440 user_id: user_id_b,
441 should_notify,
442 });
443 }
444 } else {
445 if accepted {
446 contacts.push(Contact::Accepted {
447 user_id: user_id_a,
448 should_notify: should_notify && !a_to_b,
449 });
450 } else if a_to_b {
451 contacts.push(Contact::Incoming {
452 user_id: user_id_a,
453 should_notify,
454 });
455 } else {
456 contacts.push(Contact::Outgoing { user_id: user_id_a });
457 }
458 }
459 }
460
461 contacts.sort_unstable_by_key(|contact| contact.user_id());
462
463 Ok(contacts)
464 }
465
466 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
467 let (id_a, id_b) = if user_id_1 < user_id_2 {
468 (user_id_1, user_id_2)
469 } else {
470 (user_id_2, user_id_1)
471 };
472
473 let query = "
474 SELECT 1 FROM contacts
475 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
476 LIMIT 1
477 ";
478 Ok(sqlx::query_scalar::<_, i32>(query)
479 .bind(id_a.0)
480 .bind(id_b.0)
481 .fetch_optional(&self.pool)
482 .await?
483 .is_some())
484 }
485
486 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
487 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
488 (sender_id, receiver_id, true)
489 } else {
490 (receiver_id, sender_id, false)
491 };
492 let query = "
493 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
494 VALUES ($1, $2, $3, 'f', 't')
495 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
496 SET
497 accepted = 't',
498 should_notify = 'f'
499 WHERE
500 NOT contacts.accepted AND
501 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
502 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
503 ";
504 let result = sqlx::query(query)
505 .bind(id_a.0)
506 .bind(id_b.0)
507 .bind(a_to_b)
508 .execute(&self.pool)
509 .await?;
510
511 if result.rows_affected() == 1 {
512 Ok(())
513 } else {
514 Err(anyhow!("contact already requested"))?
515 }
516 }
517
518 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
519 let (id_a, id_b) = if responder_id < requester_id {
520 (responder_id, requester_id)
521 } else {
522 (requester_id, responder_id)
523 };
524 let query = "
525 DELETE FROM contacts
526 WHERE user_id_a = $1 AND user_id_b = $2;
527 ";
528 let result = sqlx::query(query)
529 .bind(id_a.0)
530 .bind(id_b.0)
531 .execute(&self.pool)
532 .await?;
533
534 if result.rows_affected() == 1 {
535 Ok(())
536 } else {
537 Err(anyhow!("no such contact"))?
538 }
539 }
540
541 async fn dismiss_contact_notification(
542 &self,
543 user_id: UserId,
544 contact_user_id: UserId,
545 ) -> Result<()> {
546 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
547 (user_id, contact_user_id, true)
548 } else {
549 (contact_user_id, user_id, false)
550 };
551
552 let query = "
553 UPDATE contacts
554 SET should_notify = 'f'
555 WHERE
556 user_id_a = $1 AND user_id_b = $2 AND
557 (
558 (a_to_b = $3 AND accepted) OR
559 (a_to_b != $3 AND NOT accepted)
560 );
561 ";
562
563 let result = sqlx::query(query)
564 .bind(id_a.0)
565 .bind(id_b.0)
566 .bind(a_to_b)
567 .execute(&self.pool)
568 .await?;
569
570 if result.rows_affected() == 0 {
571 Err(anyhow!("no such contact request"))?;
572 }
573
574 Ok(())
575 }
576
577 async fn respond_to_contact_request(
578 &self,
579 responder_id: UserId,
580 requester_id: UserId,
581 accept: bool,
582 ) -> Result<()> {
583 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
584 (responder_id, requester_id, false)
585 } else {
586 (requester_id, responder_id, true)
587 };
588 let result = if accept {
589 let query = "
590 UPDATE contacts
591 SET accepted = 't', should_notify = 't'
592 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
593 ";
594 sqlx::query(query)
595 .bind(id_a.0)
596 .bind(id_b.0)
597 .bind(a_to_b)
598 .execute(&self.pool)
599 .await?
600 } else {
601 let query = "
602 DELETE FROM contacts
603 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
604 ";
605 sqlx::query(query)
606 .bind(id_a.0)
607 .bind(id_b.0)
608 .bind(a_to_b)
609 .execute(&self.pool)
610 .await?
611 };
612 if result.rows_affected() == 1 {
613 Ok(())
614 } else {
615 Err(anyhow!("no such contact request"))?
616 }
617 }
618
619 // access tokens
620
621 async fn create_access_token_hash(
622 &self,
623 user_id: UserId,
624 access_token_hash: &str,
625 max_access_token_count: usize,
626 ) -> Result<()> {
627 let insert_query = "
628 INSERT INTO access_tokens (user_id, hash)
629 VALUES ($1, $2);
630 ";
631 let cleanup_query = "
632 DELETE FROM access_tokens
633 WHERE id IN (
634 SELECT id from access_tokens
635 WHERE user_id = $1
636 ORDER BY id DESC
637 OFFSET $3
638 )
639 ";
640
641 let mut tx = self.pool.begin().await?;
642 sqlx::query(insert_query)
643 .bind(user_id.0)
644 .bind(access_token_hash)
645 .execute(&mut tx)
646 .await?;
647 sqlx::query(cleanup_query)
648 .bind(user_id.0)
649 .bind(access_token_hash)
650 .bind(max_access_token_count as u32)
651 .execute(&mut tx)
652 .await?;
653 Ok(tx.commit().await?)
654 }
655
656 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
657 let query = "
658 SELECT hash
659 FROM access_tokens
660 WHERE user_id = $1
661 ORDER BY id DESC
662 ";
663 Ok(sqlx::query_scalar(query)
664 .bind(user_id.0)
665 .fetch_all(&self.pool)
666 .await?)
667 }
668
669 // orgs
670
671 #[allow(unused)] // Help rust-analyzer
672 #[cfg(any(test, feature = "seed-support"))]
673 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
674 let query = "
675 SELECT *
676 FROM orgs
677 WHERE slug = $1
678 ";
679 Ok(sqlx::query_as(query)
680 .bind(slug)
681 .fetch_optional(&self.pool)
682 .await?)
683 }
684
685 #[cfg(any(test, feature = "seed-support"))]
686 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
687 let query = "
688 INSERT INTO orgs (name, slug)
689 VALUES ($1, $2)
690 RETURNING id
691 ";
692 Ok(sqlx::query_scalar(query)
693 .bind(name)
694 .bind(slug)
695 .fetch_one(&self.pool)
696 .await
697 .map(OrgId)?)
698 }
699
700 #[cfg(any(test, feature = "seed-support"))]
701 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
702 let query = "
703 INSERT INTO org_memberships (org_id, user_id, admin)
704 VALUES ($1, $2, $3)
705 ON CONFLICT DO NOTHING
706 ";
707 Ok(sqlx::query(query)
708 .bind(org_id.0)
709 .bind(user_id.0)
710 .bind(is_admin)
711 .execute(&self.pool)
712 .await
713 .map(drop)?)
714 }
715
716 // channels
717
718 #[cfg(any(test, feature = "seed-support"))]
719 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
720 let query = "
721 INSERT INTO channels (owner_id, owner_is_user, name)
722 VALUES ($1, false, $2)
723 RETURNING id
724 ";
725 Ok(sqlx::query_scalar(query)
726 .bind(org_id.0)
727 .bind(name)
728 .fetch_one(&self.pool)
729 .await
730 .map(ChannelId)?)
731 }
732
733 #[allow(unused)] // Help rust-analyzer
734 #[cfg(any(test, feature = "seed-support"))]
735 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
736 let query = "
737 SELECT *
738 FROM channels
739 WHERE
740 channels.owner_is_user = false AND
741 channels.owner_id = $1
742 ";
743 Ok(sqlx::query_as(query)
744 .bind(org_id.0)
745 .fetch_all(&self.pool)
746 .await?)
747 }
748
749 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
750 let query = "
751 SELECT
752 channels.*
753 FROM
754 channel_memberships, channels
755 WHERE
756 channel_memberships.user_id = $1 AND
757 channel_memberships.channel_id = channels.id
758 ";
759 Ok(sqlx::query_as(query)
760 .bind(user_id.0)
761 .fetch_all(&self.pool)
762 .await?)
763 }
764
765 async fn can_user_access_channel(
766 &self,
767 user_id: UserId,
768 channel_id: ChannelId,
769 ) -> Result<bool> {
770 let query = "
771 SELECT id
772 FROM channel_memberships
773 WHERE user_id = $1 AND channel_id = $2
774 LIMIT 1
775 ";
776 Ok(sqlx::query_scalar::<_, i32>(query)
777 .bind(user_id.0)
778 .bind(channel_id.0)
779 .fetch_optional(&self.pool)
780 .await
781 .map(|e| e.is_some())?)
782 }
783
784 #[cfg(any(test, feature = "seed-support"))]
785 async fn add_channel_member(
786 &self,
787 channel_id: ChannelId,
788 user_id: UserId,
789 is_admin: bool,
790 ) -> Result<()> {
791 let query = "
792 INSERT INTO channel_memberships (channel_id, user_id, admin)
793 VALUES ($1, $2, $3)
794 ON CONFLICT DO NOTHING
795 ";
796 Ok(sqlx::query(query)
797 .bind(channel_id.0)
798 .bind(user_id.0)
799 .bind(is_admin)
800 .execute(&self.pool)
801 .await
802 .map(drop)?)
803 }
804
805 // messages
806
807 async fn create_channel_message(
808 &self,
809 channel_id: ChannelId,
810 sender_id: UserId,
811 body: &str,
812 timestamp: OffsetDateTime,
813 nonce: u128,
814 ) -> Result<MessageId> {
815 let query = "
816 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
817 VALUES ($1, $2, $3, $4, $5)
818 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
819 RETURNING id
820 ";
821 Ok(sqlx::query_scalar(query)
822 .bind(channel_id.0)
823 .bind(sender_id.0)
824 .bind(body)
825 .bind(timestamp)
826 .bind(Uuid::from_u128(nonce))
827 .fetch_one(&self.pool)
828 .await
829 .map(MessageId)?)
830 }
831
832 async fn get_channel_messages(
833 &self,
834 channel_id: ChannelId,
835 count: usize,
836 before_id: Option<MessageId>,
837 ) -> Result<Vec<ChannelMessage>> {
838 let query = r#"
839 SELECT * FROM (
840 SELECT
841 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
842 FROM
843 channel_messages
844 WHERE
845 channel_id = $1 AND
846 id < $2
847 ORDER BY id DESC
848 LIMIT $3
849 ) as recent_messages
850 ORDER BY id ASC
851 "#;
852 Ok(sqlx::query_as(query)
853 .bind(channel_id.0)
854 .bind(before_id.unwrap_or(MessageId::MAX))
855 .bind(count as i64)
856 .fetch_all(&self.pool)
857 .await?)
858 }
859
860 #[cfg(test)]
861 async fn teardown(&self, url: &str) {
862 use util::ResultExt;
863
864 let query = "
865 SELECT pg_terminate_backend(pg_stat_activity.pid)
866 FROM pg_stat_activity
867 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
868 ";
869 sqlx::query(query).execute(&self.pool).await.log_err();
870 self.pool.close().await;
871 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
872 .await
873 .log_err();
874 }
875
876 #[cfg(test)]
877 fn as_fake(&self) -> Option<&tests::FakeDb> {
878 None
879 }
880}
881
882macro_rules! id_type {
883 ($name:ident) => {
884 #[derive(
885 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
886 )]
887 #[sqlx(transparent)]
888 #[serde(transparent)]
889 pub struct $name(pub i32);
890
891 impl $name {
892 #[allow(unused)]
893 pub const MAX: Self = Self(i32::MAX);
894
895 #[allow(unused)]
896 pub fn from_proto(value: u64) -> Self {
897 Self(value as i32)
898 }
899
900 #[allow(unused)]
901 pub fn to_proto(&self) -> u64 {
902 self.0 as u64
903 }
904 }
905
906 impl std::fmt::Display for $name {
907 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
908 self.0.fmt(f)
909 }
910 }
911 };
912}
913
914id_type!(UserId);
915#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
916pub struct User {
917 pub id: UserId,
918 pub github_login: String,
919 pub email_address: Option<String>,
920 pub admin: bool,
921 pub invite_code: Option<String>,
922 pub invite_count: i32,
923 pub connected_once: bool,
924}
925
926id_type!(OrgId);
927#[derive(FromRow)]
928pub struct Org {
929 pub id: OrgId,
930 pub name: String,
931 pub slug: String,
932}
933
934id_type!(ChannelId);
935#[derive(Clone, Debug, FromRow, Serialize)]
936pub struct Channel {
937 pub id: ChannelId,
938 pub name: String,
939 pub owner_id: i32,
940 pub owner_is_user: bool,
941}
942
943id_type!(MessageId);
944#[derive(Clone, Debug, FromRow)]
945pub struct ChannelMessage {
946 pub id: MessageId,
947 pub channel_id: ChannelId,
948 pub sender_id: UserId,
949 pub body: String,
950 pub sent_at: OffsetDateTime,
951 pub nonce: Uuid,
952}
953
954#[derive(Clone, Debug, PartialEq, Eq)]
955pub enum Contact {
956 Accepted {
957 user_id: UserId,
958 should_notify: bool,
959 },
960 Outgoing {
961 user_id: UserId,
962 },
963 Incoming {
964 user_id: UserId,
965 should_notify: bool,
966 },
967}
968
969impl Contact {
970 pub fn user_id(&self) -> UserId {
971 match self {
972 Contact::Accepted { user_id, .. } => *user_id,
973 Contact::Outgoing { user_id } => *user_id,
974 Contact::Incoming { user_id, .. } => *user_id,
975 }
976 }
977}
978
979#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
980pub struct IncomingContactRequest {
981 pub requester_id: UserId,
982 pub should_notify: bool,
983}
984
985fn fuzzy_like_string(string: &str) -> String {
986 let mut result = String::with_capacity(string.len() * 2 + 1);
987 for c in string.chars() {
988 if c.is_alphanumeric() {
989 result.push('%');
990 result.push(c);
991 }
992 }
993 result.push('%');
994 result
995}
996
997#[cfg(test)]
998pub mod tests {
999 use super::*;
1000 use anyhow::anyhow;
1001 use collections::BTreeMap;
1002 use gpui::executor::Background;
1003 use lazy_static::lazy_static;
1004 use parking_lot::Mutex;
1005 use rand::prelude::*;
1006 use sqlx::{
1007 migrate::{MigrateDatabase, Migrator},
1008 Postgres,
1009 };
1010 use std::{path::Path, sync::Arc};
1011 use util::post_inc;
1012
1013 #[tokio::test(flavor = "multi_thread")]
1014 async fn test_get_users_by_ids() {
1015 for test_db in [
1016 TestDb::postgres().await,
1017 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1018 ] {
1019 let db = test_db.db();
1020
1021 let user = db.create_user("user", None, false).await.unwrap();
1022 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1023 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1024 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1025
1026 assert_eq!(
1027 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1028 .await
1029 .unwrap(),
1030 vec![
1031 User {
1032 id: user,
1033 github_login: "user".to_string(),
1034 admin: false,
1035 ..Default::default()
1036 },
1037 User {
1038 id: friend1,
1039 github_login: "friend-1".to_string(),
1040 admin: false,
1041 ..Default::default()
1042 },
1043 User {
1044 id: friend2,
1045 github_login: "friend-2".to_string(),
1046 admin: false,
1047 ..Default::default()
1048 },
1049 User {
1050 id: friend3,
1051 github_login: "friend-3".to_string(),
1052 admin: false,
1053 ..Default::default()
1054 }
1055 ]
1056 );
1057 }
1058 }
1059
1060 #[tokio::test(flavor = "multi_thread")]
1061 async fn test_create_users() {
1062 let db = TestDb::postgres().await;
1063 let db = db.db();
1064
1065 // Create the first batch of users, ensuring invite counts are assigned
1066 // correctly and the respective invite codes are unique.
1067 let user_ids_batch_1 = db
1068 .create_users(vec![
1069 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1070 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1071 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1072 ])
1073 .await
1074 .unwrap();
1075 assert_eq!(user_ids_batch_1.len(), 3);
1076
1077 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1078 assert_eq!(users.len(), 3);
1079 assert_eq!(users[0].github_login, "user1");
1080 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1081 assert_eq!(users[0].invite_count, 5);
1082 assert_eq!(users[1].github_login, "user2");
1083 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1084 assert_eq!(users[1].invite_count, 4);
1085 assert_eq!(users[2].github_login, "user3");
1086 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1087 assert_eq!(users[2].invite_count, 3);
1088
1089 let invite_code_1 = users[0].invite_code.clone().unwrap();
1090 let invite_code_2 = users[1].invite_code.clone().unwrap();
1091 let invite_code_3 = users[2].invite_code.clone().unwrap();
1092 assert_ne!(invite_code_1, invite_code_2);
1093 assert_ne!(invite_code_1, invite_code_3);
1094 assert_ne!(invite_code_2, invite_code_3);
1095
1096 // Create the second batch of users and include a user that is already in the database, ensuring
1097 // the invite count for the existing user is updated without changing their invite code.
1098 let user_ids_batch_2 = db
1099 .create_users(vec![
1100 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1101 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1102 ])
1103 .await
1104 .unwrap();
1105 assert_eq!(user_ids_batch_2.len(), 2);
1106 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1107
1108 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1109 assert_eq!(users.len(), 2);
1110 assert_eq!(users[0].github_login, "user2");
1111 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1112 assert_eq!(users[0].invite_count, 10);
1113 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1114 assert_eq!(users[1].github_login, "user4");
1115 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1116 assert_eq!(users[1].invite_count, 2);
1117
1118 let invite_code_4 = users[1].invite_code.clone().unwrap();
1119 assert_ne!(invite_code_4, invite_code_1);
1120 assert_ne!(invite_code_4, invite_code_2);
1121 assert_ne!(invite_code_4, invite_code_3);
1122 }
1123
1124 #[tokio::test(flavor = "multi_thread")]
1125 async fn test_recent_channel_messages() {
1126 for test_db in [
1127 TestDb::postgres().await,
1128 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1129 ] {
1130 let db = test_db.db();
1131 let user = db.create_user("user", None, false).await.unwrap();
1132 let org = db.create_org("org", "org").await.unwrap();
1133 let channel = db.create_org_channel(org, "channel").await.unwrap();
1134 for i in 0..10 {
1135 db.create_channel_message(
1136 channel,
1137 user,
1138 &i.to_string(),
1139 OffsetDateTime::now_utc(),
1140 i,
1141 )
1142 .await
1143 .unwrap();
1144 }
1145
1146 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1147 assert_eq!(
1148 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1149 ["5", "6", "7", "8", "9"]
1150 );
1151
1152 let prev_messages = db
1153 .get_channel_messages(channel, 4, Some(messages[0].id))
1154 .await
1155 .unwrap();
1156 assert_eq!(
1157 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1158 ["1", "2", "3", "4"]
1159 );
1160 }
1161 }
1162
1163 #[tokio::test(flavor = "multi_thread")]
1164 async fn test_channel_message_nonces() {
1165 for test_db in [
1166 TestDb::postgres().await,
1167 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1168 ] {
1169 let db = test_db.db();
1170 let user = db.create_user("user", None, false).await.unwrap();
1171 let org = db.create_org("org", "org").await.unwrap();
1172 let channel = db.create_org_channel(org, "channel").await.unwrap();
1173
1174 let msg1_id = db
1175 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1176 .await
1177 .unwrap();
1178 let msg2_id = db
1179 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1180 .await
1181 .unwrap();
1182 let msg3_id = db
1183 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1184 .await
1185 .unwrap();
1186 let msg4_id = db
1187 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1188 .await
1189 .unwrap();
1190
1191 assert_ne!(msg1_id, msg2_id);
1192 assert_eq!(msg1_id, msg3_id);
1193 assert_eq!(msg2_id, msg4_id);
1194 }
1195 }
1196
1197 #[tokio::test(flavor = "multi_thread")]
1198 async fn test_create_access_tokens() {
1199 let test_db = TestDb::postgres().await;
1200 let db = test_db.db();
1201 let user = db.create_user("the-user", None, false).await.unwrap();
1202
1203 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1204 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1205 assert_eq!(
1206 db.get_access_token_hashes(user).await.unwrap(),
1207 &["h2".to_string(), "h1".to_string()]
1208 );
1209
1210 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1211 assert_eq!(
1212 db.get_access_token_hashes(user).await.unwrap(),
1213 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1214 );
1215
1216 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1217 assert_eq!(
1218 db.get_access_token_hashes(user).await.unwrap(),
1219 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1220 );
1221
1222 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1223 assert_eq!(
1224 db.get_access_token_hashes(user).await.unwrap(),
1225 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1226 );
1227 }
1228
1229 #[test]
1230 fn test_fuzzy_like_string() {
1231 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1232 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1233 assert_eq!(fuzzy_like_string(" z "), "%z%");
1234 }
1235
1236 #[tokio::test(flavor = "multi_thread")]
1237 async fn test_fuzzy_search_users() {
1238 let test_db = TestDb::postgres().await;
1239 let db = test_db.db();
1240 for github_login in [
1241 "California",
1242 "colorado",
1243 "oregon",
1244 "washington",
1245 "florida",
1246 "delaware",
1247 "rhode-island",
1248 ] {
1249 db.create_user(github_login, None, false).await.unwrap();
1250 }
1251
1252 assert_eq!(
1253 fuzzy_search_user_names(db, "clr").await,
1254 &["colorado", "California"]
1255 );
1256 assert_eq!(
1257 fuzzy_search_user_names(db, "ro").await,
1258 &["rhode-island", "colorado", "oregon"],
1259 );
1260
1261 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1262 db.fuzzy_search_users(query, 10)
1263 .await
1264 .unwrap()
1265 .into_iter()
1266 .map(|user| user.github_login)
1267 .collect::<Vec<_>>()
1268 }
1269 }
1270
1271 #[tokio::test(flavor = "multi_thread")]
1272 async fn test_add_contacts() {
1273 for test_db in [
1274 TestDb::postgres().await,
1275 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1276 ] {
1277 let db = test_db.db();
1278
1279 let user_1 = db.create_user("user1", None, false).await.unwrap();
1280 let user_2 = db.create_user("user2", None, false).await.unwrap();
1281 let user_3 = db.create_user("user3", None, false).await.unwrap();
1282
1283 // User starts with no contacts
1284 assert_eq!(
1285 db.get_contacts(user_1).await.unwrap(),
1286 vec![Contact::Accepted {
1287 user_id: user_1,
1288 should_notify: false
1289 }],
1290 );
1291
1292 // User requests a contact. Both users see the pending request.
1293 db.send_contact_request(user_1, user_2).await.unwrap();
1294 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1295 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1296 assert_eq!(
1297 db.get_contacts(user_1).await.unwrap(),
1298 &[
1299 Contact::Accepted {
1300 user_id: user_1,
1301 should_notify: false
1302 },
1303 Contact::Outgoing { user_id: user_2 }
1304 ],
1305 );
1306 assert_eq!(
1307 db.get_contacts(user_2).await.unwrap(),
1308 &[
1309 Contact::Incoming {
1310 user_id: user_1,
1311 should_notify: true
1312 },
1313 Contact::Accepted {
1314 user_id: user_2,
1315 should_notify: false
1316 },
1317 ]
1318 );
1319
1320 // User 2 dismisses the contact request notification without accepting or rejecting.
1321 // We shouldn't notify them again.
1322 db.dismiss_contact_notification(user_1, user_2)
1323 .await
1324 .unwrap_err();
1325 db.dismiss_contact_notification(user_2, user_1)
1326 .await
1327 .unwrap();
1328 assert_eq!(
1329 db.get_contacts(user_2).await.unwrap(),
1330 &[
1331 Contact::Incoming {
1332 user_id: user_1,
1333 should_notify: false
1334 },
1335 Contact::Accepted {
1336 user_id: user_2,
1337 should_notify: false
1338 },
1339 ]
1340 );
1341
1342 // User can't accept their own contact request
1343 db.respond_to_contact_request(user_1, user_2, true)
1344 .await
1345 .unwrap_err();
1346
1347 // User accepts a contact request. Both users see the contact.
1348 db.respond_to_contact_request(user_2, user_1, true)
1349 .await
1350 .unwrap();
1351 assert_eq!(
1352 db.get_contacts(user_1).await.unwrap(),
1353 &[
1354 Contact::Accepted {
1355 user_id: user_1,
1356 should_notify: false
1357 },
1358 Contact::Accepted {
1359 user_id: user_2,
1360 should_notify: true
1361 }
1362 ],
1363 );
1364 assert!(db.has_contact(user_1, user_2).await.unwrap());
1365 assert!(db.has_contact(user_2, user_1).await.unwrap());
1366 assert_eq!(
1367 db.get_contacts(user_2).await.unwrap(),
1368 &[
1369 Contact::Accepted {
1370 user_id: user_1,
1371 should_notify: false,
1372 },
1373 Contact::Accepted {
1374 user_id: user_2,
1375 should_notify: false,
1376 },
1377 ]
1378 );
1379
1380 // Users cannot re-request existing contacts.
1381 db.send_contact_request(user_1, user_2).await.unwrap_err();
1382 db.send_contact_request(user_2, user_1).await.unwrap_err();
1383
1384 // Users can't dismiss notifications of them accepting other users' requests.
1385 db.dismiss_contact_notification(user_2, user_1)
1386 .await
1387 .unwrap_err();
1388 assert_eq!(
1389 db.get_contacts(user_1).await.unwrap(),
1390 &[
1391 Contact::Accepted {
1392 user_id: user_1,
1393 should_notify: false
1394 },
1395 Contact::Accepted {
1396 user_id: user_2,
1397 should_notify: true,
1398 },
1399 ]
1400 );
1401
1402 // Users can dismiss notifications of other users accepting their requests.
1403 db.dismiss_contact_notification(user_1, user_2)
1404 .await
1405 .unwrap();
1406 assert_eq!(
1407 db.get_contacts(user_1).await.unwrap(),
1408 &[
1409 Contact::Accepted {
1410 user_id: user_1,
1411 should_notify: false
1412 },
1413 Contact::Accepted {
1414 user_id: user_2,
1415 should_notify: false,
1416 },
1417 ]
1418 );
1419
1420 // Users send each other concurrent contact requests and
1421 // see that they are immediately accepted.
1422 db.send_contact_request(user_1, user_3).await.unwrap();
1423 db.send_contact_request(user_3, user_1).await.unwrap();
1424 assert_eq!(
1425 db.get_contacts(user_1).await.unwrap(),
1426 &[
1427 Contact::Accepted {
1428 user_id: user_1,
1429 should_notify: false
1430 },
1431 Contact::Accepted {
1432 user_id: user_2,
1433 should_notify: false,
1434 },
1435 Contact::Accepted {
1436 user_id: user_3,
1437 should_notify: false
1438 },
1439 ]
1440 );
1441 assert_eq!(
1442 db.get_contacts(user_3).await.unwrap(),
1443 &[
1444 Contact::Accepted {
1445 user_id: user_1,
1446 should_notify: false
1447 },
1448 Contact::Accepted {
1449 user_id: user_3,
1450 should_notify: false
1451 }
1452 ],
1453 );
1454
1455 // User declines a contact request. Both users see that it is gone.
1456 db.send_contact_request(user_2, user_3).await.unwrap();
1457 db.respond_to_contact_request(user_3, user_2, false)
1458 .await
1459 .unwrap();
1460 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1461 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1462 assert_eq!(
1463 db.get_contacts(user_2).await.unwrap(),
1464 &[
1465 Contact::Accepted {
1466 user_id: user_1,
1467 should_notify: false
1468 },
1469 Contact::Accepted {
1470 user_id: user_2,
1471 should_notify: false
1472 }
1473 ]
1474 );
1475 assert_eq!(
1476 db.get_contacts(user_3).await.unwrap(),
1477 &[
1478 Contact::Accepted {
1479 user_id: user_1,
1480 should_notify: false
1481 },
1482 Contact::Accepted {
1483 user_id: user_3,
1484 should_notify: false
1485 }
1486 ],
1487 );
1488 }
1489 }
1490
1491 #[tokio::test(flavor = "multi_thread")]
1492 async fn test_invite_codes() {
1493 let postgres = TestDb::postgres().await;
1494 let db = postgres.db();
1495 let user1 = db.create_user("user-1", None, false).await.unwrap();
1496
1497 // Initially, user 1 has no invite code
1498 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1499
1500 // Setting invite count to 0 when no code is assigned does not assign a new code
1501 db.set_invite_count(user1, 0).await.unwrap();
1502 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1503
1504 // User 1 creates an invite code that can be used twice.
1505 db.set_invite_count(user1, 2).await.unwrap();
1506 let (invite_code, invite_count) =
1507 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1508 assert_eq!(invite_count, 2);
1509
1510 // User 2 redeems the invite code and becomes a contact of user 1.
1511 let user2 = db
1512 .redeem_invite_code(&invite_code, "user-2", None)
1513 .await
1514 .unwrap();
1515 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1516 assert_eq!(invite_count, 1);
1517 assert_eq!(
1518 db.get_contacts(user1).await.unwrap(),
1519 [
1520 Contact::Accepted {
1521 user_id: user1,
1522 should_notify: false
1523 },
1524 Contact::Accepted {
1525 user_id: user2,
1526 should_notify: true
1527 }
1528 ]
1529 );
1530 assert_eq!(
1531 db.get_contacts(user2).await.unwrap(),
1532 [
1533 Contact::Accepted {
1534 user_id: user1,
1535 should_notify: false
1536 },
1537 Contact::Accepted {
1538 user_id: user2,
1539 should_notify: false
1540 }
1541 ]
1542 );
1543
1544 // User 3 redeems the invite code and becomes a contact of user 1.
1545 let user3 = db
1546 .redeem_invite_code(&invite_code, "user-3", None)
1547 .await
1548 .unwrap();
1549 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1550 assert_eq!(invite_count, 0);
1551 assert_eq!(
1552 db.get_contacts(user1).await.unwrap(),
1553 [
1554 Contact::Accepted {
1555 user_id: user1,
1556 should_notify: false
1557 },
1558 Contact::Accepted {
1559 user_id: user2,
1560 should_notify: true
1561 },
1562 Contact::Accepted {
1563 user_id: user3,
1564 should_notify: true
1565 }
1566 ]
1567 );
1568 assert_eq!(
1569 db.get_contacts(user3).await.unwrap(),
1570 [
1571 Contact::Accepted {
1572 user_id: user1,
1573 should_notify: false
1574 },
1575 Contact::Accepted {
1576 user_id: user3,
1577 should_notify: false
1578 },
1579 ]
1580 );
1581
1582 // Trying to reedem the code for the third time results in an error.
1583 db.redeem_invite_code(&invite_code, "user-4", None)
1584 .await
1585 .unwrap_err();
1586
1587 // Invite count can be updated after the code has been created.
1588 db.set_invite_count(user1, 2).await.unwrap();
1589 let (latest_code, invite_count) =
1590 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1591 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1592 assert_eq!(invite_count, 2);
1593
1594 // User 4 can now redeem the invite code and becomes a contact of user 1.
1595 let user4 = db
1596 .redeem_invite_code(&invite_code, "user-4", None)
1597 .await
1598 .unwrap();
1599 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1600 assert_eq!(invite_count, 1);
1601 assert_eq!(
1602 db.get_contacts(user1).await.unwrap(),
1603 [
1604 Contact::Accepted {
1605 user_id: user1,
1606 should_notify: false
1607 },
1608 Contact::Accepted {
1609 user_id: user2,
1610 should_notify: true
1611 },
1612 Contact::Accepted {
1613 user_id: user3,
1614 should_notify: true
1615 },
1616 Contact::Accepted {
1617 user_id: user4,
1618 should_notify: true
1619 }
1620 ]
1621 );
1622 assert_eq!(
1623 db.get_contacts(user4).await.unwrap(),
1624 [
1625 Contact::Accepted {
1626 user_id: user1,
1627 should_notify: false
1628 },
1629 Contact::Accepted {
1630 user_id: user4,
1631 should_notify: false
1632 },
1633 ]
1634 );
1635
1636 // An existing user cannot redeem invite codes.
1637 db.redeem_invite_code(&invite_code, "user-2", None)
1638 .await
1639 .unwrap_err();
1640 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1641 assert_eq!(invite_count, 1);
1642 }
1643
1644 pub struct TestDb {
1645 pub db: Option<Arc<dyn Db>>,
1646 pub url: String,
1647 }
1648
1649 impl TestDb {
1650 pub async fn postgres() -> Self {
1651 lazy_static! {
1652 static ref LOCK: Mutex<()> = Mutex::new(());
1653 }
1654
1655 let _guard = LOCK.lock();
1656 let mut rng = StdRng::from_entropy();
1657 let name = format!("zed-test-{}", rng.gen::<u128>());
1658 let url = format!("postgres://postgres@localhost/{}", name);
1659 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1660 Postgres::create_database(&url)
1661 .await
1662 .expect("failed to create test db");
1663 let db = PostgresDb::new(&url, 5).await.unwrap();
1664 let migrator = Migrator::new(migrations_path).await.unwrap();
1665 migrator.run(&db.pool).await.unwrap();
1666 Self {
1667 db: Some(Arc::new(db)),
1668 url,
1669 }
1670 }
1671
1672 pub fn fake(background: Arc<Background>) -> Self {
1673 Self {
1674 db: Some(Arc::new(FakeDb::new(background))),
1675 url: Default::default(),
1676 }
1677 }
1678
1679 pub fn db(&self) -> &Arc<dyn Db> {
1680 self.db.as_ref().unwrap()
1681 }
1682 }
1683
1684 impl Drop for TestDb {
1685 fn drop(&mut self) {
1686 if let Some(db) = self.db.take() {
1687 futures::executor::block_on(db.teardown(&self.url));
1688 }
1689 }
1690 }
1691
1692 pub struct FakeDb {
1693 background: Arc<Background>,
1694 pub users: Mutex<BTreeMap<UserId, User>>,
1695 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1696 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1697 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1698 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1699 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1700 pub contacts: Mutex<Vec<FakeContact>>,
1701 next_channel_message_id: Mutex<i32>,
1702 next_user_id: Mutex<i32>,
1703 next_org_id: Mutex<i32>,
1704 next_channel_id: Mutex<i32>,
1705 }
1706
1707 #[derive(Debug)]
1708 pub struct FakeContact {
1709 pub requester_id: UserId,
1710 pub responder_id: UserId,
1711 pub accepted: bool,
1712 pub should_notify: bool,
1713 }
1714
1715 impl FakeDb {
1716 pub fn new(background: Arc<Background>) -> Self {
1717 Self {
1718 background,
1719 users: Default::default(),
1720 next_user_id: Mutex::new(1),
1721 orgs: Default::default(),
1722 next_org_id: Mutex::new(1),
1723 org_memberships: Default::default(),
1724 channels: Default::default(),
1725 next_channel_id: Mutex::new(1),
1726 channel_memberships: Default::default(),
1727 channel_messages: Default::default(),
1728 next_channel_message_id: Mutex::new(1),
1729 contacts: Default::default(),
1730 }
1731 }
1732 }
1733
1734 #[async_trait]
1735 impl Db for FakeDb {
1736 async fn create_user(
1737 &self,
1738 github_login: &str,
1739 email_address: Option<&str>,
1740 admin: bool,
1741 ) -> Result<UserId> {
1742 self.background.simulate_random_delay().await;
1743
1744 let mut users = self.users.lock();
1745 if let Some(user) = users
1746 .values()
1747 .find(|user| user.github_login == github_login)
1748 {
1749 Ok(user.id)
1750 } else {
1751 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1752 users.insert(
1753 user_id,
1754 User {
1755 id: user_id,
1756 github_login: github_login.to_string(),
1757 email_address: email_address.map(str::to_string),
1758 admin,
1759 invite_code: None,
1760 invite_count: 0,
1761 connected_once: false,
1762 },
1763 );
1764 Ok(user_id)
1765 }
1766 }
1767
1768 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
1769 unimplemented!()
1770 }
1771
1772 async fn get_all_users(&self) -> Result<Vec<User>> {
1773 unimplemented!()
1774 }
1775
1776 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1777 unimplemented!()
1778 }
1779
1780 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1781 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1782 }
1783
1784 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1785 self.background.simulate_random_delay().await;
1786 let users = self.users.lock();
1787 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1788 }
1789
1790 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1791 Ok(self
1792 .users
1793 .lock()
1794 .values()
1795 .find(|user| user.github_login == github_login)
1796 .cloned())
1797 }
1798
1799 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1800 unimplemented!()
1801 }
1802
1803 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1804 self.background.simulate_random_delay().await;
1805 let mut users = self.users.lock();
1806 let mut user = users
1807 .get_mut(&id)
1808 .ok_or_else(|| anyhow!("user not found"))?;
1809 user.connected_once = connected_once;
1810 Ok(())
1811 }
1812
1813 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1814 unimplemented!()
1815 }
1816
1817 // invite codes
1818
1819 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1820 unimplemented!()
1821 }
1822
1823 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1824 Ok(None)
1825 }
1826
1827 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1828 unimplemented!()
1829 }
1830
1831 async fn redeem_invite_code(
1832 &self,
1833 _code: &str,
1834 _login: &str,
1835 _email_address: Option<&str>,
1836 ) -> Result<UserId> {
1837 unimplemented!()
1838 }
1839
1840 // contacts
1841
1842 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1843 self.background.simulate_random_delay().await;
1844 let mut contacts = vec![Contact::Accepted {
1845 user_id: id,
1846 should_notify: false,
1847 }];
1848
1849 for contact in self.contacts.lock().iter() {
1850 if contact.requester_id == id {
1851 if contact.accepted {
1852 contacts.push(Contact::Accepted {
1853 user_id: contact.responder_id,
1854 should_notify: contact.should_notify,
1855 });
1856 } else {
1857 contacts.push(Contact::Outgoing {
1858 user_id: contact.responder_id,
1859 });
1860 }
1861 } else if contact.responder_id == id {
1862 if contact.accepted {
1863 contacts.push(Contact::Accepted {
1864 user_id: contact.requester_id,
1865 should_notify: false,
1866 });
1867 } else {
1868 contacts.push(Contact::Incoming {
1869 user_id: contact.requester_id,
1870 should_notify: contact.should_notify,
1871 });
1872 }
1873 }
1874 }
1875
1876 contacts.sort_unstable_by_key(|contact| contact.user_id());
1877 Ok(contacts)
1878 }
1879
1880 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1881 self.background.simulate_random_delay().await;
1882 Ok(self.contacts.lock().iter().any(|contact| {
1883 contact.accepted
1884 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1885 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1886 }))
1887 }
1888
1889 async fn send_contact_request(
1890 &self,
1891 requester_id: UserId,
1892 responder_id: UserId,
1893 ) -> Result<()> {
1894 let mut contacts = self.contacts.lock();
1895 for contact in contacts.iter_mut() {
1896 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1897 if contact.accepted {
1898 Err(anyhow!("contact already exists"))?;
1899 } else {
1900 Err(anyhow!("contact already requested"))?;
1901 }
1902 }
1903 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1904 if contact.accepted {
1905 Err(anyhow!("contact already exists"))?;
1906 } else {
1907 contact.accepted = true;
1908 contact.should_notify = false;
1909 return Ok(());
1910 }
1911 }
1912 }
1913 contacts.push(FakeContact {
1914 requester_id,
1915 responder_id,
1916 accepted: false,
1917 should_notify: true,
1918 });
1919 Ok(())
1920 }
1921
1922 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1923 self.contacts.lock().retain(|contact| {
1924 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1925 });
1926 Ok(())
1927 }
1928
1929 async fn dismiss_contact_notification(
1930 &self,
1931 user_id: UserId,
1932 contact_user_id: UserId,
1933 ) -> Result<()> {
1934 let mut contacts = self.contacts.lock();
1935 for contact in contacts.iter_mut() {
1936 if contact.requester_id == contact_user_id
1937 && contact.responder_id == user_id
1938 && !contact.accepted
1939 {
1940 contact.should_notify = false;
1941 return Ok(());
1942 }
1943 if contact.requester_id == user_id
1944 && contact.responder_id == contact_user_id
1945 && contact.accepted
1946 {
1947 contact.should_notify = false;
1948 return Ok(());
1949 }
1950 }
1951 Err(anyhow!("no such notification"))?
1952 }
1953
1954 async fn respond_to_contact_request(
1955 &self,
1956 responder_id: UserId,
1957 requester_id: UserId,
1958 accept: bool,
1959 ) -> Result<()> {
1960 let mut contacts = self.contacts.lock();
1961 for (ix, contact) in contacts.iter_mut().enumerate() {
1962 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1963 if contact.accepted {
1964 Err(anyhow!("contact already confirmed"))?;
1965 }
1966 if accept {
1967 contact.accepted = true;
1968 contact.should_notify = true;
1969 } else {
1970 contacts.remove(ix);
1971 }
1972 return Ok(());
1973 }
1974 }
1975 Err(anyhow!("no such contact request"))?
1976 }
1977
1978 async fn create_access_token_hash(
1979 &self,
1980 _user_id: UserId,
1981 _access_token_hash: &str,
1982 _max_access_token_count: usize,
1983 ) -> Result<()> {
1984 unimplemented!()
1985 }
1986
1987 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1988 unimplemented!()
1989 }
1990
1991 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1992 unimplemented!()
1993 }
1994
1995 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1996 self.background.simulate_random_delay().await;
1997 let mut orgs = self.orgs.lock();
1998 if orgs.values().any(|org| org.slug == slug) {
1999 Err(anyhow!("org already exists"))?
2000 } else {
2001 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2002 orgs.insert(
2003 org_id,
2004 Org {
2005 id: org_id,
2006 name: name.to_string(),
2007 slug: slug.to_string(),
2008 },
2009 );
2010 Ok(org_id)
2011 }
2012 }
2013
2014 async fn add_org_member(
2015 &self,
2016 org_id: OrgId,
2017 user_id: UserId,
2018 is_admin: bool,
2019 ) -> Result<()> {
2020 self.background.simulate_random_delay().await;
2021 if !self.orgs.lock().contains_key(&org_id) {
2022 Err(anyhow!("org does not exist"))?;
2023 }
2024 if !self.users.lock().contains_key(&user_id) {
2025 Err(anyhow!("user does not exist"))?;
2026 }
2027
2028 self.org_memberships
2029 .lock()
2030 .entry((org_id, user_id))
2031 .or_insert(is_admin);
2032 Ok(())
2033 }
2034
2035 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2036 self.background.simulate_random_delay().await;
2037 if !self.orgs.lock().contains_key(&org_id) {
2038 Err(anyhow!("org does not exist"))?;
2039 }
2040
2041 let mut channels = self.channels.lock();
2042 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2043 channels.insert(
2044 channel_id,
2045 Channel {
2046 id: channel_id,
2047 name: name.to_string(),
2048 owner_id: org_id.0,
2049 owner_is_user: false,
2050 },
2051 );
2052 Ok(channel_id)
2053 }
2054
2055 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2056 self.background.simulate_random_delay().await;
2057 Ok(self
2058 .channels
2059 .lock()
2060 .values()
2061 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2062 .cloned()
2063 .collect())
2064 }
2065
2066 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2067 self.background.simulate_random_delay().await;
2068 let channels = self.channels.lock();
2069 let memberships = self.channel_memberships.lock();
2070 Ok(channels
2071 .values()
2072 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2073 .cloned()
2074 .collect())
2075 }
2076
2077 async fn can_user_access_channel(
2078 &self,
2079 user_id: UserId,
2080 channel_id: ChannelId,
2081 ) -> Result<bool> {
2082 self.background.simulate_random_delay().await;
2083 Ok(self
2084 .channel_memberships
2085 .lock()
2086 .contains_key(&(channel_id, user_id)))
2087 }
2088
2089 async fn add_channel_member(
2090 &self,
2091 channel_id: ChannelId,
2092 user_id: UserId,
2093 is_admin: bool,
2094 ) -> Result<()> {
2095 self.background.simulate_random_delay().await;
2096 if !self.channels.lock().contains_key(&channel_id) {
2097 Err(anyhow!("channel does not exist"))?;
2098 }
2099 if !self.users.lock().contains_key(&user_id) {
2100 Err(anyhow!("user does not exist"))?;
2101 }
2102
2103 self.channel_memberships
2104 .lock()
2105 .entry((channel_id, user_id))
2106 .or_insert(is_admin);
2107 Ok(())
2108 }
2109
2110 async fn create_channel_message(
2111 &self,
2112 channel_id: ChannelId,
2113 sender_id: UserId,
2114 body: &str,
2115 timestamp: OffsetDateTime,
2116 nonce: u128,
2117 ) -> Result<MessageId> {
2118 self.background.simulate_random_delay().await;
2119 if !self.channels.lock().contains_key(&channel_id) {
2120 Err(anyhow!("channel does not exist"))?;
2121 }
2122 if !self.users.lock().contains_key(&sender_id) {
2123 Err(anyhow!("user does not exist"))?;
2124 }
2125
2126 let mut messages = self.channel_messages.lock();
2127 if let Some(message) = messages
2128 .values()
2129 .find(|message| message.nonce.as_u128() == nonce)
2130 {
2131 Ok(message.id)
2132 } else {
2133 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2134 messages.insert(
2135 message_id,
2136 ChannelMessage {
2137 id: message_id,
2138 channel_id,
2139 sender_id,
2140 body: body.to_string(),
2141 sent_at: timestamp,
2142 nonce: Uuid::from_u128(nonce),
2143 },
2144 );
2145 Ok(message_id)
2146 }
2147 }
2148
2149 async fn get_channel_messages(
2150 &self,
2151 channel_id: ChannelId,
2152 count: usize,
2153 before_id: Option<MessageId>,
2154 ) -> Result<Vec<ChannelMessage>> {
2155 let mut messages = self
2156 .channel_messages
2157 .lock()
2158 .values()
2159 .rev()
2160 .filter(|message| {
2161 message.channel_id == channel_id
2162 && message.id < before_id.unwrap_or(MessageId::MAX)
2163 })
2164 .take(count)
2165 .cloned()
2166 .collect::<Vec<_>>();
2167 messages.sort_unstable_by_key(|message| message.id);
2168 Ok(messages)
2169 }
2170
2171 async fn teardown(&self, _: &str) {}
2172
2173 #[cfg(test)]
2174 fn as_fake(&self) -> Option<&FakeDb> {
2175 Some(self)
2176 }
2177 }
2178}