db.rs

   1use anyhow::{anyhow, Context, Result};
   2use async_trait::async_trait;
   3use futures::StreamExt;
   4use serde::Serialize;
   5pub use sqlx::postgres::PgPoolOptions as DbOptions;
   6use sqlx::{types::Uuid, FromRow};
   7use time::OffsetDateTime;
   8
   9#[async_trait]
  10pub trait Db: Send + Sync {
  11    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  12    async fn get_all_users(&self) -> Result<Vec<User>>;
  13    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  14    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  15    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  16    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  17    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  18    async fn destroy_user(&self, id: UserId) -> Result<()>;
  19
  20    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  21    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  22    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  23    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  24    async fn dismiss_contact_notification(
  25        &self,
  26        responder_id: UserId,
  27        requester_id: UserId,
  28    ) -> Result<()>;
  29    async fn respond_to_contact_request(
  30        &self,
  31        responder_id: UserId,
  32        requester_id: UserId,
  33        accept: bool,
  34    ) -> Result<()>;
  35
  36    async fn create_access_token_hash(
  37        &self,
  38        user_id: UserId,
  39        access_token_hash: &str,
  40        max_access_token_count: usize,
  41    ) -> Result<()>;
  42    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  43    #[cfg(any(test, feature = "seed-support"))]
  44
  45    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  46    #[cfg(any(test, feature = "seed-support"))]
  47    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  48    #[cfg(any(test, feature = "seed-support"))]
  49    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  50    #[cfg(any(test, feature = "seed-support"))]
  51    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  52    #[cfg(any(test, feature = "seed-support"))]
  53
  54    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  55    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  56    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  57        -> Result<bool>;
  58    #[cfg(any(test, feature = "seed-support"))]
  59    async fn add_channel_member(
  60        &self,
  61        channel_id: ChannelId,
  62        user_id: UserId,
  63        is_admin: bool,
  64    ) -> Result<()>;
  65    async fn create_channel_message(
  66        &self,
  67        channel_id: ChannelId,
  68        sender_id: UserId,
  69        body: &str,
  70        timestamp: OffsetDateTime,
  71        nonce: u128,
  72    ) -> Result<MessageId>;
  73    async fn get_channel_messages(
  74        &self,
  75        channel_id: ChannelId,
  76        count: usize,
  77        before_id: Option<MessageId>,
  78    ) -> Result<Vec<ChannelMessage>>;
  79    #[cfg(test)]
  80    async fn teardown(&self, url: &str);
  81    #[cfg(test)]
  82    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
  83}
  84
  85pub struct PostgresDb {
  86    pool: sqlx::PgPool,
  87}
  88
  89impl PostgresDb {
  90    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  91        let pool = DbOptions::new()
  92            .max_connections(max_connections)
  93            .connect(&url)
  94            .await
  95            .context("failed to connect to postgres database")?;
  96        Ok(Self { pool })
  97    }
  98}
  99
 100#[async_trait]
 101impl Db for PostgresDb {
 102    // users
 103
 104    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 105        let query = "
 106            INSERT INTO users (github_login, admin)
 107            VALUES ($1, $2)
 108            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 109            RETURNING id
 110        ";
 111        Ok(sqlx::query_scalar(query)
 112            .bind(github_login)
 113            .bind(admin)
 114            .fetch_one(&self.pool)
 115            .await
 116            .map(UserId)?)
 117    }
 118
 119    async fn get_all_users(&self) -> Result<Vec<User>> {
 120        let query = "SELECT * FROM users ORDER BY github_login ASC";
 121        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 122    }
 123
 124    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 125        let like_string = fuzzy_like_string(name_query);
 126        let query = "
 127            SELECT users.*
 128            FROM users
 129            WHERE github_login ILIKE $1
 130            ORDER BY github_login <-> $2
 131            LIMIT $3
 132        ";
 133        Ok(sqlx::query_as(query)
 134            .bind(like_string)
 135            .bind(name_query)
 136            .bind(limit)
 137            .fetch_all(&self.pool)
 138            .await?)
 139    }
 140
 141    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 142        let users = self.get_users_by_ids(vec![id]).await?;
 143        Ok(users.into_iter().next())
 144    }
 145
 146    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 147        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 148        let query = "
 149            SELECT users.*
 150            FROM users
 151            WHERE users.id = ANY ($1)
 152        ";
 153        Ok(sqlx::query_as(query)
 154            .bind(&ids)
 155            .fetch_all(&self.pool)
 156            .await?)
 157    }
 158
 159    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 160        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 161        Ok(sqlx::query_as(query)
 162            .bind(github_login)
 163            .fetch_optional(&self.pool)
 164            .await?)
 165    }
 166
 167    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 168        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 169        Ok(sqlx::query(query)
 170            .bind(is_admin)
 171            .bind(id.0)
 172            .execute(&self.pool)
 173            .await
 174            .map(drop)?)
 175    }
 176
 177    async fn destroy_user(&self, id: UserId) -> Result<()> {
 178        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 179        sqlx::query(query)
 180            .bind(id.0)
 181            .execute(&self.pool)
 182            .await
 183            .map(drop)?;
 184        let query = "DELETE FROM users WHERE id = $1;";
 185        Ok(sqlx::query(query)
 186            .bind(id.0)
 187            .execute(&self.pool)
 188            .await
 189            .map(drop)?)
 190    }
 191
 192    // contacts
 193
 194    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 195        let query = "
 196            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 197            FROM contacts
 198            WHERE user_id_a = $1 OR user_id_b = $1;
 199        ";
 200
 201        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 202            .bind(user_id)
 203            .fetch(&self.pool);
 204
 205        let mut contacts = vec![Contact::Accepted {
 206            user_id,
 207            should_notify: false,
 208        }];
 209        while let Some(row) = rows.next().await {
 210            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 211
 212            if user_id_a == user_id {
 213                if accepted {
 214                    contacts.push(Contact::Accepted {
 215                        user_id: user_id_b,
 216                        should_notify: should_notify && a_to_b,
 217                    });
 218                } else if a_to_b {
 219                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 220                } else {
 221                    contacts.push(Contact::Incoming {
 222                        user_id: user_id_b,
 223                        should_notify,
 224                    });
 225                }
 226            } else {
 227                if accepted {
 228                    contacts.push(Contact::Accepted {
 229                        user_id: user_id_a,
 230                        should_notify: should_notify && !a_to_b,
 231                    });
 232                } else if a_to_b {
 233                    contacts.push(Contact::Incoming {
 234                        user_id: user_id_a,
 235                        should_notify,
 236                    });
 237                } else {
 238                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 239                }
 240            }
 241        }
 242
 243        contacts.sort_unstable_by_key(|contact| contact.user_id());
 244
 245        Ok(contacts)
 246    }
 247
 248    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 249        let (id_a, id_b) = if user_id_1 < user_id_2 {
 250            (user_id_1, user_id_2)
 251        } else {
 252            (user_id_2, user_id_1)
 253        };
 254
 255        let query = "
 256            SELECT 1 FROM contacts
 257            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 258            LIMIT 1
 259        ";
 260        Ok(sqlx::query_scalar::<_, i32>(query)
 261            .bind(id_a.0)
 262            .bind(id_b.0)
 263            .fetch_optional(&self.pool)
 264            .await?
 265            .is_some())
 266    }
 267
 268    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 269        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 270            (sender_id, receiver_id, true)
 271        } else {
 272            (receiver_id, sender_id, false)
 273        };
 274        let query = "
 275            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 276            VALUES ($1, $2, $3, 'f', 't')
 277            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 278            SET
 279                accepted = 't',
 280                should_notify = 'f'
 281            WHERE
 282                NOT contacts.accepted AND
 283                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 284                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 285        ";
 286        let result = sqlx::query(query)
 287            .bind(id_a.0)
 288            .bind(id_b.0)
 289            .bind(a_to_b)
 290            .execute(&self.pool)
 291            .await?;
 292
 293        if result.rows_affected() == 1 {
 294            Ok(())
 295        } else {
 296            Err(anyhow!("contact already requested"))
 297        }
 298    }
 299
 300    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 301        let (id_a, id_b) = if responder_id < requester_id {
 302            (responder_id, requester_id)
 303        } else {
 304            (requester_id, responder_id)
 305        };
 306        let query = "
 307            DELETE FROM contacts
 308            WHERE user_id_a = $1 AND user_id_b = $2;
 309        ";
 310        let result = sqlx::query(query)
 311            .bind(id_a.0)
 312            .bind(id_b.0)
 313            .execute(&self.pool)
 314            .await?;
 315
 316        if result.rows_affected() == 1 {
 317            Ok(())
 318        } else {
 319            Err(anyhow!("no such contact"))
 320        }
 321    }
 322
 323    async fn dismiss_contact_notification(
 324        &self,
 325        user_id: UserId,
 326        contact_user_id: UserId,
 327    ) -> Result<()> {
 328        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 329            (user_id, contact_user_id, true)
 330        } else {
 331            (contact_user_id, user_id, false)
 332        };
 333
 334        let query = "
 335            UPDATE contacts
 336            SET should_notify = 'f'
 337            WHERE
 338                user_id_a = $1 AND user_id_b = $2 AND
 339                (
 340                    (a_to_b = $3 AND accepted) OR
 341                    (a_to_b != $3 AND NOT accepted)
 342                );
 343        ";
 344
 345        let result = sqlx::query(query)
 346            .bind(id_a.0)
 347            .bind(id_b.0)
 348            .bind(a_to_b)
 349            .execute(&self.pool)
 350            .await?;
 351
 352        if result.rows_affected() == 0 {
 353            Err(anyhow!("no such contact request"))?;
 354        }
 355
 356        Ok(())
 357    }
 358
 359    async fn respond_to_contact_request(
 360        &self,
 361        responder_id: UserId,
 362        requester_id: UserId,
 363        accept: bool,
 364    ) -> Result<()> {
 365        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 366            (responder_id, requester_id, false)
 367        } else {
 368            (requester_id, responder_id, true)
 369        };
 370        let result = if accept {
 371            let query = "
 372                UPDATE contacts
 373                SET accepted = 't', should_notify = 't'
 374                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 375            ";
 376            sqlx::query(query)
 377                .bind(id_a.0)
 378                .bind(id_b.0)
 379                .bind(a_to_b)
 380                .execute(&self.pool)
 381                .await?
 382        } else {
 383            let query = "
 384                DELETE FROM contacts
 385                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 386            ";
 387            sqlx::query(query)
 388                .bind(id_a.0)
 389                .bind(id_b.0)
 390                .bind(a_to_b)
 391                .execute(&self.pool)
 392                .await?
 393        };
 394        if result.rows_affected() == 1 {
 395            Ok(())
 396        } else {
 397            Err(anyhow!("no such contact request"))
 398        }
 399    }
 400
 401    // access tokens
 402
 403    async fn create_access_token_hash(
 404        &self,
 405        user_id: UserId,
 406        access_token_hash: &str,
 407        max_access_token_count: usize,
 408    ) -> Result<()> {
 409        let insert_query = "
 410            INSERT INTO access_tokens (user_id, hash)
 411            VALUES ($1, $2);
 412        ";
 413        let cleanup_query = "
 414            DELETE FROM access_tokens
 415            WHERE id IN (
 416                SELECT id from access_tokens
 417                WHERE user_id = $1
 418                ORDER BY id DESC
 419                OFFSET $3
 420            )
 421        ";
 422
 423        let mut tx = self.pool.begin().await?;
 424        sqlx::query(insert_query)
 425            .bind(user_id.0)
 426            .bind(access_token_hash)
 427            .execute(&mut tx)
 428            .await?;
 429        sqlx::query(cleanup_query)
 430            .bind(user_id.0)
 431            .bind(access_token_hash)
 432            .bind(max_access_token_count as u32)
 433            .execute(&mut tx)
 434            .await?;
 435        Ok(tx.commit().await?)
 436    }
 437
 438    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 439        let query = "
 440            SELECT hash
 441            FROM access_tokens
 442            WHERE user_id = $1
 443            ORDER BY id DESC
 444        ";
 445        Ok(sqlx::query_scalar(query)
 446            .bind(user_id.0)
 447            .fetch_all(&self.pool)
 448            .await?)
 449    }
 450
 451    // orgs
 452
 453    #[allow(unused)] // Help rust-analyzer
 454    #[cfg(any(test, feature = "seed-support"))]
 455    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 456        let query = "
 457            SELECT *
 458            FROM orgs
 459            WHERE slug = $1
 460        ";
 461        Ok(sqlx::query_as(query)
 462            .bind(slug)
 463            .fetch_optional(&self.pool)
 464            .await?)
 465    }
 466
 467    #[cfg(any(test, feature = "seed-support"))]
 468    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 469        let query = "
 470            INSERT INTO orgs (name, slug)
 471            VALUES ($1, $2)
 472            RETURNING id
 473        ";
 474        Ok(sqlx::query_scalar(query)
 475            .bind(name)
 476            .bind(slug)
 477            .fetch_one(&self.pool)
 478            .await
 479            .map(OrgId)?)
 480    }
 481
 482    #[cfg(any(test, feature = "seed-support"))]
 483    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 484        let query = "
 485            INSERT INTO org_memberships (org_id, user_id, admin)
 486            VALUES ($1, $2, $3)
 487            ON CONFLICT DO NOTHING
 488        ";
 489        Ok(sqlx::query(query)
 490            .bind(org_id.0)
 491            .bind(user_id.0)
 492            .bind(is_admin)
 493            .execute(&self.pool)
 494            .await
 495            .map(drop)?)
 496    }
 497
 498    // channels
 499
 500    #[cfg(any(test, feature = "seed-support"))]
 501    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 502        let query = "
 503            INSERT INTO channels (owner_id, owner_is_user, name)
 504            VALUES ($1, false, $2)
 505            RETURNING id
 506        ";
 507        Ok(sqlx::query_scalar(query)
 508            .bind(org_id.0)
 509            .bind(name)
 510            .fetch_one(&self.pool)
 511            .await
 512            .map(ChannelId)?)
 513    }
 514
 515    #[allow(unused)] // Help rust-analyzer
 516    #[cfg(any(test, feature = "seed-support"))]
 517    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 518        let query = "
 519            SELECT *
 520            FROM channels
 521            WHERE
 522                channels.owner_is_user = false AND
 523                channels.owner_id = $1
 524        ";
 525        Ok(sqlx::query_as(query)
 526            .bind(org_id.0)
 527            .fetch_all(&self.pool)
 528            .await?)
 529    }
 530
 531    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 532        let query = "
 533            SELECT
 534                channels.*
 535            FROM
 536                channel_memberships, channels
 537            WHERE
 538                channel_memberships.user_id = $1 AND
 539                channel_memberships.channel_id = channels.id
 540        ";
 541        Ok(sqlx::query_as(query)
 542            .bind(user_id.0)
 543            .fetch_all(&self.pool)
 544            .await?)
 545    }
 546
 547    async fn can_user_access_channel(
 548        &self,
 549        user_id: UserId,
 550        channel_id: ChannelId,
 551    ) -> Result<bool> {
 552        let query = "
 553            SELECT id
 554            FROM channel_memberships
 555            WHERE user_id = $1 AND channel_id = $2
 556            LIMIT 1
 557        ";
 558        Ok(sqlx::query_scalar::<_, i32>(query)
 559            .bind(user_id.0)
 560            .bind(channel_id.0)
 561            .fetch_optional(&self.pool)
 562            .await
 563            .map(|e| e.is_some())?)
 564    }
 565
 566    #[cfg(any(test, feature = "seed-support"))]
 567    async fn add_channel_member(
 568        &self,
 569        channel_id: ChannelId,
 570        user_id: UserId,
 571        is_admin: bool,
 572    ) -> Result<()> {
 573        let query = "
 574            INSERT INTO channel_memberships (channel_id, user_id, admin)
 575            VALUES ($1, $2, $3)
 576            ON CONFLICT DO NOTHING
 577        ";
 578        Ok(sqlx::query(query)
 579            .bind(channel_id.0)
 580            .bind(user_id.0)
 581            .bind(is_admin)
 582            .execute(&self.pool)
 583            .await
 584            .map(drop)?)
 585    }
 586
 587    // messages
 588
 589    async fn create_channel_message(
 590        &self,
 591        channel_id: ChannelId,
 592        sender_id: UserId,
 593        body: &str,
 594        timestamp: OffsetDateTime,
 595        nonce: u128,
 596    ) -> Result<MessageId> {
 597        let query = "
 598            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 599            VALUES ($1, $2, $3, $4, $5)
 600            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 601            RETURNING id
 602        ";
 603        Ok(sqlx::query_scalar(query)
 604            .bind(channel_id.0)
 605            .bind(sender_id.0)
 606            .bind(body)
 607            .bind(timestamp)
 608            .bind(Uuid::from_u128(nonce))
 609            .fetch_one(&self.pool)
 610            .await
 611            .map(MessageId)?)
 612    }
 613
 614    async fn get_channel_messages(
 615        &self,
 616        channel_id: ChannelId,
 617        count: usize,
 618        before_id: Option<MessageId>,
 619    ) -> Result<Vec<ChannelMessage>> {
 620        let query = r#"
 621            SELECT * FROM (
 622                SELECT
 623                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 624                FROM
 625                    channel_messages
 626                WHERE
 627                    channel_id = $1 AND
 628                    id < $2
 629                ORDER BY id DESC
 630                LIMIT $3
 631            ) as recent_messages
 632            ORDER BY id ASC
 633        "#;
 634        Ok(sqlx::query_as(query)
 635            .bind(channel_id.0)
 636            .bind(before_id.unwrap_or(MessageId::MAX))
 637            .bind(count as i64)
 638            .fetch_all(&self.pool)
 639            .await?)
 640    }
 641
 642    #[cfg(test)]
 643    async fn teardown(&self, url: &str) {
 644        use util::ResultExt;
 645
 646        let query = "
 647            SELECT pg_terminate_backend(pg_stat_activity.pid)
 648            FROM pg_stat_activity
 649            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 650        ";
 651        sqlx::query(query).execute(&self.pool).await.log_err();
 652        self.pool.close().await;
 653        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 654            .await
 655            .log_err();
 656    }
 657
 658    #[cfg(test)]
 659    fn as_fake(&self) -> Option<&tests::FakeDb> {
 660        None
 661    }
 662}
 663
 664macro_rules! id_type {
 665    ($name:ident) => {
 666        #[derive(
 667            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 668        )]
 669        #[sqlx(transparent)]
 670        #[serde(transparent)]
 671        pub struct $name(pub i32);
 672
 673        impl $name {
 674            #[allow(unused)]
 675            pub const MAX: Self = Self(i32::MAX);
 676
 677            #[allow(unused)]
 678            pub fn from_proto(value: u64) -> Self {
 679                Self(value as i32)
 680            }
 681
 682            #[allow(unused)]
 683            pub fn to_proto(&self) -> u64 {
 684                self.0 as u64
 685            }
 686        }
 687
 688        impl std::fmt::Display for $name {
 689            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 690                self.0.fmt(f)
 691            }
 692        }
 693    };
 694}
 695
 696id_type!(UserId);
 697#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 698pub struct User {
 699    pub id: UserId,
 700    pub github_login: String,
 701    pub admin: bool,
 702}
 703
 704id_type!(OrgId);
 705#[derive(FromRow)]
 706pub struct Org {
 707    pub id: OrgId,
 708    pub name: String,
 709    pub slug: String,
 710}
 711
 712id_type!(ChannelId);
 713#[derive(Clone, Debug, FromRow, Serialize)]
 714pub struct Channel {
 715    pub id: ChannelId,
 716    pub name: String,
 717    pub owner_id: i32,
 718    pub owner_is_user: bool,
 719}
 720
 721id_type!(MessageId);
 722#[derive(Clone, Debug, FromRow)]
 723pub struct ChannelMessage {
 724    pub id: MessageId,
 725    pub channel_id: ChannelId,
 726    pub sender_id: UserId,
 727    pub body: String,
 728    pub sent_at: OffsetDateTime,
 729    pub nonce: Uuid,
 730}
 731
 732#[derive(Clone, Debug, PartialEq, Eq)]
 733pub enum Contact {
 734    Accepted {
 735        user_id: UserId,
 736        should_notify: bool,
 737    },
 738    Outgoing {
 739        user_id: UserId,
 740    },
 741    Incoming {
 742        user_id: UserId,
 743        should_notify: bool,
 744    },
 745}
 746
 747impl Contact {
 748    pub fn user_id(&self) -> UserId {
 749        match self {
 750            Contact::Accepted { user_id, .. } => *user_id,
 751            Contact::Outgoing { user_id } => *user_id,
 752            Contact::Incoming { user_id, .. } => *user_id,
 753        }
 754    }
 755}
 756
 757#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 758pub struct IncomingContactRequest {
 759    pub requester_id: UserId,
 760    pub should_notify: bool,
 761}
 762
 763fn fuzzy_like_string(string: &str) -> String {
 764    let mut result = String::with_capacity(string.len() * 2 + 1);
 765    for c in string.chars() {
 766        if c.is_alphanumeric() {
 767            result.push('%');
 768            result.push(c);
 769        }
 770    }
 771    result.push('%');
 772    result
 773}
 774
 775#[cfg(test)]
 776pub mod tests {
 777    use super::*;
 778    use anyhow::anyhow;
 779    use collections::BTreeMap;
 780    use gpui::executor::Background;
 781    use lazy_static::lazy_static;
 782    use parking_lot::Mutex;
 783    use rand::prelude::*;
 784    use sqlx::{
 785        migrate::{MigrateDatabase, Migrator},
 786        Postgres,
 787    };
 788    use std::{path::Path, sync::Arc};
 789    use util::post_inc;
 790
 791    #[tokio::test(flavor = "multi_thread")]
 792    async fn test_get_users_by_ids() {
 793        for test_db in [
 794            TestDb::postgres().await,
 795            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 796        ] {
 797            let db = test_db.db();
 798
 799            let user = db.create_user("user", false).await.unwrap();
 800            let friend1 = db.create_user("friend-1", false).await.unwrap();
 801            let friend2 = db.create_user("friend-2", false).await.unwrap();
 802            let friend3 = db.create_user("friend-3", false).await.unwrap();
 803
 804            assert_eq!(
 805                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 806                    .await
 807                    .unwrap(),
 808                vec![
 809                    User {
 810                        id: user,
 811                        github_login: "user".to_string(),
 812                        admin: false,
 813                    },
 814                    User {
 815                        id: friend1,
 816                        github_login: "friend-1".to_string(),
 817                        admin: false,
 818                    },
 819                    User {
 820                        id: friend2,
 821                        github_login: "friend-2".to_string(),
 822                        admin: false,
 823                    },
 824                    User {
 825                        id: friend3,
 826                        github_login: "friend-3".to_string(),
 827                        admin: false,
 828                    }
 829                ]
 830            );
 831        }
 832    }
 833
 834    #[tokio::test(flavor = "multi_thread")]
 835    async fn test_recent_channel_messages() {
 836        for test_db in [
 837            TestDb::postgres().await,
 838            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 839        ] {
 840            let db = test_db.db();
 841            let user = db.create_user("user", false).await.unwrap();
 842            let org = db.create_org("org", "org").await.unwrap();
 843            let channel = db.create_org_channel(org, "channel").await.unwrap();
 844            for i in 0..10 {
 845                db.create_channel_message(
 846                    channel,
 847                    user,
 848                    &i.to_string(),
 849                    OffsetDateTime::now_utc(),
 850                    i,
 851                )
 852                .await
 853                .unwrap();
 854            }
 855
 856            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 857            assert_eq!(
 858                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 859                ["5", "6", "7", "8", "9"]
 860            );
 861
 862            let prev_messages = db
 863                .get_channel_messages(channel, 4, Some(messages[0].id))
 864                .await
 865                .unwrap();
 866            assert_eq!(
 867                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 868                ["1", "2", "3", "4"]
 869            );
 870        }
 871    }
 872
 873    #[tokio::test(flavor = "multi_thread")]
 874    async fn test_channel_message_nonces() {
 875        for test_db in [
 876            TestDb::postgres().await,
 877            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 878        ] {
 879            let db = test_db.db();
 880            let user = db.create_user("user", false).await.unwrap();
 881            let org = db.create_org("org", "org").await.unwrap();
 882            let channel = db.create_org_channel(org, "channel").await.unwrap();
 883
 884            let msg1_id = db
 885                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 886                .await
 887                .unwrap();
 888            let msg2_id = db
 889                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 890                .await
 891                .unwrap();
 892            let msg3_id = db
 893                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
 894                .await
 895                .unwrap();
 896            let msg4_id = db
 897                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
 898                .await
 899                .unwrap();
 900
 901            assert_ne!(msg1_id, msg2_id);
 902            assert_eq!(msg1_id, msg3_id);
 903            assert_eq!(msg2_id, msg4_id);
 904        }
 905    }
 906
 907    #[tokio::test(flavor = "multi_thread")]
 908    async fn test_create_access_tokens() {
 909        let test_db = TestDb::postgres().await;
 910        let db = test_db.db();
 911        let user = db.create_user("the-user", false).await.unwrap();
 912
 913        db.create_access_token_hash(user, "h1", 3).await.unwrap();
 914        db.create_access_token_hash(user, "h2", 3).await.unwrap();
 915        assert_eq!(
 916            db.get_access_token_hashes(user).await.unwrap(),
 917            &["h2".to_string(), "h1".to_string()]
 918        );
 919
 920        db.create_access_token_hash(user, "h3", 3).await.unwrap();
 921        assert_eq!(
 922            db.get_access_token_hashes(user).await.unwrap(),
 923            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
 924        );
 925
 926        db.create_access_token_hash(user, "h4", 3).await.unwrap();
 927        assert_eq!(
 928            db.get_access_token_hashes(user).await.unwrap(),
 929            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
 930        );
 931
 932        db.create_access_token_hash(user, "h5", 3).await.unwrap();
 933        assert_eq!(
 934            db.get_access_token_hashes(user).await.unwrap(),
 935            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
 936        );
 937    }
 938
 939    #[test]
 940    fn test_fuzzy_like_string() {
 941        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
 942        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
 943        assert_eq!(fuzzy_like_string(" z  "), "%z%");
 944    }
 945
 946    #[tokio::test(flavor = "multi_thread")]
 947    async fn test_fuzzy_search_users() {
 948        let test_db = TestDb::postgres().await;
 949        let db = test_db.db();
 950        for github_login in [
 951            "California",
 952            "colorado",
 953            "oregon",
 954            "washington",
 955            "florida",
 956            "delaware",
 957            "rhode-island",
 958        ] {
 959            db.create_user(github_login, false).await.unwrap();
 960        }
 961
 962        assert_eq!(
 963            fuzzy_search_user_names(db, "clr").await,
 964            &["colorado", "California"]
 965        );
 966        assert_eq!(
 967            fuzzy_search_user_names(db, "ro").await,
 968            &["rhode-island", "colorado", "oregon"],
 969        );
 970
 971        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
 972            db.fuzzy_search_users(query, 10)
 973                .await
 974                .unwrap()
 975                .into_iter()
 976                .map(|user| user.github_login)
 977                .collect::<Vec<_>>()
 978        }
 979    }
 980
 981    #[tokio::test(flavor = "multi_thread")]
 982    async fn test_add_contacts() {
 983        for test_db in [
 984            TestDb::postgres().await,
 985            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 986        ] {
 987            let db = test_db.db();
 988
 989            let user_1 = db.create_user("user1", false).await.unwrap();
 990            let user_2 = db.create_user("user2", false).await.unwrap();
 991            let user_3 = db.create_user("user3", false).await.unwrap();
 992
 993            // User starts with no contacts
 994            assert_eq!(
 995                db.get_contacts(user_1).await.unwrap(),
 996                vec![Contact::Accepted {
 997                    user_id: user_1,
 998                    should_notify: false
 999                }],
1000            );
1001
1002            // User requests a contact. Both users see the pending request.
1003            db.send_contact_request(user_1, user_2).await.unwrap();
1004            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1005            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1006            assert_eq!(
1007                db.get_contacts(user_1).await.unwrap(),
1008                &[
1009                    Contact::Accepted {
1010                        user_id: user_1,
1011                        should_notify: false
1012                    },
1013                    Contact::Outgoing { user_id: user_2 }
1014                ],
1015            );
1016            assert_eq!(
1017                db.get_contacts(user_2).await.unwrap(),
1018                &[
1019                    Contact::Incoming {
1020                        user_id: user_1,
1021                        should_notify: true
1022                    },
1023                    Contact::Accepted {
1024                        user_id: user_2,
1025                        should_notify: false
1026                    },
1027                ]
1028            );
1029
1030            // User 2 dismisses the contact request notification without accepting or rejecting.
1031            // We shouldn't notify them again.
1032            db.dismiss_contact_notification(user_1, user_2)
1033                .await
1034                .unwrap_err();
1035            db.dismiss_contact_notification(user_2, user_1)
1036                .await
1037                .unwrap();
1038            assert_eq!(
1039                db.get_contacts(user_2).await.unwrap(),
1040                &[
1041                    Contact::Incoming {
1042                        user_id: user_1,
1043                        should_notify: false
1044                    },
1045                    Contact::Accepted {
1046                        user_id: user_2,
1047                        should_notify: false
1048                    },
1049                ]
1050            );
1051
1052            // User can't accept their own contact request
1053            db.respond_to_contact_request(user_1, user_2, true)
1054                .await
1055                .unwrap_err();
1056
1057            // User accepts a contact request. Both users see the contact.
1058            db.respond_to_contact_request(user_2, user_1, true)
1059                .await
1060                .unwrap();
1061            assert_eq!(
1062                db.get_contacts(user_1).await.unwrap(),
1063                &[
1064                    Contact::Accepted {
1065                        user_id: user_1,
1066                        should_notify: false
1067                    },
1068                    Contact::Accepted {
1069                        user_id: user_2,
1070                        should_notify: true
1071                    }
1072                ],
1073            );
1074            assert!(db.has_contact(user_1, user_2).await.unwrap());
1075            assert!(db.has_contact(user_2, user_1).await.unwrap());
1076            assert_eq!(
1077                db.get_contacts(user_2).await.unwrap(),
1078                &[
1079                    Contact::Accepted {
1080                        user_id: user_1,
1081                        should_notify: false,
1082                    },
1083                    Contact::Accepted {
1084                        user_id: user_2,
1085                        should_notify: false,
1086                    },
1087                ]
1088            );
1089
1090            // Users cannot re-request existing contacts.
1091            db.send_contact_request(user_1, user_2).await.unwrap_err();
1092            db.send_contact_request(user_2, user_1).await.unwrap_err();
1093
1094            // Users can't dismiss notifications of them accepting other users' requests.
1095            db.dismiss_contact_notification(user_2, user_1)
1096                .await
1097                .unwrap_err();
1098            assert_eq!(
1099                db.get_contacts(user_1).await.unwrap(),
1100                &[
1101                    Contact::Accepted {
1102                        user_id: user_1,
1103                        should_notify: false
1104                    },
1105                    Contact::Accepted {
1106                        user_id: user_2,
1107                        should_notify: true,
1108                    },
1109                ]
1110            );
1111
1112            // Users can dismiss notifications of other users accepting their requests.
1113            db.dismiss_contact_notification(user_1, user_2)
1114                .await
1115                .unwrap();
1116            assert_eq!(
1117                db.get_contacts(user_1).await.unwrap(),
1118                &[
1119                    Contact::Accepted {
1120                        user_id: user_1,
1121                        should_notify: false
1122                    },
1123                    Contact::Accepted {
1124                        user_id: user_2,
1125                        should_notify: false,
1126                    },
1127                ]
1128            );
1129
1130            // Users send each other concurrent contact requests and
1131            // see that they are immediately accepted.
1132            db.send_contact_request(user_1, user_3).await.unwrap();
1133            db.send_contact_request(user_3, user_1).await.unwrap();
1134            assert_eq!(
1135                db.get_contacts(user_1).await.unwrap(),
1136                &[
1137                    Contact::Accepted {
1138                        user_id: user_1,
1139                        should_notify: false
1140                    },
1141                    Contact::Accepted {
1142                        user_id: user_2,
1143                        should_notify: false,
1144                    },
1145                    Contact::Accepted {
1146                        user_id: user_3,
1147                        should_notify: false
1148                    },
1149                ]
1150            );
1151            assert_eq!(
1152                db.get_contacts(user_3).await.unwrap(),
1153                &[
1154                    Contact::Accepted {
1155                        user_id: user_1,
1156                        should_notify: false
1157                    },
1158                    Contact::Accepted {
1159                        user_id: user_3,
1160                        should_notify: false
1161                    }
1162                ],
1163            );
1164
1165            // User declines a contact request. Both users see that it is gone.
1166            db.send_contact_request(user_2, user_3).await.unwrap();
1167            db.respond_to_contact_request(user_3, user_2, false)
1168                .await
1169                .unwrap();
1170            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1171            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1172            assert_eq!(
1173                db.get_contacts(user_2).await.unwrap(),
1174                &[
1175                    Contact::Accepted {
1176                        user_id: user_1,
1177                        should_notify: false
1178                    },
1179                    Contact::Accepted {
1180                        user_id: user_2,
1181                        should_notify: false
1182                    }
1183                ]
1184            );
1185            assert_eq!(
1186                db.get_contacts(user_3).await.unwrap(),
1187                &[
1188                    Contact::Accepted {
1189                        user_id: user_1,
1190                        should_notify: false
1191                    },
1192                    Contact::Accepted {
1193                        user_id: user_3,
1194                        should_notify: false
1195                    }
1196                ],
1197            );
1198        }
1199    }
1200
1201    pub struct TestDb {
1202        pub db: Option<Arc<dyn Db>>,
1203        pub url: String,
1204    }
1205
1206    impl TestDb {
1207        pub async fn postgres() -> Self {
1208            lazy_static! {
1209                static ref LOCK: Mutex<()> = Mutex::new(());
1210            }
1211
1212            let _guard = LOCK.lock();
1213            let mut rng = StdRng::from_entropy();
1214            let name = format!("zed-test-{}", rng.gen::<u128>());
1215            let url = format!("postgres://postgres@localhost/{}", name);
1216            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1217            Postgres::create_database(&url)
1218                .await
1219                .expect("failed to create test db");
1220            let db = PostgresDb::new(&url, 5).await.unwrap();
1221            let migrator = Migrator::new(migrations_path).await.unwrap();
1222            migrator.run(&db.pool).await.unwrap();
1223            Self {
1224                db: Some(Arc::new(db)),
1225                url,
1226            }
1227        }
1228
1229        pub fn fake(background: Arc<Background>) -> Self {
1230            Self {
1231                db: Some(Arc::new(FakeDb::new(background))),
1232                url: Default::default(),
1233            }
1234        }
1235
1236        pub fn db(&self) -> &Arc<dyn Db> {
1237            self.db.as_ref().unwrap()
1238        }
1239    }
1240
1241    impl Drop for TestDb {
1242        fn drop(&mut self) {
1243            if let Some(db) = self.db.take() {
1244                futures::executor::block_on(db.teardown(&self.url));
1245            }
1246        }
1247    }
1248
1249    pub struct FakeDb {
1250        background: Arc<Background>,
1251        pub users: Mutex<BTreeMap<UserId, User>>,
1252        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1253        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1254        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1255        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1256        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1257        pub contacts: Mutex<Vec<FakeContact>>,
1258        next_channel_message_id: Mutex<i32>,
1259        next_user_id: Mutex<i32>,
1260        next_org_id: Mutex<i32>,
1261        next_channel_id: Mutex<i32>,
1262    }
1263
1264    #[derive(Debug)]
1265    pub struct FakeContact {
1266        pub requester_id: UserId,
1267        pub responder_id: UserId,
1268        pub accepted: bool,
1269        pub should_notify: bool,
1270    }
1271
1272    impl FakeDb {
1273        pub fn new(background: Arc<Background>) -> Self {
1274            Self {
1275                background,
1276                users: Default::default(),
1277                next_user_id: Mutex::new(1),
1278                orgs: Default::default(),
1279                next_org_id: Mutex::new(1),
1280                org_memberships: Default::default(),
1281                channels: Default::default(),
1282                next_channel_id: Mutex::new(1),
1283                channel_memberships: Default::default(),
1284                channel_messages: Default::default(),
1285                next_channel_message_id: Mutex::new(1),
1286                contacts: Default::default(),
1287            }
1288        }
1289    }
1290
1291    #[async_trait]
1292    impl Db for FakeDb {
1293        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1294            self.background.simulate_random_delay().await;
1295
1296            let mut users = self.users.lock();
1297            if let Some(user) = users
1298                .values()
1299                .find(|user| user.github_login == github_login)
1300            {
1301                Ok(user.id)
1302            } else {
1303                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1304                users.insert(
1305                    user_id,
1306                    User {
1307                        id: user_id,
1308                        github_login: github_login.to_string(),
1309                        admin,
1310                    },
1311                );
1312                Ok(user_id)
1313            }
1314        }
1315
1316        async fn get_all_users(&self) -> Result<Vec<User>> {
1317            unimplemented!()
1318        }
1319
1320        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1321            unimplemented!()
1322        }
1323
1324        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1325            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1326        }
1327
1328        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1329            self.background.simulate_random_delay().await;
1330            let users = self.users.lock();
1331            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1332        }
1333
1334        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1335            Ok(self
1336                .users
1337                .lock()
1338                .values()
1339                .find(|user| user.github_login == github_login)
1340                .cloned())
1341        }
1342
1343        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1344            unimplemented!()
1345        }
1346
1347        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1348            unimplemented!()
1349        }
1350
1351        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1352            self.background.simulate_random_delay().await;
1353            let mut contacts = vec![Contact::Accepted {
1354                user_id: id,
1355                should_notify: false,
1356            }];
1357
1358            for contact in self.contacts.lock().iter() {
1359                if contact.requester_id == id {
1360                    if contact.accepted {
1361                        contacts.push(Contact::Accepted {
1362                            user_id: contact.responder_id,
1363                            should_notify: contact.should_notify,
1364                        });
1365                    } else {
1366                        contacts.push(Contact::Outgoing {
1367                            user_id: contact.responder_id,
1368                        });
1369                    }
1370                } else if contact.responder_id == id {
1371                    if contact.accepted {
1372                        contacts.push(Contact::Accepted {
1373                            user_id: contact.requester_id,
1374                            should_notify: false,
1375                        });
1376                    } else {
1377                        contacts.push(Contact::Incoming {
1378                            user_id: contact.requester_id,
1379                            should_notify: contact.should_notify,
1380                        });
1381                    }
1382                }
1383            }
1384
1385            contacts.sort_unstable_by_key(|contact| contact.user_id());
1386            Ok(contacts)
1387        }
1388
1389        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1390            self.background.simulate_random_delay().await;
1391            Ok(self.contacts.lock().iter().any(|contact| {
1392                contact.accepted
1393                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1394                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1395            }))
1396        }
1397
1398        async fn send_contact_request(
1399            &self,
1400            requester_id: UserId,
1401            responder_id: UserId,
1402        ) -> Result<()> {
1403            let mut contacts = self.contacts.lock();
1404            for contact in contacts.iter_mut() {
1405                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1406                    if contact.accepted {
1407                        Err(anyhow!("contact already exists"))?;
1408                    } else {
1409                        Err(anyhow!("contact already requested"))?;
1410                    }
1411                }
1412                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1413                    if contact.accepted {
1414                        Err(anyhow!("contact already exists"))?;
1415                    } else {
1416                        contact.accepted = true;
1417                        contact.should_notify = false;
1418                        return Ok(());
1419                    }
1420                }
1421            }
1422            contacts.push(FakeContact {
1423                requester_id,
1424                responder_id,
1425                accepted: false,
1426                should_notify: true,
1427            });
1428            Ok(())
1429        }
1430
1431        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1432            self.contacts.lock().retain(|contact| {
1433                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1434            });
1435            Ok(())
1436        }
1437
1438        async fn dismiss_contact_notification(
1439            &self,
1440            user_id: UserId,
1441            contact_user_id: UserId,
1442        ) -> Result<()> {
1443            let mut contacts = self.contacts.lock();
1444            for contact in contacts.iter_mut() {
1445                if contact.requester_id == contact_user_id
1446                    && contact.responder_id == user_id
1447                    && !contact.accepted
1448                {
1449                    contact.should_notify = false;
1450                    return Ok(());
1451                }
1452                if contact.requester_id == user_id
1453                    && contact.responder_id == contact_user_id
1454                    && contact.accepted
1455                {
1456                    contact.should_notify = false;
1457                    return Ok(());
1458                }
1459            }
1460            Err(anyhow!("no such notification"))
1461        }
1462
1463        async fn respond_to_contact_request(
1464            &self,
1465            responder_id: UserId,
1466            requester_id: UserId,
1467            accept: bool,
1468        ) -> Result<()> {
1469            let mut contacts = self.contacts.lock();
1470            for (ix, contact) in contacts.iter_mut().enumerate() {
1471                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1472                    if contact.accepted {
1473                        return Err(anyhow!("contact already confirmed"));
1474                    }
1475                    if accept {
1476                        contact.accepted = true;
1477                        contact.should_notify = true;
1478                    } else {
1479                        contacts.remove(ix);
1480                    }
1481                    return Ok(());
1482                }
1483            }
1484            Err(anyhow!("no such contact request"))
1485        }
1486
1487        async fn create_access_token_hash(
1488            &self,
1489            _user_id: UserId,
1490            _access_token_hash: &str,
1491            _max_access_token_count: usize,
1492        ) -> Result<()> {
1493            unimplemented!()
1494        }
1495
1496        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1497            unimplemented!()
1498        }
1499
1500        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1501            unimplemented!()
1502        }
1503
1504        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1505            self.background.simulate_random_delay().await;
1506            let mut orgs = self.orgs.lock();
1507            if orgs.values().any(|org| org.slug == slug) {
1508                Err(anyhow!("org already exists"))
1509            } else {
1510                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1511                orgs.insert(
1512                    org_id,
1513                    Org {
1514                        id: org_id,
1515                        name: name.to_string(),
1516                        slug: slug.to_string(),
1517                    },
1518                );
1519                Ok(org_id)
1520            }
1521        }
1522
1523        async fn add_org_member(
1524            &self,
1525            org_id: OrgId,
1526            user_id: UserId,
1527            is_admin: bool,
1528        ) -> Result<()> {
1529            self.background.simulate_random_delay().await;
1530            if !self.orgs.lock().contains_key(&org_id) {
1531                return Err(anyhow!("org does not exist"));
1532            }
1533            if !self.users.lock().contains_key(&user_id) {
1534                return Err(anyhow!("user does not exist"));
1535            }
1536
1537            self.org_memberships
1538                .lock()
1539                .entry((org_id, user_id))
1540                .or_insert(is_admin);
1541            Ok(())
1542        }
1543
1544        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1545            self.background.simulate_random_delay().await;
1546            if !self.orgs.lock().contains_key(&org_id) {
1547                return Err(anyhow!("org does not exist"));
1548            }
1549
1550            let mut channels = self.channels.lock();
1551            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1552            channels.insert(
1553                channel_id,
1554                Channel {
1555                    id: channel_id,
1556                    name: name.to_string(),
1557                    owner_id: org_id.0,
1558                    owner_is_user: false,
1559                },
1560            );
1561            Ok(channel_id)
1562        }
1563
1564        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1565            self.background.simulate_random_delay().await;
1566            Ok(self
1567                .channels
1568                .lock()
1569                .values()
1570                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1571                .cloned()
1572                .collect())
1573        }
1574
1575        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1576            self.background.simulate_random_delay().await;
1577            let channels = self.channels.lock();
1578            let memberships = self.channel_memberships.lock();
1579            Ok(channels
1580                .values()
1581                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1582                .cloned()
1583                .collect())
1584        }
1585
1586        async fn can_user_access_channel(
1587            &self,
1588            user_id: UserId,
1589            channel_id: ChannelId,
1590        ) -> Result<bool> {
1591            self.background.simulate_random_delay().await;
1592            Ok(self
1593                .channel_memberships
1594                .lock()
1595                .contains_key(&(channel_id, user_id)))
1596        }
1597
1598        async fn add_channel_member(
1599            &self,
1600            channel_id: ChannelId,
1601            user_id: UserId,
1602            is_admin: bool,
1603        ) -> Result<()> {
1604            self.background.simulate_random_delay().await;
1605            if !self.channels.lock().contains_key(&channel_id) {
1606                return Err(anyhow!("channel does not exist"));
1607            }
1608            if !self.users.lock().contains_key(&user_id) {
1609                return Err(anyhow!("user does not exist"));
1610            }
1611
1612            self.channel_memberships
1613                .lock()
1614                .entry((channel_id, user_id))
1615                .or_insert(is_admin);
1616            Ok(())
1617        }
1618
1619        async fn create_channel_message(
1620            &self,
1621            channel_id: ChannelId,
1622            sender_id: UserId,
1623            body: &str,
1624            timestamp: OffsetDateTime,
1625            nonce: u128,
1626        ) -> Result<MessageId> {
1627            self.background.simulate_random_delay().await;
1628            if !self.channels.lock().contains_key(&channel_id) {
1629                return Err(anyhow!("channel does not exist"));
1630            }
1631            if !self.users.lock().contains_key(&sender_id) {
1632                return Err(anyhow!("user does not exist"));
1633            }
1634
1635            let mut messages = self.channel_messages.lock();
1636            if let Some(message) = messages
1637                .values()
1638                .find(|message| message.nonce.as_u128() == nonce)
1639            {
1640                Ok(message.id)
1641            } else {
1642                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1643                messages.insert(
1644                    message_id,
1645                    ChannelMessage {
1646                        id: message_id,
1647                        channel_id,
1648                        sender_id,
1649                        body: body.to_string(),
1650                        sent_at: timestamp,
1651                        nonce: Uuid::from_u128(nonce),
1652                    },
1653                );
1654                Ok(message_id)
1655            }
1656        }
1657
1658        async fn get_channel_messages(
1659            &self,
1660            channel_id: ChannelId,
1661            count: usize,
1662            before_id: Option<MessageId>,
1663        ) -> Result<Vec<ChannelMessage>> {
1664            let mut messages = self
1665                .channel_messages
1666                .lock()
1667                .values()
1668                .rev()
1669                .filter(|message| {
1670                    message.channel_id == channel_id
1671                        && message.id < before_id.unwrap_or(MessageId::MAX)
1672                })
1673                .take(count)
1674                .cloned()
1675                .collect::<Vec<_>>();
1676            messages.sort_unstable_by_key(|message| message.id);
1677            Ok(messages)
1678        }
1679
1680        async fn teardown(&self, _: &str) {}
1681
1682        #[cfg(test)]
1683        fn as_fake(&self) -> Option<&FakeDb> {
1684            Some(self)
1685        }
1686    }
1687}