db.rs

   1use anyhow::{anyhow, Context, Result};
   2use async_trait::async_trait;
   3use futures::StreamExt;
   4use nanoid::nanoid;
   5use serde::Serialize;
   6pub use sqlx::postgres::PgPoolOptions as DbOptions;
   7use sqlx::{types::Uuid, FromRow};
   8use time::OffsetDateTime;
   9
  10#[async_trait]
  11pub trait Db: Send + Sync {
  12    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  13    async fn get_all_users(&self) -> Result<Vec<User>>;
  14    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  15    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  16    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  17    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  18    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  19    async fn destroy_user(&self, id: UserId) -> Result<()>;
  20
  21    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  22    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>>;
  23    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
  24
  25    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  26    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  27    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  28    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  29    async fn dismiss_contact_notification(
  30        &self,
  31        responder_id: UserId,
  32        requester_id: UserId,
  33    ) -> Result<()>;
  34    async fn respond_to_contact_request(
  35        &self,
  36        responder_id: UserId,
  37        requester_id: UserId,
  38        accept: bool,
  39    ) -> Result<()>;
  40
  41    async fn create_access_token_hash(
  42        &self,
  43        user_id: UserId,
  44        access_token_hash: &str,
  45        max_access_token_count: usize,
  46    ) -> Result<()>;
  47    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  48    #[cfg(any(test, feature = "seed-support"))]
  49
  50    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  51    #[cfg(any(test, feature = "seed-support"))]
  52    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  53    #[cfg(any(test, feature = "seed-support"))]
  54    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  55    #[cfg(any(test, feature = "seed-support"))]
  56    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  57    #[cfg(any(test, feature = "seed-support"))]
  58
  59    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  60    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  61    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  62        -> Result<bool>;
  63    #[cfg(any(test, feature = "seed-support"))]
  64    async fn add_channel_member(
  65        &self,
  66        channel_id: ChannelId,
  67        user_id: UserId,
  68        is_admin: bool,
  69    ) -> Result<()>;
  70    async fn create_channel_message(
  71        &self,
  72        channel_id: ChannelId,
  73        sender_id: UserId,
  74        body: &str,
  75        timestamp: OffsetDateTime,
  76        nonce: u128,
  77    ) -> Result<MessageId>;
  78    async fn get_channel_messages(
  79        &self,
  80        channel_id: ChannelId,
  81        count: usize,
  82        before_id: Option<MessageId>,
  83    ) -> Result<Vec<ChannelMessage>>;
  84    #[cfg(test)]
  85    async fn teardown(&self, url: &str);
  86    #[cfg(test)]
  87    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
  88}
  89
  90pub struct PostgresDb {
  91    pool: sqlx::PgPool,
  92}
  93
  94impl PostgresDb {
  95    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  96        let pool = DbOptions::new()
  97            .max_connections(max_connections)
  98            .connect(&url)
  99            .await
 100            .context("failed to connect to postgres database")?;
 101        Ok(Self { pool })
 102    }
 103}
 104
 105#[async_trait]
 106impl Db for PostgresDb {
 107    // users
 108
 109    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 110        let query = "
 111            INSERT INTO users (github_login, admin)
 112            VALUES ($1, $2)
 113            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 114            RETURNING id
 115        ";
 116        Ok(sqlx::query_scalar(query)
 117            .bind(github_login)
 118            .bind(admin)
 119            .fetch_one(&self.pool)
 120            .await
 121            .map(UserId)?)
 122    }
 123
 124    async fn get_all_users(&self) -> Result<Vec<User>> {
 125        let query = "SELECT * FROM users ORDER BY github_login ASC";
 126        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 127    }
 128
 129    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 130        let like_string = fuzzy_like_string(name_query);
 131        let query = "
 132            SELECT users.*
 133            FROM users
 134            WHERE github_login ILIKE $1
 135            ORDER BY github_login <-> $2
 136            LIMIT $3
 137        ";
 138        Ok(sqlx::query_as(query)
 139            .bind(like_string)
 140            .bind(name_query)
 141            .bind(limit)
 142            .fetch_all(&self.pool)
 143            .await?)
 144    }
 145
 146    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 147        let users = self.get_users_by_ids(vec![id]).await?;
 148        Ok(users.into_iter().next())
 149    }
 150
 151    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 152        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 153        let query = "
 154            SELECT users.*
 155            FROM users
 156            WHERE users.id = ANY ($1)
 157        ";
 158        Ok(sqlx::query_as(query)
 159            .bind(&ids)
 160            .fetch_all(&self.pool)
 161            .await?)
 162    }
 163
 164    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 165        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 166        Ok(sqlx::query_as(query)
 167            .bind(github_login)
 168            .fetch_optional(&self.pool)
 169            .await?)
 170    }
 171
 172    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 173        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 174        Ok(sqlx::query(query)
 175            .bind(is_admin)
 176            .bind(id.0)
 177            .execute(&self.pool)
 178            .await
 179            .map(drop)?)
 180    }
 181
 182    async fn destroy_user(&self, id: UserId) -> Result<()> {
 183        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 184        sqlx::query(query)
 185            .bind(id.0)
 186            .execute(&self.pool)
 187            .await
 188            .map(drop)?;
 189        let query = "DELETE FROM users WHERE id = $1;";
 190        Ok(sqlx::query(query)
 191            .bind(id.0)
 192            .execute(&self.pool)
 193            .await
 194            .map(drop)?)
 195    }
 196
 197    // invite codes
 198
 199    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 200        let mut tx = self.pool.begin().await?;
 201        sqlx::query(
 202            "
 203                UPDATE users
 204                SET invite_code = $1
 205                WHERE id = $2 AND invite_code IS NULL
 206            ",
 207        )
 208        .bind(nanoid!(16))
 209        .bind(id)
 210        .execute(&mut tx)
 211        .await?;
 212        sqlx::query(
 213            "
 214            UPDATE users
 215            SET invite_count = $1
 216            WHERE id = $2
 217            ",
 218        )
 219        .bind(count)
 220        .bind(id)
 221        .execute(&mut tx)
 222        .await?;
 223        tx.commit().await?;
 224        Ok(())
 225    }
 226
 227    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>> {
 228        let result: Option<(String, i32)> = sqlx::query_as(
 229            "
 230                SELECT invite_code, invite_count
 231                FROM users
 232                WHERE id = $1 AND invite_code IS NOT NULL 
 233            ",
 234        )
 235        .bind(id)
 236        .fetch_optional(&self.pool)
 237        .await?;
 238        if let Some((code, count)) = result {
 239            Ok(Some((code, count.try_into()?)))
 240        } else {
 241            Ok(None)
 242        }
 243    }
 244
 245    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
 246        let mut tx = self.pool.begin().await?;
 247
 248        let inviter_id: UserId = sqlx::query_scalar(
 249            "
 250                UPDATE users
 251                SET invite_count = invite_count - 1
 252                WHERE
 253                    invite_code = $1 AND
 254                    invite_count > 0
 255                RETURNING id
 256            ",
 257        )
 258        .bind(code)
 259        .fetch_optional(&mut tx)
 260        .await?
 261        .ok_or_else(|| anyhow!("invite code not found"))?;
 262        let invitee_id = sqlx::query_scalar(
 263            "
 264                INSERT INTO users
 265                    (github_login, admin, inviter_id)
 266                VALUES
 267                    ($1, 'f', $2)
 268                RETURNING id
 269            ",
 270        )
 271        .bind(login)
 272        .bind(inviter_id)
 273        .fetch_one(&mut tx)
 274        .await
 275        .map(UserId)?;
 276
 277        sqlx::query(
 278            "
 279                INSERT INTO contacts
 280                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 281                VALUES
 282                    ($1, $2, 't', 't', 't')
 283            ",
 284        )
 285        .bind(inviter_id)
 286        .bind(invitee_id)
 287        .execute(&mut tx)
 288        .await?;
 289
 290        tx.commit().await?;
 291        Ok(invitee_id)
 292    }
 293
 294    // contacts
 295
 296    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 297        let query = "
 298            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 299            FROM contacts
 300            WHERE user_id_a = $1 OR user_id_b = $1;
 301        ";
 302
 303        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 304            .bind(user_id)
 305            .fetch(&self.pool);
 306
 307        let mut contacts = vec![Contact::Accepted {
 308            user_id,
 309            should_notify: false,
 310        }];
 311        while let Some(row) = rows.next().await {
 312            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 313
 314            if user_id_a == user_id {
 315                if accepted {
 316                    contacts.push(Contact::Accepted {
 317                        user_id: user_id_b,
 318                        should_notify: should_notify && a_to_b,
 319                    });
 320                } else if a_to_b {
 321                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 322                } else {
 323                    contacts.push(Contact::Incoming {
 324                        user_id: user_id_b,
 325                        should_notify,
 326                    });
 327                }
 328            } else {
 329                if accepted {
 330                    contacts.push(Contact::Accepted {
 331                        user_id: user_id_a,
 332                        should_notify: should_notify && !a_to_b,
 333                    });
 334                } else if a_to_b {
 335                    contacts.push(Contact::Incoming {
 336                        user_id: user_id_a,
 337                        should_notify,
 338                    });
 339                } else {
 340                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 341                }
 342            }
 343        }
 344
 345        contacts.sort_unstable_by_key(|contact| contact.user_id());
 346
 347        Ok(contacts)
 348    }
 349
 350    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 351        let (id_a, id_b) = if user_id_1 < user_id_2 {
 352            (user_id_1, user_id_2)
 353        } else {
 354            (user_id_2, user_id_1)
 355        };
 356
 357        let query = "
 358            SELECT 1 FROM contacts
 359            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 360            LIMIT 1
 361        ";
 362        Ok(sqlx::query_scalar::<_, i32>(query)
 363            .bind(id_a.0)
 364            .bind(id_b.0)
 365            .fetch_optional(&self.pool)
 366            .await?
 367            .is_some())
 368    }
 369
 370    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 371        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 372            (sender_id, receiver_id, true)
 373        } else {
 374            (receiver_id, sender_id, false)
 375        };
 376        let query = "
 377            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 378            VALUES ($1, $2, $3, 'f', 't')
 379            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 380            SET
 381                accepted = 't',
 382                should_notify = 'f'
 383            WHERE
 384                NOT contacts.accepted AND
 385                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 386                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 387        ";
 388        let result = sqlx::query(query)
 389            .bind(id_a.0)
 390            .bind(id_b.0)
 391            .bind(a_to_b)
 392            .execute(&self.pool)
 393            .await?;
 394
 395        if result.rows_affected() == 1 {
 396            Ok(())
 397        } else {
 398            Err(anyhow!("contact already requested"))
 399        }
 400    }
 401
 402    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 403        let (id_a, id_b) = if responder_id < requester_id {
 404            (responder_id, requester_id)
 405        } else {
 406            (requester_id, responder_id)
 407        };
 408        let query = "
 409            DELETE FROM contacts
 410            WHERE user_id_a = $1 AND user_id_b = $2;
 411        ";
 412        let result = sqlx::query(query)
 413            .bind(id_a.0)
 414            .bind(id_b.0)
 415            .execute(&self.pool)
 416            .await?;
 417
 418        if result.rows_affected() == 1 {
 419            Ok(())
 420        } else {
 421            Err(anyhow!("no such contact"))
 422        }
 423    }
 424
 425    async fn dismiss_contact_notification(
 426        &self,
 427        user_id: UserId,
 428        contact_user_id: UserId,
 429    ) -> Result<()> {
 430        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 431            (user_id, contact_user_id, true)
 432        } else {
 433            (contact_user_id, user_id, false)
 434        };
 435
 436        let query = "
 437            UPDATE contacts
 438            SET should_notify = 'f'
 439            WHERE
 440                user_id_a = $1 AND user_id_b = $2 AND
 441                (
 442                    (a_to_b = $3 AND accepted) OR
 443                    (a_to_b != $3 AND NOT accepted)
 444                );
 445        ";
 446
 447        let result = sqlx::query(query)
 448            .bind(id_a.0)
 449            .bind(id_b.0)
 450            .bind(a_to_b)
 451            .execute(&self.pool)
 452            .await?;
 453
 454        if result.rows_affected() == 0 {
 455            Err(anyhow!("no such contact request"))?;
 456        }
 457
 458        Ok(())
 459    }
 460
 461    async fn respond_to_contact_request(
 462        &self,
 463        responder_id: UserId,
 464        requester_id: UserId,
 465        accept: bool,
 466    ) -> Result<()> {
 467        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 468            (responder_id, requester_id, false)
 469        } else {
 470            (requester_id, responder_id, true)
 471        };
 472        let result = if accept {
 473            let query = "
 474                UPDATE contacts
 475                SET accepted = 't', should_notify = 't'
 476                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 477            ";
 478            sqlx::query(query)
 479                .bind(id_a.0)
 480                .bind(id_b.0)
 481                .bind(a_to_b)
 482                .execute(&self.pool)
 483                .await?
 484        } else {
 485            let query = "
 486                DELETE FROM contacts
 487                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 488            ";
 489            sqlx::query(query)
 490                .bind(id_a.0)
 491                .bind(id_b.0)
 492                .bind(a_to_b)
 493                .execute(&self.pool)
 494                .await?
 495        };
 496        if result.rows_affected() == 1 {
 497            Ok(())
 498        } else {
 499            Err(anyhow!("no such contact request"))
 500        }
 501    }
 502
 503    // access tokens
 504
 505    async fn create_access_token_hash(
 506        &self,
 507        user_id: UserId,
 508        access_token_hash: &str,
 509        max_access_token_count: usize,
 510    ) -> Result<()> {
 511        let insert_query = "
 512            INSERT INTO access_tokens (user_id, hash)
 513            VALUES ($1, $2);
 514        ";
 515        let cleanup_query = "
 516            DELETE FROM access_tokens
 517            WHERE id IN (
 518                SELECT id from access_tokens
 519                WHERE user_id = $1
 520                ORDER BY id DESC
 521                OFFSET $3
 522            )
 523        ";
 524
 525        let mut tx = self.pool.begin().await?;
 526        sqlx::query(insert_query)
 527            .bind(user_id.0)
 528            .bind(access_token_hash)
 529            .execute(&mut tx)
 530            .await?;
 531        sqlx::query(cleanup_query)
 532            .bind(user_id.0)
 533            .bind(access_token_hash)
 534            .bind(max_access_token_count as u32)
 535            .execute(&mut tx)
 536            .await?;
 537        Ok(tx.commit().await?)
 538    }
 539
 540    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 541        let query = "
 542            SELECT hash
 543            FROM access_tokens
 544            WHERE user_id = $1
 545            ORDER BY id DESC
 546        ";
 547        Ok(sqlx::query_scalar(query)
 548            .bind(user_id.0)
 549            .fetch_all(&self.pool)
 550            .await?)
 551    }
 552
 553    // orgs
 554
 555    #[allow(unused)] // Help rust-analyzer
 556    #[cfg(any(test, feature = "seed-support"))]
 557    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 558        let query = "
 559            SELECT *
 560            FROM orgs
 561            WHERE slug = $1
 562        ";
 563        Ok(sqlx::query_as(query)
 564            .bind(slug)
 565            .fetch_optional(&self.pool)
 566            .await?)
 567    }
 568
 569    #[cfg(any(test, feature = "seed-support"))]
 570    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 571        let query = "
 572            INSERT INTO orgs (name, slug)
 573            VALUES ($1, $2)
 574            RETURNING id
 575        ";
 576        Ok(sqlx::query_scalar(query)
 577            .bind(name)
 578            .bind(slug)
 579            .fetch_one(&self.pool)
 580            .await
 581            .map(OrgId)?)
 582    }
 583
 584    #[cfg(any(test, feature = "seed-support"))]
 585    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 586        let query = "
 587            INSERT INTO org_memberships (org_id, user_id, admin)
 588            VALUES ($1, $2, $3)
 589            ON CONFLICT DO NOTHING
 590        ";
 591        Ok(sqlx::query(query)
 592            .bind(org_id.0)
 593            .bind(user_id.0)
 594            .bind(is_admin)
 595            .execute(&self.pool)
 596            .await
 597            .map(drop)?)
 598    }
 599
 600    // channels
 601
 602    #[cfg(any(test, feature = "seed-support"))]
 603    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 604        let query = "
 605            INSERT INTO channels (owner_id, owner_is_user, name)
 606            VALUES ($1, false, $2)
 607            RETURNING id
 608        ";
 609        Ok(sqlx::query_scalar(query)
 610            .bind(org_id.0)
 611            .bind(name)
 612            .fetch_one(&self.pool)
 613            .await
 614            .map(ChannelId)?)
 615    }
 616
 617    #[allow(unused)] // Help rust-analyzer
 618    #[cfg(any(test, feature = "seed-support"))]
 619    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 620        let query = "
 621            SELECT *
 622            FROM channels
 623            WHERE
 624                channels.owner_is_user = false AND
 625                channels.owner_id = $1
 626        ";
 627        Ok(sqlx::query_as(query)
 628            .bind(org_id.0)
 629            .fetch_all(&self.pool)
 630            .await?)
 631    }
 632
 633    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 634        let query = "
 635            SELECT
 636                channels.*
 637            FROM
 638                channel_memberships, channels
 639            WHERE
 640                channel_memberships.user_id = $1 AND
 641                channel_memberships.channel_id = channels.id
 642        ";
 643        Ok(sqlx::query_as(query)
 644            .bind(user_id.0)
 645            .fetch_all(&self.pool)
 646            .await?)
 647    }
 648
 649    async fn can_user_access_channel(
 650        &self,
 651        user_id: UserId,
 652        channel_id: ChannelId,
 653    ) -> Result<bool> {
 654        let query = "
 655            SELECT id
 656            FROM channel_memberships
 657            WHERE user_id = $1 AND channel_id = $2
 658            LIMIT 1
 659        ";
 660        Ok(sqlx::query_scalar::<_, i32>(query)
 661            .bind(user_id.0)
 662            .bind(channel_id.0)
 663            .fetch_optional(&self.pool)
 664            .await
 665            .map(|e| e.is_some())?)
 666    }
 667
 668    #[cfg(any(test, feature = "seed-support"))]
 669    async fn add_channel_member(
 670        &self,
 671        channel_id: ChannelId,
 672        user_id: UserId,
 673        is_admin: bool,
 674    ) -> Result<()> {
 675        let query = "
 676            INSERT INTO channel_memberships (channel_id, user_id, admin)
 677            VALUES ($1, $2, $3)
 678            ON CONFLICT DO NOTHING
 679        ";
 680        Ok(sqlx::query(query)
 681            .bind(channel_id.0)
 682            .bind(user_id.0)
 683            .bind(is_admin)
 684            .execute(&self.pool)
 685            .await
 686            .map(drop)?)
 687    }
 688
 689    // messages
 690
 691    async fn create_channel_message(
 692        &self,
 693        channel_id: ChannelId,
 694        sender_id: UserId,
 695        body: &str,
 696        timestamp: OffsetDateTime,
 697        nonce: u128,
 698    ) -> Result<MessageId> {
 699        let query = "
 700            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 701            VALUES ($1, $2, $3, $4, $5)
 702            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 703            RETURNING id
 704        ";
 705        Ok(sqlx::query_scalar(query)
 706            .bind(channel_id.0)
 707            .bind(sender_id.0)
 708            .bind(body)
 709            .bind(timestamp)
 710            .bind(Uuid::from_u128(nonce))
 711            .fetch_one(&self.pool)
 712            .await
 713            .map(MessageId)?)
 714    }
 715
 716    async fn get_channel_messages(
 717        &self,
 718        channel_id: ChannelId,
 719        count: usize,
 720        before_id: Option<MessageId>,
 721    ) -> Result<Vec<ChannelMessage>> {
 722        let query = r#"
 723            SELECT * FROM (
 724                SELECT
 725                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 726                FROM
 727                    channel_messages
 728                WHERE
 729                    channel_id = $1 AND
 730                    id < $2
 731                ORDER BY id DESC
 732                LIMIT $3
 733            ) as recent_messages
 734            ORDER BY id ASC
 735        "#;
 736        Ok(sqlx::query_as(query)
 737            .bind(channel_id.0)
 738            .bind(before_id.unwrap_or(MessageId::MAX))
 739            .bind(count as i64)
 740            .fetch_all(&self.pool)
 741            .await?)
 742    }
 743
 744    #[cfg(test)]
 745    async fn teardown(&self, url: &str) {
 746        use util::ResultExt;
 747
 748        let query = "
 749            SELECT pg_terminate_backend(pg_stat_activity.pid)
 750            FROM pg_stat_activity
 751            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 752        ";
 753        sqlx::query(query).execute(&self.pool).await.log_err();
 754        self.pool.close().await;
 755        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 756            .await
 757            .log_err();
 758    }
 759
 760    #[cfg(test)]
 761    fn as_fake(&self) -> Option<&tests::FakeDb> {
 762        None
 763    }
 764}
 765
 766macro_rules! id_type {
 767    ($name:ident) => {
 768        #[derive(
 769            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 770        )]
 771        #[sqlx(transparent)]
 772        #[serde(transparent)]
 773        pub struct $name(pub i32);
 774
 775        impl $name {
 776            #[allow(unused)]
 777            pub const MAX: Self = Self(i32::MAX);
 778
 779            #[allow(unused)]
 780            pub fn from_proto(value: u64) -> Self {
 781                Self(value as i32)
 782            }
 783
 784            #[allow(unused)]
 785            pub fn to_proto(&self) -> u64 {
 786                self.0 as u64
 787            }
 788        }
 789
 790        impl std::fmt::Display for $name {
 791            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 792                self.0.fmt(f)
 793            }
 794        }
 795    };
 796}
 797
 798id_type!(UserId);
 799#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 800pub struct User {
 801    pub id: UserId,
 802    pub github_login: String,
 803    pub admin: bool,
 804    pub invite_code: Option<String>,
 805    pub invite_count: i32,
 806}
 807
 808id_type!(OrgId);
 809#[derive(FromRow)]
 810pub struct Org {
 811    pub id: OrgId,
 812    pub name: String,
 813    pub slug: String,
 814}
 815
 816id_type!(ChannelId);
 817#[derive(Clone, Debug, FromRow, Serialize)]
 818pub struct Channel {
 819    pub id: ChannelId,
 820    pub name: String,
 821    pub owner_id: i32,
 822    pub owner_is_user: bool,
 823}
 824
 825id_type!(MessageId);
 826#[derive(Clone, Debug, FromRow)]
 827pub struct ChannelMessage {
 828    pub id: MessageId,
 829    pub channel_id: ChannelId,
 830    pub sender_id: UserId,
 831    pub body: String,
 832    pub sent_at: OffsetDateTime,
 833    pub nonce: Uuid,
 834}
 835
 836#[derive(Clone, Debug, PartialEq, Eq)]
 837pub enum Contact {
 838    Accepted {
 839        user_id: UserId,
 840        should_notify: bool,
 841    },
 842    Outgoing {
 843        user_id: UserId,
 844    },
 845    Incoming {
 846        user_id: UserId,
 847        should_notify: bool,
 848    },
 849}
 850
 851impl Contact {
 852    pub fn user_id(&self) -> UserId {
 853        match self {
 854            Contact::Accepted { user_id, .. } => *user_id,
 855            Contact::Outgoing { user_id } => *user_id,
 856            Contact::Incoming { user_id, .. } => *user_id,
 857        }
 858    }
 859}
 860
 861#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 862pub struct IncomingContactRequest {
 863    pub requester_id: UserId,
 864    pub should_notify: bool,
 865}
 866
 867fn fuzzy_like_string(string: &str) -> String {
 868    let mut result = String::with_capacity(string.len() * 2 + 1);
 869    for c in string.chars() {
 870        if c.is_alphanumeric() {
 871            result.push('%');
 872            result.push(c);
 873        }
 874    }
 875    result.push('%');
 876    result
 877}
 878
 879#[cfg(test)]
 880pub mod tests {
 881    use super::*;
 882    use anyhow::anyhow;
 883    use collections::BTreeMap;
 884    use gpui::executor::Background;
 885    use lazy_static::lazy_static;
 886    use parking_lot::Mutex;
 887    use rand::prelude::*;
 888    use sqlx::{
 889        migrate::{MigrateDatabase, Migrator},
 890        Postgres,
 891    };
 892    use std::{path::Path, sync::Arc};
 893    use util::post_inc;
 894
 895    #[tokio::test(flavor = "multi_thread")]
 896    async fn test_get_users_by_ids() {
 897        for test_db in [
 898            TestDb::postgres().await,
 899            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 900        ] {
 901            let db = test_db.db();
 902
 903            let user = db.create_user("user", false).await.unwrap();
 904            let friend1 = db.create_user("friend-1", false).await.unwrap();
 905            let friend2 = db.create_user("friend-2", false).await.unwrap();
 906            let friend3 = db.create_user("friend-3", false).await.unwrap();
 907
 908            assert_eq!(
 909                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 910                    .await
 911                    .unwrap(),
 912                vec![
 913                    User {
 914                        id: user,
 915                        github_login: "user".to_string(),
 916                        admin: false,
 917                        ..Default::default()
 918                    },
 919                    User {
 920                        id: friend1,
 921                        github_login: "friend-1".to_string(),
 922                        admin: false,
 923                        ..Default::default()
 924                    },
 925                    User {
 926                        id: friend2,
 927                        github_login: "friend-2".to_string(),
 928                        admin: false,
 929                        ..Default::default()
 930                    },
 931                    User {
 932                        id: friend3,
 933                        github_login: "friend-3".to_string(),
 934                        admin: false,
 935                        ..Default::default()
 936                    }
 937                ]
 938            );
 939        }
 940    }
 941
 942    #[tokio::test(flavor = "multi_thread")]
 943    async fn test_recent_channel_messages() {
 944        for test_db in [
 945            TestDb::postgres().await,
 946            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 947        ] {
 948            let db = test_db.db();
 949            let user = db.create_user("user", false).await.unwrap();
 950            let org = db.create_org("org", "org").await.unwrap();
 951            let channel = db.create_org_channel(org, "channel").await.unwrap();
 952            for i in 0..10 {
 953                db.create_channel_message(
 954                    channel,
 955                    user,
 956                    &i.to_string(),
 957                    OffsetDateTime::now_utc(),
 958                    i,
 959                )
 960                .await
 961                .unwrap();
 962            }
 963
 964            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 965            assert_eq!(
 966                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 967                ["5", "6", "7", "8", "9"]
 968            );
 969
 970            let prev_messages = db
 971                .get_channel_messages(channel, 4, Some(messages[0].id))
 972                .await
 973                .unwrap();
 974            assert_eq!(
 975                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 976                ["1", "2", "3", "4"]
 977            );
 978        }
 979    }
 980
 981    #[tokio::test(flavor = "multi_thread")]
 982    async fn test_channel_message_nonces() {
 983        for test_db in [
 984            TestDb::postgres().await,
 985            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 986        ] {
 987            let db = test_db.db();
 988            let user = db.create_user("user", false).await.unwrap();
 989            let org = db.create_org("org", "org").await.unwrap();
 990            let channel = db.create_org_channel(org, "channel").await.unwrap();
 991
 992            let msg1_id = db
 993                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 994                .await
 995                .unwrap();
 996            let msg2_id = db
 997                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 998                .await
 999                .unwrap();
1000            let msg3_id = db
1001                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1002                .await
1003                .unwrap();
1004            let msg4_id = db
1005                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1006                .await
1007                .unwrap();
1008
1009            assert_ne!(msg1_id, msg2_id);
1010            assert_eq!(msg1_id, msg3_id);
1011            assert_eq!(msg2_id, msg4_id);
1012        }
1013    }
1014
1015    #[tokio::test(flavor = "multi_thread")]
1016    async fn test_create_access_tokens() {
1017        let test_db = TestDb::postgres().await;
1018        let db = test_db.db();
1019        let user = db.create_user("the-user", false).await.unwrap();
1020
1021        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1022        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1023        assert_eq!(
1024            db.get_access_token_hashes(user).await.unwrap(),
1025            &["h2".to_string(), "h1".to_string()]
1026        );
1027
1028        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1029        assert_eq!(
1030            db.get_access_token_hashes(user).await.unwrap(),
1031            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1032        );
1033
1034        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1035        assert_eq!(
1036            db.get_access_token_hashes(user).await.unwrap(),
1037            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1038        );
1039
1040        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1041        assert_eq!(
1042            db.get_access_token_hashes(user).await.unwrap(),
1043            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1044        );
1045    }
1046
1047    #[test]
1048    fn test_fuzzy_like_string() {
1049        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1050        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1051        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1052    }
1053
1054    #[tokio::test(flavor = "multi_thread")]
1055    async fn test_fuzzy_search_users() {
1056        let test_db = TestDb::postgres().await;
1057        let db = test_db.db();
1058        for github_login in [
1059            "California",
1060            "colorado",
1061            "oregon",
1062            "washington",
1063            "florida",
1064            "delaware",
1065            "rhode-island",
1066        ] {
1067            db.create_user(github_login, false).await.unwrap();
1068        }
1069
1070        assert_eq!(
1071            fuzzy_search_user_names(db, "clr").await,
1072            &["colorado", "California"]
1073        );
1074        assert_eq!(
1075            fuzzy_search_user_names(db, "ro").await,
1076            &["rhode-island", "colorado", "oregon"],
1077        );
1078
1079        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1080            db.fuzzy_search_users(query, 10)
1081                .await
1082                .unwrap()
1083                .into_iter()
1084                .map(|user| user.github_login)
1085                .collect::<Vec<_>>()
1086        }
1087    }
1088
1089    #[tokio::test(flavor = "multi_thread")]
1090    async fn test_add_contacts() {
1091        for test_db in [
1092            TestDb::postgres().await,
1093            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1094        ] {
1095            let db = test_db.db();
1096
1097            let user_1 = db.create_user("user1", false).await.unwrap();
1098            let user_2 = db.create_user("user2", false).await.unwrap();
1099            let user_3 = db.create_user("user3", false).await.unwrap();
1100
1101            // User starts with no contacts
1102            assert_eq!(
1103                db.get_contacts(user_1).await.unwrap(),
1104                vec![Contact::Accepted {
1105                    user_id: user_1,
1106                    should_notify: false
1107                }],
1108            );
1109
1110            // User requests a contact. Both users see the pending request.
1111            db.send_contact_request(user_1, user_2).await.unwrap();
1112            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1113            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1114            assert_eq!(
1115                db.get_contacts(user_1).await.unwrap(),
1116                &[
1117                    Contact::Accepted {
1118                        user_id: user_1,
1119                        should_notify: false
1120                    },
1121                    Contact::Outgoing { user_id: user_2 }
1122                ],
1123            );
1124            assert_eq!(
1125                db.get_contacts(user_2).await.unwrap(),
1126                &[
1127                    Contact::Incoming {
1128                        user_id: user_1,
1129                        should_notify: true
1130                    },
1131                    Contact::Accepted {
1132                        user_id: user_2,
1133                        should_notify: false
1134                    },
1135                ]
1136            );
1137
1138            // User 2 dismisses the contact request notification without accepting or rejecting.
1139            // We shouldn't notify them again.
1140            db.dismiss_contact_notification(user_1, user_2)
1141                .await
1142                .unwrap_err();
1143            db.dismiss_contact_notification(user_2, user_1)
1144                .await
1145                .unwrap();
1146            assert_eq!(
1147                db.get_contacts(user_2).await.unwrap(),
1148                &[
1149                    Contact::Incoming {
1150                        user_id: user_1,
1151                        should_notify: false
1152                    },
1153                    Contact::Accepted {
1154                        user_id: user_2,
1155                        should_notify: false
1156                    },
1157                ]
1158            );
1159
1160            // User can't accept their own contact request
1161            db.respond_to_contact_request(user_1, user_2, true)
1162                .await
1163                .unwrap_err();
1164
1165            // User accepts a contact request. Both users see the contact.
1166            db.respond_to_contact_request(user_2, user_1, true)
1167                .await
1168                .unwrap();
1169            assert_eq!(
1170                db.get_contacts(user_1).await.unwrap(),
1171                &[
1172                    Contact::Accepted {
1173                        user_id: user_1,
1174                        should_notify: false
1175                    },
1176                    Contact::Accepted {
1177                        user_id: user_2,
1178                        should_notify: true
1179                    }
1180                ],
1181            );
1182            assert!(db.has_contact(user_1, user_2).await.unwrap());
1183            assert!(db.has_contact(user_2, user_1).await.unwrap());
1184            assert_eq!(
1185                db.get_contacts(user_2).await.unwrap(),
1186                &[
1187                    Contact::Accepted {
1188                        user_id: user_1,
1189                        should_notify: false,
1190                    },
1191                    Contact::Accepted {
1192                        user_id: user_2,
1193                        should_notify: false,
1194                    },
1195                ]
1196            );
1197
1198            // Users cannot re-request existing contacts.
1199            db.send_contact_request(user_1, user_2).await.unwrap_err();
1200            db.send_contact_request(user_2, user_1).await.unwrap_err();
1201
1202            // Users can't dismiss notifications of them accepting other users' requests.
1203            db.dismiss_contact_notification(user_2, user_1)
1204                .await
1205                .unwrap_err();
1206            assert_eq!(
1207                db.get_contacts(user_1).await.unwrap(),
1208                &[
1209                    Contact::Accepted {
1210                        user_id: user_1,
1211                        should_notify: false
1212                    },
1213                    Contact::Accepted {
1214                        user_id: user_2,
1215                        should_notify: true,
1216                    },
1217                ]
1218            );
1219
1220            // Users can dismiss notifications of other users accepting their requests.
1221            db.dismiss_contact_notification(user_1, user_2)
1222                .await
1223                .unwrap();
1224            assert_eq!(
1225                db.get_contacts(user_1).await.unwrap(),
1226                &[
1227                    Contact::Accepted {
1228                        user_id: user_1,
1229                        should_notify: false
1230                    },
1231                    Contact::Accepted {
1232                        user_id: user_2,
1233                        should_notify: false,
1234                    },
1235                ]
1236            );
1237
1238            // Users send each other concurrent contact requests and
1239            // see that they are immediately accepted.
1240            db.send_contact_request(user_1, user_3).await.unwrap();
1241            db.send_contact_request(user_3, user_1).await.unwrap();
1242            assert_eq!(
1243                db.get_contacts(user_1).await.unwrap(),
1244                &[
1245                    Contact::Accepted {
1246                        user_id: user_1,
1247                        should_notify: false
1248                    },
1249                    Contact::Accepted {
1250                        user_id: user_2,
1251                        should_notify: false,
1252                    },
1253                    Contact::Accepted {
1254                        user_id: user_3,
1255                        should_notify: false
1256                    },
1257                ]
1258            );
1259            assert_eq!(
1260                db.get_contacts(user_3).await.unwrap(),
1261                &[
1262                    Contact::Accepted {
1263                        user_id: user_1,
1264                        should_notify: false
1265                    },
1266                    Contact::Accepted {
1267                        user_id: user_3,
1268                        should_notify: false
1269                    }
1270                ],
1271            );
1272
1273            // User declines a contact request. Both users see that it is gone.
1274            db.send_contact_request(user_2, user_3).await.unwrap();
1275            db.respond_to_contact_request(user_3, user_2, false)
1276                .await
1277                .unwrap();
1278            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1279            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1280            assert_eq!(
1281                db.get_contacts(user_2).await.unwrap(),
1282                &[
1283                    Contact::Accepted {
1284                        user_id: user_1,
1285                        should_notify: false
1286                    },
1287                    Contact::Accepted {
1288                        user_id: user_2,
1289                        should_notify: false
1290                    }
1291                ]
1292            );
1293            assert_eq!(
1294                db.get_contacts(user_3).await.unwrap(),
1295                &[
1296                    Contact::Accepted {
1297                        user_id: user_1,
1298                        should_notify: false
1299                    },
1300                    Contact::Accepted {
1301                        user_id: user_3,
1302                        should_notify: false
1303                    }
1304                ],
1305            );
1306        }
1307    }
1308
1309    #[tokio::test(flavor = "multi_thread")]
1310    async fn test_invite_codes() {
1311        let postgres = TestDb::postgres().await;
1312        let db = postgres.db();
1313        let user1 = db.create_user("user-1", false).await.unwrap();
1314
1315        // Initially, user 1 has no invite code
1316        assert_eq!(db.get_invite_code(user1).await.unwrap(), None);
1317
1318        // User 1 creates an invite code that can be used twice.
1319        db.set_invite_count(user1, 2).await.unwrap();
1320        let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1321        assert_eq!(invite_count, 2);
1322
1323        // User 2 redeems the invite code and becomes a contact of user 1.
1324        let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
1325        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1326        assert_eq!(invite_count, 1);
1327        assert_eq!(
1328            db.get_contacts(user1).await.unwrap(),
1329            [
1330                Contact::Accepted {
1331                    user_id: user1,
1332                    should_notify: false
1333                },
1334                Contact::Accepted {
1335                    user_id: user2,
1336                    should_notify: true
1337                }
1338            ]
1339        );
1340        assert_eq!(
1341            db.get_contacts(user2).await.unwrap(),
1342            [
1343                Contact::Accepted {
1344                    user_id: user1,
1345                    should_notify: false
1346                },
1347                Contact::Accepted {
1348                    user_id: user2,
1349                    should_notify: false
1350                }
1351            ]
1352        );
1353
1354        // User 3 redeems the invite code and becomes a contact of user 1.
1355        let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
1356        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1357        assert_eq!(invite_count, 0);
1358        assert_eq!(
1359            db.get_contacts(user1).await.unwrap(),
1360            [
1361                Contact::Accepted {
1362                    user_id: user1,
1363                    should_notify: false
1364                },
1365                Contact::Accepted {
1366                    user_id: user2,
1367                    should_notify: true
1368                },
1369                Contact::Accepted {
1370                    user_id: user3,
1371                    should_notify: true
1372                }
1373            ]
1374        );
1375        assert_eq!(
1376            db.get_contacts(user3).await.unwrap(),
1377            [
1378                Contact::Accepted {
1379                    user_id: user1,
1380                    should_notify: false
1381                },
1382                Contact::Accepted {
1383                    user_id: user3,
1384                    should_notify: false
1385                },
1386            ]
1387        );
1388
1389        // Trying to reedem the code for the third time results in an error.
1390        db.redeem_invite_code(&invite_code, "user-4")
1391            .await
1392            .unwrap_err();
1393
1394        // Invite count can be updated after the code has been created.
1395        db.set_invite_count(user1, 2).await.unwrap();
1396        let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1397        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1398        assert_eq!(invite_count, 2);
1399
1400        // User 4 can now redeem the invite code and becomes a contact of user 1.
1401        let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
1402        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1403        assert_eq!(invite_count, 1);
1404        assert_eq!(
1405            db.get_contacts(user1).await.unwrap(),
1406            [
1407                Contact::Accepted {
1408                    user_id: user1,
1409                    should_notify: false
1410                },
1411                Contact::Accepted {
1412                    user_id: user2,
1413                    should_notify: true
1414                },
1415                Contact::Accepted {
1416                    user_id: user3,
1417                    should_notify: true
1418                },
1419                Contact::Accepted {
1420                    user_id: user4,
1421                    should_notify: true
1422                }
1423            ]
1424        );
1425        assert_eq!(
1426            db.get_contacts(user4).await.unwrap(),
1427            [
1428                Contact::Accepted {
1429                    user_id: user1,
1430                    should_notify: false
1431                },
1432                Contact::Accepted {
1433                    user_id: user4,
1434                    should_notify: false
1435                },
1436            ]
1437        );
1438
1439        // An existing user cannot redeem invite codes.
1440        db.redeem_invite_code(&invite_code, "user-2")
1441            .await
1442            .unwrap_err();
1443        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1444        assert_eq!(invite_count, 1);
1445    }
1446
1447    pub struct TestDb {
1448        pub db: Option<Arc<dyn Db>>,
1449        pub url: String,
1450    }
1451
1452    impl TestDb {
1453        pub async fn postgres() -> Self {
1454            lazy_static! {
1455                static ref LOCK: Mutex<()> = Mutex::new(());
1456            }
1457
1458            let _guard = LOCK.lock();
1459            let mut rng = StdRng::from_entropy();
1460            let name = format!("zed-test-{}", rng.gen::<u128>());
1461            let url = format!("postgres://postgres@localhost/{}", name);
1462            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1463            Postgres::create_database(&url)
1464                .await
1465                .expect("failed to create test db");
1466            let db = PostgresDb::new(&url, 5).await.unwrap();
1467            let migrator = Migrator::new(migrations_path).await.unwrap();
1468            migrator.run(&db.pool).await.unwrap();
1469            Self {
1470                db: Some(Arc::new(db)),
1471                url,
1472            }
1473        }
1474
1475        pub fn fake(background: Arc<Background>) -> Self {
1476            Self {
1477                db: Some(Arc::new(FakeDb::new(background))),
1478                url: Default::default(),
1479            }
1480        }
1481
1482        pub fn db(&self) -> &Arc<dyn Db> {
1483            self.db.as_ref().unwrap()
1484        }
1485    }
1486
1487    impl Drop for TestDb {
1488        fn drop(&mut self) {
1489            if let Some(db) = self.db.take() {
1490                futures::executor::block_on(db.teardown(&self.url));
1491            }
1492        }
1493    }
1494
1495    pub struct FakeDb {
1496        background: Arc<Background>,
1497        pub users: Mutex<BTreeMap<UserId, User>>,
1498        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1499        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1500        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1501        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1502        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1503        pub contacts: Mutex<Vec<FakeContact>>,
1504        next_channel_message_id: Mutex<i32>,
1505        next_user_id: Mutex<i32>,
1506        next_org_id: Mutex<i32>,
1507        next_channel_id: Mutex<i32>,
1508    }
1509
1510    #[derive(Debug)]
1511    pub struct FakeContact {
1512        pub requester_id: UserId,
1513        pub responder_id: UserId,
1514        pub accepted: bool,
1515        pub should_notify: bool,
1516    }
1517
1518    impl FakeDb {
1519        pub fn new(background: Arc<Background>) -> Self {
1520            Self {
1521                background,
1522                users: Default::default(),
1523                next_user_id: Mutex::new(1),
1524                orgs: Default::default(),
1525                next_org_id: Mutex::new(1),
1526                org_memberships: Default::default(),
1527                channels: Default::default(),
1528                next_channel_id: Mutex::new(1),
1529                channel_memberships: Default::default(),
1530                channel_messages: Default::default(),
1531                next_channel_message_id: Mutex::new(1),
1532                contacts: Default::default(),
1533            }
1534        }
1535    }
1536
1537    #[async_trait]
1538    impl Db for FakeDb {
1539        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1540            self.background.simulate_random_delay().await;
1541
1542            let mut users = self.users.lock();
1543            if let Some(user) = users
1544                .values()
1545                .find(|user| user.github_login == github_login)
1546            {
1547                Ok(user.id)
1548            } else {
1549                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1550                users.insert(
1551                    user_id,
1552                    User {
1553                        id: user_id,
1554                        github_login: github_login.to_string(),
1555                        admin,
1556                        invite_code: None,
1557                        invite_count: 0,
1558                    },
1559                );
1560                Ok(user_id)
1561            }
1562        }
1563
1564        async fn get_all_users(&self) -> Result<Vec<User>> {
1565            unimplemented!()
1566        }
1567
1568        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1569            unimplemented!()
1570        }
1571
1572        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1573            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1574        }
1575
1576        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1577            self.background.simulate_random_delay().await;
1578            let users = self.users.lock();
1579            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1580        }
1581
1582        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1583            Ok(self
1584                .users
1585                .lock()
1586                .values()
1587                .find(|user| user.github_login == github_login)
1588                .cloned())
1589        }
1590
1591        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1592            unimplemented!()
1593        }
1594
1595        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1596            unimplemented!()
1597        }
1598
1599        // invite codes
1600
1601        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1602            unimplemented!()
1603        }
1604
1605        async fn get_invite_code(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1606            unimplemented!()
1607        }
1608
1609        async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
1610            unimplemented!()
1611        }
1612
1613        // contacts
1614
1615        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1616            self.background.simulate_random_delay().await;
1617            let mut contacts = vec![Contact::Accepted {
1618                user_id: id,
1619                should_notify: false,
1620            }];
1621
1622            for contact in self.contacts.lock().iter() {
1623                if contact.requester_id == id {
1624                    if contact.accepted {
1625                        contacts.push(Contact::Accepted {
1626                            user_id: contact.responder_id,
1627                            should_notify: contact.should_notify,
1628                        });
1629                    } else {
1630                        contacts.push(Contact::Outgoing {
1631                            user_id: contact.responder_id,
1632                        });
1633                    }
1634                } else if contact.responder_id == id {
1635                    if contact.accepted {
1636                        contacts.push(Contact::Accepted {
1637                            user_id: contact.requester_id,
1638                            should_notify: false,
1639                        });
1640                    } else {
1641                        contacts.push(Contact::Incoming {
1642                            user_id: contact.requester_id,
1643                            should_notify: contact.should_notify,
1644                        });
1645                    }
1646                }
1647            }
1648
1649            contacts.sort_unstable_by_key(|contact| contact.user_id());
1650            Ok(contacts)
1651        }
1652
1653        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1654            self.background.simulate_random_delay().await;
1655            Ok(self.contacts.lock().iter().any(|contact| {
1656                contact.accepted
1657                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1658                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1659            }))
1660        }
1661
1662        async fn send_contact_request(
1663            &self,
1664            requester_id: UserId,
1665            responder_id: UserId,
1666        ) -> Result<()> {
1667            let mut contacts = self.contacts.lock();
1668            for contact in contacts.iter_mut() {
1669                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1670                    if contact.accepted {
1671                        Err(anyhow!("contact already exists"))?;
1672                    } else {
1673                        Err(anyhow!("contact already requested"))?;
1674                    }
1675                }
1676                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1677                    if contact.accepted {
1678                        Err(anyhow!("contact already exists"))?;
1679                    } else {
1680                        contact.accepted = true;
1681                        contact.should_notify = false;
1682                        return Ok(());
1683                    }
1684                }
1685            }
1686            contacts.push(FakeContact {
1687                requester_id,
1688                responder_id,
1689                accepted: false,
1690                should_notify: true,
1691            });
1692            Ok(())
1693        }
1694
1695        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1696            self.contacts.lock().retain(|contact| {
1697                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1698            });
1699            Ok(())
1700        }
1701
1702        async fn dismiss_contact_notification(
1703            &self,
1704            user_id: UserId,
1705            contact_user_id: UserId,
1706        ) -> Result<()> {
1707            let mut contacts = self.contacts.lock();
1708            for contact in contacts.iter_mut() {
1709                if contact.requester_id == contact_user_id
1710                    && contact.responder_id == user_id
1711                    && !contact.accepted
1712                {
1713                    contact.should_notify = false;
1714                    return Ok(());
1715                }
1716                if contact.requester_id == user_id
1717                    && contact.responder_id == contact_user_id
1718                    && contact.accepted
1719                {
1720                    contact.should_notify = false;
1721                    return Ok(());
1722                }
1723            }
1724            Err(anyhow!("no such notification"))
1725        }
1726
1727        async fn respond_to_contact_request(
1728            &self,
1729            responder_id: UserId,
1730            requester_id: UserId,
1731            accept: bool,
1732        ) -> Result<()> {
1733            let mut contacts = self.contacts.lock();
1734            for (ix, contact) in contacts.iter_mut().enumerate() {
1735                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1736                    if contact.accepted {
1737                        return Err(anyhow!("contact already confirmed"));
1738                    }
1739                    if accept {
1740                        contact.accepted = true;
1741                        contact.should_notify = true;
1742                    } else {
1743                        contacts.remove(ix);
1744                    }
1745                    return Ok(());
1746                }
1747            }
1748            Err(anyhow!("no such contact request"))
1749        }
1750
1751        async fn create_access_token_hash(
1752            &self,
1753            _user_id: UserId,
1754            _access_token_hash: &str,
1755            _max_access_token_count: usize,
1756        ) -> Result<()> {
1757            unimplemented!()
1758        }
1759
1760        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1761            unimplemented!()
1762        }
1763
1764        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1765            unimplemented!()
1766        }
1767
1768        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1769            self.background.simulate_random_delay().await;
1770            let mut orgs = self.orgs.lock();
1771            if orgs.values().any(|org| org.slug == slug) {
1772                Err(anyhow!("org already exists"))
1773            } else {
1774                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1775                orgs.insert(
1776                    org_id,
1777                    Org {
1778                        id: org_id,
1779                        name: name.to_string(),
1780                        slug: slug.to_string(),
1781                    },
1782                );
1783                Ok(org_id)
1784            }
1785        }
1786
1787        async fn add_org_member(
1788            &self,
1789            org_id: OrgId,
1790            user_id: UserId,
1791            is_admin: bool,
1792        ) -> Result<()> {
1793            self.background.simulate_random_delay().await;
1794            if !self.orgs.lock().contains_key(&org_id) {
1795                return Err(anyhow!("org does not exist"));
1796            }
1797            if !self.users.lock().contains_key(&user_id) {
1798                return Err(anyhow!("user does not exist"));
1799            }
1800
1801            self.org_memberships
1802                .lock()
1803                .entry((org_id, user_id))
1804                .or_insert(is_admin);
1805            Ok(())
1806        }
1807
1808        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1809            self.background.simulate_random_delay().await;
1810            if !self.orgs.lock().contains_key(&org_id) {
1811                return Err(anyhow!("org does not exist"));
1812            }
1813
1814            let mut channels = self.channels.lock();
1815            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1816            channels.insert(
1817                channel_id,
1818                Channel {
1819                    id: channel_id,
1820                    name: name.to_string(),
1821                    owner_id: org_id.0,
1822                    owner_is_user: false,
1823                },
1824            );
1825            Ok(channel_id)
1826        }
1827
1828        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1829            self.background.simulate_random_delay().await;
1830            Ok(self
1831                .channels
1832                .lock()
1833                .values()
1834                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1835                .cloned()
1836                .collect())
1837        }
1838
1839        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1840            self.background.simulate_random_delay().await;
1841            let channels = self.channels.lock();
1842            let memberships = self.channel_memberships.lock();
1843            Ok(channels
1844                .values()
1845                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1846                .cloned()
1847                .collect())
1848        }
1849
1850        async fn can_user_access_channel(
1851            &self,
1852            user_id: UserId,
1853            channel_id: ChannelId,
1854        ) -> Result<bool> {
1855            self.background.simulate_random_delay().await;
1856            Ok(self
1857                .channel_memberships
1858                .lock()
1859                .contains_key(&(channel_id, user_id)))
1860        }
1861
1862        async fn add_channel_member(
1863            &self,
1864            channel_id: ChannelId,
1865            user_id: UserId,
1866            is_admin: bool,
1867        ) -> Result<()> {
1868            self.background.simulate_random_delay().await;
1869            if !self.channels.lock().contains_key(&channel_id) {
1870                return Err(anyhow!("channel does not exist"));
1871            }
1872            if !self.users.lock().contains_key(&user_id) {
1873                return Err(anyhow!("user does not exist"));
1874            }
1875
1876            self.channel_memberships
1877                .lock()
1878                .entry((channel_id, user_id))
1879                .or_insert(is_admin);
1880            Ok(())
1881        }
1882
1883        async fn create_channel_message(
1884            &self,
1885            channel_id: ChannelId,
1886            sender_id: UserId,
1887            body: &str,
1888            timestamp: OffsetDateTime,
1889            nonce: u128,
1890        ) -> Result<MessageId> {
1891            self.background.simulate_random_delay().await;
1892            if !self.channels.lock().contains_key(&channel_id) {
1893                return Err(anyhow!("channel does not exist"));
1894            }
1895            if !self.users.lock().contains_key(&sender_id) {
1896                return Err(anyhow!("user does not exist"));
1897            }
1898
1899            let mut messages = self.channel_messages.lock();
1900            if let Some(message) = messages
1901                .values()
1902                .find(|message| message.nonce.as_u128() == nonce)
1903            {
1904                Ok(message.id)
1905            } else {
1906                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1907                messages.insert(
1908                    message_id,
1909                    ChannelMessage {
1910                        id: message_id,
1911                        channel_id,
1912                        sender_id,
1913                        body: body.to_string(),
1914                        sent_at: timestamp,
1915                        nonce: Uuid::from_u128(nonce),
1916                    },
1917                );
1918                Ok(message_id)
1919            }
1920        }
1921
1922        async fn get_channel_messages(
1923            &self,
1924            channel_id: ChannelId,
1925            count: usize,
1926            before_id: Option<MessageId>,
1927        ) -> Result<Vec<ChannelMessage>> {
1928            let mut messages = self
1929                .channel_messages
1930                .lock()
1931                .values()
1932                .rev()
1933                .filter(|message| {
1934                    message.channel_id == channel_id
1935                        && message.id < before_id.unwrap_or(MessageId::MAX)
1936                })
1937                .take(count)
1938                .cloned()
1939                .collect::<Vec<_>>();
1940            messages.sort_unstable_by_key(|message| message.id);
1941            Ok(messages)
1942        }
1943
1944        async fn teardown(&self, _: &str) {}
1945
1946        #[cfg(test)]
1947        fn as_fake(&self) -> Option<&FakeDb> {
1948            Some(self)
1949        }
1950    }
1951}