db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use futures::StreamExt;
   6use nanoid::nanoid;
   7use serde::Serialize;
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{types::Uuid, FromRow};
  10use time::OffsetDateTime;
  11
  12#[async_trait]
  13pub trait Db: Send + Sync {
  14    async fn create_user(
  15        &self,
  16        github_login: &str,
  17        email_address: Option<&str>,
  18        admin: bool,
  19    ) -> Result<UserId>;
  20    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  21    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  22    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  23    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  24    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  25    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  26    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  27    async fn destroy_user(&self, id: UserId) -> Result<()>;
  28
  29    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  30    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  31    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  32    async fn redeem_invite_code(
  33        &self,
  34        code: &str,
  35        login: &str,
  36        email_address: Option<&str>,
  37    ) -> Result<UserId>;
  38
  39    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  40    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  41    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  42    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  43    async fn dismiss_contact_notification(
  44        &self,
  45        responder_id: UserId,
  46        requester_id: UserId,
  47    ) -> Result<()>;
  48    async fn respond_to_contact_request(
  49        &self,
  50        responder_id: UserId,
  51        requester_id: UserId,
  52        accept: bool,
  53    ) -> Result<()>;
  54
  55    async fn create_access_token_hash(
  56        &self,
  57        user_id: UserId,
  58        access_token_hash: &str,
  59        max_access_token_count: usize,
  60    ) -> Result<()>;
  61    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  62    #[cfg(any(test, feature = "seed-support"))]
  63
  64    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  65    #[cfg(any(test, feature = "seed-support"))]
  66    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  67    #[cfg(any(test, feature = "seed-support"))]
  68    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  69    #[cfg(any(test, feature = "seed-support"))]
  70    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  71    #[cfg(any(test, feature = "seed-support"))]
  72
  73    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  74    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  75    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  76        -> Result<bool>;
  77    #[cfg(any(test, feature = "seed-support"))]
  78    async fn add_channel_member(
  79        &self,
  80        channel_id: ChannelId,
  81        user_id: UserId,
  82        is_admin: bool,
  83    ) -> Result<()>;
  84    async fn create_channel_message(
  85        &self,
  86        channel_id: ChannelId,
  87        sender_id: UserId,
  88        body: &str,
  89        timestamp: OffsetDateTime,
  90        nonce: u128,
  91    ) -> Result<MessageId>;
  92    async fn get_channel_messages(
  93        &self,
  94        channel_id: ChannelId,
  95        count: usize,
  96        before_id: Option<MessageId>,
  97    ) -> Result<Vec<ChannelMessage>>;
  98    #[cfg(test)]
  99    async fn teardown(&self, url: &str);
 100    #[cfg(test)]
 101    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 102}
 103
 104pub struct PostgresDb {
 105    pool: sqlx::PgPool,
 106}
 107
 108impl PostgresDb {
 109    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 110        let pool = DbOptions::new()
 111            .max_connections(max_connections)
 112            .connect(&url)
 113            .await
 114            .context("failed to connect to postgres database")?;
 115        Ok(Self { pool })
 116    }
 117}
 118
 119#[async_trait]
 120impl Db for PostgresDb {
 121    // users
 122
 123    async fn create_user(
 124        &self,
 125        github_login: &str,
 126        email_address: Option<&str>,
 127        admin: bool,
 128    ) -> Result<UserId> {
 129        let query = "
 130            INSERT INTO users (github_login, email_address, admin)
 131            VALUES ($1, $2, $3)
 132            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 133            RETURNING id
 134        ";
 135        Ok(sqlx::query_scalar(query)
 136            .bind(github_login)
 137            .bind(email_address)
 138            .bind(admin)
 139            .fetch_one(&self.pool)
 140            .await
 141            .map(UserId)?)
 142    }
 143
 144    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 145        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 146        Ok(sqlx::query_as(query)
 147            .bind(limit)
 148            .bind(page * limit)
 149            .fetch_all(&self.pool)
 150            .await?)
 151    }
 152
 153    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 154        let like_string = fuzzy_like_string(name_query);
 155        let query = "
 156            SELECT users.*
 157            FROM users
 158            WHERE github_login ILIKE $1
 159            ORDER BY github_login <-> $2
 160            LIMIT $3
 161        ";
 162        Ok(sqlx::query_as(query)
 163            .bind(like_string)
 164            .bind(name_query)
 165            .bind(limit)
 166            .fetch_all(&self.pool)
 167            .await?)
 168    }
 169
 170    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 171        let users = self.get_users_by_ids(vec![id]).await?;
 172        Ok(users.into_iter().next())
 173    }
 174
 175    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 176        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 177        let query = "
 178            SELECT users.*
 179            FROM users
 180            WHERE users.id = ANY ($1)
 181        ";
 182        Ok(sqlx::query_as(query)
 183            .bind(&ids)
 184            .fetch_all(&self.pool)
 185            .await?)
 186    }
 187
 188    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 189        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 190        Ok(sqlx::query_as(query)
 191            .bind(github_login)
 192            .fetch_optional(&self.pool)
 193            .await?)
 194    }
 195
 196    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 197        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 198        Ok(sqlx::query(query)
 199            .bind(is_admin)
 200            .bind(id.0)
 201            .execute(&self.pool)
 202            .await
 203            .map(drop)?)
 204    }
 205
 206    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 207        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 208        Ok(sqlx::query(query)
 209            .bind(connected_once)
 210            .bind(id.0)
 211            .execute(&self.pool)
 212            .await
 213            .map(drop)?)
 214    }
 215
 216    async fn destroy_user(&self, id: UserId) -> Result<()> {
 217        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 218        sqlx::query(query)
 219            .bind(id.0)
 220            .execute(&self.pool)
 221            .await
 222            .map(drop)?;
 223        let query = "DELETE FROM users WHERE id = $1;";
 224        Ok(sqlx::query(query)
 225            .bind(id.0)
 226            .execute(&self.pool)
 227            .await
 228            .map(drop)?)
 229    }
 230
 231    // invite codes
 232
 233    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 234        let mut tx = self.pool.begin().await?;
 235        if count > 0 {
 236            sqlx::query(
 237                "
 238                UPDATE users
 239                SET invite_code = $1
 240                WHERE id = $2 AND invite_code IS NULL
 241            ",
 242            )
 243            .bind(nanoid!(16))
 244            .bind(id)
 245            .execute(&mut tx)
 246            .await?;
 247        }
 248
 249        sqlx::query(
 250            "
 251            UPDATE users
 252            SET invite_count = $1
 253            WHERE id = $2
 254            ",
 255        )
 256        .bind(count)
 257        .bind(id)
 258        .execute(&mut tx)
 259        .await?;
 260        tx.commit().await?;
 261        Ok(())
 262    }
 263
 264    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 265        let result: Option<(String, i32)> = sqlx::query_as(
 266            "
 267                SELECT invite_code, invite_count
 268                FROM users
 269                WHERE id = $1 AND invite_code IS NOT NULL 
 270            ",
 271        )
 272        .bind(id)
 273        .fetch_optional(&self.pool)
 274        .await?;
 275        if let Some((code, count)) = result {
 276            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 277        } else {
 278            Ok(None)
 279        }
 280    }
 281
 282    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 283        sqlx::query_as(
 284            "
 285                SELECT *
 286                FROM users
 287                WHERE invite_code = $1
 288            ",
 289        )
 290        .bind(code)
 291        .fetch_optional(&self.pool)
 292        .await?
 293        .ok_or_else(|| {
 294            Error::Http(
 295                StatusCode::NOT_FOUND,
 296                "that invite code does not exist".to_string(),
 297            )
 298        })
 299    }
 300
 301    async fn redeem_invite_code(
 302        &self,
 303        code: &str,
 304        login: &str,
 305        email_address: Option<&str>,
 306    ) -> Result<UserId> {
 307        let mut tx = self.pool.begin().await?;
 308
 309        let inviter_id: Option<UserId> = sqlx::query_scalar(
 310            "
 311                UPDATE users
 312                SET invite_count = invite_count - 1
 313                WHERE
 314                    invite_code = $1 AND
 315                    invite_count > 0
 316                RETURNING id
 317            ",
 318        )
 319        .bind(code)
 320        .fetch_optional(&mut tx)
 321        .await?;
 322
 323        let inviter_id = match inviter_id {
 324            Some(inviter_id) => inviter_id,
 325            None => {
 326                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 327                    .bind(code)
 328                    .fetch_optional(&mut tx)
 329                    .await?
 330                    .is_some()
 331                {
 332                    Err(Error::Http(
 333                        StatusCode::UNAUTHORIZED,
 334                        "no invites remaining".to_string(),
 335                    ))?
 336                } else {
 337                    Err(Error::Http(
 338                        StatusCode::NOT_FOUND,
 339                        "invite code not found".to_string(),
 340                    ))?
 341                }
 342            }
 343        };
 344
 345        let invitee_id = sqlx::query_scalar(
 346            "
 347                INSERT INTO users
 348                    (github_login, email_address, admin, inviter_id)
 349                VALUES
 350                    ($1, $2, 'f', $3)
 351                RETURNING id
 352            ",
 353        )
 354        .bind(login)
 355        .bind(email_address)
 356        .bind(inviter_id)
 357        .fetch_one(&mut tx)
 358        .await
 359        .map(UserId)?;
 360
 361        sqlx::query(
 362            "
 363                INSERT INTO contacts
 364                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 365                VALUES
 366                    ($1, $2, 't', 't', 't')
 367            ",
 368        )
 369        .bind(inviter_id)
 370        .bind(invitee_id)
 371        .execute(&mut tx)
 372        .await?;
 373
 374        tx.commit().await?;
 375        Ok(invitee_id)
 376    }
 377
 378    // contacts
 379
 380    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 381        let query = "
 382            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 383            FROM contacts
 384            WHERE user_id_a = $1 OR user_id_b = $1;
 385        ";
 386
 387        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 388            .bind(user_id)
 389            .fetch(&self.pool);
 390
 391        let mut contacts = vec![Contact::Accepted {
 392            user_id,
 393            should_notify: false,
 394        }];
 395        while let Some(row) = rows.next().await {
 396            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 397
 398            if user_id_a == user_id {
 399                if accepted {
 400                    contacts.push(Contact::Accepted {
 401                        user_id: user_id_b,
 402                        should_notify: should_notify && a_to_b,
 403                    });
 404                } else if a_to_b {
 405                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 406                } else {
 407                    contacts.push(Contact::Incoming {
 408                        user_id: user_id_b,
 409                        should_notify,
 410                    });
 411                }
 412            } else {
 413                if accepted {
 414                    contacts.push(Contact::Accepted {
 415                        user_id: user_id_a,
 416                        should_notify: should_notify && !a_to_b,
 417                    });
 418                } else if a_to_b {
 419                    contacts.push(Contact::Incoming {
 420                        user_id: user_id_a,
 421                        should_notify,
 422                    });
 423                } else {
 424                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 425                }
 426            }
 427        }
 428
 429        contacts.sort_unstable_by_key(|contact| contact.user_id());
 430
 431        Ok(contacts)
 432    }
 433
 434    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 435        let (id_a, id_b) = if user_id_1 < user_id_2 {
 436            (user_id_1, user_id_2)
 437        } else {
 438            (user_id_2, user_id_1)
 439        };
 440
 441        let query = "
 442            SELECT 1 FROM contacts
 443            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 444            LIMIT 1
 445        ";
 446        Ok(sqlx::query_scalar::<_, i32>(query)
 447            .bind(id_a.0)
 448            .bind(id_b.0)
 449            .fetch_optional(&self.pool)
 450            .await?
 451            .is_some())
 452    }
 453
 454    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 455        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 456            (sender_id, receiver_id, true)
 457        } else {
 458            (receiver_id, sender_id, false)
 459        };
 460        let query = "
 461            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 462            VALUES ($1, $2, $3, 'f', 't')
 463            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 464            SET
 465                accepted = 't',
 466                should_notify = 'f'
 467            WHERE
 468                NOT contacts.accepted AND
 469                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 470                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 471        ";
 472        let result = sqlx::query(query)
 473            .bind(id_a.0)
 474            .bind(id_b.0)
 475            .bind(a_to_b)
 476            .execute(&self.pool)
 477            .await?;
 478
 479        if result.rows_affected() == 1 {
 480            Ok(())
 481        } else {
 482            Err(anyhow!("contact already requested"))?
 483        }
 484    }
 485
 486    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 487        let (id_a, id_b) = if responder_id < requester_id {
 488            (responder_id, requester_id)
 489        } else {
 490            (requester_id, responder_id)
 491        };
 492        let query = "
 493            DELETE FROM contacts
 494            WHERE user_id_a = $1 AND user_id_b = $2;
 495        ";
 496        let result = sqlx::query(query)
 497            .bind(id_a.0)
 498            .bind(id_b.0)
 499            .execute(&self.pool)
 500            .await?;
 501
 502        if result.rows_affected() == 1 {
 503            Ok(())
 504        } else {
 505            Err(anyhow!("no such contact"))?
 506        }
 507    }
 508
 509    async fn dismiss_contact_notification(
 510        &self,
 511        user_id: UserId,
 512        contact_user_id: UserId,
 513    ) -> Result<()> {
 514        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 515            (user_id, contact_user_id, true)
 516        } else {
 517            (contact_user_id, user_id, false)
 518        };
 519
 520        let query = "
 521            UPDATE contacts
 522            SET should_notify = 'f'
 523            WHERE
 524                user_id_a = $1 AND user_id_b = $2 AND
 525                (
 526                    (a_to_b = $3 AND accepted) OR
 527                    (a_to_b != $3 AND NOT accepted)
 528                );
 529        ";
 530
 531        let result = sqlx::query(query)
 532            .bind(id_a.0)
 533            .bind(id_b.0)
 534            .bind(a_to_b)
 535            .execute(&self.pool)
 536            .await?;
 537
 538        if result.rows_affected() == 0 {
 539            Err(anyhow!("no such contact request"))?;
 540        }
 541
 542        Ok(())
 543    }
 544
 545    async fn respond_to_contact_request(
 546        &self,
 547        responder_id: UserId,
 548        requester_id: UserId,
 549        accept: bool,
 550    ) -> Result<()> {
 551        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 552            (responder_id, requester_id, false)
 553        } else {
 554            (requester_id, responder_id, true)
 555        };
 556        let result = if accept {
 557            let query = "
 558                UPDATE contacts
 559                SET accepted = 't', should_notify = 't'
 560                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 561            ";
 562            sqlx::query(query)
 563                .bind(id_a.0)
 564                .bind(id_b.0)
 565                .bind(a_to_b)
 566                .execute(&self.pool)
 567                .await?
 568        } else {
 569            let query = "
 570                DELETE FROM contacts
 571                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 572            ";
 573            sqlx::query(query)
 574                .bind(id_a.0)
 575                .bind(id_b.0)
 576                .bind(a_to_b)
 577                .execute(&self.pool)
 578                .await?
 579        };
 580        if result.rows_affected() == 1 {
 581            Ok(())
 582        } else {
 583            Err(anyhow!("no such contact request"))?
 584        }
 585    }
 586
 587    // access tokens
 588
 589    async fn create_access_token_hash(
 590        &self,
 591        user_id: UserId,
 592        access_token_hash: &str,
 593        max_access_token_count: usize,
 594    ) -> Result<()> {
 595        let insert_query = "
 596            INSERT INTO access_tokens (user_id, hash)
 597            VALUES ($1, $2);
 598        ";
 599        let cleanup_query = "
 600            DELETE FROM access_tokens
 601            WHERE id IN (
 602                SELECT id from access_tokens
 603                WHERE user_id = $1
 604                ORDER BY id DESC
 605                OFFSET $3
 606            )
 607        ";
 608
 609        let mut tx = self.pool.begin().await?;
 610        sqlx::query(insert_query)
 611            .bind(user_id.0)
 612            .bind(access_token_hash)
 613            .execute(&mut tx)
 614            .await?;
 615        sqlx::query(cleanup_query)
 616            .bind(user_id.0)
 617            .bind(access_token_hash)
 618            .bind(max_access_token_count as u32)
 619            .execute(&mut tx)
 620            .await?;
 621        Ok(tx.commit().await?)
 622    }
 623
 624    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 625        let query = "
 626            SELECT hash
 627            FROM access_tokens
 628            WHERE user_id = $1
 629            ORDER BY id DESC
 630        ";
 631        Ok(sqlx::query_scalar(query)
 632            .bind(user_id.0)
 633            .fetch_all(&self.pool)
 634            .await?)
 635    }
 636
 637    // orgs
 638
 639    #[allow(unused)] // Help rust-analyzer
 640    #[cfg(any(test, feature = "seed-support"))]
 641    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 642        let query = "
 643            SELECT *
 644            FROM orgs
 645            WHERE slug = $1
 646        ";
 647        Ok(sqlx::query_as(query)
 648            .bind(slug)
 649            .fetch_optional(&self.pool)
 650            .await?)
 651    }
 652
 653    #[cfg(any(test, feature = "seed-support"))]
 654    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 655        let query = "
 656            INSERT INTO orgs (name, slug)
 657            VALUES ($1, $2)
 658            RETURNING id
 659        ";
 660        Ok(sqlx::query_scalar(query)
 661            .bind(name)
 662            .bind(slug)
 663            .fetch_one(&self.pool)
 664            .await
 665            .map(OrgId)?)
 666    }
 667
 668    #[cfg(any(test, feature = "seed-support"))]
 669    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 670        let query = "
 671            INSERT INTO org_memberships (org_id, user_id, admin)
 672            VALUES ($1, $2, $3)
 673            ON CONFLICT DO NOTHING
 674        ";
 675        Ok(sqlx::query(query)
 676            .bind(org_id.0)
 677            .bind(user_id.0)
 678            .bind(is_admin)
 679            .execute(&self.pool)
 680            .await
 681            .map(drop)?)
 682    }
 683
 684    // channels
 685
 686    #[cfg(any(test, feature = "seed-support"))]
 687    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 688        let query = "
 689            INSERT INTO channels (owner_id, owner_is_user, name)
 690            VALUES ($1, false, $2)
 691            RETURNING id
 692        ";
 693        Ok(sqlx::query_scalar(query)
 694            .bind(org_id.0)
 695            .bind(name)
 696            .fetch_one(&self.pool)
 697            .await
 698            .map(ChannelId)?)
 699    }
 700
 701    #[allow(unused)] // Help rust-analyzer
 702    #[cfg(any(test, feature = "seed-support"))]
 703    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 704        let query = "
 705            SELECT *
 706            FROM channels
 707            WHERE
 708                channels.owner_is_user = false AND
 709                channels.owner_id = $1
 710        ";
 711        Ok(sqlx::query_as(query)
 712            .bind(org_id.0)
 713            .fetch_all(&self.pool)
 714            .await?)
 715    }
 716
 717    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 718        let query = "
 719            SELECT
 720                channels.*
 721            FROM
 722                channel_memberships, channels
 723            WHERE
 724                channel_memberships.user_id = $1 AND
 725                channel_memberships.channel_id = channels.id
 726        ";
 727        Ok(sqlx::query_as(query)
 728            .bind(user_id.0)
 729            .fetch_all(&self.pool)
 730            .await?)
 731    }
 732
 733    async fn can_user_access_channel(
 734        &self,
 735        user_id: UserId,
 736        channel_id: ChannelId,
 737    ) -> Result<bool> {
 738        let query = "
 739            SELECT id
 740            FROM channel_memberships
 741            WHERE user_id = $1 AND channel_id = $2
 742            LIMIT 1
 743        ";
 744        Ok(sqlx::query_scalar::<_, i32>(query)
 745            .bind(user_id.0)
 746            .bind(channel_id.0)
 747            .fetch_optional(&self.pool)
 748            .await
 749            .map(|e| e.is_some())?)
 750    }
 751
 752    #[cfg(any(test, feature = "seed-support"))]
 753    async fn add_channel_member(
 754        &self,
 755        channel_id: ChannelId,
 756        user_id: UserId,
 757        is_admin: bool,
 758    ) -> Result<()> {
 759        let query = "
 760            INSERT INTO channel_memberships (channel_id, user_id, admin)
 761            VALUES ($1, $2, $3)
 762            ON CONFLICT DO NOTHING
 763        ";
 764        Ok(sqlx::query(query)
 765            .bind(channel_id.0)
 766            .bind(user_id.0)
 767            .bind(is_admin)
 768            .execute(&self.pool)
 769            .await
 770            .map(drop)?)
 771    }
 772
 773    // messages
 774
 775    async fn create_channel_message(
 776        &self,
 777        channel_id: ChannelId,
 778        sender_id: UserId,
 779        body: &str,
 780        timestamp: OffsetDateTime,
 781        nonce: u128,
 782    ) -> Result<MessageId> {
 783        let query = "
 784            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 785            VALUES ($1, $2, $3, $4, $5)
 786            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 787            RETURNING id
 788        ";
 789        Ok(sqlx::query_scalar(query)
 790            .bind(channel_id.0)
 791            .bind(sender_id.0)
 792            .bind(body)
 793            .bind(timestamp)
 794            .bind(Uuid::from_u128(nonce))
 795            .fetch_one(&self.pool)
 796            .await
 797            .map(MessageId)?)
 798    }
 799
 800    async fn get_channel_messages(
 801        &self,
 802        channel_id: ChannelId,
 803        count: usize,
 804        before_id: Option<MessageId>,
 805    ) -> Result<Vec<ChannelMessage>> {
 806        let query = r#"
 807            SELECT * FROM (
 808                SELECT
 809                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 810                FROM
 811                    channel_messages
 812                WHERE
 813                    channel_id = $1 AND
 814                    id < $2
 815                ORDER BY id DESC
 816                LIMIT $3
 817            ) as recent_messages
 818            ORDER BY id ASC
 819        "#;
 820        Ok(sqlx::query_as(query)
 821            .bind(channel_id.0)
 822            .bind(before_id.unwrap_or(MessageId::MAX))
 823            .bind(count as i64)
 824            .fetch_all(&self.pool)
 825            .await?)
 826    }
 827
 828    #[cfg(test)]
 829    async fn teardown(&self, url: &str) {
 830        use util::ResultExt;
 831
 832        let query = "
 833            SELECT pg_terminate_backend(pg_stat_activity.pid)
 834            FROM pg_stat_activity
 835            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 836        ";
 837        sqlx::query(query).execute(&self.pool).await.log_err();
 838        self.pool.close().await;
 839        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 840            .await
 841            .log_err();
 842    }
 843
 844    #[cfg(test)]
 845    fn as_fake(&self) -> Option<&tests::FakeDb> {
 846        None
 847    }
 848}
 849
 850macro_rules! id_type {
 851    ($name:ident) => {
 852        #[derive(
 853            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 854        )]
 855        #[sqlx(transparent)]
 856        #[serde(transparent)]
 857        pub struct $name(pub i32);
 858
 859        impl $name {
 860            #[allow(unused)]
 861            pub const MAX: Self = Self(i32::MAX);
 862
 863            #[allow(unused)]
 864            pub fn from_proto(value: u64) -> Self {
 865                Self(value as i32)
 866            }
 867
 868            #[allow(unused)]
 869            pub fn to_proto(&self) -> u64 {
 870                self.0 as u64
 871            }
 872        }
 873
 874        impl std::fmt::Display for $name {
 875            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 876                self.0.fmt(f)
 877            }
 878        }
 879    };
 880}
 881
 882id_type!(UserId);
 883#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 884pub struct User {
 885    pub id: UserId,
 886    pub github_login: String,
 887    pub email_address: Option<String>,
 888    pub admin: bool,
 889    pub invite_code: Option<String>,
 890    pub invite_count: i32,
 891    pub connected_once: bool,
 892}
 893
 894id_type!(OrgId);
 895#[derive(FromRow)]
 896pub struct Org {
 897    pub id: OrgId,
 898    pub name: String,
 899    pub slug: String,
 900}
 901
 902id_type!(ChannelId);
 903#[derive(Clone, Debug, FromRow, Serialize)]
 904pub struct Channel {
 905    pub id: ChannelId,
 906    pub name: String,
 907    pub owner_id: i32,
 908    pub owner_is_user: bool,
 909}
 910
 911id_type!(MessageId);
 912#[derive(Clone, Debug, FromRow)]
 913pub struct ChannelMessage {
 914    pub id: MessageId,
 915    pub channel_id: ChannelId,
 916    pub sender_id: UserId,
 917    pub body: String,
 918    pub sent_at: OffsetDateTime,
 919    pub nonce: Uuid,
 920}
 921
 922#[derive(Clone, Debug, PartialEq, Eq)]
 923pub enum Contact {
 924    Accepted {
 925        user_id: UserId,
 926        should_notify: bool,
 927    },
 928    Outgoing {
 929        user_id: UserId,
 930    },
 931    Incoming {
 932        user_id: UserId,
 933        should_notify: bool,
 934    },
 935}
 936
 937impl Contact {
 938    pub fn user_id(&self) -> UserId {
 939        match self {
 940            Contact::Accepted { user_id, .. } => *user_id,
 941            Contact::Outgoing { user_id } => *user_id,
 942            Contact::Incoming { user_id, .. } => *user_id,
 943        }
 944    }
 945}
 946
 947#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 948pub struct IncomingContactRequest {
 949    pub requester_id: UserId,
 950    pub should_notify: bool,
 951}
 952
 953fn fuzzy_like_string(string: &str) -> String {
 954    let mut result = String::with_capacity(string.len() * 2 + 1);
 955    for c in string.chars() {
 956        if c.is_alphanumeric() {
 957            result.push('%');
 958            result.push(c);
 959        }
 960    }
 961    result.push('%');
 962    result
 963}
 964
 965#[cfg(test)]
 966pub mod tests {
 967    use super::*;
 968    use anyhow::anyhow;
 969    use collections::BTreeMap;
 970    use gpui::executor::Background;
 971    use lazy_static::lazy_static;
 972    use parking_lot::Mutex;
 973    use rand::prelude::*;
 974    use sqlx::{
 975        migrate::{MigrateDatabase, Migrator},
 976        Postgres,
 977    };
 978    use std::{path::Path, sync::Arc};
 979    use util::post_inc;
 980
 981    #[tokio::test(flavor = "multi_thread")]
 982    async fn test_get_users_by_ids() {
 983        for test_db in [
 984            TestDb::postgres().await,
 985            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 986        ] {
 987            let db = test_db.db();
 988
 989            let user = db.create_user("user", None, false).await.unwrap();
 990            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
 991            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
 992            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
 993
 994            assert_eq!(
 995                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 996                    .await
 997                    .unwrap(),
 998                vec![
 999                    User {
1000                        id: user,
1001                        github_login: "user".to_string(),
1002                        admin: false,
1003                        ..Default::default()
1004                    },
1005                    User {
1006                        id: friend1,
1007                        github_login: "friend-1".to_string(),
1008                        admin: false,
1009                        ..Default::default()
1010                    },
1011                    User {
1012                        id: friend2,
1013                        github_login: "friend-2".to_string(),
1014                        admin: false,
1015                        ..Default::default()
1016                    },
1017                    User {
1018                        id: friend3,
1019                        github_login: "friend-3".to_string(),
1020                        admin: false,
1021                        ..Default::default()
1022                    }
1023                ]
1024            );
1025        }
1026    }
1027
1028    #[tokio::test(flavor = "multi_thread")]
1029    async fn test_recent_channel_messages() {
1030        for test_db in [
1031            TestDb::postgres().await,
1032            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1033        ] {
1034            let db = test_db.db();
1035            let user = db.create_user("user", None, false).await.unwrap();
1036            let org = db.create_org("org", "org").await.unwrap();
1037            let channel = db.create_org_channel(org, "channel").await.unwrap();
1038            for i in 0..10 {
1039                db.create_channel_message(
1040                    channel,
1041                    user,
1042                    &i.to_string(),
1043                    OffsetDateTime::now_utc(),
1044                    i,
1045                )
1046                .await
1047                .unwrap();
1048            }
1049
1050            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1051            assert_eq!(
1052                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1053                ["5", "6", "7", "8", "9"]
1054            );
1055
1056            let prev_messages = db
1057                .get_channel_messages(channel, 4, Some(messages[0].id))
1058                .await
1059                .unwrap();
1060            assert_eq!(
1061                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1062                ["1", "2", "3", "4"]
1063            );
1064        }
1065    }
1066
1067    #[tokio::test(flavor = "multi_thread")]
1068    async fn test_channel_message_nonces() {
1069        for test_db in [
1070            TestDb::postgres().await,
1071            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1072        ] {
1073            let db = test_db.db();
1074            let user = db.create_user("user", None, false).await.unwrap();
1075            let org = db.create_org("org", "org").await.unwrap();
1076            let channel = db.create_org_channel(org, "channel").await.unwrap();
1077
1078            let msg1_id = db
1079                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1080                .await
1081                .unwrap();
1082            let msg2_id = db
1083                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1084                .await
1085                .unwrap();
1086            let msg3_id = db
1087                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1088                .await
1089                .unwrap();
1090            let msg4_id = db
1091                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1092                .await
1093                .unwrap();
1094
1095            assert_ne!(msg1_id, msg2_id);
1096            assert_eq!(msg1_id, msg3_id);
1097            assert_eq!(msg2_id, msg4_id);
1098        }
1099    }
1100
1101    #[tokio::test(flavor = "multi_thread")]
1102    async fn test_create_access_tokens() {
1103        let test_db = TestDb::postgres().await;
1104        let db = test_db.db();
1105        let user = db.create_user("the-user", None, false).await.unwrap();
1106
1107        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1108        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1109        assert_eq!(
1110            db.get_access_token_hashes(user).await.unwrap(),
1111            &["h2".to_string(), "h1".to_string()]
1112        );
1113
1114        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1115        assert_eq!(
1116            db.get_access_token_hashes(user).await.unwrap(),
1117            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1118        );
1119
1120        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1121        assert_eq!(
1122            db.get_access_token_hashes(user).await.unwrap(),
1123            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1124        );
1125
1126        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1127        assert_eq!(
1128            db.get_access_token_hashes(user).await.unwrap(),
1129            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1130        );
1131    }
1132
1133    #[test]
1134    fn test_fuzzy_like_string() {
1135        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1136        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1137        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1138    }
1139
1140    #[tokio::test(flavor = "multi_thread")]
1141    async fn test_fuzzy_search_users() {
1142        let test_db = TestDb::postgres().await;
1143        let db = test_db.db();
1144        for github_login in [
1145            "California",
1146            "colorado",
1147            "oregon",
1148            "washington",
1149            "florida",
1150            "delaware",
1151            "rhode-island",
1152        ] {
1153            db.create_user(github_login, None, false).await.unwrap();
1154        }
1155
1156        assert_eq!(
1157            fuzzy_search_user_names(db, "clr").await,
1158            &["colorado", "California"]
1159        );
1160        assert_eq!(
1161            fuzzy_search_user_names(db, "ro").await,
1162            &["rhode-island", "colorado", "oregon"],
1163        );
1164
1165        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1166            db.fuzzy_search_users(query, 10)
1167                .await
1168                .unwrap()
1169                .into_iter()
1170                .map(|user| user.github_login)
1171                .collect::<Vec<_>>()
1172        }
1173    }
1174
1175    #[tokio::test(flavor = "multi_thread")]
1176    async fn test_add_contacts() {
1177        for test_db in [
1178            TestDb::postgres().await,
1179            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1180        ] {
1181            let db = test_db.db();
1182
1183            let user_1 = db.create_user("user1", None, false).await.unwrap();
1184            let user_2 = db.create_user("user2", None, false).await.unwrap();
1185            let user_3 = db.create_user("user3", None, false).await.unwrap();
1186
1187            // User starts with no contacts
1188            assert_eq!(
1189                db.get_contacts(user_1).await.unwrap(),
1190                vec![Contact::Accepted {
1191                    user_id: user_1,
1192                    should_notify: false
1193                }],
1194            );
1195
1196            // User requests a contact. Both users see the pending request.
1197            db.send_contact_request(user_1, user_2).await.unwrap();
1198            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1199            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1200            assert_eq!(
1201                db.get_contacts(user_1).await.unwrap(),
1202                &[
1203                    Contact::Accepted {
1204                        user_id: user_1,
1205                        should_notify: false
1206                    },
1207                    Contact::Outgoing { user_id: user_2 }
1208                ],
1209            );
1210            assert_eq!(
1211                db.get_contacts(user_2).await.unwrap(),
1212                &[
1213                    Contact::Incoming {
1214                        user_id: user_1,
1215                        should_notify: true
1216                    },
1217                    Contact::Accepted {
1218                        user_id: user_2,
1219                        should_notify: false
1220                    },
1221                ]
1222            );
1223
1224            // User 2 dismisses the contact request notification without accepting or rejecting.
1225            // We shouldn't notify them again.
1226            db.dismiss_contact_notification(user_1, user_2)
1227                .await
1228                .unwrap_err();
1229            db.dismiss_contact_notification(user_2, user_1)
1230                .await
1231                .unwrap();
1232            assert_eq!(
1233                db.get_contacts(user_2).await.unwrap(),
1234                &[
1235                    Contact::Incoming {
1236                        user_id: user_1,
1237                        should_notify: false
1238                    },
1239                    Contact::Accepted {
1240                        user_id: user_2,
1241                        should_notify: false
1242                    },
1243                ]
1244            );
1245
1246            // User can't accept their own contact request
1247            db.respond_to_contact_request(user_1, user_2, true)
1248                .await
1249                .unwrap_err();
1250
1251            // User accepts a contact request. Both users see the contact.
1252            db.respond_to_contact_request(user_2, user_1, true)
1253                .await
1254                .unwrap();
1255            assert_eq!(
1256                db.get_contacts(user_1).await.unwrap(),
1257                &[
1258                    Contact::Accepted {
1259                        user_id: user_1,
1260                        should_notify: false
1261                    },
1262                    Contact::Accepted {
1263                        user_id: user_2,
1264                        should_notify: true
1265                    }
1266                ],
1267            );
1268            assert!(db.has_contact(user_1, user_2).await.unwrap());
1269            assert!(db.has_contact(user_2, user_1).await.unwrap());
1270            assert_eq!(
1271                db.get_contacts(user_2).await.unwrap(),
1272                &[
1273                    Contact::Accepted {
1274                        user_id: user_1,
1275                        should_notify: false,
1276                    },
1277                    Contact::Accepted {
1278                        user_id: user_2,
1279                        should_notify: false,
1280                    },
1281                ]
1282            );
1283
1284            // Users cannot re-request existing contacts.
1285            db.send_contact_request(user_1, user_2).await.unwrap_err();
1286            db.send_contact_request(user_2, user_1).await.unwrap_err();
1287
1288            // Users can't dismiss notifications of them accepting other users' requests.
1289            db.dismiss_contact_notification(user_2, user_1)
1290                .await
1291                .unwrap_err();
1292            assert_eq!(
1293                db.get_contacts(user_1).await.unwrap(),
1294                &[
1295                    Contact::Accepted {
1296                        user_id: user_1,
1297                        should_notify: false
1298                    },
1299                    Contact::Accepted {
1300                        user_id: user_2,
1301                        should_notify: true,
1302                    },
1303                ]
1304            );
1305
1306            // Users can dismiss notifications of other users accepting their requests.
1307            db.dismiss_contact_notification(user_1, user_2)
1308                .await
1309                .unwrap();
1310            assert_eq!(
1311                db.get_contacts(user_1).await.unwrap(),
1312                &[
1313                    Contact::Accepted {
1314                        user_id: user_1,
1315                        should_notify: false
1316                    },
1317                    Contact::Accepted {
1318                        user_id: user_2,
1319                        should_notify: false,
1320                    },
1321                ]
1322            );
1323
1324            // Users send each other concurrent contact requests and
1325            // see that they are immediately accepted.
1326            db.send_contact_request(user_1, user_3).await.unwrap();
1327            db.send_contact_request(user_3, user_1).await.unwrap();
1328            assert_eq!(
1329                db.get_contacts(user_1).await.unwrap(),
1330                &[
1331                    Contact::Accepted {
1332                        user_id: user_1,
1333                        should_notify: false
1334                    },
1335                    Contact::Accepted {
1336                        user_id: user_2,
1337                        should_notify: false,
1338                    },
1339                    Contact::Accepted {
1340                        user_id: user_3,
1341                        should_notify: false
1342                    },
1343                ]
1344            );
1345            assert_eq!(
1346                db.get_contacts(user_3).await.unwrap(),
1347                &[
1348                    Contact::Accepted {
1349                        user_id: user_1,
1350                        should_notify: false
1351                    },
1352                    Contact::Accepted {
1353                        user_id: user_3,
1354                        should_notify: false
1355                    }
1356                ],
1357            );
1358
1359            // User declines a contact request. Both users see that it is gone.
1360            db.send_contact_request(user_2, user_3).await.unwrap();
1361            db.respond_to_contact_request(user_3, user_2, false)
1362                .await
1363                .unwrap();
1364            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1365            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1366            assert_eq!(
1367                db.get_contacts(user_2).await.unwrap(),
1368                &[
1369                    Contact::Accepted {
1370                        user_id: user_1,
1371                        should_notify: false
1372                    },
1373                    Contact::Accepted {
1374                        user_id: user_2,
1375                        should_notify: false
1376                    }
1377                ]
1378            );
1379            assert_eq!(
1380                db.get_contacts(user_3).await.unwrap(),
1381                &[
1382                    Contact::Accepted {
1383                        user_id: user_1,
1384                        should_notify: false
1385                    },
1386                    Contact::Accepted {
1387                        user_id: user_3,
1388                        should_notify: false
1389                    }
1390                ],
1391            );
1392        }
1393    }
1394
1395    #[tokio::test(flavor = "multi_thread")]
1396    async fn test_invite_codes() {
1397        let postgres = TestDb::postgres().await;
1398        let db = postgres.db();
1399        let user1 = db.create_user("user-1", None, false).await.unwrap();
1400
1401        // Initially, user 1 has no invite code
1402        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1403
1404        // Setting invite count to 0 when no code is assigned does not assign a new code
1405        db.set_invite_count(user1, 0).await.unwrap();
1406        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1407
1408        // User 1 creates an invite code that can be used twice.
1409        db.set_invite_count(user1, 2).await.unwrap();
1410        let (invite_code, invite_count) =
1411            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1412        assert_eq!(invite_count, 2);
1413
1414        // User 2 redeems the invite code and becomes a contact of user 1.
1415        let user2 = db
1416            .redeem_invite_code(&invite_code, "user-2", None)
1417            .await
1418            .unwrap();
1419        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1420        assert_eq!(invite_count, 1);
1421        assert_eq!(
1422            db.get_contacts(user1).await.unwrap(),
1423            [
1424                Contact::Accepted {
1425                    user_id: user1,
1426                    should_notify: false
1427                },
1428                Contact::Accepted {
1429                    user_id: user2,
1430                    should_notify: true
1431                }
1432            ]
1433        );
1434        assert_eq!(
1435            db.get_contacts(user2).await.unwrap(),
1436            [
1437                Contact::Accepted {
1438                    user_id: user1,
1439                    should_notify: false
1440                },
1441                Contact::Accepted {
1442                    user_id: user2,
1443                    should_notify: false
1444                }
1445            ]
1446        );
1447
1448        // User 3 redeems the invite code and becomes a contact of user 1.
1449        let user3 = db
1450            .redeem_invite_code(&invite_code, "user-3", None)
1451            .await
1452            .unwrap();
1453        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1454        assert_eq!(invite_count, 0);
1455        assert_eq!(
1456            db.get_contacts(user1).await.unwrap(),
1457            [
1458                Contact::Accepted {
1459                    user_id: user1,
1460                    should_notify: false
1461                },
1462                Contact::Accepted {
1463                    user_id: user2,
1464                    should_notify: true
1465                },
1466                Contact::Accepted {
1467                    user_id: user3,
1468                    should_notify: true
1469                }
1470            ]
1471        );
1472        assert_eq!(
1473            db.get_contacts(user3).await.unwrap(),
1474            [
1475                Contact::Accepted {
1476                    user_id: user1,
1477                    should_notify: false
1478                },
1479                Contact::Accepted {
1480                    user_id: user3,
1481                    should_notify: false
1482                },
1483            ]
1484        );
1485
1486        // Trying to reedem the code for the third time results in an error.
1487        db.redeem_invite_code(&invite_code, "user-4", None)
1488            .await
1489            .unwrap_err();
1490
1491        // Invite count can be updated after the code has been created.
1492        db.set_invite_count(user1, 2).await.unwrap();
1493        let (latest_code, invite_count) =
1494            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1495        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1496        assert_eq!(invite_count, 2);
1497
1498        // User 4 can now redeem the invite code and becomes a contact of user 1.
1499        let user4 = db
1500            .redeem_invite_code(&invite_code, "user-4", None)
1501            .await
1502            .unwrap();
1503        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1504        assert_eq!(invite_count, 1);
1505        assert_eq!(
1506            db.get_contacts(user1).await.unwrap(),
1507            [
1508                Contact::Accepted {
1509                    user_id: user1,
1510                    should_notify: false
1511                },
1512                Contact::Accepted {
1513                    user_id: user2,
1514                    should_notify: true
1515                },
1516                Contact::Accepted {
1517                    user_id: user3,
1518                    should_notify: true
1519                },
1520                Contact::Accepted {
1521                    user_id: user4,
1522                    should_notify: true
1523                }
1524            ]
1525        );
1526        assert_eq!(
1527            db.get_contacts(user4).await.unwrap(),
1528            [
1529                Contact::Accepted {
1530                    user_id: user1,
1531                    should_notify: false
1532                },
1533                Contact::Accepted {
1534                    user_id: user4,
1535                    should_notify: false
1536                },
1537            ]
1538        );
1539
1540        // An existing user cannot redeem invite codes.
1541        db.redeem_invite_code(&invite_code, "user-2", None)
1542            .await
1543            .unwrap_err();
1544        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1545        assert_eq!(invite_count, 1);
1546    }
1547
1548    pub struct TestDb {
1549        pub db: Option<Arc<dyn Db>>,
1550        pub url: String,
1551    }
1552
1553    impl TestDb {
1554        pub async fn postgres() -> Self {
1555            lazy_static! {
1556                static ref LOCK: Mutex<()> = Mutex::new(());
1557            }
1558
1559            let _guard = LOCK.lock();
1560            let mut rng = StdRng::from_entropy();
1561            let name = format!("zed-test-{}", rng.gen::<u128>());
1562            let url = format!("postgres://postgres@localhost/{}", name);
1563            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1564            Postgres::create_database(&url)
1565                .await
1566                .expect("failed to create test db");
1567            let db = PostgresDb::new(&url, 5).await.unwrap();
1568            let migrator = Migrator::new(migrations_path).await.unwrap();
1569            migrator.run(&db.pool).await.unwrap();
1570            Self {
1571                db: Some(Arc::new(db)),
1572                url,
1573            }
1574        }
1575
1576        pub fn fake(background: Arc<Background>) -> Self {
1577            Self {
1578                db: Some(Arc::new(FakeDb::new(background))),
1579                url: Default::default(),
1580            }
1581        }
1582
1583        pub fn db(&self) -> &Arc<dyn Db> {
1584            self.db.as_ref().unwrap()
1585        }
1586    }
1587
1588    impl Drop for TestDb {
1589        fn drop(&mut self) {
1590            if let Some(db) = self.db.take() {
1591                futures::executor::block_on(db.teardown(&self.url));
1592            }
1593        }
1594    }
1595
1596    pub struct FakeDb {
1597        background: Arc<Background>,
1598        pub users: Mutex<BTreeMap<UserId, User>>,
1599        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1600        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1601        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1602        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1603        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1604        pub contacts: Mutex<Vec<FakeContact>>,
1605        next_channel_message_id: Mutex<i32>,
1606        next_user_id: Mutex<i32>,
1607        next_org_id: Mutex<i32>,
1608        next_channel_id: Mutex<i32>,
1609    }
1610
1611    #[derive(Debug)]
1612    pub struct FakeContact {
1613        pub requester_id: UserId,
1614        pub responder_id: UserId,
1615        pub accepted: bool,
1616        pub should_notify: bool,
1617    }
1618
1619    impl FakeDb {
1620        pub fn new(background: Arc<Background>) -> Self {
1621            Self {
1622                background,
1623                users: Default::default(),
1624                next_user_id: Mutex::new(1),
1625                orgs: Default::default(),
1626                next_org_id: Mutex::new(1),
1627                org_memberships: Default::default(),
1628                channels: Default::default(),
1629                next_channel_id: Mutex::new(1),
1630                channel_memberships: Default::default(),
1631                channel_messages: Default::default(),
1632                next_channel_message_id: Mutex::new(1),
1633                contacts: Default::default(),
1634            }
1635        }
1636    }
1637
1638    #[async_trait]
1639    impl Db for FakeDb {
1640        async fn create_user(
1641            &self,
1642            github_login: &str,
1643            email_address: Option<&str>,
1644            admin: bool,
1645        ) -> Result<UserId> {
1646            self.background.simulate_random_delay().await;
1647
1648            let mut users = self.users.lock();
1649            if let Some(user) = users
1650                .values()
1651                .find(|user| user.github_login == github_login)
1652            {
1653                Ok(user.id)
1654            } else {
1655                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1656                users.insert(
1657                    user_id,
1658                    User {
1659                        id: user_id,
1660                        github_login: github_login.to_string(),
1661                        email_address: email_address.map(str::to_string),
1662                        admin,
1663                        invite_code: None,
1664                        invite_count: 0,
1665                        connected_once: false,
1666                    },
1667                );
1668                Ok(user_id)
1669            }
1670        }
1671
1672        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1673            unimplemented!()
1674        }
1675
1676        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1677            unimplemented!()
1678        }
1679
1680        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1681            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1682        }
1683
1684        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1685            self.background.simulate_random_delay().await;
1686            let users = self.users.lock();
1687            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1688        }
1689
1690        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1691            Ok(self
1692                .users
1693                .lock()
1694                .values()
1695                .find(|user| user.github_login == github_login)
1696                .cloned())
1697        }
1698
1699        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1700            unimplemented!()
1701        }
1702
1703        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1704            self.background.simulate_random_delay().await;
1705            let mut users = self.users.lock();
1706            let mut user = users
1707                .get_mut(&id)
1708                .ok_or_else(|| anyhow!("user not found"))?;
1709            user.connected_once = connected_once;
1710            Ok(())
1711        }
1712
1713        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1714            unimplemented!()
1715        }
1716
1717        // invite codes
1718
1719        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1720            unimplemented!()
1721        }
1722
1723        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1724            Ok(None)
1725        }
1726
1727        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1728            unimplemented!()
1729        }
1730
1731        async fn redeem_invite_code(
1732            &self,
1733            _code: &str,
1734            _login: &str,
1735            _email_address: Option<&str>,
1736        ) -> Result<UserId> {
1737            unimplemented!()
1738        }
1739
1740        // contacts
1741
1742        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1743            self.background.simulate_random_delay().await;
1744            let mut contacts = vec![Contact::Accepted {
1745                user_id: id,
1746                should_notify: false,
1747            }];
1748
1749            for contact in self.contacts.lock().iter() {
1750                if contact.requester_id == id {
1751                    if contact.accepted {
1752                        contacts.push(Contact::Accepted {
1753                            user_id: contact.responder_id,
1754                            should_notify: contact.should_notify,
1755                        });
1756                    } else {
1757                        contacts.push(Contact::Outgoing {
1758                            user_id: contact.responder_id,
1759                        });
1760                    }
1761                } else if contact.responder_id == id {
1762                    if contact.accepted {
1763                        contacts.push(Contact::Accepted {
1764                            user_id: contact.requester_id,
1765                            should_notify: false,
1766                        });
1767                    } else {
1768                        contacts.push(Contact::Incoming {
1769                            user_id: contact.requester_id,
1770                            should_notify: contact.should_notify,
1771                        });
1772                    }
1773                }
1774            }
1775
1776            contacts.sort_unstable_by_key(|contact| contact.user_id());
1777            Ok(contacts)
1778        }
1779
1780        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1781            self.background.simulate_random_delay().await;
1782            Ok(self.contacts.lock().iter().any(|contact| {
1783                contact.accepted
1784                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1785                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1786            }))
1787        }
1788
1789        async fn send_contact_request(
1790            &self,
1791            requester_id: UserId,
1792            responder_id: UserId,
1793        ) -> Result<()> {
1794            let mut contacts = self.contacts.lock();
1795            for contact in contacts.iter_mut() {
1796                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1797                    if contact.accepted {
1798                        Err(anyhow!("contact already exists"))?;
1799                    } else {
1800                        Err(anyhow!("contact already requested"))?;
1801                    }
1802                }
1803                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1804                    if contact.accepted {
1805                        Err(anyhow!("contact already exists"))?;
1806                    } else {
1807                        contact.accepted = true;
1808                        contact.should_notify = false;
1809                        return Ok(());
1810                    }
1811                }
1812            }
1813            contacts.push(FakeContact {
1814                requester_id,
1815                responder_id,
1816                accepted: false,
1817                should_notify: true,
1818            });
1819            Ok(())
1820        }
1821
1822        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1823            self.contacts.lock().retain(|contact| {
1824                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1825            });
1826            Ok(())
1827        }
1828
1829        async fn dismiss_contact_notification(
1830            &self,
1831            user_id: UserId,
1832            contact_user_id: UserId,
1833        ) -> Result<()> {
1834            let mut contacts = self.contacts.lock();
1835            for contact in contacts.iter_mut() {
1836                if contact.requester_id == contact_user_id
1837                    && contact.responder_id == user_id
1838                    && !contact.accepted
1839                {
1840                    contact.should_notify = false;
1841                    return Ok(());
1842                }
1843                if contact.requester_id == user_id
1844                    && contact.responder_id == contact_user_id
1845                    && contact.accepted
1846                {
1847                    contact.should_notify = false;
1848                    return Ok(());
1849                }
1850            }
1851            Err(anyhow!("no such notification"))?
1852        }
1853
1854        async fn respond_to_contact_request(
1855            &self,
1856            responder_id: UserId,
1857            requester_id: UserId,
1858            accept: bool,
1859        ) -> Result<()> {
1860            let mut contacts = self.contacts.lock();
1861            for (ix, contact) in contacts.iter_mut().enumerate() {
1862                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1863                    if contact.accepted {
1864                        Err(anyhow!("contact already confirmed"))?;
1865                    }
1866                    if accept {
1867                        contact.accepted = true;
1868                        contact.should_notify = true;
1869                    } else {
1870                        contacts.remove(ix);
1871                    }
1872                    return Ok(());
1873                }
1874            }
1875            Err(anyhow!("no such contact request"))?
1876        }
1877
1878        async fn create_access_token_hash(
1879            &self,
1880            _user_id: UserId,
1881            _access_token_hash: &str,
1882            _max_access_token_count: usize,
1883        ) -> Result<()> {
1884            unimplemented!()
1885        }
1886
1887        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1888            unimplemented!()
1889        }
1890
1891        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1892            unimplemented!()
1893        }
1894
1895        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1896            self.background.simulate_random_delay().await;
1897            let mut orgs = self.orgs.lock();
1898            if orgs.values().any(|org| org.slug == slug) {
1899                Err(anyhow!("org already exists"))?
1900            } else {
1901                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1902                orgs.insert(
1903                    org_id,
1904                    Org {
1905                        id: org_id,
1906                        name: name.to_string(),
1907                        slug: slug.to_string(),
1908                    },
1909                );
1910                Ok(org_id)
1911            }
1912        }
1913
1914        async fn add_org_member(
1915            &self,
1916            org_id: OrgId,
1917            user_id: UserId,
1918            is_admin: bool,
1919        ) -> Result<()> {
1920            self.background.simulate_random_delay().await;
1921            if !self.orgs.lock().contains_key(&org_id) {
1922                Err(anyhow!("org does not exist"))?;
1923            }
1924            if !self.users.lock().contains_key(&user_id) {
1925                Err(anyhow!("user does not exist"))?;
1926            }
1927
1928            self.org_memberships
1929                .lock()
1930                .entry((org_id, user_id))
1931                .or_insert(is_admin);
1932            Ok(())
1933        }
1934
1935        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1936            self.background.simulate_random_delay().await;
1937            if !self.orgs.lock().contains_key(&org_id) {
1938                Err(anyhow!("org does not exist"))?;
1939            }
1940
1941            let mut channels = self.channels.lock();
1942            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1943            channels.insert(
1944                channel_id,
1945                Channel {
1946                    id: channel_id,
1947                    name: name.to_string(),
1948                    owner_id: org_id.0,
1949                    owner_is_user: false,
1950                },
1951            );
1952            Ok(channel_id)
1953        }
1954
1955        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1956            self.background.simulate_random_delay().await;
1957            Ok(self
1958                .channels
1959                .lock()
1960                .values()
1961                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1962                .cloned()
1963                .collect())
1964        }
1965
1966        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1967            self.background.simulate_random_delay().await;
1968            let channels = self.channels.lock();
1969            let memberships = self.channel_memberships.lock();
1970            Ok(channels
1971                .values()
1972                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1973                .cloned()
1974                .collect())
1975        }
1976
1977        async fn can_user_access_channel(
1978            &self,
1979            user_id: UserId,
1980            channel_id: ChannelId,
1981        ) -> Result<bool> {
1982            self.background.simulate_random_delay().await;
1983            Ok(self
1984                .channel_memberships
1985                .lock()
1986                .contains_key(&(channel_id, user_id)))
1987        }
1988
1989        async fn add_channel_member(
1990            &self,
1991            channel_id: ChannelId,
1992            user_id: UserId,
1993            is_admin: bool,
1994        ) -> Result<()> {
1995            self.background.simulate_random_delay().await;
1996            if !self.channels.lock().contains_key(&channel_id) {
1997                Err(anyhow!("channel does not exist"))?;
1998            }
1999            if !self.users.lock().contains_key(&user_id) {
2000                Err(anyhow!("user does not exist"))?;
2001            }
2002
2003            self.channel_memberships
2004                .lock()
2005                .entry((channel_id, user_id))
2006                .or_insert(is_admin);
2007            Ok(())
2008        }
2009
2010        async fn create_channel_message(
2011            &self,
2012            channel_id: ChannelId,
2013            sender_id: UserId,
2014            body: &str,
2015            timestamp: OffsetDateTime,
2016            nonce: u128,
2017        ) -> Result<MessageId> {
2018            self.background.simulate_random_delay().await;
2019            if !self.channels.lock().contains_key(&channel_id) {
2020                Err(anyhow!("channel does not exist"))?;
2021            }
2022            if !self.users.lock().contains_key(&sender_id) {
2023                Err(anyhow!("user does not exist"))?;
2024            }
2025
2026            let mut messages = self.channel_messages.lock();
2027            if let Some(message) = messages
2028                .values()
2029                .find(|message| message.nonce.as_u128() == nonce)
2030            {
2031                Ok(message.id)
2032            } else {
2033                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2034                messages.insert(
2035                    message_id,
2036                    ChannelMessage {
2037                        id: message_id,
2038                        channel_id,
2039                        sender_id,
2040                        body: body.to_string(),
2041                        sent_at: timestamp,
2042                        nonce: Uuid::from_u128(nonce),
2043                    },
2044                );
2045                Ok(message_id)
2046            }
2047        }
2048
2049        async fn get_channel_messages(
2050            &self,
2051            channel_id: ChannelId,
2052            count: usize,
2053            before_id: Option<MessageId>,
2054        ) -> Result<Vec<ChannelMessage>> {
2055            let mut messages = self
2056                .channel_messages
2057                .lock()
2058                .values()
2059                .rev()
2060                .filter(|message| {
2061                    message.channel_id == channel_id
2062                        && message.id < before_id.unwrap_or(MessageId::MAX)
2063                })
2064                .take(count)
2065                .cloned()
2066                .collect::<Vec<_>>();
2067            messages.sort_unstable_by_key(|message| message.id);
2068            Ok(messages)
2069        }
2070
2071        async fn teardown(&self, _: &str) {}
2072
2073        #[cfg(test)]
2074        fn as_fake(&self) -> Option<&FakeDb> {
2075            Some(self)
2076        }
2077    }
2078}