1use anyhow::Context;
2use anyhow::Result;
3use async_std::task::{block_on, yield_now};
4use async_trait::async_trait;
5use serde::Serialize;
6pub use sqlx::postgres::PgPoolOptions as DbOptions;
7use sqlx::{types::Uuid, FromRow};
8use time::OffsetDateTime;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[async_trait]
25pub trait Db: Send + Sync {
26 async fn create_signup(
27 &self,
28 github_login: &str,
29 email_address: &str,
30 about: &str,
31 wants_releases: bool,
32 wants_updates: bool,
33 wants_community: bool,
34 ) -> Result<SignupId>;
35 async fn get_all_signups(&self) -> Result<Vec<Signup>>;
36 async fn destroy_signup(&self, id: SignupId) -> Result<()>;
37 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
38 async fn get_all_users(&self) -> Result<Vec<User>>;
39 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
40 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
41 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
42 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
43 async fn destroy_user(&self, id: UserId) -> Result<()>;
44 async fn create_access_token_hash(
45 &self,
46 user_id: UserId,
47 access_token_hash: &str,
48 max_access_token_count: usize,
49 ) -> Result<()>;
50 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
51 #[cfg(any(test, feature = "seed-support"))]
52 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
53 #[cfg(any(test, feature = "seed-support"))]
54 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
55 #[cfg(any(test, feature = "seed-support"))]
56 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
57 #[cfg(any(test, feature = "seed-support"))]
58 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
59 #[cfg(any(test, feature = "seed-support"))]
60 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
61 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
62 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
63 -> Result<bool>;
64 #[cfg(any(test, feature = "seed-support"))]
65 async fn add_channel_member(
66 &self,
67 channel_id: ChannelId,
68 user_id: UserId,
69 is_admin: bool,
70 ) -> Result<()>;
71 async fn create_channel_message(
72 &self,
73 channel_id: ChannelId,
74 sender_id: UserId,
75 body: &str,
76 timestamp: OffsetDateTime,
77 nonce: u128,
78 ) -> Result<MessageId>;
79 async fn get_channel_messages(
80 &self,
81 channel_id: ChannelId,
82 count: usize,
83 before_id: Option<MessageId>,
84 ) -> Result<Vec<ChannelMessage>>;
85 #[cfg(test)]
86 async fn teardown(&self, name: &str, url: &str);
87}
88
89pub struct PostgresDb {
90 pool: sqlx::PgPool,
91 test_mode: bool,
92}
93
94impl PostgresDb {
95 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
96 let pool = DbOptions::new()
97 .max_connections(max_connections)
98 .connect(url)
99 .await
100 .context("failed to connect to postgres database")?;
101 Ok(Self {
102 pool,
103 test_mode: false,
104 })
105 }
106}
107
108#[async_trait]
109impl Db for PostgresDb {
110 // signups
111 async fn create_signup(
112 &self,
113 github_login: &str,
114 email_address: &str,
115 about: &str,
116 wants_releases: bool,
117 wants_updates: bool,
118 wants_community: bool,
119 ) -> Result<SignupId> {
120 test_support!(self, {
121 let query = "
122 INSERT INTO signups (
123 github_login,
124 email_address,
125 about,
126 wants_releases,
127 wants_updates,
128 wants_community
129 )
130 VALUES ($1, $2, $3, $4, $5, $6)
131 RETURNING id
132 ";
133 Ok(sqlx::query_scalar(query)
134 .bind(github_login)
135 .bind(email_address)
136 .bind(about)
137 .bind(wants_releases)
138 .bind(wants_updates)
139 .bind(wants_community)
140 .fetch_one(&self.pool)
141 .await
142 .map(SignupId)?)
143 })
144 }
145
146 async fn get_all_signups(&self) -> Result<Vec<Signup>> {
147 test_support!(self, {
148 let query = "SELECT * FROM signups ORDER BY github_login ASC";
149 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
150 })
151 }
152
153 async fn destroy_signup(&self, id: SignupId) -> Result<()> {
154 test_support!(self, {
155 let query = "DELETE FROM signups WHERE id = $1";
156 Ok(sqlx::query(query)
157 .bind(id.0)
158 .execute(&self.pool)
159 .await
160 .map(drop)?)
161 })
162 }
163
164 // users
165
166 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
167 test_support!(self, {
168 let query = "
169 INSERT INTO users (github_login, admin)
170 VALUES ($1, $2)
171 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
172 RETURNING id
173 ";
174 Ok(sqlx::query_scalar(query)
175 .bind(github_login)
176 .bind(admin)
177 .fetch_one(&self.pool)
178 .await
179 .map(UserId)?)
180 })
181 }
182
183 async fn get_all_users(&self) -> Result<Vec<User>> {
184 test_support!(self, {
185 let query = "SELECT * FROM users ORDER BY github_login ASC";
186 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
187 })
188 }
189
190 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
191 let users = self.get_users_by_ids(vec![id]).await?;
192 Ok(users.into_iter().next())
193 }
194
195 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
196 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
197 test_support!(self, {
198 let query = "
199 SELECT users.*
200 FROM users
201 WHERE users.id = ANY ($1)
202 ";
203
204 Ok(sqlx::query_as(query)
205 .bind(&ids)
206 .fetch_all(&self.pool)
207 .await?)
208 })
209 }
210
211 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
212 test_support!(self, {
213 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
214 Ok(sqlx::query_as(query)
215 .bind(github_login)
216 .fetch_optional(&self.pool)
217 .await?)
218 })
219 }
220
221 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
222 test_support!(self, {
223 let query = "UPDATE users SET admin = $1 WHERE id = $2";
224 Ok(sqlx::query(query)
225 .bind(is_admin)
226 .bind(id.0)
227 .execute(&self.pool)
228 .await
229 .map(drop)?)
230 })
231 }
232
233 async fn destroy_user(&self, id: UserId) -> Result<()> {
234 test_support!(self, {
235 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
236 sqlx::query(query)
237 .bind(id.0)
238 .execute(&self.pool)
239 .await
240 .map(drop)?;
241 let query = "DELETE FROM users WHERE id = $1;";
242 Ok(sqlx::query(query)
243 .bind(id.0)
244 .execute(&self.pool)
245 .await
246 .map(drop)?)
247 })
248 }
249
250 // access tokens
251
252 async fn create_access_token_hash(
253 &self,
254 user_id: UserId,
255 access_token_hash: &str,
256 max_access_token_count: usize,
257 ) -> Result<()> {
258 test_support!(self, {
259 let insert_query = "
260 INSERT INTO access_tokens (user_id, hash)
261 VALUES ($1, $2);
262 ";
263 let cleanup_query = "
264 DELETE FROM access_tokens
265 WHERE id IN (
266 SELECT id from access_tokens
267 WHERE user_id = $1
268 ORDER BY id DESC
269 OFFSET $3
270 )
271 ";
272
273 let mut tx = self.pool.begin().await?;
274 sqlx::query(insert_query)
275 .bind(user_id.0)
276 .bind(access_token_hash)
277 .execute(&mut tx)
278 .await?;
279 sqlx::query(cleanup_query)
280 .bind(user_id.0)
281 .bind(access_token_hash)
282 .bind(max_access_token_count as u32)
283 .execute(&mut tx)
284 .await?;
285 Ok(tx.commit().await?)
286 })
287 }
288
289 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
290 test_support!(self, {
291 let query = "
292 SELECT hash
293 FROM access_tokens
294 WHERE user_id = $1
295 ORDER BY id DESC
296 ";
297 Ok(sqlx::query_scalar(query)
298 .bind(user_id.0)
299 .fetch_all(&self.pool)
300 .await?)
301 })
302 }
303
304 // orgs
305
306 #[allow(unused)] // Help rust-analyzer
307 #[cfg(any(test, feature = "seed-support"))]
308 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
309 test_support!(self, {
310 let query = "
311 SELECT *
312 FROM orgs
313 WHERE slug = $1
314 ";
315 Ok(sqlx::query_as(query)
316 .bind(slug)
317 .fetch_optional(&self.pool)
318 .await?)
319 })
320 }
321
322 #[cfg(any(test, feature = "seed-support"))]
323 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
324 test_support!(self, {
325 let query = "
326 INSERT INTO orgs (name, slug)
327 VALUES ($1, $2)
328 RETURNING id
329 ";
330 Ok(sqlx::query_scalar(query)
331 .bind(name)
332 .bind(slug)
333 .fetch_one(&self.pool)
334 .await
335 .map(OrgId)?)
336 })
337 }
338
339 #[cfg(any(test, feature = "seed-support"))]
340 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
341 test_support!(self, {
342 let query = "
343 INSERT INTO org_memberships (org_id, user_id, admin)
344 VALUES ($1, $2, $3)
345 ON CONFLICT DO NOTHING
346 ";
347 Ok(sqlx::query(query)
348 .bind(org_id.0)
349 .bind(user_id.0)
350 .bind(is_admin)
351 .execute(&self.pool)
352 .await
353 .map(drop)?)
354 })
355 }
356
357 // channels
358
359 #[cfg(any(test, feature = "seed-support"))]
360 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
361 test_support!(self, {
362 let query = "
363 INSERT INTO channels (owner_id, owner_is_user, name)
364 VALUES ($1, false, $2)
365 RETURNING id
366 ";
367 Ok(sqlx::query_scalar(query)
368 .bind(org_id.0)
369 .bind(name)
370 .fetch_one(&self.pool)
371 .await
372 .map(ChannelId)?)
373 })
374 }
375
376 #[allow(unused)] // Help rust-analyzer
377 #[cfg(any(test, feature = "seed-support"))]
378 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
379 test_support!(self, {
380 let query = "
381 SELECT *
382 FROM channels
383 WHERE
384 channels.owner_is_user = false AND
385 channels.owner_id = $1
386 ";
387 Ok(sqlx::query_as(query)
388 .bind(org_id.0)
389 .fetch_all(&self.pool)
390 .await?)
391 })
392 }
393
394 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
395 test_support!(self, {
396 let query = "
397 SELECT
398 channels.*
399 FROM
400 channel_memberships, channels
401 WHERE
402 channel_memberships.user_id = $1 AND
403 channel_memberships.channel_id = channels.id
404 ";
405 Ok(sqlx::query_as(query)
406 .bind(user_id.0)
407 .fetch_all(&self.pool)
408 .await?)
409 })
410 }
411
412 async fn can_user_access_channel(
413 &self,
414 user_id: UserId,
415 channel_id: ChannelId,
416 ) -> Result<bool> {
417 test_support!(self, {
418 let query = "
419 SELECT id
420 FROM channel_memberships
421 WHERE user_id = $1 AND channel_id = $2
422 LIMIT 1
423 ";
424 Ok(sqlx::query_scalar::<_, i32>(query)
425 .bind(user_id.0)
426 .bind(channel_id.0)
427 .fetch_optional(&self.pool)
428 .await
429 .map(|e| e.is_some())?)
430 })
431 }
432
433 #[cfg(any(test, feature = "seed-support"))]
434 async fn add_channel_member(
435 &self,
436 channel_id: ChannelId,
437 user_id: UserId,
438 is_admin: bool,
439 ) -> Result<()> {
440 test_support!(self, {
441 let query = "
442 INSERT INTO channel_memberships (channel_id, user_id, admin)
443 VALUES ($1, $2, $3)
444 ON CONFLICT DO NOTHING
445 ";
446 Ok(sqlx::query(query)
447 .bind(channel_id.0)
448 .bind(user_id.0)
449 .bind(is_admin)
450 .execute(&self.pool)
451 .await
452 .map(drop)?)
453 })
454 }
455
456 // messages
457
458 async fn create_channel_message(
459 &self,
460 channel_id: ChannelId,
461 sender_id: UserId,
462 body: &str,
463 timestamp: OffsetDateTime,
464 nonce: u128,
465 ) -> Result<MessageId> {
466 test_support!(self, {
467 let query = "
468 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
469 VALUES ($1, $2, $3, $4, $5)
470 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
471 RETURNING id
472 ";
473 Ok(sqlx::query_scalar(query)
474 .bind(channel_id.0)
475 .bind(sender_id.0)
476 .bind(body)
477 .bind(timestamp)
478 .bind(Uuid::from_u128(nonce))
479 .fetch_one(&self.pool)
480 .await
481 .map(MessageId)?)
482 })
483 }
484
485 async fn get_channel_messages(
486 &self,
487 channel_id: ChannelId,
488 count: usize,
489 before_id: Option<MessageId>,
490 ) -> Result<Vec<ChannelMessage>> {
491 test_support!(self, {
492 let query = r#"
493 SELECT * FROM (
494 SELECT
495 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
496 FROM
497 channel_messages
498 WHERE
499 channel_id = $1 AND
500 id < $2
501 ORDER BY id DESC
502 LIMIT $3
503 ) as recent_messages
504 ORDER BY id ASC
505 "#;
506 Ok(sqlx::query_as(query)
507 .bind(channel_id.0)
508 .bind(before_id.unwrap_or(MessageId::MAX))
509 .bind(count as i64)
510 .fetch_all(&self.pool)
511 .await?)
512 })
513 }
514
515 #[cfg(test)]
516 async fn teardown(&self, name: &str, url: &str) {
517 use util::ResultExt;
518
519 test_support!(self, {
520 let query = "
521 SELECT pg_terminate_backend(pg_stat_activity.pid)
522 FROM pg_stat_activity
523 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
524 ";
525 sqlx::query(query)
526 .bind(name)
527 .execute(&self.pool)
528 .await
529 .log_err();
530 self.pool.close().await;
531 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
532 .await
533 .log_err();
534 })
535 }
536}
537
538macro_rules! id_type {
539 ($name:ident) => {
540 #[derive(
541 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
542 )]
543 #[sqlx(transparent)]
544 #[serde(transparent)]
545 pub struct $name(pub i32);
546
547 impl $name {
548 #[allow(unused)]
549 pub const MAX: Self = Self(i32::MAX);
550
551 #[allow(unused)]
552 pub fn from_proto(value: u64) -> Self {
553 Self(value as i32)
554 }
555
556 #[allow(unused)]
557 pub fn to_proto(&self) -> u64 {
558 self.0 as u64
559 }
560 }
561 };
562}
563
564id_type!(UserId);
565#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
566pub struct User {
567 pub id: UserId,
568 pub github_login: String,
569 pub admin: bool,
570}
571
572id_type!(OrgId);
573#[derive(FromRow)]
574pub struct Org {
575 pub id: OrgId,
576 pub name: String,
577 pub slug: String,
578}
579
580id_type!(SignupId);
581#[derive(Debug, FromRow, Serialize)]
582pub struct Signup {
583 pub id: SignupId,
584 pub github_login: String,
585 pub email_address: String,
586 pub about: String,
587 pub wants_releases: Option<bool>,
588 pub wants_updates: Option<bool>,
589 pub wants_community: Option<bool>,
590}
591
592id_type!(ChannelId);
593#[derive(Clone, Debug, FromRow, Serialize)]
594pub struct Channel {
595 pub id: ChannelId,
596 pub name: String,
597 pub owner_id: i32,
598 pub owner_is_user: bool,
599}
600
601id_type!(MessageId);
602#[derive(Clone, Debug, FromRow)]
603pub struct ChannelMessage {
604 pub id: MessageId,
605 pub channel_id: ChannelId,
606 pub sender_id: UserId,
607 pub body: String,
608 pub sent_at: OffsetDateTime,
609 pub nonce: Uuid,
610}
611
612#[cfg(test)]
613pub mod tests {
614 use super::*;
615 use anyhow::anyhow;
616 use collections::BTreeMap;
617 use gpui::{executor::Background, TestAppContext};
618 use lazy_static::lazy_static;
619 use parking_lot::Mutex;
620 use rand::prelude::*;
621 use sqlx::{
622 migrate::{MigrateDatabase, Migrator},
623 Postgres,
624 };
625 use std::{path::Path, sync::Arc};
626 use util::post_inc;
627
628 #[gpui::test]
629 async fn test_get_users_by_ids(cx: &mut TestAppContext) {
630 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
631 let db = test_db.db();
632
633 let user = db.create_user("user", false).await.unwrap();
634 let friend1 = db.create_user("friend-1", false).await.unwrap();
635 let friend2 = db.create_user("friend-2", false).await.unwrap();
636 let friend3 = db.create_user("friend-3", false).await.unwrap();
637
638 assert_eq!(
639 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
640 .await
641 .unwrap(),
642 vec![
643 User {
644 id: user,
645 github_login: "user".to_string(),
646 admin: false,
647 },
648 User {
649 id: friend1,
650 github_login: "friend-1".to_string(),
651 admin: false,
652 },
653 User {
654 id: friend2,
655 github_login: "friend-2".to_string(),
656 admin: false,
657 },
658 User {
659 id: friend3,
660 github_login: "friend-3".to_string(),
661 admin: false,
662 }
663 ]
664 );
665 }
666 }
667
668 #[gpui::test]
669 async fn test_recent_channel_messages(cx: &mut TestAppContext) {
670 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
671 let db = test_db.db();
672 let user = db.create_user("user", false).await.unwrap();
673 let org = db.create_org("org", "org").await.unwrap();
674 let channel = db.create_org_channel(org, "channel").await.unwrap();
675 for i in 0..10 {
676 db.create_channel_message(
677 channel,
678 user,
679 &i.to_string(),
680 OffsetDateTime::now_utc(),
681 i,
682 )
683 .await
684 .unwrap();
685 }
686
687 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
688 assert_eq!(
689 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
690 ["5", "6", "7", "8", "9"]
691 );
692
693 let prev_messages = db
694 .get_channel_messages(channel, 4, Some(messages[0].id))
695 .await
696 .unwrap();
697 assert_eq!(
698 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
699 ["1", "2", "3", "4"]
700 );
701 }
702 }
703
704 #[gpui::test]
705 async fn test_channel_message_nonces(cx: &mut TestAppContext) {
706 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
707 let db = test_db.db();
708 let user = db.create_user("user", false).await.unwrap();
709 let org = db.create_org("org", "org").await.unwrap();
710 let channel = db.create_org_channel(org, "channel").await.unwrap();
711
712 let msg1_id = db
713 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
714 .await
715 .unwrap();
716 let msg2_id = db
717 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
718 .await
719 .unwrap();
720 let msg3_id = db
721 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
722 .await
723 .unwrap();
724 let msg4_id = db
725 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
726 .await
727 .unwrap();
728
729 assert_ne!(msg1_id, msg2_id);
730 assert_eq!(msg1_id, msg3_id);
731 assert_eq!(msg2_id, msg4_id);
732 }
733 }
734
735 #[gpui::test]
736 async fn test_create_access_tokens() {
737 let test_db = TestDb::postgres();
738 let db = test_db.db();
739 let user = db.create_user("the-user", false).await.unwrap();
740
741 db.create_access_token_hash(user, "h1", 3).await.unwrap();
742 db.create_access_token_hash(user, "h2", 3).await.unwrap();
743 assert_eq!(
744 db.get_access_token_hashes(user).await.unwrap(),
745 &["h2".to_string(), "h1".to_string()]
746 );
747
748 db.create_access_token_hash(user, "h3", 3).await.unwrap();
749 assert_eq!(
750 db.get_access_token_hashes(user).await.unwrap(),
751 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
752 );
753
754 db.create_access_token_hash(user, "h4", 3).await.unwrap();
755 assert_eq!(
756 db.get_access_token_hashes(user).await.unwrap(),
757 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
758 );
759
760 db.create_access_token_hash(user, "h5", 3).await.unwrap();
761 assert_eq!(
762 db.get_access_token_hashes(user).await.unwrap(),
763 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
764 );
765 }
766
767 pub struct TestDb {
768 pub db: Option<Arc<dyn Db>>,
769 pub name: String,
770 pub url: String,
771 }
772
773 impl TestDb {
774 pub fn postgres() -> Self {
775 lazy_static! {
776 static ref LOCK: Mutex<()> = Mutex::new(());
777 }
778
779 let _guard = LOCK.lock();
780 let mut rng = StdRng::from_entropy();
781 let name = format!("zed-test-{}", rng.gen::<u128>());
782 let url = format!("postgres://postgres@localhost/{}", name);
783 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
784 let db = block_on(async {
785 Postgres::create_database(&url)
786 .await
787 .expect("failed to create test db");
788 let mut db = PostgresDb::new(&url, 5).await.unwrap();
789 db.test_mode = true;
790 let migrator = Migrator::new(migrations_path).await.unwrap();
791 migrator.run(&db.pool).await.unwrap();
792 db
793 });
794 Self {
795 db: Some(Arc::new(db)),
796 name,
797 url,
798 }
799 }
800
801 pub fn fake(background: Arc<Background>) -> Self {
802 Self {
803 db: Some(Arc::new(FakeDb::new(background))),
804 name: "fake".to_string(),
805 url: "fake".to_string(),
806 }
807 }
808
809 pub fn db(&self) -> &Arc<dyn Db> {
810 self.db.as_ref().unwrap()
811 }
812 }
813
814 impl Drop for TestDb {
815 fn drop(&mut self) {
816 if let Some(db) = self.db.take() {
817 block_on(db.teardown(&self.name, &self.url));
818 }
819 }
820 }
821
822 pub struct FakeDb {
823 background: Arc<Background>,
824 users: Mutex<BTreeMap<UserId, User>>,
825 next_user_id: Mutex<i32>,
826 orgs: Mutex<BTreeMap<OrgId, Org>>,
827 next_org_id: Mutex<i32>,
828 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
829 channels: Mutex<BTreeMap<ChannelId, Channel>>,
830 next_channel_id: Mutex<i32>,
831 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
832 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
833 next_channel_message_id: Mutex<i32>,
834 }
835
836 impl FakeDb {
837 pub fn new(background: Arc<Background>) -> Self {
838 Self {
839 background,
840 users: Default::default(),
841 next_user_id: Mutex::new(1),
842 orgs: Default::default(),
843 next_org_id: Mutex::new(1),
844 org_memberships: Default::default(),
845 channels: Default::default(),
846 next_channel_id: Mutex::new(1),
847 channel_memberships: Default::default(),
848 channel_messages: Default::default(),
849 next_channel_message_id: Mutex::new(1),
850 }
851 }
852 }
853
854 #[async_trait]
855 impl Db for FakeDb {
856 async fn create_signup(
857 &self,
858 _github_login: &str,
859 _email_address: &str,
860 _about: &str,
861 _wants_releases: bool,
862 _wants_updates: bool,
863 _wants_community: bool,
864 ) -> Result<SignupId> {
865 unimplemented!()
866 }
867
868 async fn get_all_signups(&self) -> Result<Vec<Signup>> {
869 unimplemented!()
870 }
871
872 async fn destroy_signup(&self, _id: SignupId) -> Result<()> {
873 unimplemented!()
874 }
875
876 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
877 self.background.simulate_random_delay().await;
878
879 let mut users = self.users.lock();
880 if let Some(user) = users
881 .values()
882 .find(|user| user.github_login == github_login)
883 {
884 Ok(user.id)
885 } else {
886 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
887 users.insert(
888 user_id,
889 User {
890 id: user_id,
891 github_login: github_login.to_string(),
892 admin,
893 },
894 );
895 Ok(user_id)
896 }
897 }
898
899 async fn get_all_users(&self) -> Result<Vec<User>> {
900 unimplemented!()
901 }
902
903 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
904 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
905 }
906
907 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
908 self.background.simulate_random_delay().await;
909 let users = self.users.lock();
910 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
911 }
912
913 async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
914 unimplemented!()
915 }
916
917 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
918 unimplemented!()
919 }
920
921 async fn destroy_user(&self, _id: UserId) -> Result<()> {
922 unimplemented!()
923 }
924
925 async fn create_access_token_hash(
926 &self,
927 _user_id: UserId,
928 _access_token_hash: &str,
929 _max_access_token_count: usize,
930 ) -> Result<()> {
931 unimplemented!()
932 }
933
934 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
935 unimplemented!()
936 }
937
938 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
939 unimplemented!()
940 }
941
942 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
943 self.background.simulate_random_delay().await;
944 let mut orgs = self.orgs.lock();
945 if orgs.values().any(|org| org.slug == slug) {
946 Err(anyhow!("org already exists"))
947 } else {
948 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
949 orgs.insert(
950 org_id,
951 Org {
952 id: org_id,
953 name: name.to_string(),
954 slug: slug.to_string(),
955 },
956 );
957 Ok(org_id)
958 }
959 }
960
961 async fn add_org_member(
962 &self,
963 org_id: OrgId,
964 user_id: UserId,
965 is_admin: bool,
966 ) -> Result<()> {
967 self.background.simulate_random_delay().await;
968 if !self.orgs.lock().contains_key(&org_id) {
969 return Err(anyhow!("org does not exist"));
970 }
971 if !self.users.lock().contains_key(&user_id) {
972 return Err(anyhow!("user does not exist"));
973 }
974
975 self.org_memberships
976 .lock()
977 .entry((org_id, user_id))
978 .or_insert(is_admin);
979 Ok(())
980 }
981
982 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
983 self.background.simulate_random_delay().await;
984 if !self.orgs.lock().contains_key(&org_id) {
985 return Err(anyhow!("org does not exist"));
986 }
987
988 let mut channels = self.channels.lock();
989 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
990 channels.insert(
991 channel_id,
992 Channel {
993 id: channel_id,
994 name: name.to_string(),
995 owner_id: org_id.0,
996 owner_is_user: false,
997 },
998 );
999 Ok(channel_id)
1000 }
1001
1002 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1003 self.background.simulate_random_delay().await;
1004 Ok(self
1005 .channels
1006 .lock()
1007 .values()
1008 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1009 .cloned()
1010 .collect())
1011 }
1012
1013 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1014 self.background.simulate_random_delay().await;
1015 let channels = self.channels.lock();
1016 let memberships = self.channel_memberships.lock();
1017 Ok(channels
1018 .values()
1019 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1020 .cloned()
1021 .collect())
1022 }
1023
1024 async fn can_user_access_channel(
1025 &self,
1026 user_id: UserId,
1027 channel_id: ChannelId,
1028 ) -> Result<bool> {
1029 self.background.simulate_random_delay().await;
1030 Ok(self
1031 .channel_memberships
1032 .lock()
1033 .contains_key(&(channel_id, user_id)))
1034 }
1035
1036 async fn add_channel_member(
1037 &self,
1038 channel_id: ChannelId,
1039 user_id: UserId,
1040 is_admin: bool,
1041 ) -> Result<()> {
1042 self.background.simulate_random_delay().await;
1043 if !self.channels.lock().contains_key(&channel_id) {
1044 return Err(anyhow!("channel does not exist"));
1045 }
1046 if !self.users.lock().contains_key(&user_id) {
1047 return Err(anyhow!("user does not exist"));
1048 }
1049
1050 self.channel_memberships
1051 .lock()
1052 .entry((channel_id, user_id))
1053 .or_insert(is_admin);
1054 Ok(())
1055 }
1056
1057 async fn create_channel_message(
1058 &self,
1059 channel_id: ChannelId,
1060 sender_id: UserId,
1061 body: &str,
1062 timestamp: OffsetDateTime,
1063 nonce: u128,
1064 ) -> Result<MessageId> {
1065 self.background.simulate_random_delay().await;
1066 if !self.channels.lock().contains_key(&channel_id) {
1067 return Err(anyhow!("channel does not exist"));
1068 }
1069 if !self.users.lock().contains_key(&sender_id) {
1070 return Err(anyhow!("user does not exist"));
1071 }
1072
1073 let mut messages = self.channel_messages.lock();
1074 if let Some(message) = messages
1075 .values()
1076 .find(|message| message.nonce.as_u128() == nonce)
1077 {
1078 Ok(message.id)
1079 } else {
1080 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1081 messages.insert(
1082 message_id,
1083 ChannelMessage {
1084 id: message_id,
1085 channel_id,
1086 sender_id,
1087 body: body.to_string(),
1088 sent_at: timestamp,
1089 nonce: Uuid::from_u128(nonce),
1090 },
1091 );
1092 Ok(message_id)
1093 }
1094 }
1095
1096 async fn get_channel_messages(
1097 &self,
1098 channel_id: ChannelId,
1099 count: usize,
1100 before_id: Option<MessageId>,
1101 ) -> Result<Vec<ChannelMessage>> {
1102 let mut messages = self
1103 .channel_messages
1104 .lock()
1105 .values()
1106 .rev()
1107 .filter(|message| {
1108 message.channel_id == channel_id
1109 && message.id < before_id.unwrap_or(MessageId::MAX)
1110 })
1111 .take(count)
1112 .cloned()
1113 .collect::<Vec<_>>();
1114 messages.sort_unstable_by_key(|message| message.id);
1115 Ok(messages)
1116 }
1117
1118 async fn teardown(&self, _name: &str, _url: &str) {}
1119 }
1120}