db.rs

   1use anyhow::Context;
   2use anyhow::Result;
   3pub use async_sqlx_session::PostgresSessionStore as SessionStore;
   4use async_std::task::{block_on, yield_now};
   5use async_trait::async_trait;
   6use serde::Serialize;
   7pub use sqlx::postgres::PgPoolOptions as DbOptions;
   8use sqlx::{types::Uuid, FromRow};
   9use time::OffsetDateTime;
  10
  11macro_rules! test_support {
  12    ($self:ident, { $($token:tt)* }) => {{
  13        let body = async {
  14            $($token)*
  15        };
  16        if $self.test_mode {
  17            yield_now().await;
  18            block_on(body)
  19        } else {
  20            body.await
  21        }
  22    }};
  23}
  24
  25#[async_trait]
  26pub trait Db: Send + Sync {
  27    async fn create_signup(
  28        &self,
  29        github_login: &str,
  30        email_address: &str,
  31        about: &str,
  32        wants_releases: bool,
  33        wants_updates: bool,
  34        wants_community: bool,
  35    ) -> Result<SignupId>;
  36    async fn get_all_signups(&self) -> Result<Vec<Signup>>;
  37    async fn destroy_signup(&self, id: SignupId) -> Result<()>;
  38    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  39    async fn get_all_users(&self) -> Result<Vec<User>>;
  40    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  41    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  42    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  43    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  44    async fn destroy_user(&self, id: UserId) -> Result<()>;
  45    async fn create_access_token_hash(
  46        &self,
  47        user_id: UserId,
  48        access_token_hash: &str,
  49        max_access_token_count: usize,
  50    ) -> Result<()>;
  51    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  52    #[cfg(any(test, feature = "seed-support"))]
  53    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  54    #[cfg(any(test, feature = "seed-support"))]
  55    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  56    #[cfg(any(test, feature = "seed-support"))]
  57    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  58    #[cfg(any(test, feature = "seed-support"))]
  59    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  60    #[cfg(any(test, feature = "seed-support"))]
  61    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  62    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  63    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  64        -> Result<bool>;
  65    #[cfg(any(test, feature = "seed-support"))]
  66    async fn add_channel_member(
  67        &self,
  68        channel_id: ChannelId,
  69        user_id: UserId,
  70        is_admin: bool,
  71    ) -> Result<()>;
  72    async fn create_channel_message(
  73        &self,
  74        channel_id: ChannelId,
  75        sender_id: UserId,
  76        body: &str,
  77        timestamp: OffsetDateTime,
  78        nonce: u128,
  79    ) -> Result<MessageId>;
  80    async fn get_channel_messages(
  81        &self,
  82        channel_id: ChannelId,
  83        count: usize,
  84        before_id: Option<MessageId>,
  85    ) -> Result<Vec<ChannelMessage>>;
  86    #[cfg(test)]
  87    async fn teardown(&self, name: &str, url: &str);
  88}
  89
  90pub struct PostgresDb {
  91    pool: sqlx::PgPool,
  92    test_mode: bool,
  93}
  94
  95impl PostgresDb {
  96    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
  97        let pool = DbOptions::new()
  98            .max_connections(max_connections)
  99            .connect(url)
 100            .await
 101            .context("failed to connect to postgres database")?;
 102        Ok(Self {
 103            pool,
 104            test_mode: false,
 105        })
 106    }
 107}
 108
 109#[async_trait]
 110impl Db for PostgresDb {
 111    // signups
 112    async fn create_signup(
 113        &self,
 114        github_login: &str,
 115        email_address: &str,
 116        about: &str,
 117        wants_releases: bool,
 118        wants_updates: bool,
 119        wants_community: bool,
 120    ) -> Result<SignupId> {
 121        test_support!(self, {
 122            let query = "
 123                INSERT INTO signups (
 124                    github_login,
 125                    email_address,
 126                    about,
 127                    wants_releases,
 128                    wants_updates,
 129                    wants_community
 130                )
 131                VALUES ($1, $2, $3, $4, $5, $6)
 132                RETURNING id
 133            ";
 134            Ok(sqlx::query_scalar(query)
 135                .bind(github_login)
 136                .bind(email_address)
 137                .bind(about)
 138                .bind(wants_releases)
 139                .bind(wants_updates)
 140                .bind(wants_community)
 141                .fetch_one(&self.pool)
 142                .await
 143                .map(SignupId)?)
 144        })
 145    }
 146
 147    async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 148        test_support!(self, {
 149            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 150            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 151        })
 152    }
 153
 154    async fn destroy_signup(&self, id: SignupId) -> Result<()> {
 155        test_support!(self, {
 156            let query = "DELETE FROM signups WHERE id = $1";
 157            Ok(sqlx::query(query)
 158                .bind(id.0)
 159                .execute(&self.pool)
 160                .await
 161                .map(drop)?)
 162        })
 163    }
 164
 165    // users
 166
 167    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 168        test_support!(self, {
 169            let query = "
 170                INSERT INTO users (github_login, admin)
 171                VALUES ($1, $2)
 172                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 173                RETURNING id
 174            ";
 175            Ok(sqlx::query_scalar(query)
 176                .bind(github_login)
 177                .bind(admin)
 178                .fetch_one(&self.pool)
 179                .await
 180                .map(UserId)?)
 181        })
 182    }
 183
 184    async fn get_all_users(&self) -> Result<Vec<User>> {
 185        test_support!(self, {
 186            let query = "SELECT * FROM users ORDER BY github_login ASC";
 187            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 188        })
 189    }
 190
 191    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 192        let users = self.get_users_by_ids(vec![id]).await?;
 193        Ok(users.into_iter().next())
 194    }
 195
 196    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 197        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 198        test_support!(self, {
 199            let query = "
 200                SELECT users.*
 201                FROM users
 202                WHERE users.id = ANY ($1)
 203            ";
 204
 205            Ok(sqlx::query_as(query)
 206                .bind(&ids)
 207                .fetch_all(&self.pool)
 208                .await?)
 209        })
 210    }
 211
 212    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 213        test_support!(self, {
 214            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 215            Ok(sqlx::query_as(query)
 216                .bind(github_login)
 217                .fetch_optional(&self.pool)
 218                .await?)
 219        })
 220    }
 221
 222    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 223        test_support!(self, {
 224            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 225            Ok(sqlx::query(query)
 226                .bind(is_admin)
 227                .bind(id.0)
 228                .execute(&self.pool)
 229                .await
 230                .map(drop)?)
 231        })
 232    }
 233
 234    async fn destroy_user(&self, id: UserId) -> Result<()> {
 235        test_support!(self, {
 236            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 237            sqlx::query(query)
 238                .bind(id.0)
 239                .execute(&self.pool)
 240                .await
 241                .map(drop)?;
 242            let query = "DELETE FROM users WHERE id = $1;";
 243            Ok(sqlx::query(query)
 244                .bind(id.0)
 245                .execute(&self.pool)
 246                .await
 247                .map(drop)?)
 248        })
 249    }
 250
 251    // access tokens
 252
 253    async fn create_access_token_hash(
 254        &self,
 255        user_id: UserId,
 256        access_token_hash: &str,
 257        max_access_token_count: usize,
 258    ) -> Result<()> {
 259        test_support!(self, {
 260            let insert_query = "
 261                INSERT INTO access_tokens (user_id, hash)
 262                VALUES ($1, $2);
 263            ";
 264            let cleanup_query = "
 265                DELETE FROM access_tokens
 266                WHERE id IN (
 267                    SELECT id from access_tokens
 268                    WHERE user_id = $1
 269                    ORDER BY id DESC
 270                    OFFSET $3
 271                )
 272            ";
 273
 274            let mut tx = self.pool.begin().await?;
 275            sqlx::query(insert_query)
 276                .bind(user_id.0)
 277                .bind(access_token_hash)
 278                .execute(&mut tx)
 279                .await?;
 280            sqlx::query(cleanup_query)
 281                .bind(user_id.0)
 282                .bind(access_token_hash)
 283                .bind(max_access_token_count as u32)
 284                .execute(&mut tx)
 285                .await?;
 286            Ok(tx.commit().await?)
 287        })
 288    }
 289
 290    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 291        test_support!(self, {
 292            let query = "
 293                SELECT hash
 294                FROM access_tokens
 295                WHERE user_id = $1
 296                ORDER BY id DESC
 297            ";
 298            Ok(sqlx::query_scalar(query)
 299                .bind(user_id.0)
 300                .fetch_all(&self.pool)
 301                .await?)
 302        })
 303    }
 304
 305    // orgs
 306
 307    #[allow(unused)] // Help rust-analyzer
 308    #[cfg(any(test, feature = "seed-support"))]
 309    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 310        test_support!(self, {
 311            let query = "
 312                SELECT *
 313                FROM orgs
 314                WHERE slug = $1
 315            ";
 316            Ok(sqlx::query_as(query)
 317                .bind(slug)
 318                .fetch_optional(&self.pool)
 319                .await?)
 320        })
 321    }
 322
 323    #[cfg(any(test, feature = "seed-support"))]
 324    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 325        test_support!(self, {
 326            let query = "
 327                INSERT INTO orgs (name, slug)
 328                VALUES ($1, $2)
 329                RETURNING id
 330            ";
 331            Ok(sqlx::query_scalar(query)
 332                .bind(name)
 333                .bind(slug)
 334                .fetch_one(&self.pool)
 335                .await
 336                .map(OrgId)?)
 337        })
 338    }
 339
 340    #[cfg(any(test, feature = "seed-support"))]
 341    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 342        test_support!(self, {
 343            let query = "
 344                INSERT INTO org_memberships (org_id, user_id, admin)
 345                VALUES ($1, $2, $3)
 346                ON CONFLICT DO NOTHING
 347            ";
 348            Ok(sqlx::query(query)
 349                .bind(org_id.0)
 350                .bind(user_id.0)
 351                .bind(is_admin)
 352                .execute(&self.pool)
 353                .await
 354                .map(drop)?)
 355        })
 356    }
 357
 358    // channels
 359
 360    #[cfg(any(test, feature = "seed-support"))]
 361    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 362        test_support!(self, {
 363            let query = "
 364                INSERT INTO channels (owner_id, owner_is_user, name)
 365                VALUES ($1, false, $2)
 366                RETURNING id
 367            ";
 368            Ok(sqlx::query_scalar(query)
 369                .bind(org_id.0)
 370                .bind(name)
 371                .fetch_one(&self.pool)
 372                .await
 373                .map(ChannelId)?)
 374        })
 375    }
 376
 377    #[allow(unused)] // Help rust-analyzer
 378    #[cfg(any(test, feature = "seed-support"))]
 379    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 380        test_support!(self, {
 381            let query = "
 382                SELECT *
 383                FROM channels
 384                WHERE
 385                    channels.owner_is_user = false AND
 386                    channels.owner_id = $1
 387            ";
 388            Ok(sqlx::query_as(query)
 389                .bind(org_id.0)
 390                .fetch_all(&self.pool)
 391                .await?)
 392        })
 393    }
 394
 395    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 396        test_support!(self, {
 397            let query = "
 398                SELECT
 399                    channels.*
 400                FROM
 401                    channel_memberships, channels
 402                WHERE
 403                    channel_memberships.user_id = $1 AND
 404                    channel_memberships.channel_id = channels.id
 405            ";
 406            Ok(sqlx::query_as(query)
 407                .bind(user_id.0)
 408                .fetch_all(&self.pool)
 409                .await?)
 410        })
 411    }
 412
 413    async fn can_user_access_channel(
 414        &self,
 415        user_id: UserId,
 416        channel_id: ChannelId,
 417    ) -> Result<bool> {
 418        test_support!(self, {
 419            let query = "
 420                SELECT id
 421                FROM channel_memberships
 422                WHERE user_id = $1 AND channel_id = $2
 423                LIMIT 1
 424            ";
 425            Ok(sqlx::query_scalar::<_, i32>(query)
 426                .bind(user_id.0)
 427                .bind(channel_id.0)
 428                .fetch_optional(&self.pool)
 429                .await
 430                .map(|e| e.is_some())?)
 431        })
 432    }
 433
 434    #[cfg(any(test, feature = "seed-support"))]
 435    async fn add_channel_member(
 436        &self,
 437        channel_id: ChannelId,
 438        user_id: UserId,
 439        is_admin: bool,
 440    ) -> Result<()> {
 441        test_support!(self, {
 442            let query = "
 443                INSERT INTO channel_memberships (channel_id, user_id, admin)
 444                VALUES ($1, $2, $3)
 445                ON CONFLICT DO NOTHING
 446            ";
 447            Ok(sqlx::query(query)
 448                .bind(channel_id.0)
 449                .bind(user_id.0)
 450                .bind(is_admin)
 451                .execute(&self.pool)
 452                .await
 453                .map(drop)?)
 454        })
 455    }
 456
 457    // messages
 458
 459    async fn create_channel_message(
 460        &self,
 461        channel_id: ChannelId,
 462        sender_id: UserId,
 463        body: &str,
 464        timestamp: OffsetDateTime,
 465        nonce: u128,
 466    ) -> Result<MessageId> {
 467        test_support!(self, {
 468            let query = "
 469                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 470                VALUES ($1, $2, $3, $4, $5)
 471                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 472                RETURNING id
 473            ";
 474            Ok(sqlx::query_scalar(query)
 475                .bind(channel_id.0)
 476                .bind(sender_id.0)
 477                .bind(body)
 478                .bind(timestamp)
 479                .bind(Uuid::from_u128(nonce))
 480                .fetch_one(&self.pool)
 481                .await
 482                .map(MessageId)?)
 483        })
 484    }
 485
 486    async fn get_channel_messages(
 487        &self,
 488        channel_id: ChannelId,
 489        count: usize,
 490        before_id: Option<MessageId>,
 491    ) -> Result<Vec<ChannelMessage>> {
 492        test_support!(self, {
 493            let query = r#"
 494                SELECT * FROM (
 495                    SELECT
 496                        id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 497                    FROM
 498                        channel_messages
 499                    WHERE
 500                        channel_id = $1 AND
 501                        id < $2
 502                    ORDER BY id DESC
 503                    LIMIT $3
 504                ) as recent_messages
 505                ORDER BY id ASC
 506            "#;
 507            Ok(sqlx::query_as(query)
 508                .bind(channel_id.0)
 509                .bind(before_id.unwrap_or(MessageId::MAX))
 510                .bind(count as i64)
 511                .fetch_all(&self.pool)
 512                .await?)
 513        })
 514    }
 515
 516    #[cfg(test)]
 517    async fn teardown(&self, name: &str, url: &str) {
 518        use util::ResultExt;
 519
 520        test_support!(self, {
 521            let query = "
 522                SELECT pg_terminate_backend(pg_stat_activity.pid)
 523                FROM pg_stat_activity
 524                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
 525            ";
 526            sqlx::query(query)
 527                .bind(name)
 528                .execute(&self.pool)
 529                .await
 530                .log_err();
 531            self.pool.close().await;
 532            <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 533                .await
 534                .log_err();
 535        })
 536    }
 537}
 538
 539macro_rules! id_type {
 540    ($name:ident) => {
 541        #[derive(
 542            Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 543        )]
 544        #[sqlx(transparent)]
 545        #[serde(transparent)]
 546        pub struct $name(pub i32);
 547
 548        impl $name {
 549            #[allow(unused)]
 550            pub const MAX: Self = Self(i32::MAX);
 551
 552            #[allow(unused)]
 553            pub fn from_proto(value: u64) -> Self {
 554                Self(value as i32)
 555            }
 556
 557            #[allow(unused)]
 558            pub fn to_proto(&self) -> u64 {
 559                self.0 as u64
 560            }
 561        }
 562    };
 563}
 564
 565id_type!(UserId);
 566#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 567pub struct User {
 568    pub id: UserId,
 569    pub github_login: String,
 570    pub admin: bool,
 571}
 572
 573id_type!(OrgId);
 574#[derive(FromRow)]
 575pub struct Org {
 576    pub id: OrgId,
 577    pub name: String,
 578    pub slug: String,
 579}
 580
 581id_type!(SignupId);
 582#[derive(Debug, FromRow, Serialize)]
 583pub struct Signup {
 584    pub id: SignupId,
 585    pub github_login: String,
 586    pub email_address: String,
 587    pub about: String,
 588    pub wants_releases: Option<bool>,
 589    pub wants_updates: Option<bool>,
 590    pub wants_community: Option<bool>,
 591}
 592
 593id_type!(ChannelId);
 594#[derive(Clone, Debug, FromRow, Serialize)]
 595pub struct Channel {
 596    pub id: ChannelId,
 597    pub name: String,
 598    pub owner_id: i32,
 599    pub owner_is_user: bool,
 600}
 601
 602id_type!(MessageId);
 603#[derive(Clone, Debug, FromRow)]
 604pub struct ChannelMessage {
 605    pub id: MessageId,
 606    pub channel_id: ChannelId,
 607    pub sender_id: UserId,
 608    pub body: String,
 609    pub sent_at: OffsetDateTime,
 610    pub nonce: Uuid,
 611}
 612
 613#[cfg(test)]
 614pub mod tests {
 615    use super::*;
 616    use anyhow::anyhow;
 617    use collections::BTreeMap;
 618    use gpui::{executor::Background, TestAppContext};
 619    use lazy_static::lazy_static;
 620    use parking_lot::Mutex;
 621    use rand::prelude::*;
 622    use sqlx::{
 623        migrate::{MigrateDatabase, Migrator},
 624        Postgres,
 625    };
 626    use std::{path::Path, sync::Arc};
 627    use util::post_inc;
 628
 629    #[gpui::test]
 630    async fn test_get_users_by_ids(cx: &mut TestAppContext) {
 631        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 632            let db = test_db.db();
 633
 634            let user = db.create_user("user", false).await.unwrap();
 635            let friend1 = db.create_user("friend-1", false).await.unwrap();
 636            let friend2 = db.create_user("friend-2", false).await.unwrap();
 637            let friend3 = db.create_user("friend-3", false).await.unwrap();
 638
 639            assert_eq!(
 640                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 641                    .await
 642                    .unwrap(),
 643                vec![
 644                    User {
 645                        id: user,
 646                        github_login: "user".to_string(),
 647                        admin: false,
 648                    },
 649                    User {
 650                        id: friend1,
 651                        github_login: "friend-1".to_string(),
 652                        admin: false,
 653                    },
 654                    User {
 655                        id: friend2,
 656                        github_login: "friend-2".to_string(),
 657                        admin: false,
 658                    },
 659                    User {
 660                        id: friend3,
 661                        github_login: "friend-3".to_string(),
 662                        admin: false,
 663                    }
 664                ]
 665            );
 666        }
 667    }
 668
 669    #[gpui::test]
 670    async fn test_recent_channel_messages(cx: &mut TestAppContext) {
 671        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 672            let db = test_db.db();
 673            let user = db.create_user("user", false).await.unwrap();
 674            let org = db.create_org("org", "org").await.unwrap();
 675            let channel = db.create_org_channel(org, "channel").await.unwrap();
 676            for i in 0..10 {
 677                db.create_channel_message(
 678                    channel,
 679                    user,
 680                    &i.to_string(),
 681                    OffsetDateTime::now_utc(),
 682                    i,
 683                )
 684                .await
 685                .unwrap();
 686            }
 687
 688            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 689            assert_eq!(
 690                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 691                ["5", "6", "7", "8", "9"]
 692            );
 693
 694            let prev_messages = db
 695                .get_channel_messages(channel, 4, Some(messages[0].id))
 696                .await
 697                .unwrap();
 698            assert_eq!(
 699                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 700                ["1", "2", "3", "4"]
 701            );
 702        }
 703    }
 704
 705    #[gpui::test]
 706    async fn test_channel_message_nonces(cx: &mut TestAppContext) {
 707        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
 708            let db = test_db.db();
 709            let user = db.create_user("user", false).await.unwrap();
 710            let org = db.create_org("org", "org").await.unwrap();
 711            let channel = db.create_org_channel(org, "channel").await.unwrap();
 712
 713            let msg1_id = db
 714                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 715                .await
 716                .unwrap();
 717            let msg2_id = db
 718                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 719                .await
 720                .unwrap();
 721            let msg3_id = db
 722                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
 723                .await
 724                .unwrap();
 725            let msg4_id = db
 726                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
 727                .await
 728                .unwrap();
 729
 730            assert_ne!(msg1_id, msg2_id);
 731            assert_eq!(msg1_id, msg3_id);
 732            assert_eq!(msg2_id, msg4_id);
 733        }
 734    }
 735
 736    #[gpui::test]
 737    async fn test_create_access_tokens() {
 738        let test_db = TestDb::postgres();
 739        let db = test_db.db();
 740        let user = db.create_user("the-user", false).await.unwrap();
 741
 742        db.create_access_token_hash(user, "h1", 3).await.unwrap();
 743        db.create_access_token_hash(user, "h2", 3).await.unwrap();
 744        assert_eq!(
 745            db.get_access_token_hashes(user).await.unwrap(),
 746            &["h2".to_string(), "h1".to_string()]
 747        );
 748
 749        db.create_access_token_hash(user, "h3", 3).await.unwrap();
 750        assert_eq!(
 751            db.get_access_token_hashes(user).await.unwrap(),
 752            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
 753        );
 754
 755        db.create_access_token_hash(user, "h4", 3).await.unwrap();
 756        assert_eq!(
 757            db.get_access_token_hashes(user).await.unwrap(),
 758            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
 759        );
 760
 761        db.create_access_token_hash(user, "h5", 3).await.unwrap();
 762        assert_eq!(
 763            db.get_access_token_hashes(user).await.unwrap(),
 764            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
 765        );
 766    }
 767
 768    pub struct TestDb {
 769        pub db: Option<Arc<dyn Db>>,
 770        pub name: String,
 771        pub url: String,
 772    }
 773
 774    impl TestDb {
 775        pub fn postgres() -> Self {
 776            lazy_static! {
 777                static ref LOCK: Mutex<()> = Mutex::new(());
 778            }
 779
 780            let _guard = LOCK.lock();
 781            let mut rng = StdRng::from_entropy();
 782            let name = format!("zed-test-{}", rng.gen::<u128>());
 783            let url = format!("postgres://postgres@localhost/{}", name);
 784            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 785            let db = block_on(async {
 786                Postgres::create_database(&url)
 787                    .await
 788                    .expect("failed to create test db");
 789                let mut db = PostgresDb::new(&url, 5).await.unwrap();
 790                db.test_mode = true;
 791                let migrator = Migrator::new(migrations_path).await.unwrap();
 792                migrator.run(&db.pool).await.unwrap();
 793                db
 794            });
 795            Self {
 796                db: Some(Arc::new(db)),
 797                name,
 798                url,
 799            }
 800        }
 801
 802        pub fn fake(background: Arc<Background>) -> Self {
 803            Self {
 804                db: Some(Arc::new(FakeDb::new(background))),
 805                name: "fake".to_string(),
 806                url: "fake".to_string(),
 807            }
 808        }
 809
 810        pub fn db(&self) -> &Arc<dyn Db> {
 811            self.db.as_ref().unwrap()
 812        }
 813    }
 814
 815    impl Drop for TestDb {
 816        fn drop(&mut self) {
 817            if let Some(db) = self.db.take() {
 818                block_on(db.teardown(&self.name, &self.url));
 819            }
 820        }
 821    }
 822
 823    pub struct FakeDb {
 824        background: Arc<Background>,
 825        users: Mutex<BTreeMap<UserId, User>>,
 826        next_user_id: Mutex<i32>,
 827        orgs: Mutex<BTreeMap<OrgId, Org>>,
 828        next_org_id: Mutex<i32>,
 829        org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
 830        channels: Mutex<BTreeMap<ChannelId, Channel>>,
 831        next_channel_id: Mutex<i32>,
 832        channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
 833        channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
 834        next_channel_message_id: Mutex<i32>,
 835    }
 836
 837    impl FakeDb {
 838        pub fn new(background: Arc<Background>) -> Self {
 839            Self {
 840                background,
 841                users: Default::default(),
 842                next_user_id: Mutex::new(1),
 843                orgs: Default::default(),
 844                next_org_id: Mutex::new(1),
 845                org_memberships: Default::default(),
 846                channels: Default::default(),
 847                next_channel_id: Mutex::new(1),
 848                channel_memberships: Default::default(),
 849                channel_messages: Default::default(),
 850                next_channel_message_id: Mutex::new(1),
 851            }
 852        }
 853    }
 854
 855    #[async_trait]
 856    impl Db for FakeDb {
 857        async fn create_signup(
 858            &self,
 859            _github_login: &str,
 860            _email_address: &str,
 861            _about: &str,
 862            _wants_releases: bool,
 863            _wants_updates: bool,
 864            _wants_community: bool,
 865        ) -> Result<SignupId> {
 866            unimplemented!()
 867        }
 868
 869        async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 870            unimplemented!()
 871        }
 872
 873        async fn destroy_signup(&self, _id: SignupId) -> Result<()> {
 874            unimplemented!()
 875        }
 876
 877        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 878            self.background.simulate_random_delay().await;
 879
 880            let mut users = self.users.lock();
 881            if let Some(user) = users
 882                .values()
 883                .find(|user| user.github_login == github_login)
 884            {
 885                Ok(user.id)
 886            } else {
 887                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
 888                users.insert(
 889                    user_id,
 890                    User {
 891                        id: user_id,
 892                        github_login: github_login.to_string(),
 893                        admin,
 894                    },
 895                );
 896                Ok(user_id)
 897            }
 898        }
 899
 900        async fn get_all_users(&self) -> Result<Vec<User>> {
 901            unimplemented!()
 902        }
 903
 904        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 905            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
 906        }
 907
 908        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 909            self.background.simulate_random_delay().await;
 910            let users = self.users.lock();
 911            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
 912        }
 913
 914        async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
 915            unimplemented!()
 916        }
 917
 918        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
 919            unimplemented!()
 920        }
 921
 922        async fn destroy_user(&self, _id: UserId) -> Result<()> {
 923            unimplemented!()
 924        }
 925
 926        async fn create_access_token_hash(
 927            &self,
 928            _user_id: UserId,
 929            _access_token_hash: &str,
 930            _max_access_token_count: usize,
 931        ) -> Result<()> {
 932            unimplemented!()
 933        }
 934
 935        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
 936            unimplemented!()
 937        }
 938
 939        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
 940            unimplemented!()
 941        }
 942
 943        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 944            self.background.simulate_random_delay().await;
 945            let mut orgs = self.orgs.lock();
 946            if orgs.values().any(|org| org.slug == slug) {
 947                Err(anyhow!("org already exists"))
 948            } else {
 949                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
 950                orgs.insert(
 951                    org_id,
 952                    Org {
 953                        id: org_id,
 954                        name: name.to_string(),
 955                        slug: slug.to_string(),
 956                    },
 957                );
 958                Ok(org_id)
 959            }
 960        }
 961
 962        async fn add_org_member(
 963            &self,
 964            org_id: OrgId,
 965            user_id: UserId,
 966            is_admin: bool,
 967        ) -> Result<()> {
 968            self.background.simulate_random_delay().await;
 969            if !self.orgs.lock().contains_key(&org_id) {
 970                return Err(anyhow!("org does not exist"));
 971            }
 972            if !self.users.lock().contains_key(&user_id) {
 973                return Err(anyhow!("user does not exist"));
 974            }
 975
 976            self.org_memberships
 977                .lock()
 978                .entry((org_id, user_id))
 979                .or_insert(is_admin);
 980            Ok(())
 981        }
 982
 983        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 984            self.background.simulate_random_delay().await;
 985            if !self.orgs.lock().contains_key(&org_id) {
 986                return Err(anyhow!("org does not exist"));
 987            }
 988
 989            let mut channels = self.channels.lock();
 990            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
 991            channels.insert(
 992                channel_id,
 993                Channel {
 994                    id: channel_id,
 995                    name: name.to_string(),
 996                    owner_id: org_id.0,
 997                    owner_is_user: false,
 998                },
 999            );
1000            Ok(channel_id)
1001        }
1002
1003        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1004            self.background.simulate_random_delay().await;
1005            Ok(self
1006                .channels
1007                .lock()
1008                .values()
1009                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1010                .cloned()
1011                .collect())
1012        }
1013
1014        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1015            self.background.simulate_random_delay().await;
1016            let channels = self.channels.lock();
1017            let memberships = self.channel_memberships.lock();
1018            Ok(channels
1019                .values()
1020                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1021                .cloned()
1022                .collect())
1023        }
1024
1025        async fn can_user_access_channel(
1026            &self,
1027            user_id: UserId,
1028            channel_id: ChannelId,
1029        ) -> Result<bool> {
1030            self.background.simulate_random_delay().await;
1031            Ok(self
1032                .channel_memberships
1033                .lock()
1034                .contains_key(&(channel_id, user_id)))
1035        }
1036
1037        async fn add_channel_member(
1038            &self,
1039            channel_id: ChannelId,
1040            user_id: UserId,
1041            is_admin: bool,
1042        ) -> Result<()> {
1043            self.background.simulate_random_delay().await;
1044            if !self.channels.lock().contains_key(&channel_id) {
1045                return Err(anyhow!("channel does not exist"));
1046            }
1047            if !self.users.lock().contains_key(&user_id) {
1048                return Err(anyhow!("user does not exist"));
1049            }
1050
1051            self.channel_memberships
1052                .lock()
1053                .entry((channel_id, user_id))
1054                .or_insert(is_admin);
1055            Ok(())
1056        }
1057
1058        async fn create_channel_message(
1059            &self,
1060            channel_id: ChannelId,
1061            sender_id: UserId,
1062            body: &str,
1063            timestamp: OffsetDateTime,
1064            nonce: u128,
1065        ) -> Result<MessageId> {
1066            self.background.simulate_random_delay().await;
1067            if !self.channels.lock().contains_key(&channel_id) {
1068                return Err(anyhow!("channel does not exist"));
1069            }
1070            if !self.users.lock().contains_key(&sender_id) {
1071                return Err(anyhow!("user does not exist"));
1072            }
1073
1074            let mut messages = self.channel_messages.lock();
1075            if let Some(message) = messages
1076                .values()
1077                .find(|message| message.nonce.as_u128() == nonce)
1078            {
1079                Ok(message.id)
1080            } else {
1081                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1082                messages.insert(
1083                    message_id,
1084                    ChannelMessage {
1085                        id: message_id,
1086                        channel_id,
1087                        sender_id,
1088                        body: body.to_string(),
1089                        sent_at: timestamp,
1090                        nonce: Uuid::from_u128(nonce),
1091                    },
1092                );
1093                Ok(message_id)
1094            }
1095        }
1096
1097        async fn get_channel_messages(
1098            &self,
1099            channel_id: ChannelId,
1100            count: usize,
1101            before_id: Option<MessageId>,
1102        ) -> Result<Vec<ChannelMessage>> {
1103            let mut messages = self
1104                .channel_messages
1105                .lock()
1106                .values()
1107                .rev()
1108                .filter(|message| {
1109                    message.channel_id == channel_id
1110                        && message.id < before_id.unwrap_or(MessageId::MAX)
1111                })
1112                .take(count)
1113                .cloned()
1114                .collect::<Vec<_>>();
1115            messages.sort_unstable_by_key(|message| message.id);
1116            Ok(messages)
1117        }
1118
1119        async fn teardown(&self, _name: &str, _url: &str) {}
1120    }
1121}