1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use futures::StreamExt;
6use nanoid::nanoid;
7use serde::Serialize;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
10use time::OffsetDateTime;
11
12#[async_trait]
13pub trait Db: Send + Sync {
14 async fn create_user(
15 &self,
16 github_login: &str,
17 email_address: Option<&str>,
18 admin: bool,
19 ) -> Result<UserId>;
20 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
21 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
22 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
23 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
24 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
25 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
26 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
27 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
28 async fn destroy_user(&self, id: UserId) -> Result<()>;
29
30 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
31 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
32 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
33 async fn redeem_invite_code(
34 &self,
35 code: &str,
36 login: &str,
37 email_address: Option<&str>,
38 ) -> Result<UserId>;
39
40 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
41 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
42 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
43 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
44 async fn dismiss_contact_notification(
45 &self,
46 responder_id: UserId,
47 requester_id: UserId,
48 ) -> Result<()>;
49 async fn respond_to_contact_request(
50 &self,
51 responder_id: UserId,
52 requester_id: UserId,
53 accept: bool,
54 ) -> Result<()>;
55
56 async fn create_access_token_hash(
57 &self,
58 user_id: UserId,
59 access_token_hash: &str,
60 max_access_token_count: usize,
61 ) -> Result<()>;
62 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
63 #[cfg(any(test, feature = "seed-support"))]
64
65 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
66 #[cfg(any(test, feature = "seed-support"))]
67 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
68 #[cfg(any(test, feature = "seed-support"))]
69 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
70 #[cfg(any(test, feature = "seed-support"))]
71 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
72 #[cfg(any(test, feature = "seed-support"))]
73
74 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
75 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
76 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
77 -> Result<bool>;
78 #[cfg(any(test, feature = "seed-support"))]
79 async fn add_channel_member(
80 &self,
81 channel_id: ChannelId,
82 user_id: UserId,
83 is_admin: bool,
84 ) -> Result<()>;
85 async fn create_channel_message(
86 &self,
87 channel_id: ChannelId,
88 sender_id: UserId,
89 body: &str,
90 timestamp: OffsetDateTime,
91 nonce: u128,
92 ) -> Result<MessageId>;
93 async fn get_channel_messages(
94 &self,
95 channel_id: ChannelId,
96 count: usize,
97 before_id: Option<MessageId>,
98 ) -> Result<Vec<ChannelMessage>>;
99 #[cfg(test)]
100 async fn teardown(&self, url: &str);
101 #[cfg(test)]
102 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
103}
104
105pub struct PostgresDb {
106 pool: sqlx::PgPool,
107}
108
109impl PostgresDb {
110 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
111 let pool = DbOptions::new()
112 .max_connections(max_connections)
113 .connect(&url)
114 .await
115 .context("failed to connect to postgres database")?;
116 Ok(Self { pool })
117 }
118}
119
120#[async_trait]
121impl Db for PostgresDb {
122 // users
123
124 async fn create_user(
125 &self,
126 github_login: &str,
127 email_address: Option<&str>,
128 admin: bool,
129 ) -> Result<UserId> {
130 let query = "
131 INSERT INTO users (github_login, email_address, admin)
132 VALUES ($1, $2, $3)
133 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
134 RETURNING id
135 ";
136 Ok(sqlx::query_scalar(query)
137 .bind(github_login)
138 .bind(email_address)
139 .bind(admin)
140 .fetch_one(&self.pool)
141 .await
142 .map(UserId)?)
143 }
144
145 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
146 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
147 Ok(sqlx::query_as(query)
148 .bind(limit)
149 .bind(page * limit)
150 .fetch_all(&self.pool)
151 .await?)
152 }
153
154 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
155 let mut query = QueryBuilder::new(
156 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
157 );
158 query.push_values(
159 users,
160 |mut query, (github_login, email_address, invite_count)| {
161 query
162 .push_bind(github_login)
163 .push_bind(email_address)
164 .push_bind(false)
165 .push_bind(nanoid!(16))
166 .push_bind(invite_count as u32);
167 },
168 );
169 query.push(
170 "
171 ON CONFLICT (github_login) DO UPDATE SET
172 github_login = excluded.github_login,
173 invite_count = excluded.invite_count,
174 invite_code = CASE WHEN users.invite_code IS NULL
175 THEN excluded.invite_code
176 ELSE users.invite_code
177 END
178 RETURNING id
179 ",
180 );
181
182 let rows = query.build().fetch_all(&self.pool).await?;
183 Ok(rows
184 .into_iter()
185 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
186 .collect())
187 }
188
189 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
190 let like_string = fuzzy_like_string(name_query);
191 let query = "
192 SELECT users.*
193 FROM users
194 WHERE github_login ILIKE $1
195 ORDER BY github_login <-> $2
196 LIMIT $3
197 ";
198 Ok(sqlx::query_as(query)
199 .bind(like_string)
200 .bind(name_query)
201 .bind(limit)
202 .fetch_all(&self.pool)
203 .await?)
204 }
205
206 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
207 let users = self.get_users_by_ids(vec![id]).await?;
208 Ok(users.into_iter().next())
209 }
210
211 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
212 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
213 let query = "
214 SELECT users.*
215 FROM users
216 WHERE users.id = ANY ($1)
217 ";
218 Ok(sqlx::query_as(query)
219 .bind(&ids)
220 .fetch_all(&self.pool)
221 .await?)
222 }
223
224 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
225 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
226 Ok(sqlx::query_as(query)
227 .bind(github_login)
228 .fetch_optional(&self.pool)
229 .await?)
230 }
231
232 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
233 let query = "UPDATE users SET admin = $1 WHERE id = $2";
234 Ok(sqlx::query(query)
235 .bind(is_admin)
236 .bind(id.0)
237 .execute(&self.pool)
238 .await
239 .map(drop)?)
240 }
241
242 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
243 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
244 Ok(sqlx::query(query)
245 .bind(connected_once)
246 .bind(id.0)
247 .execute(&self.pool)
248 .await
249 .map(drop)?)
250 }
251
252 async fn destroy_user(&self, id: UserId) -> Result<()> {
253 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
254 sqlx::query(query)
255 .bind(id.0)
256 .execute(&self.pool)
257 .await
258 .map(drop)?;
259 let query = "DELETE FROM users WHERE id = $1;";
260 Ok(sqlx::query(query)
261 .bind(id.0)
262 .execute(&self.pool)
263 .await
264 .map(drop)?)
265 }
266
267 // invite codes
268
269 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
270 let mut tx = self.pool.begin().await?;
271 if count > 0 {
272 sqlx::query(
273 "
274 UPDATE users
275 SET invite_code = $1
276 WHERE id = $2 AND invite_code IS NULL
277 ",
278 )
279 .bind(nanoid!(16))
280 .bind(id)
281 .execute(&mut tx)
282 .await?;
283 }
284
285 sqlx::query(
286 "
287 UPDATE users
288 SET invite_count = $1
289 WHERE id = $2
290 ",
291 )
292 .bind(count)
293 .bind(id)
294 .execute(&mut tx)
295 .await?;
296 tx.commit().await?;
297 Ok(())
298 }
299
300 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
301 let result: Option<(String, i32)> = sqlx::query_as(
302 "
303 SELECT invite_code, invite_count
304 FROM users
305 WHERE id = $1 AND invite_code IS NOT NULL
306 ",
307 )
308 .bind(id)
309 .fetch_optional(&self.pool)
310 .await?;
311 if let Some((code, count)) = result {
312 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
313 } else {
314 Ok(None)
315 }
316 }
317
318 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
319 sqlx::query_as(
320 "
321 SELECT *
322 FROM users
323 WHERE invite_code = $1
324 ",
325 )
326 .bind(code)
327 .fetch_optional(&self.pool)
328 .await?
329 .ok_or_else(|| {
330 Error::Http(
331 StatusCode::NOT_FOUND,
332 "that invite code does not exist".to_string(),
333 )
334 })
335 }
336
337 async fn redeem_invite_code(
338 &self,
339 code: &str,
340 login: &str,
341 email_address: Option<&str>,
342 ) -> Result<UserId> {
343 let mut tx = self.pool.begin().await?;
344
345 let inviter_id: Option<UserId> = sqlx::query_scalar(
346 "
347 UPDATE users
348 SET invite_count = invite_count - 1
349 WHERE
350 invite_code = $1 AND
351 invite_count > 0
352 RETURNING id
353 ",
354 )
355 .bind(code)
356 .fetch_optional(&mut tx)
357 .await?;
358
359 let inviter_id = match inviter_id {
360 Some(inviter_id) => inviter_id,
361 None => {
362 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
363 .bind(code)
364 .fetch_optional(&mut tx)
365 .await?
366 .is_some()
367 {
368 Err(Error::Http(
369 StatusCode::UNAUTHORIZED,
370 "no invites remaining".to_string(),
371 ))?
372 } else {
373 Err(Error::Http(
374 StatusCode::NOT_FOUND,
375 "invite code not found".to_string(),
376 ))?
377 }
378 }
379 };
380
381 let invitee_id = sqlx::query_scalar(
382 "
383 INSERT INTO users
384 (github_login, email_address, admin, inviter_id)
385 VALUES
386 ($1, $2, 'f', $3)
387 RETURNING id
388 ",
389 )
390 .bind(login)
391 .bind(email_address)
392 .bind(inviter_id)
393 .fetch_one(&mut tx)
394 .await
395 .map(UserId)?;
396
397 sqlx::query(
398 "
399 INSERT INTO contacts
400 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
401 VALUES
402 ($1, $2, 't', 't', 't')
403 ",
404 )
405 .bind(inviter_id)
406 .bind(invitee_id)
407 .execute(&mut tx)
408 .await?;
409
410 tx.commit().await?;
411 Ok(invitee_id)
412 }
413
414 // contacts
415
416 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
417 let query = "
418 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
419 FROM contacts
420 WHERE user_id_a = $1 OR user_id_b = $1;
421 ";
422
423 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
424 .bind(user_id)
425 .fetch(&self.pool);
426
427 let mut contacts = vec![Contact::Accepted {
428 user_id,
429 should_notify: false,
430 }];
431 while let Some(row) = rows.next().await {
432 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
433
434 if user_id_a == user_id {
435 if accepted {
436 contacts.push(Contact::Accepted {
437 user_id: user_id_b,
438 should_notify: should_notify && a_to_b,
439 });
440 } else if a_to_b {
441 contacts.push(Contact::Outgoing { user_id: user_id_b })
442 } else {
443 contacts.push(Contact::Incoming {
444 user_id: user_id_b,
445 should_notify,
446 });
447 }
448 } else {
449 if accepted {
450 contacts.push(Contact::Accepted {
451 user_id: user_id_a,
452 should_notify: should_notify && !a_to_b,
453 });
454 } else if a_to_b {
455 contacts.push(Contact::Incoming {
456 user_id: user_id_a,
457 should_notify,
458 });
459 } else {
460 contacts.push(Contact::Outgoing { user_id: user_id_a });
461 }
462 }
463 }
464
465 contacts.sort_unstable_by_key(|contact| contact.user_id());
466
467 Ok(contacts)
468 }
469
470 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
471 let (id_a, id_b) = if user_id_1 < user_id_2 {
472 (user_id_1, user_id_2)
473 } else {
474 (user_id_2, user_id_1)
475 };
476
477 let query = "
478 SELECT 1 FROM contacts
479 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
480 LIMIT 1
481 ";
482 Ok(sqlx::query_scalar::<_, i32>(query)
483 .bind(id_a.0)
484 .bind(id_b.0)
485 .fetch_optional(&self.pool)
486 .await?
487 .is_some())
488 }
489
490 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
491 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
492 (sender_id, receiver_id, true)
493 } else {
494 (receiver_id, sender_id, false)
495 };
496 let query = "
497 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
498 VALUES ($1, $2, $3, 'f', 't')
499 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
500 SET
501 accepted = 't',
502 should_notify = 'f'
503 WHERE
504 NOT contacts.accepted AND
505 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
506 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
507 ";
508 let result = sqlx::query(query)
509 .bind(id_a.0)
510 .bind(id_b.0)
511 .bind(a_to_b)
512 .execute(&self.pool)
513 .await?;
514
515 if result.rows_affected() == 1 {
516 Ok(())
517 } else {
518 Err(anyhow!("contact already requested"))?
519 }
520 }
521
522 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
523 let (id_a, id_b) = if responder_id < requester_id {
524 (responder_id, requester_id)
525 } else {
526 (requester_id, responder_id)
527 };
528 let query = "
529 DELETE FROM contacts
530 WHERE user_id_a = $1 AND user_id_b = $2;
531 ";
532 let result = sqlx::query(query)
533 .bind(id_a.0)
534 .bind(id_b.0)
535 .execute(&self.pool)
536 .await?;
537
538 if result.rows_affected() == 1 {
539 Ok(())
540 } else {
541 Err(anyhow!("no such contact"))?
542 }
543 }
544
545 async fn dismiss_contact_notification(
546 &self,
547 user_id: UserId,
548 contact_user_id: UserId,
549 ) -> Result<()> {
550 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
551 (user_id, contact_user_id, true)
552 } else {
553 (contact_user_id, user_id, false)
554 };
555
556 let query = "
557 UPDATE contacts
558 SET should_notify = 'f'
559 WHERE
560 user_id_a = $1 AND user_id_b = $2 AND
561 (
562 (a_to_b = $3 AND accepted) OR
563 (a_to_b != $3 AND NOT accepted)
564 );
565 ";
566
567 let result = sqlx::query(query)
568 .bind(id_a.0)
569 .bind(id_b.0)
570 .bind(a_to_b)
571 .execute(&self.pool)
572 .await?;
573
574 if result.rows_affected() == 0 {
575 Err(anyhow!("no such contact request"))?;
576 }
577
578 Ok(())
579 }
580
581 async fn respond_to_contact_request(
582 &self,
583 responder_id: UserId,
584 requester_id: UserId,
585 accept: bool,
586 ) -> Result<()> {
587 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
588 (responder_id, requester_id, false)
589 } else {
590 (requester_id, responder_id, true)
591 };
592 let result = if accept {
593 let query = "
594 UPDATE contacts
595 SET accepted = 't', should_notify = 't'
596 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
597 ";
598 sqlx::query(query)
599 .bind(id_a.0)
600 .bind(id_b.0)
601 .bind(a_to_b)
602 .execute(&self.pool)
603 .await?
604 } else {
605 let query = "
606 DELETE FROM contacts
607 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
608 ";
609 sqlx::query(query)
610 .bind(id_a.0)
611 .bind(id_b.0)
612 .bind(a_to_b)
613 .execute(&self.pool)
614 .await?
615 };
616 if result.rows_affected() == 1 {
617 Ok(())
618 } else {
619 Err(anyhow!("no such contact request"))?
620 }
621 }
622
623 // access tokens
624
625 async fn create_access_token_hash(
626 &self,
627 user_id: UserId,
628 access_token_hash: &str,
629 max_access_token_count: usize,
630 ) -> Result<()> {
631 let insert_query = "
632 INSERT INTO access_tokens (user_id, hash)
633 VALUES ($1, $2);
634 ";
635 let cleanup_query = "
636 DELETE FROM access_tokens
637 WHERE id IN (
638 SELECT id from access_tokens
639 WHERE user_id = $1
640 ORDER BY id DESC
641 OFFSET $3
642 )
643 ";
644
645 let mut tx = self.pool.begin().await?;
646 sqlx::query(insert_query)
647 .bind(user_id.0)
648 .bind(access_token_hash)
649 .execute(&mut tx)
650 .await?;
651 sqlx::query(cleanup_query)
652 .bind(user_id.0)
653 .bind(access_token_hash)
654 .bind(max_access_token_count as u32)
655 .execute(&mut tx)
656 .await?;
657 Ok(tx.commit().await?)
658 }
659
660 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
661 let query = "
662 SELECT hash
663 FROM access_tokens
664 WHERE user_id = $1
665 ORDER BY id DESC
666 ";
667 Ok(sqlx::query_scalar(query)
668 .bind(user_id.0)
669 .fetch_all(&self.pool)
670 .await?)
671 }
672
673 // orgs
674
675 #[allow(unused)] // Help rust-analyzer
676 #[cfg(any(test, feature = "seed-support"))]
677 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
678 let query = "
679 SELECT *
680 FROM orgs
681 WHERE slug = $1
682 ";
683 Ok(sqlx::query_as(query)
684 .bind(slug)
685 .fetch_optional(&self.pool)
686 .await?)
687 }
688
689 #[cfg(any(test, feature = "seed-support"))]
690 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
691 let query = "
692 INSERT INTO orgs (name, slug)
693 VALUES ($1, $2)
694 RETURNING id
695 ";
696 Ok(sqlx::query_scalar(query)
697 .bind(name)
698 .bind(slug)
699 .fetch_one(&self.pool)
700 .await
701 .map(OrgId)?)
702 }
703
704 #[cfg(any(test, feature = "seed-support"))]
705 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
706 let query = "
707 INSERT INTO org_memberships (org_id, user_id, admin)
708 VALUES ($1, $2, $3)
709 ON CONFLICT DO NOTHING
710 ";
711 Ok(sqlx::query(query)
712 .bind(org_id.0)
713 .bind(user_id.0)
714 .bind(is_admin)
715 .execute(&self.pool)
716 .await
717 .map(drop)?)
718 }
719
720 // channels
721
722 #[cfg(any(test, feature = "seed-support"))]
723 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
724 let query = "
725 INSERT INTO channels (owner_id, owner_is_user, name)
726 VALUES ($1, false, $2)
727 RETURNING id
728 ";
729 Ok(sqlx::query_scalar(query)
730 .bind(org_id.0)
731 .bind(name)
732 .fetch_one(&self.pool)
733 .await
734 .map(ChannelId)?)
735 }
736
737 #[allow(unused)] // Help rust-analyzer
738 #[cfg(any(test, feature = "seed-support"))]
739 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
740 let query = "
741 SELECT *
742 FROM channels
743 WHERE
744 channels.owner_is_user = false AND
745 channels.owner_id = $1
746 ";
747 Ok(sqlx::query_as(query)
748 .bind(org_id.0)
749 .fetch_all(&self.pool)
750 .await?)
751 }
752
753 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
754 let query = "
755 SELECT
756 channels.*
757 FROM
758 channel_memberships, channels
759 WHERE
760 channel_memberships.user_id = $1 AND
761 channel_memberships.channel_id = channels.id
762 ";
763 Ok(sqlx::query_as(query)
764 .bind(user_id.0)
765 .fetch_all(&self.pool)
766 .await?)
767 }
768
769 async fn can_user_access_channel(
770 &self,
771 user_id: UserId,
772 channel_id: ChannelId,
773 ) -> Result<bool> {
774 let query = "
775 SELECT id
776 FROM channel_memberships
777 WHERE user_id = $1 AND channel_id = $2
778 LIMIT 1
779 ";
780 Ok(sqlx::query_scalar::<_, i32>(query)
781 .bind(user_id.0)
782 .bind(channel_id.0)
783 .fetch_optional(&self.pool)
784 .await
785 .map(|e| e.is_some())?)
786 }
787
788 #[cfg(any(test, feature = "seed-support"))]
789 async fn add_channel_member(
790 &self,
791 channel_id: ChannelId,
792 user_id: UserId,
793 is_admin: bool,
794 ) -> Result<()> {
795 let query = "
796 INSERT INTO channel_memberships (channel_id, user_id, admin)
797 VALUES ($1, $2, $3)
798 ON CONFLICT DO NOTHING
799 ";
800 Ok(sqlx::query(query)
801 .bind(channel_id.0)
802 .bind(user_id.0)
803 .bind(is_admin)
804 .execute(&self.pool)
805 .await
806 .map(drop)?)
807 }
808
809 // messages
810
811 async fn create_channel_message(
812 &self,
813 channel_id: ChannelId,
814 sender_id: UserId,
815 body: &str,
816 timestamp: OffsetDateTime,
817 nonce: u128,
818 ) -> Result<MessageId> {
819 let query = "
820 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
821 VALUES ($1, $2, $3, $4, $5)
822 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
823 RETURNING id
824 ";
825 Ok(sqlx::query_scalar(query)
826 .bind(channel_id.0)
827 .bind(sender_id.0)
828 .bind(body)
829 .bind(timestamp)
830 .bind(Uuid::from_u128(nonce))
831 .fetch_one(&self.pool)
832 .await
833 .map(MessageId)?)
834 }
835
836 async fn get_channel_messages(
837 &self,
838 channel_id: ChannelId,
839 count: usize,
840 before_id: Option<MessageId>,
841 ) -> Result<Vec<ChannelMessage>> {
842 let query = r#"
843 SELECT * FROM (
844 SELECT
845 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
846 FROM
847 channel_messages
848 WHERE
849 channel_id = $1 AND
850 id < $2
851 ORDER BY id DESC
852 LIMIT $3
853 ) as recent_messages
854 ORDER BY id ASC
855 "#;
856 Ok(sqlx::query_as(query)
857 .bind(channel_id.0)
858 .bind(before_id.unwrap_or(MessageId::MAX))
859 .bind(count as i64)
860 .fetch_all(&self.pool)
861 .await?)
862 }
863
864 #[cfg(test)]
865 async fn teardown(&self, url: &str) {
866 use util::ResultExt;
867
868 let query = "
869 SELECT pg_terminate_backend(pg_stat_activity.pid)
870 FROM pg_stat_activity
871 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
872 ";
873 sqlx::query(query).execute(&self.pool).await.log_err();
874 self.pool.close().await;
875 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
876 .await
877 .log_err();
878 }
879
880 #[cfg(test)]
881 fn as_fake(&self) -> Option<&tests::FakeDb> {
882 None
883 }
884}
885
886macro_rules! id_type {
887 ($name:ident) => {
888 #[derive(
889 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
890 )]
891 #[sqlx(transparent)]
892 #[serde(transparent)]
893 pub struct $name(pub i32);
894
895 impl $name {
896 #[allow(unused)]
897 pub const MAX: Self = Self(i32::MAX);
898
899 #[allow(unused)]
900 pub fn from_proto(value: u64) -> Self {
901 Self(value as i32)
902 }
903
904 #[allow(unused)]
905 pub fn to_proto(&self) -> u64 {
906 self.0 as u64
907 }
908 }
909
910 impl std::fmt::Display for $name {
911 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
912 self.0.fmt(f)
913 }
914 }
915 };
916}
917
918id_type!(UserId);
919#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
920pub struct User {
921 pub id: UserId,
922 pub github_login: String,
923 pub email_address: Option<String>,
924 pub admin: bool,
925 pub invite_code: Option<String>,
926 pub invite_count: i32,
927 pub connected_once: bool,
928}
929
930id_type!(OrgId);
931#[derive(FromRow)]
932pub struct Org {
933 pub id: OrgId,
934 pub name: String,
935 pub slug: String,
936}
937
938id_type!(ChannelId);
939#[derive(Clone, Debug, FromRow, Serialize)]
940pub struct Channel {
941 pub id: ChannelId,
942 pub name: String,
943 pub owner_id: i32,
944 pub owner_is_user: bool,
945}
946
947id_type!(MessageId);
948#[derive(Clone, Debug, FromRow)]
949pub struct ChannelMessage {
950 pub id: MessageId,
951 pub channel_id: ChannelId,
952 pub sender_id: UserId,
953 pub body: String,
954 pub sent_at: OffsetDateTime,
955 pub nonce: Uuid,
956}
957
958#[derive(Clone, Debug, PartialEq, Eq)]
959pub enum Contact {
960 Accepted {
961 user_id: UserId,
962 should_notify: bool,
963 },
964 Outgoing {
965 user_id: UserId,
966 },
967 Incoming {
968 user_id: UserId,
969 should_notify: bool,
970 },
971}
972
973impl Contact {
974 pub fn user_id(&self) -> UserId {
975 match self {
976 Contact::Accepted { user_id, .. } => *user_id,
977 Contact::Outgoing { user_id } => *user_id,
978 Contact::Incoming { user_id, .. } => *user_id,
979 }
980 }
981}
982
983#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
984pub struct IncomingContactRequest {
985 pub requester_id: UserId,
986 pub should_notify: bool,
987}
988
989fn fuzzy_like_string(string: &str) -> String {
990 let mut result = String::with_capacity(string.len() * 2 + 1);
991 for c in string.chars() {
992 if c.is_alphanumeric() {
993 result.push('%');
994 result.push(c);
995 }
996 }
997 result.push('%');
998 result
999}
1000
1001#[cfg(test)]
1002pub mod tests {
1003 use super::*;
1004 use anyhow::anyhow;
1005 use collections::BTreeMap;
1006 use gpui::executor::Background;
1007 use lazy_static::lazy_static;
1008 use parking_lot::Mutex;
1009 use rand::prelude::*;
1010 use sqlx::{
1011 migrate::{MigrateDatabase, Migrator},
1012 Postgres,
1013 };
1014 use std::{path::Path, sync::Arc};
1015 use util::post_inc;
1016
1017 #[tokio::test(flavor = "multi_thread")]
1018 async fn test_get_users_by_ids() {
1019 for test_db in [
1020 TestDb::postgres().await,
1021 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1022 ] {
1023 let db = test_db.db();
1024
1025 let user = db.create_user("user", None, false).await.unwrap();
1026 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1027 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1028 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1029
1030 assert_eq!(
1031 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1032 .await
1033 .unwrap(),
1034 vec![
1035 User {
1036 id: user,
1037 github_login: "user".to_string(),
1038 admin: false,
1039 ..Default::default()
1040 },
1041 User {
1042 id: friend1,
1043 github_login: "friend-1".to_string(),
1044 admin: false,
1045 ..Default::default()
1046 },
1047 User {
1048 id: friend2,
1049 github_login: "friend-2".to_string(),
1050 admin: false,
1051 ..Default::default()
1052 },
1053 User {
1054 id: friend3,
1055 github_login: "friend-3".to_string(),
1056 admin: false,
1057 ..Default::default()
1058 }
1059 ]
1060 );
1061 }
1062 }
1063
1064 #[tokio::test(flavor = "multi_thread")]
1065 async fn test_create_users() {
1066 let db = TestDb::postgres().await;
1067 let db = db.db();
1068
1069 // Create the first batch of users, ensuring invite counts are assigned
1070 // correctly and the respective invite codes are unique.
1071 let user_ids_batch_1 = db
1072 .create_users(vec![
1073 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1074 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1075 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1076 ])
1077 .await
1078 .unwrap();
1079 assert_eq!(user_ids_batch_1.len(), 3);
1080
1081 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1082 assert_eq!(users.len(), 3);
1083 assert_eq!(users[0].github_login, "user1");
1084 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1085 assert_eq!(users[0].invite_count, 5);
1086 assert_eq!(users[1].github_login, "user2");
1087 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1088 assert_eq!(users[1].invite_count, 4);
1089 assert_eq!(users[2].github_login, "user3");
1090 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1091 assert_eq!(users[2].invite_count, 3);
1092
1093 let invite_code_1 = users[0].invite_code.clone().unwrap();
1094 let invite_code_2 = users[1].invite_code.clone().unwrap();
1095 let invite_code_3 = users[2].invite_code.clone().unwrap();
1096 assert_ne!(invite_code_1, invite_code_2);
1097 assert_ne!(invite_code_1, invite_code_3);
1098 assert_ne!(invite_code_2, invite_code_3);
1099
1100 // Create the second batch of users and include a user that is already in the database, ensuring
1101 // the invite count for the existing user is updated without changing their invite code.
1102 let user_ids_batch_2 = db
1103 .create_users(vec![
1104 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1105 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1106 ])
1107 .await
1108 .unwrap();
1109 assert_eq!(user_ids_batch_2.len(), 2);
1110 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1111
1112 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1113 assert_eq!(users.len(), 2);
1114 assert_eq!(users[0].github_login, "user2");
1115 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1116 assert_eq!(users[0].invite_count, 10);
1117 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1118 assert_eq!(users[1].github_login, "user4");
1119 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1120 assert_eq!(users[1].invite_count, 2);
1121
1122 let invite_code_4 = users[1].invite_code.clone().unwrap();
1123 assert_ne!(invite_code_4, invite_code_1);
1124 assert_ne!(invite_code_4, invite_code_2);
1125 assert_ne!(invite_code_4, invite_code_3);
1126 }
1127
1128 #[tokio::test(flavor = "multi_thread")]
1129 async fn test_recent_channel_messages() {
1130 for test_db in [
1131 TestDb::postgres().await,
1132 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1133 ] {
1134 let db = test_db.db();
1135 let user = db.create_user("user", None, false).await.unwrap();
1136 let org = db.create_org("org", "org").await.unwrap();
1137 let channel = db.create_org_channel(org, "channel").await.unwrap();
1138 for i in 0..10 {
1139 db.create_channel_message(
1140 channel,
1141 user,
1142 &i.to_string(),
1143 OffsetDateTime::now_utc(),
1144 i,
1145 )
1146 .await
1147 .unwrap();
1148 }
1149
1150 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1151 assert_eq!(
1152 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1153 ["5", "6", "7", "8", "9"]
1154 );
1155
1156 let prev_messages = db
1157 .get_channel_messages(channel, 4, Some(messages[0].id))
1158 .await
1159 .unwrap();
1160 assert_eq!(
1161 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1162 ["1", "2", "3", "4"]
1163 );
1164 }
1165 }
1166
1167 #[tokio::test(flavor = "multi_thread")]
1168 async fn test_channel_message_nonces() {
1169 for test_db in [
1170 TestDb::postgres().await,
1171 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1172 ] {
1173 let db = test_db.db();
1174 let user = db.create_user("user", None, false).await.unwrap();
1175 let org = db.create_org("org", "org").await.unwrap();
1176 let channel = db.create_org_channel(org, "channel").await.unwrap();
1177
1178 let msg1_id = db
1179 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1180 .await
1181 .unwrap();
1182 let msg2_id = db
1183 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1184 .await
1185 .unwrap();
1186 let msg3_id = db
1187 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1188 .await
1189 .unwrap();
1190 let msg4_id = db
1191 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1192 .await
1193 .unwrap();
1194
1195 assert_ne!(msg1_id, msg2_id);
1196 assert_eq!(msg1_id, msg3_id);
1197 assert_eq!(msg2_id, msg4_id);
1198 }
1199 }
1200
1201 #[tokio::test(flavor = "multi_thread")]
1202 async fn test_create_access_tokens() {
1203 let test_db = TestDb::postgres().await;
1204 let db = test_db.db();
1205 let user = db.create_user("the-user", None, false).await.unwrap();
1206
1207 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1208 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1209 assert_eq!(
1210 db.get_access_token_hashes(user).await.unwrap(),
1211 &["h2".to_string(), "h1".to_string()]
1212 );
1213
1214 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1215 assert_eq!(
1216 db.get_access_token_hashes(user).await.unwrap(),
1217 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1218 );
1219
1220 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1221 assert_eq!(
1222 db.get_access_token_hashes(user).await.unwrap(),
1223 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1224 );
1225
1226 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1227 assert_eq!(
1228 db.get_access_token_hashes(user).await.unwrap(),
1229 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1230 );
1231 }
1232
1233 #[test]
1234 fn test_fuzzy_like_string() {
1235 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1236 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1237 assert_eq!(fuzzy_like_string(" z "), "%z%");
1238 }
1239
1240 #[tokio::test(flavor = "multi_thread")]
1241 async fn test_fuzzy_search_users() {
1242 let test_db = TestDb::postgres().await;
1243 let db = test_db.db();
1244 for github_login in [
1245 "California",
1246 "colorado",
1247 "oregon",
1248 "washington",
1249 "florida",
1250 "delaware",
1251 "rhode-island",
1252 ] {
1253 db.create_user(github_login, None, false).await.unwrap();
1254 }
1255
1256 assert_eq!(
1257 fuzzy_search_user_names(db, "clr").await,
1258 &["colorado", "California"]
1259 );
1260 assert_eq!(
1261 fuzzy_search_user_names(db, "ro").await,
1262 &["rhode-island", "colorado", "oregon"],
1263 );
1264
1265 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1266 db.fuzzy_search_users(query, 10)
1267 .await
1268 .unwrap()
1269 .into_iter()
1270 .map(|user| user.github_login)
1271 .collect::<Vec<_>>()
1272 }
1273 }
1274
1275 #[tokio::test(flavor = "multi_thread")]
1276 async fn test_add_contacts() {
1277 for test_db in [
1278 TestDb::postgres().await,
1279 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1280 ] {
1281 let db = test_db.db();
1282
1283 let user_1 = db.create_user("user1", None, false).await.unwrap();
1284 let user_2 = db.create_user("user2", None, false).await.unwrap();
1285 let user_3 = db.create_user("user3", None, false).await.unwrap();
1286
1287 // User starts with no contacts
1288 assert_eq!(
1289 db.get_contacts(user_1).await.unwrap(),
1290 vec![Contact::Accepted {
1291 user_id: user_1,
1292 should_notify: false
1293 }],
1294 );
1295
1296 // User requests a contact. Both users see the pending request.
1297 db.send_contact_request(user_1, user_2).await.unwrap();
1298 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1299 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1300 assert_eq!(
1301 db.get_contacts(user_1).await.unwrap(),
1302 &[
1303 Contact::Accepted {
1304 user_id: user_1,
1305 should_notify: false
1306 },
1307 Contact::Outgoing { user_id: user_2 }
1308 ],
1309 );
1310 assert_eq!(
1311 db.get_contacts(user_2).await.unwrap(),
1312 &[
1313 Contact::Incoming {
1314 user_id: user_1,
1315 should_notify: true
1316 },
1317 Contact::Accepted {
1318 user_id: user_2,
1319 should_notify: false
1320 },
1321 ]
1322 );
1323
1324 // User 2 dismisses the contact request notification without accepting or rejecting.
1325 // We shouldn't notify them again.
1326 db.dismiss_contact_notification(user_1, user_2)
1327 .await
1328 .unwrap_err();
1329 db.dismiss_contact_notification(user_2, user_1)
1330 .await
1331 .unwrap();
1332 assert_eq!(
1333 db.get_contacts(user_2).await.unwrap(),
1334 &[
1335 Contact::Incoming {
1336 user_id: user_1,
1337 should_notify: false
1338 },
1339 Contact::Accepted {
1340 user_id: user_2,
1341 should_notify: false
1342 },
1343 ]
1344 );
1345
1346 // User can't accept their own contact request
1347 db.respond_to_contact_request(user_1, user_2, true)
1348 .await
1349 .unwrap_err();
1350
1351 // User accepts a contact request. Both users see the contact.
1352 db.respond_to_contact_request(user_2, user_1, true)
1353 .await
1354 .unwrap();
1355 assert_eq!(
1356 db.get_contacts(user_1).await.unwrap(),
1357 &[
1358 Contact::Accepted {
1359 user_id: user_1,
1360 should_notify: false
1361 },
1362 Contact::Accepted {
1363 user_id: user_2,
1364 should_notify: true
1365 }
1366 ],
1367 );
1368 assert!(db.has_contact(user_1, user_2).await.unwrap());
1369 assert!(db.has_contact(user_2, user_1).await.unwrap());
1370 assert_eq!(
1371 db.get_contacts(user_2).await.unwrap(),
1372 &[
1373 Contact::Accepted {
1374 user_id: user_1,
1375 should_notify: false,
1376 },
1377 Contact::Accepted {
1378 user_id: user_2,
1379 should_notify: false,
1380 },
1381 ]
1382 );
1383
1384 // Users cannot re-request existing contacts.
1385 db.send_contact_request(user_1, user_2).await.unwrap_err();
1386 db.send_contact_request(user_2, user_1).await.unwrap_err();
1387
1388 // Users can't dismiss notifications of them accepting other users' requests.
1389 db.dismiss_contact_notification(user_2, user_1)
1390 .await
1391 .unwrap_err();
1392 assert_eq!(
1393 db.get_contacts(user_1).await.unwrap(),
1394 &[
1395 Contact::Accepted {
1396 user_id: user_1,
1397 should_notify: false
1398 },
1399 Contact::Accepted {
1400 user_id: user_2,
1401 should_notify: true,
1402 },
1403 ]
1404 );
1405
1406 // Users can dismiss notifications of other users accepting their requests.
1407 db.dismiss_contact_notification(user_1, user_2)
1408 .await
1409 .unwrap();
1410 assert_eq!(
1411 db.get_contacts(user_1).await.unwrap(),
1412 &[
1413 Contact::Accepted {
1414 user_id: user_1,
1415 should_notify: false
1416 },
1417 Contact::Accepted {
1418 user_id: user_2,
1419 should_notify: false,
1420 },
1421 ]
1422 );
1423
1424 // Users send each other concurrent contact requests and
1425 // see that they are immediately accepted.
1426 db.send_contact_request(user_1, user_3).await.unwrap();
1427 db.send_contact_request(user_3, user_1).await.unwrap();
1428 assert_eq!(
1429 db.get_contacts(user_1).await.unwrap(),
1430 &[
1431 Contact::Accepted {
1432 user_id: user_1,
1433 should_notify: false
1434 },
1435 Contact::Accepted {
1436 user_id: user_2,
1437 should_notify: false,
1438 },
1439 Contact::Accepted {
1440 user_id: user_3,
1441 should_notify: false
1442 },
1443 ]
1444 );
1445 assert_eq!(
1446 db.get_contacts(user_3).await.unwrap(),
1447 &[
1448 Contact::Accepted {
1449 user_id: user_1,
1450 should_notify: false
1451 },
1452 Contact::Accepted {
1453 user_id: user_3,
1454 should_notify: false
1455 }
1456 ],
1457 );
1458
1459 // User declines a contact request. Both users see that it is gone.
1460 db.send_contact_request(user_2, user_3).await.unwrap();
1461 db.respond_to_contact_request(user_3, user_2, false)
1462 .await
1463 .unwrap();
1464 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1465 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1466 assert_eq!(
1467 db.get_contacts(user_2).await.unwrap(),
1468 &[
1469 Contact::Accepted {
1470 user_id: user_1,
1471 should_notify: false
1472 },
1473 Contact::Accepted {
1474 user_id: user_2,
1475 should_notify: false
1476 }
1477 ]
1478 );
1479 assert_eq!(
1480 db.get_contacts(user_3).await.unwrap(),
1481 &[
1482 Contact::Accepted {
1483 user_id: user_1,
1484 should_notify: false
1485 },
1486 Contact::Accepted {
1487 user_id: user_3,
1488 should_notify: false
1489 }
1490 ],
1491 );
1492 }
1493 }
1494
1495 #[tokio::test(flavor = "multi_thread")]
1496 async fn test_invite_codes() {
1497 let postgres = TestDb::postgres().await;
1498 let db = postgres.db();
1499 let user1 = db.create_user("user-1", None, false).await.unwrap();
1500
1501 // Initially, user 1 has no invite code
1502 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1503
1504 // Setting invite count to 0 when no code is assigned does not assign a new code
1505 db.set_invite_count(user1, 0).await.unwrap();
1506 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1507
1508 // User 1 creates an invite code that can be used twice.
1509 db.set_invite_count(user1, 2).await.unwrap();
1510 let (invite_code, invite_count) =
1511 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1512 assert_eq!(invite_count, 2);
1513
1514 // User 2 redeems the invite code and becomes a contact of user 1.
1515 let user2 = db
1516 .redeem_invite_code(&invite_code, "user-2", None)
1517 .await
1518 .unwrap();
1519 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1520 assert_eq!(invite_count, 1);
1521 assert_eq!(
1522 db.get_contacts(user1).await.unwrap(),
1523 [
1524 Contact::Accepted {
1525 user_id: user1,
1526 should_notify: false
1527 },
1528 Contact::Accepted {
1529 user_id: user2,
1530 should_notify: true
1531 }
1532 ]
1533 );
1534 assert_eq!(
1535 db.get_contacts(user2).await.unwrap(),
1536 [
1537 Contact::Accepted {
1538 user_id: user1,
1539 should_notify: false
1540 },
1541 Contact::Accepted {
1542 user_id: user2,
1543 should_notify: false
1544 }
1545 ]
1546 );
1547
1548 // User 3 redeems the invite code and becomes a contact of user 1.
1549 let user3 = db
1550 .redeem_invite_code(&invite_code, "user-3", None)
1551 .await
1552 .unwrap();
1553 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1554 assert_eq!(invite_count, 0);
1555 assert_eq!(
1556 db.get_contacts(user1).await.unwrap(),
1557 [
1558 Contact::Accepted {
1559 user_id: user1,
1560 should_notify: false
1561 },
1562 Contact::Accepted {
1563 user_id: user2,
1564 should_notify: true
1565 },
1566 Contact::Accepted {
1567 user_id: user3,
1568 should_notify: true
1569 }
1570 ]
1571 );
1572 assert_eq!(
1573 db.get_contacts(user3).await.unwrap(),
1574 [
1575 Contact::Accepted {
1576 user_id: user1,
1577 should_notify: false
1578 },
1579 Contact::Accepted {
1580 user_id: user3,
1581 should_notify: false
1582 },
1583 ]
1584 );
1585
1586 // Trying to reedem the code for the third time results in an error.
1587 db.redeem_invite_code(&invite_code, "user-4", None)
1588 .await
1589 .unwrap_err();
1590
1591 // Invite count can be updated after the code has been created.
1592 db.set_invite_count(user1, 2).await.unwrap();
1593 let (latest_code, invite_count) =
1594 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1595 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1596 assert_eq!(invite_count, 2);
1597
1598 // User 4 can now redeem the invite code and becomes a contact of user 1.
1599 let user4 = db
1600 .redeem_invite_code(&invite_code, "user-4", None)
1601 .await
1602 .unwrap();
1603 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1604 assert_eq!(invite_count, 1);
1605 assert_eq!(
1606 db.get_contacts(user1).await.unwrap(),
1607 [
1608 Contact::Accepted {
1609 user_id: user1,
1610 should_notify: false
1611 },
1612 Contact::Accepted {
1613 user_id: user2,
1614 should_notify: true
1615 },
1616 Contact::Accepted {
1617 user_id: user3,
1618 should_notify: true
1619 },
1620 Contact::Accepted {
1621 user_id: user4,
1622 should_notify: true
1623 }
1624 ]
1625 );
1626 assert_eq!(
1627 db.get_contacts(user4).await.unwrap(),
1628 [
1629 Contact::Accepted {
1630 user_id: user1,
1631 should_notify: false
1632 },
1633 Contact::Accepted {
1634 user_id: user4,
1635 should_notify: false
1636 },
1637 ]
1638 );
1639
1640 // An existing user cannot redeem invite codes.
1641 db.redeem_invite_code(&invite_code, "user-2", None)
1642 .await
1643 .unwrap_err();
1644 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1645 assert_eq!(invite_count, 1);
1646 }
1647
1648 pub struct TestDb {
1649 pub db: Option<Arc<dyn Db>>,
1650 pub url: String,
1651 }
1652
1653 impl TestDb {
1654 pub async fn postgres() -> Self {
1655 lazy_static! {
1656 static ref LOCK: Mutex<()> = Mutex::new(());
1657 }
1658
1659 let _guard = LOCK.lock();
1660 let mut rng = StdRng::from_entropy();
1661 let name = format!("zed-test-{}", rng.gen::<u128>());
1662 let url = format!("postgres://postgres@localhost/{}", name);
1663 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1664 Postgres::create_database(&url)
1665 .await
1666 .expect("failed to create test db");
1667 let db = PostgresDb::new(&url, 5).await.unwrap();
1668 let migrator = Migrator::new(migrations_path).await.unwrap();
1669 migrator.run(&db.pool).await.unwrap();
1670 Self {
1671 db: Some(Arc::new(db)),
1672 url,
1673 }
1674 }
1675
1676 pub fn fake(background: Arc<Background>) -> Self {
1677 Self {
1678 db: Some(Arc::new(FakeDb::new(background))),
1679 url: Default::default(),
1680 }
1681 }
1682
1683 pub fn db(&self) -> &Arc<dyn Db> {
1684 self.db.as_ref().unwrap()
1685 }
1686 }
1687
1688 impl Drop for TestDb {
1689 fn drop(&mut self) {
1690 if let Some(db) = self.db.take() {
1691 futures::executor::block_on(db.teardown(&self.url));
1692 }
1693 }
1694 }
1695
1696 pub struct FakeDb {
1697 background: Arc<Background>,
1698 pub users: Mutex<BTreeMap<UserId, User>>,
1699 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1700 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1701 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1702 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1703 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1704 pub contacts: Mutex<Vec<FakeContact>>,
1705 next_channel_message_id: Mutex<i32>,
1706 next_user_id: Mutex<i32>,
1707 next_org_id: Mutex<i32>,
1708 next_channel_id: Mutex<i32>,
1709 }
1710
1711 #[derive(Debug)]
1712 pub struct FakeContact {
1713 pub requester_id: UserId,
1714 pub responder_id: UserId,
1715 pub accepted: bool,
1716 pub should_notify: bool,
1717 }
1718
1719 impl FakeDb {
1720 pub fn new(background: Arc<Background>) -> Self {
1721 Self {
1722 background,
1723 users: Default::default(),
1724 next_user_id: Mutex::new(1),
1725 orgs: Default::default(),
1726 next_org_id: Mutex::new(1),
1727 org_memberships: Default::default(),
1728 channels: Default::default(),
1729 next_channel_id: Mutex::new(1),
1730 channel_memberships: Default::default(),
1731 channel_messages: Default::default(),
1732 next_channel_message_id: Mutex::new(1),
1733 contacts: Default::default(),
1734 }
1735 }
1736 }
1737
1738 #[async_trait]
1739 impl Db for FakeDb {
1740 async fn create_user(
1741 &self,
1742 github_login: &str,
1743 email_address: Option<&str>,
1744 admin: bool,
1745 ) -> Result<UserId> {
1746 self.background.simulate_random_delay().await;
1747
1748 let mut users = self.users.lock();
1749 if let Some(user) = users
1750 .values()
1751 .find(|user| user.github_login == github_login)
1752 {
1753 Ok(user.id)
1754 } else {
1755 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1756 users.insert(
1757 user_id,
1758 User {
1759 id: user_id,
1760 github_login: github_login.to_string(),
1761 email_address: email_address.map(str::to_string),
1762 admin,
1763 invite_code: None,
1764 invite_count: 0,
1765 connected_once: false,
1766 },
1767 );
1768 Ok(user_id)
1769 }
1770 }
1771
1772 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1773 unimplemented!()
1774 }
1775
1776 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
1777 unimplemented!()
1778 }
1779
1780 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1781 unimplemented!()
1782 }
1783
1784 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1785 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1786 }
1787
1788 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1789 self.background.simulate_random_delay().await;
1790 let users = self.users.lock();
1791 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1792 }
1793
1794 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1795 Ok(self
1796 .users
1797 .lock()
1798 .values()
1799 .find(|user| user.github_login == github_login)
1800 .cloned())
1801 }
1802
1803 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1804 unimplemented!()
1805 }
1806
1807 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1808 self.background.simulate_random_delay().await;
1809 let mut users = self.users.lock();
1810 let mut user = users
1811 .get_mut(&id)
1812 .ok_or_else(|| anyhow!("user not found"))?;
1813 user.connected_once = connected_once;
1814 Ok(())
1815 }
1816
1817 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1818 unimplemented!()
1819 }
1820
1821 // invite codes
1822
1823 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1824 unimplemented!()
1825 }
1826
1827 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1828 Ok(None)
1829 }
1830
1831 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1832 unimplemented!()
1833 }
1834
1835 async fn redeem_invite_code(
1836 &self,
1837 _code: &str,
1838 _login: &str,
1839 _email_address: Option<&str>,
1840 ) -> Result<UserId> {
1841 unimplemented!()
1842 }
1843
1844 // contacts
1845
1846 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1847 self.background.simulate_random_delay().await;
1848 let mut contacts = vec![Contact::Accepted {
1849 user_id: id,
1850 should_notify: false,
1851 }];
1852
1853 for contact in self.contacts.lock().iter() {
1854 if contact.requester_id == id {
1855 if contact.accepted {
1856 contacts.push(Contact::Accepted {
1857 user_id: contact.responder_id,
1858 should_notify: contact.should_notify,
1859 });
1860 } else {
1861 contacts.push(Contact::Outgoing {
1862 user_id: contact.responder_id,
1863 });
1864 }
1865 } else if contact.responder_id == id {
1866 if contact.accepted {
1867 contacts.push(Contact::Accepted {
1868 user_id: contact.requester_id,
1869 should_notify: false,
1870 });
1871 } else {
1872 contacts.push(Contact::Incoming {
1873 user_id: contact.requester_id,
1874 should_notify: contact.should_notify,
1875 });
1876 }
1877 }
1878 }
1879
1880 contacts.sort_unstable_by_key(|contact| contact.user_id());
1881 Ok(contacts)
1882 }
1883
1884 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1885 self.background.simulate_random_delay().await;
1886 Ok(self.contacts.lock().iter().any(|contact| {
1887 contact.accepted
1888 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1889 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1890 }))
1891 }
1892
1893 async fn send_contact_request(
1894 &self,
1895 requester_id: UserId,
1896 responder_id: UserId,
1897 ) -> Result<()> {
1898 let mut contacts = self.contacts.lock();
1899 for contact in contacts.iter_mut() {
1900 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1901 if contact.accepted {
1902 Err(anyhow!("contact already exists"))?;
1903 } else {
1904 Err(anyhow!("contact already requested"))?;
1905 }
1906 }
1907 if contact.responder_id == requester_id && contact.requester_id == responder_id {
1908 if contact.accepted {
1909 Err(anyhow!("contact already exists"))?;
1910 } else {
1911 contact.accepted = true;
1912 contact.should_notify = false;
1913 return Ok(());
1914 }
1915 }
1916 }
1917 contacts.push(FakeContact {
1918 requester_id,
1919 responder_id,
1920 accepted: false,
1921 should_notify: true,
1922 });
1923 Ok(())
1924 }
1925
1926 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1927 self.contacts.lock().retain(|contact| {
1928 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1929 });
1930 Ok(())
1931 }
1932
1933 async fn dismiss_contact_notification(
1934 &self,
1935 user_id: UserId,
1936 contact_user_id: UserId,
1937 ) -> Result<()> {
1938 let mut contacts = self.contacts.lock();
1939 for contact in contacts.iter_mut() {
1940 if contact.requester_id == contact_user_id
1941 && contact.responder_id == user_id
1942 && !contact.accepted
1943 {
1944 contact.should_notify = false;
1945 return Ok(());
1946 }
1947 if contact.requester_id == user_id
1948 && contact.responder_id == contact_user_id
1949 && contact.accepted
1950 {
1951 contact.should_notify = false;
1952 return Ok(());
1953 }
1954 }
1955 Err(anyhow!("no such notification"))?
1956 }
1957
1958 async fn respond_to_contact_request(
1959 &self,
1960 responder_id: UserId,
1961 requester_id: UserId,
1962 accept: bool,
1963 ) -> Result<()> {
1964 let mut contacts = self.contacts.lock();
1965 for (ix, contact) in contacts.iter_mut().enumerate() {
1966 if contact.requester_id == requester_id && contact.responder_id == responder_id {
1967 if contact.accepted {
1968 Err(anyhow!("contact already confirmed"))?;
1969 }
1970 if accept {
1971 contact.accepted = true;
1972 contact.should_notify = true;
1973 } else {
1974 contacts.remove(ix);
1975 }
1976 return Ok(());
1977 }
1978 }
1979 Err(anyhow!("no such contact request"))?
1980 }
1981
1982 async fn create_access_token_hash(
1983 &self,
1984 _user_id: UserId,
1985 _access_token_hash: &str,
1986 _max_access_token_count: usize,
1987 ) -> Result<()> {
1988 unimplemented!()
1989 }
1990
1991 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1992 unimplemented!()
1993 }
1994
1995 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1996 unimplemented!()
1997 }
1998
1999 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2000 self.background.simulate_random_delay().await;
2001 let mut orgs = self.orgs.lock();
2002 if orgs.values().any(|org| org.slug == slug) {
2003 Err(anyhow!("org already exists"))?
2004 } else {
2005 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2006 orgs.insert(
2007 org_id,
2008 Org {
2009 id: org_id,
2010 name: name.to_string(),
2011 slug: slug.to_string(),
2012 },
2013 );
2014 Ok(org_id)
2015 }
2016 }
2017
2018 async fn add_org_member(
2019 &self,
2020 org_id: OrgId,
2021 user_id: UserId,
2022 is_admin: bool,
2023 ) -> Result<()> {
2024 self.background.simulate_random_delay().await;
2025 if !self.orgs.lock().contains_key(&org_id) {
2026 Err(anyhow!("org does not exist"))?;
2027 }
2028 if !self.users.lock().contains_key(&user_id) {
2029 Err(anyhow!("user does not exist"))?;
2030 }
2031
2032 self.org_memberships
2033 .lock()
2034 .entry((org_id, user_id))
2035 .or_insert(is_admin);
2036 Ok(())
2037 }
2038
2039 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2040 self.background.simulate_random_delay().await;
2041 if !self.orgs.lock().contains_key(&org_id) {
2042 Err(anyhow!("org does not exist"))?;
2043 }
2044
2045 let mut channels = self.channels.lock();
2046 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2047 channels.insert(
2048 channel_id,
2049 Channel {
2050 id: channel_id,
2051 name: name.to_string(),
2052 owner_id: org_id.0,
2053 owner_is_user: false,
2054 },
2055 );
2056 Ok(channel_id)
2057 }
2058
2059 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2060 self.background.simulate_random_delay().await;
2061 Ok(self
2062 .channels
2063 .lock()
2064 .values()
2065 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2066 .cloned()
2067 .collect())
2068 }
2069
2070 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2071 self.background.simulate_random_delay().await;
2072 let channels = self.channels.lock();
2073 let memberships = self.channel_memberships.lock();
2074 Ok(channels
2075 .values()
2076 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2077 .cloned()
2078 .collect())
2079 }
2080
2081 async fn can_user_access_channel(
2082 &self,
2083 user_id: UserId,
2084 channel_id: ChannelId,
2085 ) -> Result<bool> {
2086 self.background.simulate_random_delay().await;
2087 Ok(self
2088 .channel_memberships
2089 .lock()
2090 .contains_key(&(channel_id, user_id)))
2091 }
2092
2093 async fn add_channel_member(
2094 &self,
2095 channel_id: ChannelId,
2096 user_id: UserId,
2097 is_admin: bool,
2098 ) -> Result<()> {
2099 self.background.simulate_random_delay().await;
2100 if !self.channels.lock().contains_key(&channel_id) {
2101 Err(anyhow!("channel does not exist"))?;
2102 }
2103 if !self.users.lock().contains_key(&user_id) {
2104 Err(anyhow!("user does not exist"))?;
2105 }
2106
2107 self.channel_memberships
2108 .lock()
2109 .entry((channel_id, user_id))
2110 .or_insert(is_admin);
2111 Ok(())
2112 }
2113
2114 async fn create_channel_message(
2115 &self,
2116 channel_id: ChannelId,
2117 sender_id: UserId,
2118 body: &str,
2119 timestamp: OffsetDateTime,
2120 nonce: u128,
2121 ) -> Result<MessageId> {
2122 self.background.simulate_random_delay().await;
2123 if !self.channels.lock().contains_key(&channel_id) {
2124 Err(anyhow!("channel does not exist"))?;
2125 }
2126 if !self.users.lock().contains_key(&sender_id) {
2127 Err(anyhow!("user does not exist"))?;
2128 }
2129
2130 let mut messages = self.channel_messages.lock();
2131 if let Some(message) = messages
2132 .values()
2133 .find(|message| message.nonce.as_u128() == nonce)
2134 {
2135 Ok(message.id)
2136 } else {
2137 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2138 messages.insert(
2139 message_id,
2140 ChannelMessage {
2141 id: message_id,
2142 channel_id,
2143 sender_id,
2144 body: body.to_string(),
2145 sent_at: timestamp,
2146 nonce: Uuid::from_u128(nonce),
2147 },
2148 );
2149 Ok(message_id)
2150 }
2151 }
2152
2153 async fn get_channel_messages(
2154 &self,
2155 channel_id: ChannelId,
2156 count: usize,
2157 before_id: Option<MessageId>,
2158 ) -> Result<Vec<ChannelMessage>> {
2159 let mut messages = self
2160 .channel_messages
2161 .lock()
2162 .values()
2163 .rev()
2164 .filter(|message| {
2165 message.channel_id == channel_id
2166 && message.id < before_id.unwrap_or(MessageId::MAX)
2167 })
2168 .take(count)
2169 .cloned()
2170 .collect::<Vec<_>>();
2171 messages.sort_unstable_by_key(|message| message.id);
2172 Ok(messages)
2173 }
2174
2175 async fn teardown(&self, _: &str) {}
2176
2177 #[cfg(test)]
2178 fn as_fake(&self) -> Option<&FakeDb> {
2179 Some(self)
2180 }
2181 }
2182}