db.rs

   1use anyhow::Context;
   2use anyhow::Result;
   3use async_trait::async_trait;
   4use futures::executor::block_on;
   5use serde::Serialize;
   6pub use sqlx::postgres::PgPoolOptions as DbOptions;
   7use sqlx::{types::Uuid, FromRow};
   8use time::OffsetDateTime;
   9use tokio::task::yield_now;
  10
  11macro_rules! test_support {
  12    ($self:ident, { $($token:tt)* }) => {{
  13        let body = async {
  14            $($token)*
  15        };
  16        if $self.test_mode {
  17            yield_now().await;
  18            block_on(body)
  19        } else {
  20            body.await
  21        }
  22    }};
  23}
  24
  25#[async_trait]
  26pub trait Db: Send + Sync {
  27    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  28    async fn get_all_users(&self) -> Result<Vec<User>>;
  29    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  30    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  31    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  32    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  33    async fn destroy_user(&self, id: UserId) -> Result<()>;
  34    async fn create_access_token_hash(
  35        &self,
  36        user_id: UserId,
  37        access_token_hash: &str,
  38        max_access_token_count: usize,
  39    ) -> Result<()>;
  40    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  41    #[cfg(any(test, feature = "seed-support"))]
  42    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  43    #[cfg(any(test, feature = "seed-support"))]
  44    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  45    #[cfg(any(test, feature = "seed-support"))]
  46    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  47    #[cfg(any(test, feature = "seed-support"))]
  48    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  49    #[cfg(any(test, feature = "seed-support"))]
  50    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  51    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  52    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  53        -> Result<bool>;
  54    #[cfg(any(test, feature = "seed-support"))]
  55    async fn add_channel_member(
  56        &self,
  57        channel_id: ChannelId,
  58        user_id: UserId,
  59        is_admin: bool,
  60    ) -> Result<()>;
  61    async fn create_channel_message(
  62        &self,
  63        channel_id: ChannelId,
  64        sender_id: UserId,
  65        body: &str,
  66        timestamp: OffsetDateTime,
  67        nonce: u128,
  68    ) -> Result<MessageId>;
  69    async fn get_channel_messages(
  70        &self,
  71        channel_id: ChannelId,
  72        count: usize,
  73        before_id: Option<MessageId>,
  74    ) -> Result<Vec<ChannelMessage>>;
  75    #[cfg(test)]
  76    async fn teardown(&self, name: &str, url: &str);
  77}
  78
  79pub struct PostgresDb {
  80    pool: sqlx::PgPool,
  81    test_mode: bool,
  82}
  83
  84impl PostgresDb {
  85    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  86        let pool = DbOptions::new()
  87            .max_connections(max_connections)
  88            .connect(url)
  89            .await
  90            .context("failed to connect to postgres database")?;
  91        Ok(Self {
  92            pool,
  93            test_mode: false,
  94        })
  95    }
  96}
  97
  98#[async_trait]
  99impl Db for PostgresDb {
 100    // users
 101
 102    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 103        test_support!(self, {
 104            let query = "
 105                INSERT INTO users (github_login, admin)
 106                VALUES ($1, $2)
 107                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 108                RETURNING id
 109            ";
 110            Ok(sqlx::query_scalar(query)
 111                .bind(github_login)
 112                .bind(admin)
 113                .fetch_one(&self.pool)
 114                .await
 115                .map(UserId)?)
 116        })
 117    }
 118
 119    async fn get_all_users(&self) -> Result<Vec<User>> {
 120        test_support!(self, {
 121            let query = "SELECT * FROM users ORDER BY github_login ASC";
 122            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 123        })
 124    }
 125
 126    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 127        let users = self.get_users_by_ids(vec![id]).await?;
 128        Ok(users.into_iter().next())
 129    }
 130
 131    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 132        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 133        test_support!(self, {
 134            let query = "
 135                SELECT users.*
 136                FROM users
 137                WHERE users.id = ANY ($1)
 138            ";
 139
 140            Ok(sqlx::query_as(query)
 141                .bind(&ids)
 142                .fetch_all(&self.pool)
 143                .await?)
 144        })
 145    }
 146
 147    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 148        test_support!(self, {
 149            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 150            Ok(sqlx::query_as(query)
 151                .bind(github_login)
 152                .fetch_optional(&self.pool)
 153                .await?)
 154        })
 155    }
 156
 157    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 158        test_support!(self, {
 159            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 160            Ok(sqlx::query(query)
 161                .bind(is_admin)
 162                .bind(id.0)
 163                .execute(&self.pool)
 164                .await
 165                .map(drop)?)
 166        })
 167    }
 168
 169    async fn destroy_user(&self, id: UserId) -> Result<()> {
 170        test_support!(self, {
 171            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 172            sqlx::query(query)
 173                .bind(id.0)
 174                .execute(&self.pool)
 175                .await
 176                .map(drop)?;
 177            let query = "DELETE FROM users WHERE id = $1;";
 178            Ok(sqlx::query(query)
 179                .bind(id.0)
 180                .execute(&self.pool)
 181                .await
 182                .map(drop)?)
 183        })
 184    }
 185
 186    // access tokens
 187
 188    async fn create_access_token_hash(
 189        &self,
 190        user_id: UserId,
 191        access_token_hash: &str,
 192        max_access_token_count: usize,
 193    ) -> Result<()> {
 194        test_support!(self, {
 195            let insert_query = "
 196                INSERT INTO access_tokens (user_id, hash)
 197                VALUES ($1, $2);
 198            ";
 199            let cleanup_query = "
 200                DELETE FROM access_tokens
 201                WHERE id IN (
 202                    SELECT id from access_tokens
 203                    WHERE user_id = $1
 204                    ORDER BY id DESC
 205                    OFFSET $3
 206                )
 207            ";
 208
 209            let mut tx = self.pool.begin().await?;
 210            sqlx::query(insert_query)
 211                .bind(user_id.0)
 212                .bind(access_token_hash)
 213                .execute(&mut tx)
 214                .await?;
 215            sqlx::query(cleanup_query)
 216                .bind(user_id.0)
 217                .bind(access_token_hash)
 218                .bind(max_access_token_count as u32)
 219                .execute(&mut tx)
 220                .await?;
 221            Ok(tx.commit().await?)
 222        })
 223    }
 224
 225    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 226        test_support!(self, {
 227            let query = "
 228                SELECT hash
 229                FROM access_tokens
 230                WHERE user_id = $1
 231                ORDER BY id DESC
 232            ";
 233            Ok(sqlx::query_scalar(query)
 234                .bind(user_id.0)
 235                .fetch_all(&self.pool)
 236                .await?)
 237        })
 238    }
 239
 240    // orgs
 241
 242    #[allow(unused)] // Help rust-analyzer
 243    #[cfg(any(test, feature = "seed-support"))]
 244    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 245        test_support!(self, {
 246            let query = "
 247                SELECT *
 248                FROM orgs
 249                WHERE slug = $1
 250            ";
 251            Ok(sqlx::query_as(query)
 252                .bind(slug)
 253                .fetch_optional(&self.pool)
 254                .await?)
 255        })
 256    }
 257
 258    #[cfg(any(test, feature = "seed-support"))]
 259    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 260        test_support!(self, {
 261            let query = "
 262                INSERT INTO orgs (name, slug)
 263                VALUES ($1, $2)
 264                RETURNING id
 265            ";
 266            Ok(sqlx::query_scalar(query)
 267                .bind(name)
 268                .bind(slug)
 269                .fetch_one(&self.pool)
 270                .await
 271                .map(OrgId)?)
 272        })
 273    }
 274
 275    #[cfg(any(test, feature = "seed-support"))]
 276    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 277        test_support!(self, {
 278            let query = "
 279                INSERT INTO org_memberships (org_id, user_id, admin)
 280                VALUES ($1, $2, $3)
 281                ON CONFLICT DO NOTHING
 282            ";
 283            Ok(sqlx::query(query)
 284                .bind(org_id.0)
 285                .bind(user_id.0)
 286                .bind(is_admin)
 287                .execute(&self.pool)
 288                .await
 289                .map(drop)?)
 290        })
 291    }
 292
 293    // channels
 294
 295    #[cfg(any(test, feature = "seed-support"))]
 296    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 297        test_support!(self, {
 298            let query = "
 299                INSERT INTO channels (owner_id, owner_is_user, name)
 300                VALUES ($1, false, $2)
 301                RETURNING id
 302            ";
 303            Ok(sqlx::query_scalar(query)
 304                .bind(org_id.0)
 305                .bind(name)
 306                .fetch_one(&self.pool)
 307                .await
 308                .map(ChannelId)?)
 309        })
 310    }
 311
 312    #[allow(unused)] // Help rust-analyzer
 313    #[cfg(any(test, feature = "seed-support"))]
 314    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 315        test_support!(self, {
 316            let query = "
 317                SELECT *
 318                FROM channels
 319                WHERE
 320                    channels.owner_is_user = false AND
 321                    channels.owner_id = $1
 322            ";
 323            Ok(sqlx::query_as(query)
 324                .bind(org_id.0)
 325                .fetch_all(&self.pool)
 326                .await?)
 327        })
 328    }
 329
 330    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 331        test_support!(self, {
 332            let query = "
 333                SELECT
 334                    channels.*
 335                FROM
 336                    channel_memberships, channels
 337                WHERE
 338                    channel_memberships.user_id = $1 AND
 339                    channel_memberships.channel_id = channels.id
 340            ";
 341            Ok(sqlx::query_as(query)
 342                .bind(user_id.0)
 343                .fetch_all(&self.pool)
 344                .await?)
 345        })
 346    }
 347
 348    async fn can_user_access_channel(
 349        &self,
 350        user_id: UserId,
 351        channel_id: ChannelId,
 352    ) -> Result<bool> {
 353        test_support!(self, {
 354            let query = "
 355                SELECT id
 356                FROM channel_memberships
 357                WHERE user_id = $1 AND channel_id = $2
 358                LIMIT 1
 359            ";
 360            Ok(sqlx::query_scalar::<_, i32>(query)
 361                .bind(user_id.0)
 362                .bind(channel_id.0)
 363                .fetch_optional(&self.pool)
 364                .await
 365                .map(|e| e.is_some())?)
 366        })
 367    }
 368
 369    #[cfg(any(test, feature = "seed-support"))]
 370    async fn add_channel_member(
 371        &self,
 372        channel_id: ChannelId,
 373        user_id: UserId,
 374        is_admin: bool,
 375    ) -> Result<()> {
 376        test_support!(self, {
 377            let query = "
 378                INSERT INTO channel_memberships (channel_id, user_id, admin)
 379                VALUES ($1, $2, $3)
 380                ON CONFLICT DO NOTHING
 381            ";
 382            Ok(sqlx::query(query)
 383                .bind(channel_id.0)
 384                .bind(user_id.0)
 385                .bind(is_admin)
 386                .execute(&self.pool)
 387                .await
 388                .map(drop)?)
 389        })
 390    }
 391
 392    // messages
 393
 394    async fn create_channel_message(
 395        &self,
 396        channel_id: ChannelId,
 397        sender_id: UserId,
 398        body: &str,
 399        timestamp: OffsetDateTime,
 400        nonce: u128,
 401    ) -> Result<MessageId> {
 402        test_support!(self, {
 403            let query = "
 404                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 405                VALUES ($1, $2, $3, $4, $5)
 406                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 407                RETURNING id
 408            ";
 409            Ok(sqlx::query_scalar(query)
 410                .bind(channel_id.0)
 411                .bind(sender_id.0)
 412                .bind(body)
 413                .bind(timestamp)
 414                .bind(Uuid::from_u128(nonce))
 415                .fetch_one(&self.pool)
 416                .await
 417                .map(MessageId)?)
 418        })
 419    }
 420
 421    async fn get_channel_messages(
 422        &self,
 423        channel_id: ChannelId,
 424        count: usize,
 425        before_id: Option<MessageId>,
 426    ) -> Result<Vec<ChannelMessage>> {
 427        test_support!(self, {
 428            let query = r#"
 429                SELECT * FROM (
 430                    SELECT
 431                        id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 432                    FROM
 433                        channel_messages
 434                    WHERE
 435                        channel_id = $1 AND
 436                        id < $2
 437                    ORDER BY id DESC
 438                    LIMIT $3
 439                ) as recent_messages
 440                ORDER BY id ASC
 441            "#;
 442            Ok(sqlx::query_as(query)
 443                .bind(channel_id.0)
 444                .bind(before_id.unwrap_or(MessageId::MAX))
 445                .bind(count as i64)
 446                .fetch_all(&self.pool)
 447                .await?)
 448        })
 449    }
 450
 451    #[cfg(test)]
 452    async fn teardown(&self, name: &str, url: &str) {
 453        use util::ResultExt;
 454
 455        test_support!(self, {
 456            let query = "
 457                SELECT pg_terminate_backend(pg_stat_activity.pid)
 458                FROM pg_stat_activity
 459                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
 460            ";
 461            sqlx::query(query)
 462                .bind(name)
 463                .execute(&self.pool)
 464                .await
 465                .log_err();
 466            self.pool.close().await;
 467            <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 468                .await
 469                .log_err();
 470        })
 471    }
 472}
 473
 474macro_rules! id_type {
 475    ($name:ident) => {
 476        #[derive(
 477            Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 478        )]
 479        #[sqlx(transparent)]
 480        #[serde(transparent)]
 481        pub struct $name(pub i32);
 482
 483        impl $name {
 484            #[allow(unused)]
 485            pub const MAX: Self = Self(i32::MAX);
 486
 487            #[allow(unused)]
 488            pub fn from_proto(value: u64) -> Self {
 489                Self(value as i32)
 490            }
 491
 492            #[allow(unused)]
 493            pub fn to_proto(&self) -> u64 {
 494                self.0 as u64
 495            }
 496        }
 497    };
 498}
 499
 500id_type!(UserId);
 501#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 502pub struct User {
 503    pub id: UserId,
 504    pub github_login: String,
 505    pub admin: bool,
 506}
 507
 508id_type!(OrgId);
 509#[derive(FromRow)]
 510pub struct Org {
 511    pub id: OrgId,
 512    pub name: String,
 513    pub slug: String,
 514}
 515
 516id_type!(ChannelId);
 517#[derive(Clone, Debug, FromRow, Serialize)]
 518pub struct Channel {
 519    pub id: ChannelId,
 520    pub name: String,
 521    pub owner_id: i32,
 522    pub owner_is_user: bool,
 523}
 524
 525id_type!(MessageId);
 526#[derive(Clone, Debug, FromRow)]
 527pub struct ChannelMessage {
 528    pub id: MessageId,
 529    pub channel_id: ChannelId,
 530    pub sender_id: UserId,
 531    pub body: String,
 532    pub sent_at: OffsetDateTime,
 533    pub nonce: Uuid,
 534}
 535
 536#[cfg(test)]
 537pub mod tests {
 538    use super::*;
 539    use anyhow::anyhow;
 540    use collections::BTreeMap;
 541    use gpui::{executor::Background, TestAppContext};
 542    use lazy_static::lazy_static;
 543    use parking_lot::Mutex;
 544    use rand::prelude::*;
 545    use sqlx::{
 546        migrate::{MigrateDatabase, Migrator},
 547        Postgres,
 548    };
 549    use std::{path::Path, sync::Arc};
 550    use util::post_inc;
 551
 552    #[gpui::test]
 553    async fn test_get_users_by_ids(cx: &mut TestAppContext) {
 554        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 555            let db = test_db.db();
 556
 557            let user = db.create_user("user", false).await.unwrap();
 558            let friend1 = db.create_user("friend-1", false).await.unwrap();
 559            let friend2 = db.create_user("friend-2", false).await.unwrap();
 560            let friend3 = db.create_user("friend-3", false).await.unwrap();
 561
 562            assert_eq!(
 563                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 564                    .await
 565                    .unwrap(),
 566                vec![
 567                    User {
 568                        id: user,
 569                        github_login: "user".to_string(),
 570                        admin: false,
 571                    },
 572                    User {
 573                        id: friend1,
 574                        github_login: "friend-1".to_string(),
 575                        admin: false,
 576                    },
 577                    User {
 578                        id: friend2,
 579                        github_login: "friend-2".to_string(),
 580                        admin: false,
 581                    },
 582                    User {
 583                        id: friend3,
 584                        github_login: "friend-3".to_string(),
 585                        admin: false,
 586                    }
 587                ]
 588            );
 589        }
 590    }
 591
 592    #[gpui::test]
 593    async fn test_recent_channel_messages(cx: &mut TestAppContext) {
 594        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 595            let db = test_db.db();
 596            let user = db.create_user("user", false).await.unwrap();
 597            let org = db.create_org("org", "org").await.unwrap();
 598            let channel = db.create_org_channel(org, "channel").await.unwrap();
 599            for i in 0..10 {
 600                db.create_channel_message(
 601                    channel,
 602                    user,
 603                    &i.to_string(),
 604                    OffsetDateTime::now_utc(),
 605                    i,
 606                )
 607                .await
 608                .unwrap();
 609            }
 610
 611            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 612            assert_eq!(
 613                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 614                ["5", "6", "7", "8", "9"]
 615            );
 616
 617            let prev_messages = db
 618                .get_channel_messages(channel, 4, Some(messages[0].id))
 619                .await
 620                .unwrap();
 621            assert_eq!(
 622                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 623                ["1", "2", "3", "4"]
 624            );
 625        }
 626    }
 627
 628    #[gpui::test]
 629    async fn test_channel_message_nonces(cx: &mut TestAppContext) {
 630        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 631            let db = test_db.db();
 632            let user = db.create_user("user", false).await.unwrap();
 633            let org = db.create_org("org", "org").await.unwrap();
 634            let channel = db.create_org_channel(org, "channel").await.unwrap();
 635
 636            let msg1_id = db
 637                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 638                .await
 639                .unwrap();
 640            let msg2_id = db
 641                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 642                .await
 643                .unwrap();
 644            let msg3_id = db
 645                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
 646                .await
 647                .unwrap();
 648            let msg4_id = db
 649                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
 650                .await
 651                .unwrap();
 652
 653            assert_ne!(msg1_id, msg2_id);
 654            assert_eq!(msg1_id, msg3_id);
 655            assert_eq!(msg2_id, msg4_id);
 656        }
 657    }
 658
 659    #[gpui::test]
 660    async fn test_create_access_tokens() {
 661        let test_db = TestDb::postgres();
 662        let db = test_db.db();
 663        let user = db.create_user("the-user", false).await.unwrap();
 664
 665        db.create_access_token_hash(user, "h1", 3).await.unwrap();
 666        db.create_access_token_hash(user, "h2", 3).await.unwrap();
 667        assert_eq!(
 668            db.get_access_token_hashes(user).await.unwrap(),
 669            &["h2".to_string(), "h1".to_string()]
 670        );
 671
 672        db.create_access_token_hash(user, "h3", 3).await.unwrap();
 673        assert_eq!(
 674            db.get_access_token_hashes(user).await.unwrap(),
 675            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
 676        );
 677
 678        db.create_access_token_hash(user, "h4", 3).await.unwrap();
 679        assert_eq!(
 680            db.get_access_token_hashes(user).await.unwrap(),
 681            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
 682        );
 683
 684        db.create_access_token_hash(user, "h5", 3).await.unwrap();
 685        assert_eq!(
 686            db.get_access_token_hashes(user).await.unwrap(),
 687            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
 688        );
 689    }
 690
 691    pub struct TestDb {
 692        pub db: Option<Arc<dyn Db>>,
 693        pub name: String,
 694        pub url: String,
 695    }
 696
 697    impl TestDb {
 698        pub fn postgres() -> Self {
 699            lazy_static! {
 700                static ref LOCK: Mutex<()> = Mutex::new(());
 701            }
 702
 703            let _guard = LOCK.lock();
 704            let mut rng = StdRng::from_entropy();
 705            let name = format!("zed-test-{}", rng.gen::<u128>());
 706            let url = format!("postgres://postgres@localhost/{}", name);
 707            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 708            let db = block_on(async {
 709                Postgres::create_database(&url)
 710                    .await
 711                    .expect("failed to create test db");
 712                let mut db = PostgresDb::new(&url, 5).await.unwrap();
 713                db.test_mode = true;
 714                let migrator = Migrator::new(migrations_path).await.unwrap();
 715                migrator.run(&db.pool).await.unwrap();
 716                db
 717            });
 718            Self {
 719                db: Some(Arc::new(db)),
 720                name,
 721                url,
 722            }
 723        }
 724
 725        pub fn fake(background: Arc<Background>) -> Self {
 726            Self {
 727                db: Some(Arc::new(FakeDb::new(background))),
 728                name: "fake".to_string(),
 729                url: "fake".to_string(),
 730            }
 731        }
 732
 733        pub fn db(&self) -> &Arc<dyn Db> {
 734            self.db.as_ref().unwrap()
 735        }
 736    }
 737
 738    impl Drop for TestDb {
 739        fn drop(&mut self) {
 740            if let Some(db) = self.db.take() {
 741                block_on(db.teardown(&self.name, &self.url));
 742            }
 743        }
 744    }
 745
 746    pub struct FakeDb {
 747        background: Arc<Background>,
 748        users: Mutex<BTreeMap<UserId, User>>,
 749        next_user_id: Mutex<i32>,
 750        orgs: Mutex<BTreeMap<OrgId, Org>>,
 751        next_org_id: Mutex<i32>,
 752        org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
 753        channels: Mutex<BTreeMap<ChannelId, Channel>>,
 754        next_channel_id: Mutex<i32>,
 755        channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
 756        channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
 757        next_channel_message_id: Mutex<i32>,
 758    }
 759
 760    impl FakeDb {
 761        pub fn new(background: Arc<Background>) -> Self {
 762            Self {
 763                background,
 764                users: Default::default(),
 765                next_user_id: Mutex::new(1),
 766                orgs: Default::default(),
 767                next_org_id: Mutex::new(1),
 768                org_memberships: Default::default(),
 769                channels: Default::default(),
 770                next_channel_id: Mutex::new(1),
 771                channel_memberships: Default::default(),
 772                channel_messages: Default::default(),
 773                next_channel_message_id: Mutex::new(1),
 774            }
 775        }
 776    }
 777
 778    #[async_trait]
 779    impl Db for FakeDb {
 780        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 781            self.background.simulate_random_delay().await;
 782
 783            let mut users = self.users.lock();
 784            if let Some(user) = users
 785                .values()
 786                .find(|user| user.github_login == github_login)
 787            {
 788                Ok(user.id)
 789            } else {
 790                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
 791                users.insert(
 792                    user_id,
 793                    User {
 794                        id: user_id,
 795                        github_login: github_login.to_string(),
 796                        admin,
 797                    },
 798                );
 799                Ok(user_id)
 800            }
 801        }
 802
 803        async fn get_all_users(&self) -> Result<Vec<User>> {
 804            unimplemented!()
 805        }
 806
 807        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 808            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
 809        }
 810
 811        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 812            self.background.simulate_random_delay().await;
 813            let users = self.users.lock();
 814            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
 815        }
 816
 817        async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
 818            unimplemented!()
 819        }
 820
 821        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
 822            unimplemented!()
 823        }
 824
 825        async fn destroy_user(&self, _id: UserId) -> Result<()> {
 826            unimplemented!()
 827        }
 828
 829        async fn create_access_token_hash(
 830            &self,
 831            _user_id: UserId,
 832            _access_token_hash: &str,
 833            _max_access_token_count: usize,
 834        ) -> Result<()> {
 835            unimplemented!()
 836        }
 837
 838        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
 839            unimplemented!()
 840        }
 841
 842        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
 843            unimplemented!()
 844        }
 845
 846        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 847            self.background.simulate_random_delay().await;
 848            let mut orgs = self.orgs.lock();
 849            if orgs.values().any(|org| org.slug == slug) {
 850                Err(anyhow!("org already exists"))
 851            } else {
 852                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
 853                orgs.insert(
 854                    org_id,
 855                    Org {
 856                        id: org_id,
 857                        name: name.to_string(),
 858                        slug: slug.to_string(),
 859                    },
 860                );
 861                Ok(org_id)
 862            }
 863        }
 864
 865        async fn add_org_member(
 866            &self,
 867            org_id: OrgId,
 868            user_id: UserId,
 869            is_admin: bool,
 870        ) -> Result<()> {
 871            self.background.simulate_random_delay().await;
 872            if !self.orgs.lock().contains_key(&org_id) {
 873                return Err(anyhow!("org does not exist"));
 874            }
 875            if !self.users.lock().contains_key(&user_id) {
 876                return Err(anyhow!("user does not exist"));
 877            }
 878
 879            self.org_memberships
 880                .lock()
 881                .entry((org_id, user_id))
 882                .or_insert(is_admin);
 883            Ok(())
 884        }
 885
 886        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 887            self.background.simulate_random_delay().await;
 888            if !self.orgs.lock().contains_key(&org_id) {
 889                return Err(anyhow!("org does not exist"));
 890            }
 891
 892            let mut channels = self.channels.lock();
 893            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
 894            channels.insert(
 895                channel_id,
 896                Channel {
 897                    id: channel_id,
 898                    name: name.to_string(),
 899                    owner_id: org_id.0,
 900                    owner_is_user: false,
 901                },
 902            );
 903            Ok(channel_id)
 904        }
 905
 906        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 907            self.background.simulate_random_delay().await;
 908            Ok(self
 909                .channels
 910                .lock()
 911                .values()
 912                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
 913                .cloned()
 914                .collect())
 915        }
 916
 917        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 918            self.background.simulate_random_delay().await;
 919            let channels = self.channels.lock();
 920            let memberships = self.channel_memberships.lock();
 921            Ok(channels
 922                .values()
 923                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
 924                .cloned()
 925                .collect())
 926        }
 927
 928        async fn can_user_access_channel(
 929            &self,
 930            user_id: UserId,
 931            channel_id: ChannelId,
 932        ) -> Result<bool> {
 933            self.background.simulate_random_delay().await;
 934            Ok(self
 935                .channel_memberships
 936                .lock()
 937                .contains_key(&(channel_id, user_id)))
 938        }
 939
 940        async fn add_channel_member(
 941            &self,
 942            channel_id: ChannelId,
 943            user_id: UserId,
 944            is_admin: bool,
 945        ) -> Result<()> {
 946            self.background.simulate_random_delay().await;
 947            if !self.channels.lock().contains_key(&channel_id) {
 948                return Err(anyhow!("channel does not exist"));
 949            }
 950            if !self.users.lock().contains_key(&user_id) {
 951                return Err(anyhow!("user does not exist"));
 952            }
 953
 954            self.channel_memberships
 955                .lock()
 956                .entry((channel_id, user_id))
 957                .or_insert(is_admin);
 958            Ok(())
 959        }
 960
 961        async fn create_channel_message(
 962            &self,
 963            channel_id: ChannelId,
 964            sender_id: UserId,
 965            body: &str,
 966            timestamp: OffsetDateTime,
 967            nonce: u128,
 968        ) -> Result<MessageId> {
 969            self.background.simulate_random_delay().await;
 970            if !self.channels.lock().contains_key(&channel_id) {
 971                return Err(anyhow!("channel does not exist"));
 972            }
 973            if !self.users.lock().contains_key(&sender_id) {
 974                return Err(anyhow!("user does not exist"));
 975            }
 976
 977            let mut messages = self.channel_messages.lock();
 978            if let Some(message) = messages
 979                .values()
 980                .find(|message| message.nonce.as_u128() == nonce)
 981            {
 982                Ok(message.id)
 983            } else {
 984                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
 985                messages.insert(
 986                    message_id,
 987                    ChannelMessage {
 988                        id: message_id,
 989                        channel_id,
 990                        sender_id,
 991                        body: body.to_string(),
 992                        sent_at: timestamp,
 993                        nonce: Uuid::from_u128(nonce),
 994                    },
 995                );
 996                Ok(message_id)
 997            }
 998        }
 999
1000        async fn get_channel_messages(
1001            &self,
1002            channel_id: ChannelId,
1003            count: usize,
1004            before_id: Option<MessageId>,
1005        ) -> Result<Vec<ChannelMessage>> {
1006            let mut messages = self
1007                .channel_messages
1008                .lock()
1009                .values()
1010                .rev()
1011                .filter(|message| {
1012                    message.channel_id == channel_id
1013                        && message.id < before_id.unwrap_or(MessageId::MAX)
1014                })
1015                .take(count)
1016                .cloned()
1017                .collect::<Vec<_>>();
1018            messages.sort_unstable_by_key(|message| message.id);
1019            Ok(messages)
1020        }
1021
1022        async fn teardown(&self, _name: &str, _url: &str) {}
1023    }
1024}