db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use futures::StreamExt;
   6use nanoid::nanoid;
   7use serde::Serialize;
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  10use time::OffsetDateTime;
  11
  12#[async_trait]
  13pub trait Db: Send + Sync {
  14    async fn create_user(
  15        &self,
  16        github_login: &str,
  17        email_address: Option<&str>,
  18        admin: bool,
  19    ) -> Result<UserId>;
  20    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  21    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  22    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  23    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  24    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  25    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  26    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  27    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  28    async fn destroy_user(&self, id: UserId) -> Result<()>;
  29
  30    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  31    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  32    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  33    async fn redeem_invite_code(
  34        &self,
  35        code: &str,
  36        login: &str,
  37        email_address: Option<&str>,
  38    ) -> Result<UserId>;
  39
  40    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  41    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  42    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  43    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  44    async fn dismiss_contact_notification(
  45        &self,
  46        responder_id: UserId,
  47        requester_id: UserId,
  48    ) -> Result<()>;
  49    async fn respond_to_contact_request(
  50        &self,
  51        responder_id: UserId,
  52        requester_id: UserId,
  53        accept: bool,
  54    ) -> Result<()>;
  55
  56    async fn create_access_token_hash(
  57        &self,
  58        user_id: UserId,
  59        access_token_hash: &str,
  60        max_access_token_count: usize,
  61    ) -> Result<()>;
  62    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  63    #[cfg(any(test, feature = "seed-support"))]
  64
  65    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  66    #[cfg(any(test, feature = "seed-support"))]
  67    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  68    #[cfg(any(test, feature = "seed-support"))]
  69    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  70    #[cfg(any(test, feature = "seed-support"))]
  71    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  72    #[cfg(any(test, feature = "seed-support"))]
  73
  74    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  75    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  76    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  77        -> Result<bool>;
  78    #[cfg(any(test, feature = "seed-support"))]
  79    async fn add_channel_member(
  80        &self,
  81        channel_id: ChannelId,
  82        user_id: UserId,
  83        is_admin: bool,
  84    ) -> Result<()>;
  85    async fn create_channel_message(
  86        &self,
  87        channel_id: ChannelId,
  88        sender_id: UserId,
  89        body: &str,
  90        timestamp: OffsetDateTime,
  91        nonce: u128,
  92    ) -> Result<MessageId>;
  93    async fn get_channel_messages(
  94        &self,
  95        channel_id: ChannelId,
  96        count: usize,
  97        before_id: Option<MessageId>,
  98    ) -> Result<Vec<ChannelMessage>>;
  99    #[cfg(test)]
 100    async fn teardown(&self, url: &str);
 101    #[cfg(test)]
 102    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 103}
 104
 105pub struct PostgresDb {
 106    pool: sqlx::PgPool,
 107}
 108
 109impl PostgresDb {
 110    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 111        let pool = DbOptions::new()
 112            .max_connections(max_connections)
 113            .connect(&url)
 114            .await
 115            .context("failed to connect to postgres database")?;
 116        Ok(Self { pool })
 117    }
 118}
 119
 120#[async_trait]
 121impl Db for PostgresDb {
 122    // users
 123
 124    async fn create_user(
 125        &self,
 126        github_login: &str,
 127        email_address: Option<&str>,
 128        admin: bool,
 129    ) -> Result<UserId> {
 130        let query = "
 131            INSERT INTO users (github_login, email_address, admin)
 132            VALUES ($1, $2, $3)
 133            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 134            RETURNING id
 135        ";
 136        Ok(sqlx::query_scalar(query)
 137            .bind(github_login)
 138            .bind(email_address)
 139            .bind(admin)
 140            .fetch_one(&self.pool)
 141            .await
 142            .map(UserId)?)
 143    }
 144
 145    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 146        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 147        Ok(sqlx::query_as(query)
 148            .bind(limit)
 149            .bind(page * limit)
 150            .fetch_all(&self.pool)
 151            .await?)
 152    }
 153  
 154    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 155        let mut query = QueryBuilder::new(
 156            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 157        );
 158        query.push_values(
 159            users,
 160            |mut query, (github_login, email_address, invite_count)| {
 161                query
 162                    .push_bind(github_login)
 163                    .push_bind(email_address)
 164                    .push_bind(false)
 165                    .push_bind(nanoid!(16))
 166                    .push_bind(invite_count as u32);
 167            },
 168        );
 169        query.push(
 170            "
 171            ON CONFLICT (github_login) DO UPDATE SET
 172                github_login = excluded.github_login,
 173                invite_count = excluded.invite_count,
 174                invite_code = CASE WHEN users.invite_code IS NULL
 175                                   THEN excluded.invite_code
 176                                   ELSE users.invite_code
 177                              END
 178            RETURNING id
 179            ",
 180        );
 181
 182        let rows = query.build().fetch_all(&self.pool).await?;
 183        Ok(rows
 184            .into_iter()
 185            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 186            .collect())
 187    }
 188
 189    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 190        let like_string = fuzzy_like_string(name_query);
 191        let query = "
 192            SELECT users.*
 193            FROM users
 194            WHERE github_login ILIKE $1
 195            ORDER BY github_login <-> $2
 196            LIMIT $3
 197        ";
 198        Ok(sqlx::query_as(query)
 199            .bind(like_string)
 200            .bind(name_query)
 201            .bind(limit)
 202            .fetch_all(&self.pool)
 203            .await?)
 204    }
 205
 206    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 207        let users = self.get_users_by_ids(vec![id]).await?;
 208        Ok(users.into_iter().next())
 209    }
 210
 211    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 212        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 213        let query = "
 214            SELECT users.*
 215            FROM users
 216            WHERE users.id = ANY ($1)
 217        ";
 218        Ok(sqlx::query_as(query)
 219            .bind(&ids)
 220            .fetch_all(&self.pool)
 221            .await?)
 222    }
 223
 224    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 225        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 226        Ok(sqlx::query_as(query)
 227            .bind(github_login)
 228            .fetch_optional(&self.pool)
 229            .await?)
 230    }
 231
 232    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 233        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 234        Ok(sqlx::query(query)
 235            .bind(is_admin)
 236            .bind(id.0)
 237            .execute(&self.pool)
 238            .await
 239            .map(drop)?)
 240    }
 241
 242    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 243        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 244        Ok(sqlx::query(query)
 245            .bind(connected_once)
 246            .bind(id.0)
 247            .execute(&self.pool)
 248            .await
 249            .map(drop)?)
 250    }
 251
 252    async fn destroy_user(&self, id: UserId) -> Result<()> {
 253        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 254        sqlx::query(query)
 255            .bind(id.0)
 256            .execute(&self.pool)
 257            .await
 258            .map(drop)?;
 259        let query = "DELETE FROM users WHERE id = $1;";
 260        Ok(sqlx::query(query)
 261            .bind(id.0)
 262            .execute(&self.pool)
 263            .await
 264            .map(drop)?)
 265    }
 266
 267    // invite codes
 268
 269    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 270        let mut tx = self.pool.begin().await?;
 271        if count > 0 {
 272            sqlx::query(
 273                "
 274                UPDATE users
 275                SET invite_code = $1
 276                WHERE id = $2 AND invite_code IS NULL
 277            ",
 278            )
 279            .bind(nanoid!(16))
 280            .bind(id)
 281            .execute(&mut tx)
 282            .await?;
 283        }
 284
 285        sqlx::query(
 286            "
 287            UPDATE users
 288            SET invite_count = $1
 289            WHERE id = $2
 290            ",
 291        )
 292        .bind(count)
 293        .bind(id)
 294        .execute(&mut tx)
 295        .await?;
 296        tx.commit().await?;
 297        Ok(())
 298    }
 299
 300    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 301        let result: Option<(String, i32)> = sqlx::query_as(
 302            "
 303                SELECT invite_code, invite_count
 304                FROM users
 305                WHERE id = $1 AND invite_code IS NOT NULL 
 306            ",
 307        )
 308        .bind(id)
 309        .fetch_optional(&self.pool)
 310        .await?;
 311        if let Some((code, count)) = result {
 312            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 313        } else {
 314            Ok(None)
 315        }
 316    }
 317
 318    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 319        sqlx::query_as(
 320            "
 321                SELECT *
 322                FROM users
 323                WHERE invite_code = $1
 324            ",
 325        )
 326        .bind(code)
 327        .fetch_optional(&self.pool)
 328        .await?
 329        .ok_or_else(|| {
 330            Error::Http(
 331                StatusCode::NOT_FOUND,
 332                "that invite code does not exist".to_string(),
 333            )
 334        })
 335    }
 336
 337    async fn redeem_invite_code(
 338        &self,
 339        code: &str,
 340        login: &str,
 341        email_address: Option<&str>,
 342    ) -> Result<UserId> {
 343        let mut tx = self.pool.begin().await?;
 344
 345        let inviter_id: Option<UserId> = sqlx::query_scalar(
 346            "
 347                UPDATE users
 348                SET invite_count = invite_count - 1
 349                WHERE
 350                    invite_code = $1 AND
 351                    invite_count > 0
 352                RETURNING id
 353            ",
 354        )
 355        .bind(code)
 356        .fetch_optional(&mut tx)
 357        .await?;
 358
 359        let inviter_id = match inviter_id {
 360            Some(inviter_id) => inviter_id,
 361            None => {
 362                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 363                    .bind(code)
 364                    .fetch_optional(&mut tx)
 365                    .await?
 366                    .is_some()
 367                {
 368                    Err(Error::Http(
 369                        StatusCode::UNAUTHORIZED,
 370                        "no invites remaining".to_string(),
 371                    ))?
 372                } else {
 373                    Err(Error::Http(
 374                        StatusCode::NOT_FOUND,
 375                        "invite code not found".to_string(),
 376                    ))?
 377                }
 378            }
 379        };
 380
 381        let invitee_id = sqlx::query_scalar(
 382            "
 383                INSERT INTO users
 384                    (github_login, email_address, admin, inviter_id)
 385                VALUES
 386                    ($1, $2, 'f', $3)
 387                RETURNING id
 388            ",
 389        )
 390        .bind(login)
 391        .bind(email_address)
 392        .bind(inviter_id)
 393        .fetch_one(&mut tx)
 394        .await
 395        .map(UserId)?;
 396
 397        sqlx::query(
 398            "
 399                INSERT INTO contacts
 400                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 401                VALUES
 402                    ($1, $2, 't', 't', 't')
 403            ",
 404        )
 405        .bind(inviter_id)
 406        .bind(invitee_id)
 407        .execute(&mut tx)
 408        .await?;
 409
 410        tx.commit().await?;
 411        Ok(invitee_id)
 412    }
 413
 414    // contacts
 415
 416    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 417        let query = "
 418            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 419            FROM contacts
 420            WHERE user_id_a = $1 OR user_id_b = $1;
 421        ";
 422
 423        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 424            .bind(user_id)
 425            .fetch(&self.pool);
 426
 427        let mut contacts = vec![Contact::Accepted {
 428            user_id,
 429            should_notify: false,
 430        }];
 431        while let Some(row) = rows.next().await {
 432            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 433
 434            if user_id_a == user_id {
 435                if accepted {
 436                    contacts.push(Contact::Accepted {
 437                        user_id: user_id_b,
 438                        should_notify: should_notify && a_to_b,
 439                    });
 440                } else if a_to_b {
 441                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 442                } else {
 443                    contacts.push(Contact::Incoming {
 444                        user_id: user_id_b,
 445                        should_notify,
 446                    });
 447                }
 448            } else {
 449                if accepted {
 450                    contacts.push(Contact::Accepted {
 451                        user_id: user_id_a,
 452                        should_notify: should_notify && !a_to_b,
 453                    });
 454                } else if a_to_b {
 455                    contacts.push(Contact::Incoming {
 456                        user_id: user_id_a,
 457                        should_notify,
 458                    });
 459                } else {
 460                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 461                }
 462            }
 463        }
 464
 465        contacts.sort_unstable_by_key(|contact| contact.user_id());
 466
 467        Ok(contacts)
 468    }
 469
 470    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 471        let (id_a, id_b) = if user_id_1 < user_id_2 {
 472            (user_id_1, user_id_2)
 473        } else {
 474            (user_id_2, user_id_1)
 475        };
 476
 477        let query = "
 478            SELECT 1 FROM contacts
 479            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 480            LIMIT 1
 481        ";
 482        Ok(sqlx::query_scalar::<_, i32>(query)
 483            .bind(id_a.0)
 484            .bind(id_b.0)
 485            .fetch_optional(&self.pool)
 486            .await?
 487            .is_some())
 488    }
 489
 490    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 491        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 492            (sender_id, receiver_id, true)
 493        } else {
 494            (receiver_id, sender_id, false)
 495        };
 496        let query = "
 497            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 498            VALUES ($1, $2, $3, 'f', 't')
 499            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 500            SET
 501                accepted = 't',
 502                should_notify = 'f'
 503            WHERE
 504                NOT contacts.accepted AND
 505                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 506                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 507        ";
 508        let result = sqlx::query(query)
 509            .bind(id_a.0)
 510            .bind(id_b.0)
 511            .bind(a_to_b)
 512            .execute(&self.pool)
 513            .await?;
 514
 515        if result.rows_affected() == 1 {
 516            Ok(())
 517        } else {
 518            Err(anyhow!("contact already requested"))?
 519        }
 520    }
 521
 522    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 523        let (id_a, id_b) = if responder_id < requester_id {
 524            (responder_id, requester_id)
 525        } else {
 526            (requester_id, responder_id)
 527        };
 528        let query = "
 529            DELETE FROM contacts
 530            WHERE user_id_a = $1 AND user_id_b = $2;
 531        ";
 532        let result = sqlx::query(query)
 533            .bind(id_a.0)
 534            .bind(id_b.0)
 535            .execute(&self.pool)
 536            .await?;
 537
 538        if result.rows_affected() == 1 {
 539            Ok(())
 540        } else {
 541            Err(anyhow!("no such contact"))?
 542        }
 543    }
 544
 545    async fn dismiss_contact_notification(
 546        &self,
 547        user_id: UserId,
 548        contact_user_id: UserId,
 549    ) -> Result<()> {
 550        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 551            (user_id, contact_user_id, true)
 552        } else {
 553            (contact_user_id, user_id, false)
 554        };
 555
 556        let query = "
 557            UPDATE contacts
 558            SET should_notify = 'f'
 559            WHERE
 560                user_id_a = $1 AND user_id_b = $2 AND
 561                (
 562                    (a_to_b = $3 AND accepted) OR
 563                    (a_to_b != $3 AND NOT accepted)
 564                );
 565        ";
 566
 567        let result = sqlx::query(query)
 568            .bind(id_a.0)
 569            .bind(id_b.0)
 570            .bind(a_to_b)
 571            .execute(&self.pool)
 572            .await?;
 573
 574        if result.rows_affected() == 0 {
 575            Err(anyhow!("no such contact request"))?;
 576        }
 577
 578        Ok(())
 579    }
 580
 581    async fn respond_to_contact_request(
 582        &self,
 583        responder_id: UserId,
 584        requester_id: UserId,
 585        accept: bool,
 586    ) -> Result<()> {
 587        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 588            (responder_id, requester_id, false)
 589        } else {
 590            (requester_id, responder_id, true)
 591        };
 592        let result = if accept {
 593            let query = "
 594                UPDATE contacts
 595                SET accepted = 't', should_notify = 't'
 596                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 597            ";
 598            sqlx::query(query)
 599                .bind(id_a.0)
 600                .bind(id_b.0)
 601                .bind(a_to_b)
 602                .execute(&self.pool)
 603                .await?
 604        } else {
 605            let query = "
 606                DELETE FROM contacts
 607                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 608            ";
 609            sqlx::query(query)
 610                .bind(id_a.0)
 611                .bind(id_b.0)
 612                .bind(a_to_b)
 613                .execute(&self.pool)
 614                .await?
 615        };
 616        if result.rows_affected() == 1 {
 617            Ok(())
 618        } else {
 619            Err(anyhow!("no such contact request"))?
 620        }
 621    }
 622
 623    // access tokens
 624
 625    async fn create_access_token_hash(
 626        &self,
 627        user_id: UserId,
 628        access_token_hash: &str,
 629        max_access_token_count: usize,
 630    ) -> Result<()> {
 631        let insert_query = "
 632            INSERT INTO access_tokens (user_id, hash)
 633            VALUES ($1, $2);
 634        ";
 635        let cleanup_query = "
 636            DELETE FROM access_tokens
 637            WHERE id IN (
 638                SELECT id from access_tokens
 639                WHERE user_id = $1
 640                ORDER BY id DESC
 641                OFFSET $3
 642            )
 643        ";
 644
 645        let mut tx = self.pool.begin().await?;
 646        sqlx::query(insert_query)
 647            .bind(user_id.0)
 648            .bind(access_token_hash)
 649            .execute(&mut tx)
 650            .await?;
 651        sqlx::query(cleanup_query)
 652            .bind(user_id.0)
 653            .bind(access_token_hash)
 654            .bind(max_access_token_count as u32)
 655            .execute(&mut tx)
 656            .await?;
 657        Ok(tx.commit().await?)
 658    }
 659
 660    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 661        let query = "
 662            SELECT hash
 663            FROM access_tokens
 664            WHERE user_id = $1
 665            ORDER BY id DESC
 666        ";
 667        Ok(sqlx::query_scalar(query)
 668            .bind(user_id.0)
 669            .fetch_all(&self.pool)
 670            .await?)
 671    }
 672
 673    // orgs
 674
 675    #[allow(unused)] // Help rust-analyzer
 676    #[cfg(any(test, feature = "seed-support"))]
 677    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 678        let query = "
 679            SELECT *
 680            FROM orgs
 681            WHERE slug = $1
 682        ";
 683        Ok(sqlx::query_as(query)
 684            .bind(slug)
 685            .fetch_optional(&self.pool)
 686            .await?)
 687    }
 688
 689    #[cfg(any(test, feature = "seed-support"))]
 690    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 691        let query = "
 692            INSERT INTO orgs (name, slug)
 693            VALUES ($1, $2)
 694            RETURNING id
 695        ";
 696        Ok(sqlx::query_scalar(query)
 697            .bind(name)
 698            .bind(slug)
 699            .fetch_one(&self.pool)
 700            .await
 701            .map(OrgId)?)
 702    }
 703
 704    #[cfg(any(test, feature = "seed-support"))]
 705    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 706        let query = "
 707            INSERT INTO org_memberships (org_id, user_id, admin)
 708            VALUES ($1, $2, $3)
 709            ON CONFLICT DO NOTHING
 710        ";
 711        Ok(sqlx::query(query)
 712            .bind(org_id.0)
 713            .bind(user_id.0)
 714            .bind(is_admin)
 715            .execute(&self.pool)
 716            .await
 717            .map(drop)?)
 718    }
 719
 720    // channels
 721
 722    #[cfg(any(test, feature = "seed-support"))]
 723    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 724        let query = "
 725            INSERT INTO channels (owner_id, owner_is_user, name)
 726            VALUES ($1, false, $2)
 727            RETURNING id
 728        ";
 729        Ok(sqlx::query_scalar(query)
 730            .bind(org_id.0)
 731            .bind(name)
 732            .fetch_one(&self.pool)
 733            .await
 734            .map(ChannelId)?)
 735    }
 736
 737    #[allow(unused)] // Help rust-analyzer
 738    #[cfg(any(test, feature = "seed-support"))]
 739    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 740        let query = "
 741            SELECT *
 742            FROM channels
 743            WHERE
 744                channels.owner_is_user = false AND
 745                channels.owner_id = $1
 746        ";
 747        Ok(sqlx::query_as(query)
 748            .bind(org_id.0)
 749            .fetch_all(&self.pool)
 750            .await?)
 751    }
 752
 753    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 754        let query = "
 755            SELECT
 756                channels.*
 757            FROM
 758                channel_memberships, channels
 759            WHERE
 760                channel_memberships.user_id = $1 AND
 761                channel_memberships.channel_id = channels.id
 762        ";
 763        Ok(sqlx::query_as(query)
 764            .bind(user_id.0)
 765            .fetch_all(&self.pool)
 766            .await?)
 767    }
 768
 769    async fn can_user_access_channel(
 770        &self,
 771        user_id: UserId,
 772        channel_id: ChannelId,
 773    ) -> Result<bool> {
 774        let query = "
 775            SELECT id
 776            FROM channel_memberships
 777            WHERE user_id = $1 AND channel_id = $2
 778            LIMIT 1
 779        ";
 780        Ok(sqlx::query_scalar::<_, i32>(query)
 781            .bind(user_id.0)
 782            .bind(channel_id.0)
 783            .fetch_optional(&self.pool)
 784            .await
 785            .map(|e| e.is_some())?)
 786    }
 787
 788    #[cfg(any(test, feature = "seed-support"))]
 789    async fn add_channel_member(
 790        &self,
 791        channel_id: ChannelId,
 792        user_id: UserId,
 793        is_admin: bool,
 794    ) -> Result<()> {
 795        let query = "
 796            INSERT INTO channel_memberships (channel_id, user_id, admin)
 797            VALUES ($1, $2, $3)
 798            ON CONFLICT DO NOTHING
 799        ";
 800        Ok(sqlx::query(query)
 801            .bind(channel_id.0)
 802            .bind(user_id.0)
 803            .bind(is_admin)
 804            .execute(&self.pool)
 805            .await
 806            .map(drop)?)
 807    }
 808
 809    // messages
 810
 811    async fn create_channel_message(
 812        &self,
 813        channel_id: ChannelId,
 814        sender_id: UserId,
 815        body: &str,
 816        timestamp: OffsetDateTime,
 817        nonce: u128,
 818    ) -> Result<MessageId> {
 819        let query = "
 820            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 821            VALUES ($1, $2, $3, $4, $5)
 822            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 823            RETURNING id
 824        ";
 825        Ok(sqlx::query_scalar(query)
 826            .bind(channel_id.0)
 827            .bind(sender_id.0)
 828            .bind(body)
 829            .bind(timestamp)
 830            .bind(Uuid::from_u128(nonce))
 831            .fetch_one(&self.pool)
 832            .await
 833            .map(MessageId)?)
 834    }
 835
 836    async fn get_channel_messages(
 837        &self,
 838        channel_id: ChannelId,
 839        count: usize,
 840        before_id: Option<MessageId>,
 841    ) -> Result<Vec<ChannelMessage>> {
 842        let query = r#"
 843            SELECT * FROM (
 844                SELECT
 845                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 846                FROM
 847                    channel_messages
 848                WHERE
 849                    channel_id = $1 AND
 850                    id < $2
 851                ORDER BY id DESC
 852                LIMIT $3
 853            ) as recent_messages
 854            ORDER BY id ASC
 855        "#;
 856        Ok(sqlx::query_as(query)
 857            .bind(channel_id.0)
 858            .bind(before_id.unwrap_or(MessageId::MAX))
 859            .bind(count as i64)
 860            .fetch_all(&self.pool)
 861            .await?)
 862    }
 863
 864    #[cfg(test)]
 865    async fn teardown(&self, url: &str) {
 866        use util::ResultExt;
 867
 868        let query = "
 869            SELECT pg_terminate_backend(pg_stat_activity.pid)
 870            FROM pg_stat_activity
 871            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 872        ";
 873        sqlx::query(query).execute(&self.pool).await.log_err();
 874        self.pool.close().await;
 875        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 876            .await
 877            .log_err();
 878    }
 879
 880    #[cfg(test)]
 881    fn as_fake(&self) -> Option<&tests::FakeDb> {
 882        None
 883    }
 884}
 885
 886macro_rules! id_type {
 887    ($name:ident) => {
 888        #[derive(
 889            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 890        )]
 891        #[sqlx(transparent)]
 892        #[serde(transparent)]
 893        pub struct $name(pub i32);
 894
 895        impl $name {
 896            #[allow(unused)]
 897            pub const MAX: Self = Self(i32::MAX);
 898
 899            #[allow(unused)]
 900            pub fn from_proto(value: u64) -> Self {
 901                Self(value as i32)
 902            }
 903
 904            #[allow(unused)]
 905            pub fn to_proto(&self) -> u64 {
 906                self.0 as u64
 907            }
 908        }
 909
 910        impl std::fmt::Display for $name {
 911            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 912                self.0.fmt(f)
 913            }
 914        }
 915    };
 916}
 917
 918id_type!(UserId);
 919#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 920pub struct User {
 921    pub id: UserId,
 922    pub github_login: String,
 923    pub email_address: Option<String>,
 924    pub admin: bool,
 925    pub invite_code: Option<String>,
 926    pub invite_count: i32,
 927    pub connected_once: bool,
 928}
 929
 930id_type!(OrgId);
 931#[derive(FromRow)]
 932pub struct Org {
 933    pub id: OrgId,
 934    pub name: String,
 935    pub slug: String,
 936}
 937
 938id_type!(ChannelId);
 939#[derive(Clone, Debug, FromRow, Serialize)]
 940pub struct Channel {
 941    pub id: ChannelId,
 942    pub name: String,
 943    pub owner_id: i32,
 944    pub owner_is_user: bool,
 945}
 946
 947id_type!(MessageId);
 948#[derive(Clone, Debug, FromRow)]
 949pub struct ChannelMessage {
 950    pub id: MessageId,
 951    pub channel_id: ChannelId,
 952    pub sender_id: UserId,
 953    pub body: String,
 954    pub sent_at: OffsetDateTime,
 955    pub nonce: Uuid,
 956}
 957
 958#[derive(Clone, Debug, PartialEq, Eq)]
 959pub enum Contact {
 960    Accepted {
 961        user_id: UserId,
 962        should_notify: bool,
 963    },
 964    Outgoing {
 965        user_id: UserId,
 966    },
 967    Incoming {
 968        user_id: UserId,
 969        should_notify: bool,
 970    },
 971}
 972
 973impl Contact {
 974    pub fn user_id(&self) -> UserId {
 975        match self {
 976            Contact::Accepted { user_id, .. } => *user_id,
 977            Contact::Outgoing { user_id } => *user_id,
 978            Contact::Incoming { user_id, .. } => *user_id,
 979        }
 980    }
 981}
 982
 983#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 984pub struct IncomingContactRequest {
 985    pub requester_id: UserId,
 986    pub should_notify: bool,
 987}
 988
 989fn fuzzy_like_string(string: &str) -> String {
 990    let mut result = String::with_capacity(string.len() * 2 + 1);
 991    for c in string.chars() {
 992        if c.is_alphanumeric() {
 993            result.push('%');
 994            result.push(c);
 995        }
 996    }
 997    result.push('%');
 998    result
 999}
1000
1001#[cfg(test)]
1002pub mod tests {
1003    use super::*;
1004    use anyhow::anyhow;
1005    use collections::BTreeMap;
1006    use gpui::executor::Background;
1007    use lazy_static::lazy_static;
1008    use parking_lot::Mutex;
1009    use rand::prelude::*;
1010    use sqlx::{
1011        migrate::{MigrateDatabase, Migrator},
1012        Postgres,
1013    };
1014    use std::{path::Path, sync::Arc};
1015    use util::post_inc;
1016
1017    #[tokio::test(flavor = "multi_thread")]
1018    async fn test_get_users_by_ids() {
1019        for test_db in [
1020            TestDb::postgres().await,
1021            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1022        ] {
1023            let db = test_db.db();
1024
1025            let user = db.create_user("user", None, false).await.unwrap();
1026            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1027            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1028            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1029
1030            assert_eq!(
1031                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1032                    .await
1033                    .unwrap(),
1034                vec![
1035                    User {
1036                        id: user,
1037                        github_login: "user".to_string(),
1038                        admin: false,
1039                        ..Default::default()
1040                    },
1041                    User {
1042                        id: friend1,
1043                        github_login: "friend-1".to_string(),
1044                        admin: false,
1045                        ..Default::default()
1046                    },
1047                    User {
1048                        id: friend2,
1049                        github_login: "friend-2".to_string(),
1050                        admin: false,
1051                        ..Default::default()
1052                    },
1053                    User {
1054                        id: friend3,
1055                        github_login: "friend-3".to_string(),
1056                        admin: false,
1057                        ..Default::default()
1058                    }
1059                ]
1060            );
1061        }
1062    }
1063
1064    #[tokio::test(flavor = "multi_thread")]
1065    async fn test_create_users() {
1066        let db = TestDb::postgres().await;
1067        let db = db.db();
1068
1069        // Create the first batch of users, ensuring invite counts are assigned
1070        // correctly and the respective invite codes are unique.
1071        let user_ids_batch_1 = db
1072            .create_users(vec![
1073                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1074                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1075                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1076            ])
1077            .await
1078            .unwrap();
1079        assert_eq!(user_ids_batch_1.len(), 3);
1080
1081        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1082        assert_eq!(users.len(), 3);
1083        assert_eq!(users[0].github_login, "user1");
1084        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1085        assert_eq!(users[0].invite_count, 5);
1086        assert_eq!(users[1].github_login, "user2");
1087        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1088        assert_eq!(users[1].invite_count, 4);
1089        assert_eq!(users[2].github_login, "user3");
1090        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1091        assert_eq!(users[2].invite_count, 3);
1092
1093        let invite_code_1 = users[0].invite_code.clone().unwrap();
1094        let invite_code_2 = users[1].invite_code.clone().unwrap();
1095        let invite_code_3 = users[2].invite_code.clone().unwrap();
1096        assert_ne!(invite_code_1, invite_code_2);
1097        assert_ne!(invite_code_1, invite_code_3);
1098        assert_ne!(invite_code_2, invite_code_3);
1099
1100        // Create the second batch of users and include a user that is already in the database, ensuring
1101        // the invite count for the existing user is updated without changing their invite code.
1102        let user_ids_batch_2 = db
1103            .create_users(vec![
1104                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1105                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1106            ])
1107            .await
1108            .unwrap();
1109        assert_eq!(user_ids_batch_2.len(), 2);
1110        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1111
1112        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1113        assert_eq!(users.len(), 2);
1114        assert_eq!(users[0].github_login, "user2");
1115        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1116        assert_eq!(users[0].invite_count, 10);
1117        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1118        assert_eq!(users[1].github_login, "user4");
1119        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1120        assert_eq!(users[1].invite_count, 2);
1121
1122        let invite_code_4 = users[1].invite_code.clone().unwrap();
1123        assert_ne!(invite_code_4, invite_code_1);
1124        assert_ne!(invite_code_4, invite_code_2);
1125        assert_ne!(invite_code_4, invite_code_3);
1126    }
1127
1128    #[tokio::test(flavor = "multi_thread")]
1129    async fn test_recent_channel_messages() {
1130        for test_db in [
1131            TestDb::postgres().await,
1132            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1133        ] {
1134            let db = test_db.db();
1135            let user = db.create_user("user", None, false).await.unwrap();
1136            let org = db.create_org("org", "org").await.unwrap();
1137            let channel = db.create_org_channel(org, "channel").await.unwrap();
1138            for i in 0..10 {
1139                db.create_channel_message(
1140                    channel,
1141                    user,
1142                    &i.to_string(),
1143                    OffsetDateTime::now_utc(),
1144                    i,
1145                )
1146                .await
1147                .unwrap();
1148            }
1149
1150            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1151            assert_eq!(
1152                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1153                ["5", "6", "7", "8", "9"]
1154            );
1155
1156            let prev_messages = db
1157                .get_channel_messages(channel, 4, Some(messages[0].id))
1158                .await
1159                .unwrap();
1160            assert_eq!(
1161                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1162                ["1", "2", "3", "4"]
1163            );
1164        }
1165    }
1166
1167    #[tokio::test(flavor = "multi_thread")]
1168    async fn test_channel_message_nonces() {
1169        for test_db in [
1170            TestDb::postgres().await,
1171            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1172        ] {
1173            let db = test_db.db();
1174            let user = db.create_user("user", None, false).await.unwrap();
1175            let org = db.create_org("org", "org").await.unwrap();
1176            let channel = db.create_org_channel(org, "channel").await.unwrap();
1177
1178            let msg1_id = db
1179                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1180                .await
1181                .unwrap();
1182            let msg2_id = db
1183                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1184                .await
1185                .unwrap();
1186            let msg3_id = db
1187                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1188                .await
1189                .unwrap();
1190            let msg4_id = db
1191                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1192                .await
1193                .unwrap();
1194
1195            assert_ne!(msg1_id, msg2_id);
1196            assert_eq!(msg1_id, msg3_id);
1197            assert_eq!(msg2_id, msg4_id);
1198        }
1199    }
1200
1201    #[tokio::test(flavor = "multi_thread")]
1202    async fn test_create_access_tokens() {
1203        let test_db = TestDb::postgres().await;
1204        let db = test_db.db();
1205        let user = db.create_user("the-user", None, false).await.unwrap();
1206
1207        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1208        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1209        assert_eq!(
1210            db.get_access_token_hashes(user).await.unwrap(),
1211            &["h2".to_string(), "h1".to_string()]
1212        );
1213
1214        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1215        assert_eq!(
1216            db.get_access_token_hashes(user).await.unwrap(),
1217            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1218        );
1219
1220        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1221        assert_eq!(
1222            db.get_access_token_hashes(user).await.unwrap(),
1223            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1224        );
1225
1226        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1227        assert_eq!(
1228            db.get_access_token_hashes(user).await.unwrap(),
1229            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1230        );
1231    }
1232
1233    #[test]
1234    fn test_fuzzy_like_string() {
1235        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1236        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1237        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1238    }
1239
1240    #[tokio::test(flavor = "multi_thread")]
1241    async fn test_fuzzy_search_users() {
1242        let test_db = TestDb::postgres().await;
1243        let db = test_db.db();
1244        for github_login in [
1245            "California",
1246            "colorado",
1247            "oregon",
1248            "washington",
1249            "florida",
1250            "delaware",
1251            "rhode-island",
1252        ] {
1253            db.create_user(github_login, None, false).await.unwrap();
1254        }
1255
1256        assert_eq!(
1257            fuzzy_search_user_names(db, "clr").await,
1258            &["colorado", "California"]
1259        );
1260        assert_eq!(
1261            fuzzy_search_user_names(db, "ro").await,
1262            &["rhode-island", "colorado", "oregon"],
1263        );
1264
1265        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1266            db.fuzzy_search_users(query, 10)
1267                .await
1268                .unwrap()
1269                .into_iter()
1270                .map(|user| user.github_login)
1271                .collect::<Vec<_>>()
1272        }
1273    }
1274
1275    #[tokio::test(flavor = "multi_thread")]
1276    async fn test_add_contacts() {
1277        for test_db in [
1278            TestDb::postgres().await,
1279            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1280        ] {
1281            let db = test_db.db();
1282
1283            let user_1 = db.create_user("user1", None, false).await.unwrap();
1284            let user_2 = db.create_user("user2", None, false).await.unwrap();
1285            let user_3 = db.create_user("user3", None, false).await.unwrap();
1286
1287            // User starts with no contacts
1288            assert_eq!(
1289                db.get_contacts(user_1).await.unwrap(),
1290                vec![Contact::Accepted {
1291                    user_id: user_1,
1292                    should_notify: false
1293                }],
1294            );
1295
1296            // User requests a contact. Both users see the pending request.
1297            db.send_contact_request(user_1, user_2).await.unwrap();
1298            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1299            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1300            assert_eq!(
1301                db.get_contacts(user_1).await.unwrap(),
1302                &[
1303                    Contact::Accepted {
1304                        user_id: user_1,
1305                        should_notify: false
1306                    },
1307                    Contact::Outgoing { user_id: user_2 }
1308                ],
1309            );
1310            assert_eq!(
1311                db.get_contacts(user_2).await.unwrap(),
1312                &[
1313                    Contact::Incoming {
1314                        user_id: user_1,
1315                        should_notify: true
1316                    },
1317                    Contact::Accepted {
1318                        user_id: user_2,
1319                        should_notify: false
1320                    },
1321                ]
1322            );
1323
1324            // User 2 dismisses the contact request notification without accepting or rejecting.
1325            // We shouldn't notify them again.
1326            db.dismiss_contact_notification(user_1, user_2)
1327                .await
1328                .unwrap_err();
1329            db.dismiss_contact_notification(user_2, user_1)
1330                .await
1331                .unwrap();
1332            assert_eq!(
1333                db.get_contacts(user_2).await.unwrap(),
1334                &[
1335                    Contact::Incoming {
1336                        user_id: user_1,
1337                        should_notify: false
1338                    },
1339                    Contact::Accepted {
1340                        user_id: user_2,
1341                        should_notify: false
1342                    },
1343                ]
1344            );
1345
1346            // User can't accept their own contact request
1347            db.respond_to_contact_request(user_1, user_2, true)
1348                .await
1349                .unwrap_err();
1350
1351            // User accepts a contact request. Both users see the contact.
1352            db.respond_to_contact_request(user_2, user_1, true)
1353                .await
1354                .unwrap();
1355            assert_eq!(
1356                db.get_contacts(user_1).await.unwrap(),
1357                &[
1358                    Contact::Accepted {
1359                        user_id: user_1,
1360                        should_notify: false
1361                    },
1362                    Contact::Accepted {
1363                        user_id: user_2,
1364                        should_notify: true
1365                    }
1366                ],
1367            );
1368            assert!(db.has_contact(user_1, user_2).await.unwrap());
1369            assert!(db.has_contact(user_2, user_1).await.unwrap());
1370            assert_eq!(
1371                db.get_contacts(user_2).await.unwrap(),
1372                &[
1373                    Contact::Accepted {
1374                        user_id: user_1,
1375                        should_notify: false,
1376                    },
1377                    Contact::Accepted {
1378                        user_id: user_2,
1379                        should_notify: false,
1380                    },
1381                ]
1382            );
1383
1384            // Users cannot re-request existing contacts.
1385            db.send_contact_request(user_1, user_2).await.unwrap_err();
1386            db.send_contact_request(user_2, user_1).await.unwrap_err();
1387
1388            // Users can't dismiss notifications of them accepting other users' requests.
1389            db.dismiss_contact_notification(user_2, user_1)
1390                .await
1391                .unwrap_err();
1392            assert_eq!(
1393                db.get_contacts(user_1).await.unwrap(),
1394                &[
1395                    Contact::Accepted {
1396                        user_id: user_1,
1397                        should_notify: false
1398                    },
1399                    Contact::Accepted {
1400                        user_id: user_2,
1401                        should_notify: true,
1402                    },
1403                ]
1404            );
1405
1406            // Users can dismiss notifications of other users accepting their requests.
1407            db.dismiss_contact_notification(user_1, user_2)
1408                .await
1409                .unwrap();
1410            assert_eq!(
1411                db.get_contacts(user_1).await.unwrap(),
1412                &[
1413                    Contact::Accepted {
1414                        user_id: user_1,
1415                        should_notify: false
1416                    },
1417                    Contact::Accepted {
1418                        user_id: user_2,
1419                        should_notify: false,
1420                    },
1421                ]
1422            );
1423
1424            // Users send each other concurrent contact requests and
1425            // see that they are immediately accepted.
1426            db.send_contact_request(user_1, user_3).await.unwrap();
1427            db.send_contact_request(user_3, user_1).await.unwrap();
1428            assert_eq!(
1429                db.get_contacts(user_1).await.unwrap(),
1430                &[
1431                    Contact::Accepted {
1432                        user_id: user_1,
1433                        should_notify: false
1434                    },
1435                    Contact::Accepted {
1436                        user_id: user_2,
1437                        should_notify: false,
1438                    },
1439                    Contact::Accepted {
1440                        user_id: user_3,
1441                        should_notify: false
1442                    },
1443                ]
1444            );
1445            assert_eq!(
1446                db.get_contacts(user_3).await.unwrap(),
1447                &[
1448                    Contact::Accepted {
1449                        user_id: user_1,
1450                        should_notify: false
1451                    },
1452                    Contact::Accepted {
1453                        user_id: user_3,
1454                        should_notify: false
1455                    }
1456                ],
1457            );
1458
1459            // User declines a contact request. Both users see that it is gone.
1460            db.send_contact_request(user_2, user_3).await.unwrap();
1461            db.respond_to_contact_request(user_3, user_2, false)
1462                .await
1463                .unwrap();
1464            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1465            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1466            assert_eq!(
1467                db.get_contacts(user_2).await.unwrap(),
1468                &[
1469                    Contact::Accepted {
1470                        user_id: user_1,
1471                        should_notify: false
1472                    },
1473                    Contact::Accepted {
1474                        user_id: user_2,
1475                        should_notify: false
1476                    }
1477                ]
1478            );
1479            assert_eq!(
1480                db.get_contacts(user_3).await.unwrap(),
1481                &[
1482                    Contact::Accepted {
1483                        user_id: user_1,
1484                        should_notify: false
1485                    },
1486                    Contact::Accepted {
1487                        user_id: user_3,
1488                        should_notify: false
1489                    }
1490                ],
1491            );
1492        }
1493    }
1494
1495    #[tokio::test(flavor = "multi_thread")]
1496    async fn test_invite_codes() {
1497        let postgres = TestDb::postgres().await;
1498        let db = postgres.db();
1499        let user1 = db.create_user("user-1", None, false).await.unwrap();
1500
1501        // Initially, user 1 has no invite code
1502        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1503
1504        // Setting invite count to 0 when no code is assigned does not assign a new code
1505        db.set_invite_count(user1, 0).await.unwrap();
1506        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1507
1508        // User 1 creates an invite code that can be used twice.
1509        db.set_invite_count(user1, 2).await.unwrap();
1510        let (invite_code, invite_count) =
1511            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1512        assert_eq!(invite_count, 2);
1513
1514        // User 2 redeems the invite code and becomes a contact of user 1.
1515        let user2 = db
1516            .redeem_invite_code(&invite_code, "user-2", None)
1517            .await
1518            .unwrap();
1519        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1520        assert_eq!(invite_count, 1);
1521        assert_eq!(
1522            db.get_contacts(user1).await.unwrap(),
1523            [
1524                Contact::Accepted {
1525                    user_id: user1,
1526                    should_notify: false
1527                },
1528                Contact::Accepted {
1529                    user_id: user2,
1530                    should_notify: true
1531                }
1532            ]
1533        );
1534        assert_eq!(
1535            db.get_contacts(user2).await.unwrap(),
1536            [
1537                Contact::Accepted {
1538                    user_id: user1,
1539                    should_notify: false
1540                },
1541                Contact::Accepted {
1542                    user_id: user2,
1543                    should_notify: false
1544                }
1545            ]
1546        );
1547
1548        // User 3 redeems the invite code and becomes a contact of user 1.
1549        let user3 = db
1550            .redeem_invite_code(&invite_code, "user-3", None)
1551            .await
1552            .unwrap();
1553        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1554        assert_eq!(invite_count, 0);
1555        assert_eq!(
1556            db.get_contacts(user1).await.unwrap(),
1557            [
1558                Contact::Accepted {
1559                    user_id: user1,
1560                    should_notify: false
1561                },
1562                Contact::Accepted {
1563                    user_id: user2,
1564                    should_notify: true
1565                },
1566                Contact::Accepted {
1567                    user_id: user3,
1568                    should_notify: true
1569                }
1570            ]
1571        );
1572        assert_eq!(
1573            db.get_contacts(user3).await.unwrap(),
1574            [
1575                Contact::Accepted {
1576                    user_id: user1,
1577                    should_notify: false
1578                },
1579                Contact::Accepted {
1580                    user_id: user3,
1581                    should_notify: false
1582                },
1583            ]
1584        );
1585
1586        // Trying to reedem the code for the third time results in an error.
1587        db.redeem_invite_code(&invite_code, "user-4", None)
1588            .await
1589            .unwrap_err();
1590
1591        // Invite count can be updated after the code has been created.
1592        db.set_invite_count(user1, 2).await.unwrap();
1593        let (latest_code, invite_count) =
1594            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1595        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1596        assert_eq!(invite_count, 2);
1597
1598        // User 4 can now redeem the invite code and becomes a contact of user 1.
1599        let user4 = db
1600            .redeem_invite_code(&invite_code, "user-4", None)
1601            .await
1602            .unwrap();
1603        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1604        assert_eq!(invite_count, 1);
1605        assert_eq!(
1606            db.get_contacts(user1).await.unwrap(),
1607            [
1608                Contact::Accepted {
1609                    user_id: user1,
1610                    should_notify: false
1611                },
1612                Contact::Accepted {
1613                    user_id: user2,
1614                    should_notify: true
1615                },
1616                Contact::Accepted {
1617                    user_id: user3,
1618                    should_notify: true
1619                },
1620                Contact::Accepted {
1621                    user_id: user4,
1622                    should_notify: true
1623                }
1624            ]
1625        );
1626        assert_eq!(
1627            db.get_contacts(user4).await.unwrap(),
1628            [
1629                Contact::Accepted {
1630                    user_id: user1,
1631                    should_notify: false
1632                },
1633                Contact::Accepted {
1634                    user_id: user4,
1635                    should_notify: false
1636                },
1637            ]
1638        );
1639
1640        // An existing user cannot redeem invite codes.
1641        db.redeem_invite_code(&invite_code, "user-2", None)
1642            .await
1643            .unwrap_err();
1644        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1645        assert_eq!(invite_count, 1);
1646    }
1647
1648    pub struct TestDb {
1649        pub db: Option<Arc<dyn Db>>,
1650        pub url: String,
1651    }
1652
1653    impl TestDb {
1654        pub async fn postgres() -> Self {
1655            lazy_static! {
1656                static ref LOCK: Mutex<()> = Mutex::new(());
1657            }
1658
1659            let _guard = LOCK.lock();
1660            let mut rng = StdRng::from_entropy();
1661            let name = format!("zed-test-{}", rng.gen::<u128>());
1662            let url = format!("postgres://postgres@localhost/{}", name);
1663            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1664            Postgres::create_database(&url)
1665                .await
1666                .expect("failed to create test db");
1667            let db = PostgresDb::new(&url, 5).await.unwrap();
1668            let migrator = Migrator::new(migrations_path).await.unwrap();
1669            migrator.run(&db.pool).await.unwrap();
1670            Self {
1671                db: Some(Arc::new(db)),
1672                url,
1673            }
1674        }
1675
1676        pub fn fake(background: Arc<Background>) -> Self {
1677            Self {
1678                db: Some(Arc::new(FakeDb::new(background))),
1679                url: Default::default(),
1680            }
1681        }
1682
1683        pub fn db(&self) -> &Arc<dyn Db> {
1684            self.db.as_ref().unwrap()
1685        }
1686    }
1687
1688    impl Drop for TestDb {
1689        fn drop(&mut self) {
1690            if let Some(db) = self.db.take() {
1691                futures::executor::block_on(db.teardown(&self.url));
1692            }
1693        }
1694    }
1695
1696    pub struct FakeDb {
1697        background: Arc<Background>,
1698        pub users: Mutex<BTreeMap<UserId, User>>,
1699        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1700        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1701        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1702        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1703        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1704        pub contacts: Mutex<Vec<FakeContact>>,
1705        next_channel_message_id: Mutex<i32>,
1706        next_user_id: Mutex<i32>,
1707        next_org_id: Mutex<i32>,
1708        next_channel_id: Mutex<i32>,
1709    }
1710
1711    #[derive(Debug)]
1712    pub struct FakeContact {
1713        pub requester_id: UserId,
1714        pub responder_id: UserId,
1715        pub accepted: bool,
1716        pub should_notify: bool,
1717    }
1718
1719    impl FakeDb {
1720        pub fn new(background: Arc<Background>) -> Self {
1721            Self {
1722                background,
1723                users: Default::default(),
1724                next_user_id: Mutex::new(1),
1725                orgs: Default::default(),
1726                next_org_id: Mutex::new(1),
1727                org_memberships: Default::default(),
1728                channels: Default::default(),
1729                next_channel_id: Mutex::new(1),
1730                channel_memberships: Default::default(),
1731                channel_messages: Default::default(),
1732                next_channel_message_id: Mutex::new(1),
1733                contacts: Default::default(),
1734            }
1735        }
1736    }
1737
1738    #[async_trait]
1739    impl Db for FakeDb {
1740        async fn create_user(
1741            &self,
1742            github_login: &str,
1743            email_address: Option<&str>,
1744            admin: bool,
1745        ) -> Result<UserId> {
1746            self.background.simulate_random_delay().await;
1747
1748            let mut users = self.users.lock();
1749            if let Some(user) = users
1750                .values()
1751                .find(|user| user.github_login == github_login)
1752            {
1753                Ok(user.id)
1754            } else {
1755                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1756                users.insert(
1757                    user_id,
1758                    User {
1759                        id: user_id,
1760                        github_login: github_login.to_string(),
1761                        email_address: email_address.map(str::to_string),
1762                        admin,
1763                        invite_code: None,
1764                        invite_count: 0,
1765                        connected_once: false,
1766                    },
1767                );
1768                Ok(user_id)
1769            }
1770        }
1771
1772        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1773            unimplemented!()
1774        }
1775
1776        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
1777            unimplemented!()
1778        }
1779
1780        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1781            unimplemented!()
1782        }
1783
1784        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1785            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1786        }
1787
1788        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1789            self.background.simulate_random_delay().await;
1790            let users = self.users.lock();
1791            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1792        }
1793
1794        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1795            Ok(self
1796                .users
1797                .lock()
1798                .values()
1799                .find(|user| user.github_login == github_login)
1800                .cloned())
1801        }
1802
1803        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1804            unimplemented!()
1805        }
1806
1807        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1808            self.background.simulate_random_delay().await;
1809            let mut users = self.users.lock();
1810            let mut user = users
1811                .get_mut(&id)
1812                .ok_or_else(|| anyhow!("user not found"))?;
1813            user.connected_once = connected_once;
1814            Ok(())
1815        }
1816
1817        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1818            unimplemented!()
1819        }
1820
1821        // invite codes
1822
1823        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1824            unimplemented!()
1825        }
1826
1827        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1828            Ok(None)
1829        }
1830
1831        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1832            unimplemented!()
1833        }
1834
1835        async fn redeem_invite_code(
1836            &self,
1837            _code: &str,
1838            _login: &str,
1839            _email_address: Option<&str>,
1840        ) -> Result<UserId> {
1841            unimplemented!()
1842        }
1843
1844        // contacts
1845
1846        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1847            self.background.simulate_random_delay().await;
1848            let mut contacts = vec![Contact::Accepted {
1849                user_id: id,
1850                should_notify: false,
1851            }];
1852
1853            for contact in self.contacts.lock().iter() {
1854                if contact.requester_id == id {
1855                    if contact.accepted {
1856                        contacts.push(Contact::Accepted {
1857                            user_id: contact.responder_id,
1858                            should_notify: contact.should_notify,
1859                        });
1860                    } else {
1861                        contacts.push(Contact::Outgoing {
1862                            user_id: contact.responder_id,
1863                        });
1864                    }
1865                } else if contact.responder_id == id {
1866                    if contact.accepted {
1867                        contacts.push(Contact::Accepted {
1868                            user_id: contact.requester_id,
1869                            should_notify: false,
1870                        });
1871                    } else {
1872                        contacts.push(Contact::Incoming {
1873                            user_id: contact.requester_id,
1874                            should_notify: contact.should_notify,
1875                        });
1876                    }
1877                }
1878            }
1879
1880            contacts.sort_unstable_by_key(|contact| contact.user_id());
1881            Ok(contacts)
1882        }
1883
1884        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1885            self.background.simulate_random_delay().await;
1886            Ok(self.contacts.lock().iter().any(|contact| {
1887                contact.accepted
1888                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1889                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1890            }))
1891        }
1892
1893        async fn send_contact_request(
1894            &self,
1895            requester_id: UserId,
1896            responder_id: UserId,
1897        ) -> Result<()> {
1898            let mut contacts = self.contacts.lock();
1899            for contact in contacts.iter_mut() {
1900                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1901                    if contact.accepted {
1902                        Err(anyhow!("contact already exists"))?;
1903                    } else {
1904                        Err(anyhow!("contact already requested"))?;
1905                    }
1906                }
1907                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1908                    if contact.accepted {
1909                        Err(anyhow!("contact already exists"))?;
1910                    } else {
1911                        contact.accepted = true;
1912                        contact.should_notify = false;
1913                        return Ok(());
1914                    }
1915                }
1916            }
1917            contacts.push(FakeContact {
1918                requester_id,
1919                responder_id,
1920                accepted: false,
1921                should_notify: true,
1922            });
1923            Ok(())
1924        }
1925
1926        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1927            self.contacts.lock().retain(|contact| {
1928                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1929            });
1930            Ok(())
1931        }
1932
1933        async fn dismiss_contact_notification(
1934            &self,
1935            user_id: UserId,
1936            contact_user_id: UserId,
1937        ) -> Result<()> {
1938            let mut contacts = self.contacts.lock();
1939            for contact in contacts.iter_mut() {
1940                if contact.requester_id == contact_user_id
1941                    && contact.responder_id == user_id
1942                    && !contact.accepted
1943                {
1944                    contact.should_notify = false;
1945                    return Ok(());
1946                }
1947                if contact.requester_id == user_id
1948                    && contact.responder_id == contact_user_id
1949                    && contact.accepted
1950                {
1951                    contact.should_notify = false;
1952                    return Ok(());
1953                }
1954            }
1955            Err(anyhow!("no such notification"))?
1956        }
1957
1958        async fn respond_to_contact_request(
1959            &self,
1960            responder_id: UserId,
1961            requester_id: UserId,
1962            accept: bool,
1963        ) -> Result<()> {
1964            let mut contacts = self.contacts.lock();
1965            for (ix, contact) in contacts.iter_mut().enumerate() {
1966                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1967                    if contact.accepted {
1968                        Err(anyhow!("contact already confirmed"))?;
1969                    }
1970                    if accept {
1971                        contact.accepted = true;
1972                        contact.should_notify = true;
1973                    } else {
1974                        contacts.remove(ix);
1975                    }
1976                    return Ok(());
1977                }
1978            }
1979            Err(anyhow!("no such contact request"))?
1980        }
1981
1982        async fn create_access_token_hash(
1983            &self,
1984            _user_id: UserId,
1985            _access_token_hash: &str,
1986            _max_access_token_count: usize,
1987        ) -> Result<()> {
1988            unimplemented!()
1989        }
1990
1991        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1992            unimplemented!()
1993        }
1994
1995        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1996            unimplemented!()
1997        }
1998
1999        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2000            self.background.simulate_random_delay().await;
2001            let mut orgs = self.orgs.lock();
2002            if orgs.values().any(|org| org.slug == slug) {
2003                Err(anyhow!("org already exists"))?
2004            } else {
2005                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2006                orgs.insert(
2007                    org_id,
2008                    Org {
2009                        id: org_id,
2010                        name: name.to_string(),
2011                        slug: slug.to_string(),
2012                    },
2013                );
2014                Ok(org_id)
2015            }
2016        }
2017
2018        async fn add_org_member(
2019            &self,
2020            org_id: OrgId,
2021            user_id: UserId,
2022            is_admin: bool,
2023        ) -> Result<()> {
2024            self.background.simulate_random_delay().await;
2025            if !self.orgs.lock().contains_key(&org_id) {
2026                Err(anyhow!("org does not exist"))?;
2027            }
2028            if !self.users.lock().contains_key(&user_id) {
2029                Err(anyhow!("user does not exist"))?;
2030            }
2031
2032            self.org_memberships
2033                .lock()
2034                .entry((org_id, user_id))
2035                .or_insert(is_admin);
2036            Ok(())
2037        }
2038
2039        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2040            self.background.simulate_random_delay().await;
2041            if !self.orgs.lock().contains_key(&org_id) {
2042                Err(anyhow!("org does not exist"))?;
2043            }
2044
2045            let mut channels = self.channels.lock();
2046            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2047            channels.insert(
2048                channel_id,
2049                Channel {
2050                    id: channel_id,
2051                    name: name.to_string(),
2052                    owner_id: org_id.0,
2053                    owner_is_user: false,
2054                },
2055            );
2056            Ok(channel_id)
2057        }
2058
2059        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2060            self.background.simulate_random_delay().await;
2061            Ok(self
2062                .channels
2063                .lock()
2064                .values()
2065                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2066                .cloned()
2067                .collect())
2068        }
2069
2070        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2071            self.background.simulate_random_delay().await;
2072            let channels = self.channels.lock();
2073            let memberships = self.channel_memberships.lock();
2074            Ok(channels
2075                .values()
2076                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2077                .cloned()
2078                .collect())
2079        }
2080
2081        async fn can_user_access_channel(
2082            &self,
2083            user_id: UserId,
2084            channel_id: ChannelId,
2085        ) -> Result<bool> {
2086            self.background.simulate_random_delay().await;
2087            Ok(self
2088                .channel_memberships
2089                .lock()
2090                .contains_key(&(channel_id, user_id)))
2091        }
2092
2093        async fn add_channel_member(
2094            &self,
2095            channel_id: ChannelId,
2096            user_id: UserId,
2097            is_admin: bool,
2098        ) -> Result<()> {
2099            self.background.simulate_random_delay().await;
2100            if !self.channels.lock().contains_key(&channel_id) {
2101                Err(anyhow!("channel does not exist"))?;
2102            }
2103            if !self.users.lock().contains_key(&user_id) {
2104                Err(anyhow!("user does not exist"))?;
2105            }
2106
2107            self.channel_memberships
2108                .lock()
2109                .entry((channel_id, user_id))
2110                .or_insert(is_admin);
2111            Ok(())
2112        }
2113
2114        async fn create_channel_message(
2115            &self,
2116            channel_id: ChannelId,
2117            sender_id: UserId,
2118            body: &str,
2119            timestamp: OffsetDateTime,
2120            nonce: u128,
2121        ) -> Result<MessageId> {
2122            self.background.simulate_random_delay().await;
2123            if !self.channels.lock().contains_key(&channel_id) {
2124                Err(anyhow!("channel does not exist"))?;
2125            }
2126            if !self.users.lock().contains_key(&sender_id) {
2127                Err(anyhow!("user does not exist"))?;
2128            }
2129
2130            let mut messages = self.channel_messages.lock();
2131            if let Some(message) = messages
2132                .values()
2133                .find(|message| message.nonce.as_u128() == nonce)
2134            {
2135                Ok(message.id)
2136            } else {
2137                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2138                messages.insert(
2139                    message_id,
2140                    ChannelMessage {
2141                        id: message_id,
2142                        channel_id,
2143                        sender_id,
2144                        body: body.to_string(),
2145                        sent_at: timestamp,
2146                        nonce: Uuid::from_u128(nonce),
2147                    },
2148                );
2149                Ok(message_id)
2150            }
2151        }
2152
2153        async fn get_channel_messages(
2154            &self,
2155            channel_id: ChannelId,
2156            count: usize,
2157            before_id: Option<MessageId>,
2158        ) -> Result<Vec<ChannelMessage>> {
2159            let mut messages = self
2160                .channel_messages
2161                .lock()
2162                .values()
2163                .rev()
2164                .filter(|message| {
2165                    message.channel_id == channel_id
2166                        && message.id < before_id.unwrap_or(MessageId::MAX)
2167                })
2168                .take(count)
2169                .cloned()
2170                .collect::<Vec<_>>();
2171            messages.sort_unstable_by_key(|message| message.id);
2172            Ok(messages)
2173        }
2174
2175        async fn teardown(&self, _: &str) {}
2176
2177        #[cfg(test)]
2178        fn as_fake(&self) -> Option<&FakeDb> {
2179            Some(self)
2180        }
2181    }
2182}