db.rs

   1use anyhow::{anyhow, Context, Result};
   2use async_trait::async_trait;
   3use futures::StreamExt;
   4use serde::Serialize;
   5pub use sqlx::postgres::PgPoolOptions as DbOptions;
   6use sqlx::{types::Uuid, FromRow};
   7use time::OffsetDateTime;
   8
   9#[async_trait]
  10pub trait Db: Send + Sync {
  11    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  12    async fn get_all_users(&self) -> Result<Vec<User>>;
  13    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  14    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  15    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  16    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  17    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  18    async fn destroy_user(&self, id: UserId) -> Result<()>;
  19
  20    async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
  21    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  22    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  23    async fn dismiss_contact_request(
  24        &self,
  25        responder_id: UserId,
  26        requester_id: UserId,
  27    ) -> Result<()>;
  28    async fn respond_to_contact_request(
  29        &self,
  30        responder_id: UserId,
  31        requester_id: UserId,
  32        accept: bool,
  33    ) -> Result<()>;
  34
  35    async fn create_access_token_hash(
  36        &self,
  37        user_id: UserId,
  38        access_token_hash: &str,
  39        max_access_token_count: usize,
  40    ) -> Result<()>;
  41    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  42    #[cfg(any(test, feature = "seed-support"))]
  43
  44    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  45    #[cfg(any(test, feature = "seed-support"))]
  46    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  47    #[cfg(any(test, feature = "seed-support"))]
  48    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  49    #[cfg(any(test, feature = "seed-support"))]
  50    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  51    #[cfg(any(test, feature = "seed-support"))]
  52
  53    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  54    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  55    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  56        -> Result<bool>;
  57    #[cfg(any(test, feature = "seed-support"))]
  58    async fn add_channel_member(
  59        &self,
  60        channel_id: ChannelId,
  61        user_id: UserId,
  62        is_admin: bool,
  63    ) -> Result<()>;
  64    async fn create_channel_message(
  65        &self,
  66        channel_id: ChannelId,
  67        sender_id: UserId,
  68        body: &str,
  69        timestamp: OffsetDateTime,
  70        nonce: u128,
  71    ) -> Result<MessageId>;
  72    async fn get_channel_messages(
  73        &self,
  74        channel_id: ChannelId,
  75        count: usize,
  76        before_id: Option<MessageId>,
  77    ) -> Result<Vec<ChannelMessage>>;
  78    #[cfg(test)]
  79    async fn teardown(&self, url: &str);
  80    #[cfg(test)]
  81    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
  82}
  83
  84pub struct PostgresDb {
  85    pool: sqlx::PgPool,
  86}
  87
  88impl PostgresDb {
  89    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  90        let pool = DbOptions::new()
  91            .max_connections(max_connections)
  92            .connect(&url)
  93            .await
  94            .context("failed to connect to postgres database")?;
  95        Ok(Self { pool })
  96    }
  97}
  98
  99#[async_trait]
 100impl Db for PostgresDb {
 101    // users
 102
 103    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 104        let query = "
 105            INSERT INTO users (github_login, admin)
 106            VALUES ($1, $2)
 107            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 108            RETURNING id
 109        ";
 110        Ok(sqlx::query_scalar(query)
 111            .bind(github_login)
 112            .bind(admin)
 113            .fetch_one(&self.pool)
 114            .await
 115            .map(UserId)?)
 116    }
 117
 118    async fn get_all_users(&self) -> Result<Vec<User>> {
 119        let query = "SELECT * FROM users ORDER BY github_login ASC";
 120        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 121    }
 122
 123    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 124        let like_string = fuzzy_like_string(name_query);
 125        let query = "
 126            SELECT users.*
 127            FROM users
 128            WHERE github_login ILIKE $1
 129            ORDER BY github_login <-> $2
 130            LIMIT $3
 131        ";
 132        Ok(sqlx::query_as(query)
 133            .bind(like_string)
 134            .bind(name_query)
 135            .bind(limit)
 136            .fetch_all(&self.pool)
 137            .await?)
 138    }
 139
 140    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 141        let users = self.get_users_by_ids(vec![id]).await?;
 142        Ok(users.into_iter().next())
 143    }
 144
 145    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 146        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 147        let query = "
 148            SELECT users.*
 149            FROM users
 150            WHERE users.id = ANY ($1)
 151        ";
 152        Ok(sqlx::query_as(query)
 153            .bind(&ids)
 154            .fetch_all(&self.pool)
 155            .await?)
 156    }
 157
 158    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 159        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 160        Ok(sqlx::query_as(query)
 161            .bind(github_login)
 162            .fetch_optional(&self.pool)
 163            .await?)
 164    }
 165
 166    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 167        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 168        Ok(sqlx::query(query)
 169            .bind(is_admin)
 170            .bind(id.0)
 171            .execute(&self.pool)
 172            .await
 173            .map(drop)?)
 174    }
 175
 176    async fn destroy_user(&self, id: UserId) -> Result<()> {
 177        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 178        sqlx::query(query)
 179            .bind(id.0)
 180            .execute(&self.pool)
 181            .await
 182            .map(drop)?;
 183        let query = "DELETE FROM users WHERE id = $1;";
 184        Ok(sqlx::query(query)
 185            .bind(id.0)
 186            .execute(&self.pool)
 187            .await
 188            .map(drop)?)
 189    }
 190
 191    // contacts
 192
 193    async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
 194        let query = "
 195            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 196            FROM contacts
 197            WHERE user_id_a = $1 OR user_id_b = $1;
 198        ";
 199
 200        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 201            .bind(user_id)
 202            .fetch(&self.pool);
 203
 204        let mut current = vec![user_id];
 205        let mut outgoing_requests = Vec::new();
 206        let mut incoming_requests = Vec::new();
 207        while let Some(row) = rows.next().await {
 208            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 209
 210            if user_id_a == user_id {
 211                if accepted {
 212                    current.push(user_id_b);
 213                } else if a_to_b {
 214                    outgoing_requests.push(user_id_b);
 215                } else {
 216                    incoming_requests.push(IncomingContactRequest {
 217                        requester_id: user_id_b,
 218                        should_notify,
 219                    });
 220                }
 221            } else {
 222                if accepted {
 223                    current.push(user_id_a);
 224                } else if a_to_b {
 225                    incoming_requests.push(IncomingContactRequest {
 226                        requester_id: user_id_a,
 227                        should_notify,
 228                    });
 229                } else {
 230                    outgoing_requests.push(user_id_a);
 231                }
 232            }
 233        }
 234
 235        current.sort_unstable();
 236        outgoing_requests.sort_unstable();
 237        incoming_requests.sort_unstable();
 238
 239        Ok(Contacts {
 240            current,
 241            outgoing_requests,
 242            incoming_requests,
 243        })
 244    }
 245
 246    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 247        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 248            (sender_id, receiver_id, true)
 249        } else {
 250            (receiver_id, sender_id, false)
 251        };
 252        let query = "
 253            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 254            VALUES ($1, $2, $3, 'f', 't')
 255            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 256            SET
 257                accepted = 't'
 258            WHERE
 259                NOT contacts.accepted AND
 260                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 261                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 262        ";
 263        let result = sqlx::query(query)
 264            .bind(id_a.0)
 265            .bind(id_b.0)
 266            .bind(a_to_b)
 267            .execute(&self.pool)
 268            .await?;
 269
 270        if result.rows_affected() == 1 {
 271            Ok(())
 272        } else {
 273            Err(anyhow!("contact already requested"))
 274        }
 275    }
 276
 277    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 278        let (id_a, id_b) = if responder_id < requester_id {
 279            (responder_id, requester_id)
 280        } else {
 281            (requester_id, responder_id)
 282        };
 283        let query = "
 284            DELETE FROM contacts
 285            WHERE user_id_a = $1 AND user_id_b = $2;
 286        ";
 287        let result = sqlx::query(query)
 288            .bind(id_a.0)
 289            .bind(id_b.0)
 290            .execute(&self.pool)
 291            .await?;
 292
 293        if result.rows_affected() == 1 {
 294            Ok(())
 295        } else {
 296            Err(anyhow!("no such contact"))
 297        }
 298    }
 299
 300    async fn dismiss_contact_request(
 301        &self,
 302        responder_id: UserId,
 303        requester_id: UserId,
 304    ) -> Result<()> {
 305        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 306            (responder_id, requester_id, false)
 307        } else {
 308            (requester_id, responder_id, true)
 309        };
 310
 311        let query = "
 312            UPDATE contacts
 313            SET should_notify = 'f'
 314            WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 315        ";
 316
 317        let result = sqlx::query(query)
 318            .bind(id_a.0)
 319            .bind(id_b.0)
 320            .bind(a_to_b)
 321            .execute(&self.pool)
 322            .await?;
 323
 324        if result.rows_affected() == 0 {
 325            Err(anyhow!("no such contact request"))?;
 326        }
 327
 328        Ok(())
 329    }
 330
 331    async fn respond_to_contact_request(
 332        &self,
 333        responder_id: UserId,
 334        requester_id: UserId,
 335        accept: bool,
 336    ) -> Result<()> {
 337        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 338            (responder_id, requester_id, false)
 339        } else {
 340            (requester_id, responder_id, true)
 341        };
 342        let result = if accept {
 343            let query = "
 344                UPDATE contacts
 345                SET accepted = 't', should_notify = 'f'
 346                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 347            ";
 348            sqlx::query(query)
 349                .bind(id_a.0)
 350                .bind(id_b.0)
 351                .bind(a_to_b)
 352                .execute(&self.pool)
 353                .await?
 354        } else {
 355            let query = "
 356                DELETE FROM contacts
 357                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 358            ";
 359            sqlx::query(query)
 360                .bind(id_a.0)
 361                .bind(id_b.0)
 362                .bind(a_to_b)
 363                .execute(&self.pool)
 364                .await?
 365        };
 366        if result.rows_affected() == 1 {
 367            Ok(())
 368        } else {
 369            Err(anyhow!("no such contact request"))
 370        }
 371    }
 372
 373    // access tokens
 374
 375    async fn create_access_token_hash(
 376        &self,
 377        user_id: UserId,
 378        access_token_hash: &str,
 379        max_access_token_count: usize,
 380    ) -> Result<()> {
 381        let insert_query = "
 382            INSERT INTO access_tokens (user_id, hash)
 383            VALUES ($1, $2);
 384        ";
 385        let cleanup_query = "
 386            DELETE FROM access_tokens
 387            WHERE id IN (
 388                SELECT id from access_tokens
 389                WHERE user_id = $1
 390                ORDER BY id DESC
 391                OFFSET $3
 392            )
 393        ";
 394
 395        let mut tx = self.pool.begin().await?;
 396        sqlx::query(insert_query)
 397            .bind(user_id.0)
 398            .bind(access_token_hash)
 399            .execute(&mut tx)
 400            .await?;
 401        sqlx::query(cleanup_query)
 402            .bind(user_id.0)
 403            .bind(access_token_hash)
 404            .bind(max_access_token_count as u32)
 405            .execute(&mut tx)
 406            .await?;
 407        Ok(tx.commit().await?)
 408    }
 409
 410    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 411        let query = "
 412            SELECT hash
 413            FROM access_tokens
 414            WHERE user_id = $1
 415            ORDER BY id DESC
 416        ";
 417        Ok(sqlx::query_scalar(query)
 418            .bind(user_id.0)
 419            .fetch_all(&self.pool)
 420            .await?)
 421    }
 422
 423    // orgs
 424
 425    #[allow(unused)] // Help rust-analyzer
 426    #[cfg(any(test, feature = "seed-support"))]
 427    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 428        let query = "
 429            SELECT *
 430            FROM orgs
 431            WHERE slug = $1
 432        ";
 433        Ok(sqlx::query_as(query)
 434            .bind(slug)
 435            .fetch_optional(&self.pool)
 436            .await?)
 437    }
 438
 439    #[cfg(any(test, feature = "seed-support"))]
 440    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 441        let query = "
 442            INSERT INTO orgs (name, slug)
 443            VALUES ($1, $2)
 444            RETURNING id
 445        ";
 446        Ok(sqlx::query_scalar(query)
 447            .bind(name)
 448            .bind(slug)
 449            .fetch_one(&self.pool)
 450            .await
 451            .map(OrgId)?)
 452    }
 453
 454    #[cfg(any(test, feature = "seed-support"))]
 455    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 456        let query = "
 457            INSERT INTO org_memberships (org_id, user_id, admin)
 458            VALUES ($1, $2, $3)
 459            ON CONFLICT DO NOTHING
 460        ";
 461        Ok(sqlx::query(query)
 462            .bind(org_id.0)
 463            .bind(user_id.0)
 464            .bind(is_admin)
 465            .execute(&self.pool)
 466            .await
 467            .map(drop)?)
 468    }
 469
 470    // channels
 471
 472    #[cfg(any(test, feature = "seed-support"))]
 473    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 474        let query = "
 475            INSERT INTO channels (owner_id, owner_is_user, name)
 476            VALUES ($1, false, $2)
 477            RETURNING id
 478        ";
 479        Ok(sqlx::query_scalar(query)
 480            .bind(org_id.0)
 481            .bind(name)
 482            .fetch_one(&self.pool)
 483            .await
 484            .map(ChannelId)?)
 485    }
 486
 487    #[allow(unused)] // Help rust-analyzer
 488    #[cfg(any(test, feature = "seed-support"))]
 489    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 490        let query = "
 491            SELECT *
 492            FROM channels
 493            WHERE
 494                channels.owner_is_user = false AND
 495                channels.owner_id = $1
 496        ";
 497        Ok(sqlx::query_as(query)
 498            .bind(org_id.0)
 499            .fetch_all(&self.pool)
 500            .await?)
 501    }
 502
 503    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 504        let query = "
 505            SELECT
 506                channels.*
 507            FROM
 508                channel_memberships, channels
 509            WHERE
 510                channel_memberships.user_id = $1 AND
 511                channel_memberships.channel_id = channels.id
 512        ";
 513        Ok(sqlx::query_as(query)
 514            .bind(user_id.0)
 515            .fetch_all(&self.pool)
 516            .await?)
 517    }
 518
 519    async fn can_user_access_channel(
 520        &self,
 521        user_id: UserId,
 522        channel_id: ChannelId,
 523    ) -> Result<bool> {
 524        let query = "
 525            SELECT id
 526            FROM channel_memberships
 527            WHERE user_id = $1 AND channel_id = $2
 528            LIMIT 1
 529        ";
 530        Ok(sqlx::query_scalar::<_, i32>(query)
 531            .bind(user_id.0)
 532            .bind(channel_id.0)
 533            .fetch_optional(&self.pool)
 534            .await
 535            .map(|e| e.is_some())?)
 536    }
 537
 538    #[cfg(any(test, feature = "seed-support"))]
 539    async fn add_channel_member(
 540        &self,
 541        channel_id: ChannelId,
 542        user_id: UserId,
 543        is_admin: bool,
 544    ) -> Result<()> {
 545        let query = "
 546            INSERT INTO channel_memberships (channel_id, user_id, admin)
 547            VALUES ($1, $2, $3)
 548            ON CONFLICT DO NOTHING
 549        ";
 550        Ok(sqlx::query(query)
 551            .bind(channel_id.0)
 552            .bind(user_id.0)
 553            .bind(is_admin)
 554            .execute(&self.pool)
 555            .await
 556            .map(drop)?)
 557    }
 558
 559    // messages
 560
 561    async fn create_channel_message(
 562        &self,
 563        channel_id: ChannelId,
 564        sender_id: UserId,
 565        body: &str,
 566        timestamp: OffsetDateTime,
 567        nonce: u128,
 568    ) -> Result<MessageId> {
 569        let query = "
 570            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 571            VALUES ($1, $2, $3, $4, $5)
 572            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 573            RETURNING id
 574        ";
 575        Ok(sqlx::query_scalar(query)
 576            .bind(channel_id.0)
 577            .bind(sender_id.0)
 578            .bind(body)
 579            .bind(timestamp)
 580            .bind(Uuid::from_u128(nonce))
 581            .fetch_one(&self.pool)
 582            .await
 583            .map(MessageId)?)
 584    }
 585
 586    async fn get_channel_messages(
 587        &self,
 588        channel_id: ChannelId,
 589        count: usize,
 590        before_id: Option<MessageId>,
 591    ) -> Result<Vec<ChannelMessage>> {
 592        let query = r#"
 593            SELECT * FROM (
 594                SELECT
 595                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 596                FROM
 597                    channel_messages
 598                WHERE
 599                    channel_id = $1 AND
 600                    id < $2
 601                ORDER BY id DESC
 602                LIMIT $3
 603            ) as recent_messages
 604            ORDER BY id ASC
 605        "#;
 606        Ok(sqlx::query_as(query)
 607            .bind(channel_id.0)
 608            .bind(before_id.unwrap_or(MessageId::MAX))
 609            .bind(count as i64)
 610            .fetch_all(&self.pool)
 611            .await?)
 612    }
 613
 614    #[cfg(test)]
 615    async fn teardown(&self, url: &str) {
 616        use util::ResultExt;
 617
 618        let query = "
 619            SELECT pg_terminate_backend(pg_stat_activity.pid)
 620            FROM pg_stat_activity
 621            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 622        ";
 623        sqlx::query(query).execute(&self.pool).await.log_err();
 624        self.pool.close().await;
 625        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 626            .await
 627            .log_err();
 628    }
 629
 630    #[cfg(test)]
 631    fn as_fake(&self) -> Option<&tests::FakeDb> {
 632        None
 633    }
 634}
 635
 636macro_rules! id_type {
 637    ($name:ident) => {
 638        #[derive(
 639            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 640        )]
 641        #[sqlx(transparent)]
 642        #[serde(transparent)]
 643        pub struct $name(pub i32);
 644
 645        impl $name {
 646            #[allow(unused)]
 647            pub const MAX: Self = Self(i32::MAX);
 648
 649            #[allow(unused)]
 650            pub fn from_proto(value: u64) -> Self {
 651                Self(value as i32)
 652            }
 653
 654            #[allow(unused)]
 655            pub fn to_proto(&self) -> u64 {
 656                self.0 as u64
 657            }
 658        }
 659
 660        impl std::fmt::Display for $name {
 661            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 662                self.0.fmt(f)
 663            }
 664        }
 665    };
 666}
 667
 668id_type!(UserId);
 669#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 670pub struct User {
 671    pub id: UserId,
 672    pub github_login: String,
 673    pub admin: bool,
 674}
 675
 676id_type!(OrgId);
 677#[derive(FromRow)]
 678pub struct Org {
 679    pub id: OrgId,
 680    pub name: String,
 681    pub slug: String,
 682}
 683
 684id_type!(ChannelId);
 685#[derive(Clone, Debug, FromRow, Serialize)]
 686pub struct Channel {
 687    pub id: ChannelId,
 688    pub name: String,
 689    pub owner_id: i32,
 690    pub owner_is_user: bool,
 691}
 692
 693id_type!(MessageId);
 694#[derive(Clone, Debug, FromRow)]
 695pub struct ChannelMessage {
 696    pub id: MessageId,
 697    pub channel_id: ChannelId,
 698    pub sender_id: UserId,
 699    pub body: String,
 700    pub sent_at: OffsetDateTime,
 701    pub nonce: Uuid,
 702}
 703
 704#[derive(Clone, Debug, PartialEq, Eq)]
 705pub struct Contacts {
 706    pub current: Vec<UserId>,
 707    pub incoming_requests: Vec<IncomingContactRequest>,
 708    pub outgoing_requests: Vec<UserId>,
 709}
 710
 711#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 712pub struct IncomingContactRequest {
 713    pub requester_id: UserId,
 714    pub should_notify: bool,
 715}
 716
 717fn fuzzy_like_string(string: &str) -> String {
 718    let mut result = String::with_capacity(string.len() * 2 + 1);
 719    for c in string.chars() {
 720        if c.is_alphanumeric() {
 721            result.push('%');
 722            result.push(c);
 723        }
 724    }
 725    result.push('%');
 726    result
 727}
 728
 729#[cfg(test)]
 730pub mod tests {
 731    use super::*;
 732    use anyhow::anyhow;
 733    use collections::BTreeMap;
 734    use gpui::executor::Background;
 735    use lazy_static::lazy_static;
 736    use parking_lot::Mutex;
 737    use rand::prelude::*;
 738    use sqlx::{
 739        migrate::{MigrateDatabase, Migrator},
 740        Postgres,
 741    };
 742    use std::{path::Path, sync::Arc};
 743    use util::post_inc;
 744
 745    #[tokio::test(flavor = "multi_thread")]
 746    async fn test_get_users_by_ids() {
 747        for test_db in [
 748            TestDb::postgres().await,
 749            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 750        ] {
 751            let db = test_db.db();
 752
 753            let user = db.create_user("user", false).await.unwrap();
 754            let friend1 = db.create_user("friend-1", false).await.unwrap();
 755            let friend2 = db.create_user("friend-2", false).await.unwrap();
 756            let friend3 = db.create_user("friend-3", false).await.unwrap();
 757
 758            assert_eq!(
 759                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 760                    .await
 761                    .unwrap(),
 762                vec![
 763                    User {
 764                        id: user,
 765                        github_login: "user".to_string(),
 766                        admin: false,
 767                    },
 768                    User {
 769                        id: friend1,
 770                        github_login: "friend-1".to_string(),
 771                        admin: false,
 772                    },
 773                    User {
 774                        id: friend2,
 775                        github_login: "friend-2".to_string(),
 776                        admin: false,
 777                    },
 778                    User {
 779                        id: friend3,
 780                        github_login: "friend-3".to_string(),
 781                        admin: false,
 782                    }
 783                ]
 784            );
 785        }
 786    }
 787
 788    #[tokio::test(flavor = "multi_thread")]
 789    async fn test_recent_channel_messages() {
 790        for test_db in [
 791            TestDb::postgres().await,
 792            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 793        ] {
 794            let db = test_db.db();
 795            let user = db.create_user("user", false).await.unwrap();
 796            let org = db.create_org("org", "org").await.unwrap();
 797            let channel = db.create_org_channel(org, "channel").await.unwrap();
 798            for i in 0..10 {
 799                db.create_channel_message(
 800                    channel,
 801                    user,
 802                    &i.to_string(),
 803                    OffsetDateTime::now_utc(),
 804                    i,
 805                )
 806                .await
 807                .unwrap();
 808            }
 809
 810            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 811            assert_eq!(
 812                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 813                ["5", "6", "7", "8", "9"]
 814            );
 815
 816            let prev_messages = db
 817                .get_channel_messages(channel, 4, Some(messages[0].id))
 818                .await
 819                .unwrap();
 820            assert_eq!(
 821                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 822                ["1", "2", "3", "4"]
 823            );
 824        }
 825    }
 826
 827    #[tokio::test(flavor = "multi_thread")]
 828    async fn test_channel_message_nonces() {
 829        for test_db in [
 830            TestDb::postgres().await,
 831            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 832        ] {
 833            let db = test_db.db();
 834            let user = db.create_user("user", false).await.unwrap();
 835            let org = db.create_org("org", "org").await.unwrap();
 836            let channel = db.create_org_channel(org, "channel").await.unwrap();
 837
 838            let msg1_id = db
 839                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 840                .await
 841                .unwrap();
 842            let msg2_id = db
 843                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 844                .await
 845                .unwrap();
 846            let msg3_id = db
 847                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
 848                .await
 849                .unwrap();
 850            let msg4_id = db
 851                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
 852                .await
 853                .unwrap();
 854
 855            assert_ne!(msg1_id, msg2_id);
 856            assert_eq!(msg1_id, msg3_id);
 857            assert_eq!(msg2_id, msg4_id);
 858        }
 859    }
 860
 861    #[tokio::test(flavor = "multi_thread")]
 862    async fn test_create_access_tokens() {
 863        let test_db = TestDb::postgres().await;
 864        let db = test_db.db();
 865        let user = db.create_user("the-user", false).await.unwrap();
 866
 867        db.create_access_token_hash(user, "h1", 3).await.unwrap();
 868        db.create_access_token_hash(user, "h2", 3).await.unwrap();
 869        assert_eq!(
 870            db.get_access_token_hashes(user).await.unwrap(),
 871            &["h2".to_string(), "h1".to_string()]
 872        );
 873
 874        db.create_access_token_hash(user, "h3", 3).await.unwrap();
 875        assert_eq!(
 876            db.get_access_token_hashes(user).await.unwrap(),
 877            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
 878        );
 879
 880        db.create_access_token_hash(user, "h4", 3).await.unwrap();
 881        assert_eq!(
 882            db.get_access_token_hashes(user).await.unwrap(),
 883            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
 884        );
 885
 886        db.create_access_token_hash(user, "h5", 3).await.unwrap();
 887        assert_eq!(
 888            db.get_access_token_hashes(user).await.unwrap(),
 889            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
 890        );
 891    }
 892
 893    #[test]
 894    fn test_fuzzy_like_string() {
 895        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
 896        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
 897        assert_eq!(fuzzy_like_string(" z  "), "%z%");
 898    }
 899
 900    #[tokio::test(flavor = "multi_thread")]
 901    async fn test_fuzzy_search_users() {
 902        let test_db = TestDb::postgres().await;
 903        let db = test_db.db();
 904        for github_login in [
 905            "California",
 906            "colorado",
 907            "oregon",
 908            "washington",
 909            "florida",
 910            "delaware",
 911            "rhode-island",
 912        ] {
 913            db.create_user(github_login, false).await.unwrap();
 914        }
 915
 916        assert_eq!(
 917            fuzzy_search_user_names(db, "clr").await,
 918            &["colorado", "California"]
 919        );
 920        assert_eq!(
 921            fuzzy_search_user_names(db, "ro").await,
 922            &["rhode-island", "colorado", "oregon"],
 923        );
 924
 925        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
 926            db.fuzzy_search_users(query, 10)
 927                .await
 928                .unwrap()
 929                .into_iter()
 930                .map(|user| user.github_login)
 931                .collect::<Vec<_>>()
 932        }
 933    }
 934
 935    #[tokio::test(flavor = "multi_thread")]
 936    async fn test_add_contacts() {
 937        for test_db in [
 938            TestDb::postgres().await,
 939            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 940        ] {
 941            let db = test_db.db();
 942
 943            let user_1 = db.create_user("user1", false).await.unwrap();
 944            let user_2 = db.create_user("user2", false).await.unwrap();
 945            let user_3 = db.create_user("user3", false).await.unwrap();
 946
 947            // User starts with no contacts
 948            assert_eq!(
 949                db.get_contacts(user_1).await.unwrap(),
 950                Contacts {
 951                    current: vec![user_1],
 952                    outgoing_requests: vec![],
 953                    incoming_requests: vec![],
 954                },
 955            );
 956
 957            // User requests a contact. Both users see the pending request.
 958            db.send_contact_request(user_1, user_2).await.unwrap();
 959            assert_eq!(
 960                db.get_contacts(user_1).await.unwrap(),
 961                Contacts {
 962                    current: vec![user_1],
 963                    outgoing_requests: vec![user_2],
 964                    incoming_requests: vec![],
 965                },
 966            );
 967            assert_eq!(
 968                db.get_contacts(user_2).await.unwrap(),
 969                Contacts {
 970                    current: vec![user_2],
 971                    outgoing_requests: vec![],
 972                    incoming_requests: vec![IncomingContactRequest {
 973                        requester_id: user_1,
 974                        should_notify: true
 975                    }],
 976                },
 977            );
 978
 979            // User 2 dismisses the contact request notification without accepting or rejecting.
 980            // We shouldn't notify them again.
 981            db.dismiss_contact_request(user_1, user_2)
 982                .await
 983                .unwrap_err();
 984            db.dismiss_contact_request(user_2, user_1).await.unwrap();
 985            assert_eq!(
 986                db.get_contacts(user_2).await.unwrap(),
 987                Contacts {
 988                    current: vec![user_2],
 989                    outgoing_requests: vec![],
 990                    incoming_requests: vec![IncomingContactRequest {
 991                        requester_id: user_1,
 992                        should_notify: false
 993                    }],
 994                },
 995            );
 996
 997            // User can't accept their own contact request
 998            db.respond_to_contact_request(user_1, user_2, true)
 999                .await
1000                .unwrap_err();
1001
1002            // User accepts a contact request. Both users see the contact.
1003            db.respond_to_contact_request(user_2, user_1, true)
1004                .await
1005                .unwrap();
1006            assert_eq!(
1007                db.get_contacts(user_1).await.unwrap(),
1008                Contacts {
1009                    current: vec![user_1, user_2],
1010                    outgoing_requests: vec![],
1011                    incoming_requests: vec![],
1012                },
1013            );
1014            assert_eq!(
1015                db.get_contacts(user_2).await.unwrap(),
1016                Contacts {
1017                    current: vec![user_1, user_2],
1018                    outgoing_requests: vec![],
1019                    incoming_requests: vec![],
1020                },
1021            );
1022
1023            // Users cannot re-request existing contacts.
1024            db.send_contact_request(user_1, user_2).await.unwrap_err();
1025            db.send_contact_request(user_2, user_1).await.unwrap_err();
1026
1027            // Users send each other concurrent contact requests and
1028            // see that they are immediately accepted.
1029            db.send_contact_request(user_1, user_3).await.unwrap();
1030            db.send_contact_request(user_3, user_1).await.unwrap();
1031            assert_eq!(
1032                db.get_contacts(user_1).await.unwrap(),
1033                Contacts {
1034                    current: vec![user_1, user_2, user_3],
1035                    outgoing_requests: vec![],
1036                    incoming_requests: vec![],
1037                },
1038            );
1039            assert_eq!(
1040                db.get_contacts(user_3).await.unwrap(),
1041                Contacts {
1042                    current: vec![user_1, user_3],
1043                    outgoing_requests: vec![],
1044                    incoming_requests: vec![],
1045                },
1046            );
1047
1048            // User declines a contact request. Both users see that it is gone.
1049            db.send_contact_request(user_2, user_3).await.unwrap();
1050            db.respond_to_contact_request(user_3, user_2, false)
1051                .await
1052                .unwrap();
1053            assert_eq!(
1054                db.get_contacts(user_2).await.unwrap(),
1055                Contacts {
1056                    current: vec![user_1, user_2],
1057                    outgoing_requests: vec![],
1058                    incoming_requests: vec![],
1059                },
1060            );
1061            assert_eq!(
1062                db.get_contacts(user_3).await.unwrap(),
1063                Contacts {
1064                    current: vec![user_1, user_3],
1065                    outgoing_requests: vec![],
1066                    incoming_requests: vec![],
1067                },
1068            );
1069        }
1070    }
1071
1072    pub struct TestDb {
1073        pub db: Option<Arc<dyn Db>>,
1074        pub url: String,
1075    }
1076
1077    impl TestDb {
1078        pub async fn postgres() -> Self {
1079            lazy_static! {
1080                static ref LOCK: Mutex<()> = Mutex::new(());
1081            }
1082
1083            let _guard = LOCK.lock();
1084            let mut rng = StdRng::from_entropy();
1085            let name = format!("zed-test-{}", rng.gen::<u128>());
1086            let url = format!("postgres://postgres@localhost/{}", name);
1087            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1088            Postgres::create_database(&url)
1089                .await
1090                .expect("failed to create test db");
1091            let db = PostgresDb::new(&url, 5).await.unwrap();
1092            let migrator = Migrator::new(migrations_path).await.unwrap();
1093            migrator.run(&db.pool).await.unwrap();
1094            Self {
1095                db: Some(Arc::new(db)),
1096                url,
1097            }
1098        }
1099
1100        pub fn fake(background: Arc<Background>) -> Self {
1101            Self {
1102                db: Some(Arc::new(FakeDb::new(background))),
1103                url: Default::default(),
1104            }
1105        }
1106
1107        pub fn db(&self) -> &Arc<dyn Db> {
1108            self.db.as_ref().unwrap()
1109        }
1110    }
1111
1112    impl Drop for TestDb {
1113        fn drop(&mut self) {
1114            if let Some(db) = self.db.take() {
1115                futures::executor::block_on(db.teardown(&self.url));
1116            }
1117        }
1118    }
1119
1120    pub struct FakeDb {
1121        background: Arc<Background>,
1122        pub users: Mutex<BTreeMap<UserId, User>>,
1123        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1124        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1125        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1126        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1127        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1128        pub contacts: Mutex<Vec<FakeContact>>,
1129        next_channel_message_id: Mutex<i32>,
1130        next_user_id: Mutex<i32>,
1131        next_org_id: Mutex<i32>,
1132        next_channel_id: Mutex<i32>,
1133    }
1134
1135    #[derive(Debug)]
1136    pub struct FakeContact {
1137        pub requester_id: UserId,
1138        pub responder_id: UserId,
1139        pub accepted: bool,
1140        pub should_notify: bool,
1141    }
1142
1143    impl FakeDb {
1144        pub fn new(background: Arc<Background>) -> Self {
1145            Self {
1146                background,
1147                users: Default::default(),
1148                next_user_id: Mutex::new(1),
1149                orgs: Default::default(),
1150                next_org_id: Mutex::new(1),
1151                org_memberships: Default::default(),
1152                channels: Default::default(),
1153                next_channel_id: Mutex::new(1),
1154                channel_memberships: Default::default(),
1155                channel_messages: Default::default(),
1156                next_channel_message_id: Mutex::new(1),
1157                contacts: Default::default(),
1158            }
1159        }
1160    }
1161
1162    #[async_trait]
1163    impl Db for FakeDb {
1164        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1165            self.background.simulate_random_delay().await;
1166
1167            let mut users = self.users.lock();
1168            if let Some(user) = users
1169                .values()
1170                .find(|user| user.github_login == github_login)
1171            {
1172                Ok(user.id)
1173            } else {
1174                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1175                users.insert(
1176                    user_id,
1177                    User {
1178                        id: user_id,
1179                        github_login: github_login.to_string(),
1180                        admin,
1181                    },
1182                );
1183                Ok(user_id)
1184            }
1185        }
1186
1187        async fn get_all_users(&self) -> Result<Vec<User>> {
1188            unimplemented!()
1189        }
1190
1191        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1192            unimplemented!()
1193        }
1194
1195        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1196            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1197        }
1198
1199        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1200            self.background.simulate_random_delay().await;
1201            let users = self.users.lock();
1202            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1203        }
1204
1205        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1206            Ok(self
1207                .users
1208                .lock()
1209                .values()
1210                .find(|user| user.github_login == github_login)
1211                .cloned())
1212        }
1213
1214        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1215            unimplemented!()
1216        }
1217
1218        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1219            unimplemented!()
1220        }
1221
1222        async fn get_contacts(&self, id: UserId) -> Result<Contacts> {
1223            self.background.simulate_random_delay().await;
1224            let mut current = vec![id];
1225            let mut outgoing_requests = Vec::new();
1226            let mut incoming_requests = Vec::new();
1227
1228            for contact in self.contacts.lock().iter() {
1229                if contact.requester_id == id {
1230                    if contact.accepted {
1231                        current.push(contact.responder_id);
1232                    } else {
1233                        outgoing_requests.push(contact.responder_id);
1234                    }
1235                } else if contact.responder_id == id {
1236                    if contact.accepted {
1237                        current.push(contact.requester_id);
1238                    } else {
1239                        incoming_requests.push(IncomingContactRequest {
1240                            requester_id: contact.requester_id,
1241                            should_notify: contact.should_notify,
1242                        });
1243                    }
1244                }
1245            }
1246
1247            current.sort_unstable();
1248            outgoing_requests.sort_unstable();
1249            incoming_requests.sort_unstable();
1250
1251            Ok(Contacts {
1252                current,
1253                outgoing_requests,
1254                incoming_requests,
1255            })
1256        }
1257
1258        async fn send_contact_request(
1259            &self,
1260            requester_id: UserId,
1261            responder_id: UserId,
1262        ) -> Result<()> {
1263            let mut contacts = self.contacts.lock();
1264            for contact in contacts.iter_mut() {
1265                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1266                    if contact.accepted {
1267                        Err(anyhow!("contact already exists"))?;
1268                    } else {
1269                        Err(anyhow!("contact already requested"))?;
1270                    }
1271                }
1272                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1273                    if contact.accepted {
1274                        Err(anyhow!("contact already exists"))?;
1275                    } else {
1276                        contact.accepted = true;
1277                        return Ok(());
1278                    }
1279                }
1280            }
1281            contacts.push(FakeContact {
1282                requester_id,
1283                responder_id,
1284                accepted: false,
1285                should_notify: true,
1286            });
1287            Ok(())
1288        }
1289
1290        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1291            self.contacts.lock().retain(|contact| {
1292                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1293            });
1294            Ok(())
1295        }
1296
1297        async fn dismiss_contact_request(
1298            &self,
1299            responder_id: UserId,
1300            requester_id: UserId,
1301        ) -> Result<()> {
1302            let mut contacts = self.contacts.lock();
1303            for contact in contacts.iter_mut() {
1304                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1305                    if contact.accepted {
1306                        return Err(anyhow!("contact already confirmed"));
1307                    }
1308                    contact.should_notify = false;
1309                    return Ok(());
1310                }
1311            }
1312            Err(anyhow!("no such contact request"))
1313        }
1314
1315        async fn respond_to_contact_request(
1316            &self,
1317            responder_id: UserId,
1318            requester_id: UserId,
1319            accept: bool,
1320        ) -> Result<()> {
1321            let mut contacts = self.contacts.lock();
1322            for (ix, contact) in contacts.iter_mut().enumerate() {
1323                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1324                    if contact.accepted {
1325                        return Err(anyhow!("contact already confirmed"));
1326                    }
1327                    if accept {
1328                        contact.accepted = true;
1329                    } else {
1330                        contacts.remove(ix);
1331                    }
1332                    return Ok(());
1333                }
1334            }
1335            Err(anyhow!("no such contact request"))
1336        }
1337
1338        async fn create_access_token_hash(
1339            &self,
1340            _user_id: UserId,
1341            _access_token_hash: &str,
1342            _max_access_token_count: usize,
1343        ) -> Result<()> {
1344            unimplemented!()
1345        }
1346
1347        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1348            unimplemented!()
1349        }
1350
1351        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1352            unimplemented!()
1353        }
1354
1355        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1356            self.background.simulate_random_delay().await;
1357            let mut orgs = self.orgs.lock();
1358            if orgs.values().any(|org| org.slug == slug) {
1359                Err(anyhow!("org already exists"))
1360            } else {
1361                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1362                orgs.insert(
1363                    org_id,
1364                    Org {
1365                        id: org_id,
1366                        name: name.to_string(),
1367                        slug: slug.to_string(),
1368                    },
1369                );
1370                Ok(org_id)
1371            }
1372        }
1373
1374        async fn add_org_member(
1375            &self,
1376            org_id: OrgId,
1377            user_id: UserId,
1378            is_admin: bool,
1379        ) -> Result<()> {
1380            self.background.simulate_random_delay().await;
1381            if !self.orgs.lock().contains_key(&org_id) {
1382                return Err(anyhow!("org does not exist"));
1383            }
1384            if !self.users.lock().contains_key(&user_id) {
1385                return Err(anyhow!("user does not exist"));
1386            }
1387
1388            self.org_memberships
1389                .lock()
1390                .entry((org_id, user_id))
1391                .or_insert(is_admin);
1392            Ok(())
1393        }
1394
1395        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1396            self.background.simulate_random_delay().await;
1397            if !self.orgs.lock().contains_key(&org_id) {
1398                return Err(anyhow!("org does not exist"));
1399            }
1400
1401            let mut channels = self.channels.lock();
1402            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1403            channels.insert(
1404                channel_id,
1405                Channel {
1406                    id: channel_id,
1407                    name: name.to_string(),
1408                    owner_id: org_id.0,
1409                    owner_is_user: false,
1410                },
1411            );
1412            Ok(channel_id)
1413        }
1414
1415        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1416            self.background.simulate_random_delay().await;
1417            Ok(self
1418                .channels
1419                .lock()
1420                .values()
1421                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1422                .cloned()
1423                .collect())
1424        }
1425
1426        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1427            self.background.simulate_random_delay().await;
1428            let channels = self.channels.lock();
1429            let memberships = self.channel_memberships.lock();
1430            Ok(channels
1431                .values()
1432                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1433                .cloned()
1434                .collect())
1435        }
1436
1437        async fn can_user_access_channel(
1438            &self,
1439            user_id: UserId,
1440            channel_id: ChannelId,
1441        ) -> Result<bool> {
1442            self.background.simulate_random_delay().await;
1443            Ok(self
1444                .channel_memberships
1445                .lock()
1446                .contains_key(&(channel_id, user_id)))
1447        }
1448
1449        async fn add_channel_member(
1450            &self,
1451            channel_id: ChannelId,
1452            user_id: UserId,
1453            is_admin: bool,
1454        ) -> Result<()> {
1455            self.background.simulate_random_delay().await;
1456            if !self.channels.lock().contains_key(&channel_id) {
1457                return Err(anyhow!("channel does not exist"));
1458            }
1459            if !self.users.lock().contains_key(&user_id) {
1460                return Err(anyhow!("user does not exist"));
1461            }
1462
1463            self.channel_memberships
1464                .lock()
1465                .entry((channel_id, user_id))
1466                .or_insert(is_admin);
1467            Ok(())
1468        }
1469
1470        async fn create_channel_message(
1471            &self,
1472            channel_id: ChannelId,
1473            sender_id: UserId,
1474            body: &str,
1475            timestamp: OffsetDateTime,
1476            nonce: u128,
1477        ) -> Result<MessageId> {
1478            self.background.simulate_random_delay().await;
1479            if !self.channels.lock().contains_key(&channel_id) {
1480                return Err(anyhow!("channel does not exist"));
1481            }
1482            if !self.users.lock().contains_key(&sender_id) {
1483                return Err(anyhow!("user does not exist"));
1484            }
1485
1486            let mut messages = self.channel_messages.lock();
1487            if let Some(message) = messages
1488                .values()
1489                .find(|message| message.nonce.as_u128() == nonce)
1490            {
1491                Ok(message.id)
1492            } else {
1493                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1494                messages.insert(
1495                    message_id,
1496                    ChannelMessage {
1497                        id: message_id,
1498                        channel_id,
1499                        sender_id,
1500                        body: body.to_string(),
1501                        sent_at: timestamp,
1502                        nonce: Uuid::from_u128(nonce),
1503                    },
1504                );
1505                Ok(message_id)
1506            }
1507        }
1508
1509        async fn get_channel_messages(
1510            &self,
1511            channel_id: ChannelId,
1512            count: usize,
1513            before_id: Option<MessageId>,
1514        ) -> Result<Vec<ChannelMessage>> {
1515            let mut messages = self
1516                .channel_messages
1517                .lock()
1518                .values()
1519                .rev()
1520                .filter(|message| {
1521                    message.channel_id == channel_id
1522                        && message.id < before_id.unwrap_or(MessageId::MAX)
1523                })
1524                .take(count)
1525                .cloned()
1526                .collect::<Vec<_>>();
1527            messages.sort_unstable_by_key(|message| message.id);
1528            Ok(messages)
1529        }
1530
1531        async fn teardown(&self, _: &str) {}
1532
1533        #[cfg(test)]
1534        fn as_fake(&self) -> Option<&FakeDb> {
1535            Some(self)
1536        }
1537    }
1538}