db.rs

   1use crate::Result;
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use futures::StreamExt;
   5use nanoid::nanoid;
   6use serde::Serialize;
   7pub use sqlx::postgres::PgPoolOptions as DbOptions;
   8use sqlx::{types::Uuid, FromRow};
   9use time::OffsetDateTime;
  10
  11#[async_trait]
  12pub trait Db: Send + Sync {
  13    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  14    async fn get_all_users(&self) -> Result<Vec<User>>;
  15    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  16    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  17    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  18    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  19    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  20    async fn destroy_user(&self, id: UserId) -> Result<()>;
  21
  22    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  23    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>>;
  24    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
  25
  26    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  27    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  28    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  29    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  30    async fn dismiss_contact_notification(
  31        &self,
  32        responder_id: UserId,
  33        requester_id: UserId,
  34    ) -> Result<()>;
  35    async fn respond_to_contact_request(
  36        &self,
  37        responder_id: UserId,
  38        requester_id: UserId,
  39        accept: bool,
  40    ) -> Result<()>;
  41
  42    async fn create_access_token_hash(
  43        &self,
  44        user_id: UserId,
  45        access_token_hash: &str,
  46        max_access_token_count: usize,
  47    ) -> Result<()>;
  48    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  49    #[cfg(any(test, feature = "seed-support"))]
  50
  51    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  52    #[cfg(any(test, feature = "seed-support"))]
  53    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  54    #[cfg(any(test, feature = "seed-support"))]
  55    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  56    #[cfg(any(test, feature = "seed-support"))]
  57    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  58    #[cfg(any(test, feature = "seed-support"))]
  59
  60    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  61    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  62    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  63        -> Result<bool>;
  64    #[cfg(any(test, feature = "seed-support"))]
  65    async fn add_channel_member(
  66        &self,
  67        channel_id: ChannelId,
  68        user_id: UserId,
  69        is_admin: bool,
  70    ) -> Result<()>;
  71    async fn create_channel_message(
  72        &self,
  73        channel_id: ChannelId,
  74        sender_id: UserId,
  75        body: &str,
  76        timestamp: OffsetDateTime,
  77        nonce: u128,
  78    ) -> Result<MessageId>;
  79    async fn get_channel_messages(
  80        &self,
  81        channel_id: ChannelId,
  82        count: usize,
  83        before_id: Option<MessageId>,
  84    ) -> Result<Vec<ChannelMessage>>;
  85    #[cfg(test)]
  86    async fn teardown(&self, url: &str);
  87    #[cfg(test)]
  88    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
  89}
  90
  91pub struct PostgresDb {
  92    pool: sqlx::PgPool,
  93}
  94
  95impl PostgresDb {
  96    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  97        let pool = DbOptions::new()
  98            .max_connections(max_connections)
  99            .connect(&url)
 100            .await
 101            .context("failed to connect to postgres database")?;
 102        Ok(Self { pool })
 103    }
 104}
 105
 106#[async_trait]
 107impl Db for PostgresDb {
 108    // users
 109
 110    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 111        let query = "
 112            INSERT INTO users (github_login, admin)
 113            VALUES ($1, $2)
 114            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 115            RETURNING id
 116        ";
 117        Ok(sqlx::query_scalar(query)
 118            .bind(github_login)
 119            .bind(admin)
 120            .fetch_one(&self.pool)
 121            .await
 122            .map(UserId)?)
 123    }
 124
 125    async fn get_all_users(&self) -> Result<Vec<User>> {
 126        let query = "SELECT * FROM users ORDER BY github_login ASC";
 127        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 128    }
 129
 130    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 131        let like_string = fuzzy_like_string(name_query);
 132        let query = "
 133            SELECT users.*
 134            FROM users
 135            WHERE github_login ILIKE $1
 136            ORDER BY github_login <-> $2
 137            LIMIT $3
 138        ";
 139        Ok(sqlx::query_as(query)
 140            .bind(like_string)
 141            .bind(name_query)
 142            .bind(limit)
 143            .fetch_all(&self.pool)
 144            .await?)
 145    }
 146
 147    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 148        let users = self.get_users_by_ids(vec![id]).await?;
 149        Ok(users.into_iter().next())
 150    }
 151
 152    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 153        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 154        let query = "
 155            SELECT users.*
 156            FROM users
 157            WHERE users.id = ANY ($1)
 158        ";
 159        Ok(sqlx::query_as(query)
 160            .bind(&ids)
 161            .fetch_all(&self.pool)
 162            .await?)
 163    }
 164
 165    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 166        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 167        Ok(sqlx::query_as(query)
 168            .bind(github_login)
 169            .fetch_optional(&self.pool)
 170            .await?)
 171    }
 172
 173    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 174        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 175        Ok(sqlx::query(query)
 176            .bind(is_admin)
 177            .bind(id.0)
 178            .execute(&self.pool)
 179            .await
 180            .map(drop)?)
 181    }
 182
 183    async fn destroy_user(&self, id: UserId) -> Result<()> {
 184        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 185        sqlx::query(query)
 186            .bind(id.0)
 187            .execute(&self.pool)
 188            .await
 189            .map(drop)?;
 190        let query = "DELETE FROM users WHERE id = $1;";
 191        Ok(sqlx::query(query)
 192            .bind(id.0)
 193            .execute(&self.pool)
 194            .await
 195            .map(drop)?)
 196    }
 197
 198    // invite codes
 199
 200    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 201        let mut tx = self.pool.begin().await?;
 202        sqlx::query(
 203            "
 204                UPDATE users
 205                SET invite_code = $1
 206                WHERE id = $2 AND invite_code IS NULL
 207            ",
 208        )
 209        .bind(nanoid!(16))
 210        .bind(id)
 211        .execute(&mut tx)
 212        .await?;
 213        sqlx::query(
 214            "
 215            UPDATE users
 216            SET invite_count = $1
 217            WHERE id = $2
 218            ",
 219        )
 220        .bind(count)
 221        .bind(id)
 222        .execute(&mut tx)
 223        .await?;
 224        tx.commit().await?;
 225        Ok(())
 226    }
 227
 228    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>> {
 229        let result: Option<(String, i32)> = sqlx::query_as(
 230            "
 231                SELECT invite_code, invite_count
 232                FROM users
 233                WHERE id = $1 AND invite_code IS NOT NULL 
 234            ",
 235        )
 236        .bind(id)
 237        .fetch_optional(&self.pool)
 238        .await?;
 239        if let Some((code, count)) = result {
 240            Ok(Some((code, count.try_into()?)))
 241        } else {
 242            Ok(None)
 243        }
 244    }
 245
 246    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
 247        let mut tx = self.pool.begin().await?;
 248
 249        let inviter_id: UserId = sqlx::query_scalar(
 250            "
 251                UPDATE users
 252                SET invite_count = invite_count - 1
 253                WHERE
 254                    invite_code = $1 AND
 255                    invite_count > 0
 256                RETURNING id
 257            ",
 258        )
 259        .bind(code)
 260        .fetch_optional(&mut tx)
 261        .await?
 262        .ok_or_else(|| anyhow!("invite code not found"))?;
 263        let invitee_id = sqlx::query_scalar(
 264            "
 265                INSERT INTO users
 266                    (github_login, admin, inviter_id)
 267                VALUES
 268                    ($1, 'f', $2)
 269                RETURNING id
 270            ",
 271        )
 272        .bind(login)
 273        .bind(inviter_id)
 274        .fetch_one(&mut tx)
 275        .await
 276        .map(UserId)?;
 277
 278        sqlx::query(
 279            "
 280                INSERT INTO contacts
 281                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 282                VALUES
 283                    ($1, $2, 't', 't', 't')
 284            ",
 285        )
 286        .bind(inviter_id)
 287        .bind(invitee_id)
 288        .execute(&mut tx)
 289        .await?;
 290
 291        tx.commit().await?;
 292        Ok(invitee_id)
 293    }
 294
 295    // contacts
 296
 297    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 298        let query = "
 299            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 300            FROM contacts
 301            WHERE user_id_a = $1 OR user_id_b = $1;
 302        ";
 303
 304        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 305            .bind(user_id)
 306            .fetch(&self.pool);
 307
 308        let mut contacts = vec![Contact::Accepted {
 309            user_id,
 310            should_notify: false,
 311        }];
 312        while let Some(row) = rows.next().await {
 313            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 314
 315            if user_id_a == user_id {
 316                if accepted {
 317                    contacts.push(Contact::Accepted {
 318                        user_id: user_id_b,
 319                        should_notify: should_notify && a_to_b,
 320                    });
 321                } else if a_to_b {
 322                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 323                } else {
 324                    contacts.push(Contact::Incoming {
 325                        user_id: user_id_b,
 326                        should_notify,
 327                    });
 328                }
 329            } else {
 330                if accepted {
 331                    contacts.push(Contact::Accepted {
 332                        user_id: user_id_a,
 333                        should_notify: should_notify && !a_to_b,
 334                    });
 335                } else if a_to_b {
 336                    contacts.push(Contact::Incoming {
 337                        user_id: user_id_a,
 338                        should_notify,
 339                    });
 340                } else {
 341                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 342                }
 343            }
 344        }
 345
 346        contacts.sort_unstable_by_key(|contact| contact.user_id());
 347
 348        Ok(contacts)
 349    }
 350
 351    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 352        let (id_a, id_b) = if user_id_1 < user_id_2 {
 353            (user_id_1, user_id_2)
 354        } else {
 355            (user_id_2, user_id_1)
 356        };
 357
 358        let query = "
 359            SELECT 1 FROM contacts
 360            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 361            LIMIT 1
 362        ";
 363        Ok(sqlx::query_scalar::<_, i32>(query)
 364            .bind(id_a.0)
 365            .bind(id_b.0)
 366            .fetch_optional(&self.pool)
 367            .await?
 368            .is_some())
 369    }
 370
 371    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 372        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 373            (sender_id, receiver_id, true)
 374        } else {
 375            (receiver_id, sender_id, false)
 376        };
 377        let query = "
 378            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 379            VALUES ($1, $2, $3, 'f', 't')
 380            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 381            SET
 382                accepted = 't',
 383                should_notify = 'f'
 384            WHERE
 385                NOT contacts.accepted AND
 386                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 387                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 388        ";
 389        let result = sqlx::query(query)
 390            .bind(id_a.0)
 391            .bind(id_b.0)
 392            .bind(a_to_b)
 393            .execute(&self.pool)
 394            .await?;
 395
 396        if result.rows_affected() == 1 {
 397            Ok(())
 398        } else {
 399            Err(anyhow!("contact already requested"))?
 400        }
 401    }
 402
 403    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 404        let (id_a, id_b) = if responder_id < requester_id {
 405            (responder_id, requester_id)
 406        } else {
 407            (requester_id, responder_id)
 408        };
 409        let query = "
 410            DELETE FROM contacts
 411            WHERE user_id_a = $1 AND user_id_b = $2;
 412        ";
 413        let result = sqlx::query(query)
 414            .bind(id_a.0)
 415            .bind(id_b.0)
 416            .execute(&self.pool)
 417            .await?;
 418
 419        if result.rows_affected() == 1 {
 420            Ok(())
 421        } else {
 422            Err(anyhow!("no such contact"))?
 423        }
 424    }
 425
 426    async fn dismiss_contact_notification(
 427        &self,
 428        user_id: UserId,
 429        contact_user_id: UserId,
 430    ) -> Result<()> {
 431        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 432            (user_id, contact_user_id, true)
 433        } else {
 434            (contact_user_id, user_id, false)
 435        };
 436
 437        let query = "
 438            UPDATE contacts
 439            SET should_notify = 'f'
 440            WHERE
 441                user_id_a = $1 AND user_id_b = $2 AND
 442                (
 443                    (a_to_b = $3 AND accepted) OR
 444                    (a_to_b != $3 AND NOT accepted)
 445                );
 446        ";
 447
 448        let result = sqlx::query(query)
 449            .bind(id_a.0)
 450            .bind(id_b.0)
 451            .bind(a_to_b)
 452            .execute(&self.pool)
 453            .await?;
 454
 455        if result.rows_affected() == 0 {
 456            Err(anyhow!("no such contact request"))?;
 457        }
 458
 459        Ok(())
 460    }
 461
 462    async fn respond_to_contact_request(
 463        &self,
 464        responder_id: UserId,
 465        requester_id: UserId,
 466        accept: bool,
 467    ) -> Result<()> {
 468        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 469            (responder_id, requester_id, false)
 470        } else {
 471            (requester_id, responder_id, true)
 472        };
 473        let result = if accept {
 474            let query = "
 475                UPDATE contacts
 476                SET accepted = 't', should_notify = 't'
 477                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 478            ";
 479            sqlx::query(query)
 480                .bind(id_a.0)
 481                .bind(id_b.0)
 482                .bind(a_to_b)
 483                .execute(&self.pool)
 484                .await?
 485        } else {
 486            let query = "
 487                DELETE FROM contacts
 488                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 489            ";
 490            sqlx::query(query)
 491                .bind(id_a.0)
 492                .bind(id_b.0)
 493                .bind(a_to_b)
 494                .execute(&self.pool)
 495                .await?
 496        };
 497        if result.rows_affected() == 1 {
 498            Ok(())
 499        } else {
 500            Err(anyhow!("no such contact request"))?
 501        }
 502    }
 503
 504    // access tokens
 505
 506    async fn create_access_token_hash(
 507        &self,
 508        user_id: UserId,
 509        access_token_hash: &str,
 510        max_access_token_count: usize,
 511    ) -> Result<()> {
 512        let insert_query = "
 513            INSERT INTO access_tokens (user_id, hash)
 514            VALUES ($1, $2);
 515        ";
 516        let cleanup_query = "
 517            DELETE FROM access_tokens
 518            WHERE id IN (
 519                SELECT id from access_tokens
 520                WHERE user_id = $1
 521                ORDER BY id DESC
 522                OFFSET $3
 523            )
 524        ";
 525
 526        let mut tx = self.pool.begin().await?;
 527        sqlx::query(insert_query)
 528            .bind(user_id.0)
 529            .bind(access_token_hash)
 530            .execute(&mut tx)
 531            .await?;
 532        sqlx::query(cleanup_query)
 533            .bind(user_id.0)
 534            .bind(access_token_hash)
 535            .bind(max_access_token_count as u32)
 536            .execute(&mut tx)
 537            .await?;
 538        Ok(tx.commit().await?)
 539    }
 540
 541    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 542        let query = "
 543            SELECT hash
 544            FROM access_tokens
 545            WHERE user_id = $1
 546            ORDER BY id DESC
 547        ";
 548        Ok(sqlx::query_scalar(query)
 549            .bind(user_id.0)
 550            .fetch_all(&self.pool)
 551            .await?)
 552    }
 553
 554    // orgs
 555
 556    #[allow(unused)] // Help rust-analyzer
 557    #[cfg(any(test, feature = "seed-support"))]
 558    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 559        let query = "
 560            SELECT *
 561            FROM orgs
 562            WHERE slug = $1
 563        ";
 564        Ok(sqlx::query_as(query)
 565            .bind(slug)
 566            .fetch_optional(&self.pool)
 567            .await?)
 568    }
 569
 570    #[cfg(any(test, feature = "seed-support"))]
 571    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 572        let query = "
 573            INSERT INTO orgs (name, slug)
 574            VALUES ($1, $2)
 575            RETURNING id
 576        ";
 577        Ok(sqlx::query_scalar(query)
 578            .bind(name)
 579            .bind(slug)
 580            .fetch_one(&self.pool)
 581            .await
 582            .map(OrgId)?)
 583    }
 584
 585    #[cfg(any(test, feature = "seed-support"))]
 586    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 587        let query = "
 588            INSERT INTO org_memberships (org_id, user_id, admin)
 589            VALUES ($1, $2, $3)
 590            ON CONFLICT DO NOTHING
 591        ";
 592        Ok(sqlx::query(query)
 593            .bind(org_id.0)
 594            .bind(user_id.0)
 595            .bind(is_admin)
 596            .execute(&self.pool)
 597            .await
 598            .map(drop)?)
 599    }
 600
 601    // channels
 602
 603    #[cfg(any(test, feature = "seed-support"))]
 604    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 605        let query = "
 606            INSERT INTO channels (owner_id, owner_is_user, name)
 607            VALUES ($1, false, $2)
 608            RETURNING id
 609        ";
 610        Ok(sqlx::query_scalar(query)
 611            .bind(org_id.0)
 612            .bind(name)
 613            .fetch_one(&self.pool)
 614            .await
 615            .map(ChannelId)?)
 616    }
 617
 618    #[allow(unused)] // Help rust-analyzer
 619    #[cfg(any(test, feature = "seed-support"))]
 620    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 621        let query = "
 622            SELECT *
 623            FROM channels
 624            WHERE
 625                channels.owner_is_user = false AND
 626                channels.owner_id = $1
 627        ";
 628        Ok(sqlx::query_as(query)
 629            .bind(org_id.0)
 630            .fetch_all(&self.pool)
 631            .await?)
 632    }
 633
 634    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 635        let query = "
 636            SELECT
 637                channels.*
 638            FROM
 639                channel_memberships, channels
 640            WHERE
 641                channel_memberships.user_id = $1 AND
 642                channel_memberships.channel_id = channels.id
 643        ";
 644        Ok(sqlx::query_as(query)
 645            .bind(user_id.0)
 646            .fetch_all(&self.pool)
 647            .await?)
 648    }
 649
 650    async fn can_user_access_channel(
 651        &self,
 652        user_id: UserId,
 653        channel_id: ChannelId,
 654    ) -> Result<bool> {
 655        let query = "
 656            SELECT id
 657            FROM channel_memberships
 658            WHERE user_id = $1 AND channel_id = $2
 659            LIMIT 1
 660        ";
 661        Ok(sqlx::query_scalar::<_, i32>(query)
 662            .bind(user_id.0)
 663            .bind(channel_id.0)
 664            .fetch_optional(&self.pool)
 665            .await
 666            .map(|e| e.is_some())?)
 667    }
 668
 669    #[cfg(any(test, feature = "seed-support"))]
 670    async fn add_channel_member(
 671        &self,
 672        channel_id: ChannelId,
 673        user_id: UserId,
 674        is_admin: bool,
 675    ) -> Result<()> {
 676        let query = "
 677            INSERT INTO channel_memberships (channel_id, user_id, admin)
 678            VALUES ($1, $2, $3)
 679            ON CONFLICT DO NOTHING
 680        ";
 681        Ok(sqlx::query(query)
 682            .bind(channel_id.0)
 683            .bind(user_id.0)
 684            .bind(is_admin)
 685            .execute(&self.pool)
 686            .await
 687            .map(drop)?)
 688    }
 689
 690    // messages
 691
 692    async fn create_channel_message(
 693        &self,
 694        channel_id: ChannelId,
 695        sender_id: UserId,
 696        body: &str,
 697        timestamp: OffsetDateTime,
 698        nonce: u128,
 699    ) -> Result<MessageId> {
 700        let query = "
 701            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 702            VALUES ($1, $2, $3, $4, $5)
 703            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 704            RETURNING id
 705        ";
 706        Ok(sqlx::query_scalar(query)
 707            .bind(channel_id.0)
 708            .bind(sender_id.0)
 709            .bind(body)
 710            .bind(timestamp)
 711            .bind(Uuid::from_u128(nonce))
 712            .fetch_one(&self.pool)
 713            .await
 714            .map(MessageId)?)
 715    }
 716
 717    async fn get_channel_messages(
 718        &self,
 719        channel_id: ChannelId,
 720        count: usize,
 721        before_id: Option<MessageId>,
 722    ) -> Result<Vec<ChannelMessage>> {
 723        let query = r#"
 724            SELECT * FROM (
 725                SELECT
 726                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 727                FROM
 728                    channel_messages
 729                WHERE
 730                    channel_id = $1 AND
 731                    id < $2
 732                ORDER BY id DESC
 733                LIMIT $3
 734            ) as recent_messages
 735            ORDER BY id ASC
 736        "#;
 737        Ok(sqlx::query_as(query)
 738            .bind(channel_id.0)
 739            .bind(before_id.unwrap_or(MessageId::MAX))
 740            .bind(count as i64)
 741            .fetch_all(&self.pool)
 742            .await?)
 743    }
 744
 745    #[cfg(test)]
 746    async fn teardown(&self, url: &str) {
 747        use util::ResultExt;
 748
 749        let query = "
 750            SELECT pg_terminate_backend(pg_stat_activity.pid)
 751            FROM pg_stat_activity
 752            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 753        ";
 754        sqlx::query(query).execute(&self.pool).await.log_err();
 755        self.pool.close().await;
 756        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 757            .await
 758            .log_err();
 759    }
 760
 761    #[cfg(test)]
 762    fn as_fake(&self) -> Option<&tests::FakeDb> {
 763        None
 764    }
 765}
 766
 767macro_rules! id_type {
 768    ($name:ident) => {
 769        #[derive(
 770            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 771        )]
 772        #[sqlx(transparent)]
 773        #[serde(transparent)]
 774        pub struct $name(pub i32);
 775
 776        impl $name {
 777            #[allow(unused)]
 778            pub const MAX: Self = Self(i32::MAX);
 779
 780            #[allow(unused)]
 781            pub fn from_proto(value: u64) -> Self {
 782                Self(value as i32)
 783            }
 784
 785            #[allow(unused)]
 786            pub fn to_proto(&self) -> u64 {
 787                self.0 as u64
 788            }
 789        }
 790
 791        impl std::fmt::Display for $name {
 792            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 793                self.0.fmt(f)
 794            }
 795        }
 796    };
 797}
 798
 799id_type!(UserId);
 800#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 801pub struct User {
 802    pub id: UserId,
 803    pub github_login: String,
 804    pub admin: bool,
 805    pub invite_code: Option<String>,
 806    pub invite_count: i32,
 807}
 808
 809id_type!(OrgId);
 810#[derive(FromRow)]
 811pub struct Org {
 812    pub id: OrgId,
 813    pub name: String,
 814    pub slug: String,
 815}
 816
 817id_type!(ChannelId);
 818#[derive(Clone, Debug, FromRow, Serialize)]
 819pub struct Channel {
 820    pub id: ChannelId,
 821    pub name: String,
 822    pub owner_id: i32,
 823    pub owner_is_user: bool,
 824}
 825
 826id_type!(MessageId);
 827#[derive(Clone, Debug, FromRow)]
 828pub struct ChannelMessage {
 829    pub id: MessageId,
 830    pub channel_id: ChannelId,
 831    pub sender_id: UserId,
 832    pub body: String,
 833    pub sent_at: OffsetDateTime,
 834    pub nonce: Uuid,
 835}
 836
 837#[derive(Clone, Debug, PartialEq, Eq)]
 838pub enum Contact {
 839    Accepted {
 840        user_id: UserId,
 841        should_notify: bool,
 842    },
 843    Outgoing {
 844        user_id: UserId,
 845    },
 846    Incoming {
 847        user_id: UserId,
 848        should_notify: bool,
 849    },
 850}
 851
 852impl Contact {
 853    pub fn user_id(&self) -> UserId {
 854        match self {
 855            Contact::Accepted { user_id, .. } => *user_id,
 856            Contact::Outgoing { user_id } => *user_id,
 857            Contact::Incoming { user_id, .. } => *user_id,
 858        }
 859    }
 860}
 861
 862#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 863pub struct IncomingContactRequest {
 864    pub requester_id: UserId,
 865    pub should_notify: bool,
 866}
 867
 868fn fuzzy_like_string(string: &str) -> String {
 869    let mut result = String::with_capacity(string.len() * 2 + 1);
 870    for c in string.chars() {
 871        if c.is_alphanumeric() {
 872            result.push('%');
 873            result.push(c);
 874        }
 875    }
 876    result.push('%');
 877    result
 878}
 879
 880#[cfg(test)]
 881pub mod tests {
 882    use super::*;
 883    use anyhow::anyhow;
 884    use collections::BTreeMap;
 885    use gpui::executor::Background;
 886    use lazy_static::lazy_static;
 887    use parking_lot::Mutex;
 888    use rand::prelude::*;
 889    use sqlx::{
 890        migrate::{MigrateDatabase, Migrator},
 891        Postgres,
 892    };
 893    use std::{path::Path, sync::Arc};
 894    use util::post_inc;
 895
 896    #[tokio::test(flavor = "multi_thread")]
 897    async fn test_get_users_by_ids() {
 898        for test_db in [
 899            TestDb::postgres().await,
 900            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 901        ] {
 902            let db = test_db.db();
 903
 904            let user = db.create_user("user", false).await.unwrap();
 905            let friend1 = db.create_user("friend-1", false).await.unwrap();
 906            let friend2 = db.create_user("friend-2", false).await.unwrap();
 907            let friend3 = db.create_user("friend-3", false).await.unwrap();
 908
 909            assert_eq!(
 910                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 911                    .await
 912                    .unwrap(),
 913                vec![
 914                    User {
 915                        id: user,
 916                        github_login: "user".to_string(),
 917                        admin: false,
 918                        ..Default::default()
 919                    },
 920                    User {
 921                        id: friend1,
 922                        github_login: "friend-1".to_string(),
 923                        admin: false,
 924                        ..Default::default()
 925                    },
 926                    User {
 927                        id: friend2,
 928                        github_login: "friend-2".to_string(),
 929                        admin: false,
 930                        ..Default::default()
 931                    },
 932                    User {
 933                        id: friend3,
 934                        github_login: "friend-3".to_string(),
 935                        admin: false,
 936                        ..Default::default()
 937                    }
 938                ]
 939            );
 940        }
 941    }
 942
 943    #[tokio::test(flavor = "multi_thread")]
 944    async fn test_recent_channel_messages() {
 945        for test_db in [
 946            TestDb::postgres().await,
 947            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 948        ] {
 949            let db = test_db.db();
 950            let user = db.create_user("user", false).await.unwrap();
 951            let org = db.create_org("org", "org").await.unwrap();
 952            let channel = db.create_org_channel(org, "channel").await.unwrap();
 953            for i in 0..10 {
 954                db.create_channel_message(
 955                    channel,
 956                    user,
 957                    &i.to_string(),
 958                    OffsetDateTime::now_utc(),
 959                    i,
 960                )
 961                .await
 962                .unwrap();
 963            }
 964
 965            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
 966            assert_eq!(
 967                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 968                ["5", "6", "7", "8", "9"]
 969            );
 970
 971            let prev_messages = db
 972                .get_channel_messages(channel, 4, Some(messages[0].id))
 973                .await
 974                .unwrap();
 975            assert_eq!(
 976                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
 977                ["1", "2", "3", "4"]
 978            );
 979        }
 980    }
 981
 982    #[tokio::test(flavor = "multi_thread")]
 983    async fn test_channel_message_nonces() {
 984        for test_db in [
 985            TestDb::postgres().await,
 986            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 987        ] {
 988            let db = test_db.db();
 989            let user = db.create_user("user", false).await.unwrap();
 990            let org = db.create_org("org", "org").await.unwrap();
 991            let channel = db.create_org_channel(org, "channel").await.unwrap();
 992
 993            let msg1_id = db
 994                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
 995                .await
 996                .unwrap();
 997            let msg2_id = db
 998                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
 999                .await
1000                .unwrap();
1001            let msg3_id = db
1002                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1003                .await
1004                .unwrap();
1005            let msg4_id = db
1006                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1007                .await
1008                .unwrap();
1009
1010            assert_ne!(msg1_id, msg2_id);
1011            assert_eq!(msg1_id, msg3_id);
1012            assert_eq!(msg2_id, msg4_id);
1013        }
1014    }
1015
1016    #[tokio::test(flavor = "multi_thread")]
1017    async fn test_create_access_tokens() {
1018        let test_db = TestDb::postgres().await;
1019        let db = test_db.db();
1020        let user = db.create_user("the-user", false).await.unwrap();
1021
1022        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1023        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1024        assert_eq!(
1025            db.get_access_token_hashes(user).await.unwrap(),
1026            &["h2".to_string(), "h1".to_string()]
1027        );
1028
1029        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1030        assert_eq!(
1031            db.get_access_token_hashes(user).await.unwrap(),
1032            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1033        );
1034
1035        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1036        assert_eq!(
1037            db.get_access_token_hashes(user).await.unwrap(),
1038            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1039        );
1040
1041        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1042        assert_eq!(
1043            db.get_access_token_hashes(user).await.unwrap(),
1044            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1045        );
1046    }
1047
1048    #[test]
1049    fn test_fuzzy_like_string() {
1050        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1051        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1052        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1053    }
1054
1055    #[tokio::test(flavor = "multi_thread")]
1056    async fn test_fuzzy_search_users() {
1057        let test_db = TestDb::postgres().await;
1058        let db = test_db.db();
1059        for github_login in [
1060            "California",
1061            "colorado",
1062            "oregon",
1063            "washington",
1064            "florida",
1065            "delaware",
1066            "rhode-island",
1067        ] {
1068            db.create_user(github_login, false).await.unwrap();
1069        }
1070
1071        assert_eq!(
1072            fuzzy_search_user_names(db, "clr").await,
1073            &["colorado", "California"]
1074        );
1075        assert_eq!(
1076            fuzzy_search_user_names(db, "ro").await,
1077            &["rhode-island", "colorado", "oregon"],
1078        );
1079
1080        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1081            db.fuzzy_search_users(query, 10)
1082                .await
1083                .unwrap()
1084                .into_iter()
1085                .map(|user| user.github_login)
1086                .collect::<Vec<_>>()
1087        }
1088    }
1089
1090    #[tokio::test(flavor = "multi_thread")]
1091    async fn test_add_contacts() {
1092        for test_db in [
1093            TestDb::postgres().await,
1094            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1095        ] {
1096            let db = test_db.db();
1097
1098            let user_1 = db.create_user("user1", false).await.unwrap();
1099            let user_2 = db.create_user("user2", false).await.unwrap();
1100            let user_3 = db.create_user("user3", false).await.unwrap();
1101
1102            // User starts with no contacts
1103            assert_eq!(
1104                db.get_contacts(user_1).await.unwrap(),
1105                vec![Contact::Accepted {
1106                    user_id: user_1,
1107                    should_notify: false
1108                }],
1109            );
1110
1111            // User requests a contact. Both users see the pending request.
1112            db.send_contact_request(user_1, user_2).await.unwrap();
1113            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1114            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1115            assert_eq!(
1116                db.get_contacts(user_1).await.unwrap(),
1117                &[
1118                    Contact::Accepted {
1119                        user_id: user_1,
1120                        should_notify: false
1121                    },
1122                    Contact::Outgoing { user_id: user_2 }
1123                ],
1124            );
1125            assert_eq!(
1126                db.get_contacts(user_2).await.unwrap(),
1127                &[
1128                    Contact::Incoming {
1129                        user_id: user_1,
1130                        should_notify: true
1131                    },
1132                    Contact::Accepted {
1133                        user_id: user_2,
1134                        should_notify: false
1135                    },
1136                ]
1137            );
1138
1139            // User 2 dismisses the contact request notification without accepting or rejecting.
1140            // We shouldn't notify them again.
1141            db.dismiss_contact_notification(user_1, user_2)
1142                .await
1143                .unwrap_err();
1144            db.dismiss_contact_notification(user_2, user_1)
1145                .await
1146                .unwrap();
1147            assert_eq!(
1148                db.get_contacts(user_2).await.unwrap(),
1149                &[
1150                    Contact::Incoming {
1151                        user_id: user_1,
1152                        should_notify: false
1153                    },
1154                    Contact::Accepted {
1155                        user_id: user_2,
1156                        should_notify: false
1157                    },
1158                ]
1159            );
1160
1161            // User can't accept their own contact request
1162            db.respond_to_contact_request(user_1, user_2, true)
1163                .await
1164                .unwrap_err();
1165
1166            // User accepts a contact request. Both users see the contact.
1167            db.respond_to_contact_request(user_2, user_1, true)
1168                .await
1169                .unwrap();
1170            assert_eq!(
1171                db.get_contacts(user_1).await.unwrap(),
1172                &[
1173                    Contact::Accepted {
1174                        user_id: user_1,
1175                        should_notify: false
1176                    },
1177                    Contact::Accepted {
1178                        user_id: user_2,
1179                        should_notify: true
1180                    }
1181                ],
1182            );
1183            assert!(db.has_contact(user_1, user_2).await.unwrap());
1184            assert!(db.has_contact(user_2, user_1).await.unwrap());
1185            assert_eq!(
1186                db.get_contacts(user_2).await.unwrap(),
1187                &[
1188                    Contact::Accepted {
1189                        user_id: user_1,
1190                        should_notify: false,
1191                    },
1192                    Contact::Accepted {
1193                        user_id: user_2,
1194                        should_notify: false,
1195                    },
1196                ]
1197            );
1198
1199            // Users cannot re-request existing contacts.
1200            db.send_contact_request(user_1, user_2).await.unwrap_err();
1201            db.send_contact_request(user_2, user_1).await.unwrap_err();
1202
1203            // Users can't dismiss notifications of them accepting other users' requests.
1204            db.dismiss_contact_notification(user_2, user_1)
1205                .await
1206                .unwrap_err();
1207            assert_eq!(
1208                db.get_contacts(user_1).await.unwrap(),
1209                &[
1210                    Contact::Accepted {
1211                        user_id: user_1,
1212                        should_notify: false
1213                    },
1214                    Contact::Accepted {
1215                        user_id: user_2,
1216                        should_notify: true,
1217                    },
1218                ]
1219            );
1220
1221            // Users can dismiss notifications of other users accepting their requests.
1222            db.dismiss_contact_notification(user_1, user_2)
1223                .await
1224                .unwrap();
1225            assert_eq!(
1226                db.get_contacts(user_1).await.unwrap(),
1227                &[
1228                    Contact::Accepted {
1229                        user_id: user_1,
1230                        should_notify: false
1231                    },
1232                    Contact::Accepted {
1233                        user_id: user_2,
1234                        should_notify: false,
1235                    },
1236                ]
1237            );
1238
1239            // Users send each other concurrent contact requests and
1240            // see that they are immediately accepted.
1241            db.send_contact_request(user_1, user_3).await.unwrap();
1242            db.send_contact_request(user_3, user_1).await.unwrap();
1243            assert_eq!(
1244                db.get_contacts(user_1).await.unwrap(),
1245                &[
1246                    Contact::Accepted {
1247                        user_id: user_1,
1248                        should_notify: false
1249                    },
1250                    Contact::Accepted {
1251                        user_id: user_2,
1252                        should_notify: false,
1253                    },
1254                    Contact::Accepted {
1255                        user_id: user_3,
1256                        should_notify: false
1257                    },
1258                ]
1259            );
1260            assert_eq!(
1261                db.get_contacts(user_3).await.unwrap(),
1262                &[
1263                    Contact::Accepted {
1264                        user_id: user_1,
1265                        should_notify: false
1266                    },
1267                    Contact::Accepted {
1268                        user_id: user_3,
1269                        should_notify: false
1270                    }
1271                ],
1272            );
1273
1274            // User declines a contact request. Both users see that it is gone.
1275            db.send_contact_request(user_2, user_3).await.unwrap();
1276            db.respond_to_contact_request(user_3, user_2, false)
1277                .await
1278                .unwrap();
1279            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1280            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1281            assert_eq!(
1282                db.get_contacts(user_2).await.unwrap(),
1283                &[
1284                    Contact::Accepted {
1285                        user_id: user_1,
1286                        should_notify: false
1287                    },
1288                    Contact::Accepted {
1289                        user_id: user_2,
1290                        should_notify: false
1291                    }
1292                ]
1293            );
1294            assert_eq!(
1295                db.get_contacts(user_3).await.unwrap(),
1296                &[
1297                    Contact::Accepted {
1298                        user_id: user_1,
1299                        should_notify: false
1300                    },
1301                    Contact::Accepted {
1302                        user_id: user_3,
1303                        should_notify: false
1304                    }
1305                ],
1306            );
1307        }
1308    }
1309
1310    #[tokio::test(flavor = "multi_thread")]
1311    async fn test_invite_codes() {
1312        let postgres = TestDb::postgres().await;
1313        let db = postgres.db();
1314        let user1 = db.create_user("user-1", false).await.unwrap();
1315
1316        // Initially, user 1 has no invite code
1317        assert_eq!(db.get_invite_code(user1).await.unwrap(), None);
1318
1319        // User 1 creates an invite code that can be used twice.
1320        db.set_invite_count(user1, 2).await.unwrap();
1321        let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1322        assert_eq!(invite_count, 2);
1323
1324        // User 2 redeems the invite code and becomes a contact of user 1.
1325        let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
1326        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1327        assert_eq!(invite_count, 1);
1328        assert_eq!(
1329            db.get_contacts(user1).await.unwrap(),
1330            [
1331                Contact::Accepted {
1332                    user_id: user1,
1333                    should_notify: false
1334                },
1335                Contact::Accepted {
1336                    user_id: user2,
1337                    should_notify: true
1338                }
1339            ]
1340        );
1341        assert_eq!(
1342            db.get_contacts(user2).await.unwrap(),
1343            [
1344                Contact::Accepted {
1345                    user_id: user1,
1346                    should_notify: false
1347                },
1348                Contact::Accepted {
1349                    user_id: user2,
1350                    should_notify: false
1351                }
1352            ]
1353        );
1354
1355        // User 3 redeems the invite code and becomes a contact of user 1.
1356        let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
1357        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1358        assert_eq!(invite_count, 0);
1359        assert_eq!(
1360            db.get_contacts(user1).await.unwrap(),
1361            [
1362                Contact::Accepted {
1363                    user_id: user1,
1364                    should_notify: false
1365                },
1366                Contact::Accepted {
1367                    user_id: user2,
1368                    should_notify: true
1369                },
1370                Contact::Accepted {
1371                    user_id: user3,
1372                    should_notify: true
1373                }
1374            ]
1375        );
1376        assert_eq!(
1377            db.get_contacts(user3).await.unwrap(),
1378            [
1379                Contact::Accepted {
1380                    user_id: user1,
1381                    should_notify: false
1382                },
1383                Contact::Accepted {
1384                    user_id: user3,
1385                    should_notify: false
1386                },
1387            ]
1388        );
1389
1390        // Trying to reedem the code for the third time results in an error.
1391        db.redeem_invite_code(&invite_code, "user-4")
1392            .await
1393            .unwrap_err();
1394
1395        // Invite count can be updated after the code has been created.
1396        db.set_invite_count(user1, 2).await.unwrap();
1397        let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1398        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1399        assert_eq!(invite_count, 2);
1400
1401        // User 4 can now redeem the invite code and becomes a contact of user 1.
1402        let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
1403        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1404        assert_eq!(invite_count, 1);
1405        assert_eq!(
1406            db.get_contacts(user1).await.unwrap(),
1407            [
1408                Contact::Accepted {
1409                    user_id: user1,
1410                    should_notify: false
1411                },
1412                Contact::Accepted {
1413                    user_id: user2,
1414                    should_notify: true
1415                },
1416                Contact::Accepted {
1417                    user_id: user3,
1418                    should_notify: true
1419                },
1420                Contact::Accepted {
1421                    user_id: user4,
1422                    should_notify: true
1423                }
1424            ]
1425        );
1426        assert_eq!(
1427            db.get_contacts(user4).await.unwrap(),
1428            [
1429                Contact::Accepted {
1430                    user_id: user1,
1431                    should_notify: false
1432                },
1433                Contact::Accepted {
1434                    user_id: user4,
1435                    should_notify: false
1436                },
1437            ]
1438        );
1439
1440        // An existing user cannot redeem invite codes.
1441        db.redeem_invite_code(&invite_code, "user-2")
1442            .await
1443            .unwrap_err();
1444        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
1445        assert_eq!(invite_count, 1);
1446    }
1447
1448    pub struct TestDb {
1449        pub db: Option<Arc<dyn Db>>,
1450        pub url: String,
1451    }
1452
1453    impl TestDb {
1454        pub async fn postgres() -> Self {
1455            lazy_static! {
1456                static ref LOCK: Mutex<()> = Mutex::new(());
1457            }
1458
1459            let _guard = LOCK.lock();
1460            let mut rng = StdRng::from_entropy();
1461            let name = format!("zed-test-{}", rng.gen::<u128>());
1462            let url = format!("postgres://postgres@localhost/{}", name);
1463            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1464            Postgres::create_database(&url)
1465                .await
1466                .expect("failed to create test db");
1467            let db = PostgresDb::new(&url, 5).await.unwrap();
1468            let migrator = Migrator::new(migrations_path).await.unwrap();
1469            migrator.run(&db.pool).await.unwrap();
1470            Self {
1471                db: Some(Arc::new(db)),
1472                url,
1473            }
1474        }
1475
1476        pub fn fake(background: Arc<Background>) -> Self {
1477            Self {
1478                db: Some(Arc::new(FakeDb::new(background))),
1479                url: Default::default(),
1480            }
1481        }
1482
1483        pub fn db(&self) -> &Arc<dyn Db> {
1484            self.db.as_ref().unwrap()
1485        }
1486    }
1487
1488    impl Drop for TestDb {
1489        fn drop(&mut self) {
1490            if let Some(db) = self.db.take() {
1491                futures::executor::block_on(db.teardown(&self.url));
1492            }
1493        }
1494    }
1495
1496    pub struct FakeDb {
1497        background: Arc<Background>,
1498        pub users: Mutex<BTreeMap<UserId, User>>,
1499        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1500        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1501        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1502        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1503        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1504        pub contacts: Mutex<Vec<FakeContact>>,
1505        next_channel_message_id: Mutex<i32>,
1506        next_user_id: Mutex<i32>,
1507        next_org_id: Mutex<i32>,
1508        next_channel_id: Mutex<i32>,
1509    }
1510
1511    #[derive(Debug)]
1512    pub struct FakeContact {
1513        pub requester_id: UserId,
1514        pub responder_id: UserId,
1515        pub accepted: bool,
1516        pub should_notify: bool,
1517    }
1518
1519    impl FakeDb {
1520        pub fn new(background: Arc<Background>) -> Self {
1521            Self {
1522                background,
1523                users: Default::default(),
1524                next_user_id: Mutex::new(1),
1525                orgs: Default::default(),
1526                next_org_id: Mutex::new(1),
1527                org_memberships: Default::default(),
1528                channels: Default::default(),
1529                next_channel_id: Mutex::new(1),
1530                channel_memberships: Default::default(),
1531                channel_messages: Default::default(),
1532                next_channel_message_id: Mutex::new(1),
1533                contacts: Default::default(),
1534            }
1535        }
1536    }
1537
1538    #[async_trait]
1539    impl Db for FakeDb {
1540        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1541            self.background.simulate_random_delay().await;
1542
1543            let mut users = self.users.lock();
1544            if let Some(user) = users
1545                .values()
1546                .find(|user| user.github_login == github_login)
1547            {
1548                Ok(user.id)
1549            } else {
1550                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1551                users.insert(
1552                    user_id,
1553                    User {
1554                        id: user_id,
1555                        github_login: github_login.to_string(),
1556                        admin,
1557                        invite_code: None,
1558                        invite_count: 0,
1559                    },
1560                );
1561                Ok(user_id)
1562            }
1563        }
1564
1565        async fn get_all_users(&self) -> Result<Vec<User>> {
1566            unimplemented!()
1567        }
1568
1569        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1570            unimplemented!()
1571        }
1572
1573        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1574            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1575        }
1576
1577        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1578            self.background.simulate_random_delay().await;
1579            let users = self.users.lock();
1580            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1581        }
1582
1583        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1584            Ok(self
1585                .users
1586                .lock()
1587                .values()
1588                .find(|user| user.github_login == github_login)
1589                .cloned())
1590        }
1591
1592        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1593            unimplemented!()
1594        }
1595
1596        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1597            unimplemented!()
1598        }
1599
1600        // invite codes
1601
1602        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1603            unimplemented!()
1604        }
1605
1606        async fn get_invite_code(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1607            unimplemented!()
1608        }
1609
1610        async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
1611            unimplemented!()
1612        }
1613
1614        // contacts
1615
1616        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1617            self.background.simulate_random_delay().await;
1618            let mut contacts = vec![Contact::Accepted {
1619                user_id: id,
1620                should_notify: false,
1621            }];
1622
1623            for contact in self.contacts.lock().iter() {
1624                if contact.requester_id == id {
1625                    if contact.accepted {
1626                        contacts.push(Contact::Accepted {
1627                            user_id: contact.responder_id,
1628                            should_notify: contact.should_notify,
1629                        });
1630                    } else {
1631                        contacts.push(Contact::Outgoing {
1632                            user_id: contact.responder_id,
1633                        });
1634                    }
1635                } else if contact.responder_id == id {
1636                    if contact.accepted {
1637                        contacts.push(Contact::Accepted {
1638                            user_id: contact.requester_id,
1639                            should_notify: false,
1640                        });
1641                    } else {
1642                        contacts.push(Contact::Incoming {
1643                            user_id: contact.requester_id,
1644                            should_notify: contact.should_notify,
1645                        });
1646                    }
1647                }
1648            }
1649
1650            contacts.sort_unstable_by_key(|contact| contact.user_id());
1651            Ok(contacts)
1652        }
1653
1654        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1655            self.background.simulate_random_delay().await;
1656            Ok(self.contacts.lock().iter().any(|contact| {
1657                contact.accepted
1658                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1659                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1660            }))
1661        }
1662
1663        async fn send_contact_request(
1664            &self,
1665            requester_id: UserId,
1666            responder_id: UserId,
1667        ) -> Result<()> {
1668            let mut contacts = self.contacts.lock();
1669            for contact in contacts.iter_mut() {
1670                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1671                    if contact.accepted {
1672                        Err(anyhow!("contact already exists"))?;
1673                    } else {
1674                        Err(anyhow!("contact already requested"))?;
1675                    }
1676                }
1677                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1678                    if contact.accepted {
1679                        Err(anyhow!("contact already exists"))?;
1680                    } else {
1681                        contact.accepted = true;
1682                        contact.should_notify = false;
1683                        return Ok(());
1684                    }
1685                }
1686            }
1687            contacts.push(FakeContact {
1688                requester_id,
1689                responder_id,
1690                accepted: false,
1691                should_notify: true,
1692            });
1693            Ok(())
1694        }
1695
1696        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1697            self.contacts.lock().retain(|contact| {
1698                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1699            });
1700            Ok(())
1701        }
1702
1703        async fn dismiss_contact_notification(
1704            &self,
1705            user_id: UserId,
1706            contact_user_id: UserId,
1707        ) -> Result<()> {
1708            let mut contacts = self.contacts.lock();
1709            for contact in contacts.iter_mut() {
1710                if contact.requester_id == contact_user_id
1711                    && contact.responder_id == user_id
1712                    && !contact.accepted
1713                {
1714                    contact.should_notify = false;
1715                    return Ok(());
1716                }
1717                if contact.requester_id == user_id
1718                    && contact.responder_id == contact_user_id
1719                    && contact.accepted
1720                {
1721                    contact.should_notify = false;
1722                    return Ok(());
1723                }
1724            }
1725            Err(anyhow!("no such notification"))?
1726        }
1727
1728        async fn respond_to_contact_request(
1729            &self,
1730            responder_id: UserId,
1731            requester_id: UserId,
1732            accept: bool,
1733        ) -> Result<()> {
1734            let mut contacts = self.contacts.lock();
1735            for (ix, contact) in contacts.iter_mut().enumerate() {
1736                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1737                    if contact.accepted {
1738                        Err(anyhow!("contact already confirmed"))?;
1739                    }
1740                    if accept {
1741                        contact.accepted = true;
1742                        contact.should_notify = true;
1743                    } else {
1744                        contacts.remove(ix);
1745                    }
1746                    return Ok(());
1747                }
1748            }
1749            Err(anyhow!("no such contact request"))?
1750        }
1751
1752        async fn create_access_token_hash(
1753            &self,
1754            _user_id: UserId,
1755            _access_token_hash: &str,
1756            _max_access_token_count: usize,
1757        ) -> Result<()> {
1758            unimplemented!()
1759        }
1760
1761        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1762            unimplemented!()
1763        }
1764
1765        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1766            unimplemented!()
1767        }
1768
1769        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1770            self.background.simulate_random_delay().await;
1771            let mut orgs = self.orgs.lock();
1772            if orgs.values().any(|org| org.slug == slug) {
1773                Err(anyhow!("org already exists"))?
1774            } else {
1775                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1776                orgs.insert(
1777                    org_id,
1778                    Org {
1779                        id: org_id,
1780                        name: name.to_string(),
1781                        slug: slug.to_string(),
1782                    },
1783                );
1784                Ok(org_id)
1785            }
1786        }
1787
1788        async fn add_org_member(
1789            &self,
1790            org_id: OrgId,
1791            user_id: UserId,
1792            is_admin: bool,
1793        ) -> Result<()> {
1794            self.background.simulate_random_delay().await;
1795            if !self.orgs.lock().contains_key(&org_id) {
1796                Err(anyhow!("org does not exist"))?;
1797            }
1798            if !self.users.lock().contains_key(&user_id) {
1799                Err(anyhow!("user does not exist"))?;
1800            }
1801
1802            self.org_memberships
1803                .lock()
1804                .entry((org_id, user_id))
1805                .or_insert(is_admin);
1806            Ok(())
1807        }
1808
1809        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1810            self.background.simulate_random_delay().await;
1811            if !self.orgs.lock().contains_key(&org_id) {
1812                Err(anyhow!("org does not exist"))?;
1813            }
1814
1815            let mut channels = self.channels.lock();
1816            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1817            channels.insert(
1818                channel_id,
1819                Channel {
1820                    id: channel_id,
1821                    name: name.to_string(),
1822                    owner_id: org_id.0,
1823                    owner_is_user: false,
1824                },
1825            );
1826            Ok(channel_id)
1827        }
1828
1829        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1830            self.background.simulate_random_delay().await;
1831            Ok(self
1832                .channels
1833                .lock()
1834                .values()
1835                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1836                .cloned()
1837                .collect())
1838        }
1839
1840        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1841            self.background.simulate_random_delay().await;
1842            let channels = self.channels.lock();
1843            let memberships = self.channel_memberships.lock();
1844            Ok(channels
1845                .values()
1846                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1847                .cloned()
1848                .collect())
1849        }
1850
1851        async fn can_user_access_channel(
1852            &self,
1853            user_id: UserId,
1854            channel_id: ChannelId,
1855        ) -> Result<bool> {
1856            self.background.simulate_random_delay().await;
1857            Ok(self
1858                .channel_memberships
1859                .lock()
1860                .contains_key(&(channel_id, user_id)))
1861        }
1862
1863        async fn add_channel_member(
1864            &self,
1865            channel_id: ChannelId,
1866            user_id: UserId,
1867            is_admin: bool,
1868        ) -> Result<()> {
1869            self.background.simulate_random_delay().await;
1870            if !self.channels.lock().contains_key(&channel_id) {
1871                Err(anyhow!("channel does not exist"))?;
1872            }
1873            if !self.users.lock().contains_key(&user_id) {
1874                Err(anyhow!("user does not exist"))?;
1875            }
1876
1877            self.channel_memberships
1878                .lock()
1879                .entry((channel_id, user_id))
1880                .or_insert(is_admin);
1881            Ok(())
1882        }
1883
1884        async fn create_channel_message(
1885            &self,
1886            channel_id: ChannelId,
1887            sender_id: UserId,
1888            body: &str,
1889            timestamp: OffsetDateTime,
1890            nonce: u128,
1891        ) -> Result<MessageId> {
1892            self.background.simulate_random_delay().await;
1893            if !self.channels.lock().contains_key(&channel_id) {
1894                Err(anyhow!("channel does not exist"))?;
1895            }
1896            if !self.users.lock().contains_key(&sender_id) {
1897                Err(anyhow!("user does not exist"))?;
1898            }
1899
1900            let mut messages = self.channel_messages.lock();
1901            if let Some(message) = messages
1902                .values()
1903                .find(|message| message.nonce.as_u128() == nonce)
1904            {
1905                Ok(message.id)
1906            } else {
1907                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1908                messages.insert(
1909                    message_id,
1910                    ChannelMessage {
1911                        id: message_id,
1912                        channel_id,
1913                        sender_id,
1914                        body: body.to_string(),
1915                        sent_at: timestamp,
1916                        nonce: Uuid::from_u128(nonce),
1917                    },
1918                );
1919                Ok(message_id)
1920            }
1921        }
1922
1923        async fn get_channel_messages(
1924            &self,
1925            channel_id: ChannelId,
1926            count: usize,
1927            before_id: Option<MessageId>,
1928        ) -> Result<Vec<ChannelMessage>> {
1929            let mut messages = self
1930                .channel_messages
1931                .lock()
1932                .values()
1933                .rev()
1934                .filter(|message| {
1935                    message.channel_id == channel_id
1936                        && message.id < before_id.unwrap_or(MessageId::MAX)
1937                })
1938                .take(count)
1939                .cloned()
1940                .collect::<Vec<_>>();
1941            messages.sort_unstable_by_key(|message| message.id);
1942            Ok(messages)
1943        }
1944
1945        async fn teardown(&self, _: &str) {}
1946
1947        #[cfg(test)]
1948        fn as_fake(&self) -> Option<&FakeDb> {
1949            Some(self)
1950        }
1951    }
1952}