db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use futures::StreamExt;
   6use nanoid::nanoid;
   7use serde::Serialize;
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{types::Uuid, FromRow};
  10use time::OffsetDateTime;
  11
  12#[async_trait]
  13pub trait Db: Send + Sync {
  14    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
  15    async fn get_all_users(&self) -> Result<Vec<User>>;
  16    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  17    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  18    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  19    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  20    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  21    async fn destroy_user(&self, id: UserId) -> Result<()>;
  22
  23    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  24    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  25    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  26    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
  27
  28    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  29    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  30    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  31    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  32    async fn dismiss_contact_notification(
  33        &self,
  34        responder_id: UserId,
  35        requester_id: UserId,
  36    ) -> Result<()>;
  37    async fn respond_to_contact_request(
  38        &self,
  39        responder_id: UserId,
  40        requester_id: UserId,
  41        accept: bool,
  42    ) -> Result<()>;
  43
  44    async fn create_access_token_hash(
  45        &self,
  46        user_id: UserId,
  47        access_token_hash: &str,
  48        max_access_token_count: usize,
  49    ) -> Result<()>;
  50    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  51    #[cfg(any(test, feature = "seed-support"))]
  52
  53    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  54    #[cfg(any(test, feature = "seed-support"))]
  55    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  56    #[cfg(any(test, feature = "seed-support"))]
  57    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  58    #[cfg(any(test, feature = "seed-support"))]
  59    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  60    #[cfg(any(test, feature = "seed-support"))]
  61
  62    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  63    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  64    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  65        -> Result<bool>;
  66    #[cfg(any(test, feature = "seed-support"))]
  67    async fn add_channel_member(
  68        &self,
  69        channel_id: ChannelId,
  70        user_id: UserId,
  71        is_admin: bool,
  72    ) -> Result<()>;
  73    async fn create_channel_message(
  74        &self,
  75        channel_id: ChannelId,
  76        sender_id: UserId,
  77        body: &str,
  78        timestamp: OffsetDateTime,
  79        nonce: u128,
  80    ) -> Result<MessageId>;
  81    async fn get_channel_messages(
  82        &self,
  83        channel_id: ChannelId,
  84        count: usize,
  85        before_id: Option<MessageId>,
  86    ) -> Result<Vec<ChannelMessage>>;
  87    #[cfg(test)]
  88    async fn teardown(&self, url: &str);
  89    #[cfg(test)]
  90    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
  91}
  92
  93pub struct PostgresDb {
  94    pool: sqlx::PgPool,
  95}
  96
  97impl PostgresDb {
  98    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  99        let pool = DbOptions::new()
 100            .max_connections(max_connections)
 101            .connect(&url)
 102            .await
 103            .context("failed to connect to postgres database")?;
 104        Ok(Self { pool })
 105    }
 106}
 107
 108#[async_trait]
 109impl Db for PostgresDb {
 110    // users
 111
 112    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 113        let query = "
 114            INSERT INTO users (github_login, admin)
 115            VALUES ($1, $2)
 116            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 117            RETURNING id
 118        ";
 119        Ok(sqlx::query_scalar(query)
 120            .bind(github_login)
 121            .bind(admin)
 122            .fetch_one(&self.pool)
 123            .await
 124            .map(UserId)?)
 125    }
 126
 127    async fn get_all_users(&self) -> Result<Vec<User>> {
 128        let query = "SELECT * FROM users ORDER BY github_login ASC";
 129        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 130    }
 131
 132    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 133        let like_string = fuzzy_like_string(name_query);
 134        let query = "
 135            SELECT users.*
 136            FROM users
 137            WHERE github_login ILIKE $1
 138            ORDER BY github_login <-> $2
 139            LIMIT $3
 140        ";
 141        Ok(sqlx::query_as(query)
 142            .bind(like_string)
 143            .bind(name_query)
 144            .bind(limit)
 145            .fetch_all(&self.pool)
 146            .await?)
 147    }
 148
 149    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 150        let users = self.get_users_by_ids(vec![id]).await?;
 151        Ok(users.into_iter().next())
 152    }
 153
 154    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 155        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 156        let query = "
 157            SELECT users.*
 158            FROM users
 159            WHERE users.id = ANY ($1)
 160        ";
 161        Ok(sqlx::query_as(query)
 162            .bind(&ids)
 163            .fetch_all(&self.pool)
 164            .await?)
 165    }
 166
 167    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 168        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 169        Ok(sqlx::query_as(query)
 170            .bind(github_login)
 171            .fetch_optional(&self.pool)
 172            .await?)
 173    }
 174
 175    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 176        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 177        Ok(sqlx::query(query)
 178            .bind(is_admin)
 179            .bind(id.0)
 180            .execute(&self.pool)
 181            .await
 182            .map(drop)?)
 183    }
 184
 185    async fn destroy_user(&self, id: UserId) -> Result<()> {
 186        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 187        sqlx::query(query)
 188            .bind(id.0)
 189            .execute(&self.pool)
 190            .await
 191            .map(drop)?;
 192        let query = "DELETE FROM users WHERE id = $1;";
 193        Ok(sqlx::query(query)
 194            .bind(id.0)
 195            .execute(&self.pool)
 196            .await
 197            .map(drop)?)
 198    }
 199
 200    // invite codes
 201
 202    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 203        let mut tx = self.pool.begin().await?;
 204        sqlx::query(
 205            "
 206                UPDATE users
 207                SET invite_code = $1
 208                WHERE id = $2 AND invite_code IS NULL
 209            ",
 210        )
 211        .bind(nanoid!(16))
 212        .bind(id)
 213        .execute(&mut tx)
 214        .await?;
 215        sqlx::query(
 216            "
 217            UPDATE users
 218            SET invite_count = $1
 219            WHERE id = $2
 220            ",
 221        )
 222        .bind(count)
 223        .bind(id)
 224        .execute(&mut tx)
 225        .await?;
 226        tx.commit().await?;
 227        Ok(())
 228    }
 229
 230    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 231        let result: Option<(String, i32)> = sqlx::query_as(
 232            "
 233                SELECT invite_code, invite_count
 234                FROM users
 235                WHERE id = $1 AND invite_code IS NOT NULL 
 236            ",
 237        )
 238        .bind(id)
 239        .fetch_optional(&self.pool)
 240        .await?;
 241        if let Some((code, count)) = result {
 242            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 243        } else {
 244            Ok(None)
 245        }
 246    }
 247
 248    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 249        sqlx::query_as(
 250            "
 251                SELECT *
 252                FROM users
 253                WHERE invite_code = $1
 254            ",
 255        )
 256        .bind(code)
 257        .fetch_optional(&self.pool)
 258        .await?
 259        .ok_or_else(|| {
 260            Error::Http(
 261                StatusCode::NOT_FOUND,
 262                "that invite code does not exist".to_string(),
 263            )
 264        })
 265    }
 266
 267    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
 268        let mut tx = self.pool.begin().await?;
 269
 270        let inviter_id: Option<UserId> = sqlx::query_scalar(
 271            "
 272                UPDATE users
 273                SET invite_count = invite_count - 1
 274                WHERE
 275                    invite_code = $1 AND
 276                    invite_count > 0
 277                RETURNING id
 278            ",
 279        )
 280        .bind(code)
 281        .fetch_optional(&mut tx)
 282        .await?;
 283
 284        let inviter_id = match inviter_id {
 285            Some(inviter_id) => inviter_id,
 286            None => {
 287                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 288                    .bind(code)
 289                    .fetch_optional(&mut tx)
 290                    .await?
 291                    .is_some()
 292                {
 293                    Err(Error::Http(
 294                        StatusCode::UNAUTHORIZED,
 295                        "no invites remaining".to_string(),
 296                    ))?
 297                } else {
 298                    Err(Error::Http(
 299                        StatusCode::NOT_FOUND,
 300                        "invite code not found".to_string(),
 301                    ))?
 302                }
 303            }
 304        };
 305
 306        let invitee_id = sqlx::query_scalar(
 307            "
 308                INSERT INTO users
 309                    (github_login, admin, inviter_id)
 310                VALUES
 311                    ($1, 'f', $2)
 312                RETURNING id
 313            ",
 314        )
 315        .bind(login)
 316        .bind(inviter_id)
 317        .fetch_one(&mut tx)
 318        .await
 319        .map(UserId)?;
 320
 321        sqlx::query(
 322            "
 323                INSERT INTO contacts
 324                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 325                VALUES
 326                    ($1, $2, 't', 't', 't')
 327            ",
 328        )
 329        .bind(inviter_id)
 330        .bind(invitee_id)
 331        .execute(&mut tx)
 332        .await?;
 333
 334        tx.commit().await?;
 335        Ok(invitee_id)
 336    }
 337
 338    // contacts
 339
 340    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 341        let query = "
 342            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 343            FROM contacts
 344            WHERE user_id_a = $1 OR user_id_b = $1;
 345        ";
 346
 347        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 348            .bind(user_id)
 349            .fetch(&self.pool);
 350
 351        let mut contacts = vec![Contact::Accepted {
 352            user_id,
 353            should_notify: false,
 354        }];
 355        while let Some(row) = rows.next().await {
 356            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 357
 358            if user_id_a == user_id {
 359                if accepted {
 360                    contacts.push(Contact::Accepted {
 361                        user_id: user_id_b,
 362                        should_notify: should_notify && a_to_b,
 363                    });
 364                } else if a_to_b {
 365                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 366                } else {
 367                    contacts.push(Contact::Incoming {
 368                        user_id: user_id_b,
 369                        should_notify,
 370                    });
 371                }
 372            } else {
 373                if accepted {
 374                    contacts.push(Contact::Accepted {
 375                        user_id: user_id_a,
 376                        should_notify: should_notify && !a_to_b,
 377                    });
 378                } else if a_to_b {
 379                    contacts.push(Contact::Incoming {
 380                        user_id: user_id_a,
 381                        should_notify,
 382                    });
 383                } else {
 384                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 385                }
 386            }
 387        }
 388
 389        contacts.sort_unstable_by_key(|contact| contact.user_id());
 390
 391        Ok(contacts)
 392    }
 393
 394    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 395        let (id_a, id_b) = if user_id_1 < user_id_2 {
 396            (user_id_1, user_id_2)
 397        } else {
 398            (user_id_2, user_id_1)
 399        };
 400
 401        let query = "
 402            SELECT 1 FROM contacts
 403            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 404            LIMIT 1
 405        ";
 406        Ok(sqlx::query_scalar::<_, i32>(query)
 407            .bind(id_a.0)
 408            .bind(id_b.0)
 409            .fetch_optional(&self.pool)
 410            .await?
 411            .is_some())
 412    }
 413
 414    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 415        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 416            (sender_id, receiver_id, true)
 417        } else {
 418            (receiver_id, sender_id, false)
 419        };
 420        let query = "
 421            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 422            VALUES ($1, $2, $3, 'f', 't')
 423            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 424            SET
 425                accepted = 't',
 426                should_notify = 'f'
 427            WHERE
 428                NOT contacts.accepted AND
 429                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 430                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 431        ";
 432        let result = sqlx::query(query)
 433            .bind(id_a.0)
 434            .bind(id_b.0)
 435            .bind(a_to_b)
 436            .execute(&self.pool)
 437            .await?;
 438
 439        if result.rows_affected() == 1 {
 440            Ok(())
 441        } else {
 442            Err(anyhow!("contact already requested"))?
 443        }
 444    }
 445
 446    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 447        let (id_a, id_b) = if responder_id < requester_id {
 448            (responder_id, requester_id)
 449        } else {
 450            (requester_id, responder_id)
 451        };
 452        let query = "
 453            DELETE FROM contacts
 454            WHERE user_id_a = $1 AND user_id_b = $2;
 455        ";
 456        let result = sqlx::query(query)
 457            .bind(id_a.0)
 458            .bind(id_b.0)
 459            .execute(&self.pool)
 460            .await?;
 461
 462        if result.rows_affected() == 1 {
 463            Ok(())
 464        } else {
 465            Err(anyhow!("no such contact"))?
 466        }
 467    }
 468
 469    async fn dismiss_contact_notification(
 470        &self,
 471        user_id: UserId,
 472        contact_user_id: UserId,
 473    ) -> Result<()> {
 474        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 475            (user_id, contact_user_id, true)
 476        } else {
 477            (contact_user_id, user_id, false)
 478        };
 479
 480        let query = "
 481            UPDATE contacts
 482            SET should_notify = 'f'
 483            WHERE
 484                user_id_a = $1 AND user_id_b = $2 AND
 485                (
 486                    (a_to_b = $3 AND accepted) OR
 487                    (a_to_b != $3 AND NOT accepted)
 488                );
 489        ";
 490
 491        let result = sqlx::query(query)
 492            .bind(id_a.0)
 493            .bind(id_b.0)
 494            .bind(a_to_b)
 495            .execute(&self.pool)
 496            .await?;
 497
 498        if result.rows_affected() == 0 {
 499            Err(anyhow!("no such contact request"))?;
 500        }
 501
 502        Ok(())
 503    }
 504
 505    async fn respond_to_contact_request(
 506        &self,
 507        responder_id: UserId,
 508        requester_id: UserId,
 509        accept: bool,
 510    ) -> Result<()> {
 511        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 512            (responder_id, requester_id, false)
 513        } else {
 514            (requester_id, responder_id, true)
 515        };
 516        let result = if accept {
 517            let query = "
 518                UPDATE contacts
 519                SET accepted = 't', should_notify = 't'
 520                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 521            ";
 522            sqlx::query(query)
 523                .bind(id_a.0)
 524                .bind(id_b.0)
 525                .bind(a_to_b)
 526                .execute(&self.pool)
 527                .await?
 528        } else {
 529            let query = "
 530                DELETE FROM contacts
 531                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 532            ";
 533            sqlx::query(query)
 534                .bind(id_a.0)
 535                .bind(id_b.0)
 536                .bind(a_to_b)
 537                .execute(&self.pool)
 538                .await?
 539        };
 540        if result.rows_affected() == 1 {
 541            Ok(())
 542        } else {
 543            Err(anyhow!("no such contact request"))?
 544        }
 545    }
 546
 547    // access tokens
 548
 549    async fn create_access_token_hash(
 550        &self,
 551        user_id: UserId,
 552        access_token_hash: &str,
 553        max_access_token_count: usize,
 554    ) -> Result<()> {
 555        let insert_query = "
 556            INSERT INTO access_tokens (user_id, hash)
 557            VALUES ($1, $2);
 558        ";
 559        let cleanup_query = "
 560            DELETE FROM access_tokens
 561            WHERE id IN (
 562                SELECT id from access_tokens
 563                WHERE user_id = $1
 564                ORDER BY id DESC
 565                OFFSET $3
 566            )
 567        ";
 568
 569        let mut tx = self.pool.begin().await?;
 570        sqlx::query(insert_query)
 571            .bind(user_id.0)
 572            .bind(access_token_hash)
 573            .execute(&mut tx)
 574            .await?;
 575        sqlx::query(cleanup_query)
 576            .bind(user_id.0)
 577            .bind(access_token_hash)
 578            .bind(max_access_token_count as u32)
 579            .execute(&mut tx)
 580            .await?;
 581        Ok(tx.commit().await?)
 582    }
 583
 584    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 585        let query = "
 586            SELECT hash
 587            FROM access_tokens
 588            WHERE user_id = $1
 589            ORDER BY id DESC
 590        ";
 591        Ok(sqlx::query_scalar(query)
 592            .bind(user_id.0)
 593            .fetch_all(&self.pool)
 594            .await?)
 595    }
 596
 597    // orgs
 598
 599    #[allow(unused)] // Help rust-analyzer
 600    #[cfg(any(test, feature = "seed-support"))]
 601    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 602        let query = "
 603            SELECT *
 604            FROM orgs
 605            WHERE slug = $1
 606        ";
 607        Ok(sqlx::query_as(query)
 608            .bind(slug)
 609            .fetch_optional(&self.pool)
 610            .await?)
 611    }
 612
 613    #[cfg(any(test, feature = "seed-support"))]
 614    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 615        let query = "
 616            INSERT INTO orgs (name, slug)
 617            VALUES ($1, $2)
 618            RETURNING id
 619        ";
 620        Ok(sqlx::query_scalar(query)
 621            .bind(name)
 622            .bind(slug)
 623            .fetch_one(&self.pool)
 624            .await
 625            .map(OrgId)?)
 626    }
 627
 628    #[cfg(any(test, feature = "seed-support"))]
 629    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 630        let query = "
 631            INSERT INTO org_memberships (org_id, user_id, admin)
 632            VALUES ($1, $2, $3)
 633            ON CONFLICT DO NOTHING
 634        ";
 635        Ok(sqlx::query(query)
 636            .bind(org_id.0)
 637            .bind(user_id.0)
 638            .bind(is_admin)
 639            .execute(&self.pool)
 640            .await
 641            .map(drop)?)
 642    }
 643
 644    // channels
 645
 646    #[cfg(any(test, feature = "seed-support"))]
 647    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 648        let query = "
 649            INSERT INTO channels (owner_id, owner_is_user, name)
 650            VALUES ($1, false, $2)
 651            RETURNING id
 652        ";
 653        Ok(sqlx::query_scalar(query)
 654            .bind(org_id.0)
 655            .bind(name)
 656            .fetch_one(&self.pool)
 657            .await
 658            .map(ChannelId)?)
 659    }
 660
 661    #[allow(unused)] // Help rust-analyzer
 662    #[cfg(any(test, feature = "seed-support"))]
 663    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 664        let query = "
 665            SELECT *
 666            FROM channels
 667            WHERE
 668                channels.owner_is_user = false AND
 669                channels.owner_id = $1
 670        ";
 671        Ok(sqlx::query_as(query)
 672            .bind(org_id.0)
 673            .fetch_all(&self.pool)
 674            .await?)
 675    }
 676
 677    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 678        let query = "
 679            SELECT
 680                channels.*
 681            FROM
 682                channel_memberships, channels
 683            WHERE
 684                channel_memberships.user_id = $1 AND
 685                channel_memberships.channel_id = channels.id
 686        ";
 687        Ok(sqlx::query_as(query)
 688            .bind(user_id.0)
 689            .fetch_all(&self.pool)
 690            .await?)
 691    }
 692
 693    async fn can_user_access_channel(
 694        &self,
 695        user_id: UserId,
 696        channel_id: ChannelId,
 697    ) -> Result<bool> {
 698        let query = "
 699            SELECT id
 700            FROM channel_memberships
 701            WHERE user_id = $1 AND channel_id = $2
 702            LIMIT 1
 703        ";
 704        Ok(sqlx::query_scalar::<_, i32>(query)
 705            .bind(user_id.0)
 706            .bind(channel_id.0)
 707            .fetch_optional(&self.pool)
 708            .await
 709            .map(|e| e.is_some())?)
 710    }
 711
 712    #[cfg(any(test, feature = "seed-support"))]
 713    async fn add_channel_member(
 714        &self,
 715        channel_id: ChannelId,
 716        user_id: UserId,
 717        is_admin: bool,
 718    ) -> Result<()> {
 719        let query = "
 720            INSERT INTO channel_memberships (channel_id, user_id, admin)
 721            VALUES ($1, $2, $3)
 722            ON CONFLICT DO NOTHING
 723        ";
 724        Ok(sqlx::query(query)
 725            .bind(channel_id.0)
 726            .bind(user_id.0)
 727            .bind(is_admin)
 728            .execute(&self.pool)
 729            .await
 730            .map(drop)?)
 731    }
 732
 733    // messages
 734
 735    async fn create_channel_message(
 736        &self,
 737        channel_id: ChannelId,
 738        sender_id: UserId,
 739        body: &str,
 740        timestamp: OffsetDateTime,
 741        nonce: u128,
 742    ) -> Result<MessageId> {
 743        let query = "
 744            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 745            VALUES ($1, $2, $3, $4, $5)
 746            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 747            RETURNING id
 748        ";
 749        Ok(sqlx::query_scalar(query)
 750            .bind(channel_id.0)
 751            .bind(sender_id.0)
 752            .bind(body)
 753            .bind(timestamp)
 754            .bind(Uuid::from_u128(nonce))
 755            .fetch_one(&self.pool)
 756            .await
 757            .map(MessageId)?)
 758    }
 759
 760    async fn get_channel_messages(
 761        &self,
 762        channel_id: ChannelId,
 763        count: usize,
 764        before_id: Option<MessageId>,
 765    ) -> Result<Vec<ChannelMessage>> {
 766        let query = r#"
 767            SELECT * FROM (
 768                SELECT
 769                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 770                FROM
 771                    channel_messages
 772                WHERE
 773                    channel_id = $1 AND
 774                    id < $2
 775                ORDER BY id DESC
 776                LIMIT $3
 777            ) as recent_messages
 778            ORDER BY id ASC
 779        "#;
 780        Ok(sqlx::query_as(query)
 781            .bind(channel_id.0)
 782            .bind(before_id.unwrap_or(MessageId::MAX))
 783            .bind(count as i64)
 784            .fetch_all(&self.pool)
 785            .await?)
 786    }
 787
 788    #[cfg(test)]
 789    async fn teardown(&self, url: &str) {
 790        use util::ResultExt;
 791
 792        let query = "
 793            SELECT pg_terminate_backend(pg_stat_activity.pid)
 794            FROM pg_stat_activity
 795            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 796        ";
 797        sqlx::query(query).execute(&self.pool).await.log_err();
 798        self.pool.close().await;
 799        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 800            .await
 801            .log_err();
 802    }
 803
 804    #[cfg(test)]
 805    fn as_fake(&self) -> Option<&tests::FakeDb> {
 806        None
 807    }
 808}
 809
 810macro_rules! id_type {
 811    ($name:ident) => {
 812        #[derive(
 813            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 814        )]
 815        #[sqlx(transparent)]
 816        #[serde(transparent)]
 817        pub struct $name(pub i32);
 818
 819        impl $name {
 820            #[allow(unused)]
 821            pub const MAX: Self = Self(i32::MAX);
 822
 823            #[allow(unused)]
 824            pub fn from_proto(value: u64) -> Self {
 825                Self(value as i32)
 826            }
 827
 828            #[allow(unused)]
 829            pub fn to_proto(&self) -> u64 {
 830                self.0 as u64
 831            }
 832        }
 833
 834        impl std::fmt::Display for $name {
 835            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 836                self.0.fmt(f)
 837            }
 838        }
 839    };
 840}
 841
 842id_type!(UserId);
 843#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 844pub struct User {
 845    pub id: UserId,
 846    pub github_login: String,
 847    pub admin: bool,
 848    pub invite_code: Option<String>,
 849    pub invite_count: i32,
 850}
 851
 852id_type!(OrgId);
 853#[derive(FromRow)]
 854pub struct Org {
 855    pub id: OrgId,
 856    pub name: String,
 857    pub slug: String,
 858}
 859
 860id_type!(ChannelId);
 861#[derive(Clone, Debug, FromRow, Serialize)]
 862pub struct Channel {
 863    pub id: ChannelId,
 864    pub name: String,
 865    pub owner_id: i32,
 866    pub owner_is_user: bool,
 867}
 868
 869id_type!(MessageId);
 870#[derive(Clone, Debug, FromRow)]
 871pub struct ChannelMessage {
 872    pub id: MessageId,
 873    pub channel_id: ChannelId,
 874    pub sender_id: UserId,
 875    pub body: String,
 876    pub sent_at: OffsetDateTime,
 877    pub nonce: Uuid,
 878}
 879
 880#[derive(Clone, Debug, PartialEq, Eq)]
 881pub enum Contact {
 882    Accepted {
 883        user_id: UserId,
 884        should_notify: bool,
 885    },
 886    Outgoing {
 887        user_id: UserId,
 888    },
 889    Incoming {
 890        user_id: UserId,
 891        should_notify: bool,
 892    },
 893}
 894
 895impl Contact {
 896    pub fn user_id(&self) -> UserId {
 897        match self {
 898            Contact::Accepted { user_id, .. } => *user_id,
 899            Contact::Outgoing { user_id } => *user_id,
 900            Contact::Incoming { user_id, .. } => *user_id,
 901        }
 902    }
 903}
 904
 905#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 906pub struct IncomingContactRequest {
 907    pub requester_id: UserId,
 908    pub should_notify: bool,
 909}
 910
 911fn fuzzy_like_string(string: &str) -> String {
 912    let mut result = String::with_capacity(string.len() * 2 + 1);
 913    for c in string.chars() {
 914        if c.is_alphanumeric() {
 915            result.push('%');
 916            result.push(c);
 917        }
 918    }
 919    result.push('%');
 920    result
 921}
 922
 923#[cfg(test)]
 924pub mod tests {
 925    use super::*;
 926    use anyhow::anyhow;
 927    use collections::BTreeMap;
 928    use gpui::executor::Background;
 929    use lazy_static::lazy_static;
 930    use parking_lot::Mutex;
 931    use rand::prelude::*;
 932    use sqlx::{
 933        migrate::{MigrateDatabase, Migrator},
 934        Postgres,
 935    };
 936    use std::{path::Path, sync::Arc};
 937    use util::post_inc;
 938
 939    #[tokio::test(flavor = "multi_thread")]
 940    async fn test_get_users_by_ids() {
 941        for test_db in [
 942            TestDb::postgres().await,
 943            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 944        ] {
 945            let db = test_db.db();
 946
 947            let user = db.create_user("user", false).await.unwrap();
 948            let friend1 = db.create_user("friend-1", false).await.unwrap();
 949            let friend2 = db.create_user("friend-2", false).await.unwrap();
 950            let friend3 = db.create_user("friend-3", false).await.unwrap();
 951
 952            assert_eq!(
 953                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 954                    .await
 955                    .unwrap(),
 956                vec![
 957                    User {
 958                        id: user,
 959                        github_login: "user".to_string(),
 960                        admin: false,
 961                        ..Default::default()
 962                    },
 963                    User {
 964                        id: friend1,
 965                        github_login: "friend-1".to_string(),
 966                        admin: false,
 967                        ..Default::default()
 968                    },
 969                    User {
 970                        id: friend2,
 971                        github_login: "friend-2".to_string(),
 972                        admin: false,
 973                        ..Default::default()
 974                    },
 975                    User {
 976                        id: friend3,
 977                        github_login: "friend-3".to_string(),
 978                        admin: false,
 979                        ..Default::default()
 980                    }
 981                ]
 982            );
 983        }
 984    }
 985
 986    #[tokio::test(flavor = "multi_thread")]
 987    async fn test_recent_channel_messages() {
 988        for test_db in [
 989            TestDb::postgres().await,
 990            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 991        ] {
 992            let db = test_db.db();
 993            let user = db.create_user("user", false).await.unwrap();
 994            let org = db.create_org("org", "org").await.unwrap();
 995            let channel = db.create_org_channel(org, "channel").await.unwrap();
 996            for i in 0..10 {
 997                db.create_channel_message(
 998                    channel,
 999                    user,
1000                    &i.to_string(),
1001                    OffsetDateTime::now_utc(),
1002                    i,
1003                )
1004                .await
1005                .unwrap();
1006            }
1007
1008            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1009            assert_eq!(
1010                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1011                ["5", "6", "7", "8", "9"]
1012            );
1013
1014            let prev_messages = db
1015                .get_channel_messages(channel, 4, Some(messages[0].id))
1016                .await
1017                .unwrap();
1018            assert_eq!(
1019                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1020                ["1", "2", "3", "4"]
1021            );
1022        }
1023    }
1024
1025    #[tokio::test(flavor = "multi_thread")]
1026    async fn test_channel_message_nonces() {
1027        for test_db in [
1028            TestDb::postgres().await,
1029            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1030        ] {
1031            let db = test_db.db();
1032            let user = db.create_user("user", false).await.unwrap();
1033            let org = db.create_org("org", "org").await.unwrap();
1034            let channel = db.create_org_channel(org, "channel").await.unwrap();
1035
1036            let msg1_id = db
1037                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1038                .await
1039                .unwrap();
1040            let msg2_id = db
1041                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1042                .await
1043                .unwrap();
1044            let msg3_id = db
1045                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1046                .await
1047                .unwrap();
1048            let msg4_id = db
1049                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1050                .await
1051                .unwrap();
1052
1053            assert_ne!(msg1_id, msg2_id);
1054            assert_eq!(msg1_id, msg3_id);
1055            assert_eq!(msg2_id, msg4_id);
1056        }
1057    }
1058
1059    #[tokio::test(flavor = "multi_thread")]
1060    async fn test_create_access_tokens() {
1061        let test_db = TestDb::postgres().await;
1062        let db = test_db.db();
1063        let user = db.create_user("the-user", false).await.unwrap();
1064
1065        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1066        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1067        assert_eq!(
1068            db.get_access_token_hashes(user).await.unwrap(),
1069            &["h2".to_string(), "h1".to_string()]
1070        );
1071
1072        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1073        assert_eq!(
1074            db.get_access_token_hashes(user).await.unwrap(),
1075            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1076        );
1077
1078        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1079        assert_eq!(
1080            db.get_access_token_hashes(user).await.unwrap(),
1081            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1082        );
1083
1084        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1085        assert_eq!(
1086            db.get_access_token_hashes(user).await.unwrap(),
1087            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1088        );
1089    }
1090
1091    #[test]
1092    fn test_fuzzy_like_string() {
1093        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1094        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1095        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1096    }
1097
1098    #[tokio::test(flavor = "multi_thread")]
1099    async fn test_fuzzy_search_users() {
1100        let test_db = TestDb::postgres().await;
1101        let db = test_db.db();
1102        for github_login in [
1103            "California",
1104            "colorado",
1105            "oregon",
1106            "washington",
1107            "florida",
1108            "delaware",
1109            "rhode-island",
1110        ] {
1111            db.create_user(github_login, false).await.unwrap();
1112        }
1113
1114        assert_eq!(
1115            fuzzy_search_user_names(db, "clr").await,
1116            &["colorado", "California"]
1117        );
1118        assert_eq!(
1119            fuzzy_search_user_names(db, "ro").await,
1120            &["rhode-island", "colorado", "oregon"],
1121        );
1122
1123        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1124            db.fuzzy_search_users(query, 10)
1125                .await
1126                .unwrap()
1127                .into_iter()
1128                .map(|user| user.github_login)
1129                .collect::<Vec<_>>()
1130        }
1131    }
1132
1133    #[tokio::test(flavor = "multi_thread")]
1134    async fn test_add_contacts() {
1135        for test_db in [
1136            TestDb::postgres().await,
1137            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1138        ] {
1139            let db = test_db.db();
1140
1141            let user_1 = db.create_user("user1", false).await.unwrap();
1142            let user_2 = db.create_user("user2", false).await.unwrap();
1143            let user_3 = db.create_user("user3", false).await.unwrap();
1144
1145            // User starts with no contacts
1146            assert_eq!(
1147                db.get_contacts(user_1).await.unwrap(),
1148                vec![Contact::Accepted {
1149                    user_id: user_1,
1150                    should_notify: false
1151                }],
1152            );
1153
1154            // User requests a contact. Both users see the pending request.
1155            db.send_contact_request(user_1, user_2).await.unwrap();
1156            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1157            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1158            assert_eq!(
1159                db.get_contacts(user_1).await.unwrap(),
1160                &[
1161                    Contact::Accepted {
1162                        user_id: user_1,
1163                        should_notify: false
1164                    },
1165                    Contact::Outgoing { user_id: user_2 }
1166                ],
1167            );
1168            assert_eq!(
1169                db.get_contacts(user_2).await.unwrap(),
1170                &[
1171                    Contact::Incoming {
1172                        user_id: user_1,
1173                        should_notify: true
1174                    },
1175                    Contact::Accepted {
1176                        user_id: user_2,
1177                        should_notify: false
1178                    },
1179                ]
1180            );
1181
1182            // User 2 dismisses the contact request notification without accepting or rejecting.
1183            // We shouldn't notify them again.
1184            db.dismiss_contact_notification(user_1, user_2)
1185                .await
1186                .unwrap_err();
1187            db.dismiss_contact_notification(user_2, user_1)
1188                .await
1189                .unwrap();
1190            assert_eq!(
1191                db.get_contacts(user_2).await.unwrap(),
1192                &[
1193                    Contact::Incoming {
1194                        user_id: user_1,
1195                        should_notify: false
1196                    },
1197                    Contact::Accepted {
1198                        user_id: user_2,
1199                        should_notify: false
1200                    },
1201                ]
1202            );
1203
1204            // User can't accept their own contact request
1205            db.respond_to_contact_request(user_1, user_2, true)
1206                .await
1207                .unwrap_err();
1208
1209            // User accepts a contact request. Both users see the contact.
1210            db.respond_to_contact_request(user_2, user_1, true)
1211                .await
1212                .unwrap();
1213            assert_eq!(
1214                db.get_contacts(user_1).await.unwrap(),
1215                &[
1216                    Contact::Accepted {
1217                        user_id: user_1,
1218                        should_notify: false
1219                    },
1220                    Contact::Accepted {
1221                        user_id: user_2,
1222                        should_notify: true
1223                    }
1224                ],
1225            );
1226            assert!(db.has_contact(user_1, user_2).await.unwrap());
1227            assert!(db.has_contact(user_2, user_1).await.unwrap());
1228            assert_eq!(
1229                db.get_contacts(user_2).await.unwrap(),
1230                &[
1231                    Contact::Accepted {
1232                        user_id: user_1,
1233                        should_notify: false,
1234                    },
1235                    Contact::Accepted {
1236                        user_id: user_2,
1237                        should_notify: false,
1238                    },
1239                ]
1240            );
1241
1242            // Users cannot re-request existing contacts.
1243            db.send_contact_request(user_1, user_2).await.unwrap_err();
1244            db.send_contact_request(user_2, user_1).await.unwrap_err();
1245
1246            // Users can't dismiss notifications of them accepting other users' requests.
1247            db.dismiss_contact_notification(user_2, user_1)
1248                .await
1249                .unwrap_err();
1250            assert_eq!(
1251                db.get_contacts(user_1).await.unwrap(),
1252                &[
1253                    Contact::Accepted {
1254                        user_id: user_1,
1255                        should_notify: false
1256                    },
1257                    Contact::Accepted {
1258                        user_id: user_2,
1259                        should_notify: true,
1260                    },
1261                ]
1262            );
1263
1264            // Users can dismiss notifications of other users accepting their requests.
1265            db.dismiss_contact_notification(user_1, user_2)
1266                .await
1267                .unwrap();
1268            assert_eq!(
1269                db.get_contacts(user_1).await.unwrap(),
1270                &[
1271                    Contact::Accepted {
1272                        user_id: user_1,
1273                        should_notify: false
1274                    },
1275                    Contact::Accepted {
1276                        user_id: user_2,
1277                        should_notify: false,
1278                    },
1279                ]
1280            );
1281
1282            // Users send each other concurrent contact requests and
1283            // see that they are immediately accepted.
1284            db.send_contact_request(user_1, user_3).await.unwrap();
1285            db.send_contact_request(user_3, user_1).await.unwrap();
1286            assert_eq!(
1287                db.get_contacts(user_1).await.unwrap(),
1288                &[
1289                    Contact::Accepted {
1290                        user_id: user_1,
1291                        should_notify: false
1292                    },
1293                    Contact::Accepted {
1294                        user_id: user_2,
1295                        should_notify: false,
1296                    },
1297                    Contact::Accepted {
1298                        user_id: user_3,
1299                        should_notify: false
1300                    },
1301                ]
1302            );
1303            assert_eq!(
1304                db.get_contacts(user_3).await.unwrap(),
1305                &[
1306                    Contact::Accepted {
1307                        user_id: user_1,
1308                        should_notify: false
1309                    },
1310                    Contact::Accepted {
1311                        user_id: user_3,
1312                        should_notify: false
1313                    }
1314                ],
1315            );
1316
1317            // User declines a contact request. Both users see that it is gone.
1318            db.send_contact_request(user_2, user_3).await.unwrap();
1319            db.respond_to_contact_request(user_3, user_2, false)
1320                .await
1321                .unwrap();
1322            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1323            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1324            assert_eq!(
1325                db.get_contacts(user_2).await.unwrap(),
1326                &[
1327                    Contact::Accepted {
1328                        user_id: user_1,
1329                        should_notify: false
1330                    },
1331                    Contact::Accepted {
1332                        user_id: user_2,
1333                        should_notify: false
1334                    }
1335                ]
1336            );
1337            assert_eq!(
1338                db.get_contacts(user_3).await.unwrap(),
1339                &[
1340                    Contact::Accepted {
1341                        user_id: user_1,
1342                        should_notify: false
1343                    },
1344                    Contact::Accepted {
1345                        user_id: user_3,
1346                        should_notify: false
1347                    }
1348                ],
1349            );
1350        }
1351    }
1352
1353    #[tokio::test(flavor = "multi_thread")]
1354    async fn test_invite_codes() {
1355        let postgres = TestDb::postgres().await;
1356        let db = postgres.db();
1357        let user1 = db.create_user("user-1", false).await.unwrap();
1358
1359        // Initially, user 1 has no invite code
1360        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1361
1362        // User 1 creates an invite code that can be used twice.
1363        db.set_invite_count(user1, 2).await.unwrap();
1364        let (invite_code, invite_count) =
1365            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1366        assert_eq!(invite_count, 2);
1367
1368        // User 2 redeems the invite code and becomes a contact of user 1.
1369        let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
1370        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1371        assert_eq!(invite_count, 1);
1372        assert_eq!(
1373            db.get_contacts(user1).await.unwrap(),
1374            [
1375                Contact::Accepted {
1376                    user_id: user1,
1377                    should_notify: false
1378                },
1379                Contact::Accepted {
1380                    user_id: user2,
1381                    should_notify: true
1382                }
1383            ]
1384        );
1385        assert_eq!(
1386            db.get_contacts(user2).await.unwrap(),
1387            [
1388                Contact::Accepted {
1389                    user_id: user1,
1390                    should_notify: false
1391                },
1392                Contact::Accepted {
1393                    user_id: user2,
1394                    should_notify: false
1395                }
1396            ]
1397        );
1398
1399        // User 3 redeems the invite code and becomes a contact of user 1.
1400        let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
1401        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1402        assert_eq!(invite_count, 0);
1403        assert_eq!(
1404            db.get_contacts(user1).await.unwrap(),
1405            [
1406                Contact::Accepted {
1407                    user_id: user1,
1408                    should_notify: false
1409                },
1410                Contact::Accepted {
1411                    user_id: user2,
1412                    should_notify: true
1413                },
1414                Contact::Accepted {
1415                    user_id: user3,
1416                    should_notify: true
1417                }
1418            ]
1419        );
1420        assert_eq!(
1421            db.get_contacts(user3).await.unwrap(),
1422            [
1423                Contact::Accepted {
1424                    user_id: user1,
1425                    should_notify: false
1426                },
1427                Contact::Accepted {
1428                    user_id: user3,
1429                    should_notify: false
1430                },
1431            ]
1432        );
1433
1434        // Trying to reedem the code for the third time results in an error.
1435        db.redeem_invite_code(&invite_code, "user-4")
1436            .await
1437            .unwrap_err();
1438
1439        // Invite count can be updated after the code has been created.
1440        db.set_invite_count(user1, 2).await.unwrap();
1441        let (latest_code, invite_count) =
1442            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1443        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1444        assert_eq!(invite_count, 2);
1445
1446        // User 4 can now redeem the invite code and becomes a contact of user 1.
1447        let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
1448        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1449        assert_eq!(invite_count, 1);
1450        assert_eq!(
1451            db.get_contacts(user1).await.unwrap(),
1452            [
1453                Contact::Accepted {
1454                    user_id: user1,
1455                    should_notify: false
1456                },
1457                Contact::Accepted {
1458                    user_id: user2,
1459                    should_notify: true
1460                },
1461                Contact::Accepted {
1462                    user_id: user3,
1463                    should_notify: true
1464                },
1465                Contact::Accepted {
1466                    user_id: user4,
1467                    should_notify: true
1468                }
1469            ]
1470        );
1471        assert_eq!(
1472            db.get_contacts(user4).await.unwrap(),
1473            [
1474                Contact::Accepted {
1475                    user_id: user1,
1476                    should_notify: false
1477                },
1478                Contact::Accepted {
1479                    user_id: user4,
1480                    should_notify: false
1481                },
1482            ]
1483        );
1484
1485        // An existing user cannot redeem invite codes.
1486        db.redeem_invite_code(&invite_code, "user-2")
1487            .await
1488            .unwrap_err();
1489        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1490        assert_eq!(invite_count, 1);
1491    }
1492
1493    pub struct TestDb {
1494        pub db: Option<Arc<dyn Db>>,
1495        pub url: String,
1496    }
1497
1498    impl TestDb {
1499        pub async fn postgres() -> Self {
1500            lazy_static! {
1501                static ref LOCK: Mutex<()> = Mutex::new(());
1502            }
1503
1504            let _guard = LOCK.lock();
1505            let mut rng = StdRng::from_entropy();
1506            let name = format!("zed-test-{}", rng.gen::<u128>());
1507            let url = format!("postgres://postgres@localhost/{}", name);
1508            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1509            Postgres::create_database(&url)
1510                .await
1511                .expect("failed to create test db");
1512            let db = PostgresDb::new(&url, 5).await.unwrap();
1513            let migrator = Migrator::new(migrations_path).await.unwrap();
1514            migrator.run(&db.pool).await.unwrap();
1515            Self {
1516                db: Some(Arc::new(db)),
1517                url,
1518            }
1519        }
1520
1521        pub fn fake(background: Arc<Background>) -> Self {
1522            Self {
1523                db: Some(Arc::new(FakeDb::new(background))),
1524                url: Default::default(),
1525            }
1526        }
1527
1528        pub fn db(&self) -> &Arc<dyn Db> {
1529            self.db.as_ref().unwrap()
1530        }
1531    }
1532
1533    impl Drop for TestDb {
1534        fn drop(&mut self) {
1535            if let Some(db) = self.db.take() {
1536                futures::executor::block_on(db.teardown(&self.url));
1537            }
1538        }
1539    }
1540
1541    pub struct FakeDb {
1542        background: Arc<Background>,
1543        pub users: Mutex<BTreeMap<UserId, User>>,
1544        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1545        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1546        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1547        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1548        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1549        pub contacts: Mutex<Vec<FakeContact>>,
1550        next_channel_message_id: Mutex<i32>,
1551        next_user_id: Mutex<i32>,
1552        next_org_id: Mutex<i32>,
1553        next_channel_id: Mutex<i32>,
1554    }
1555
1556    #[derive(Debug)]
1557    pub struct FakeContact {
1558        pub requester_id: UserId,
1559        pub responder_id: UserId,
1560        pub accepted: bool,
1561        pub should_notify: bool,
1562    }
1563
1564    impl FakeDb {
1565        pub fn new(background: Arc<Background>) -> Self {
1566            Self {
1567                background,
1568                users: Default::default(),
1569                next_user_id: Mutex::new(1),
1570                orgs: Default::default(),
1571                next_org_id: Mutex::new(1),
1572                org_memberships: Default::default(),
1573                channels: Default::default(),
1574                next_channel_id: Mutex::new(1),
1575                channel_memberships: Default::default(),
1576                channel_messages: Default::default(),
1577                next_channel_message_id: Mutex::new(1),
1578                contacts: Default::default(),
1579            }
1580        }
1581    }
1582
1583    #[async_trait]
1584    impl Db for FakeDb {
1585        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
1586            self.background.simulate_random_delay().await;
1587
1588            let mut users = self.users.lock();
1589            if let Some(user) = users
1590                .values()
1591                .find(|user| user.github_login == github_login)
1592            {
1593                Ok(user.id)
1594            } else {
1595                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1596                users.insert(
1597                    user_id,
1598                    User {
1599                        id: user_id,
1600                        github_login: github_login.to_string(),
1601                        admin,
1602                        invite_code: None,
1603                        invite_count: 0,
1604                    },
1605                );
1606                Ok(user_id)
1607            }
1608        }
1609
1610        async fn get_all_users(&self) -> Result<Vec<User>> {
1611            unimplemented!()
1612        }
1613
1614        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1615            unimplemented!()
1616        }
1617
1618        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1619            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1620        }
1621
1622        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1623            self.background.simulate_random_delay().await;
1624            let users = self.users.lock();
1625            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1626        }
1627
1628        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1629            Ok(self
1630                .users
1631                .lock()
1632                .values()
1633                .find(|user| user.github_login == github_login)
1634                .cloned())
1635        }
1636
1637        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1638            unimplemented!()
1639        }
1640
1641        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1642            unimplemented!()
1643        }
1644
1645        // invite codes
1646
1647        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1648            unimplemented!()
1649        }
1650
1651        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1652            unimplemented!()
1653        }
1654
1655        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1656            unimplemented!()
1657        }
1658
1659        async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
1660            unimplemented!()
1661        }
1662
1663        // contacts
1664
1665        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1666            self.background.simulate_random_delay().await;
1667            let mut contacts = vec![Contact::Accepted {
1668                user_id: id,
1669                should_notify: false,
1670            }];
1671
1672            for contact in self.contacts.lock().iter() {
1673                if contact.requester_id == id {
1674                    if contact.accepted {
1675                        contacts.push(Contact::Accepted {
1676                            user_id: contact.responder_id,
1677                            should_notify: contact.should_notify,
1678                        });
1679                    } else {
1680                        contacts.push(Contact::Outgoing {
1681                            user_id: contact.responder_id,
1682                        });
1683                    }
1684                } else if contact.responder_id == id {
1685                    if contact.accepted {
1686                        contacts.push(Contact::Accepted {
1687                            user_id: contact.requester_id,
1688                            should_notify: false,
1689                        });
1690                    } else {
1691                        contacts.push(Contact::Incoming {
1692                            user_id: contact.requester_id,
1693                            should_notify: contact.should_notify,
1694                        });
1695                    }
1696                }
1697            }
1698
1699            contacts.sort_unstable_by_key(|contact| contact.user_id());
1700            Ok(contacts)
1701        }
1702
1703        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1704            self.background.simulate_random_delay().await;
1705            Ok(self.contacts.lock().iter().any(|contact| {
1706                contact.accepted
1707                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1708                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1709            }))
1710        }
1711
1712        async fn send_contact_request(
1713            &self,
1714            requester_id: UserId,
1715            responder_id: UserId,
1716        ) -> Result<()> {
1717            let mut contacts = self.contacts.lock();
1718            for contact in contacts.iter_mut() {
1719                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1720                    if contact.accepted {
1721                        Err(anyhow!("contact already exists"))?;
1722                    } else {
1723                        Err(anyhow!("contact already requested"))?;
1724                    }
1725                }
1726                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1727                    if contact.accepted {
1728                        Err(anyhow!("contact already exists"))?;
1729                    } else {
1730                        contact.accepted = true;
1731                        contact.should_notify = false;
1732                        return Ok(());
1733                    }
1734                }
1735            }
1736            contacts.push(FakeContact {
1737                requester_id,
1738                responder_id,
1739                accepted: false,
1740                should_notify: true,
1741            });
1742            Ok(())
1743        }
1744
1745        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1746            self.contacts.lock().retain(|contact| {
1747                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1748            });
1749            Ok(())
1750        }
1751
1752        async fn dismiss_contact_notification(
1753            &self,
1754            user_id: UserId,
1755            contact_user_id: UserId,
1756        ) -> Result<()> {
1757            let mut contacts = self.contacts.lock();
1758            for contact in contacts.iter_mut() {
1759                if contact.requester_id == contact_user_id
1760                    && contact.responder_id == user_id
1761                    && !contact.accepted
1762                {
1763                    contact.should_notify = false;
1764                    return Ok(());
1765                }
1766                if contact.requester_id == user_id
1767                    && contact.responder_id == contact_user_id
1768                    && contact.accepted
1769                {
1770                    contact.should_notify = false;
1771                    return Ok(());
1772                }
1773            }
1774            Err(anyhow!("no such notification"))?
1775        }
1776
1777        async fn respond_to_contact_request(
1778            &self,
1779            responder_id: UserId,
1780            requester_id: UserId,
1781            accept: bool,
1782        ) -> Result<()> {
1783            let mut contacts = self.contacts.lock();
1784            for (ix, contact) in contacts.iter_mut().enumerate() {
1785                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1786                    if contact.accepted {
1787                        Err(anyhow!("contact already confirmed"))?;
1788                    }
1789                    if accept {
1790                        contact.accepted = true;
1791                        contact.should_notify = true;
1792                    } else {
1793                        contacts.remove(ix);
1794                    }
1795                    return Ok(());
1796                }
1797            }
1798            Err(anyhow!("no such contact request"))?
1799        }
1800
1801        async fn create_access_token_hash(
1802            &self,
1803            _user_id: UserId,
1804            _access_token_hash: &str,
1805            _max_access_token_count: usize,
1806        ) -> Result<()> {
1807            unimplemented!()
1808        }
1809
1810        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1811            unimplemented!()
1812        }
1813
1814        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1815            unimplemented!()
1816        }
1817
1818        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1819            self.background.simulate_random_delay().await;
1820            let mut orgs = self.orgs.lock();
1821            if orgs.values().any(|org| org.slug == slug) {
1822                Err(anyhow!("org already exists"))?
1823            } else {
1824                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1825                orgs.insert(
1826                    org_id,
1827                    Org {
1828                        id: org_id,
1829                        name: name.to_string(),
1830                        slug: slug.to_string(),
1831                    },
1832                );
1833                Ok(org_id)
1834            }
1835        }
1836
1837        async fn add_org_member(
1838            &self,
1839            org_id: OrgId,
1840            user_id: UserId,
1841            is_admin: bool,
1842        ) -> Result<()> {
1843            self.background.simulate_random_delay().await;
1844            if !self.orgs.lock().contains_key(&org_id) {
1845                Err(anyhow!("org does not exist"))?;
1846            }
1847            if !self.users.lock().contains_key(&user_id) {
1848                Err(anyhow!("user does not exist"))?;
1849            }
1850
1851            self.org_memberships
1852                .lock()
1853                .entry((org_id, user_id))
1854                .or_insert(is_admin);
1855            Ok(())
1856        }
1857
1858        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1859            self.background.simulate_random_delay().await;
1860            if !self.orgs.lock().contains_key(&org_id) {
1861                Err(anyhow!("org does not exist"))?;
1862            }
1863
1864            let mut channels = self.channels.lock();
1865            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1866            channels.insert(
1867                channel_id,
1868                Channel {
1869                    id: channel_id,
1870                    name: name.to_string(),
1871                    owner_id: org_id.0,
1872                    owner_is_user: false,
1873                },
1874            );
1875            Ok(channel_id)
1876        }
1877
1878        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1879            self.background.simulate_random_delay().await;
1880            Ok(self
1881                .channels
1882                .lock()
1883                .values()
1884                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1885                .cloned()
1886                .collect())
1887        }
1888
1889        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1890            self.background.simulate_random_delay().await;
1891            let channels = self.channels.lock();
1892            let memberships = self.channel_memberships.lock();
1893            Ok(channels
1894                .values()
1895                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1896                .cloned()
1897                .collect())
1898        }
1899
1900        async fn can_user_access_channel(
1901            &self,
1902            user_id: UserId,
1903            channel_id: ChannelId,
1904        ) -> Result<bool> {
1905            self.background.simulate_random_delay().await;
1906            Ok(self
1907                .channel_memberships
1908                .lock()
1909                .contains_key(&(channel_id, user_id)))
1910        }
1911
1912        async fn add_channel_member(
1913            &self,
1914            channel_id: ChannelId,
1915            user_id: UserId,
1916            is_admin: bool,
1917        ) -> Result<()> {
1918            self.background.simulate_random_delay().await;
1919            if !self.channels.lock().contains_key(&channel_id) {
1920                Err(anyhow!("channel does not exist"))?;
1921            }
1922            if !self.users.lock().contains_key(&user_id) {
1923                Err(anyhow!("user does not exist"))?;
1924            }
1925
1926            self.channel_memberships
1927                .lock()
1928                .entry((channel_id, user_id))
1929                .or_insert(is_admin);
1930            Ok(())
1931        }
1932
1933        async fn create_channel_message(
1934            &self,
1935            channel_id: ChannelId,
1936            sender_id: UserId,
1937            body: &str,
1938            timestamp: OffsetDateTime,
1939            nonce: u128,
1940        ) -> Result<MessageId> {
1941            self.background.simulate_random_delay().await;
1942            if !self.channels.lock().contains_key(&channel_id) {
1943                Err(anyhow!("channel does not exist"))?;
1944            }
1945            if !self.users.lock().contains_key(&sender_id) {
1946                Err(anyhow!("user does not exist"))?;
1947            }
1948
1949            let mut messages = self.channel_messages.lock();
1950            if let Some(message) = messages
1951                .values()
1952                .find(|message| message.nonce.as_u128() == nonce)
1953            {
1954                Ok(message.id)
1955            } else {
1956                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1957                messages.insert(
1958                    message_id,
1959                    ChannelMessage {
1960                        id: message_id,
1961                        channel_id,
1962                        sender_id,
1963                        body: body.to_string(),
1964                        sent_at: timestamp,
1965                        nonce: Uuid::from_u128(nonce),
1966                    },
1967                );
1968                Ok(message_id)
1969            }
1970        }
1971
1972        async fn get_channel_messages(
1973            &self,
1974            channel_id: ChannelId,
1975            count: usize,
1976            before_id: Option<MessageId>,
1977        ) -> Result<Vec<ChannelMessage>> {
1978            let mut messages = self
1979                .channel_messages
1980                .lock()
1981                .values()
1982                .rev()
1983                .filter(|message| {
1984                    message.channel_id == channel_id
1985                        && message.id < before_id.unwrap_or(MessageId::MAX)
1986                })
1987                .take(count)
1988                .cloned()
1989                .collect::<Vec<_>>();
1990            messages.sort_unstable_by_key(|message| message.id);
1991            Ok(messages)
1992        }
1993
1994        async fn teardown(&self, _: &str) {}
1995
1996        #[cfg(test)]
1997        fn as_fake(&self) -> Option<&FakeDb> {
1998            Some(self)
1999        }
2000    }
2001}