db.rs

   1use anyhow::Context;
   2use anyhow::Result;
   3use async_std::task::{block_on, yield_now};
   4use async_trait::async_trait;
   5use serde::Serialize;
   6pub use sqlx::postgres::PgPoolOptions as DbOptions;
   7use sqlx::{types::Uuid, FromRow};
   8use time::OffsetDateTime;
   9
  10macro_rules! test_support {
  11    ($self:ident, { $($token:tt)* }) => {{
  12        let body = async {
  13            $($token)*
  14        };
  15        if $self.test_mode {
  16            yield_now().await;
  17            block_on(body)
  18        } else {
  19            body.await
  20        }
  21    }};
  22}
  23
  24#[async_trait]
  25pub trait Db: Send + Sync {
  26    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  27    async fn get_all_users(&self) -> Result<Vec<User>>;
  28    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  29    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  30    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  31    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  32    async fn destroy_user(&self, id: UserId) -> Result<()>;
  33    async fn create_access_token_hash(
  34        &self,
  35        user_id: UserId,
  36        access_token_hash: &str,
  37        max_access_token_count: usize,
  38    ) -> Result<()>;
  39    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  40    #[cfg(any(test, feature = "seed-support"))]
  41    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  42    #[cfg(any(test, feature = "seed-support"))]
  43    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  44    #[cfg(any(test, feature = "seed-support"))]
  45    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  46    #[cfg(any(test, feature = "seed-support"))]
  47    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  48    #[cfg(any(test, feature = "seed-support"))]
  49    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  50    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  51    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  52        -> Result<bool>;
  53    #[cfg(any(test, feature = "seed-support"))]
  54    async fn add_channel_member(
  55        &self,
  56        channel_id: ChannelId,
  57        user_id: UserId,
  58        is_admin: bool,
  59    ) -> Result<()>;
  60    async fn create_channel_message(
  61        &self,
  62        channel_id: ChannelId,
  63        sender_id: UserId,
  64        body: &str,
  65        timestamp: OffsetDateTime,
  66        nonce: u128,
  67    ) -> Result<MessageId>;
  68    async fn get_channel_messages(
  69        &self,
  70        channel_id: ChannelId,
  71        count: usize,
  72        before_id: Option<MessageId>,
  73    ) -> Result<Vec<ChannelMessage>>;
  74    #[cfg(test)]
  75    async fn teardown(&self, name: &str, url: &str);
  76}
  77
  78pub struct PostgresDb {
  79    pool: sqlx::PgPool,
  80    test_mode: bool,
  81}
  82
  83impl PostgresDb {
  84    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
  85        let pool = DbOptions::new()
  86            .max_connections(max_connections)
  87            .connect(url)
  88            .await
  89            .context("failed to connect to postgres database")?;
  90        Ok(Self {
  91            pool,
  92            test_mode: false,
  93        })
  94    }
  95}
  96
  97#[async_trait]
  98impl Db for PostgresDb {
  99    // users
 100
 101    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 102        test_support!(self, {
 103            let query = "
 104                INSERT INTO users (github_login, admin)
 105                VALUES ($1, $2)
 106                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 107                RETURNING id
 108            ";
 109            Ok(sqlx::query_scalar(query)
 110                .bind(github_login)
 111                .bind(admin)
 112                .fetch_one(&self.pool)
 113                .await
 114                .map(UserId)?)
 115        })
 116    }
 117
 118    async fn get_all_users(&self) -> Result<Vec<User>> {
 119        test_support!(self, {
 120            let query = "SELECT * FROM users ORDER BY github_login ASC";
 121            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 122        })
 123    }
 124
 125    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 126        let users = self.get_users_by_ids(vec![id]).await?;
 127        Ok(users.into_iter().next())
 128    }
 129
 130    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 131        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 132        test_support!(self, {
 133            let query = "
 134                SELECT users.*
 135                FROM users
 136                WHERE users.id = ANY ($1)
 137            ";
 138
 139            Ok(sqlx::query_as(query)
 140                .bind(&ids)
 141                .fetch_all(&self.pool)
 142                .await?)
 143        })
 144    }
 145
 146    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 147        test_support!(self, {
 148            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 149            Ok(sqlx::query_as(query)
 150                .bind(github_login)
 151                .fetch_optional(&self.pool)
 152                .await?)
 153        })
 154    }
 155
 156    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 157        test_support!(self, {
 158            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 159            Ok(sqlx::query(query)
 160                .bind(is_admin)
 161                .bind(id.0)
 162                .execute(&self.pool)
 163                .await
 164                .map(drop)?)
 165        })
 166    }
 167
 168    async fn destroy_user(&self, id: UserId) -> Result<()> {
 169        test_support!(self, {
 170            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 171            sqlx::query(query)
 172                .bind(id.0)
 173                .execute(&self.pool)
 174                .await
 175                .map(drop)?;
 176            let query = "DELETE FROM users WHERE id = $1;";
 177            Ok(sqlx::query(query)
 178                .bind(id.0)
 179                .execute(&self.pool)
 180                .await
 181                .map(drop)?)
 182        })
 183    }
 184
 185    // access tokens
 186
 187    async fn create_access_token_hash(
 188        &self,
 189        user_id: UserId,
 190        access_token_hash: &str,
 191        max_access_token_count: usize,
 192    ) -> Result<()> {
 193        test_support!(self, {
 194            let insert_query = "
 195                INSERT INTO access_tokens (user_id, hash)
 196                VALUES ($1, $2);
 197            ";
 198            let cleanup_query = "
 199                DELETE FROM access_tokens
 200                WHERE id IN (
 201                    SELECT id from access_tokens
 202                    WHERE user_id = $1
 203                    ORDER BY id DESC
 204                    OFFSET $3
 205                )
 206            ";
 207
 208            let mut tx = self.pool.begin().await?;
 209            sqlx::query(insert_query)
 210                .bind(user_id.0)
 211                .bind(access_token_hash)
 212                .execute(&mut tx)
 213                .await?;
 214            sqlx::query(cleanup_query)
 215                .bind(user_id.0)
 216                .bind(access_token_hash)
 217                .bind(max_access_token_count as u32)
 218                .execute(&mut tx)
 219                .await?;
 220            Ok(tx.commit().await?)
 221        })
 222    }
 223
 224    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 225        test_support!(self, {
 226            let query = "
 227                SELECT hash
 228                FROM access_tokens
 229                WHERE user_id = $1
 230                ORDER BY id DESC
 231            ";
 232            Ok(sqlx::query_scalar(query)
 233                .bind(user_id.0)
 234                .fetch_all(&self.pool)
 235                .await?)
 236        })
 237    }
 238
 239    // orgs
 240
 241    #[allow(unused)] // Help rust-analyzer
 242    #[cfg(any(test, feature = "seed-support"))]
 243    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 244        test_support!(self, {
 245            let query = "
 246                SELECT *
 247                FROM orgs
 248                WHERE slug = $1
 249            ";
 250            Ok(sqlx::query_as(query)
 251                .bind(slug)
 252                .fetch_optional(&self.pool)
 253                .await?)
 254        })
 255    }
 256
 257    #[cfg(any(test, feature = "seed-support"))]
 258    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 259        test_support!(self, {
 260            let query = "
 261                INSERT INTO orgs (name, slug)
 262                VALUES ($1, $2)
 263                RETURNING id
 264            ";
 265            Ok(sqlx::query_scalar(query)
 266                .bind(name)
 267                .bind(slug)
 268                .fetch_one(&self.pool)
 269                .await
 270                .map(OrgId)?)
 271        })
 272    }
 273
 274    #[cfg(any(test, feature = "seed-support"))]
 275    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 276        test_support!(self, {
 277            let query = "
 278                INSERT INTO org_memberships (org_id, user_id, admin)
 279                VALUES ($1, $2, $3)
 280                ON CONFLICT DO NOTHING
 281            ";
 282            Ok(sqlx::query(query)
 283                .bind(org_id.0)
 284                .bind(user_id.0)
 285                .bind(is_admin)
 286                .execute(&self.pool)
 287                .await
 288                .map(drop)?)
 289        })
 290    }
 291
 292    // channels
 293
 294    #[cfg(any(test, feature = "seed-support"))]
 295    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 296        test_support!(self, {
 297            let query = "
 298                INSERT INTO channels (owner_id, owner_is_user, name)
 299                VALUES ($1, false, $2)
 300                RETURNING id
 301            ";
 302            Ok(sqlx::query_scalar(query)
 303                .bind(org_id.0)
 304                .bind(name)
 305                .fetch_one(&self.pool)
 306                .await
 307                .map(ChannelId)?)
 308        })
 309    }
 310
 311    #[allow(unused)] // Help rust-analyzer
 312    #[cfg(any(test, feature = "seed-support"))]
 313    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 314        test_support!(self, {
 315            let query = "
 316                SELECT *
 317                FROM channels
 318                WHERE
 319                    channels.owner_is_user = false AND
 320                    channels.owner_id = $1
 321            ";
 322            Ok(sqlx::query_as(query)
 323                .bind(org_id.0)
 324                .fetch_all(&self.pool)
 325                .await?)
 326        })
 327    }
 328
 329    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 330        test_support!(self, {
 331            let query = "
 332                SELECT
 333                    channels.*
 334                FROM
 335                    channel_memberships, channels
 336                WHERE
 337                    channel_memberships.user_id = $1 AND
 338                    channel_memberships.channel_id = channels.id
 339            ";
 340            Ok(sqlx::query_as(query)
 341                .bind(user_id.0)
 342                .fetch_all(&self.pool)
 343                .await?)
 344        })
 345    }
 346
 347    async fn can_user_access_channel(
 348        &self,
 349        user_id: UserId,
 350        channel_id: ChannelId,
 351    ) -> Result<bool> {
 352        test_support!(self, {
 353            let query = "
 354                SELECT id
 355                FROM channel_memberships
 356                WHERE user_id = $1 AND channel_id = $2
 357                LIMIT 1
 358            ";
 359            Ok(sqlx::query_scalar::<_, i32>(query)
 360                .bind(user_id.0)
 361                .bind(channel_id.0)
 362                .fetch_optional(&self.pool)
 363                .await
 364                .map(|e| e.is_some())?)
 365        })
 366    }
 367
 368    #[cfg(any(test, feature = "seed-support"))]
 369    async fn add_channel_member(
 370        &self,
 371        channel_id: ChannelId,
 372        user_id: UserId,
 373        is_admin: bool,
 374    ) -> Result<()> {
 375        test_support!(self, {
 376            let query = "
 377                INSERT INTO channel_memberships (channel_id, user_id, admin)
 378                VALUES ($1, $2, $3)
 379                ON CONFLICT DO NOTHING
 380            ";
 381            Ok(sqlx::query(query)
 382                .bind(channel_id.0)
 383                .bind(user_id.0)
 384                .bind(is_admin)
 385                .execute(&self.pool)
 386                .await
 387                .map(drop)?)
 388        })
 389    }
 390
 391    // messages
 392
 393    async fn create_channel_message(
 394        &self,
 395        channel_id: ChannelId,
 396        sender_id: UserId,
 397        body: &str,
 398        timestamp: OffsetDateTime,
 399        nonce: u128,
 400    ) -> Result<MessageId> {
 401        test_support!(self, {
 402            let query = "
 403                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 404                VALUES ($1, $2, $3, $4, $5)
 405                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 406                RETURNING id
 407            ";
 408            Ok(sqlx::query_scalar(query)
 409                .bind(channel_id.0)
 410                .bind(sender_id.0)
 411                .bind(body)
 412                .bind(timestamp)
 413                .bind(Uuid::from_u128(nonce))
 414                .fetch_one(&self.pool)
 415                .await
 416                .map(MessageId)?)
 417        })
 418    }
 419
 420    async fn get_channel_messages(
 421        &self,
 422        channel_id: ChannelId,
 423        count: usize,
 424        before_id: Option<MessageId>,
 425    ) -> Result<Vec<ChannelMessage>> {
 426        test_support!(self, {
 427            let query = r#"
 428                SELECT * FROM (
 429                    SELECT
 430                        id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 431                    FROM
 432                        channel_messages
 433                    WHERE
 434                        channel_id = $1 AND
 435                        id < $2
 436                    ORDER BY id DESC
 437                    LIMIT $3
 438                ) as recent_messages
 439                ORDER BY id ASC
 440            "#;
 441            Ok(sqlx::query_as(query)
 442                .bind(channel_id.0)
 443                .bind(before_id.unwrap_or(MessageId::MAX))
 444                .bind(count as i64)
 445                .fetch_all(&self.pool)
 446                .await?)
 447        })
 448    }
 449
 450    #[cfg(test)]
 451    async fn teardown(&self, name: &str, url: &str) {
 452        use util::ResultExt;
 453
 454        test_support!(self, {
 455            let query = "
 456                SELECT pg_terminate_backend(pg_stat_activity.pid)
 457                FROM pg_stat_activity
 458                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
 459            ";
 460            sqlx::query(query)
 461                .bind(name)
 462                .execute(&self.pool)
 463                .await
 464                .log_err();
 465            self.pool.close().await;
 466            <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 467                .await
 468                .log_err();
 469        })
 470    }
 471}
 472
 473macro_rules! id_type {
 474    ($name:ident) => {
 475        #[derive(
 476            Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 477        )]
 478        #[sqlx(transparent)]
 479        #[serde(transparent)]
 480        pub struct $name(pub i32);
 481
 482        impl $name {
 483            #[allow(unused)]
 484            pub const MAX: Self = Self(i32::MAX);
 485
 486            #[allow(unused)]
 487            pub fn from_proto(value: u64) -> Self {
 488                Self(value as i32)
 489            }
 490
 491            #[allow(unused)]
 492            pub fn to_proto(&self) -> u64 {
 493                self.0 as u64
 494            }
 495        }
 496    };
 497}
 498
 499id_type!(UserId);
 500#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 501pub struct User {
 502    pub id: UserId,
 503    pub github_login: String,
 504    pub admin: bool,
 505}
 506
 507id_type!(OrgId);
 508#[derive(FromRow)]
 509pub struct Org {
 510    pub id: OrgId,
 511    pub name: String,
 512    pub slug: String,
 513}
 514
 515id_type!(ChannelId);
 516#[derive(Clone, Debug, FromRow, Serialize)]
 517pub struct Channel {
 518    pub id: ChannelId,
 519    pub name: String,
 520    pub owner_id: i32,
 521    pub owner_is_user: bool,
 522}
 523
 524id_type!(MessageId);
 525#[derive(Clone, Debug, FromRow)]
 526pub struct ChannelMessage {
 527    pub id: MessageId,
 528    pub channel_id: ChannelId,
 529    pub sender_id: UserId,
 530    pub body: String,
 531    pub sent_at: OffsetDateTime,
 532    pub nonce: Uuid,
 533}
 534
 535#[cfg(test)]
 536pub mod tests {
 537    use super::*;
 538    use anyhow::anyhow;
 539    use collections::BTreeMap;
 540    use gpui::{executor::Background, TestAppContext};
 541    use lazy_static::lazy_static;
 542    use parking_lot::Mutex;
 543    use rand::prelude::*;
 544    use sqlx::{
 545        migrate::{MigrateDatabase, Migrator},
 546        Postgres,
 547    };
 548    use std::{path::Path, sync::Arc};
 549    use util::post_inc;
 550
 551    #[gpui::test]
 552    async fn test_get_users_by_ids(cx: &mut TestAppContext) {
 553        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 554            let db = test_db.db();
 555
 556            let user = db.create_user("user", false).await.unwrap();
 557            let friend1 = db.create_user("friend-1", false).await.unwrap();
 558            let friend2 = db.create_user("friend-2", false).await.unwrap();
 559            let friend3 = db.create_user("friend-3", false).await.unwrap();
 560
 561            assert_eq!(
 562                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 563                    .await
 564                    .unwrap(),
 565                vec![
 566                    User {
 567                        id: user,
 568                        github_login: "user".to_string(),
 569                        admin: false,
 570                    },
 571                    User {
 572                        id: friend1,
 573                        github_login: "friend-1".to_string(),
 574                        admin: false,
 575                    },
 576                    User {
 577                        id: friend2,
 578                        github_login: "friend-2".to_string(),
 579                        admin: false,
 580                    },
 581                    User {
 582                        id: friend3,
 583                        github_login: "friend-3".to_string(),
 584                        admin: false,
 585                    }
 586                ]
 587            );
 588        }
 589    }
 590
 591    #[gpui::test]
 592    async fn test_recent_channel_messages(cx: &mut TestAppContext) {
 593        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 594            let db = test_db.db();
 595            let user = db.create_user("user", false).await.unwrap();
 596            let org = db.create_org("org", "org").await.unwrap();
 597            let channel = db.create_org_channel(org, "channel").await.unwrap();
 598            for i in 0..10 {
 599                db.create_channel_message(
 600                    channel,
 601                    user,
 602                    &i.to_string(),
 603                    OffsetDateTime::now_utc(),
 604                    i,
 605                )
 606                .await
 607                .unwrap();
 608            }
 609
 610            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 611            assert_eq!(
 612                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 613                ["5", "6", "7", "8", "9"]
 614            );
 615
 616            let prev_messages = db
 617                .get_channel_messages(channel, 4, Some(messages[0].id))
 618                .await
 619                .unwrap();
 620            assert_eq!(
 621                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 622                ["1", "2", "3", "4"]
 623            );
 624        }
 625    }
 626
 627    #[gpui::test]
 628    async fn test_channel_message_nonces(cx: &mut TestAppContext) {
 629        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 630            let db = test_db.db();
 631            let user = db.create_user("user", false).await.unwrap();
 632            let org = db.create_org("org", "org").await.unwrap();
 633            let channel = db.create_org_channel(org, "channel").await.unwrap();
 634
 635            let msg1_id = db
 636                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 637                .await
 638                .unwrap();
 639            let msg2_id = db
 640                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 641                .await
 642                .unwrap();
 643            let msg3_id = db
 644                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
 645                .await
 646                .unwrap();
 647            let msg4_id = db
 648                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
 649                .await
 650                .unwrap();
 651
 652            assert_ne!(msg1_id, msg2_id);
 653            assert_eq!(msg1_id, msg3_id);
 654            assert_eq!(msg2_id, msg4_id);
 655        }
 656    }
 657
 658    #[gpui::test]
 659    async fn test_create_access_tokens() {
 660        let test_db = TestDb::postgres();
 661        let db = test_db.db();
 662        let user = db.create_user("the-user", false).await.unwrap();
 663
 664        db.create_access_token_hash(user, "h1", 3).await.unwrap();
 665        db.create_access_token_hash(user, "h2", 3).await.unwrap();
 666        assert_eq!(
 667            db.get_access_token_hashes(user).await.unwrap(),
 668            &["h2".to_string(), "h1".to_string()]
 669        );
 670
 671        db.create_access_token_hash(user, "h3", 3).await.unwrap();
 672        assert_eq!(
 673            db.get_access_token_hashes(user).await.unwrap(),
 674            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
 675        );
 676
 677        db.create_access_token_hash(user, "h4", 3).await.unwrap();
 678        assert_eq!(
 679            db.get_access_token_hashes(user).await.unwrap(),
 680            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
 681        );
 682
 683        db.create_access_token_hash(user, "h5", 3).await.unwrap();
 684        assert_eq!(
 685            db.get_access_token_hashes(user).await.unwrap(),
 686            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
 687        );
 688    }
 689
 690    pub struct TestDb {
 691        pub db: Option<Arc<dyn Db>>,
 692        pub name: String,
 693        pub url: String,
 694    }
 695
 696    impl TestDb {
 697        pub fn postgres() -> Self {
 698            lazy_static! {
 699                static ref LOCK: Mutex<()> = Mutex::new(());
 700            }
 701
 702            let _guard = LOCK.lock();
 703            let mut rng = StdRng::from_entropy();
 704            let name = format!("zed-test-{}", rng.gen::<u128>());
 705            let url = format!("postgres://postgres@localhost/{}", name);
 706            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 707            let db = block_on(async {
 708                Postgres::create_database(&url)
 709                    .await
 710                    .expect("failed to create test db");
 711                let mut db = PostgresDb::new(&url, 5).await.unwrap();
 712                db.test_mode = true;
 713                let migrator = Migrator::new(migrations_path).await.unwrap();
 714                migrator.run(&db.pool).await.unwrap();
 715                db
 716            });
 717            Self {
 718                db: Some(Arc::new(db)),
 719                name,
 720                url,
 721            }
 722        }
 723
 724        pub fn fake(background: Arc<Background>) -> Self {
 725            Self {
 726                db: Some(Arc::new(FakeDb::new(background))),
 727                name: "fake".to_string(),
 728                url: "fake".to_string(),
 729            }
 730        }
 731
 732        pub fn db(&self) -> &Arc<dyn Db> {
 733            self.db.as_ref().unwrap()
 734        }
 735    }
 736
 737    impl Drop for TestDb {
 738        fn drop(&mut self) {
 739            if let Some(db) = self.db.take() {
 740                block_on(db.teardown(&self.name, &self.url));
 741            }
 742        }
 743    }
 744
 745    pub struct FakeDb {
 746        background: Arc<Background>,
 747        users: Mutex<BTreeMap<UserId, User>>,
 748        next_user_id: Mutex<i32>,
 749        orgs: Mutex<BTreeMap<OrgId, Org>>,
 750        next_org_id: Mutex<i32>,
 751        org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
 752        channels: Mutex<BTreeMap<ChannelId, Channel>>,
 753        next_channel_id: Mutex<i32>,
 754        channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
 755        channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
 756        next_channel_message_id: Mutex<i32>,
 757    }
 758
 759    impl FakeDb {
 760        pub fn new(background: Arc<Background>) -> Self {
 761            Self {
 762                background,
 763                users: Default::default(),
 764                next_user_id: Mutex::new(1),
 765                orgs: Default::default(),
 766                next_org_id: Mutex::new(1),
 767                org_memberships: Default::default(),
 768                channels: Default::default(),
 769                next_channel_id: Mutex::new(1),
 770                channel_memberships: Default::default(),
 771                channel_messages: Default::default(),
 772                next_channel_message_id: Mutex::new(1),
 773            }
 774        }
 775    }
 776
 777    #[async_trait]
 778    impl Db for FakeDb {
 779        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 780            self.background.simulate_random_delay().await;
 781
 782            let mut users = self.users.lock();
 783            if let Some(user) = users
 784                .values()
 785                .find(|user| user.github_login == github_login)
 786            {
 787                Ok(user.id)
 788            } else {
 789                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
 790                users.insert(
 791                    user_id,
 792                    User {
 793                        id: user_id,
 794                        github_login: github_login.to_string(),
 795                        admin,
 796                    },
 797                );
 798                Ok(user_id)
 799            }
 800        }
 801
 802        async fn get_all_users(&self) -> Result<Vec<User>> {
 803            unimplemented!()
 804        }
 805
 806        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 807            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
 808        }
 809
 810        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 811            self.background.simulate_random_delay().await;
 812            let users = self.users.lock();
 813            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
 814        }
 815
 816        async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
 817            unimplemented!()
 818        }
 819
 820        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
 821            unimplemented!()
 822        }
 823
 824        async fn destroy_user(&self, _id: UserId) -> Result<()> {
 825            unimplemented!()
 826        }
 827
 828        async fn create_access_token_hash(
 829            &self,
 830            _user_id: UserId,
 831            _access_token_hash: &str,
 832            _max_access_token_count: usize,
 833        ) -> Result<()> {
 834            unimplemented!()
 835        }
 836
 837        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
 838            unimplemented!()
 839        }
 840
 841        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
 842            unimplemented!()
 843        }
 844
 845        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 846            self.background.simulate_random_delay().await;
 847            let mut orgs = self.orgs.lock();
 848            if orgs.values().any(|org| org.slug == slug) {
 849                Err(anyhow!("org already exists"))
 850            } else {
 851                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
 852                orgs.insert(
 853                    org_id,
 854                    Org {
 855                        id: org_id,
 856                        name: name.to_string(),
 857                        slug: slug.to_string(),
 858                    },
 859                );
 860                Ok(org_id)
 861            }
 862        }
 863
 864        async fn add_org_member(
 865            &self,
 866            org_id: OrgId,
 867            user_id: UserId,
 868            is_admin: bool,
 869        ) -> Result<()> {
 870            self.background.simulate_random_delay().await;
 871            if !self.orgs.lock().contains_key(&org_id) {
 872                return Err(anyhow!("org does not exist"));
 873            }
 874            if !self.users.lock().contains_key(&user_id) {
 875                return Err(anyhow!("user does not exist"));
 876            }
 877
 878            self.org_memberships
 879                .lock()
 880                .entry((org_id, user_id))
 881                .or_insert(is_admin);
 882            Ok(())
 883        }
 884
 885        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 886            self.background.simulate_random_delay().await;
 887            if !self.orgs.lock().contains_key(&org_id) {
 888                return Err(anyhow!("org does not exist"));
 889            }
 890
 891            let mut channels = self.channels.lock();
 892            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
 893            channels.insert(
 894                channel_id,
 895                Channel {
 896                    id: channel_id,
 897                    name: name.to_string(),
 898                    owner_id: org_id.0,
 899                    owner_is_user: false,
 900                },
 901            );
 902            Ok(channel_id)
 903        }
 904
 905        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 906            self.background.simulate_random_delay().await;
 907            Ok(self
 908                .channels
 909                .lock()
 910                .values()
 911                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
 912                .cloned()
 913                .collect())
 914        }
 915
 916        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 917            self.background.simulate_random_delay().await;
 918            let channels = self.channels.lock();
 919            let memberships = self.channel_memberships.lock();
 920            Ok(channels
 921                .values()
 922                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
 923                .cloned()
 924                .collect())
 925        }
 926
 927        async fn can_user_access_channel(
 928            &self,
 929            user_id: UserId,
 930            channel_id: ChannelId,
 931        ) -> Result<bool> {
 932            self.background.simulate_random_delay().await;
 933            Ok(self
 934                .channel_memberships
 935                .lock()
 936                .contains_key(&(channel_id, user_id)))
 937        }
 938
 939        async fn add_channel_member(
 940            &self,
 941            channel_id: ChannelId,
 942            user_id: UserId,
 943            is_admin: bool,
 944        ) -> Result<()> {
 945            self.background.simulate_random_delay().await;
 946            if !self.channels.lock().contains_key(&channel_id) {
 947                return Err(anyhow!("channel does not exist"));
 948            }
 949            if !self.users.lock().contains_key(&user_id) {
 950                return Err(anyhow!("user does not exist"));
 951            }
 952
 953            self.channel_memberships
 954                .lock()
 955                .entry((channel_id, user_id))
 956                .or_insert(is_admin);
 957            Ok(())
 958        }
 959
 960        async fn create_channel_message(
 961            &self,
 962            channel_id: ChannelId,
 963            sender_id: UserId,
 964            body: &str,
 965            timestamp: OffsetDateTime,
 966            nonce: u128,
 967        ) -> Result<MessageId> {
 968            self.background.simulate_random_delay().await;
 969            if !self.channels.lock().contains_key(&channel_id) {
 970                return Err(anyhow!("channel does not exist"));
 971            }
 972            if !self.users.lock().contains_key(&sender_id) {
 973                return Err(anyhow!("user does not exist"));
 974            }
 975
 976            let mut messages = self.channel_messages.lock();
 977            if let Some(message) = messages
 978                .values()
 979                .find(|message| message.nonce.as_u128() == nonce)
 980            {
 981                Ok(message.id)
 982            } else {
 983                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
 984                messages.insert(
 985                    message_id,
 986                    ChannelMessage {
 987                        id: message_id,
 988                        channel_id,
 989                        sender_id,
 990                        body: body.to_string(),
 991                        sent_at: timestamp,
 992                        nonce: Uuid::from_u128(nonce),
 993                    },
 994                );
 995                Ok(message_id)
 996            }
 997        }
 998
 999        async fn get_channel_messages(
1000            &self,
1001            channel_id: ChannelId,
1002            count: usize,
1003            before_id: Option<MessageId>,
1004        ) -> Result<Vec<ChannelMessage>> {
1005            let mut messages = self
1006                .channel_messages
1007                .lock()
1008                .values()
1009                .rev()
1010                .filter(|message| {
1011                    message.channel_id == channel_id
1012                        && message.id < before_id.unwrap_or(MessageId::MAX)
1013                })
1014                .take(count)
1015                .cloned()
1016                .collect::<Vec<_>>();
1017            messages.sort_unstable_by_key(|message| message.id);
1018            Ok(messages)
1019        }
1020
1021        async fn teardown(&self, _name: &str, _url: &str) {}
1022    }
1023}