db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use futures::StreamExt;
   6use nanoid::nanoid;
   7use serde::Serialize;
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{types::Uuid, FromRow};
  10use time::OffsetDateTime;
  11
  12#[async_trait]
  13pub trait Db: Send + Sync {
  14    async fn create_user(
  15        &self,
  16        github_login: &str,
  17        email_address: Option<&str>,
  18        admin: bool,
  19    ) -> Result<UserId>;
  20    async fn get_all_users(&self) -> Result<Vec<User>>;
  21    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  22    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  23    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  24    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  25    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  26    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  27    async fn destroy_user(&self, id: UserId) -> Result<()>;
  28
  29    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  30    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  31    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  32    async fn redeem_invite_code(
  33        &self,
  34        code: &str,
  35        login: &str,
  36        email_address: Option<&str>,
  37    ) -> Result<UserId>;
  38
  39    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  40    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  41    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  42    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  43    async fn dismiss_contact_notification(
  44        &self,
  45        responder_id: UserId,
  46        requester_id: UserId,
  47    ) -> Result<()>;
  48    async fn respond_to_contact_request(
  49        &self,
  50        responder_id: UserId,
  51        requester_id: UserId,
  52        accept: bool,
  53    ) -> Result<()>;
  54
  55    async fn create_access_token_hash(
  56        &self,
  57        user_id: UserId,
  58        access_token_hash: &str,
  59        max_access_token_count: usize,
  60    ) -> Result<()>;
  61    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  62    #[cfg(any(test, feature = "seed-support"))]
  63
  64    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  65    #[cfg(any(test, feature = "seed-support"))]
  66    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  67    #[cfg(any(test, feature = "seed-support"))]
  68    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  69    #[cfg(any(test, feature = "seed-support"))]
  70    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  71    #[cfg(any(test, feature = "seed-support"))]
  72
  73    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  74    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  75    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  76        -> Result<bool>;
  77    #[cfg(any(test, feature = "seed-support"))]
  78    async fn add_channel_member(
  79        &self,
  80        channel_id: ChannelId,
  81        user_id: UserId,
  82        is_admin: bool,
  83    ) -> Result<()>;
  84    async fn create_channel_message(
  85        &self,
  86        channel_id: ChannelId,
  87        sender_id: UserId,
  88        body: &str,
  89        timestamp: OffsetDateTime,
  90        nonce: u128,
  91    ) -> Result<MessageId>;
  92    async fn get_channel_messages(
  93        &self,
  94        channel_id: ChannelId,
  95        count: usize,
  96        before_id: Option<MessageId>,
  97    ) -> Result<Vec<ChannelMessage>>;
  98    #[cfg(test)]
  99    async fn teardown(&self, url: &str);
 100    #[cfg(test)]
 101    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 102}
 103
 104pub struct PostgresDb {
 105    pool: sqlx::PgPool,
 106}
 107
 108impl PostgresDb {
 109    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 110        let pool = DbOptions::new()
 111            .max_connections(max_connections)
 112            .connect(&url)
 113            .await
 114            .context("failed to connect to postgres database")?;
 115        Ok(Self { pool })
 116    }
 117}
 118
 119#[async_trait]
 120impl Db for PostgresDb {
 121    // users
 122
 123    async fn create_user(
 124        &self,
 125        github_login: &str,
 126        email_address: Option<&str>,
 127        admin: bool,
 128    ) -> Result<UserId> {
 129        let query = "
 130            INSERT INTO users (github_login, email_address, admin)
 131            VALUES ($1, $2, $3)
 132            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 133            RETURNING id
 134        ";
 135        Ok(sqlx::query_scalar(query)
 136            .bind(github_login)
 137            .bind(email_address)
 138            .bind(admin)
 139            .fetch_one(&self.pool)
 140            .await
 141            .map(UserId)?)
 142    }
 143
 144    async fn get_all_users(&self) -> Result<Vec<User>> {
 145        let query = "SELECT * FROM users ORDER BY github_login ASC";
 146        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 147    }
 148
 149    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 150        let like_string = fuzzy_like_string(name_query);
 151        let query = "
 152            SELECT users.*
 153            FROM users
 154            WHERE github_login ILIKE $1
 155            ORDER BY github_login <-> $2
 156            LIMIT $3
 157        ";
 158        Ok(sqlx::query_as(query)
 159            .bind(like_string)
 160            .bind(name_query)
 161            .bind(limit)
 162            .fetch_all(&self.pool)
 163            .await?)
 164    }
 165
 166    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 167        let users = self.get_users_by_ids(vec![id]).await?;
 168        Ok(users.into_iter().next())
 169    }
 170
 171    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 172        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 173        let query = "
 174            SELECT users.*
 175            FROM users
 176            WHERE users.id = ANY ($1)
 177        ";
 178        Ok(sqlx::query_as(query)
 179            .bind(&ids)
 180            .fetch_all(&self.pool)
 181            .await?)
 182    }
 183
 184    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 185        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 186        Ok(sqlx::query_as(query)
 187            .bind(github_login)
 188            .fetch_optional(&self.pool)
 189            .await?)
 190    }
 191
 192    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 193        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 194        Ok(sqlx::query(query)
 195            .bind(is_admin)
 196            .bind(id.0)
 197            .execute(&self.pool)
 198            .await
 199            .map(drop)?)
 200    }
 201
 202    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 203        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 204        Ok(sqlx::query(query)
 205            .bind(connected_once)
 206            .bind(id.0)
 207            .execute(&self.pool)
 208            .await
 209            .map(drop)?)
 210    }
 211
 212    async fn destroy_user(&self, id: UserId) -> Result<()> {
 213        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 214        sqlx::query(query)
 215            .bind(id.0)
 216            .execute(&self.pool)
 217            .await
 218            .map(drop)?;
 219        let query = "DELETE FROM users WHERE id = $1;";
 220        Ok(sqlx::query(query)
 221            .bind(id.0)
 222            .execute(&self.pool)
 223            .await
 224            .map(drop)?)
 225    }
 226
 227    // invite codes
 228
 229    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 230        let mut tx = self.pool.begin().await?;
 231        sqlx::query(
 232            "
 233                UPDATE users
 234                SET invite_code = $1
 235                WHERE id = $2 AND invite_code IS NULL
 236            ",
 237        )
 238        .bind(nanoid!(16))
 239        .bind(id)
 240        .execute(&mut tx)
 241        .await?;
 242        sqlx::query(
 243            "
 244            UPDATE users
 245            SET invite_count = $1
 246            WHERE id = $2
 247            ",
 248        )
 249        .bind(count)
 250        .bind(id)
 251        .execute(&mut tx)
 252        .await?;
 253        tx.commit().await?;
 254        Ok(())
 255    }
 256
 257    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 258        let result: Option<(String, i32)> = sqlx::query_as(
 259            "
 260                SELECT invite_code, invite_count
 261                FROM users
 262                WHERE id = $1 AND invite_code IS NOT NULL 
 263            ",
 264        )
 265        .bind(id)
 266        .fetch_optional(&self.pool)
 267        .await?;
 268        if let Some((code, count)) = result {
 269            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 270        } else {
 271            Ok(None)
 272        }
 273    }
 274
 275    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 276        sqlx::query_as(
 277            "
 278                SELECT *
 279                FROM users
 280                WHERE invite_code = $1
 281            ",
 282        )
 283        .bind(code)
 284        .fetch_optional(&self.pool)
 285        .await?
 286        .ok_or_else(|| {
 287            Error::Http(
 288                StatusCode::NOT_FOUND,
 289                "that invite code does not exist".to_string(),
 290            )
 291        })
 292    }
 293
 294    async fn redeem_invite_code(
 295        &self,
 296        code: &str,
 297        login: &str,
 298        email_address: Option<&str>,
 299    ) -> Result<UserId> {
 300        let mut tx = self.pool.begin().await?;
 301
 302        let inviter_id: Option<UserId> = sqlx::query_scalar(
 303            "
 304                UPDATE users
 305                SET invite_count = invite_count - 1
 306                WHERE
 307                    invite_code = $1 AND
 308                    invite_count > 0
 309                RETURNING id
 310            ",
 311        )
 312        .bind(code)
 313        .fetch_optional(&mut tx)
 314        .await?;
 315
 316        let inviter_id = match inviter_id {
 317            Some(inviter_id) => inviter_id,
 318            None => {
 319                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 320                    .bind(code)
 321                    .fetch_optional(&mut tx)
 322                    .await?
 323                    .is_some()
 324                {
 325                    Err(Error::Http(
 326                        StatusCode::UNAUTHORIZED,
 327                        "no invites remaining".to_string(),
 328                    ))?
 329                } else {
 330                    Err(Error::Http(
 331                        StatusCode::NOT_FOUND,
 332                        "invite code not found".to_string(),
 333                    ))?
 334                }
 335            }
 336        };
 337
 338        let invitee_id = sqlx::query_scalar(
 339            "
 340                INSERT INTO users
 341                    (github_login, email_address, admin, inviter_id)
 342                VALUES
 343                    ($1, $2, 'f', $3)
 344                RETURNING id
 345            ",
 346        )
 347        .bind(login)
 348        .bind(email_address)
 349        .bind(inviter_id)
 350        .fetch_one(&mut tx)
 351        .await
 352        .map(UserId)?;
 353
 354        sqlx::query(
 355            "
 356                INSERT INTO contacts
 357                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 358                VALUES
 359                    ($1, $2, 't', 't', 't')
 360            ",
 361        )
 362        .bind(inviter_id)
 363        .bind(invitee_id)
 364        .execute(&mut tx)
 365        .await?;
 366
 367        tx.commit().await?;
 368        Ok(invitee_id)
 369    }
 370
 371    // contacts
 372
 373    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 374        let query = "
 375            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 376            FROM contacts
 377            WHERE user_id_a = $1 OR user_id_b = $1;
 378        ";
 379
 380        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 381            .bind(user_id)
 382            .fetch(&self.pool);
 383
 384        let mut contacts = vec![Contact::Accepted {
 385            user_id,
 386            should_notify: false,
 387        }];
 388        while let Some(row) = rows.next().await {
 389            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 390
 391            if user_id_a == user_id {
 392                if accepted {
 393                    contacts.push(Contact::Accepted {
 394                        user_id: user_id_b,
 395                        should_notify: should_notify && a_to_b,
 396                    });
 397                } else if a_to_b {
 398                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 399                } else {
 400                    contacts.push(Contact::Incoming {
 401                        user_id: user_id_b,
 402                        should_notify,
 403                    });
 404                }
 405            } else {
 406                if accepted {
 407                    contacts.push(Contact::Accepted {
 408                        user_id: user_id_a,
 409                        should_notify: should_notify && !a_to_b,
 410                    });
 411                } else if a_to_b {
 412                    contacts.push(Contact::Incoming {
 413                        user_id: user_id_a,
 414                        should_notify,
 415                    });
 416                } else {
 417                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 418                }
 419            }
 420        }
 421
 422        contacts.sort_unstable_by_key(|contact| contact.user_id());
 423
 424        Ok(contacts)
 425    }
 426
 427    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 428        let (id_a, id_b) = if user_id_1 < user_id_2 {
 429            (user_id_1, user_id_2)
 430        } else {
 431            (user_id_2, user_id_1)
 432        };
 433
 434        let query = "
 435            SELECT 1 FROM contacts
 436            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 437            LIMIT 1
 438        ";
 439        Ok(sqlx::query_scalar::<_, i32>(query)
 440            .bind(id_a.0)
 441            .bind(id_b.0)
 442            .fetch_optional(&self.pool)
 443            .await?
 444            .is_some())
 445    }
 446
 447    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 448        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 449            (sender_id, receiver_id, true)
 450        } else {
 451            (receiver_id, sender_id, false)
 452        };
 453        let query = "
 454            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 455            VALUES ($1, $2, $3, 'f', 't')
 456            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 457            SET
 458                accepted = 't',
 459                should_notify = 'f'
 460            WHERE
 461                NOT contacts.accepted AND
 462                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 463                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 464        ";
 465        let result = sqlx::query(query)
 466            .bind(id_a.0)
 467            .bind(id_b.0)
 468            .bind(a_to_b)
 469            .execute(&self.pool)
 470            .await?;
 471
 472        if result.rows_affected() == 1 {
 473            Ok(())
 474        } else {
 475            Err(anyhow!("contact already requested"))?
 476        }
 477    }
 478
 479    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 480        let (id_a, id_b) = if responder_id < requester_id {
 481            (responder_id, requester_id)
 482        } else {
 483            (requester_id, responder_id)
 484        };
 485        let query = "
 486            DELETE FROM contacts
 487            WHERE user_id_a = $1 AND user_id_b = $2;
 488        ";
 489        let result = sqlx::query(query)
 490            .bind(id_a.0)
 491            .bind(id_b.0)
 492            .execute(&self.pool)
 493            .await?;
 494
 495        if result.rows_affected() == 1 {
 496            Ok(())
 497        } else {
 498            Err(anyhow!("no such contact"))?
 499        }
 500    }
 501
 502    async fn dismiss_contact_notification(
 503        &self,
 504        user_id: UserId,
 505        contact_user_id: UserId,
 506    ) -> Result<()> {
 507        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 508            (user_id, contact_user_id, true)
 509        } else {
 510            (contact_user_id, user_id, false)
 511        };
 512
 513        let query = "
 514            UPDATE contacts
 515            SET should_notify = 'f'
 516            WHERE
 517                user_id_a = $1 AND user_id_b = $2 AND
 518                (
 519                    (a_to_b = $3 AND accepted) OR
 520                    (a_to_b != $3 AND NOT accepted)
 521                );
 522        ";
 523
 524        let result = sqlx::query(query)
 525            .bind(id_a.0)
 526            .bind(id_b.0)
 527            .bind(a_to_b)
 528            .execute(&self.pool)
 529            .await?;
 530
 531        if result.rows_affected() == 0 {
 532            Err(anyhow!("no such contact request"))?;
 533        }
 534
 535        Ok(())
 536    }
 537
 538    async fn respond_to_contact_request(
 539        &self,
 540        responder_id: UserId,
 541        requester_id: UserId,
 542        accept: bool,
 543    ) -> Result<()> {
 544        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 545            (responder_id, requester_id, false)
 546        } else {
 547            (requester_id, responder_id, true)
 548        };
 549        let result = if accept {
 550            let query = "
 551                UPDATE contacts
 552                SET accepted = 't', should_notify = 't'
 553                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 554            ";
 555            sqlx::query(query)
 556                .bind(id_a.0)
 557                .bind(id_b.0)
 558                .bind(a_to_b)
 559                .execute(&self.pool)
 560                .await?
 561        } else {
 562            let query = "
 563                DELETE FROM contacts
 564                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 565            ";
 566            sqlx::query(query)
 567                .bind(id_a.0)
 568                .bind(id_b.0)
 569                .bind(a_to_b)
 570                .execute(&self.pool)
 571                .await?
 572        };
 573        if result.rows_affected() == 1 {
 574            Ok(())
 575        } else {
 576            Err(anyhow!("no such contact request"))?
 577        }
 578    }
 579
 580    // access tokens
 581
 582    async fn create_access_token_hash(
 583        &self,
 584        user_id: UserId,
 585        access_token_hash: &str,
 586        max_access_token_count: usize,
 587    ) -> Result<()> {
 588        let insert_query = "
 589            INSERT INTO access_tokens (user_id, hash)
 590            VALUES ($1, $2);
 591        ";
 592        let cleanup_query = "
 593            DELETE FROM access_tokens
 594            WHERE id IN (
 595                SELECT id from access_tokens
 596                WHERE user_id = $1
 597                ORDER BY id DESC
 598                OFFSET $3
 599            )
 600        ";
 601
 602        let mut tx = self.pool.begin().await?;
 603        sqlx::query(insert_query)
 604            .bind(user_id.0)
 605            .bind(access_token_hash)
 606            .execute(&mut tx)
 607            .await?;
 608        sqlx::query(cleanup_query)
 609            .bind(user_id.0)
 610            .bind(access_token_hash)
 611            .bind(max_access_token_count as u32)
 612            .execute(&mut tx)
 613            .await?;
 614        Ok(tx.commit().await?)
 615    }
 616
 617    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 618        let query = "
 619            SELECT hash
 620            FROM access_tokens
 621            WHERE user_id = $1
 622            ORDER BY id DESC
 623        ";
 624        Ok(sqlx::query_scalar(query)
 625            .bind(user_id.0)
 626            .fetch_all(&self.pool)
 627            .await?)
 628    }
 629
 630    // orgs
 631
 632    #[allow(unused)] // Help rust-analyzer
 633    #[cfg(any(test, feature = "seed-support"))]
 634    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 635        let query = "
 636            SELECT *
 637            FROM orgs
 638            WHERE slug = $1
 639        ";
 640        Ok(sqlx::query_as(query)
 641            .bind(slug)
 642            .fetch_optional(&self.pool)
 643            .await?)
 644    }
 645
 646    #[cfg(any(test, feature = "seed-support"))]
 647    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 648        let query = "
 649            INSERT INTO orgs (name, slug)
 650            VALUES ($1, $2)
 651            RETURNING id
 652        ";
 653        Ok(sqlx::query_scalar(query)
 654            .bind(name)
 655            .bind(slug)
 656            .fetch_one(&self.pool)
 657            .await
 658            .map(OrgId)?)
 659    }
 660
 661    #[cfg(any(test, feature = "seed-support"))]
 662    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 663        let query = "
 664            INSERT INTO org_memberships (org_id, user_id, admin)
 665            VALUES ($1, $2, $3)
 666            ON CONFLICT DO NOTHING
 667        ";
 668        Ok(sqlx::query(query)
 669            .bind(org_id.0)
 670            .bind(user_id.0)
 671            .bind(is_admin)
 672            .execute(&self.pool)
 673            .await
 674            .map(drop)?)
 675    }
 676
 677    // channels
 678
 679    #[cfg(any(test, feature = "seed-support"))]
 680    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 681        let query = "
 682            INSERT INTO channels (owner_id, owner_is_user, name)
 683            VALUES ($1, false, $2)
 684            RETURNING id
 685        ";
 686        Ok(sqlx::query_scalar(query)
 687            .bind(org_id.0)
 688            .bind(name)
 689            .fetch_one(&self.pool)
 690            .await
 691            .map(ChannelId)?)
 692    }
 693
 694    #[allow(unused)] // Help rust-analyzer
 695    #[cfg(any(test, feature = "seed-support"))]
 696    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 697        let query = "
 698            SELECT *
 699            FROM channels
 700            WHERE
 701                channels.owner_is_user = false AND
 702                channels.owner_id = $1
 703        ";
 704        Ok(sqlx::query_as(query)
 705            .bind(org_id.0)
 706            .fetch_all(&self.pool)
 707            .await?)
 708    }
 709
 710    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 711        let query = "
 712            SELECT
 713                channels.*
 714            FROM
 715                channel_memberships, channels
 716            WHERE
 717                channel_memberships.user_id = $1 AND
 718                channel_memberships.channel_id = channels.id
 719        ";
 720        Ok(sqlx::query_as(query)
 721            .bind(user_id.0)
 722            .fetch_all(&self.pool)
 723            .await?)
 724    }
 725
 726    async fn can_user_access_channel(
 727        &self,
 728        user_id: UserId,
 729        channel_id: ChannelId,
 730    ) -> Result<bool> {
 731        let query = "
 732            SELECT id
 733            FROM channel_memberships
 734            WHERE user_id = $1 AND channel_id = $2
 735            LIMIT 1
 736        ";
 737        Ok(sqlx::query_scalar::<_, i32>(query)
 738            .bind(user_id.0)
 739            .bind(channel_id.0)
 740            .fetch_optional(&self.pool)
 741            .await
 742            .map(|e| e.is_some())?)
 743    }
 744
 745    #[cfg(any(test, feature = "seed-support"))]
 746    async fn add_channel_member(
 747        &self,
 748        channel_id: ChannelId,
 749        user_id: UserId,
 750        is_admin: bool,
 751    ) -> Result<()> {
 752        let query = "
 753            INSERT INTO channel_memberships (channel_id, user_id, admin)
 754            VALUES ($1, $2, $3)
 755            ON CONFLICT DO NOTHING
 756        ";
 757        Ok(sqlx::query(query)
 758            .bind(channel_id.0)
 759            .bind(user_id.0)
 760            .bind(is_admin)
 761            .execute(&self.pool)
 762            .await
 763            .map(drop)?)
 764    }
 765
 766    // messages
 767
 768    async fn create_channel_message(
 769        &self,
 770        channel_id: ChannelId,
 771        sender_id: UserId,
 772        body: &str,
 773        timestamp: OffsetDateTime,
 774        nonce: u128,
 775    ) -> Result<MessageId> {
 776        let query = "
 777            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 778            VALUES ($1, $2, $3, $4, $5)
 779            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 780            RETURNING id
 781        ";
 782        Ok(sqlx::query_scalar(query)
 783            .bind(channel_id.0)
 784            .bind(sender_id.0)
 785            .bind(body)
 786            .bind(timestamp)
 787            .bind(Uuid::from_u128(nonce))
 788            .fetch_one(&self.pool)
 789            .await
 790            .map(MessageId)?)
 791    }
 792
 793    async fn get_channel_messages(
 794        &self,
 795        channel_id: ChannelId,
 796        count: usize,
 797        before_id: Option<MessageId>,
 798    ) -> Result<Vec<ChannelMessage>> {
 799        let query = r#"
 800            SELECT * FROM (
 801                SELECT
 802                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 803                FROM
 804                    channel_messages
 805                WHERE
 806                    channel_id = $1 AND
 807                    id < $2
 808                ORDER BY id DESC
 809                LIMIT $3
 810            ) as recent_messages
 811            ORDER BY id ASC
 812        "#;
 813        Ok(sqlx::query_as(query)
 814            .bind(channel_id.0)
 815            .bind(before_id.unwrap_or(MessageId::MAX))
 816            .bind(count as i64)
 817            .fetch_all(&self.pool)
 818            .await?)
 819    }
 820
 821    #[cfg(test)]
 822    async fn teardown(&self, url: &str) {
 823        use util::ResultExt;
 824
 825        let query = "
 826            SELECT pg_terminate_backend(pg_stat_activity.pid)
 827            FROM pg_stat_activity
 828            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 829        ";
 830        sqlx::query(query).execute(&self.pool).await.log_err();
 831        self.pool.close().await;
 832        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 833            .await
 834            .log_err();
 835    }
 836
 837    #[cfg(test)]
 838    fn as_fake(&self) -> Option<&tests::FakeDb> {
 839        None
 840    }
 841}
 842
 843macro_rules! id_type {
 844    ($name:ident) => {
 845        #[derive(
 846            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 847        )]
 848        #[sqlx(transparent)]
 849        #[serde(transparent)]
 850        pub struct $name(pub i32);
 851
 852        impl $name {
 853            #[allow(unused)]
 854            pub const MAX: Self = Self(i32::MAX);
 855
 856            #[allow(unused)]
 857            pub fn from_proto(value: u64) -> Self {
 858                Self(value as i32)
 859            }
 860
 861            #[allow(unused)]
 862            pub fn to_proto(&self) -> u64 {
 863                self.0 as u64
 864            }
 865        }
 866
 867        impl std::fmt::Display for $name {
 868            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 869                self.0.fmt(f)
 870            }
 871        }
 872    };
 873}
 874
 875id_type!(UserId);
 876#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 877pub struct User {
 878    pub id: UserId,
 879    pub github_login: String,
 880    pub email_address: Option<String>,
 881    pub admin: bool,
 882    pub invite_code: Option<String>,
 883    pub invite_count: i32,
 884    pub connected_once: bool,
 885}
 886
 887id_type!(OrgId);
 888#[derive(FromRow)]
 889pub struct Org {
 890    pub id: OrgId,
 891    pub name: String,
 892    pub slug: String,
 893}
 894
 895id_type!(ChannelId);
 896#[derive(Clone, Debug, FromRow, Serialize)]
 897pub struct Channel {
 898    pub id: ChannelId,
 899    pub name: String,
 900    pub owner_id: i32,
 901    pub owner_is_user: bool,
 902}
 903
 904id_type!(MessageId);
 905#[derive(Clone, Debug, FromRow)]
 906pub struct ChannelMessage {
 907    pub id: MessageId,
 908    pub channel_id: ChannelId,
 909    pub sender_id: UserId,
 910    pub body: String,
 911    pub sent_at: OffsetDateTime,
 912    pub nonce: Uuid,
 913}
 914
 915#[derive(Clone, Debug, PartialEq, Eq)]
 916pub enum Contact {
 917    Accepted {
 918        user_id: UserId,
 919        should_notify: bool,
 920    },
 921    Outgoing {
 922        user_id: UserId,
 923    },
 924    Incoming {
 925        user_id: UserId,
 926        should_notify: bool,
 927    },
 928}
 929
 930impl Contact {
 931    pub fn user_id(&self) -> UserId {
 932        match self {
 933            Contact::Accepted { user_id, .. } => *user_id,
 934            Contact::Outgoing { user_id } => *user_id,
 935            Contact::Incoming { user_id, .. } => *user_id,
 936        }
 937    }
 938}
 939
 940#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 941pub struct IncomingContactRequest {
 942    pub requester_id: UserId,
 943    pub should_notify: bool,
 944}
 945
 946fn fuzzy_like_string(string: &str) -> String {
 947    let mut result = String::with_capacity(string.len() * 2 + 1);
 948    for c in string.chars() {
 949        if c.is_alphanumeric() {
 950            result.push('%');
 951            result.push(c);
 952        }
 953    }
 954    result.push('%');
 955    result
 956}
 957
 958#[cfg(test)]
 959pub mod tests {
 960    use super::*;
 961    use anyhow::anyhow;
 962    use collections::BTreeMap;
 963    use gpui::executor::Background;
 964    use lazy_static::lazy_static;
 965    use parking_lot::Mutex;
 966    use rand::prelude::*;
 967    use sqlx::{
 968        migrate::{MigrateDatabase, Migrator},
 969        Postgres,
 970    };
 971    use std::{path::Path, sync::Arc};
 972    use util::post_inc;
 973
 974    #[tokio::test(flavor = "multi_thread")]
 975    async fn test_get_users_by_ids() {
 976        for test_db in [
 977            TestDb::postgres().await,
 978            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 979        ] {
 980            let db = test_db.db();
 981
 982            let user = db.create_user("user", None, false).await.unwrap();
 983            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
 984            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
 985            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
 986
 987            assert_eq!(
 988                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 989                    .await
 990                    .unwrap(),
 991                vec![
 992                    User {
 993                        id: user,
 994                        github_login: "user".to_string(),
 995                        admin: false,
 996                        ..Default::default()
 997                    },
 998                    User {
 999                        id: friend1,
1000                        github_login: "friend-1".to_string(),
1001                        admin: false,
1002                        ..Default::default()
1003                    },
1004                    User {
1005                        id: friend2,
1006                        github_login: "friend-2".to_string(),
1007                        admin: false,
1008                        ..Default::default()
1009                    },
1010                    User {
1011                        id: friend3,
1012                        github_login: "friend-3".to_string(),
1013                        admin: false,
1014                        ..Default::default()
1015                    }
1016                ]
1017            );
1018        }
1019    }
1020
1021    #[tokio::test(flavor = "multi_thread")]
1022    async fn test_recent_channel_messages() {
1023        for test_db in [
1024            TestDb::postgres().await,
1025            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1026        ] {
1027            let db = test_db.db();
1028            let user = db.create_user("user", None, false).await.unwrap();
1029            let org = db.create_org("org", "org").await.unwrap();
1030            let channel = db.create_org_channel(org, "channel").await.unwrap();
1031            for i in 0..10 {
1032                db.create_channel_message(
1033                    channel,
1034                    user,
1035                    &i.to_string(),
1036                    OffsetDateTime::now_utc(),
1037                    i,
1038                )
1039                .await
1040                .unwrap();
1041            }
1042
1043            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1044            assert_eq!(
1045                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1046                ["5", "6", "7", "8", "9"]
1047            );
1048
1049            let prev_messages = db
1050                .get_channel_messages(channel, 4, Some(messages[0].id))
1051                .await
1052                .unwrap();
1053            assert_eq!(
1054                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1055                ["1", "2", "3", "4"]
1056            );
1057        }
1058    }
1059
1060    #[tokio::test(flavor = "multi_thread")]
1061    async fn test_channel_message_nonces() {
1062        for test_db in [
1063            TestDb::postgres().await,
1064            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1065        ] {
1066            let db = test_db.db();
1067            let user = db.create_user("user", None, false).await.unwrap();
1068            let org = db.create_org("org", "org").await.unwrap();
1069            let channel = db.create_org_channel(org, "channel").await.unwrap();
1070
1071            let msg1_id = db
1072                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1073                .await
1074                .unwrap();
1075            let msg2_id = db
1076                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1077                .await
1078                .unwrap();
1079            let msg3_id = db
1080                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1081                .await
1082                .unwrap();
1083            let msg4_id = db
1084                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1085                .await
1086                .unwrap();
1087
1088            assert_ne!(msg1_id, msg2_id);
1089            assert_eq!(msg1_id, msg3_id);
1090            assert_eq!(msg2_id, msg4_id);
1091        }
1092    }
1093
1094    #[tokio::test(flavor = "multi_thread")]
1095    async fn test_create_access_tokens() {
1096        let test_db = TestDb::postgres().await;
1097        let db = test_db.db();
1098        let user = db.create_user("the-user", None, false).await.unwrap();
1099
1100        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1101        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1102        assert_eq!(
1103            db.get_access_token_hashes(user).await.unwrap(),
1104            &["h2".to_string(), "h1".to_string()]
1105        );
1106
1107        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1108        assert_eq!(
1109            db.get_access_token_hashes(user).await.unwrap(),
1110            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1111        );
1112
1113        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1114        assert_eq!(
1115            db.get_access_token_hashes(user).await.unwrap(),
1116            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1117        );
1118
1119        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1120        assert_eq!(
1121            db.get_access_token_hashes(user).await.unwrap(),
1122            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1123        );
1124    }
1125
1126    #[test]
1127    fn test_fuzzy_like_string() {
1128        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1129        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1130        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1131    }
1132
1133    #[tokio::test(flavor = "multi_thread")]
1134    async fn test_fuzzy_search_users() {
1135        let test_db = TestDb::postgres().await;
1136        let db = test_db.db();
1137        for github_login in [
1138            "California",
1139            "colorado",
1140            "oregon",
1141            "washington",
1142            "florida",
1143            "delaware",
1144            "rhode-island",
1145        ] {
1146            db.create_user(github_login, None, false).await.unwrap();
1147        }
1148
1149        assert_eq!(
1150            fuzzy_search_user_names(db, "clr").await,
1151            &["colorado", "California"]
1152        );
1153        assert_eq!(
1154            fuzzy_search_user_names(db, "ro").await,
1155            &["rhode-island", "colorado", "oregon"],
1156        );
1157
1158        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1159            db.fuzzy_search_users(query, 10)
1160                .await
1161                .unwrap()
1162                .into_iter()
1163                .map(|user| user.github_login)
1164                .collect::<Vec<_>>()
1165        }
1166    }
1167
1168    #[tokio::test(flavor = "multi_thread")]
1169    async fn test_add_contacts() {
1170        for test_db in [
1171            TestDb::postgres().await,
1172            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1173        ] {
1174            let db = test_db.db();
1175
1176            let user_1 = db.create_user("user1", None, false).await.unwrap();
1177            let user_2 = db.create_user("user2", None, false).await.unwrap();
1178            let user_3 = db.create_user("user3", None, false).await.unwrap();
1179
1180            // User starts with no contacts
1181            assert_eq!(
1182                db.get_contacts(user_1).await.unwrap(),
1183                vec![Contact::Accepted {
1184                    user_id: user_1,
1185                    should_notify: false
1186                }],
1187            );
1188
1189            // User requests a contact. Both users see the pending request.
1190            db.send_contact_request(user_1, user_2).await.unwrap();
1191            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1192            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1193            assert_eq!(
1194                db.get_contacts(user_1).await.unwrap(),
1195                &[
1196                    Contact::Accepted {
1197                        user_id: user_1,
1198                        should_notify: false
1199                    },
1200                    Contact::Outgoing { user_id: user_2 }
1201                ],
1202            );
1203            assert_eq!(
1204                db.get_contacts(user_2).await.unwrap(),
1205                &[
1206                    Contact::Incoming {
1207                        user_id: user_1,
1208                        should_notify: true
1209                    },
1210                    Contact::Accepted {
1211                        user_id: user_2,
1212                        should_notify: false
1213                    },
1214                ]
1215            );
1216
1217            // User 2 dismisses the contact request notification without accepting or rejecting.
1218            // We shouldn't notify them again.
1219            db.dismiss_contact_notification(user_1, user_2)
1220                .await
1221                .unwrap_err();
1222            db.dismiss_contact_notification(user_2, user_1)
1223                .await
1224                .unwrap();
1225            assert_eq!(
1226                db.get_contacts(user_2).await.unwrap(),
1227                &[
1228                    Contact::Incoming {
1229                        user_id: user_1,
1230                        should_notify: false
1231                    },
1232                    Contact::Accepted {
1233                        user_id: user_2,
1234                        should_notify: false
1235                    },
1236                ]
1237            );
1238
1239            // User can't accept their own contact request
1240            db.respond_to_contact_request(user_1, user_2, true)
1241                .await
1242                .unwrap_err();
1243
1244            // User accepts a contact request. Both users see the contact.
1245            db.respond_to_contact_request(user_2, user_1, true)
1246                .await
1247                .unwrap();
1248            assert_eq!(
1249                db.get_contacts(user_1).await.unwrap(),
1250                &[
1251                    Contact::Accepted {
1252                        user_id: user_1,
1253                        should_notify: false
1254                    },
1255                    Contact::Accepted {
1256                        user_id: user_2,
1257                        should_notify: true
1258                    }
1259                ],
1260            );
1261            assert!(db.has_contact(user_1, user_2).await.unwrap());
1262            assert!(db.has_contact(user_2, user_1).await.unwrap());
1263            assert_eq!(
1264                db.get_contacts(user_2).await.unwrap(),
1265                &[
1266                    Contact::Accepted {
1267                        user_id: user_1,
1268                        should_notify: false,
1269                    },
1270                    Contact::Accepted {
1271                        user_id: user_2,
1272                        should_notify: false,
1273                    },
1274                ]
1275            );
1276
1277            // Users cannot re-request existing contacts.
1278            db.send_contact_request(user_1, user_2).await.unwrap_err();
1279            db.send_contact_request(user_2, user_1).await.unwrap_err();
1280
1281            // Users can't dismiss notifications of them accepting other users' requests.
1282            db.dismiss_contact_notification(user_2, user_1)
1283                .await
1284                .unwrap_err();
1285            assert_eq!(
1286                db.get_contacts(user_1).await.unwrap(),
1287                &[
1288                    Contact::Accepted {
1289                        user_id: user_1,
1290                        should_notify: false
1291                    },
1292                    Contact::Accepted {
1293                        user_id: user_2,
1294                        should_notify: true,
1295                    },
1296                ]
1297            );
1298
1299            // Users can dismiss notifications of other users accepting their requests.
1300            db.dismiss_contact_notification(user_1, user_2)
1301                .await
1302                .unwrap();
1303            assert_eq!(
1304                db.get_contacts(user_1).await.unwrap(),
1305                &[
1306                    Contact::Accepted {
1307                        user_id: user_1,
1308                        should_notify: false
1309                    },
1310                    Contact::Accepted {
1311                        user_id: user_2,
1312                        should_notify: false,
1313                    },
1314                ]
1315            );
1316
1317            // Users send each other concurrent contact requests and
1318            // see that they are immediately accepted.
1319            db.send_contact_request(user_1, user_3).await.unwrap();
1320            db.send_contact_request(user_3, user_1).await.unwrap();
1321            assert_eq!(
1322                db.get_contacts(user_1).await.unwrap(),
1323                &[
1324                    Contact::Accepted {
1325                        user_id: user_1,
1326                        should_notify: false
1327                    },
1328                    Contact::Accepted {
1329                        user_id: user_2,
1330                        should_notify: false,
1331                    },
1332                    Contact::Accepted {
1333                        user_id: user_3,
1334                        should_notify: false
1335                    },
1336                ]
1337            );
1338            assert_eq!(
1339                db.get_contacts(user_3).await.unwrap(),
1340                &[
1341                    Contact::Accepted {
1342                        user_id: user_1,
1343                        should_notify: false
1344                    },
1345                    Contact::Accepted {
1346                        user_id: user_3,
1347                        should_notify: false
1348                    }
1349                ],
1350            );
1351
1352            // User declines a contact request. Both users see that it is gone.
1353            db.send_contact_request(user_2, user_3).await.unwrap();
1354            db.respond_to_contact_request(user_3, user_2, false)
1355                .await
1356                .unwrap();
1357            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1358            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1359            assert_eq!(
1360                db.get_contacts(user_2).await.unwrap(),
1361                &[
1362                    Contact::Accepted {
1363                        user_id: user_1,
1364                        should_notify: false
1365                    },
1366                    Contact::Accepted {
1367                        user_id: user_2,
1368                        should_notify: false
1369                    }
1370                ]
1371            );
1372            assert_eq!(
1373                db.get_contacts(user_3).await.unwrap(),
1374                &[
1375                    Contact::Accepted {
1376                        user_id: user_1,
1377                        should_notify: false
1378                    },
1379                    Contact::Accepted {
1380                        user_id: user_3,
1381                        should_notify: false
1382                    }
1383                ],
1384            );
1385        }
1386    }
1387
1388    #[tokio::test(flavor = "multi_thread")]
1389    async fn test_invite_codes() {
1390        let postgres = TestDb::postgres().await;
1391        let db = postgres.db();
1392        let user1 = db.create_user("user-1", None, false).await.unwrap();
1393
1394        // Initially, user 1 has no invite code
1395        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1396
1397        // User 1 creates an invite code that can be used twice.
1398        db.set_invite_count(user1, 2).await.unwrap();
1399        let (invite_code, invite_count) =
1400            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1401        assert_eq!(invite_count, 2);
1402
1403        // User 2 redeems the invite code and becomes a contact of user 1.
1404        let user2 = db
1405            .redeem_invite_code(&invite_code, "user-2", None)
1406            .await
1407            .unwrap();
1408        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1409        assert_eq!(invite_count, 1);
1410        assert_eq!(
1411            db.get_contacts(user1).await.unwrap(),
1412            [
1413                Contact::Accepted {
1414                    user_id: user1,
1415                    should_notify: false
1416                },
1417                Contact::Accepted {
1418                    user_id: user2,
1419                    should_notify: true
1420                }
1421            ]
1422        );
1423        assert_eq!(
1424            db.get_contacts(user2).await.unwrap(),
1425            [
1426                Contact::Accepted {
1427                    user_id: user1,
1428                    should_notify: false
1429                },
1430                Contact::Accepted {
1431                    user_id: user2,
1432                    should_notify: false
1433                }
1434            ]
1435        );
1436
1437        // User 3 redeems the invite code and becomes a contact of user 1.
1438        let user3 = db
1439            .redeem_invite_code(&invite_code, "user-3", None)
1440            .await
1441            .unwrap();
1442        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1443        assert_eq!(invite_count, 0);
1444        assert_eq!(
1445            db.get_contacts(user1).await.unwrap(),
1446            [
1447                Contact::Accepted {
1448                    user_id: user1,
1449                    should_notify: false
1450                },
1451                Contact::Accepted {
1452                    user_id: user2,
1453                    should_notify: true
1454                },
1455                Contact::Accepted {
1456                    user_id: user3,
1457                    should_notify: true
1458                }
1459            ]
1460        );
1461        assert_eq!(
1462            db.get_contacts(user3).await.unwrap(),
1463            [
1464                Contact::Accepted {
1465                    user_id: user1,
1466                    should_notify: false
1467                },
1468                Contact::Accepted {
1469                    user_id: user3,
1470                    should_notify: false
1471                },
1472            ]
1473        );
1474
1475        // Trying to reedem the code for the third time results in an error.
1476        db.redeem_invite_code(&invite_code, "user-4", None)
1477            .await
1478            .unwrap_err();
1479
1480        // Invite count can be updated after the code has been created.
1481        db.set_invite_count(user1, 2).await.unwrap();
1482        let (latest_code, invite_count) =
1483            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1484        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1485        assert_eq!(invite_count, 2);
1486
1487        // User 4 can now redeem the invite code and becomes a contact of user 1.
1488        let user4 = db
1489            .redeem_invite_code(&invite_code, "user-4", None)
1490            .await
1491            .unwrap();
1492        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1493        assert_eq!(invite_count, 1);
1494        assert_eq!(
1495            db.get_contacts(user1).await.unwrap(),
1496            [
1497                Contact::Accepted {
1498                    user_id: user1,
1499                    should_notify: false
1500                },
1501                Contact::Accepted {
1502                    user_id: user2,
1503                    should_notify: true
1504                },
1505                Contact::Accepted {
1506                    user_id: user3,
1507                    should_notify: true
1508                },
1509                Contact::Accepted {
1510                    user_id: user4,
1511                    should_notify: true
1512                }
1513            ]
1514        );
1515        assert_eq!(
1516            db.get_contacts(user4).await.unwrap(),
1517            [
1518                Contact::Accepted {
1519                    user_id: user1,
1520                    should_notify: false
1521                },
1522                Contact::Accepted {
1523                    user_id: user4,
1524                    should_notify: false
1525                },
1526            ]
1527        );
1528
1529        // An existing user cannot redeem invite codes.
1530        db.redeem_invite_code(&invite_code, "user-2", None)
1531            .await
1532            .unwrap_err();
1533        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1534        assert_eq!(invite_count, 1);
1535    }
1536
1537    pub struct TestDb {
1538        pub db: Option<Arc<dyn Db>>,
1539        pub url: String,
1540    }
1541
1542    impl TestDb {
1543        pub async fn postgres() -> Self {
1544            lazy_static! {
1545                static ref LOCK: Mutex<()> = Mutex::new(());
1546            }
1547
1548            let _guard = LOCK.lock();
1549            let mut rng = StdRng::from_entropy();
1550            let name = format!("zed-test-{}", rng.gen::<u128>());
1551            let url = format!("postgres://postgres@localhost/{}", name);
1552            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1553            Postgres::create_database(&url)
1554                .await
1555                .expect("failed to create test db");
1556            let db = PostgresDb::new(&url, 5).await.unwrap();
1557            let migrator = Migrator::new(migrations_path).await.unwrap();
1558            migrator.run(&db.pool).await.unwrap();
1559            Self {
1560                db: Some(Arc::new(db)),
1561                url,
1562            }
1563        }
1564
1565        pub fn fake(background: Arc<Background>) -> Self {
1566            Self {
1567                db: Some(Arc::new(FakeDb::new(background))),
1568                url: Default::default(),
1569            }
1570        }
1571
1572        pub fn db(&self) -> &Arc<dyn Db> {
1573            self.db.as_ref().unwrap()
1574        }
1575    }
1576
1577    impl Drop for TestDb {
1578        fn drop(&mut self) {
1579            if let Some(db) = self.db.take() {
1580                futures::executor::block_on(db.teardown(&self.url));
1581            }
1582        }
1583    }
1584
1585    pub struct FakeDb {
1586        background: Arc<Background>,
1587        pub users: Mutex<BTreeMap<UserId, User>>,
1588        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1589        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1590        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1591        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1592        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1593        pub contacts: Mutex<Vec<FakeContact>>,
1594        next_channel_message_id: Mutex<i32>,
1595        next_user_id: Mutex<i32>,
1596        next_org_id: Mutex<i32>,
1597        next_channel_id: Mutex<i32>,
1598    }
1599
1600    #[derive(Debug)]
1601    pub struct FakeContact {
1602        pub requester_id: UserId,
1603        pub responder_id: UserId,
1604        pub accepted: bool,
1605        pub should_notify: bool,
1606    }
1607
1608    impl FakeDb {
1609        pub fn new(background: Arc<Background>) -> Self {
1610            Self {
1611                background,
1612                users: Default::default(),
1613                next_user_id: Mutex::new(1),
1614                orgs: Default::default(),
1615                next_org_id: Mutex::new(1),
1616                org_memberships: Default::default(),
1617                channels: Default::default(),
1618                next_channel_id: Mutex::new(1),
1619                channel_memberships: Default::default(),
1620                channel_messages: Default::default(),
1621                next_channel_message_id: Mutex::new(1),
1622                contacts: Default::default(),
1623            }
1624        }
1625    }
1626
1627    #[async_trait]
1628    impl Db for FakeDb {
1629        async fn create_user(
1630            &self,
1631            github_login: &str,
1632            email_address: Option<&str>,
1633            admin: bool,
1634        ) -> Result<UserId> {
1635            self.background.simulate_random_delay().await;
1636
1637            let mut users = self.users.lock();
1638            if let Some(user) = users
1639                .values()
1640                .find(|user| user.github_login == github_login)
1641            {
1642                Ok(user.id)
1643            } else {
1644                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1645                users.insert(
1646                    user_id,
1647                    User {
1648                        id: user_id,
1649                        github_login: github_login.to_string(),
1650                        email_address: email_address.map(str::to_string),
1651                        admin,
1652                        invite_code: None,
1653                        invite_count: 0,
1654                        connected_once: false,
1655                    },
1656                );
1657                Ok(user_id)
1658            }
1659        }
1660
1661        async fn get_all_users(&self) -> Result<Vec<User>> {
1662            unimplemented!()
1663        }
1664
1665        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1666            unimplemented!()
1667        }
1668
1669        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1670            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1671        }
1672
1673        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1674            self.background.simulate_random_delay().await;
1675            let users = self.users.lock();
1676            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1677        }
1678
1679        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1680            Ok(self
1681                .users
1682                .lock()
1683                .values()
1684                .find(|user| user.github_login == github_login)
1685                .cloned())
1686        }
1687
1688        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1689            unimplemented!()
1690        }
1691
1692        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1693            self.background.simulate_random_delay().await;
1694            let mut users = self.users.lock();
1695            let mut user = users
1696                .get_mut(&id)
1697                .ok_or_else(|| anyhow!("user not found"))?;
1698            user.connected_once = connected_once;
1699            Ok(())
1700        }
1701
1702        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1703            unimplemented!()
1704        }
1705
1706        // invite codes
1707
1708        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1709            unimplemented!()
1710        }
1711
1712        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1713            unimplemented!()
1714        }
1715
1716        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1717            unimplemented!()
1718        }
1719
1720        async fn redeem_invite_code(
1721            &self,
1722            _code: &str,
1723            _login: &str,
1724            _email_address: Option<&str>,
1725        ) -> Result<UserId> {
1726            unimplemented!()
1727        }
1728
1729        // contacts
1730
1731        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1732            self.background.simulate_random_delay().await;
1733            let mut contacts = vec![Contact::Accepted {
1734                user_id: id,
1735                should_notify: false,
1736            }];
1737
1738            for contact in self.contacts.lock().iter() {
1739                if contact.requester_id == id {
1740                    if contact.accepted {
1741                        contacts.push(Contact::Accepted {
1742                            user_id: contact.responder_id,
1743                            should_notify: contact.should_notify,
1744                        });
1745                    } else {
1746                        contacts.push(Contact::Outgoing {
1747                            user_id: contact.responder_id,
1748                        });
1749                    }
1750                } else if contact.responder_id == id {
1751                    if contact.accepted {
1752                        contacts.push(Contact::Accepted {
1753                            user_id: contact.requester_id,
1754                            should_notify: false,
1755                        });
1756                    } else {
1757                        contacts.push(Contact::Incoming {
1758                            user_id: contact.requester_id,
1759                            should_notify: contact.should_notify,
1760                        });
1761                    }
1762                }
1763            }
1764
1765            contacts.sort_unstable_by_key(|contact| contact.user_id());
1766            Ok(contacts)
1767        }
1768
1769        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1770            self.background.simulate_random_delay().await;
1771            Ok(self.contacts.lock().iter().any(|contact| {
1772                contact.accepted
1773                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1774                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1775            }))
1776        }
1777
1778        async fn send_contact_request(
1779            &self,
1780            requester_id: UserId,
1781            responder_id: UserId,
1782        ) -> Result<()> {
1783            let mut contacts = self.contacts.lock();
1784            for contact in contacts.iter_mut() {
1785                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1786                    if contact.accepted {
1787                        Err(anyhow!("contact already exists"))?;
1788                    } else {
1789                        Err(anyhow!("contact already requested"))?;
1790                    }
1791                }
1792                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1793                    if contact.accepted {
1794                        Err(anyhow!("contact already exists"))?;
1795                    } else {
1796                        contact.accepted = true;
1797                        contact.should_notify = false;
1798                        return Ok(());
1799                    }
1800                }
1801            }
1802            contacts.push(FakeContact {
1803                requester_id,
1804                responder_id,
1805                accepted: false,
1806                should_notify: true,
1807            });
1808            Ok(())
1809        }
1810
1811        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1812            self.contacts.lock().retain(|contact| {
1813                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1814            });
1815            Ok(())
1816        }
1817
1818        async fn dismiss_contact_notification(
1819            &self,
1820            user_id: UserId,
1821            contact_user_id: UserId,
1822        ) -> Result<()> {
1823            let mut contacts = self.contacts.lock();
1824            for contact in contacts.iter_mut() {
1825                if contact.requester_id == contact_user_id
1826                    && contact.responder_id == user_id
1827                    && !contact.accepted
1828                {
1829                    contact.should_notify = false;
1830                    return Ok(());
1831                }
1832                if contact.requester_id == user_id
1833                    && contact.responder_id == contact_user_id
1834                    && contact.accepted
1835                {
1836                    contact.should_notify = false;
1837                    return Ok(());
1838                }
1839            }
1840            Err(anyhow!("no such notification"))?
1841        }
1842
1843        async fn respond_to_contact_request(
1844            &self,
1845            responder_id: UserId,
1846            requester_id: UserId,
1847            accept: bool,
1848        ) -> Result<()> {
1849            let mut contacts = self.contacts.lock();
1850            for (ix, contact) in contacts.iter_mut().enumerate() {
1851                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1852                    if contact.accepted {
1853                        Err(anyhow!("contact already confirmed"))?;
1854                    }
1855                    if accept {
1856                        contact.accepted = true;
1857                        contact.should_notify = true;
1858                    } else {
1859                        contacts.remove(ix);
1860                    }
1861                    return Ok(());
1862                }
1863            }
1864            Err(anyhow!("no such contact request"))?
1865        }
1866
1867        async fn create_access_token_hash(
1868            &self,
1869            _user_id: UserId,
1870            _access_token_hash: &str,
1871            _max_access_token_count: usize,
1872        ) -> Result<()> {
1873            unimplemented!()
1874        }
1875
1876        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1877            unimplemented!()
1878        }
1879
1880        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1881            unimplemented!()
1882        }
1883
1884        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1885            self.background.simulate_random_delay().await;
1886            let mut orgs = self.orgs.lock();
1887            if orgs.values().any(|org| org.slug == slug) {
1888                Err(anyhow!("org already exists"))?
1889            } else {
1890                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1891                orgs.insert(
1892                    org_id,
1893                    Org {
1894                        id: org_id,
1895                        name: name.to_string(),
1896                        slug: slug.to_string(),
1897                    },
1898                );
1899                Ok(org_id)
1900            }
1901        }
1902
1903        async fn add_org_member(
1904            &self,
1905            org_id: OrgId,
1906            user_id: UserId,
1907            is_admin: bool,
1908        ) -> Result<()> {
1909            self.background.simulate_random_delay().await;
1910            if !self.orgs.lock().contains_key(&org_id) {
1911                Err(anyhow!("org does not exist"))?;
1912            }
1913            if !self.users.lock().contains_key(&user_id) {
1914                Err(anyhow!("user does not exist"))?;
1915            }
1916
1917            self.org_memberships
1918                .lock()
1919                .entry((org_id, user_id))
1920                .or_insert(is_admin);
1921            Ok(())
1922        }
1923
1924        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1925            self.background.simulate_random_delay().await;
1926            if !self.orgs.lock().contains_key(&org_id) {
1927                Err(anyhow!("org does not exist"))?;
1928            }
1929
1930            let mut channels = self.channels.lock();
1931            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1932            channels.insert(
1933                channel_id,
1934                Channel {
1935                    id: channel_id,
1936                    name: name.to_string(),
1937                    owner_id: org_id.0,
1938                    owner_is_user: false,
1939                },
1940            );
1941            Ok(channel_id)
1942        }
1943
1944        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1945            self.background.simulate_random_delay().await;
1946            Ok(self
1947                .channels
1948                .lock()
1949                .values()
1950                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1951                .cloned()
1952                .collect())
1953        }
1954
1955        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1956            self.background.simulate_random_delay().await;
1957            let channels = self.channels.lock();
1958            let memberships = self.channel_memberships.lock();
1959            Ok(channels
1960                .values()
1961                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1962                .cloned()
1963                .collect())
1964        }
1965
1966        async fn can_user_access_channel(
1967            &self,
1968            user_id: UserId,
1969            channel_id: ChannelId,
1970        ) -> Result<bool> {
1971            self.background.simulate_random_delay().await;
1972            Ok(self
1973                .channel_memberships
1974                .lock()
1975                .contains_key(&(channel_id, user_id)))
1976        }
1977
1978        async fn add_channel_member(
1979            &self,
1980            channel_id: ChannelId,
1981            user_id: UserId,
1982            is_admin: bool,
1983        ) -> Result<()> {
1984            self.background.simulate_random_delay().await;
1985            if !self.channels.lock().contains_key(&channel_id) {
1986                Err(anyhow!("channel does not exist"))?;
1987            }
1988            if !self.users.lock().contains_key(&user_id) {
1989                Err(anyhow!("user does not exist"))?;
1990            }
1991
1992            self.channel_memberships
1993                .lock()
1994                .entry((channel_id, user_id))
1995                .or_insert(is_admin);
1996            Ok(())
1997        }
1998
1999        async fn create_channel_message(
2000            &self,
2001            channel_id: ChannelId,
2002            sender_id: UserId,
2003            body: &str,
2004            timestamp: OffsetDateTime,
2005            nonce: u128,
2006        ) -> Result<MessageId> {
2007            self.background.simulate_random_delay().await;
2008            if !self.channels.lock().contains_key(&channel_id) {
2009                Err(anyhow!("channel does not exist"))?;
2010            }
2011            if !self.users.lock().contains_key(&sender_id) {
2012                Err(anyhow!("user does not exist"))?;
2013            }
2014
2015            let mut messages = self.channel_messages.lock();
2016            if let Some(message) = messages
2017                .values()
2018                .find(|message| message.nonce.as_u128() == nonce)
2019            {
2020                Ok(message.id)
2021            } else {
2022                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2023                messages.insert(
2024                    message_id,
2025                    ChannelMessage {
2026                        id: message_id,
2027                        channel_id,
2028                        sender_id,
2029                        body: body.to_string(),
2030                        sent_at: timestamp,
2031                        nonce: Uuid::from_u128(nonce),
2032                    },
2033                );
2034                Ok(message_id)
2035            }
2036        }
2037
2038        async fn get_channel_messages(
2039            &self,
2040            channel_id: ChannelId,
2041            count: usize,
2042            before_id: Option<MessageId>,
2043        ) -> Result<Vec<ChannelMessage>> {
2044            let mut messages = self
2045                .channel_messages
2046                .lock()
2047                .values()
2048                .rev()
2049                .filter(|message| {
2050                    message.channel_id == channel_id
2051                        && message.id < before_id.unwrap_or(MessageId::MAX)
2052                })
2053                .take(count)
2054                .cloned()
2055                .collect::<Vec<_>>();
2056            messages.sort_unstable_by_key(|message| message.id);
2057            Ok(messages)
2058        }
2059
2060        async fn teardown(&self, _: &str) {}
2061
2062        #[cfg(test)]
2063        fn as_fake(&self) -> Option<&FakeDb> {
2064            Some(self)
2065        }
2066    }
2067}