db.rs

   1use anyhow::Context;
   2use anyhow::Result;
   3use async_std::task::{block_on, yield_now};
   4use async_trait::async_trait;
   5use serde::Serialize;
   6pub use sqlx::postgres::PgPoolOptions as DbOptions;
   7use sqlx::{types::Uuid, FromRow};
   8use time::OffsetDateTime;
   9
  10macro_rules! test_support {
  11    ($self:ident, { $($token:tt)* }) => {{
  12        let body = async {
  13            $($token)*
  14        };
  15        if $self.test_mode {
  16            yield_now().await;
  17            block_on(body)
  18        } else {
  19            body.await
  20        }
  21    }};
  22}
  23
  24#[async_trait]
  25pub trait Db: Send + Sync {
  26    async fn create_signup(
  27        &self,
  28        github_login: &str,
  29        email_address: &str,
  30        about: &str,
  31        wants_releases: bool,
  32        wants_updates: bool,
  33        wants_community: bool,
  34    ) -> Result<SignupId>;
  35    async fn get_all_signups(&self) -> Result<Vec<Signup>>;
  36    async fn destroy_signup(&self, id: SignupId) -> Result<()>;
  37    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  38    async fn get_all_users(&self) -> Result<Vec<User>>;
  39    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  40    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  41    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  42    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  43    async fn destroy_user(&self, id: UserId) -> Result<()>;
  44    async fn create_access_token_hash(
  45        &self,
  46        user_id: UserId,
  47        access_token_hash: &str,
  48        max_access_token_count: usize,
  49    ) -> Result<()>;
  50    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  51    #[cfg(any(test, feature = "seed-support"))]
  52    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  53    #[cfg(any(test, feature = "seed-support"))]
  54    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  55    #[cfg(any(test, feature = "seed-support"))]
  56    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  57    #[cfg(any(test, feature = "seed-support"))]
  58    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  59    #[cfg(any(test, feature = "seed-support"))]
  60    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  61    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  62    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  63        -> Result<bool>;
  64    #[cfg(any(test, feature = "seed-support"))]
  65    async fn add_channel_member(
  66        &self,
  67        channel_id: ChannelId,
  68        user_id: UserId,
  69        is_admin: bool,
  70    ) -> Result<()>;
  71    async fn create_channel_message(
  72        &self,
  73        channel_id: ChannelId,
  74        sender_id: UserId,
  75        body: &str,
  76        timestamp: OffsetDateTime,
  77        nonce: u128,
  78    ) -> Result<MessageId>;
  79    async fn get_channel_messages(
  80        &self,
  81        channel_id: ChannelId,
  82        count: usize,
  83        before_id: Option<MessageId>,
  84    ) -> Result<Vec<ChannelMessage>>;
  85    #[cfg(test)]
  86    async fn teardown(&self, name: &str, url: &str);
  87}
  88
  89pub struct PostgresDb {
  90    pool: sqlx::PgPool,
  91    test_mode: bool,
  92}
  93
  94impl PostgresDb {
  95    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
  96        let pool = DbOptions::new()
  97            .max_connections(max_connections)
  98            .connect(url)
  99            .await
 100            .context("failed to connect to postgres database")?;
 101        Ok(Self {
 102            pool,
 103            test_mode: false,
 104        })
 105    }
 106}
 107
 108#[async_trait]
 109impl Db for PostgresDb {
 110    // signups
 111    async fn create_signup(
 112        &self,
 113        github_login: &str,
 114        email_address: &str,
 115        about: &str,
 116        wants_releases: bool,
 117        wants_updates: bool,
 118        wants_community: bool,
 119    ) -> Result<SignupId> {
 120        test_support!(self, {
 121            let query = "
 122                INSERT INTO signups (
 123                    github_login,
 124                    email_address,
 125                    about,
 126                    wants_releases,
 127                    wants_updates,
 128                    wants_community
 129                )
 130                VALUES ($1, $2, $3, $4, $5, $6)
 131                RETURNING id
 132            ";
 133            Ok(sqlx::query_scalar(query)
 134                .bind(github_login)
 135                .bind(email_address)
 136                .bind(about)
 137                .bind(wants_releases)
 138                .bind(wants_updates)
 139                .bind(wants_community)
 140                .fetch_one(&self.pool)
 141                .await
 142                .map(SignupId)?)
 143        })
 144    }
 145
 146    async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 147        test_support!(self, {
 148            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 149            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 150        })
 151    }
 152
 153    async fn destroy_signup(&self, id: SignupId) -> Result<()> {
 154        test_support!(self, {
 155            let query = "DELETE FROM signups WHERE id = $1";
 156            Ok(sqlx::query(query)
 157                .bind(id.0)
 158                .execute(&self.pool)
 159                .await
 160                .map(drop)?)
 161        })
 162    }
 163
 164    // users
 165
 166    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 167        test_support!(self, {
 168            let query = "
 169                INSERT INTO users (github_login, admin)
 170                VALUES ($1, $2)
 171                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 172                RETURNING id
 173            ";
 174            Ok(sqlx::query_scalar(query)
 175                .bind(github_login)
 176                .bind(admin)
 177                .fetch_one(&self.pool)
 178                .await
 179                .map(UserId)?)
 180        })
 181    }
 182
 183    async fn get_all_users(&self) -> Result<Vec<User>> {
 184        test_support!(self, {
 185            let query = "SELECT * FROM users ORDER BY github_login ASC";
 186            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 187        })
 188    }
 189
 190    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 191        let users = self.get_users_by_ids(vec![id]).await?;
 192        Ok(users.into_iter().next())
 193    }
 194
 195    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 196        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 197        test_support!(self, {
 198            let query = "
 199                SELECT users.*
 200                FROM users
 201                WHERE users.id = ANY ($1)
 202            ";
 203
 204            Ok(sqlx::query_as(query)
 205                .bind(&ids)
 206                .fetch_all(&self.pool)
 207                .await?)
 208        })
 209    }
 210
 211    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 212        test_support!(self, {
 213            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 214            Ok(sqlx::query_as(query)
 215                .bind(github_login)
 216                .fetch_optional(&self.pool)
 217                .await?)
 218        })
 219    }
 220
 221    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 222        test_support!(self, {
 223            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 224            Ok(sqlx::query(query)
 225                .bind(is_admin)
 226                .bind(id.0)
 227                .execute(&self.pool)
 228                .await
 229                .map(drop)?)
 230        })
 231    }
 232
 233    async fn destroy_user(&self, id: UserId) -> Result<()> {
 234        test_support!(self, {
 235            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 236            sqlx::query(query)
 237                .bind(id.0)
 238                .execute(&self.pool)
 239                .await
 240                .map(drop)?;
 241            let query = "DELETE FROM users WHERE id = $1;";
 242            Ok(sqlx::query(query)
 243                .bind(id.0)
 244                .execute(&self.pool)
 245                .await
 246                .map(drop)?)
 247        })
 248    }
 249
 250    // access tokens
 251
 252    async fn create_access_token_hash(
 253        &self,
 254        user_id: UserId,
 255        access_token_hash: &str,
 256        max_access_token_count: usize,
 257    ) -> Result<()> {
 258        test_support!(self, {
 259            let insert_query = "
 260                INSERT INTO access_tokens (user_id, hash)
 261                VALUES ($1, $2);
 262            ";
 263            let cleanup_query = "
 264                DELETE FROM access_tokens
 265                WHERE id IN (
 266                    SELECT id from access_tokens
 267                    WHERE user_id = $1
 268                    ORDER BY id DESC
 269                    OFFSET $3
 270                )
 271            ";
 272
 273            let mut tx = self.pool.begin().await?;
 274            sqlx::query(insert_query)
 275                .bind(user_id.0)
 276                .bind(access_token_hash)
 277                .execute(&mut tx)
 278                .await?;
 279            sqlx::query(cleanup_query)
 280                .bind(user_id.0)
 281                .bind(access_token_hash)
 282                .bind(max_access_token_count as u32)
 283                .execute(&mut tx)
 284                .await?;
 285            Ok(tx.commit().await?)
 286        })
 287    }
 288
 289    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 290        test_support!(self, {
 291            let query = "
 292                SELECT hash
 293                FROM access_tokens
 294                WHERE user_id = $1
 295                ORDER BY id DESC
 296            ";
 297            Ok(sqlx::query_scalar(query)
 298                .bind(user_id.0)
 299                .fetch_all(&self.pool)
 300                .await?)
 301        })
 302    }
 303
 304    // orgs
 305
 306    #[allow(unused)] // Help rust-analyzer
 307    #[cfg(any(test, feature = "seed-support"))]
 308    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 309        test_support!(self, {
 310            let query = "
 311                SELECT *
 312                FROM orgs
 313                WHERE slug = $1
 314            ";
 315            Ok(sqlx::query_as(query)
 316                .bind(slug)
 317                .fetch_optional(&self.pool)
 318                .await?)
 319        })
 320    }
 321
 322    #[cfg(any(test, feature = "seed-support"))]
 323    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 324        test_support!(self, {
 325            let query = "
 326                INSERT INTO orgs (name, slug)
 327                VALUES ($1, $2)
 328                RETURNING id
 329            ";
 330            Ok(sqlx::query_scalar(query)
 331                .bind(name)
 332                .bind(slug)
 333                .fetch_one(&self.pool)
 334                .await
 335                .map(OrgId)?)
 336        })
 337    }
 338
 339    #[cfg(any(test, feature = "seed-support"))]
 340    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 341        test_support!(self, {
 342            let query = "
 343                INSERT INTO org_memberships (org_id, user_id, admin)
 344                VALUES ($1, $2, $3)
 345                ON CONFLICT DO NOTHING
 346            ";
 347            Ok(sqlx::query(query)
 348                .bind(org_id.0)
 349                .bind(user_id.0)
 350                .bind(is_admin)
 351                .execute(&self.pool)
 352                .await
 353                .map(drop)?)
 354        })
 355    }
 356
 357    // channels
 358
 359    #[cfg(any(test, feature = "seed-support"))]
 360    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 361        test_support!(self, {
 362            let query = "
 363                INSERT INTO channels (owner_id, owner_is_user, name)
 364                VALUES ($1, false, $2)
 365                RETURNING id
 366            ";
 367            Ok(sqlx::query_scalar(query)
 368                .bind(org_id.0)
 369                .bind(name)
 370                .fetch_one(&self.pool)
 371                .await
 372                .map(ChannelId)?)
 373        })
 374    }
 375
 376    #[allow(unused)] // Help rust-analyzer
 377    #[cfg(any(test, feature = "seed-support"))]
 378    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 379        test_support!(self, {
 380            let query = "
 381                SELECT *
 382                FROM channels
 383                WHERE
 384                    channels.owner_is_user = false AND
 385                    channels.owner_id = $1
 386            ";
 387            Ok(sqlx::query_as(query)
 388                .bind(org_id.0)
 389                .fetch_all(&self.pool)
 390                .await?)
 391        })
 392    }
 393
 394    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 395        test_support!(self, {
 396            let query = "
 397                SELECT
 398                    channels.*
 399                FROM
 400                    channel_memberships, channels
 401                WHERE
 402                    channel_memberships.user_id = $1 AND
 403                    channel_memberships.channel_id = channels.id
 404            ";
 405            Ok(sqlx::query_as(query)
 406                .bind(user_id.0)
 407                .fetch_all(&self.pool)
 408                .await?)
 409        })
 410    }
 411
 412    async fn can_user_access_channel(
 413        &self,
 414        user_id: UserId,
 415        channel_id: ChannelId,
 416    ) -> Result<bool> {
 417        test_support!(self, {
 418            let query = "
 419                SELECT id
 420                FROM channel_memberships
 421                WHERE user_id = $1 AND channel_id = $2
 422                LIMIT 1
 423            ";
 424            Ok(sqlx::query_scalar::<_, i32>(query)
 425                .bind(user_id.0)
 426                .bind(channel_id.0)
 427                .fetch_optional(&self.pool)
 428                .await
 429                .map(|e| e.is_some())?)
 430        })
 431    }
 432
 433    #[cfg(any(test, feature = "seed-support"))]
 434    async fn add_channel_member(
 435        &self,
 436        channel_id: ChannelId,
 437        user_id: UserId,
 438        is_admin: bool,
 439    ) -> Result<()> {
 440        test_support!(self, {
 441            let query = "
 442                INSERT INTO channel_memberships (channel_id, user_id, admin)
 443                VALUES ($1, $2, $3)
 444                ON CONFLICT DO NOTHING
 445            ";
 446            Ok(sqlx::query(query)
 447                .bind(channel_id.0)
 448                .bind(user_id.0)
 449                .bind(is_admin)
 450                .execute(&self.pool)
 451                .await
 452                .map(drop)?)
 453        })
 454    }
 455
 456    // messages
 457
 458    async fn create_channel_message(
 459        &self,
 460        channel_id: ChannelId,
 461        sender_id: UserId,
 462        body: &str,
 463        timestamp: OffsetDateTime,
 464        nonce: u128,
 465    ) -> Result<MessageId> {
 466        test_support!(self, {
 467            let query = "
 468                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 469                VALUES ($1, $2, $3, $4, $5)
 470                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 471                RETURNING id
 472            ";
 473            Ok(sqlx::query_scalar(query)
 474                .bind(channel_id.0)
 475                .bind(sender_id.0)
 476                .bind(body)
 477                .bind(timestamp)
 478                .bind(Uuid::from_u128(nonce))
 479                .fetch_one(&self.pool)
 480                .await
 481                .map(MessageId)?)
 482        })
 483    }
 484
 485    async fn get_channel_messages(
 486        &self,
 487        channel_id: ChannelId,
 488        count: usize,
 489        before_id: Option<MessageId>,
 490    ) -> Result<Vec<ChannelMessage>> {
 491        test_support!(self, {
 492            let query = r#"
 493                SELECT * FROM (
 494                    SELECT
 495                        id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 496                    FROM
 497                        channel_messages
 498                    WHERE
 499                        channel_id = $1 AND
 500                        id < $2
 501                    ORDER BY id DESC
 502                    LIMIT $3
 503                ) as recent_messages
 504                ORDER BY id ASC
 505            "#;
 506            Ok(sqlx::query_as(query)
 507                .bind(channel_id.0)
 508                .bind(before_id.unwrap_or(MessageId::MAX))
 509                .bind(count as i64)
 510                .fetch_all(&self.pool)
 511                .await?)
 512        })
 513    }
 514
 515    #[cfg(test)]
 516    async fn teardown(&self, name: &str, url: &str) {
 517        use util::ResultExt;
 518
 519        test_support!(self, {
 520            let query = "
 521                SELECT pg_terminate_backend(pg_stat_activity.pid)
 522                FROM pg_stat_activity
 523                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
 524            ";
 525            sqlx::query(query)
 526                .bind(name)
 527                .execute(&self.pool)
 528                .await
 529                .log_err();
 530            self.pool.close().await;
 531            <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 532                .await
 533                .log_err();
 534        })
 535    }
 536}
 537
 538macro_rules! id_type {
 539    ($name:ident) => {
 540        #[derive(
 541            Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 542        )]
 543        #[sqlx(transparent)]
 544        #[serde(transparent)]
 545        pub struct $name(pub i32);
 546
 547        impl $name {
 548            #[allow(unused)]
 549            pub const MAX: Self = Self(i32::MAX);
 550
 551            #[allow(unused)]
 552            pub fn from_proto(value: u64) -> Self {
 553                Self(value as i32)
 554            }
 555
 556            #[allow(unused)]
 557            pub fn to_proto(&self) -> u64 {
 558                self.0 as u64
 559            }
 560        }
 561    };
 562}
 563
 564id_type!(UserId);
 565#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 566pub struct User {
 567    pub id: UserId,
 568    pub github_login: String,
 569    pub admin: bool,
 570}
 571
 572id_type!(OrgId);
 573#[derive(FromRow)]
 574pub struct Org {
 575    pub id: OrgId,
 576    pub name: String,
 577    pub slug: String,
 578}
 579
 580id_type!(SignupId);
 581#[derive(Debug, FromRow, Serialize)]
 582pub struct Signup {
 583    pub id: SignupId,
 584    pub github_login: String,
 585    pub email_address: String,
 586    pub about: String,
 587    pub wants_releases: Option<bool>,
 588    pub wants_updates: Option<bool>,
 589    pub wants_community: Option<bool>,
 590}
 591
 592id_type!(ChannelId);
 593#[derive(Clone, Debug, FromRow, Serialize)]
 594pub struct Channel {
 595    pub id: ChannelId,
 596    pub name: String,
 597    pub owner_id: i32,
 598    pub owner_is_user: bool,
 599}
 600
 601id_type!(MessageId);
 602#[derive(Clone, Debug, FromRow)]
 603pub struct ChannelMessage {
 604    pub id: MessageId,
 605    pub channel_id: ChannelId,
 606    pub sender_id: UserId,
 607    pub body: String,
 608    pub sent_at: OffsetDateTime,
 609    pub nonce: Uuid,
 610}
 611
 612#[cfg(test)]
 613pub mod tests {
 614    use super::*;
 615    use anyhow::anyhow;
 616    use collections::BTreeMap;
 617    use gpui::{executor::Background, TestAppContext};
 618    use lazy_static::lazy_static;
 619    use parking_lot::Mutex;
 620    use rand::prelude::*;
 621    use sqlx::{
 622        migrate::{MigrateDatabase, Migrator},
 623        Postgres,
 624    };
 625    use std::{path::Path, sync::Arc};
 626    use util::post_inc;
 627
 628    #[gpui::test]
 629    async fn test_get_users_by_ids(cx: &mut TestAppContext) {
 630        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 631            let db = test_db.db();
 632
 633            let user = db.create_user("user", false).await.unwrap();
 634            let friend1 = db.create_user("friend-1", false).await.unwrap();
 635            let friend2 = db.create_user("friend-2", false).await.unwrap();
 636            let friend3 = db.create_user("friend-3", false).await.unwrap();
 637
 638            assert_eq!(
 639                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 640                    .await
 641                    .unwrap(),
 642                vec![
 643                    User {
 644                        id: user,
 645                        github_login: "user".to_string(),
 646                        admin: false,
 647                    },
 648                    User {
 649                        id: friend1,
 650                        github_login: "friend-1".to_string(),
 651                        admin: false,
 652                    },
 653                    User {
 654                        id: friend2,
 655                        github_login: "friend-2".to_string(),
 656                        admin: false,
 657                    },
 658                    User {
 659                        id: friend3,
 660                        github_login: "friend-3".to_string(),
 661                        admin: false,
 662                    }
 663                ]
 664            );
 665        }
 666    }
 667
 668    #[gpui::test]
 669    async fn test_recent_channel_messages(cx: &mut TestAppContext) {
 670        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 671            let db = test_db.db();
 672            let user = db.create_user("user", false).await.unwrap();
 673            let org = db.create_org("org", "org").await.unwrap();
 674            let channel = db.create_org_channel(org, "channel").await.unwrap();
 675            for i in 0..10 {
 676                db.create_channel_message(
 677                    channel,
 678                    user,
 679                    &i.to_string(),
 680                    OffsetDateTime::now_utc(),
 681                    i,
 682                )
 683                .await
 684                .unwrap();
 685            }
 686
 687            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 688            assert_eq!(
 689                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 690                ["5", "6", "7", "8", "9"]
 691            );
 692
 693            let prev_messages = db
 694                .get_channel_messages(channel, 4, Some(messages[0].id))
 695                .await
 696                .unwrap();
 697            assert_eq!(
 698                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 699                ["1", "2", "3", "4"]
 700            );
 701        }
 702    }
 703
 704    #[gpui::test]
 705    async fn test_channel_message_nonces(cx: &mut TestAppContext) {
 706        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 707            let db = test_db.db();
 708            let user = db.create_user("user", false).await.unwrap();
 709            let org = db.create_org("org", "org").await.unwrap();
 710            let channel = db.create_org_channel(org, "channel").await.unwrap();
 711
 712            let msg1_id = db
 713                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 714                .await
 715                .unwrap();
 716            let msg2_id = db
 717                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 718                .await
 719                .unwrap();
 720            let msg3_id = db
 721                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
 722                .await
 723                .unwrap();
 724            let msg4_id = db
 725                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
 726                .await
 727                .unwrap();
 728
 729            assert_ne!(msg1_id, msg2_id);
 730            assert_eq!(msg1_id, msg3_id);
 731            assert_eq!(msg2_id, msg4_id);
 732        }
 733    }
 734
 735    #[gpui::test]
 736    async fn test_create_access_tokens() {
 737        let test_db = TestDb::postgres();
 738        let db = test_db.db();
 739        let user = db.create_user("the-user", false).await.unwrap();
 740
 741        db.create_access_token_hash(user, "h1", 3).await.unwrap();
 742        db.create_access_token_hash(user, "h2", 3).await.unwrap();
 743        assert_eq!(
 744            db.get_access_token_hashes(user).await.unwrap(),
 745            &["h2".to_string(), "h1".to_string()]
 746        );
 747
 748        db.create_access_token_hash(user, "h3", 3).await.unwrap();
 749        assert_eq!(
 750            db.get_access_token_hashes(user).await.unwrap(),
 751            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
 752        );
 753
 754        db.create_access_token_hash(user, "h4", 3).await.unwrap();
 755        assert_eq!(
 756            db.get_access_token_hashes(user).await.unwrap(),
 757            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
 758        );
 759
 760        db.create_access_token_hash(user, "h5", 3).await.unwrap();
 761        assert_eq!(
 762            db.get_access_token_hashes(user).await.unwrap(),
 763            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
 764        );
 765    }
 766
 767    pub struct TestDb {
 768        pub db: Option<Arc<dyn Db>>,
 769        pub name: String,
 770        pub url: String,
 771    }
 772
 773    impl TestDb {
 774        pub fn postgres() -> Self {
 775            lazy_static! {
 776                static ref LOCK: Mutex<()> = Mutex::new(());
 777            }
 778
 779            let _guard = LOCK.lock();
 780            let mut rng = StdRng::from_entropy();
 781            let name = format!("zed-test-{}", rng.gen::<u128>());
 782            let url = format!("postgres://postgres@localhost/{}", name);
 783            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 784            let db = block_on(async {
 785                Postgres::create_database(&url)
 786                    .await
 787                    .expect("failed to create test db");
 788                let mut db = PostgresDb::new(&url, 5).await.unwrap();
 789                db.test_mode = true;
 790                let migrator = Migrator::new(migrations_path).await.unwrap();
 791                migrator.run(&db.pool).await.unwrap();
 792                db
 793            });
 794            Self {
 795                db: Some(Arc::new(db)),
 796                name,
 797                url,
 798            }
 799        }
 800
 801        pub fn fake(background: Arc<Background>) -> Self {
 802            Self {
 803                db: Some(Arc::new(FakeDb::new(background))),
 804                name: "fake".to_string(),
 805                url: "fake".to_string(),
 806            }
 807        }
 808
 809        pub fn db(&self) -> &Arc<dyn Db> {
 810            self.db.as_ref().unwrap()
 811        }
 812    }
 813
 814    impl Drop for TestDb {
 815        fn drop(&mut self) {
 816            if let Some(db) = self.db.take() {
 817                block_on(db.teardown(&self.name, &self.url));
 818            }
 819        }
 820    }
 821
 822    pub struct FakeDb {
 823        background: Arc<Background>,
 824        users: Mutex<BTreeMap<UserId, User>>,
 825        next_user_id: Mutex<i32>,
 826        orgs: Mutex<BTreeMap<OrgId, Org>>,
 827        next_org_id: Mutex<i32>,
 828        org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
 829        channels: Mutex<BTreeMap<ChannelId, Channel>>,
 830        next_channel_id: Mutex<i32>,
 831        channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
 832        channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
 833        next_channel_message_id: Mutex<i32>,
 834    }
 835
 836    impl FakeDb {
 837        pub fn new(background: Arc<Background>) -> Self {
 838            Self {
 839                background,
 840                users: Default::default(),
 841                next_user_id: Mutex::new(1),
 842                orgs: Default::default(),
 843                next_org_id: Mutex::new(1),
 844                org_memberships: Default::default(),
 845                channels: Default::default(),
 846                next_channel_id: Mutex::new(1),
 847                channel_memberships: Default::default(),
 848                channel_messages: Default::default(),
 849                next_channel_message_id: Mutex::new(1),
 850            }
 851        }
 852    }
 853
 854    #[async_trait]
 855    impl Db for FakeDb {
 856        async fn create_signup(
 857            &self,
 858            _github_login: &str,
 859            _email_address: &str,
 860            _about: &str,
 861            _wants_releases: bool,
 862            _wants_updates: bool,
 863            _wants_community: bool,
 864        ) -> Result<SignupId> {
 865            unimplemented!()
 866        }
 867
 868        async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 869            unimplemented!()
 870        }
 871
 872        async fn destroy_signup(&self, _id: SignupId) -> Result<()> {
 873            unimplemented!()
 874        }
 875
 876        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 877            self.background.simulate_random_delay().await;
 878
 879            let mut users = self.users.lock();
 880            if let Some(user) = users
 881                .values()
 882                .find(|user| user.github_login == github_login)
 883            {
 884                Ok(user.id)
 885            } else {
 886                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
 887                users.insert(
 888                    user_id,
 889                    User {
 890                        id: user_id,
 891                        github_login: github_login.to_string(),
 892                        admin,
 893                    },
 894                );
 895                Ok(user_id)
 896            }
 897        }
 898
 899        async fn get_all_users(&self) -> Result<Vec<User>> {
 900            unimplemented!()
 901        }
 902
 903        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 904            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
 905        }
 906
 907        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 908            self.background.simulate_random_delay().await;
 909            let users = self.users.lock();
 910            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
 911        }
 912
 913        async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
 914            unimplemented!()
 915        }
 916
 917        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
 918            unimplemented!()
 919        }
 920
 921        async fn destroy_user(&self, _id: UserId) -> Result<()> {
 922            unimplemented!()
 923        }
 924
 925        async fn create_access_token_hash(
 926            &self,
 927            _user_id: UserId,
 928            _access_token_hash: &str,
 929            _max_access_token_count: usize,
 930        ) -> Result<()> {
 931            unimplemented!()
 932        }
 933
 934        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
 935            unimplemented!()
 936        }
 937
 938        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
 939            unimplemented!()
 940        }
 941
 942        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 943            self.background.simulate_random_delay().await;
 944            let mut orgs = self.orgs.lock();
 945            if orgs.values().any(|org| org.slug == slug) {
 946                Err(anyhow!("org already exists"))
 947            } else {
 948                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
 949                orgs.insert(
 950                    org_id,
 951                    Org {
 952                        id: org_id,
 953                        name: name.to_string(),
 954                        slug: slug.to_string(),
 955                    },
 956                );
 957                Ok(org_id)
 958            }
 959        }
 960
 961        async fn add_org_member(
 962            &self,
 963            org_id: OrgId,
 964            user_id: UserId,
 965            is_admin: bool,
 966        ) -> Result<()> {
 967            self.background.simulate_random_delay().await;
 968            if !self.orgs.lock().contains_key(&org_id) {
 969                return Err(anyhow!("org does not exist"));
 970            }
 971            if !self.users.lock().contains_key(&user_id) {
 972                return Err(anyhow!("user does not exist"));
 973            }
 974
 975            self.org_memberships
 976                .lock()
 977                .entry((org_id, user_id))
 978                .or_insert(is_admin);
 979            Ok(())
 980        }
 981
 982        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 983            self.background.simulate_random_delay().await;
 984            if !self.orgs.lock().contains_key(&org_id) {
 985                return Err(anyhow!("org does not exist"));
 986            }
 987
 988            let mut channels = self.channels.lock();
 989            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
 990            channels.insert(
 991                channel_id,
 992                Channel {
 993                    id: channel_id,
 994                    name: name.to_string(),
 995                    owner_id: org_id.0,
 996                    owner_is_user: false,
 997                },
 998            );
 999            Ok(channel_id)
1000        }
1001
1002        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1003            self.background.simulate_random_delay().await;
1004            Ok(self
1005                .channels
1006                .lock()
1007                .values()
1008                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1009                .cloned()
1010                .collect())
1011        }
1012
1013        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1014            self.background.simulate_random_delay().await;
1015            let channels = self.channels.lock();
1016            let memberships = self.channel_memberships.lock();
1017            Ok(channels
1018                .values()
1019                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1020                .cloned()
1021                .collect())
1022        }
1023
1024        async fn can_user_access_channel(
1025            &self,
1026            user_id: UserId,
1027            channel_id: ChannelId,
1028        ) -> Result<bool> {
1029            self.background.simulate_random_delay().await;
1030            Ok(self
1031                .channel_memberships
1032                .lock()
1033                .contains_key(&(channel_id, user_id)))
1034        }
1035
1036        async fn add_channel_member(
1037            &self,
1038            channel_id: ChannelId,
1039            user_id: UserId,
1040            is_admin: bool,
1041        ) -> Result<()> {
1042            self.background.simulate_random_delay().await;
1043            if !self.channels.lock().contains_key(&channel_id) {
1044                return Err(anyhow!("channel does not exist"));
1045            }
1046            if !self.users.lock().contains_key(&user_id) {
1047                return Err(anyhow!("user does not exist"));
1048            }
1049
1050            self.channel_memberships
1051                .lock()
1052                .entry((channel_id, user_id))
1053                .or_insert(is_admin);
1054            Ok(())
1055        }
1056
1057        async fn create_channel_message(
1058            &self,
1059            channel_id: ChannelId,
1060            sender_id: UserId,
1061            body: &str,
1062            timestamp: OffsetDateTime,
1063            nonce: u128,
1064        ) -> Result<MessageId> {
1065            self.background.simulate_random_delay().await;
1066            if !self.channels.lock().contains_key(&channel_id) {
1067                return Err(anyhow!("channel does not exist"));
1068            }
1069            if !self.users.lock().contains_key(&sender_id) {
1070                return Err(anyhow!("user does not exist"));
1071            }
1072
1073            let mut messages = self.channel_messages.lock();
1074            if let Some(message) = messages
1075                .values()
1076                .find(|message| message.nonce.as_u128() == nonce)
1077            {
1078                Ok(message.id)
1079            } else {
1080                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1081                messages.insert(
1082                    message_id,
1083                    ChannelMessage {
1084                        id: message_id,
1085                        channel_id,
1086                        sender_id,
1087                        body: body.to_string(),
1088                        sent_at: timestamp,
1089                        nonce: Uuid::from_u128(nonce),
1090                    },
1091                );
1092                Ok(message_id)
1093            }
1094        }
1095
1096        async fn get_channel_messages(
1097            &self,
1098            channel_id: ChannelId,
1099            count: usize,
1100            before_id: Option<MessageId>,
1101        ) -> Result<Vec<ChannelMessage>> {
1102            let mut messages = self
1103                .channel_messages
1104                .lock()
1105                .values()
1106                .rev()
1107                .filter(|message| {
1108                    message.channel_id == channel_id
1109                        && message.id < before_id.unwrap_or(MessageId::MAX)
1110                })
1111                .take(count)
1112                .cloned()
1113                .collect::<Vec<_>>();
1114            messages.sort_unstable_by_key(|message| message.id);
1115            Ok(messages)
1116        }
1117
1118        async fn teardown(&self, _name: &str, _url: &str) {}
1119    }
1120}