db.rs

   1use anyhow::Context;
   2use anyhow::Result;
   3use async_trait::async_trait;
   4use serde::Serialize;
   5pub use sqlx::postgres::PgPoolOptions as DbOptions;
   6use sqlx::{types::Uuid, FromRow};
   7use time::OffsetDateTime;
   8
   9#[async_trait]
  10pub trait Db: Send + Sync {
  11    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  12    async fn get_all_users(&self) -> Result<Vec<User>>;
  13    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  14    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  15    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  16    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  17    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  18    async fn destroy_user(&self, id: UserId) -> Result<()>;
  19    async fn create_access_token_hash(
  20        &self,
  21        user_id: UserId,
  22        access_token_hash: &str,
  23        max_access_token_count: usize,
  24    ) -> Result<()>;
  25    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  26    #[cfg(any(test, feature = "seed-support"))]
  27    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  28    #[cfg(any(test, feature = "seed-support"))]
  29    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  30    #[cfg(any(test, feature = "seed-support"))]
  31    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  32    #[cfg(any(test, feature = "seed-support"))]
  33    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  34    #[cfg(any(test, feature = "seed-support"))]
  35    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  36    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  37    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  38        -> Result<bool>;
  39    #[cfg(any(test, feature = "seed-support"))]
  40    async fn add_channel_member(
  41        &self,
  42        channel_id: ChannelId,
  43        user_id: UserId,
  44        is_admin: bool,
  45    ) -> Result<()>;
  46    async fn create_channel_message(
  47        &self,
  48        channel_id: ChannelId,
  49        sender_id: UserId,
  50        body: &str,
  51        timestamp: OffsetDateTime,
  52        nonce: u128,
  53    ) -> Result<MessageId>;
  54    async fn get_channel_messages(
  55        &self,
  56        channel_id: ChannelId,
  57        count: usize,
  58        before_id: Option<MessageId>,
  59    ) -> Result<Vec<ChannelMessage>>;
  60    #[cfg(test)]
  61    async fn teardown(&self, url: &str);
  62}
  63
  64pub struct PostgresDb {
  65    pool: sqlx::PgPool,
  66}
  67
  68impl PostgresDb {
  69    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  70        let pool = DbOptions::new()
  71            .max_connections(max_connections)
  72            .connect(&url)
  73            .await
  74            .context("failed to connect to postgres database")?;
  75        Ok(Self { pool })
  76    }
  77}
  78
  79#[async_trait]
  80impl Db for PostgresDb {
  81    // users
  82
  83    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
  84        let query = "
  85            INSERT INTO users (github_login, admin)
  86            VALUES ($1, $2)
  87            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
  88            RETURNING id
  89        ";
  90        Ok(sqlx::query_scalar(query)
  91            .bind(github_login)
  92            .bind(admin)
  93            .fetch_one(&self.pool)
  94            .await
  95            .map(UserId)?)
  96    }
  97
  98    async fn get_all_users(&self) -> Result<Vec<User>> {
  99        let query = "SELECT * FROM users ORDER BY github_login ASC";
 100        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 101    }
 102
 103    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 104        let like_string = fuzzy_like_string(name_query);
 105        let query = "
 106            SELECT users.*
 107            FROM users
 108            WHERE github_login like $1
 109            ORDER BY github_login <-> $2
 110            LIMIT $3
 111        ";
 112        Ok(sqlx::query_as(query)
 113            .bind(like_string)
 114            .bind(name_query)
 115            .bind(limit)
 116            .fetch_all(&self.pool)
 117            .await?)
 118    }
 119
 120    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 121        let users = self.get_users_by_ids(vec![id]).await?;
 122        Ok(users.into_iter().next())
 123    }
 124
 125    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 126        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 127        let query = "
 128            SELECT users.*
 129            FROM users
 130            WHERE users.id = ANY ($1)
 131        ";
 132        Ok(sqlx::query_as(query)
 133            .bind(&ids)
 134            .fetch_all(&self.pool)
 135            .await?)
 136    }
 137
 138    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 139        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 140        Ok(sqlx::query_as(query)
 141            .bind(github_login)
 142            .fetch_optional(&self.pool)
 143            .await?)
 144    }
 145
 146    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 147        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 148        Ok(sqlx::query(query)
 149            .bind(is_admin)
 150            .bind(id.0)
 151            .execute(&self.pool)
 152            .await
 153            .map(drop)?)
 154    }
 155
 156    async fn destroy_user(&self, id: UserId) -> Result<()> {
 157        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 158        sqlx::query(query)
 159            .bind(id.0)
 160            .execute(&self.pool)
 161            .await
 162            .map(drop)?;
 163        let query = "DELETE FROM users WHERE id = $1;";
 164        Ok(sqlx::query(query)
 165            .bind(id.0)
 166            .execute(&self.pool)
 167            .await
 168            .map(drop)?)
 169    }
 170
 171    // access tokens
 172
 173    async fn create_access_token_hash(
 174        &self,
 175        user_id: UserId,
 176        access_token_hash: &str,
 177        max_access_token_count: usize,
 178    ) -> Result<()> {
 179        let insert_query = "
 180            INSERT INTO access_tokens (user_id, hash)
 181            VALUES ($1, $2);
 182        ";
 183        let cleanup_query = "
 184            DELETE FROM access_tokens
 185            WHERE id IN (
 186                SELECT id from access_tokens
 187                WHERE user_id = $1
 188                ORDER BY id DESC
 189                OFFSET $3
 190            )
 191        ";
 192
 193        let mut tx = self.pool.begin().await?;
 194        sqlx::query(insert_query)
 195            .bind(user_id.0)
 196            .bind(access_token_hash)
 197            .execute(&mut tx)
 198            .await?;
 199        sqlx::query(cleanup_query)
 200            .bind(user_id.0)
 201            .bind(access_token_hash)
 202            .bind(max_access_token_count as u32)
 203            .execute(&mut tx)
 204            .await?;
 205        Ok(tx.commit().await?)
 206    }
 207
 208    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 209        let query = "
 210            SELECT hash
 211            FROM access_tokens
 212            WHERE user_id = $1
 213            ORDER BY id DESC
 214        ";
 215        Ok(sqlx::query_scalar(query)
 216            .bind(user_id.0)
 217            .fetch_all(&self.pool)
 218            .await?)
 219    }
 220
 221    // orgs
 222
 223    #[allow(unused)] // Help rust-analyzer
 224    #[cfg(any(test, feature = "seed-support"))]
 225    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 226        let query = "
 227            SELECT *
 228            FROM orgs
 229            WHERE slug = $1
 230        ";
 231        Ok(sqlx::query_as(query)
 232            .bind(slug)
 233            .fetch_optional(&self.pool)
 234            .await?)
 235    }
 236
 237    #[cfg(any(test, feature = "seed-support"))]
 238    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 239        let query = "
 240            INSERT INTO orgs (name, slug)
 241            VALUES ($1, $2)
 242            RETURNING id
 243        ";
 244        Ok(sqlx::query_scalar(query)
 245            .bind(name)
 246            .bind(slug)
 247            .fetch_one(&self.pool)
 248            .await
 249            .map(OrgId)?)
 250    }
 251
 252    #[cfg(any(test, feature = "seed-support"))]
 253    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 254        let query = "
 255            INSERT INTO org_memberships (org_id, user_id, admin)
 256            VALUES ($1, $2, $3)
 257            ON CONFLICT DO NOTHING
 258        ";
 259        Ok(sqlx::query(query)
 260            .bind(org_id.0)
 261            .bind(user_id.0)
 262            .bind(is_admin)
 263            .execute(&self.pool)
 264            .await
 265            .map(drop)?)
 266    }
 267
 268    // channels
 269
 270    #[cfg(any(test, feature = "seed-support"))]
 271    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 272        let query = "
 273            INSERT INTO channels (owner_id, owner_is_user, name)
 274            VALUES ($1, false, $2)
 275            RETURNING id
 276        ";
 277        Ok(sqlx::query_scalar(query)
 278            .bind(org_id.0)
 279            .bind(name)
 280            .fetch_one(&self.pool)
 281            .await
 282            .map(ChannelId)?)
 283    }
 284
 285    #[allow(unused)] // Help rust-analyzer
 286    #[cfg(any(test, feature = "seed-support"))]
 287    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 288        let query = "
 289            SELECT *
 290            FROM channels
 291            WHERE
 292                channels.owner_is_user = false AND
 293                channels.owner_id = $1
 294        ";
 295        Ok(sqlx::query_as(query)
 296            .bind(org_id.0)
 297            .fetch_all(&self.pool)
 298            .await?)
 299    }
 300
 301    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 302        let query = "
 303            SELECT
 304                channels.*
 305            FROM
 306                channel_memberships, channels
 307            WHERE
 308                channel_memberships.user_id = $1 AND
 309                channel_memberships.channel_id = channels.id
 310        ";
 311        Ok(sqlx::query_as(query)
 312            .bind(user_id.0)
 313            .fetch_all(&self.pool)
 314            .await?)
 315    }
 316
 317    async fn can_user_access_channel(
 318        &self,
 319        user_id: UserId,
 320        channel_id: ChannelId,
 321    ) -> Result<bool> {
 322        let query = "
 323            SELECT id
 324            FROM channel_memberships
 325            WHERE user_id = $1 AND channel_id = $2
 326            LIMIT 1
 327        ";
 328        Ok(sqlx::query_scalar::<_, i32>(query)
 329            .bind(user_id.0)
 330            .bind(channel_id.0)
 331            .fetch_optional(&self.pool)
 332            .await
 333            .map(|e| e.is_some())?)
 334    }
 335
 336    #[cfg(any(test, feature = "seed-support"))]
 337    async fn add_channel_member(
 338        &self,
 339        channel_id: ChannelId,
 340        user_id: UserId,
 341        is_admin: bool,
 342    ) -> Result<()> {
 343        let query = "
 344            INSERT INTO channel_memberships (channel_id, user_id, admin)
 345            VALUES ($1, $2, $3)
 346            ON CONFLICT DO NOTHING
 347        ";
 348        Ok(sqlx::query(query)
 349            .bind(channel_id.0)
 350            .bind(user_id.0)
 351            .bind(is_admin)
 352            .execute(&self.pool)
 353            .await
 354            .map(drop)?)
 355    }
 356
 357    // messages
 358
 359    async fn create_channel_message(
 360        &self,
 361        channel_id: ChannelId,
 362        sender_id: UserId,
 363        body: &str,
 364        timestamp: OffsetDateTime,
 365        nonce: u128,
 366    ) -> Result<MessageId> {
 367        let query = "
 368            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 369            VALUES ($1, $2, $3, $4, $5)
 370            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 371            RETURNING id
 372        ";
 373        Ok(sqlx::query_scalar(query)
 374            .bind(channel_id.0)
 375            .bind(sender_id.0)
 376            .bind(body)
 377            .bind(timestamp)
 378            .bind(Uuid::from_u128(nonce))
 379            .fetch_one(&self.pool)
 380            .await
 381            .map(MessageId)?)
 382    }
 383
 384    async fn get_channel_messages(
 385        &self,
 386        channel_id: ChannelId,
 387        count: usize,
 388        before_id: Option<MessageId>,
 389    ) -> Result<Vec<ChannelMessage>> {
 390        let query = r#"
 391            SELECT * FROM (
 392                SELECT
 393                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 394                FROM
 395                    channel_messages
 396                WHERE
 397                    channel_id = $1 AND
 398                    id < $2
 399                ORDER BY id DESC
 400                LIMIT $3
 401            ) as recent_messages
 402            ORDER BY id ASC
 403        "#;
 404        Ok(sqlx::query_as(query)
 405            .bind(channel_id.0)
 406            .bind(before_id.unwrap_or(MessageId::MAX))
 407            .bind(count as i64)
 408            .fetch_all(&self.pool)
 409            .await?)
 410    }
 411
 412    #[cfg(test)]
 413    async fn teardown(&self, url: &str) {
 414        use util::ResultExt;
 415
 416        let query = "
 417            SELECT pg_terminate_backend(pg_stat_activity.pid)
 418            FROM pg_stat_activity
 419            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 420        ";
 421        sqlx::query(query).execute(&self.pool).await.log_err();
 422        self.pool.close().await;
 423        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 424            .await
 425            .log_err();
 426    }
 427}
 428
 429macro_rules! id_type {
 430    ($name:ident) => {
 431        #[derive(
 432            Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 433        )]
 434        #[sqlx(transparent)]
 435        #[serde(transparent)]
 436        pub struct $name(pub i32);
 437
 438        impl $name {
 439            #[allow(unused)]
 440            pub const MAX: Self = Self(i32::MAX);
 441
 442            #[allow(unused)]
 443            pub fn from_proto(value: u64) -> Self {
 444                Self(value as i32)
 445            }
 446
 447            #[allow(unused)]
 448            pub fn to_proto(&self) -> u64 {
 449                self.0 as u64
 450            }
 451        }
 452
 453        impl std::fmt::Display for $name {
 454            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 455                self.0.fmt(f)
 456            }
 457        }
 458    };
 459}
 460
 461id_type!(UserId);
 462#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 463pub struct User {
 464    pub id: UserId,
 465    pub github_login: String,
 466    pub admin: bool,
 467}
 468
 469id_type!(OrgId);
 470#[derive(FromRow)]
 471pub struct Org {
 472    pub id: OrgId,
 473    pub name: String,
 474    pub slug: String,
 475}
 476
 477id_type!(ChannelId);
 478#[derive(Clone, Debug, FromRow, Serialize)]
 479pub struct Channel {
 480    pub id: ChannelId,
 481    pub name: String,
 482    pub owner_id: i32,
 483    pub owner_is_user: bool,
 484}
 485
 486id_type!(MessageId);
 487#[derive(Clone, Debug, FromRow)]
 488pub struct ChannelMessage {
 489    pub id: MessageId,
 490    pub channel_id: ChannelId,
 491    pub sender_id: UserId,
 492    pub body: String,
 493    pub sent_at: OffsetDateTime,
 494    pub nonce: Uuid,
 495}
 496
 497fn fuzzy_like_string(string: &str) -> String {
 498    let mut result = String::with_capacity(string.len() * 2 + 1);
 499    for c in string.chars() {
 500        if c.is_alphanumeric() {
 501            result.push('%');
 502            result.push(c);
 503        }
 504    }
 505    result.push('%');
 506    result
 507}
 508
 509#[cfg(test)]
 510pub mod tests {
 511    use super::*;
 512    use anyhow::anyhow;
 513    use collections::BTreeMap;
 514    use gpui::executor::Background;
 515    use lazy_static::lazy_static;
 516    use parking_lot::Mutex;
 517    use rand::prelude::*;
 518    use sqlx::{
 519        migrate::{MigrateDatabase, Migrator},
 520        Postgres,
 521    };
 522    use std::{path::Path, sync::Arc};
 523    use util::post_inc;
 524
 525    #[tokio::test(flavor = "multi_thread")]
 526    async fn test_get_users_by_ids() {
 527        for test_db in [
 528            TestDb::postgres().await,
 529            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 530        ] {
 531            let db = test_db.db();
 532
 533            let user = db.create_user("user", false).await.unwrap();
 534            let friend1 = db.create_user("friend-1", false).await.unwrap();
 535            let friend2 = db.create_user("friend-2", false).await.unwrap();
 536            let friend3 = db.create_user("friend-3", false).await.unwrap();
 537
 538            assert_eq!(
 539                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 540                    .await
 541                    .unwrap(),
 542                vec![
 543                    User {
 544                        id: user,
 545                        github_login: "user".to_string(),
 546                        admin: false,
 547                    },
 548                    User {
 549                        id: friend1,
 550                        github_login: "friend-1".to_string(),
 551                        admin: false,
 552                    },
 553                    User {
 554                        id: friend2,
 555                        github_login: "friend-2".to_string(),
 556                        admin: false,
 557                    },
 558                    User {
 559                        id: friend3,
 560                        github_login: "friend-3".to_string(),
 561                        admin: false,
 562                    }
 563                ]
 564            );
 565        }
 566    }
 567
 568    #[tokio::test(flavor = "multi_thread")]
 569    async fn test_recent_channel_messages() {
 570        for test_db in [
 571            TestDb::postgres().await,
 572            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 573        ] {
 574            let db = test_db.db();
 575            let user = db.create_user("user", false).await.unwrap();
 576            let org = db.create_org("org", "org").await.unwrap();
 577            let channel = db.create_org_channel(org, "channel").await.unwrap();
 578            for i in 0..10 {
 579                db.create_channel_message(
 580                    channel,
 581                    user,
 582                    &i.to_string(),
 583                    OffsetDateTime::now_utc(),
 584                    i,
 585                )
 586                .await
 587                .unwrap();
 588            }
 589
 590            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 591            assert_eq!(
 592                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 593                ["5", "6", "7", "8", "9"]
 594            );
 595
 596            let prev_messages = db
 597                .get_channel_messages(channel, 4, Some(messages[0].id))
 598                .await
 599                .unwrap();
 600            assert_eq!(
 601                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 602                ["1", "2", "3", "4"]
 603            );
 604        }
 605    }
 606
 607    #[tokio::test(flavor = "multi_thread")]
 608    async fn test_channel_message_nonces() {
 609        for test_db in [
 610            TestDb::postgres().await,
 611            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 612        ] {
 613            let db = test_db.db();
 614            let user = db.create_user("user", false).await.unwrap();
 615            let org = db.create_org("org", "org").await.unwrap();
 616            let channel = db.create_org_channel(org, "channel").await.unwrap();
 617
 618            let msg1_id = db
 619                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 620                .await
 621                .unwrap();
 622            let msg2_id = db
 623                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 624                .await
 625                .unwrap();
 626            let msg3_id = db
 627                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
 628                .await
 629                .unwrap();
 630            let msg4_id = db
 631                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
 632                .await
 633                .unwrap();
 634
 635            assert_ne!(msg1_id, msg2_id);
 636            assert_eq!(msg1_id, msg3_id);
 637            assert_eq!(msg2_id, msg4_id);
 638        }
 639    }
 640
 641    #[tokio::test(flavor = "multi_thread")]
 642    async fn test_create_access_tokens() {
 643        let test_db = TestDb::postgres().await;
 644        let db = test_db.db();
 645        let user = db.create_user("the-user", false).await.unwrap();
 646
 647        db.create_access_token_hash(user, "h1", 3).await.unwrap();
 648        db.create_access_token_hash(user, "h2", 3).await.unwrap();
 649        assert_eq!(
 650            db.get_access_token_hashes(user).await.unwrap(),
 651            &["h2".to_string(), "h1".to_string()]
 652        );
 653
 654        db.create_access_token_hash(user, "h3", 3).await.unwrap();
 655        assert_eq!(
 656            db.get_access_token_hashes(user).await.unwrap(),
 657            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
 658        );
 659
 660        db.create_access_token_hash(user, "h4", 3).await.unwrap();
 661        assert_eq!(
 662            db.get_access_token_hashes(user).await.unwrap(),
 663            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
 664        );
 665
 666        db.create_access_token_hash(user, "h5", 3).await.unwrap();
 667        assert_eq!(
 668            db.get_access_token_hashes(user).await.unwrap(),
 669            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
 670        );
 671    }
 672
 673    #[test]
 674    fn test_fuzzy_like_string() {
 675        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
 676        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
 677        assert_eq!(fuzzy_like_string(" z  "), "%z%");
 678    }
 679
 680    #[tokio::test(flavor = "multi_thread")]
 681    async fn test_fuzzy_search_users() {
 682        let test_db = TestDb::postgres().await;
 683        let db = test_db.db();
 684        for github_login in [
 685            "california",
 686            "colorado",
 687            "oregon",
 688            "washington",
 689            "florida",
 690            "delaware",
 691            "rhode-island",
 692        ] {
 693            db.create_user(github_login, false).await.unwrap();
 694        }
 695
 696        assert_eq!(
 697            fuzzy_search_user_names(db, "clr").await,
 698            &["colorado", "california"]
 699        );
 700        assert_eq!(
 701            fuzzy_search_user_names(db, "ro").await,
 702            &["rhode-island", "colorado", "oregon"],
 703        );
 704
 705        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
 706            db.fuzzy_search_users(query, 10)
 707                .await
 708                .unwrap()
 709                .into_iter()
 710                .map(|user| user.github_login)
 711                .collect::<Vec<_>>()
 712        }
 713    }
 714
 715    pub struct TestDb {
 716        pub db: Option<Arc<dyn Db>>,
 717        pub url: String,
 718    }
 719
 720    impl TestDb {
 721        pub async fn postgres() -> Self {
 722            lazy_static! {
 723                static ref LOCK: Mutex<()> = Mutex::new(());
 724            }
 725
 726            let _guard = LOCK.lock();
 727            let mut rng = StdRng::from_entropy();
 728            let name = format!("zed-test-{}", rng.gen::<u128>());
 729            let url = format!("postgres://postgres@localhost/{}", name);
 730            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 731            Postgres::create_database(&url)
 732                .await
 733                .expect("failed to create test db");
 734            let db = PostgresDb::new(&url, 5).await.unwrap();
 735            let migrator = Migrator::new(migrations_path).await.unwrap();
 736            migrator.run(&db.pool).await.unwrap();
 737            Self {
 738                db: Some(Arc::new(db)),
 739                url,
 740            }
 741        }
 742
 743        pub fn fake(background: Arc<Background>) -> Self {
 744            Self {
 745                db: Some(Arc::new(FakeDb::new(background))),
 746                url: Default::default(),
 747            }
 748        }
 749
 750        pub fn db(&self) -> &Arc<dyn Db> {
 751            self.db.as_ref().unwrap()
 752        }
 753    }
 754
 755    impl Drop for TestDb {
 756        fn drop(&mut self) {
 757            if let Some(db) = self.db.take() {
 758                futures::executor::block_on(db.teardown(&self.url));
 759            }
 760        }
 761    }
 762
 763    pub struct FakeDb {
 764        background: Arc<Background>,
 765        users: Mutex<BTreeMap<UserId, User>>,
 766        next_user_id: Mutex<i32>,
 767        orgs: Mutex<BTreeMap<OrgId, Org>>,
 768        next_org_id: Mutex<i32>,
 769        org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
 770        channels: Mutex<BTreeMap<ChannelId, Channel>>,
 771        next_channel_id: Mutex<i32>,
 772        channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
 773        channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
 774        next_channel_message_id: Mutex<i32>,
 775    }
 776
 777    impl FakeDb {
 778        pub fn new(background: Arc<Background>) -> Self {
 779            Self {
 780                background,
 781                users: Default::default(),
 782                next_user_id: Mutex::new(1),
 783                orgs: Default::default(),
 784                next_org_id: Mutex::new(1),
 785                org_memberships: Default::default(),
 786                channels: Default::default(),
 787                next_channel_id: Mutex::new(1),
 788                channel_memberships: Default::default(),
 789                channel_messages: Default::default(),
 790                next_channel_message_id: Mutex::new(1),
 791            }
 792        }
 793    }
 794
 795    #[async_trait]
 796    impl Db for FakeDb {
 797        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 798            self.background.simulate_random_delay().await;
 799
 800            let mut users = self.users.lock();
 801            if let Some(user) = users
 802                .values()
 803                .find(|user| user.github_login == github_login)
 804            {
 805                Ok(user.id)
 806            } else {
 807                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
 808                users.insert(
 809                    user_id,
 810                    User {
 811                        id: user_id,
 812                        github_login: github_login.to_string(),
 813                        admin,
 814                    },
 815                );
 816                Ok(user_id)
 817            }
 818        }
 819
 820        async fn get_all_users(&self) -> Result<Vec<User>> {
 821            unimplemented!()
 822        }
 823
 824        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
 825            unimplemented!()
 826        }
 827
 828        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 829            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
 830        }
 831
 832        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 833            self.background.simulate_random_delay().await;
 834            let users = self.users.lock();
 835            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
 836        }
 837
 838        async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
 839            unimplemented!()
 840        }
 841
 842        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
 843            unimplemented!()
 844        }
 845
 846        async fn destroy_user(&self, _id: UserId) -> Result<()> {
 847            unimplemented!()
 848        }
 849
 850        async fn create_access_token_hash(
 851            &self,
 852            _user_id: UserId,
 853            _access_token_hash: &str,
 854            _max_access_token_count: usize,
 855        ) -> Result<()> {
 856            unimplemented!()
 857        }
 858
 859        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
 860            unimplemented!()
 861        }
 862
 863        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
 864            unimplemented!()
 865        }
 866
 867        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 868            self.background.simulate_random_delay().await;
 869            let mut orgs = self.orgs.lock();
 870            if orgs.values().any(|org| org.slug == slug) {
 871                Err(anyhow!("org already exists"))
 872            } else {
 873                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
 874                orgs.insert(
 875                    org_id,
 876                    Org {
 877                        id: org_id,
 878                        name: name.to_string(),
 879                        slug: slug.to_string(),
 880                    },
 881                );
 882                Ok(org_id)
 883            }
 884        }
 885
 886        async fn add_org_member(
 887            &self,
 888            org_id: OrgId,
 889            user_id: UserId,
 890            is_admin: bool,
 891        ) -> Result<()> {
 892            self.background.simulate_random_delay().await;
 893            if !self.orgs.lock().contains_key(&org_id) {
 894                return Err(anyhow!("org does not exist"));
 895            }
 896            if !self.users.lock().contains_key(&user_id) {
 897                return Err(anyhow!("user does not exist"));
 898            }
 899
 900            self.org_memberships
 901                .lock()
 902                .entry((org_id, user_id))
 903                .or_insert(is_admin);
 904            Ok(())
 905        }
 906
 907        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 908            self.background.simulate_random_delay().await;
 909            if !self.orgs.lock().contains_key(&org_id) {
 910                return Err(anyhow!("org does not exist"));
 911            }
 912
 913            let mut channels = self.channels.lock();
 914            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
 915            channels.insert(
 916                channel_id,
 917                Channel {
 918                    id: channel_id,
 919                    name: name.to_string(),
 920                    owner_id: org_id.0,
 921                    owner_is_user: false,
 922                },
 923            );
 924            Ok(channel_id)
 925        }
 926
 927        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 928            self.background.simulate_random_delay().await;
 929            Ok(self
 930                .channels
 931                .lock()
 932                .values()
 933                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
 934                .cloned()
 935                .collect())
 936        }
 937
 938        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 939            self.background.simulate_random_delay().await;
 940            let channels = self.channels.lock();
 941            let memberships = self.channel_memberships.lock();
 942            Ok(channels
 943                .values()
 944                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
 945                .cloned()
 946                .collect())
 947        }
 948
 949        async fn can_user_access_channel(
 950            &self,
 951            user_id: UserId,
 952            channel_id: ChannelId,
 953        ) -> Result<bool> {
 954            self.background.simulate_random_delay().await;
 955            Ok(self
 956                .channel_memberships
 957                .lock()
 958                .contains_key(&(channel_id, user_id)))
 959        }
 960
 961        async fn add_channel_member(
 962            &self,
 963            channel_id: ChannelId,
 964            user_id: UserId,
 965            is_admin: bool,
 966        ) -> Result<()> {
 967            self.background.simulate_random_delay().await;
 968            if !self.channels.lock().contains_key(&channel_id) {
 969                return Err(anyhow!("channel does not exist"));
 970            }
 971            if !self.users.lock().contains_key(&user_id) {
 972                return Err(anyhow!("user does not exist"));
 973            }
 974
 975            self.channel_memberships
 976                .lock()
 977                .entry((channel_id, user_id))
 978                .or_insert(is_admin);
 979            Ok(())
 980        }
 981
 982        async fn create_channel_message(
 983            &self,
 984            channel_id: ChannelId,
 985            sender_id: UserId,
 986            body: &str,
 987            timestamp: OffsetDateTime,
 988            nonce: u128,
 989        ) -> Result<MessageId> {
 990            self.background.simulate_random_delay().await;
 991            if !self.channels.lock().contains_key(&channel_id) {
 992                return Err(anyhow!("channel does not exist"));
 993            }
 994            if !self.users.lock().contains_key(&sender_id) {
 995                return Err(anyhow!("user does not exist"));
 996            }
 997
 998            let mut messages = self.channel_messages.lock();
 999            if let Some(message) = messages
1000                .values()
1001                .find(|message| message.nonce.as_u128() == nonce)
1002            {
1003                Ok(message.id)
1004            } else {
1005                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1006                messages.insert(
1007                    message_id,
1008                    ChannelMessage {
1009                        id: message_id,
1010                        channel_id,
1011                        sender_id,
1012                        body: body.to_string(),
1013                        sent_at: timestamp,
1014                        nonce: Uuid::from_u128(nonce),
1015                    },
1016                );
1017                Ok(message_id)
1018            }
1019        }
1020
1021        async fn get_channel_messages(
1022            &self,
1023            channel_id: ChannelId,
1024            count: usize,
1025            before_id: Option<MessageId>,
1026        ) -> Result<Vec<ChannelMessage>> {
1027            let mut messages = self
1028                .channel_messages
1029                .lock()
1030                .values()
1031                .rev()
1032                .filter(|message| {
1033                    message.channel_id == channel_id
1034                        && message.id < before_id.unwrap_or(MessageId::MAX)
1035                })
1036                .take(count)
1037                .cloned()
1038                .collect::<Vec<_>>();
1039            messages.sort_unstable_by_key(|message| message.id);
1040            Ok(messages)
1041        }
1042
1043        async fn teardown(&self, _: &str) {}
1044    }
1045}