1use anyhow::Context;
2use anyhow::Result;
3pub use async_sqlx_session::PostgresSessionStore as SessionStore;
4use async_std::task::{block_on, yield_now};
5use async_trait::async_trait;
6use serde::Serialize;
7pub use sqlx::postgres::PgPoolOptions as DbOptions;
8use sqlx::{types::Uuid, FromRow};
9use time::OffsetDateTime;
10
11macro_rules! test_support {
12 ($self:ident, { $($token:tt)* }) => {{
13 let body = async {
14 $($token)*
15 };
16 if $self.test_mode {
17 yield_now().await;
18 block_on(body)
19 } else {
20 body.await
21 }
22 }};
23}
24
25#[async_trait]
26pub trait Db: Send + Sync {
27 async fn create_signup(
28 &self,
29 github_login: &str,
30 email_address: &str,
31 about: &str,
32 wants_releases: bool,
33 wants_updates: bool,
34 wants_community: bool,
35 ) -> Result<SignupId>;
36 async fn get_all_signups(&self) -> Result<Vec<Signup>>;
37 async fn destroy_signup(&self, id: SignupId) -> Result<()>;
38 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
39 async fn get_all_users(&self) -> Result<Vec<User>>;
40 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
41 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
42 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
43 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
44 async fn destroy_user(&self, id: UserId) -> Result<()>;
45 async fn create_access_token_hash(
46 &self,
47 user_id: UserId,
48 access_token_hash: &str,
49 max_access_token_count: usize,
50 ) -> Result<()>;
51 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
52 #[cfg(any(test, feature = "seed-support"))]
53 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
54 #[cfg(any(test, feature = "seed-support"))]
55 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
56 #[cfg(any(test, feature = "seed-support"))]
57 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
58 #[cfg(any(test, feature = "seed-support"))]
59 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
60 #[cfg(any(test, feature = "seed-support"))]
61 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
62 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
63 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
64 -> Result<bool>;
65 #[cfg(any(test, feature = "seed-support"))]
66 async fn add_channel_member(
67 &self,
68 channel_id: ChannelId,
69 user_id: UserId,
70 is_admin: bool,
71 ) -> Result<()>;
72 async fn create_channel_message(
73 &self,
74 channel_id: ChannelId,
75 sender_id: UserId,
76 body: &str,
77 timestamp: OffsetDateTime,
78 nonce: u128,
79 ) -> Result<MessageId>;
80 async fn get_channel_messages(
81 &self,
82 channel_id: ChannelId,
83 count: usize,
84 before_id: Option<MessageId>,
85 ) -> Result<Vec<ChannelMessage>>;
86 #[cfg(test)]
87 async fn teardown(&self, name: &str, url: &str);
88}
89
90pub struct PostgresDb {
91 pool: sqlx::PgPool,
92 test_mode: bool,
93}
94
95impl PostgresDb {
96 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
97 let pool = DbOptions::new()
98 .max_connections(max_connections)
99 .connect(url)
100 .await
101 .context("failed to connect to postgres database")?;
102 Ok(Self {
103 pool,
104 test_mode: false,
105 })
106 }
107}
108
109#[async_trait]
110impl Db for PostgresDb {
111 // signups
112 async fn create_signup(
113 &self,
114 github_login: &str,
115 email_address: &str,
116 about: &str,
117 wants_releases: bool,
118 wants_updates: bool,
119 wants_community: bool,
120 ) -> Result<SignupId> {
121 test_support!(self, {
122 let query = "
123 INSERT INTO signups (
124 github_login,
125 email_address,
126 about,
127 wants_releases,
128 wants_updates,
129 wants_community
130 )
131 VALUES ($1, $2, $3, $4, $5, $6)
132 RETURNING id
133 ";
134 Ok(sqlx::query_scalar(query)
135 .bind(github_login)
136 .bind(email_address)
137 .bind(about)
138 .bind(wants_releases)
139 .bind(wants_updates)
140 .bind(wants_community)
141 .fetch_one(&self.pool)
142 .await
143 .map(SignupId)?)
144 })
145 }
146
147 async fn get_all_signups(&self) -> Result<Vec<Signup>> {
148 test_support!(self, {
149 let query = "SELECT * FROM signups ORDER BY github_login ASC";
150 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
151 })
152 }
153
154 async fn destroy_signup(&self, id: SignupId) -> Result<()> {
155 test_support!(self, {
156 let query = "DELETE FROM signups WHERE id = $1";
157 Ok(sqlx::query(query)
158 .bind(id.0)
159 .execute(&self.pool)
160 .await
161 .map(drop)?)
162 })
163 }
164
165 // users
166
167 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
168 test_support!(self, {
169 let query = "
170 INSERT INTO users (github_login, admin)
171 VALUES ($1, $2)
172 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
173 RETURNING id
174 ";
175 Ok(sqlx::query_scalar(query)
176 .bind(github_login)
177 .bind(admin)
178 .fetch_one(&self.pool)
179 .await
180 .map(UserId)?)
181 })
182 }
183
184 async fn get_all_users(&self) -> Result<Vec<User>> {
185 test_support!(self, {
186 let query = "SELECT * FROM users ORDER BY github_login ASC";
187 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
188 })
189 }
190
191 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
192 let users = self.get_users_by_ids(vec![id]).await?;
193 Ok(users.into_iter().next())
194 }
195
196 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
197 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
198 test_support!(self, {
199 let query = "
200 SELECT users.*
201 FROM users
202 WHERE users.id = ANY ($1)
203 ";
204
205 Ok(sqlx::query_as(query)
206 .bind(&ids)
207 .fetch_all(&self.pool)
208 .await?)
209 })
210 }
211
212 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
213 test_support!(self, {
214 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
215 Ok(sqlx::query_as(query)
216 .bind(github_login)
217 .fetch_optional(&self.pool)
218 .await?)
219 })
220 }
221
222 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
223 test_support!(self, {
224 let query = "UPDATE users SET admin = $1 WHERE id = $2";
225 Ok(sqlx::query(query)
226 .bind(is_admin)
227 .bind(id.0)
228 .execute(&self.pool)
229 .await
230 .map(drop)?)
231 })
232 }
233
234 async fn destroy_user(&self, id: UserId) -> Result<()> {
235 test_support!(self, {
236 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
237 sqlx::query(query)
238 .bind(id.0)
239 .execute(&self.pool)
240 .await
241 .map(drop)?;
242 let query = "DELETE FROM users WHERE id = $1;";
243 Ok(sqlx::query(query)
244 .bind(id.0)
245 .execute(&self.pool)
246 .await
247 .map(drop)?)
248 })
249 }
250
251 // access tokens
252
253 async fn create_access_token_hash(
254 &self,
255 user_id: UserId,
256 access_token_hash: &str,
257 max_access_token_count: usize,
258 ) -> Result<()> {
259 test_support!(self, {
260 let insert_query = "
261 INSERT INTO access_tokens (user_id, hash)
262 VALUES ($1, $2);
263 ";
264 let cleanup_query = "
265 DELETE FROM access_tokens
266 WHERE id IN (
267 SELECT id from access_tokens
268 WHERE user_id = $1
269 ORDER BY id DESC
270 OFFSET $3
271 )
272 ";
273
274 let mut tx = self.pool.begin().await?;
275 sqlx::query(insert_query)
276 .bind(user_id.0)
277 .bind(access_token_hash)
278 .execute(&mut tx)
279 .await?;
280 sqlx::query(cleanup_query)
281 .bind(user_id.0)
282 .bind(access_token_hash)
283 .bind(max_access_token_count as u32)
284 .execute(&mut tx)
285 .await?;
286 Ok(tx.commit().await?)
287 })
288 }
289
290 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
291 test_support!(self, {
292 let query = "
293 SELECT hash
294 FROM access_tokens
295 WHERE user_id = $1
296 ORDER BY id DESC
297 ";
298 Ok(sqlx::query_scalar(query)
299 .bind(user_id.0)
300 .fetch_all(&self.pool)
301 .await?)
302 })
303 }
304
305 // orgs
306
307 #[allow(unused)] // Help rust-analyzer
308 #[cfg(any(test, feature = "seed-support"))]
309 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
310 test_support!(self, {
311 let query = "
312 SELECT *
313 FROM orgs
314 WHERE slug = $1
315 ";
316 Ok(sqlx::query_as(query)
317 .bind(slug)
318 .fetch_optional(&self.pool)
319 .await?)
320 })
321 }
322
323 #[cfg(any(test, feature = "seed-support"))]
324 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
325 test_support!(self, {
326 let query = "
327 INSERT INTO orgs (name, slug)
328 VALUES ($1, $2)
329 RETURNING id
330 ";
331 Ok(sqlx::query_scalar(query)
332 .bind(name)
333 .bind(slug)
334 .fetch_one(&self.pool)
335 .await
336 .map(OrgId)?)
337 })
338 }
339
340 #[cfg(any(test, feature = "seed-support"))]
341 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
342 test_support!(self, {
343 let query = "
344 INSERT INTO org_memberships (org_id, user_id, admin)
345 VALUES ($1, $2, $3)
346 ON CONFLICT DO NOTHING
347 ";
348 Ok(sqlx::query(query)
349 .bind(org_id.0)
350 .bind(user_id.0)
351 .bind(is_admin)
352 .execute(&self.pool)
353 .await
354 .map(drop)?)
355 })
356 }
357
358 // channels
359
360 #[cfg(any(test, feature = "seed-support"))]
361 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
362 test_support!(self, {
363 let query = "
364 INSERT INTO channels (owner_id, owner_is_user, name)
365 VALUES ($1, false, $2)
366 RETURNING id
367 ";
368 Ok(sqlx::query_scalar(query)
369 .bind(org_id.0)
370 .bind(name)
371 .fetch_one(&self.pool)
372 .await
373 .map(ChannelId)?)
374 })
375 }
376
377 #[allow(unused)] // Help rust-analyzer
378 #[cfg(any(test, feature = "seed-support"))]
379 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
380 test_support!(self, {
381 let query = "
382 SELECT *
383 FROM channels
384 WHERE
385 channels.owner_is_user = false AND
386 channels.owner_id = $1
387 ";
388 Ok(sqlx::query_as(query)
389 .bind(org_id.0)
390 .fetch_all(&self.pool)
391 .await?)
392 })
393 }
394
395 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
396 test_support!(self, {
397 let query = "
398 SELECT
399 channels.*
400 FROM
401 channel_memberships, channels
402 WHERE
403 channel_memberships.user_id = $1 AND
404 channel_memberships.channel_id = channels.id
405 ";
406 Ok(sqlx::query_as(query)
407 .bind(user_id.0)
408 .fetch_all(&self.pool)
409 .await?)
410 })
411 }
412
413 async fn can_user_access_channel(
414 &self,
415 user_id: UserId,
416 channel_id: ChannelId,
417 ) -> Result<bool> {
418 test_support!(self, {
419 let query = "
420 SELECT id
421 FROM channel_memberships
422 WHERE user_id = $1 AND channel_id = $2
423 LIMIT 1
424 ";
425 Ok(sqlx::query_scalar::<_, i32>(query)
426 .bind(user_id.0)
427 .bind(channel_id.0)
428 .fetch_optional(&self.pool)
429 .await
430 .map(|e| e.is_some())?)
431 })
432 }
433
434 #[cfg(any(test, feature = "seed-support"))]
435 async fn add_channel_member(
436 &self,
437 channel_id: ChannelId,
438 user_id: UserId,
439 is_admin: bool,
440 ) -> Result<()> {
441 test_support!(self, {
442 let query = "
443 INSERT INTO channel_memberships (channel_id, user_id, admin)
444 VALUES ($1, $2, $3)
445 ON CONFLICT DO NOTHING
446 ";
447 Ok(sqlx::query(query)
448 .bind(channel_id.0)
449 .bind(user_id.0)
450 .bind(is_admin)
451 .execute(&self.pool)
452 .await
453 .map(drop)?)
454 })
455 }
456
457 // messages
458
459 async fn create_channel_message(
460 &self,
461 channel_id: ChannelId,
462 sender_id: UserId,
463 body: &str,
464 timestamp: OffsetDateTime,
465 nonce: u128,
466 ) -> Result<MessageId> {
467 test_support!(self, {
468 let query = "
469 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
470 VALUES ($1, $2, $3, $4, $5)
471 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
472 RETURNING id
473 ";
474 Ok(sqlx::query_scalar(query)
475 .bind(channel_id.0)
476 .bind(sender_id.0)
477 .bind(body)
478 .bind(timestamp)
479 .bind(Uuid::from_u128(nonce))
480 .fetch_one(&self.pool)
481 .await
482 .map(MessageId)?)
483 })
484 }
485
486 async fn get_channel_messages(
487 &self,
488 channel_id: ChannelId,
489 count: usize,
490 before_id: Option<MessageId>,
491 ) -> Result<Vec<ChannelMessage>> {
492 test_support!(self, {
493 let query = r#"
494 SELECT * FROM (
495 SELECT
496 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
497 FROM
498 channel_messages
499 WHERE
500 channel_id = $1 AND
501 id < $2
502 ORDER BY id DESC
503 LIMIT $3
504 ) as recent_messages
505 ORDER BY id ASC
506 "#;
507 Ok(sqlx::query_as(query)
508 .bind(channel_id.0)
509 .bind(before_id.unwrap_or(MessageId::MAX))
510 .bind(count as i64)
511 .fetch_all(&self.pool)
512 .await?)
513 })
514 }
515
516 #[cfg(test)]
517 async fn teardown(&self, name: &str, url: &str) {
518 use util::ResultExt;
519
520 test_support!(self, {
521 let query = "
522 SELECT pg_terminate_backend(pg_stat_activity.pid)
523 FROM pg_stat_activity
524 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
525 ";
526 sqlx::query(query)
527 .bind(name)
528 .execute(&self.pool)
529 .await
530 .log_err();
531 self.pool.close().await;
532 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
533 .await
534 .log_err();
535 })
536 }
537}
538
539macro_rules! id_type {
540 ($name:ident) => {
541 #[derive(
542 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
543 )]
544 #[sqlx(transparent)]
545 #[serde(transparent)]
546 pub struct $name(pub i32);
547
548 impl $name {
549 #[allow(unused)]
550 pub const MAX: Self = Self(i32::MAX);
551
552 #[allow(unused)]
553 pub fn from_proto(value: u64) -> Self {
554 Self(value as i32)
555 }
556
557 #[allow(unused)]
558 pub fn to_proto(&self) -> u64 {
559 self.0 as u64
560 }
561 }
562 };
563}
564
565id_type!(UserId);
566#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
567pub struct User {
568 pub id: UserId,
569 pub github_login: String,
570 pub admin: bool,
571}
572
573id_type!(OrgId);
574#[derive(FromRow)]
575pub struct Org {
576 pub id: OrgId,
577 pub name: String,
578 pub slug: String,
579}
580
581id_type!(SignupId);
582#[derive(Debug, FromRow, Serialize)]
583pub struct Signup {
584 pub id: SignupId,
585 pub github_login: String,
586 pub email_address: String,
587 pub about: String,
588 pub wants_releases: Option<bool>,
589 pub wants_updates: Option<bool>,
590 pub wants_community: Option<bool>,
591}
592
593id_type!(ChannelId);
594#[derive(Clone, Debug, FromRow, Serialize)]
595pub struct Channel {
596 pub id: ChannelId,
597 pub name: String,
598 pub owner_id: i32,
599 pub owner_is_user: bool,
600}
601
602id_type!(MessageId);
603#[derive(Clone, Debug, FromRow)]
604pub struct ChannelMessage {
605 pub id: MessageId,
606 pub channel_id: ChannelId,
607 pub sender_id: UserId,
608 pub body: String,
609 pub sent_at: OffsetDateTime,
610 pub nonce: Uuid,
611}
612
613#[cfg(test)]
614pub mod tests {
615 use super::*;
616 use anyhow::anyhow;
617 use collections::BTreeMap;
618 use gpui::{executor::Background, TestAppContext};
619 use lazy_static::lazy_static;
620 use parking_lot::Mutex;
621 use rand::prelude::*;
622 use sqlx::{
623 migrate::{MigrateDatabase, Migrator},
624 Postgres,
625 };
626 use std::{path::Path, sync::Arc};
627 use util::post_inc;
628
629 #[gpui::test]
630 async fn test_get_users_by_ids(cx: &mut TestAppContext) {
631 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
632 let db = test_db.db();
633
634 let user = db.create_user("user", false).await.unwrap();
635 let friend1 = db.create_user("friend-1", false).await.unwrap();
636 let friend2 = db.create_user("friend-2", false).await.unwrap();
637 let friend3 = db.create_user("friend-3", false).await.unwrap();
638
639 assert_eq!(
640 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
641 .await
642 .unwrap(),
643 vec![
644 User {
645 id: user,
646 github_login: "user".to_string(),
647 admin: false,
648 },
649 User {
650 id: friend1,
651 github_login: "friend-1".to_string(),
652 admin: false,
653 },
654 User {
655 id: friend2,
656 github_login: "friend-2".to_string(),
657 admin: false,
658 },
659 User {
660 id: friend3,
661 github_login: "friend-3".to_string(),
662 admin: false,
663 }
664 ]
665 );
666 }
667 }
668
669 #[gpui::test]
670 async fn test_recent_channel_messages(cx: &mut TestAppContext) {
671 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
672 let db = test_db.db();
673 let user = db.create_user("user", false).await.unwrap();
674 let org = db.create_org("org", "org").await.unwrap();
675 let channel = db.create_org_channel(org, "channel").await.unwrap();
676 for i in 0..10 {
677 db.create_channel_message(
678 channel,
679 user,
680 &i.to_string(),
681 OffsetDateTime::now_utc(),
682 i,
683 )
684 .await
685 .unwrap();
686 }
687
688 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
689 assert_eq!(
690 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
691 ["5", "6", "7", "8", "9"]
692 );
693
694 let prev_messages = db
695 .get_channel_messages(channel, 4, Some(messages[0].id))
696 .await
697 .unwrap();
698 assert_eq!(
699 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
700 ["1", "2", "3", "4"]
701 );
702 }
703 }
704
705 #[gpui::test]
706 async fn test_channel_message_nonces(cx: &mut TestAppContext) {
707 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
708 let db = test_db.db();
709 let user = db.create_user("user", false).await.unwrap();
710 let org = db.create_org("org", "org").await.unwrap();
711 let channel = db.create_org_channel(org, "channel").await.unwrap();
712
713 let msg1_id = db
714 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
715 .await
716 .unwrap();
717 let msg2_id = db
718 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
719 .await
720 .unwrap();
721 let msg3_id = db
722 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
723 .await
724 .unwrap();
725 let msg4_id = db
726 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
727 .await
728 .unwrap();
729
730 assert_ne!(msg1_id, msg2_id);
731 assert_eq!(msg1_id, msg3_id);
732 assert_eq!(msg2_id, msg4_id);
733 }
734 }
735
736 #[gpui::test]
737 async fn test_create_access_tokens() {
738 let test_db = TestDb::postgres();
739 let db = test_db.db();
740 let user = db.create_user("the-user", false).await.unwrap();
741
742 db.create_access_token_hash(user, "h1", 3).await.unwrap();
743 db.create_access_token_hash(user, "h2", 3).await.unwrap();
744 assert_eq!(
745 db.get_access_token_hashes(user).await.unwrap(),
746 &["h2".to_string(), "h1".to_string()]
747 );
748
749 db.create_access_token_hash(user, "h3", 3).await.unwrap();
750 assert_eq!(
751 db.get_access_token_hashes(user).await.unwrap(),
752 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
753 );
754
755 db.create_access_token_hash(user, "h4", 3).await.unwrap();
756 assert_eq!(
757 db.get_access_token_hashes(user).await.unwrap(),
758 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
759 );
760
761 db.create_access_token_hash(user, "h5", 3).await.unwrap();
762 assert_eq!(
763 db.get_access_token_hashes(user).await.unwrap(),
764 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
765 );
766 }
767
768 pub struct TestDb {
769 pub db: Option<Arc<dyn Db>>,
770 pub name: String,
771 pub url: String,
772 }
773
774 impl TestDb {
775 pub fn postgres() -> Self {
776 lazy_static! {
777 static ref LOCK: Mutex<()> = Mutex::new(());
778 }
779
780 let _guard = LOCK.lock();
781 let mut rng = StdRng::from_entropy();
782 let name = format!("zed-test-{}", rng.gen::<u128>());
783 let url = format!("postgres://postgres@localhost/{}", name);
784 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
785 let db = block_on(async {
786 Postgres::create_database(&url)
787 .await
788 .expect("failed to create test db");
789 let mut db = PostgresDb::new(&url, 5).await.unwrap();
790 db.test_mode = true;
791 let migrator = Migrator::new(migrations_path).await.unwrap();
792 migrator.run(&db.pool).await.unwrap();
793 db
794 });
795 Self {
796 db: Some(Arc::new(db)),
797 name,
798 url,
799 }
800 }
801
802 pub fn fake(background: Arc<Background>) -> Self {
803 Self {
804 db: Some(Arc::new(FakeDb::new(background))),
805 name: "fake".to_string(),
806 url: "fake".to_string(),
807 }
808 }
809
810 pub fn db(&self) -> &Arc<dyn Db> {
811 self.db.as_ref().unwrap()
812 }
813 }
814
815 impl Drop for TestDb {
816 fn drop(&mut self) {
817 if let Some(db) = self.db.take() {
818 block_on(db.teardown(&self.name, &self.url));
819 }
820 }
821 }
822
823 pub struct FakeDb {
824 background: Arc<Background>,
825 users: Mutex<BTreeMap<UserId, User>>,
826 next_user_id: Mutex<i32>,
827 orgs: Mutex<BTreeMap<OrgId, Org>>,
828 next_org_id: Mutex<i32>,
829 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
830 channels: Mutex<BTreeMap<ChannelId, Channel>>,
831 next_channel_id: Mutex<i32>,
832 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
833 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
834 next_channel_message_id: Mutex<i32>,
835 }
836
837 impl FakeDb {
838 pub fn new(background: Arc<Background>) -> Self {
839 Self {
840 background,
841 users: Default::default(),
842 next_user_id: Mutex::new(1),
843 orgs: Default::default(),
844 next_org_id: Mutex::new(1),
845 org_memberships: Default::default(),
846 channels: Default::default(),
847 next_channel_id: Mutex::new(1),
848 channel_memberships: Default::default(),
849 channel_messages: Default::default(),
850 next_channel_message_id: Mutex::new(1),
851 }
852 }
853 }
854
855 #[async_trait]
856 impl Db for FakeDb {
857 async fn create_signup(
858 &self,
859 _github_login: &str,
860 _email_address: &str,
861 _about: &str,
862 _wants_releases: bool,
863 _wants_updates: bool,
864 _wants_community: bool,
865 ) -> Result<SignupId> {
866 unimplemented!()
867 }
868
869 async fn get_all_signups(&self) -> Result<Vec<Signup>> {
870 unimplemented!()
871 }
872
873 async fn destroy_signup(&self, _id: SignupId) -> Result<()> {
874 unimplemented!()
875 }
876
877 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
878 self.background.simulate_random_delay().await;
879
880 let mut users = self.users.lock();
881 if let Some(user) = users
882 .values()
883 .find(|user| user.github_login == github_login)
884 {
885 Ok(user.id)
886 } else {
887 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
888 users.insert(
889 user_id,
890 User {
891 id: user_id,
892 github_login: github_login.to_string(),
893 admin,
894 },
895 );
896 Ok(user_id)
897 }
898 }
899
900 async fn get_all_users(&self) -> Result<Vec<User>> {
901 unimplemented!()
902 }
903
904 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
905 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
906 }
907
908 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
909 self.background.simulate_random_delay().await;
910 let users = self.users.lock();
911 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
912 }
913
914 async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
915 unimplemented!()
916 }
917
918 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
919 unimplemented!()
920 }
921
922 async fn destroy_user(&self, _id: UserId) -> Result<()> {
923 unimplemented!()
924 }
925
926 async fn create_access_token_hash(
927 &self,
928 _user_id: UserId,
929 _access_token_hash: &str,
930 _max_access_token_count: usize,
931 ) -> Result<()> {
932 unimplemented!()
933 }
934
935 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
936 unimplemented!()
937 }
938
939 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
940 unimplemented!()
941 }
942
943 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
944 self.background.simulate_random_delay().await;
945 let mut orgs = self.orgs.lock();
946 if orgs.values().any(|org| org.slug == slug) {
947 Err(anyhow!("org already exists"))
948 } else {
949 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
950 orgs.insert(
951 org_id,
952 Org {
953 id: org_id,
954 name: name.to_string(),
955 slug: slug.to_string(),
956 },
957 );
958 Ok(org_id)
959 }
960 }
961
962 async fn add_org_member(
963 &self,
964 org_id: OrgId,
965 user_id: UserId,
966 is_admin: bool,
967 ) -> Result<()> {
968 self.background.simulate_random_delay().await;
969 if !self.orgs.lock().contains_key(&org_id) {
970 return Err(anyhow!("org does not exist"));
971 }
972 if !self.users.lock().contains_key(&user_id) {
973 return Err(anyhow!("user does not exist"));
974 }
975
976 self.org_memberships
977 .lock()
978 .entry((org_id, user_id))
979 .or_insert(is_admin);
980 Ok(())
981 }
982
983 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
984 self.background.simulate_random_delay().await;
985 if !self.orgs.lock().contains_key(&org_id) {
986 return Err(anyhow!("org does not exist"));
987 }
988
989 let mut channels = self.channels.lock();
990 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
991 channels.insert(
992 channel_id,
993 Channel {
994 id: channel_id,
995 name: name.to_string(),
996 owner_id: org_id.0,
997 owner_is_user: false,
998 },
999 );
1000 Ok(channel_id)
1001 }
1002
1003 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1004 self.background.simulate_random_delay().await;
1005 Ok(self
1006 .channels
1007 .lock()
1008 .values()
1009 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1010 .cloned()
1011 .collect())
1012 }
1013
1014 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1015 self.background.simulate_random_delay().await;
1016 let channels = self.channels.lock();
1017 let memberships = self.channel_memberships.lock();
1018 Ok(channels
1019 .values()
1020 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1021 .cloned()
1022 .collect())
1023 }
1024
1025 async fn can_user_access_channel(
1026 &self,
1027 user_id: UserId,
1028 channel_id: ChannelId,
1029 ) -> Result<bool> {
1030 self.background.simulate_random_delay().await;
1031 Ok(self
1032 .channel_memberships
1033 .lock()
1034 .contains_key(&(channel_id, user_id)))
1035 }
1036
1037 async fn add_channel_member(
1038 &self,
1039 channel_id: ChannelId,
1040 user_id: UserId,
1041 is_admin: bool,
1042 ) -> Result<()> {
1043 self.background.simulate_random_delay().await;
1044 if !self.channels.lock().contains_key(&channel_id) {
1045 return Err(anyhow!("channel does not exist"));
1046 }
1047 if !self.users.lock().contains_key(&user_id) {
1048 return Err(anyhow!("user does not exist"));
1049 }
1050
1051 self.channel_memberships
1052 .lock()
1053 .entry((channel_id, user_id))
1054 .or_insert(is_admin);
1055 Ok(())
1056 }
1057
1058 async fn create_channel_message(
1059 &self,
1060 channel_id: ChannelId,
1061 sender_id: UserId,
1062 body: &str,
1063 timestamp: OffsetDateTime,
1064 nonce: u128,
1065 ) -> Result<MessageId> {
1066 self.background.simulate_random_delay().await;
1067 if !self.channels.lock().contains_key(&channel_id) {
1068 return Err(anyhow!("channel does not exist"));
1069 }
1070 if !self.users.lock().contains_key(&sender_id) {
1071 return Err(anyhow!("user does not exist"));
1072 }
1073
1074 let mut messages = self.channel_messages.lock();
1075 if let Some(message) = messages
1076 .values()
1077 .find(|message| message.nonce.as_u128() == nonce)
1078 {
1079 Ok(message.id)
1080 } else {
1081 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1082 messages.insert(
1083 message_id,
1084 ChannelMessage {
1085 id: message_id,
1086 channel_id,
1087 sender_id,
1088 body: body.to_string(),
1089 sent_at: timestamp,
1090 nonce: Uuid::from_u128(nonce),
1091 },
1092 );
1093 Ok(message_id)
1094 }
1095 }
1096
1097 async fn get_channel_messages(
1098 &self,
1099 channel_id: ChannelId,
1100 count: usize,
1101 before_id: Option<MessageId>,
1102 ) -> Result<Vec<ChannelMessage>> {
1103 let mut messages = self
1104 .channel_messages
1105 .lock()
1106 .values()
1107 .rev()
1108 .filter(|message| {
1109 message.channel_id == channel_id
1110 && message.id < before_id.unwrap_or(MessageId::MAX)
1111 })
1112 .take(count)
1113 .cloned()
1114 .collect::<Vec<_>>();
1115 messages.sort_unstable_by_key(|message| message.id);
1116 Ok(messages)
1117 }
1118
1119 async fn teardown(&self, _name: &str, _url: &str) {}
1120 }
1121}