db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use futures::StreamExt;
   6use nanoid::nanoid;
   7use serde::Serialize;
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{types::Uuid, FromRow};
  10use time::OffsetDateTime;
  11
  12#[async_trait]
  13pub trait Db: Send + Sync {
  14    async fn create_user(
  15        &self,
  16        github_login: &str,
  17        email_address: Option<&str>,
  18        admin: bool,
  19    ) -> Result<UserId>;
  20    async fn get_all_users(&self) -> Result<Vec<User>>;
  21    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  22    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  23    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  24    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  25    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  26    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  27    async fn destroy_user(&self, id: UserId) -> Result<()>;
  28
  29    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  30    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  31    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  32    async fn redeem_invite_code(
  33        &self,
  34        code: &str,
  35        login: &str,
  36        email_address: Option<&str>,
  37    ) -> Result<UserId>;
  38
  39    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  40    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  41    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  42    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  43    async fn dismiss_contact_notification(
  44        &self,
  45        responder_id: UserId,
  46        requester_id: UserId,
  47    ) -> Result<()>;
  48    async fn respond_to_contact_request(
  49        &self,
  50        responder_id: UserId,
  51        requester_id: UserId,
  52        accept: bool,
  53    ) -> Result<()>;
  54
  55    async fn create_access_token_hash(
  56        &self,
  57        user_id: UserId,
  58        access_token_hash: &str,
  59        max_access_token_count: usize,
  60    ) -> Result<()>;
  61    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  62    #[cfg(any(test, feature = "seed-support"))]
  63
  64    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  65    #[cfg(any(test, feature = "seed-support"))]
  66    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  67    #[cfg(any(test, feature = "seed-support"))]
  68    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  69    #[cfg(any(test, feature = "seed-support"))]
  70    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  71    #[cfg(any(test, feature = "seed-support"))]
  72
  73    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  74    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  75    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  76        -> Result<bool>;
  77    #[cfg(any(test, feature = "seed-support"))]
  78    async fn add_channel_member(
  79        &self,
  80        channel_id: ChannelId,
  81        user_id: UserId,
  82        is_admin: bool,
  83    ) -> Result<()>;
  84    async fn create_channel_message(
  85        &self,
  86        channel_id: ChannelId,
  87        sender_id: UserId,
  88        body: &str,
  89        timestamp: OffsetDateTime,
  90        nonce: u128,
  91    ) -> Result<MessageId>;
  92    async fn get_channel_messages(
  93        &self,
  94        channel_id: ChannelId,
  95        count: usize,
  96        before_id: Option<MessageId>,
  97    ) -> Result<Vec<ChannelMessage>>;
  98    #[cfg(test)]
  99    async fn teardown(&self, url: &str);
 100    #[cfg(test)]
 101    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 102}
 103
 104pub struct PostgresDb {
 105    pool: sqlx::PgPool,
 106}
 107
 108impl PostgresDb {
 109    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 110        let pool = DbOptions::new()
 111            .max_connections(max_connections)
 112            .connect(&url)
 113            .await
 114            .context("failed to connect to postgres database")?;
 115        Ok(Self { pool })
 116    }
 117}
 118
 119#[async_trait]
 120impl Db for PostgresDb {
 121    // users
 122
 123    async fn create_user(
 124        &self,
 125        github_login: &str,
 126        email_address: Option<&str>,
 127        admin: bool,
 128    ) -> Result<UserId> {
 129        let query = "
 130            INSERT INTO users (github_login, email_address, admin)
 131            VALUES ($1, $2, $3)
 132            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 133            RETURNING id
 134        ";
 135        Ok(sqlx::query_scalar(query)
 136            .bind(github_login)
 137            .bind(email_address)
 138            .bind(admin)
 139            .fetch_one(&self.pool)
 140            .await
 141            .map(UserId)?)
 142    }
 143
 144    async fn get_all_users(&self) -> Result<Vec<User>> {
 145        let query = "SELECT * FROM users ORDER BY github_login ASC";
 146        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 147    }
 148
 149    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 150        let like_string = fuzzy_like_string(name_query);
 151        let query = "
 152            SELECT users.*
 153            FROM users
 154            WHERE github_login ILIKE $1
 155            ORDER BY github_login <-> $2
 156            LIMIT $3
 157        ";
 158        Ok(sqlx::query_as(query)
 159            .bind(like_string)
 160            .bind(name_query)
 161            .bind(limit)
 162            .fetch_all(&self.pool)
 163            .await?)
 164    }
 165
 166    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 167        let users = self.get_users_by_ids(vec![id]).await?;
 168        Ok(users.into_iter().next())
 169    }
 170
 171    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 172        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 173        let query = "
 174            SELECT users.*
 175            FROM users
 176            WHERE users.id = ANY ($1)
 177        ";
 178        Ok(sqlx::query_as(query)
 179            .bind(&ids)
 180            .fetch_all(&self.pool)
 181            .await?)
 182    }
 183
 184    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 185        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 186        Ok(sqlx::query_as(query)
 187            .bind(github_login)
 188            .fetch_optional(&self.pool)
 189            .await?)
 190    }
 191
 192    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 193        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 194        Ok(sqlx::query(query)
 195            .bind(is_admin)
 196            .bind(id.0)
 197            .execute(&self.pool)
 198            .await
 199            .map(drop)?)
 200    }
 201
 202    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 203        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 204        Ok(sqlx::query(query)
 205            .bind(connected_once)
 206            .bind(id.0)
 207            .execute(&self.pool)
 208            .await
 209            .map(drop)?)
 210    }
 211
 212    async fn destroy_user(&self, id: UserId) -> Result<()> {
 213        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 214        sqlx::query(query)
 215            .bind(id.0)
 216            .execute(&self.pool)
 217            .await
 218            .map(drop)?;
 219        let query = "DELETE FROM users WHERE id = $1;";
 220        Ok(sqlx::query(query)
 221            .bind(id.0)
 222            .execute(&self.pool)
 223            .await
 224            .map(drop)?)
 225    }
 226
 227    // invite codes
 228
 229    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 230        let mut tx = self.pool.begin().await?;
 231        if count > 0 {
 232            sqlx::query(
 233                "
 234                UPDATE users
 235                SET invite_code = $1
 236                WHERE id = $2 AND invite_code IS NULL
 237            ",
 238            )
 239            .bind(nanoid!(16))
 240            .bind(id)
 241            .execute(&mut tx)
 242            .await?;
 243        }
 244
 245        sqlx::query(
 246            "
 247            UPDATE users
 248            SET invite_count = $1
 249            WHERE id = $2
 250            ",
 251        )
 252        .bind(count)
 253        .bind(id)
 254        .execute(&mut tx)
 255        .await?;
 256        tx.commit().await?;
 257        Ok(())
 258    }
 259
 260    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 261        let result: Option<(String, i32)> = sqlx::query_as(
 262            "
 263                SELECT invite_code, invite_count
 264                FROM users
 265                WHERE id = $1 AND invite_code IS NOT NULL 
 266            ",
 267        )
 268        .bind(id)
 269        .fetch_optional(&self.pool)
 270        .await?;
 271        if let Some((code, count)) = result {
 272            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 273        } else {
 274            Ok(None)
 275        }
 276    }
 277
 278    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 279        sqlx::query_as(
 280            "
 281                SELECT *
 282                FROM users
 283                WHERE invite_code = $1
 284            ",
 285        )
 286        .bind(code)
 287        .fetch_optional(&self.pool)
 288        .await?
 289        .ok_or_else(|| {
 290            Error::Http(
 291                StatusCode::NOT_FOUND,
 292                "that invite code does not exist".to_string(),
 293            )
 294        })
 295    }
 296
 297    async fn redeem_invite_code(
 298        &self,
 299        code: &str,
 300        login: &str,
 301        email_address: Option<&str>,
 302    ) -> Result<UserId> {
 303        let mut tx = self.pool.begin().await?;
 304
 305        let inviter_id: Option<UserId> = sqlx::query_scalar(
 306            "
 307                UPDATE users
 308                SET invite_count = invite_count - 1
 309                WHERE
 310                    invite_code = $1 AND
 311                    invite_count > 0
 312                RETURNING id
 313            ",
 314        )
 315        .bind(code)
 316        .fetch_optional(&mut tx)
 317        .await?;
 318
 319        let inviter_id = match inviter_id {
 320            Some(inviter_id) => inviter_id,
 321            None => {
 322                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 323                    .bind(code)
 324                    .fetch_optional(&mut tx)
 325                    .await?
 326                    .is_some()
 327                {
 328                    Err(Error::Http(
 329                        StatusCode::UNAUTHORIZED,
 330                        "no invites remaining".to_string(),
 331                    ))?
 332                } else {
 333                    Err(Error::Http(
 334                        StatusCode::NOT_FOUND,
 335                        "invite code not found".to_string(),
 336                    ))?
 337                }
 338            }
 339        };
 340
 341        let invitee_id = sqlx::query_scalar(
 342            "
 343                INSERT INTO users
 344                    (github_login, email_address, admin, inviter_id)
 345                VALUES
 346                    ($1, $2, 'f', $3)
 347                RETURNING id
 348            ",
 349        )
 350        .bind(login)
 351        .bind(email_address)
 352        .bind(inviter_id)
 353        .fetch_one(&mut tx)
 354        .await
 355        .map(UserId)?;
 356
 357        sqlx::query(
 358            "
 359                INSERT INTO contacts
 360                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 361                VALUES
 362                    ($1, $2, 't', 't', 't')
 363            ",
 364        )
 365        .bind(inviter_id)
 366        .bind(invitee_id)
 367        .execute(&mut tx)
 368        .await?;
 369
 370        tx.commit().await?;
 371        Ok(invitee_id)
 372    }
 373
 374    // contacts
 375
 376    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 377        let query = "
 378            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 379            FROM contacts
 380            WHERE user_id_a = $1 OR user_id_b = $1;
 381        ";
 382
 383        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 384            .bind(user_id)
 385            .fetch(&self.pool);
 386
 387        let mut contacts = vec![Contact::Accepted {
 388            user_id,
 389            should_notify: false,
 390        }];
 391        while let Some(row) = rows.next().await {
 392            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 393
 394            if user_id_a == user_id {
 395                if accepted {
 396                    contacts.push(Contact::Accepted {
 397                        user_id: user_id_b,
 398                        should_notify: should_notify && a_to_b,
 399                    });
 400                } else if a_to_b {
 401                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 402                } else {
 403                    contacts.push(Contact::Incoming {
 404                        user_id: user_id_b,
 405                        should_notify,
 406                    });
 407                }
 408            } else {
 409                if accepted {
 410                    contacts.push(Contact::Accepted {
 411                        user_id: user_id_a,
 412                        should_notify: should_notify && !a_to_b,
 413                    });
 414                } else if a_to_b {
 415                    contacts.push(Contact::Incoming {
 416                        user_id: user_id_a,
 417                        should_notify,
 418                    });
 419                } else {
 420                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 421                }
 422            }
 423        }
 424
 425        contacts.sort_unstable_by_key(|contact| contact.user_id());
 426
 427        Ok(contacts)
 428    }
 429
 430    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 431        let (id_a, id_b) = if user_id_1 < user_id_2 {
 432            (user_id_1, user_id_2)
 433        } else {
 434            (user_id_2, user_id_1)
 435        };
 436
 437        let query = "
 438            SELECT 1 FROM contacts
 439            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 440            LIMIT 1
 441        ";
 442        Ok(sqlx::query_scalar::<_, i32>(query)
 443            .bind(id_a.0)
 444            .bind(id_b.0)
 445            .fetch_optional(&self.pool)
 446            .await?
 447            .is_some())
 448    }
 449
 450    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 451        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 452            (sender_id, receiver_id, true)
 453        } else {
 454            (receiver_id, sender_id, false)
 455        };
 456        let query = "
 457            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 458            VALUES ($1, $2, $3, 'f', 't')
 459            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 460            SET
 461                accepted = 't',
 462                should_notify = 'f'
 463            WHERE
 464                NOT contacts.accepted AND
 465                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 466                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 467        ";
 468        let result = sqlx::query(query)
 469            .bind(id_a.0)
 470            .bind(id_b.0)
 471            .bind(a_to_b)
 472            .execute(&self.pool)
 473            .await?;
 474
 475        if result.rows_affected() == 1 {
 476            Ok(())
 477        } else {
 478            Err(anyhow!("contact already requested"))?
 479        }
 480    }
 481
 482    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 483        let (id_a, id_b) = if responder_id < requester_id {
 484            (responder_id, requester_id)
 485        } else {
 486            (requester_id, responder_id)
 487        };
 488        let query = "
 489            DELETE FROM contacts
 490            WHERE user_id_a = $1 AND user_id_b = $2;
 491        ";
 492        let result = sqlx::query(query)
 493            .bind(id_a.0)
 494            .bind(id_b.0)
 495            .execute(&self.pool)
 496            .await?;
 497
 498        if result.rows_affected() == 1 {
 499            Ok(())
 500        } else {
 501            Err(anyhow!("no such contact"))?
 502        }
 503    }
 504
 505    async fn dismiss_contact_notification(
 506        &self,
 507        user_id: UserId,
 508        contact_user_id: UserId,
 509    ) -> Result<()> {
 510        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 511            (user_id, contact_user_id, true)
 512        } else {
 513            (contact_user_id, user_id, false)
 514        };
 515
 516        let query = "
 517            UPDATE contacts
 518            SET should_notify = 'f'
 519            WHERE
 520                user_id_a = $1 AND user_id_b = $2 AND
 521                (
 522                    (a_to_b = $3 AND accepted) OR
 523                    (a_to_b != $3 AND NOT accepted)
 524                );
 525        ";
 526
 527        let result = sqlx::query(query)
 528            .bind(id_a.0)
 529            .bind(id_b.0)
 530            .bind(a_to_b)
 531            .execute(&self.pool)
 532            .await?;
 533
 534        if result.rows_affected() == 0 {
 535            Err(anyhow!("no such contact request"))?;
 536        }
 537
 538        Ok(())
 539    }
 540
 541    async fn respond_to_contact_request(
 542        &self,
 543        responder_id: UserId,
 544        requester_id: UserId,
 545        accept: bool,
 546    ) -> Result<()> {
 547        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 548            (responder_id, requester_id, false)
 549        } else {
 550            (requester_id, responder_id, true)
 551        };
 552        let result = if accept {
 553            let query = "
 554                UPDATE contacts
 555                SET accepted = 't', should_notify = 't'
 556                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 557            ";
 558            sqlx::query(query)
 559                .bind(id_a.0)
 560                .bind(id_b.0)
 561                .bind(a_to_b)
 562                .execute(&self.pool)
 563                .await?
 564        } else {
 565            let query = "
 566                DELETE FROM contacts
 567                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 568            ";
 569            sqlx::query(query)
 570                .bind(id_a.0)
 571                .bind(id_b.0)
 572                .bind(a_to_b)
 573                .execute(&self.pool)
 574                .await?
 575        };
 576        if result.rows_affected() == 1 {
 577            Ok(())
 578        } else {
 579            Err(anyhow!("no such contact request"))?
 580        }
 581    }
 582
 583    // access tokens
 584
 585    async fn create_access_token_hash(
 586        &self,
 587        user_id: UserId,
 588        access_token_hash: &str,
 589        max_access_token_count: usize,
 590    ) -> Result<()> {
 591        let insert_query = "
 592            INSERT INTO access_tokens (user_id, hash)
 593            VALUES ($1, $2);
 594        ";
 595        let cleanup_query = "
 596            DELETE FROM access_tokens
 597            WHERE id IN (
 598                SELECT id from access_tokens
 599                WHERE user_id = $1
 600                ORDER BY id DESC
 601                OFFSET $3
 602            )
 603        ";
 604
 605        let mut tx = self.pool.begin().await?;
 606        sqlx::query(insert_query)
 607            .bind(user_id.0)
 608            .bind(access_token_hash)
 609            .execute(&mut tx)
 610            .await?;
 611        sqlx::query(cleanup_query)
 612            .bind(user_id.0)
 613            .bind(access_token_hash)
 614            .bind(max_access_token_count as u32)
 615            .execute(&mut tx)
 616            .await?;
 617        Ok(tx.commit().await?)
 618    }
 619
 620    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 621        let query = "
 622            SELECT hash
 623            FROM access_tokens
 624            WHERE user_id = $1
 625            ORDER BY id DESC
 626        ";
 627        Ok(sqlx::query_scalar(query)
 628            .bind(user_id.0)
 629            .fetch_all(&self.pool)
 630            .await?)
 631    }
 632
 633    // orgs
 634
 635    #[allow(unused)] // Help rust-analyzer
 636    #[cfg(any(test, feature = "seed-support"))]
 637    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 638        let query = "
 639            SELECT *
 640            FROM orgs
 641            WHERE slug = $1
 642        ";
 643        Ok(sqlx::query_as(query)
 644            .bind(slug)
 645            .fetch_optional(&self.pool)
 646            .await?)
 647    }
 648
 649    #[cfg(any(test, feature = "seed-support"))]
 650    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 651        let query = "
 652            INSERT INTO orgs (name, slug)
 653            VALUES ($1, $2)
 654            RETURNING id
 655        ";
 656        Ok(sqlx::query_scalar(query)
 657            .bind(name)
 658            .bind(slug)
 659            .fetch_one(&self.pool)
 660            .await
 661            .map(OrgId)?)
 662    }
 663
 664    #[cfg(any(test, feature = "seed-support"))]
 665    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 666        let query = "
 667            INSERT INTO org_memberships (org_id, user_id, admin)
 668            VALUES ($1, $2, $3)
 669            ON CONFLICT DO NOTHING
 670        ";
 671        Ok(sqlx::query(query)
 672            .bind(org_id.0)
 673            .bind(user_id.0)
 674            .bind(is_admin)
 675            .execute(&self.pool)
 676            .await
 677            .map(drop)?)
 678    }
 679
 680    // channels
 681
 682    #[cfg(any(test, feature = "seed-support"))]
 683    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 684        let query = "
 685            INSERT INTO channels (owner_id, owner_is_user, name)
 686            VALUES ($1, false, $2)
 687            RETURNING id
 688        ";
 689        Ok(sqlx::query_scalar(query)
 690            .bind(org_id.0)
 691            .bind(name)
 692            .fetch_one(&self.pool)
 693            .await
 694            .map(ChannelId)?)
 695    }
 696
 697    #[allow(unused)] // Help rust-analyzer
 698    #[cfg(any(test, feature = "seed-support"))]
 699    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 700        let query = "
 701            SELECT *
 702            FROM channels
 703            WHERE
 704                channels.owner_is_user = false AND
 705                channels.owner_id = $1
 706        ";
 707        Ok(sqlx::query_as(query)
 708            .bind(org_id.0)
 709            .fetch_all(&self.pool)
 710            .await?)
 711    }
 712
 713    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 714        let query = "
 715            SELECT
 716                channels.*
 717            FROM
 718                channel_memberships, channels
 719            WHERE
 720                channel_memberships.user_id = $1 AND
 721                channel_memberships.channel_id = channels.id
 722        ";
 723        Ok(sqlx::query_as(query)
 724            .bind(user_id.0)
 725            .fetch_all(&self.pool)
 726            .await?)
 727    }
 728
 729    async fn can_user_access_channel(
 730        &self,
 731        user_id: UserId,
 732        channel_id: ChannelId,
 733    ) -> Result<bool> {
 734        let query = "
 735            SELECT id
 736            FROM channel_memberships
 737            WHERE user_id = $1 AND channel_id = $2
 738            LIMIT 1
 739        ";
 740        Ok(sqlx::query_scalar::<_, i32>(query)
 741            .bind(user_id.0)
 742            .bind(channel_id.0)
 743            .fetch_optional(&self.pool)
 744            .await
 745            .map(|e| e.is_some())?)
 746    }
 747
 748    #[cfg(any(test, feature = "seed-support"))]
 749    async fn add_channel_member(
 750        &self,
 751        channel_id: ChannelId,
 752        user_id: UserId,
 753        is_admin: bool,
 754    ) -> Result<()> {
 755        let query = "
 756            INSERT INTO channel_memberships (channel_id, user_id, admin)
 757            VALUES ($1, $2, $3)
 758            ON CONFLICT DO NOTHING
 759        ";
 760        Ok(sqlx::query(query)
 761            .bind(channel_id.0)
 762            .bind(user_id.0)
 763            .bind(is_admin)
 764            .execute(&self.pool)
 765            .await
 766            .map(drop)?)
 767    }
 768
 769    // messages
 770
 771    async fn create_channel_message(
 772        &self,
 773        channel_id: ChannelId,
 774        sender_id: UserId,
 775        body: &str,
 776        timestamp: OffsetDateTime,
 777        nonce: u128,
 778    ) -> Result<MessageId> {
 779        let query = "
 780            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 781            VALUES ($1, $2, $3, $4, $5)
 782            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 783            RETURNING id
 784        ";
 785        Ok(sqlx::query_scalar(query)
 786            .bind(channel_id.0)
 787            .bind(sender_id.0)
 788            .bind(body)
 789            .bind(timestamp)
 790            .bind(Uuid::from_u128(nonce))
 791            .fetch_one(&self.pool)
 792            .await
 793            .map(MessageId)?)
 794    }
 795
 796    async fn get_channel_messages(
 797        &self,
 798        channel_id: ChannelId,
 799        count: usize,
 800        before_id: Option<MessageId>,
 801    ) -> Result<Vec<ChannelMessage>> {
 802        let query = r#"
 803            SELECT * FROM (
 804                SELECT
 805                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 806                FROM
 807                    channel_messages
 808                WHERE
 809                    channel_id = $1 AND
 810                    id < $2
 811                ORDER BY id DESC
 812                LIMIT $3
 813            ) as recent_messages
 814            ORDER BY id ASC
 815        "#;
 816        Ok(sqlx::query_as(query)
 817            .bind(channel_id.0)
 818            .bind(before_id.unwrap_or(MessageId::MAX))
 819            .bind(count as i64)
 820            .fetch_all(&self.pool)
 821            .await?)
 822    }
 823
 824    #[cfg(test)]
 825    async fn teardown(&self, url: &str) {
 826        use util::ResultExt;
 827
 828        let query = "
 829            SELECT pg_terminate_backend(pg_stat_activity.pid)
 830            FROM pg_stat_activity
 831            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 832        ";
 833        sqlx::query(query).execute(&self.pool).await.log_err();
 834        self.pool.close().await;
 835        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 836            .await
 837            .log_err();
 838    }
 839
 840    #[cfg(test)]
 841    fn as_fake(&self) -> Option<&tests::FakeDb> {
 842        None
 843    }
 844}
 845
 846macro_rules! id_type {
 847    ($name:ident) => {
 848        #[derive(
 849            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 850        )]
 851        #[sqlx(transparent)]
 852        #[serde(transparent)]
 853        pub struct $name(pub i32);
 854
 855        impl $name {
 856            #[allow(unused)]
 857            pub const MAX: Self = Self(i32::MAX);
 858
 859            #[allow(unused)]
 860            pub fn from_proto(value: u64) -> Self {
 861                Self(value as i32)
 862            }
 863
 864            #[allow(unused)]
 865            pub fn to_proto(&self) -> u64 {
 866                self.0 as u64
 867            }
 868        }
 869
 870        impl std::fmt::Display for $name {
 871            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 872                self.0.fmt(f)
 873            }
 874        }
 875    };
 876}
 877
 878id_type!(UserId);
 879#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 880pub struct User {
 881    pub id: UserId,
 882    pub github_login: String,
 883    pub email_address: Option<String>,
 884    pub admin: bool,
 885    pub invite_code: Option<String>,
 886    pub invite_count: i32,
 887    pub connected_once: bool,
 888}
 889
 890id_type!(OrgId);
 891#[derive(FromRow)]
 892pub struct Org {
 893    pub id: OrgId,
 894    pub name: String,
 895    pub slug: String,
 896}
 897
 898id_type!(ChannelId);
 899#[derive(Clone, Debug, FromRow, Serialize)]
 900pub struct Channel {
 901    pub id: ChannelId,
 902    pub name: String,
 903    pub owner_id: i32,
 904    pub owner_is_user: bool,
 905}
 906
 907id_type!(MessageId);
 908#[derive(Clone, Debug, FromRow)]
 909pub struct ChannelMessage {
 910    pub id: MessageId,
 911    pub channel_id: ChannelId,
 912    pub sender_id: UserId,
 913    pub body: String,
 914    pub sent_at: OffsetDateTime,
 915    pub nonce: Uuid,
 916}
 917
 918#[derive(Clone, Debug, PartialEq, Eq)]
 919pub enum Contact {
 920    Accepted {
 921        user_id: UserId,
 922        should_notify: bool,
 923    },
 924    Outgoing {
 925        user_id: UserId,
 926    },
 927    Incoming {
 928        user_id: UserId,
 929        should_notify: bool,
 930    },
 931}
 932
 933impl Contact {
 934    pub fn user_id(&self) -> UserId {
 935        match self {
 936            Contact::Accepted { user_id, .. } => *user_id,
 937            Contact::Outgoing { user_id } => *user_id,
 938            Contact::Incoming { user_id, .. } => *user_id,
 939        }
 940    }
 941}
 942
 943#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 944pub struct IncomingContactRequest {
 945    pub requester_id: UserId,
 946    pub should_notify: bool,
 947}
 948
 949fn fuzzy_like_string(string: &str) -> String {
 950    let mut result = String::with_capacity(string.len() * 2 + 1);
 951    for c in string.chars() {
 952        if c.is_alphanumeric() {
 953            result.push('%');
 954            result.push(c);
 955        }
 956    }
 957    result.push('%');
 958    result
 959}
 960
 961#[cfg(test)]
 962pub mod tests {
 963    use super::*;
 964    use anyhow::anyhow;
 965    use collections::BTreeMap;
 966    use gpui::executor::Background;
 967    use lazy_static::lazy_static;
 968    use parking_lot::Mutex;
 969    use rand::prelude::*;
 970    use sqlx::{
 971        migrate::{MigrateDatabase, Migrator},
 972        Postgres,
 973    };
 974    use std::{path::Path, sync::Arc};
 975    use util::post_inc;
 976
 977    #[tokio::test(flavor = "multi_thread")]
 978    async fn test_get_users_by_ids() {
 979        for test_db in [
 980            TestDb::postgres().await,
 981            TestDb::fake(Arc::new(gpui::executor::Background::new())),
 982        ] {
 983            let db = test_db.db();
 984
 985            let user = db.create_user("user", None, false).await.unwrap();
 986            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
 987            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
 988            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
 989
 990            assert_eq!(
 991                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
 992                    .await
 993                    .unwrap(),
 994                vec![
 995                    User {
 996                        id: user,
 997                        github_login: "user".to_string(),
 998                        admin: false,
 999                        ..Default::default()
1000                    },
1001                    User {
1002                        id: friend1,
1003                        github_login: "friend-1".to_string(),
1004                        admin: false,
1005                        ..Default::default()
1006                    },
1007                    User {
1008                        id: friend2,
1009                        github_login: "friend-2".to_string(),
1010                        admin: false,
1011                        ..Default::default()
1012                    },
1013                    User {
1014                        id: friend3,
1015                        github_login: "friend-3".to_string(),
1016                        admin: false,
1017                        ..Default::default()
1018                    }
1019                ]
1020            );
1021        }
1022    }
1023
1024    #[tokio::test(flavor = "multi_thread")]
1025    async fn test_recent_channel_messages() {
1026        for test_db in [
1027            TestDb::postgres().await,
1028            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1029        ] {
1030            let db = test_db.db();
1031            let user = db.create_user("user", None, false).await.unwrap();
1032            let org = db.create_org("org", "org").await.unwrap();
1033            let channel = db.create_org_channel(org, "channel").await.unwrap();
1034            for i in 0..10 {
1035                db.create_channel_message(
1036                    channel,
1037                    user,
1038                    &i.to_string(),
1039                    OffsetDateTime::now_utc(),
1040                    i,
1041                )
1042                .await
1043                .unwrap();
1044            }
1045
1046            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1047            assert_eq!(
1048                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1049                ["5", "6", "7", "8", "9"]
1050            );
1051
1052            let prev_messages = db
1053                .get_channel_messages(channel, 4, Some(messages[0].id))
1054                .await
1055                .unwrap();
1056            assert_eq!(
1057                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1058                ["1", "2", "3", "4"]
1059            );
1060        }
1061    }
1062
1063    #[tokio::test(flavor = "multi_thread")]
1064    async fn test_channel_message_nonces() {
1065        for test_db in [
1066            TestDb::postgres().await,
1067            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1068        ] {
1069            let db = test_db.db();
1070            let user = db.create_user("user", None, false).await.unwrap();
1071            let org = db.create_org("org", "org").await.unwrap();
1072            let channel = db.create_org_channel(org, "channel").await.unwrap();
1073
1074            let msg1_id = db
1075                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1076                .await
1077                .unwrap();
1078            let msg2_id = db
1079                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1080                .await
1081                .unwrap();
1082            let msg3_id = db
1083                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1084                .await
1085                .unwrap();
1086            let msg4_id = db
1087                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1088                .await
1089                .unwrap();
1090
1091            assert_ne!(msg1_id, msg2_id);
1092            assert_eq!(msg1_id, msg3_id);
1093            assert_eq!(msg2_id, msg4_id);
1094        }
1095    }
1096
1097    #[tokio::test(flavor = "multi_thread")]
1098    async fn test_create_access_tokens() {
1099        let test_db = TestDb::postgres().await;
1100        let db = test_db.db();
1101        let user = db.create_user("the-user", None, false).await.unwrap();
1102
1103        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1104        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1105        assert_eq!(
1106            db.get_access_token_hashes(user).await.unwrap(),
1107            &["h2".to_string(), "h1".to_string()]
1108        );
1109
1110        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1111        assert_eq!(
1112            db.get_access_token_hashes(user).await.unwrap(),
1113            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1114        );
1115
1116        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1117        assert_eq!(
1118            db.get_access_token_hashes(user).await.unwrap(),
1119            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1120        );
1121
1122        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1123        assert_eq!(
1124            db.get_access_token_hashes(user).await.unwrap(),
1125            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1126        );
1127    }
1128
1129    #[test]
1130    fn test_fuzzy_like_string() {
1131        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1132        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1133        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1134    }
1135
1136    #[tokio::test(flavor = "multi_thread")]
1137    async fn test_fuzzy_search_users() {
1138        let test_db = TestDb::postgres().await;
1139        let db = test_db.db();
1140        for github_login in [
1141            "California",
1142            "colorado",
1143            "oregon",
1144            "washington",
1145            "florida",
1146            "delaware",
1147            "rhode-island",
1148        ] {
1149            db.create_user(github_login, None, false).await.unwrap();
1150        }
1151
1152        assert_eq!(
1153            fuzzy_search_user_names(db, "clr").await,
1154            &["colorado", "California"]
1155        );
1156        assert_eq!(
1157            fuzzy_search_user_names(db, "ro").await,
1158            &["rhode-island", "colorado", "oregon"],
1159        );
1160
1161        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1162            db.fuzzy_search_users(query, 10)
1163                .await
1164                .unwrap()
1165                .into_iter()
1166                .map(|user| user.github_login)
1167                .collect::<Vec<_>>()
1168        }
1169    }
1170
1171    #[tokio::test(flavor = "multi_thread")]
1172    async fn test_add_contacts() {
1173        for test_db in [
1174            TestDb::postgres().await,
1175            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1176        ] {
1177            let db = test_db.db();
1178
1179            let user_1 = db.create_user("user1", None, false).await.unwrap();
1180            let user_2 = db.create_user("user2", None, false).await.unwrap();
1181            let user_3 = db.create_user("user3", None, false).await.unwrap();
1182
1183            // User starts with no contacts
1184            assert_eq!(
1185                db.get_contacts(user_1).await.unwrap(),
1186                vec![Contact::Accepted {
1187                    user_id: user_1,
1188                    should_notify: false
1189                }],
1190            );
1191
1192            // User requests a contact. Both users see the pending request.
1193            db.send_contact_request(user_1, user_2).await.unwrap();
1194            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1195            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1196            assert_eq!(
1197                db.get_contacts(user_1).await.unwrap(),
1198                &[
1199                    Contact::Accepted {
1200                        user_id: user_1,
1201                        should_notify: false
1202                    },
1203                    Contact::Outgoing { user_id: user_2 }
1204                ],
1205            );
1206            assert_eq!(
1207                db.get_contacts(user_2).await.unwrap(),
1208                &[
1209                    Contact::Incoming {
1210                        user_id: user_1,
1211                        should_notify: true
1212                    },
1213                    Contact::Accepted {
1214                        user_id: user_2,
1215                        should_notify: false
1216                    },
1217                ]
1218            );
1219
1220            // User 2 dismisses the contact request notification without accepting or rejecting.
1221            // We shouldn't notify them again.
1222            db.dismiss_contact_notification(user_1, user_2)
1223                .await
1224                .unwrap_err();
1225            db.dismiss_contact_notification(user_2, user_1)
1226                .await
1227                .unwrap();
1228            assert_eq!(
1229                db.get_contacts(user_2).await.unwrap(),
1230                &[
1231                    Contact::Incoming {
1232                        user_id: user_1,
1233                        should_notify: false
1234                    },
1235                    Contact::Accepted {
1236                        user_id: user_2,
1237                        should_notify: false
1238                    },
1239                ]
1240            );
1241
1242            // User can't accept their own contact request
1243            db.respond_to_contact_request(user_1, user_2, true)
1244                .await
1245                .unwrap_err();
1246
1247            // User accepts a contact request. Both users see the contact.
1248            db.respond_to_contact_request(user_2, user_1, true)
1249                .await
1250                .unwrap();
1251            assert_eq!(
1252                db.get_contacts(user_1).await.unwrap(),
1253                &[
1254                    Contact::Accepted {
1255                        user_id: user_1,
1256                        should_notify: false
1257                    },
1258                    Contact::Accepted {
1259                        user_id: user_2,
1260                        should_notify: true
1261                    }
1262                ],
1263            );
1264            assert!(db.has_contact(user_1, user_2).await.unwrap());
1265            assert!(db.has_contact(user_2, user_1).await.unwrap());
1266            assert_eq!(
1267                db.get_contacts(user_2).await.unwrap(),
1268                &[
1269                    Contact::Accepted {
1270                        user_id: user_1,
1271                        should_notify: false,
1272                    },
1273                    Contact::Accepted {
1274                        user_id: user_2,
1275                        should_notify: false,
1276                    },
1277                ]
1278            );
1279
1280            // Users cannot re-request existing contacts.
1281            db.send_contact_request(user_1, user_2).await.unwrap_err();
1282            db.send_contact_request(user_2, user_1).await.unwrap_err();
1283
1284            // Users can't dismiss notifications of them accepting other users' requests.
1285            db.dismiss_contact_notification(user_2, user_1)
1286                .await
1287                .unwrap_err();
1288            assert_eq!(
1289                db.get_contacts(user_1).await.unwrap(),
1290                &[
1291                    Contact::Accepted {
1292                        user_id: user_1,
1293                        should_notify: false
1294                    },
1295                    Contact::Accepted {
1296                        user_id: user_2,
1297                        should_notify: true,
1298                    },
1299                ]
1300            );
1301
1302            // Users can dismiss notifications of other users accepting their requests.
1303            db.dismiss_contact_notification(user_1, user_2)
1304                .await
1305                .unwrap();
1306            assert_eq!(
1307                db.get_contacts(user_1).await.unwrap(),
1308                &[
1309                    Contact::Accepted {
1310                        user_id: user_1,
1311                        should_notify: false
1312                    },
1313                    Contact::Accepted {
1314                        user_id: user_2,
1315                        should_notify: false,
1316                    },
1317                ]
1318            );
1319
1320            // Users send each other concurrent contact requests and
1321            // see that they are immediately accepted.
1322            db.send_contact_request(user_1, user_3).await.unwrap();
1323            db.send_contact_request(user_3, user_1).await.unwrap();
1324            assert_eq!(
1325                db.get_contacts(user_1).await.unwrap(),
1326                &[
1327                    Contact::Accepted {
1328                        user_id: user_1,
1329                        should_notify: false
1330                    },
1331                    Contact::Accepted {
1332                        user_id: user_2,
1333                        should_notify: false,
1334                    },
1335                    Contact::Accepted {
1336                        user_id: user_3,
1337                        should_notify: false
1338                    },
1339                ]
1340            );
1341            assert_eq!(
1342                db.get_contacts(user_3).await.unwrap(),
1343                &[
1344                    Contact::Accepted {
1345                        user_id: user_1,
1346                        should_notify: false
1347                    },
1348                    Contact::Accepted {
1349                        user_id: user_3,
1350                        should_notify: false
1351                    }
1352                ],
1353            );
1354
1355            // User declines a contact request. Both users see that it is gone.
1356            db.send_contact_request(user_2, user_3).await.unwrap();
1357            db.respond_to_contact_request(user_3, user_2, false)
1358                .await
1359                .unwrap();
1360            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1361            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1362            assert_eq!(
1363                db.get_contacts(user_2).await.unwrap(),
1364                &[
1365                    Contact::Accepted {
1366                        user_id: user_1,
1367                        should_notify: false
1368                    },
1369                    Contact::Accepted {
1370                        user_id: user_2,
1371                        should_notify: false
1372                    }
1373                ]
1374            );
1375            assert_eq!(
1376                db.get_contacts(user_3).await.unwrap(),
1377                &[
1378                    Contact::Accepted {
1379                        user_id: user_1,
1380                        should_notify: false
1381                    },
1382                    Contact::Accepted {
1383                        user_id: user_3,
1384                        should_notify: false
1385                    }
1386                ],
1387            );
1388        }
1389    }
1390
1391    #[tokio::test(flavor = "multi_thread")]
1392    async fn test_invite_codes() {
1393        let postgres = TestDb::postgres().await;
1394        let db = postgres.db();
1395        let user1 = db.create_user("user-1", None, false).await.unwrap();
1396
1397        // Initially, user 1 has no invite code
1398        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1399
1400        // Setting invite count to 0 when no code is assigned does not assign a new code
1401        db.set_invite_count(user1, 0).await.unwrap();
1402        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1403
1404        // User 1 creates an invite code that can be used twice.
1405        db.set_invite_count(user1, 2).await.unwrap();
1406        let (invite_code, invite_count) =
1407            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1408        assert_eq!(invite_count, 2);
1409
1410        // User 2 redeems the invite code and becomes a contact of user 1.
1411        let user2 = db
1412            .redeem_invite_code(&invite_code, "user-2", None)
1413            .await
1414            .unwrap();
1415        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1416        assert_eq!(invite_count, 1);
1417        assert_eq!(
1418            db.get_contacts(user1).await.unwrap(),
1419            [
1420                Contact::Accepted {
1421                    user_id: user1,
1422                    should_notify: false
1423                },
1424                Contact::Accepted {
1425                    user_id: user2,
1426                    should_notify: true
1427                }
1428            ]
1429        );
1430        assert_eq!(
1431            db.get_contacts(user2).await.unwrap(),
1432            [
1433                Contact::Accepted {
1434                    user_id: user1,
1435                    should_notify: false
1436                },
1437                Contact::Accepted {
1438                    user_id: user2,
1439                    should_notify: false
1440                }
1441            ]
1442        );
1443
1444        // User 3 redeems the invite code and becomes a contact of user 1.
1445        let user3 = db
1446            .redeem_invite_code(&invite_code, "user-3", None)
1447            .await
1448            .unwrap();
1449        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1450        assert_eq!(invite_count, 0);
1451        assert_eq!(
1452            db.get_contacts(user1).await.unwrap(),
1453            [
1454                Contact::Accepted {
1455                    user_id: user1,
1456                    should_notify: false
1457                },
1458                Contact::Accepted {
1459                    user_id: user2,
1460                    should_notify: true
1461                },
1462                Contact::Accepted {
1463                    user_id: user3,
1464                    should_notify: true
1465                }
1466            ]
1467        );
1468        assert_eq!(
1469            db.get_contacts(user3).await.unwrap(),
1470            [
1471                Contact::Accepted {
1472                    user_id: user1,
1473                    should_notify: false
1474                },
1475                Contact::Accepted {
1476                    user_id: user3,
1477                    should_notify: false
1478                },
1479            ]
1480        );
1481
1482        // Trying to reedem the code for the third time results in an error.
1483        db.redeem_invite_code(&invite_code, "user-4", None)
1484            .await
1485            .unwrap_err();
1486
1487        // Invite count can be updated after the code has been created.
1488        db.set_invite_count(user1, 2).await.unwrap();
1489        let (latest_code, invite_count) =
1490            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1491        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1492        assert_eq!(invite_count, 2);
1493
1494        // User 4 can now redeem the invite code and becomes a contact of user 1.
1495        let user4 = db
1496            .redeem_invite_code(&invite_code, "user-4", None)
1497            .await
1498            .unwrap();
1499        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1500        assert_eq!(invite_count, 1);
1501        assert_eq!(
1502            db.get_contacts(user1).await.unwrap(),
1503            [
1504                Contact::Accepted {
1505                    user_id: user1,
1506                    should_notify: false
1507                },
1508                Contact::Accepted {
1509                    user_id: user2,
1510                    should_notify: true
1511                },
1512                Contact::Accepted {
1513                    user_id: user3,
1514                    should_notify: true
1515                },
1516                Contact::Accepted {
1517                    user_id: user4,
1518                    should_notify: true
1519                }
1520            ]
1521        );
1522        assert_eq!(
1523            db.get_contacts(user4).await.unwrap(),
1524            [
1525                Contact::Accepted {
1526                    user_id: user1,
1527                    should_notify: false
1528                },
1529                Contact::Accepted {
1530                    user_id: user4,
1531                    should_notify: false
1532                },
1533            ]
1534        );
1535
1536        // An existing user cannot redeem invite codes.
1537        db.redeem_invite_code(&invite_code, "user-2", None)
1538            .await
1539            .unwrap_err();
1540        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1541        assert_eq!(invite_count, 1);
1542    }
1543
1544    pub struct TestDb {
1545        pub db: Option<Arc<dyn Db>>,
1546        pub url: String,
1547    }
1548
1549    impl TestDb {
1550        pub async fn postgres() -> Self {
1551            lazy_static! {
1552                static ref LOCK: Mutex<()> = Mutex::new(());
1553            }
1554
1555            let _guard = LOCK.lock();
1556            let mut rng = StdRng::from_entropy();
1557            let name = format!("zed-test-{}", rng.gen::<u128>());
1558            let url = format!("postgres://postgres@localhost/{}", name);
1559            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1560            Postgres::create_database(&url)
1561                .await
1562                .expect("failed to create test db");
1563            let db = PostgresDb::new(&url, 5).await.unwrap();
1564            let migrator = Migrator::new(migrations_path).await.unwrap();
1565            migrator.run(&db.pool).await.unwrap();
1566            Self {
1567                db: Some(Arc::new(db)),
1568                url,
1569            }
1570        }
1571
1572        pub fn fake(background: Arc<Background>) -> Self {
1573            Self {
1574                db: Some(Arc::new(FakeDb::new(background))),
1575                url: Default::default(),
1576            }
1577        }
1578
1579        pub fn db(&self) -> &Arc<dyn Db> {
1580            self.db.as_ref().unwrap()
1581        }
1582    }
1583
1584    impl Drop for TestDb {
1585        fn drop(&mut self) {
1586            if let Some(db) = self.db.take() {
1587                futures::executor::block_on(db.teardown(&self.url));
1588            }
1589        }
1590    }
1591
1592    pub struct FakeDb {
1593        background: Arc<Background>,
1594        pub users: Mutex<BTreeMap<UserId, User>>,
1595        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1596        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1597        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1598        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1599        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1600        pub contacts: Mutex<Vec<FakeContact>>,
1601        next_channel_message_id: Mutex<i32>,
1602        next_user_id: Mutex<i32>,
1603        next_org_id: Mutex<i32>,
1604        next_channel_id: Mutex<i32>,
1605    }
1606
1607    #[derive(Debug)]
1608    pub struct FakeContact {
1609        pub requester_id: UserId,
1610        pub responder_id: UserId,
1611        pub accepted: bool,
1612        pub should_notify: bool,
1613    }
1614
1615    impl FakeDb {
1616        pub fn new(background: Arc<Background>) -> Self {
1617            Self {
1618                background,
1619                users: Default::default(),
1620                next_user_id: Mutex::new(1),
1621                orgs: Default::default(),
1622                next_org_id: Mutex::new(1),
1623                org_memberships: Default::default(),
1624                channels: Default::default(),
1625                next_channel_id: Mutex::new(1),
1626                channel_memberships: Default::default(),
1627                channel_messages: Default::default(),
1628                next_channel_message_id: Mutex::new(1),
1629                contacts: Default::default(),
1630            }
1631        }
1632    }
1633
1634    #[async_trait]
1635    impl Db for FakeDb {
1636        async fn create_user(
1637            &self,
1638            github_login: &str,
1639            email_address: Option<&str>,
1640            admin: bool,
1641        ) -> Result<UserId> {
1642            self.background.simulate_random_delay().await;
1643
1644            let mut users = self.users.lock();
1645            if let Some(user) = users
1646                .values()
1647                .find(|user| user.github_login == github_login)
1648            {
1649                Ok(user.id)
1650            } else {
1651                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1652                users.insert(
1653                    user_id,
1654                    User {
1655                        id: user_id,
1656                        github_login: github_login.to_string(),
1657                        email_address: email_address.map(str::to_string),
1658                        admin,
1659                        invite_code: None,
1660                        invite_count: 0,
1661                        connected_once: false,
1662                    },
1663                );
1664                Ok(user_id)
1665            }
1666        }
1667
1668        async fn get_all_users(&self) -> Result<Vec<User>> {
1669            unimplemented!()
1670        }
1671
1672        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1673            unimplemented!()
1674        }
1675
1676        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1677            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1678        }
1679
1680        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1681            self.background.simulate_random_delay().await;
1682            let users = self.users.lock();
1683            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1684        }
1685
1686        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1687            Ok(self
1688                .users
1689                .lock()
1690                .values()
1691                .find(|user| user.github_login == github_login)
1692                .cloned())
1693        }
1694
1695        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1696            unimplemented!()
1697        }
1698
1699        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1700            self.background.simulate_random_delay().await;
1701            let mut users = self.users.lock();
1702            let mut user = users
1703                .get_mut(&id)
1704                .ok_or_else(|| anyhow!("user not found"))?;
1705            user.connected_once = connected_once;
1706            Ok(())
1707        }
1708
1709        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1710            unimplemented!()
1711        }
1712
1713        // invite codes
1714
1715        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1716            unimplemented!()
1717        }
1718
1719        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1720            Ok(None)
1721        }
1722
1723        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1724            unimplemented!()
1725        }
1726
1727        async fn redeem_invite_code(
1728            &self,
1729            _code: &str,
1730            _login: &str,
1731            _email_address: Option<&str>,
1732        ) -> Result<UserId> {
1733            unimplemented!()
1734        }
1735
1736        // contacts
1737
1738        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1739            self.background.simulate_random_delay().await;
1740            let mut contacts = vec![Contact::Accepted {
1741                user_id: id,
1742                should_notify: false,
1743            }];
1744
1745            for contact in self.contacts.lock().iter() {
1746                if contact.requester_id == id {
1747                    if contact.accepted {
1748                        contacts.push(Contact::Accepted {
1749                            user_id: contact.responder_id,
1750                            should_notify: contact.should_notify,
1751                        });
1752                    } else {
1753                        contacts.push(Contact::Outgoing {
1754                            user_id: contact.responder_id,
1755                        });
1756                    }
1757                } else if contact.responder_id == id {
1758                    if contact.accepted {
1759                        contacts.push(Contact::Accepted {
1760                            user_id: contact.requester_id,
1761                            should_notify: false,
1762                        });
1763                    } else {
1764                        contacts.push(Contact::Incoming {
1765                            user_id: contact.requester_id,
1766                            should_notify: contact.should_notify,
1767                        });
1768                    }
1769                }
1770            }
1771
1772            contacts.sort_unstable_by_key(|contact| contact.user_id());
1773            Ok(contacts)
1774        }
1775
1776        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1777            self.background.simulate_random_delay().await;
1778            Ok(self.contacts.lock().iter().any(|contact| {
1779                contact.accepted
1780                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1781                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1782            }))
1783        }
1784
1785        async fn send_contact_request(
1786            &self,
1787            requester_id: UserId,
1788            responder_id: UserId,
1789        ) -> Result<()> {
1790            let mut contacts = self.contacts.lock();
1791            for contact in contacts.iter_mut() {
1792                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1793                    if contact.accepted {
1794                        Err(anyhow!("contact already exists"))?;
1795                    } else {
1796                        Err(anyhow!("contact already requested"))?;
1797                    }
1798                }
1799                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1800                    if contact.accepted {
1801                        Err(anyhow!("contact already exists"))?;
1802                    } else {
1803                        contact.accepted = true;
1804                        contact.should_notify = false;
1805                        return Ok(());
1806                    }
1807                }
1808            }
1809            contacts.push(FakeContact {
1810                requester_id,
1811                responder_id,
1812                accepted: false,
1813                should_notify: true,
1814            });
1815            Ok(())
1816        }
1817
1818        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1819            self.contacts.lock().retain(|contact| {
1820                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1821            });
1822            Ok(())
1823        }
1824
1825        async fn dismiss_contact_notification(
1826            &self,
1827            user_id: UserId,
1828            contact_user_id: UserId,
1829        ) -> Result<()> {
1830            let mut contacts = self.contacts.lock();
1831            for contact in contacts.iter_mut() {
1832                if contact.requester_id == contact_user_id
1833                    && contact.responder_id == user_id
1834                    && !contact.accepted
1835                {
1836                    contact.should_notify = false;
1837                    return Ok(());
1838                }
1839                if contact.requester_id == user_id
1840                    && contact.responder_id == contact_user_id
1841                    && contact.accepted
1842                {
1843                    contact.should_notify = false;
1844                    return Ok(());
1845                }
1846            }
1847            Err(anyhow!("no such notification"))?
1848        }
1849
1850        async fn respond_to_contact_request(
1851            &self,
1852            responder_id: UserId,
1853            requester_id: UserId,
1854            accept: bool,
1855        ) -> Result<()> {
1856            let mut contacts = self.contacts.lock();
1857            for (ix, contact) in contacts.iter_mut().enumerate() {
1858                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1859                    if contact.accepted {
1860                        Err(anyhow!("contact already confirmed"))?;
1861                    }
1862                    if accept {
1863                        contact.accepted = true;
1864                        contact.should_notify = true;
1865                    } else {
1866                        contacts.remove(ix);
1867                    }
1868                    return Ok(());
1869                }
1870            }
1871            Err(anyhow!("no such contact request"))?
1872        }
1873
1874        async fn create_access_token_hash(
1875            &self,
1876            _user_id: UserId,
1877            _access_token_hash: &str,
1878            _max_access_token_count: usize,
1879        ) -> Result<()> {
1880            unimplemented!()
1881        }
1882
1883        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1884            unimplemented!()
1885        }
1886
1887        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1888            unimplemented!()
1889        }
1890
1891        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1892            self.background.simulate_random_delay().await;
1893            let mut orgs = self.orgs.lock();
1894            if orgs.values().any(|org| org.slug == slug) {
1895                Err(anyhow!("org already exists"))?
1896            } else {
1897                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
1898                orgs.insert(
1899                    org_id,
1900                    Org {
1901                        id: org_id,
1902                        name: name.to_string(),
1903                        slug: slug.to_string(),
1904                    },
1905                );
1906                Ok(org_id)
1907            }
1908        }
1909
1910        async fn add_org_member(
1911            &self,
1912            org_id: OrgId,
1913            user_id: UserId,
1914            is_admin: bool,
1915        ) -> Result<()> {
1916            self.background.simulate_random_delay().await;
1917            if !self.orgs.lock().contains_key(&org_id) {
1918                Err(anyhow!("org does not exist"))?;
1919            }
1920            if !self.users.lock().contains_key(&user_id) {
1921                Err(anyhow!("user does not exist"))?;
1922            }
1923
1924            self.org_memberships
1925                .lock()
1926                .entry((org_id, user_id))
1927                .or_insert(is_admin);
1928            Ok(())
1929        }
1930
1931        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1932            self.background.simulate_random_delay().await;
1933            if !self.orgs.lock().contains_key(&org_id) {
1934                Err(anyhow!("org does not exist"))?;
1935            }
1936
1937            let mut channels = self.channels.lock();
1938            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
1939            channels.insert(
1940                channel_id,
1941                Channel {
1942                    id: channel_id,
1943                    name: name.to_string(),
1944                    owner_id: org_id.0,
1945                    owner_is_user: false,
1946                },
1947            );
1948            Ok(channel_id)
1949        }
1950
1951        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1952            self.background.simulate_random_delay().await;
1953            Ok(self
1954                .channels
1955                .lock()
1956                .values()
1957                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
1958                .cloned()
1959                .collect())
1960        }
1961
1962        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1963            self.background.simulate_random_delay().await;
1964            let channels = self.channels.lock();
1965            let memberships = self.channel_memberships.lock();
1966            Ok(channels
1967                .values()
1968                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
1969                .cloned()
1970                .collect())
1971        }
1972
1973        async fn can_user_access_channel(
1974            &self,
1975            user_id: UserId,
1976            channel_id: ChannelId,
1977        ) -> Result<bool> {
1978            self.background.simulate_random_delay().await;
1979            Ok(self
1980                .channel_memberships
1981                .lock()
1982                .contains_key(&(channel_id, user_id)))
1983        }
1984
1985        async fn add_channel_member(
1986            &self,
1987            channel_id: ChannelId,
1988            user_id: UserId,
1989            is_admin: bool,
1990        ) -> Result<()> {
1991            self.background.simulate_random_delay().await;
1992            if !self.channels.lock().contains_key(&channel_id) {
1993                Err(anyhow!("channel does not exist"))?;
1994            }
1995            if !self.users.lock().contains_key(&user_id) {
1996                Err(anyhow!("user does not exist"))?;
1997            }
1998
1999            self.channel_memberships
2000                .lock()
2001                .entry((channel_id, user_id))
2002                .or_insert(is_admin);
2003            Ok(())
2004        }
2005
2006        async fn create_channel_message(
2007            &self,
2008            channel_id: ChannelId,
2009            sender_id: UserId,
2010            body: &str,
2011            timestamp: OffsetDateTime,
2012            nonce: u128,
2013        ) -> Result<MessageId> {
2014            self.background.simulate_random_delay().await;
2015            if !self.channels.lock().contains_key(&channel_id) {
2016                Err(anyhow!("channel does not exist"))?;
2017            }
2018            if !self.users.lock().contains_key(&sender_id) {
2019                Err(anyhow!("user does not exist"))?;
2020            }
2021
2022            let mut messages = self.channel_messages.lock();
2023            if let Some(message) = messages
2024                .values()
2025                .find(|message| message.nonce.as_u128() == nonce)
2026            {
2027                Ok(message.id)
2028            } else {
2029                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2030                messages.insert(
2031                    message_id,
2032                    ChannelMessage {
2033                        id: message_id,
2034                        channel_id,
2035                        sender_id,
2036                        body: body.to_string(),
2037                        sent_at: timestamp,
2038                        nonce: Uuid::from_u128(nonce),
2039                    },
2040                );
2041                Ok(message_id)
2042            }
2043        }
2044
2045        async fn get_channel_messages(
2046            &self,
2047            channel_id: ChannelId,
2048            count: usize,
2049            before_id: Option<MessageId>,
2050        ) -> Result<Vec<ChannelMessage>> {
2051            let mut messages = self
2052                .channel_messages
2053                .lock()
2054                .values()
2055                .rev()
2056                .filter(|message| {
2057                    message.channel_id == channel_id
2058                        && message.id < before_id.unwrap_or(MessageId::MAX)
2059                })
2060                .take(count)
2061                .cloned()
2062                .collect::<Vec<_>>();
2063            messages.sort_unstable_by_key(|message| message.id);
2064            Ok(messages)
2065        }
2066
2067        async fn teardown(&self, _: &str) {}
2068
2069        #[cfg(test)]
2070        fn as_fake(&self) -> Option<&FakeDb> {
2071            Some(self)
2072        }
2073    }
2074}