db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use futures::StreamExt;
   6use nanoid::nanoid;
   7use serde::Serialize;
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  10use time::OffsetDateTime;
  11
  12#[async_trait]
  13pub trait Db: Send + Sync {
  14    async fn create_user(
  15        &self,
  16        github_login: &str,
  17        email_address: Option<&str>,
  18        admin: bool,
  19    ) -> Result<UserId>;
  20    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  21    async fn get_all_users(&self) -> Result<Vec<User>>;
  22    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  23    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  24    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  25    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  26    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  27    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  28    async fn destroy_user(&self, id: UserId) -> Result<()>;
  29
  30    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  31    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  32    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  33    async fn redeem_invite_code(
  34        &self,
  35        code: &str,
  36        login: &str,
  37        email_address: Option<&str>,
  38    ) -> Result<UserId>;
  39
  40    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  41    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  42    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  43    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  44    async fn dismiss_contact_notification(
  45        &self,
  46        responder_id: UserId,
  47        requester_id: UserId,
  48    ) -> Result<()>;
  49    async fn respond_to_contact_request(
  50        &self,
  51        responder_id: UserId,
  52        requester_id: UserId,
  53        accept: bool,
  54    ) -> Result<()>;
  55
  56    async fn create_access_token_hash(
  57        &self,
  58        user_id: UserId,
  59        access_token_hash: &str,
  60        max_access_token_count: usize,
  61    ) -> Result<()>;
  62    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
  63    #[cfg(any(test, feature = "seed-support"))]
  64
  65    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
  66    #[cfg(any(test, feature = "seed-support"))]
  67    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
  68    #[cfg(any(test, feature = "seed-support"))]
  69    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
  70    #[cfg(any(test, feature = "seed-support"))]
  71    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
  72    #[cfg(any(test, feature = "seed-support"))]
  73
  74    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
  75    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
  76    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
  77        -> Result<bool>;
  78    #[cfg(any(test, feature = "seed-support"))]
  79    async fn add_channel_member(
  80        &self,
  81        channel_id: ChannelId,
  82        user_id: UserId,
  83        is_admin: bool,
  84    ) -> Result<()>;
  85    async fn create_channel_message(
  86        &self,
  87        channel_id: ChannelId,
  88        sender_id: UserId,
  89        body: &str,
  90        timestamp: OffsetDateTime,
  91        nonce: u128,
  92    ) -> Result<MessageId>;
  93    async fn get_channel_messages(
  94        &self,
  95        channel_id: ChannelId,
  96        count: usize,
  97        before_id: Option<MessageId>,
  98    ) -> Result<Vec<ChannelMessage>>;
  99    #[cfg(test)]
 100    async fn teardown(&self, url: &str);
 101    #[cfg(test)]
 102    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 103}
 104
 105pub struct PostgresDb {
 106    pool: sqlx::PgPool,
 107}
 108
 109impl PostgresDb {
 110    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 111        let pool = DbOptions::new()
 112            .max_connections(max_connections)
 113            .connect(&url)
 114            .await
 115            .context("failed to connect to postgres database")?;
 116        Ok(Self { pool })
 117    }
 118}
 119
 120#[async_trait]
 121impl Db for PostgresDb {
 122    // users
 123
 124    async fn create_user(
 125        &self,
 126        github_login: &str,
 127        email_address: Option<&str>,
 128        admin: bool,
 129    ) -> Result<UserId> {
 130        let query = "
 131            INSERT INTO users (github_login, email_address, admin)
 132            VALUES ($1, $2, $3)
 133            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 134            RETURNING id
 135        ";
 136        Ok(sqlx::query_scalar(query)
 137            .bind(github_login)
 138            .bind(email_address)
 139            .bind(admin)
 140            .fetch_one(&self.pool)
 141            .await
 142            .map(UserId)?)
 143    }
 144
 145    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 146        let mut query = QueryBuilder::new(
 147            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 148        );
 149        query.push_values(
 150            users,
 151            |mut query, (github_login, email_address, invite_count)| {
 152                query
 153                    .push_bind(github_login)
 154                    .push_bind(email_address)
 155                    .push_bind(false)
 156                    .push_bind(nanoid!(16))
 157                    .push_bind(invite_count as u32);
 158            },
 159        );
 160        query.push(
 161            "
 162            ON CONFLICT (github_login) DO UPDATE SET
 163                github_login = excluded.github_login,
 164                invite_count = excluded.invite_count,
 165                invite_code = CASE WHEN users.invite_code IS NULL
 166                                   THEN excluded.invite_code
 167                                   ELSE users.invite_code
 168                              END
 169            RETURNING id
 170            ",
 171        );
 172
 173        let rows = query.build().fetch_all(&self.pool).await?;
 174        Ok(rows
 175            .into_iter()
 176            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 177            .collect())
 178    }
 179
 180    async fn get_all_users(&self) -> Result<Vec<User>> {
 181        let query = "SELECT * FROM users ORDER BY github_login ASC";
 182        Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
 183    }
 184
 185    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 186        let like_string = fuzzy_like_string(name_query);
 187        let query = "
 188            SELECT users.*
 189            FROM users
 190            WHERE github_login ILIKE $1
 191            ORDER BY github_login <-> $2
 192            LIMIT $3
 193        ";
 194        Ok(sqlx::query_as(query)
 195            .bind(like_string)
 196            .bind(name_query)
 197            .bind(limit)
 198            .fetch_all(&self.pool)
 199            .await?)
 200    }
 201
 202    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 203        let users = self.get_users_by_ids(vec![id]).await?;
 204        Ok(users.into_iter().next())
 205    }
 206
 207    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 208        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 209        let query = "
 210            SELECT users.*
 211            FROM users
 212            WHERE users.id = ANY ($1)
 213        ";
 214        Ok(sqlx::query_as(query)
 215            .bind(&ids)
 216            .fetch_all(&self.pool)
 217            .await?)
 218    }
 219
 220    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 221        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 222        Ok(sqlx::query_as(query)
 223            .bind(github_login)
 224            .fetch_optional(&self.pool)
 225            .await?)
 226    }
 227
 228    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 229        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 230        Ok(sqlx::query(query)
 231            .bind(is_admin)
 232            .bind(id.0)
 233            .execute(&self.pool)
 234            .await
 235            .map(drop)?)
 236    }
 237
 238    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 239        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 240        Ok(sqlx::query(query)
 241            .bind(connected_once)
 242            .bind(id.0)
 243            .execute(&self.pool)
 244            .await
 245            .map(drop)?)
 246    }
 247
 248    async fn destroy_user(&self, id: UserId) -> Result<()> {
 249        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 250        sqlx::query(query)
 251            .bind(id.0)
 252            .execute(&self.pool)
 253            .await
 254            .map(drop)?;
 255        let query = "DELETE FROM users WHERE id = $1;";
 256        Ok(sqlx::query(query)
 257            .bind(id.0)
 258            .execute(&self.pool)
 259            .await
 260            .map(drop)?)
 261    }
 262
 263    // invite codes
 264
 265    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 266        let mut tx = self.pool.begin().await?;
 267        if count > 0 {
 268            sqlx::query(
 269                "
 270                UPDATE users
 271                SET invite_code = $1
 272                WHERE id = $2 AND invite_code IS NULL
 273            ",
 274            )
 275            .bind(nanoid!(16))
 276            .bind(id)
 277            .execute(&mut tx)
 278            .await?;
 279        }
 280
 281        sqlx::query(
 282            "
 283            UPDATE users
 284            SET invite_count = $1
 285            WHERE id = $2
 286            ",
 287        )
 288        .bind(count)
 289        .bind(id)
 290        .execute(&mut tx)
 291        .await?;
 292        tx.commit().await?;
 293        Ok(())
 294    }
 295
 296    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 297        let result: Option<(String, i32)> = sqlx::query_as(
 298            "
 299                SELECT invite_code, invite_count
 300                FROM users
 301                WHERE id = $1 AND invite_code IS NOT NULL 
 302            ",
 303        )
 304        .bind(id)
 305        .fetch_optional(&self.pool)
 306        .await?;
 307        if let Some((code, count)) = result {
 308            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 309        } else {
 310            Ok(None)
 311        }
 312    }
 313
 314    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 315        sqlx::query_as(
 316            "
 317                SELECT *
 318                FROM users
 319                WHERE invite_code = $1
 320            ",
 321        )
 322        .bind(code)
 323        .fetch_optional(&self.pool)
 324        .await?
 325        .ok_or_else(|| {
 326            Error::Http(
 327                StatusCode::NOT_FOUND,
 328                "that invite code does not exist".to_string(),
 329            )
 330        })
 331    }
 332
 333    async fn redeem_invite_code(
 334        &self,
 335        code: &str,
 336        login: &str,
 337        email_address: Option<&str>,
 338    ) -> Result<UserId> {
 339        let mut tx = self.pool.begin().await?;
 340
 341        let inviter_id: Option<UserId> = sqlx::query_scalar(
 342            "
 343                UPDATE users
 344                SET invite_count = invite_count - 1
 345                WHERE
 346                    invite_code = $1 AND
 347                    invite_count > 0
 348                RETURNING id
 349            ",
 350        )
 351        .bind(code)
 352        .fetch_optional(&mut tx)
 353        .await?;
 354
 355        let inviter_id = match inviter_id {
 356            Some(inviter_id) => inviter_id,
 357            None => {
 358                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 359                    .bind(code)
 360                    .fetch_optional(&mut tx)
 361                    .await?
 362                    .is_some()
 363                {
 364                    Err(Error::Http(
 365                        StatusCode::UNAUTHORIZED,
 366                        "no invites remaining".to_string(),
 367                    ))?
 368                } else {
 369                    Err(Error::Http(
 370                        StatusCode::NOT_FOUND,
 371                        "invite code not found".to_string(),
 372                    ))?
 373                }
 374            }
 375        };
 376
 377        let invitee_id = sqlx::query_scalar(
 378            "
 379                INSERT INTO users
 380                    (github_login, email_address, admin, inviter_id)
 381                VALUES
 382                    ($1, $2, 'f', $3)
 383                RETURNING id
 384            ",
 385        )
 386        .bind(login)
 387        .bind(email_address)
 388        .bind(inviter_id)
 389        .fetch_one(&mut tx)
 390        .await
 391        .map(UserId)?;
 392
 393        sqlx::query(
 394            "
 395                INSERT INTO contacts
 396                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 397                VALUES
 398                    ($1, $2, 't', 't', 't')
 399            ",
 400        )
 401        .bind(inviter_id)
 402        .bind(invitee_id)
 403        .execute(&mut tx)
 404        .await?;
 405
 406        tx.commit().await?;
 407        Ok(invitee_id)
 408    }
 409
 410    // contacts
 411
 412    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 413        let query = "
 414            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 415            FROM contacts
 416            WHERE user_id_a = $1 OR user_id_b = $1;
 417        ";
 418
 419        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 420            .bind(user_id)
 421            .fetch(&self.pool);
 422
 423        let mut contacts = vec![Contact::Accepted {
 424            user_id,
 425            should_notify: false,
 426        }];
 427        while let Some(row) = rows.next().await {
 428            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 429
 430            if user_id_a == user_id {
 431                if accepted {
 432                    contacts.push(Contact::Accepted {
 433                        user_id: user_id_b,
 434                        should_notify: should_notify && a_to_b,
 435                    });
 436                } else if a_to_b {
 437                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 438                } else {
 439                    contacts.push(Contact::Incoming {
 440                        user_id: user_id_b,
 441                        should_notify,
 442                    });
 443                }
 444            } else {
 445                if accepted {
 446                    contacts.push(Contact::Accepted {
 447                        user_id: user_id_a,
 448                        should_notify: should_notify && !a_to_b,
 449                    });
 450                } else if a_to_b {
 451                    contacts.push(Contact::Incoming {
 452                        user_id: user_id_a,
 453                        should_notify,
 454                    });
 455                } else {
 456                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 457                }
 458            }
 459        }
 460
 461        contacts.sort_unstable_by_key(|contact| contact.user_id());
 462
 463        Ok(contacts)
 464    }
 465
 466    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 467        let (id_a, id_b) = if user_id_1 < user_id_2 {
 468            (user_id_1, user_id_2)
 469        } else {
 470            (user_id_2, user_id_1)
 471        };
 472
 473        let query = "
 474            SELECT 1 FROM contacts
 475            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 476            LIMIT 1
 477        ";
 478        Ok(sqlx::query_scalar::<_, i32>(query)
 479            .bind(id_a.0)
 480            .bind(id_b.0)
 481            .fetch_optional(&self.pool)
 482            .await?
 483            .is_some())
 484    }
 485
 486    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 487        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 488            (sender_id, receiver_id, true)
 489        } else {
 490            (receiver_id, sender_id, false)
 491        };
 492        let query = "
 493            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 494            VALUES ($1, $2, $3, 'f', 't')
 495            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 496            SET
 497                accepted = 't',
 498                should_notify = 'f'
 499            WHERE
 500                NOT contacts.accepted AND
 501                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 502                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 503        ";
 504        let result = sqlx::query(query)
 505            .bind(id_a.0)
 506            .bind(id_b.0)
 507            .bind(a_to_b)
 508            .execute(&self.pool)
 509            .await?;
 510
 511        if result.rows_affected() == 1 {
 512            Ok(())
 513        } else {
 514            Err(anyhow!("contact already requested"))?
 515        }
 516    }
 517
 518    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 519        let (id_a, id_b) = if responder_id < requester_id {
 520            (responder_id, requester_id)
 521        } else {
 522            (requester_id, responder_id)
 523        };
 524        let query = "
 525            DELETE FROM contacts
 526            WHERE user_id_a = $1 AND user_id_b = $2;
 527        ";
 528        let result = sqlx::query(query)
 529            .bind(id_a.0)
 530            .bind(id_b.0)
 531            .execute(&self.pool)
 532            .await?;
 533
 534        if result.rows_affected() == 1 {
 535            Ok(())
 536        } else {
 537            Err(anyhow!("no such contact"))?
 538        }
 539    }
 540
 541    async fn dismiss_contact_notification(
 542        &self,
 543        user_id: UserId,
 544        contact_user_id: UserId,
 545    ) -> Result<()> {
 546        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 547            (user_id, contact_user_id, true)
 548        } else {
 549            (contact_user_id, user_id, false)
 550        };
 551
 552        let query = "
 553            UPDATE contacts
 554            SET should_notify = 'f'
 555            WHERE
 556                user_id_a = $1 AND user_id_b = $2 AND
 557                (
 558                    (a_to_b = $3 AND accepted) OR
 559                    (a_to_b != $3 AND NOT accepted)
 560                );
 561        ";
 562
 563        let result = sqlx::query(query)
 564            .bind(id_a.0)
 565            .bind(id_b.0)
 566            .bind(a_to_b)
 567            .execute(&self.pool)
 568            .await?;
 569
 570        if result.rows_affected() == 0 {
 571            Err(anyhow!("no such contact request"))?;
 572        }
 573
 574        Ok(())
 575    }
 576
 577    async fn respond_to_contact_request(
 578        &self,
 579        responder_id: UserId,
 580        requester_id: UserId,
 581        accept: bool,
 582    ) -> Result<()> {
 583        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 584            (responder_id, requester_id, false)
 585        } else {
 586            (requester_id, responder_id, true)
 587        };
 588        let result = if accept {
 589            let query = "
 590                UPDATE contacts
 591                SET accepted = 't', should_notify = 't'
 592                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 593            ";
 594            sqlx::query(query)
 595                .bind(id_a.0)
 596                .bind(id_b.0)
 597                .bind(a_to_b)
 598                .execute(&self.pool)
 599                .await?
 600        } else {
 601            let query = "
 602                DELETE FROM contacts
 603                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 604            ";
 605            sqlx::query(query)
 606                .bind(id_a.0)
 607                .bind(id_b.0)
 608                .bind(a_to_b)
 609                .execute(&self.pool)
 610                .await?
 611        };
 612        if result.rows_affected() == 1 {
 613            Ok(())
 614        } else {
 615            Err(anyhow!("no such contact request"))?
 616        }
 617    }
 618
 619    // access tokens
 620
 621    async fn create_access_token_hash(
 622        &self,
 623        user_id: UserId,
 624        access_token_hash: &str,
 625        max_access_token_count: usize,
 626    ) -> Result<()> {
 627        let insert_query = "
 628            INSERT INTO access_tokens (user_id, hash)
 629            VALUES ($1, $2);
 630        ";
 631        let cleanup_query = "
 632            DELETE FROM access_tokens
 633            WHERE id IN (
 634                SELECT id from access_tokens
 635                WHERE user_id = $1
 636                ORDER BY id DESC
 637                OFFSET $3
 638            )
 639        ";
 640
 641        let mut tx = self.pool.begin().await?;
 642        sqlx::query(insert_query)
 643            .bind(user_id.0)
 644            .bind(access_token_hash)
 645            .execute(&mut tx)
 646            .await?;
 647        sqlx::query(cleanup_query)
 648            .bind(user_id.0)
 649            .bind(access_token_hash)
 650            .bind(max_access_token_count as u32)
 651            .execute(&mut tx)
 652            .await?;
 653        Ok(tx.commit().await?)
 654    }
 655
 656    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 657        let query = "
 658            SELECT hash
 659            FROM access_tokens
 660            WHERE user_id = $1
 661            ORDER BY id DESC
 662        ";
 663        Ok(sqlx::query_scalar(query)
 664            .bind(user_id.0)
 665            .fetch_all(&self.pool)
 666            .await?)
 667    }
 668
 669    // orgs
 670
 671    #[allow(unused)] // Help rust-analyzer
 672    #[cfg(any(test, feature = "seed-support"))]
 673    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 674        let query = "
 675            SELECT *
 676            FROM orgs
 677            WHERE slug = $1
 678        ";
 679        Ok(sqlx::query_as(query)
 680            .bind(slug)
 681            .fetch_optional(&self.pool)
 682            .await?)
 683    }
 684
 685    #[cfg(any(test, feature = "seed-support"))]
 686    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
 687        let query = "
 688            INSERT INTO orgs (name, slug)
 689            VALUES ($1, $2)
 690            RETURNING id
 691        ";
 692        Ok(sqlx::query_scalar(query)
 693            .bind(name)
 694            .bind(slug)
 695            .fetch_one(&self.pool)
 696            .await
 697            .map(OrgId)?)
 698    }
 699
 700    #[cfg(any(test, feature = "seed-support"))]
 701    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
 702        let query = "
 703            INSERT INTO org_memberships (org_id, user_id, admin)
 704            VALUES ($1, $2, $3)
 705            ON CONFLICT DO NOTHING
 706        ";
 707        Ok(sqlx::query(query)
 708            .bind(org_id.0)
 709            .bind(user_id.0)
 710            .bind(is_admin)
 711            .execute(&self.pool)
 712            .await
 713            .map(drop)?)
 714    }
 715
 716    // channels
 717
 718    #[cfg(any(test, feature = "seed-support"))]
 719    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
 720        let query = "
 721            INSERT INTO channels (owner_id, owner_is_user, name)
 722            VALUES ($1, false, $2)
 723            RETURNING id
 724        ";
 725        Ok(sqlx::query_scalar(query)
 726            .bind(org_id.0)
 727            .bind(name)
 728            .fetch_one(&self.pool)
 729            .await
 730            .map(ChannelId)?)
 731    }
 732
 733    #[allow(unused)] // Help rust-analyzer
 734    #[cfg(any(test, feature = "seed-support"))]
 735    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
 736        let query = "
 737            SELECT *
 738            FROM channels
 739            WHERE
 740                channels.owner_is_user = false AND
 741                channels.owner_id = $1
 742        ";
 743        Ok(sqlx::query_as(query)
 744            .bind(org_id.0)
 745            .fetch_all(&self.pool)
 746            .await?)
 747    }
 748
 749    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
 750        let query = "
 751            SELECT
 752                channels.*
 753            FROM
 754                channel_memberships, channels
 755            WHERE
 756                channel_memberships.user_id = $1 AND
 757                channel_memberships.channel_id = channels.id
 758        ";
 759        Ok(sqlx::query_as(query)
 760            .bind(user_id.0)
 761            .fetch_all(&self.pool)
 762            .await?)
 763    }
 764
 765    async fn can_user_access_channel(
 766        &self,
 767        user_id: UserId,
 768        channel_id: ChannelId,
 769    ) -> Result<bool> {
 770        let query = "
 771            SELECT id
 772            FROM channel_memberships
 773            WHERE user_id = $1 AND channel_id = $2
 774            LIMIT 1
 775        ";
 776        Ok(sqlx::query_scalar::<_, i32>(query)
 777            .bind(user_id.0)
 778            .bind(channel_id.0)
 779            .fetch_optional(&self.pool)
 780            .await
 781            .map(|e| e.is_some())?)
 782    }
 783
 784    #[cfg(any(test, feature = "seed-support"))]
 785    async fn add_channel_member(
 786        &self,
 787        channel_id: ChannelId,
 788        user_id: UserId,
 789        is_admin: bool,
 790    ) -> Result<()> {
 791        let query = "
 792            INSERT INTO channel_memberships (channel_id, user_id, admin)
 793            VALUES ($1, $2, $3)
 794            ON CONFLICT DO NOTHING
 795        ";
 796        Ok(sqlx::query(query)
 797            .bind(channel_id.0)
 798            .bind(user_id.0)
 799            .bind(is_admin)
 800            .execute(&self.pool)
 801            .await
 802            .map(drop)?)
 803    }
 804
 805    // messages
 806
 807    async fn create_channel_message(
 808        &self,
 809        channel_id: ChannelId,
 810        sender_id: UserId,
 811        body: &str,
 812        timestamp: OffsetDateTime,
 813        nonce: u128,
 814    ) -> Result<MessageId> {
 815        let query = "
 816            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
 817            VALUES ($1, $2, $3, $4, $5)
 818            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
 819            RETURNING id
 820        ";
 821        Ok(sqlx::query_scalar(query)
 822            .bind(channel_id.0)
 823            .bind(sender_id.0)
 824            .bind(body)
 825            .bind(timestamp)
 826            .bind(Uuid::from_u128(nonce))
 827            .fetch_one(&self.pool)
 828            .await
 829            .map(MessageId)?)
 830    }
 831
 832    async fn get_channel_messages(
 833        &self,
 834        channel_id: ChannelId,
 835        count: usize,
 836        before_id: Option<MessageId>,
 837    ) -> Result<Vec<ChannelMessage>> {
 838        let query = r#"
 839            SELECT * FROM (
 840                SELECT
 841                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
 842                FROM
 843                    channel_messages
 844                WHERE
 845                    channel_id = $1 AND
 846                    id < $2
 847                ORDER BY id DESC
 848                LIMIT $3
 849            ) as recent_messages
 850            ORDER BY id ASC
 851        "#;
 852        Ok(sqlx::query_as(query)
 853            .bind(channel_id.0)
 854            .bind(before_id.unwrap_or(MessageId::MAX))
 855            .bind(count as i64)
 856            .fetch_all(&self.pool)
 857            .await?)
 858    }
 859
 860    #[cfg(test)]
 861    async fn teardown(&self, url: &str) {
 862        use util::ResultExt;
 863
 864        let query = "
 865            SELECT pg_terminate_backend(pg_stat_activity.pid)
 866            FROM pg_stat_activity
 867            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 868        ";
 869        sqlx::query(query).execute(&self.pool).await.log_err();
 870        self.pool.close().await;
 871        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
 872            .await
 873            .log_err();
 874    }
 875
 876    #[cfg(test)]
 877    fn as_fake(&self) -> Option<&tests::FakeDb> {
 878        None
 879    }
 880}
 881
 882macro_rules! id_type {
 883    ($name:ident) => {
 884        #[derive(
 885            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
 886        )]
 887        #[sqlx(transparent)]
 888        #[serde(transparent)]
 889        pub struct $name(pub i32);
 890
 891        impl $name {
 892            #[allow(unused)]
 893            pub const MAX: Self = Self(i32::MAX);
 894
 895            #[allow(unused)]
 896            pub fn from_proto(value: u64) -> Self {
 897                Self(value as i32)
 898            }
 899
 900            #[allow(unused)]
 901            pub fn to_proto(&self) -> u64 {
 902                self.0 as u64
 903            }
 904        }
 905
 906        impl std::fmt::Display for $name {
 907            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 908                self.0.fmt(f)
 909            }
 910        }
 911    };
 912}
 913
 914id_type!(UserId);
 915#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 916pub struct User {
 917    pub id: UserId,
 918    pub github_login: String,
 919    pub email_address: Option<String>,
 920    pub admin: bool,
 921    pub invite_code: Option<String>,
 922    pub invite_count: i32,
 923    pub connected_once: bool,
 924}
 925
 926id_type!(OrgId);
 927#[derive(FromRow)]
 928pub struct Org {
 929    pub id: OrgId,
 930    pub name: String,
 931    pub slug: String,
 932}
 933
 934id_type!(ChannelId);
 935#[derive(Clone, Debug, FromRow, Serialize)]
 936pub struct Channel {
 937    pub id: ChannelId,
 938    pub name: String,
 939    pub owner_id: i32,
 940    pub owner_is_user: bool,
 941}
 942
 943id_type!(MessageId);
 944#[derive(Clone, Debug, FromRow)]
 945pub struct ChannelMessage {
 946    pub id: MessageId,
 947    pub channel_id: ChannelId,
 948    pub sender_id: UserId,
 949    pub body: String,
 950    pub sent_at: OffsetDateTime,
 951    pub nonce: Uuid,
 952}
 953
 954#[derive(Clone, Debug, PartialEq, Eq)]
 955pub enum Contact {
 956    Accepted {
 957        user_id: UserId,
 958        should_notify: bool,
 959    },
 960    Outgoing {
 961        user_id: UserId,
 962    },
 963    Incoming {
 964        user_id: UserId,
 965        should_notify: bool,
 966    },
 967}
 968
 969impl Contact {
 970    pub fn user_id(&self) -> UserId {
 971        match self {
 972            Contact::Accepted { user_id, .. } => *user_id,
 973            Contact::Outgoing { user_id } => *user_id,
 974            Contact::Incoming { user_id, .. } => *user_id,
 975        }
 976    }
 977}
 978
 979#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 980pub struct IncomingContactRequest {
 981    pub requester_id: UserId,
 982    pub should_notify: bool,
 983}
 984
 985fn fuzzy_like_string(string: &str) -> String {
 986    let mut result = String::with_capacity(string.len() * 2 + 1);
 987    for c in string.chars() {
 988        if c.is_alphanumeric() {
 989            result.push('%');
 990            result.push(c);
 991        }
 992    }
 993    result.push('%');
 994    result
 995}
 996
 997#[cfg(test)]
 998pub mod tests {
 999    use super::*;
1000    use anyhow::anyhow;
1001    use collections::BTreeMap;
1002    use gpui::executor::Background;
1003    use lazy_static::lazy_static;
1004    use parking_lot::Mutex;
1005    use rand::prelude::*;
1006    use sqlx::{
1007        migrate::{MigrateDatabase, Migrator},
1008        Postgres,
1009    };
1010    use std::{path::Path, sync::Arc};
1011    use util::post_inc;
1012
1013    #[tokio::test(flavor = "multi_thread")]
1014    async fn test_get_users_by_ids() {
1015        for test_db in [
1016            TestDb::postgres().await,
1017            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1018        ] {
1019            let db = test_db.db();
1020
1021            let user = db.create_user("user", None, false).await.unwrap();
1022            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1023            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1024            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1025
1026            assert_eq!(
1027                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1028                    .await
1029                    .unwrap(),
1030                vec![
1031                    User {
1032                        id: user,
1033                        github_login: "user".to_string(),
1034                        admin: false,
1035                        ..Default::default()
1036                    },
1037                    User {
1038                        id: friend1,
1039                        github_login: "friend-1".to_string(),
1040                        admin: false,
1041                        ..Default::default()
1042                    },
1043                    User {
1044                        id: friend2,
1045                        github_login: "friend-2".to_string(),
1046                        admin: false,
1047                        ..Default::default()
1048                    },
1049                    User {
1050                        id: friend3,
1051                        github_login: "friend-3".to_string(),
1052                        admin: false,
1053                        ..Default::default()
1054                    }
1055                ]
1056            );
1057        }
1058    }
1059
1060    #[tokio::test(flavor = "multi_thread")]
1061    async fn test_create_users() {
1062        let db = TestDb::postgres().await;
1063        let db = db.db();
1064
1065        // Create the first batch of users, ensuring invite counts are assigned
1066        // correctly and the respective invite codes are unique.
1067        let user_ids_batch_1 = db
1068            .create_users(vec![
1069                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1070                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1071                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1072            ])
1073            .await
1074            .unwrap();
1075        assert_eq!(user_ids_batch_1.len(), 3);
1076
1077        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1078        assert_eq!(users.len(), 3);
1079        assert_eq!(users[0].github_login, "user1");
1080        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1081        assert_eq!(users[0].invite_count, 5);
1082        assert_eq!(users[1].github_login, "user2");
1083        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1084        assert_eq!(users[1].invite_count, 4);
1085        assert_eq!(users[2].github_login, "user3");
1086        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1087        assert_eq!(users[2].invite_count, 3);
1088
1089        let invite_code_1 = users[0].invite_code.clone().unwrap();
1090        let invite_code_2 = users[1].invite_code.clone().unwrap();
1091        let invite_code_3 = users[2].invite_code.clone().unwrap();
1092        assert_ne!(invite_code_1, invite_code_2);
1093        assert_ne!(invite_code_1, invite_code_3);
1094        assert_ne!(invite_code_2, invite_code_3);
1095
1096        // Create the second batch of users and include a user that is already in the database, ensuring
1097        // the invite count for the existing user is updated without changing their invite code.
1098        let user_ids_batch_2 = db
1099            .create_users(vec![
1100                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1101                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1102            ])
1103            .await
1104            .unwrap();
1105        assert_eq!(user_ids_batch_2.len(), 2);
1106        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1107
1108        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1109        assert_eq!(users.len(), 2);
1110        assert_eq!(users[0].github_login, "user2");
1111        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1112        assert_eq!(users[0].invite_count, 10);
1113        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1114        assert_eq!(users[1].github_login, "user4");
1115        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1116        assert_eq!(users[1].invite_count, 2);
1117
1118        let invite_code_4 = users[1].invite_code.clone().unwrap();
1119        assert_ne!(invite_code_4, invite_code_1);
1120        assert_ne!(invite_code_4, invite_code_2);
1121        assert_ne!(invite_code_4, invite_code_3);
1122    }
1123
1124    #[tokio::test(flavor = "multi_thread")]
1125    async fn test_recent_channel_messages() {
1126        for test_db in [
1127            TestDb::postgres().await,
1128            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1129        ] {
1130            let db = test_db.db();
1131            let user = db.create_user("user", None, false).await.unwrap();
1132            let org = db.create_org("org", "org").await.unwrap();
1133            let channel = db.create_org_channel(org, "channel").await.unwrap();
1134            for i in 0..10 {
1135                db.create_channel_message(
1136                    channel,
1137                    user,
1138                    &i.to_string(),
1139                    OffsetDateTime::now_utc(),
1140                    i,
1141                )
1142                .await
1143                .unwrap();
1144            }
1145
1146            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1147            assert_eq!(
1148                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1149                ["5", "6", "7", "8", "9"]
1150            );
1151
1152            let prev_messages = db
1153                .get_channel_messages(channel, 4, Some(messages[0].id))
1154                .await
1155                .unwrap();
1156            assert_eq!(
1157                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1158                ["1", "2", "3", "4"]
1159            );
1160        }
1161    }
1162
1163    #[tokio::test(flavor = "multi_thread")]
1164    async fn test_channel_message_nonces() {
1165        for test_db in [
1166            TestDb::postgres().await,
1167            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1168        ] {
1169            let db = test_db.db();
1170            let user = db.create_user("user", None, false).await.unwrap();
1171            let org = db.create_org("org", "org").await.unwrap();
1172            let channel = db.create_org_channel(org, "channel").await.unwrap();
1173
1174            let msg1_id = db
1175                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1176                .await
1177                .unwrap();
1178            let msg2_id = db
1179                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1180                .await
1181                .unwrap();
1182            let msg3_id = db
1183                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1184                .await
1185                .unwrap();
1186            let msg4_id = db
1187                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1188                .await
1189                .unwrap();
1190
1191            assert_ne!(msg1_id, msg2_id);
1192            assert_eq!(msg1_id, msg3_id);
1193            assert_eq!(msg2_id, msg4_id);
1194        }
1195    }
1196
1197    #[tokio::test(flavor = "multi_thread")]
1198    async fn test_create_access_tokens() {
1199        let test_db = TestDb::postgres().await;
1200        let db = test_db.db();
1201        let user = db.create_user("the-user", None, false).await.unwrap();
1202
1203        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1204        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1205        assert_eq!(
1206            db.get_access_token_hashes(user).await.unwrap(),
1207            &["h2".to_string(), "h1".to_string()]
1208        );
1209
1210        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1211        assert_eq!(
1212            db.get_access_token_hashes(user).await.unwrap(),
1213            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1214        );
1215
1216        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1217        assert_eq!(
1218            db.get_access_token_hashes(user).await.unwrap(),
1219            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1220        );
1221
1222        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1223        assert_eq!(
1224            db.get_access_token_hashes(user).await.unwrap(),
1225            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1226        );
1227    }
1228
1229    #[test]
1230    fn test_fuzzy_like_string() {
1231        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1232        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1233        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1234    }
1235
1236    #[tokio::test(flavor = "multi_thread")]
1237    async fn test_fuzzy_search_users() {
1238        let test_db = TestDb::postgres().await;
1239        let db = test_db.db();
1240        for github_login in [
1241            "California",
1242            "colorado",
1243            "oregon",
1244            "washington",
1245            "florida",
1246            "delaware",
1247            "rhode-island",
1248        ] {
1249            db.create_user(github_login, None, false).await.unwrap();
1250        }
1251
1252        assert_eq!(
1253            fuzzy_search_user_names(db, "clr").await,
1254            &["colorado", "California"]
1255        );
1256        assert_eq!(
1257            fuzzy_search_user_names(db, "ro").await,
1258            &["rhode-island", "colorado", "oregon"],
1259        );
1260
1261        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1262            db.fuzzy_search_users(query, 10)
1263                .await
1264                .unwrap()
1265                .into_iter()
1266                .map(|user| user.github_login)
1267                .collect::<Vec<_>>()
1268        }
1269    }
1270
1271    #[tokio::test(flavor = "multi_thread")]
1272    async fn test_add_contacts() {
1273        for test_db in [
1274            TestDb::postgres().await,
1275            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1276        ] {
1277            let db = test_db.db();
1278
1279            let user_1 = db.create_user("user1", None, false).await.unwrap();
1280            let user_2 = db.create_user("user2", None, false).await.unwrap();
1281            let user_3 = db.create_user("user3", None, false).await.unwrap();
1282
1283            // User starts with no contacts
1284            assert_eq!(
1285                db.get_contacts(user_1).await.unwrap(),
1286                vec![Contact::Accepted {
1287                    user_id: user_1,
1288                    should_notify: false
1289                }],
1290            );
1291
1292            // User requests a contact. Both users see the pending request.
1293            db.send_contact_request(user_1, user_2).await.unwrap();
1294            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1295            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1296            assert_eq!(
1297                db.get_contacts(user_1).await.unwrap(),
1298                &[
1299                    Contact::Accepted {
1300                        user_id: user_1,
1301                        should_notify: false
1302                    },
1303                    Contact::Outgoing { user_id: user_2 }
1304                ],
1305            );
1306            assert_eq!(
1307                db.get_contacts(user_2).await.unwrap(),
1308                &[
1309                    Contact::Incoming {
1310                        user_id: user_1,
1311                        should_notify: true
1312                    },
1313                    Contact::Accepted {
1314                        user_id: user_2,
1315                        should_notify: false
1316                    },
1317                ]
1318            );
1319
1320            // User 2 dismisses the contact request notification without accepting or rejecting.
1321            // We shouldn't notify them again.
1322            db.dismiss_contact_notification(user_1, user_2)
1323                .await
1324                .unwrap_err();
1325            db.dismiss_contact_notification(user_2, user_1)
1326                .await
1327                .unwrap();
1328            assert_eq!(
1329                db.get_contacts(user_2).await.unwrap(),
1330                &[
1331                    Contact::Incoming {
1332                        user_id: user_1,
1333                        should_notify: false
1334                    },
1335                    Contact::Accepted {
1336                        user_id: user_2,
1337                        should_notify: false
1338                    },
1339                ]
1340            );
1341
1342            // User can't accept their own contact request
1343            db.respond_to_contact_request(user_1, user_2, true)
1344                .await
1345                .unwrap_err();
1346
1347            // User accepts a contact request. Both users see the contact.
1348            db.respond_to_contact_request(user_2, user_1, true)
1349                .await
1350                .unwrap();
1351            assert_eq!(
1352                db.get_contacts(user_1).await.unwrap(),
1353                &[
1354                    Contact::Accepted {
1355                        user_id: user_1,
1356                        should_notify: false
1357                    },
1358                    Contact::Accepted {
1359                        user_id: user_2,
1360                        should_notify: true
1361                    }
1362                ],
1363            );
1364            assert!(db.has_contact(user_1, user_2).await.unwrap());
1365            assert!(db.has_contact(user_2, user_1).await.unwrap());
1366            assert_eq!(
1367                db.get_contacts(user_2).await.unwrap(),
1368                &[
1369                    Contact::Accepted {
1370                        user_id: user_1,
1371                        should_notify: false,
1372                    },
1373                    Contact::Accepted {
1374                        user_id: user_2,
1375                        should_notify: false,
1376                    },
1377                ]
1378            );
1379
1380            // Users cannot re-request existing contacts.
1381            db.send_contact_request(user_1, user_2).await.unwrap_err();
1382            db.send_contact_request(user_2, user_1).await.unwrap_err();
1383
1384            // Users can't dismiss notifications of them accepting other users' requests.
1385            db.dismiss_contact_notification(user_2, user_1)
1386                .await
1387                .unwrap_err();
1388            assert_eq!(
1389                db.get_contacts(user_1).await.unwrap(),
1390                &[
1391                    Contact::Accepted {
1392                        user_id: user_1,
1393                        should_notify: false
1394                    },
1395                    Contact::Accepted {
1396                        user_id: user_2,
1397                        should_notify: true,
1398                    },
1399                ]
1400            );
1401
1402            // Users can dismiss notifications of other users accepting their requests.
1403            db.dismiss_contact_notification(user_1, user_2)
1404                .await
1405                .unwrap();
1406            assert_eq!(
1407                db.get_contacts(user_1).await.unwrap(),
1408                &[
1409                    Contact::Accepted {
1410                        user_id: user_1,
1411                        should_notify: false
1412                    },
1413                    Contact::Accepted {
1414                        user_id: user_2,
1415                        should_notify: false,
1416                    },
1417                ]
1418            );
1419
1420            // Users send each other concurrent contact requests and
1421            // see that they are immediately accepted.
1422            db.send_contact_request(user_1, user_3).await.unwrap();
1423            db.send_contact_request(user_3, user_1).await.unwrap();
1424            assert_eq!(
1425                db.get_contacts(user_1).await.unwrap(),
1426                &[
1427                    Contact::Accepted {
1428                        user_id: user_1,
1429                        should_notify: false
1430                    },
1431                    Contact::Accepted {
1432                        user_id: user_2,
1433                        should_notify: false,
1434                    },
1435                    Contact::Accepted {
1436                        user_id: user_3,
1437                        should_notify: false
1438                    },
1439                ]
1440            );
1441            assert_eq!(
1442                db.get_contacts(user_3).await.unwrap(),
1443                &[
1444                    Contact::Accepted {
1445                        user_id: user_1,
1446                        should_notify: false
1447                    },
1448                    Contact::Accepted {
1449                        user_id: user_3,
1450                        should_notify: false
1451                    }
1452                ],
1453            );
1454
1455            // User declines a contact request. Both users see that it is gone.
1456            db.send_contact_request(user_2, user_3).await.unwrap();
1457            db.respond_to_contact_request(user_3, user_2, false)
1458                .await
1459                .unwrap();
1460            assert!(!db.has_contact(user_2, user_3).await.unwrap());
1461            assert!(!db.has_contact(user_3, user_2).await.unwrap());
1462            assert_eq!(
1463                db.get_contacts(user_2).await.unwrap(),
1464                &[
1465                    Contact::Accepted {
1466                        user_id: user_1,
1467                        should_notify: false
1468                    },
1469                    Contact::Accepted {
1470                        user_id: user_2,
1471                        should_notify: false
1472                    }
1473                ]
1474            );
1475            assert_eq!(
1476                db.get_contacts(user_3).await.unwrap(),
1477                &[
1478                    Contact::Accepted {
1479                        user_id: user_1,
1480                        should_notify: false
1481                    },
1482                    Contact::Accepted {
1483                        user_id: user_3,
1484                        should_notify: false
1485                    }
1486                ],
1487            );
1488        }
1489    }
1490
1491    #[tokio::test(flavor = "multi_thread")]
1492    async fn test_invite_codes() {
1493        let postgres = TestDb::postgres().await;
1494        let db = postgres.db();
1495        let user1 = db.create_user("user-1", None, false).await.unwrap();
1496
1497        // Initially, user 1 has no invite code
1498        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1499
1500        // Setting invite count to 0 when no code is assigned does not assign a new code
1501        db.set_invite_count(user1, 0).await.unwrap();
1502        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1503
1504        // User 1 creates an invite code that can be used twice.
1505        db.set_invite_count(user1, 2).await.unwrap();
1506        let (invite_code, invite_count) =
1507            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1508        assert_eq!(invite_count, 2);
1509
1510        // User 2 redeems the invite code and becomes a contact of user 1.
1511        let user2 = db
1512            .redeem_invite_code(&invite_code, "user-2", None)
1513            .await
1514            .unwrap();
1515        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1516        assert_eq!(invite_count, 1);
1517        assert_eq!(
1518            db.get_contacts(user1).await.unwrap(),
1519            [
1520                Contact::Accepted {
1521                    user_id: user1,
1522                    should_notify: false
1523                },
1524                Contact::Accepted {
1525                    user_id: user2,
1526                    should_notify: true
1527                }
1528            ]
1529        );
1530        assert_eq!(
1531            db.get_contacts(user2).await.unwrap(),
1532            [
1533                Contact::Accepted {
1534                    user_id: user1,
1535                    should_notify: false
1536                },
1537                Contact::Accepted {
1538                    user_id: user2,
1539                    should_notify: false
1540                }
1541            ]
1542        );
1543
1544        // User 3 redeems the invite code and becomes a contact of user 1.
1545        let user3 = db
1546            .redeem_invite_code(&invite_code, "user-3", None)
1547            .await
1548            .unwrap();
1549        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1550        assert_eq!(invite_count, 0);
1551        assert_eq!(
1552            db.get_contacts(user1).await.unwrap(),
1553            [
1554                Contact::Accepted {
1555                    user_id: user1,
1556                    should_notify: false
1557                },
1558                Contact::Accepted {
1559                    user_id: user2,
1560                    should_notify: true
1561                },
1562                Contact::Accepted {
1563                    user_id: user3,
1564                    should_notify: true
1565                }
1566            ]
1567        );
1568        assert_eq!(
1569            db.get_contacts(user3).await.unwrap(),
1570            [
1571                Contact::Accepted {
1572                    user_id: user1,
1573                    should_notify: false
1574                },
1575                Contact::Accepted {
1576                    user_id: user3,
1577                    should_notify: false
1578                },
1579            ]
1580        );
1581
1582        // Trying to reedem the code for the third time results in an error.
1583        db.redeem_invite_code(&invite_code, "user-4", None)
1584            .await
1585            .unwrap_err();
1586
1587        // Invite count can be updated after the code has been created.
1588        db.set_invite_count(user1, 2).await.unwrap();
1589        let (latest_code, invite_count) =
1590            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1591        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1592        assert_eq!(invite_count, 2);
1593
1594        // User 4 can now redeem the invite code and becomes a contact of user 1.
1595        let user4 = db
1596            .redeem_invite_code(&invite_code, "user-4", None)
1597            .await
1598            .unwrap();
1599        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1600        assert_eq!(invite_count, 1);
1601        assert_eq!(
1602            db.get_contacts(user1).await.unwrap(),
1603            [
1604                Contact::Accepted {
1605                    user_id: user1,
1606                    should_notify: false
1607                },
1608                Contact::Accepted {
1609                    user_id: user2,
1610                    should_notify: true
1611                },
1612                Contact::Accepted {
1613                    user_id: user3,
1614                    should_notify: true
1615                },
1616                Contact::Accepted {
1617                    user_id: user4,
1618                    should_notify: true
1619                }
1620            ]
1621        );
1622        assert_eq!(
1623            db.get_contacts(user4).await.unwrap(),
1624            [
1625                Contact::Accepted {
1626                    user_id: user1,
1627                    should_notify: false
1628                },
1629                Contact::Accepted {
1630                    user_id: user4,
1631                    should_notify: false
1632                },
1633            ]
1634        );
1635
1636        // An existing user cannot redeem invite codes.
1637        db.redeem_invite_code(&invite_code, "user-2", None)
1638            .await
1639            .unwrap_err();
1640        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1641        assert_eq!(invite_count, 1);
1642    }
1643
1644    pub struct TestDb {
1645        pub db: Option<Arc<dyn Db>>,
1646        pub url: String,
1647    }
1648
1649    impl TestDb {
1650        pub async fn postgres() -> Self {
1651            lazy_static! {
1652                static ref LOCK: Mutex<()> = Mutex::new(());
1653            }
1654
1655            let _guard = LOCK.lock();
1656            let mut rng = StdRng::from_entropy();
1657            let name = format!("zed-test-{}", rng.gen::<u128>());
1658            let url = format!("postgres://postgres@localhost/{}", name);
1659            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1660            Postgres::create_database(&url)
1661                .await
1662                .expect("failed to create test db");
1663            let db = PostgresDb::new(&url, 5).await.unwrap();
1664            let migrator = Migrator::new(migrations_path).await.unwrap();
1665            migrator.run(&db.pool).await.unwrap();
1666            Self {
1667                db: Some(Arc::new(db)),
1668                url,
1669            }
1670        }
1671
1672        pub fn fake(background: Arc<Background>) -> Self {
1673            Self {
1674                db: Some(Arc::new(FakeDb::new(background))),
1675                url: Default::default(),
1676            }
1677        }
1678
1679        pub fn db(&self) -> &Arc<dyn Db> {
1680            self.db.as_ref().unwrap()
1681        }
1682    }
1683
1684    impl Drop for TestDb {
1685        fn drop(&mut self) {
1686            if let Some(db) = self.db.take() {
1687                futures::executor::block_on(db.teardown(&self.url));
1688            }
1689        }
1690    }
1691
1692    pub struct FakeDb {
1693        background: Arc<Background>,
1694        pub users: Mutex<BTreeMap<UserId, User>>,
1695        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1696        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1697        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1698        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1699        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1700        pub contacts: Mutex<Vec<FakeContact>>,
1701        next_channel_message_id: Mutex<i32>,
1702        next_user_id: Mutex<i32>,
1703        next_org_id: Mutex<i32>,
1704        next_channel_id: Mutex<i32>,
1705    }
1706
1707    #[derive(Debug)]
1708    pub struct FakeContact {
1709        pub requester_id: UserId,
1710        pub responder_id: UserId,
1711        pub accepted: bool,
1712        pub should_notify: bool,
1713    }
1714
1715    impl FakeDb {
1716        pub fn new(background: Arc<Background>) -> Self {
1717            Self {
1718                background,
1719                users: Default::default(),
1720                next_user_id: Mutex::new(1),
1721                orgs: Default::default(),
1722                next_org_id: Mutex::new(1),
1723                org_memberships: Default::default(),
1724                channels: Default::default(),
1725                next_channel_id: Mutex::new(1),
1726                channel_memberships: Default::default(),
1727                channel_messages: Default::default(),
1728                next_channel_message_id: Mutex::new(1),
1729                contacts: Default::default(),
1730            }
1731        }
1732    }
1733
1734    #[async_trait]
1735    impl Db for FakeDb {
1736        async fn create_user(
1737            &self,
1738            github_login: &str,
1739            email_address: Option<&str>,
1740            admin: bool,
1741        ) -> Result<UserId> {
1742            self.background.simulate_random_delay().await;
1743
1744            let mut users = self.users.lock();
1745            if let Some(user) = users
1746                .values()
1747                .find(|user| user.github_login == github_login)
1748            {
1749                Ok(user.id)
1750            } else {
1751                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1752                users.insert(
1753                    user_id,
1754                    User {
1755                        id: user_id,
1756                        github_login: github_login.to_string(),
1757                        email_address: email_address.map(str::to_string),
1758                        admin,
1759                        invite_code: None,
1760                        invite_count: 0,
1761                        connected_once: false,
1762                    },
1763                );
1764                Ok(user_id)
1765            }
1766        }
1767
1768        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
1769            unimplemented!()
1770        }
1771
1772        async fn get_all_users(&self) -> Result<Vec<User>> {
1773            unimplemented!()
1774        }
1775
1776        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1777            unimplemented!()
1778        }
1779
1780        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1781            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1782        }
1783
1784        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1785            self.background.simulate_random_delay().await;
1786            let users = self.users.lock();
1787            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1788        }
1789
1790        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1791            Ok(self
1792                .users
1793                .lock()
1794                .values()
1795                .find(|user| user.github_login == github_login)
1796                .cloned())
1797        }
1798
1799        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1800            unimplemented!()
1801        }
1802
1803        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1804            self.background.simulate_random_delay().await;
1805            let mut users = self.users.lock();
1806            let mut user = users
1807                .get_mut(&id)
1808                .ok_or_else(|| anyhow!("user not found"))?;
1809            user.connected_once = connected_once;
1810            Ok(())
1811        }
1812
1813        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1814            unimplemented!()
1815        }
1816
1817        // invite codes
1818
1819        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
1820            unimplemented!()
1821        }
1822
1823        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1824            Ok(None)
1825        }
1826
1827        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1828            unimplemented!()
1829        }
1830
1831        async fn redeem_invite_code(
1832            &self,
1833            _code: &str,
1834            _login: &str,
1835            _email_address: Option<&str>,
1836        ) -> Result<UserId> {
1837            unimplemented!()
1838        }
1839
1840        // contacts
1841
1842        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1843            self.background.simulate_random_delay().await;
1844            let mut contacts = vec![Contact::Accepted {
1845                user_id: id,
1846                should_notify: false,
1847            }];
1848
1849            for contact in self.contacts.lock().iter() {
1850                if contact.requester_id == id {
1851                    if contact.accepted {
1852                        contacts.push(Contact::Accepted {
1853                            user_id: contact.responder_id,
1854                            should_notify: contact.should_notify,
1855                        });
1856                    } else {
1857                        contacts.push(Contact::Outgoing {
1858                            user_id: contact.responder_id,
1859                        });
1860                    }
1861                } else if contact.responder_id == id {
1862                    if contact.accepted {
1863                        contacts.push(Contact::Accepted {
1864                            user_id: contact.requester_id,
1865                            should_notify: false,
1866                        });
1867                    } else {
1868                        contacts.push(Contact::Incoming {
1869                            user_id: contact.requester_id,
1870                            should_notify: contact.should_notify,
1871                        });
1872                    }
1873                }
1874            }
1875
1876            contacts.sort_unstable_by_key(|contact| contact.user_id());
1877            Ok(contacts)
1878        }
1879
1880        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
1881            self.background.simulate_random_delay().await;
1882            Ok(self.contacts.lock().iter().any(|contact| {
1883                contact.accepted
1884                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
1885                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
1886            }))
1887        }
1888
1889        async fn send_contact_request(
1890            &self,
1891            requester_id: UserId,
1892            responder_id: UserId,
1893        ) -> Result<()> {
1894            let mut contacts = self.contacts.lock();
1895            for contact in contacts.iter_mut() {
1896                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1897                    if contact.accepted {
1898                        Err(anyhow!("contact already exists"))?;
1899                    } else {
1900                        Err(anyhow!("contact already requested"))?;
1901                    }
1902                }
1903                if contact.responder_id == requester_id && contact.requester_id == responder_id {
1904                    if contact.accepted {
1905                        Err(anyhow!("contact already exists"))?;
1906                    } else {
1907                        contact.accepted = true;
1908                        contact.should_notify = false;
1909                        return Ok(());
1910                    }
1911                }
1912            }
1913            contacts.push(FakeContact {
1914                requester_id,
1915                responder_id,
1916                accepted: false,
1917                should_notify: true,
1918            });
1919            Ok(())
1920        }
1921
1922        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1923            self.contacts.lock().retain(|contact| {
1924                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
1925            });
1926            Ok(())
1927        }
1928
1929        async fn dismiss_contact_notification(
1930            &self,
1931            user_id: UserId,
1932            contact_user_id: UserId,
1933        ) -> Result<()> {
1934            let mut contacts = self.contacts.lock();
1935            for contact in contacts.iter_mut() {
1936                if contact.requester_id == contact_user_id
1937                    && contact.responder_id == user_id
1938                    && !contact.accepted
1939                {
1940                    contact.should_notify = false;
1941                    return Ok(());
1942                }
1943                if contact.requester_id == user_id
1944                    && contact.responder_id == contact_user_id
1945                    && contact.accepted
1946                {
1947                    contact.should_notify = false;
1948                    return Ok(());
1949                }
1950            }
1951            Err(anyhow!("no such notification"))?
1952        }
1953
1954        async fn respond_to_contact_request(
1955            &self,
1956            responder_id: UserId,
1957            requester_id: UserId,
1958            accept: bool,
1959        ) -> Result<()> {
1960            let mut contacts = self.contacts.lock();
1961            for (ix, contact) in contacts.iter_mut().enumerate() {
1962                if contact.requester_id == requester_id && contact.responder_id == responder_id {
1963                    if contact.accepted {
1964                        Err(anyhow!("contact already confirmed"))?;
1965                    }
1966                    if accept {
1967                        contact.accepted = true;
1968                        contact.should_notify = true;
1969                    } else {
1970                        contacts.remove(ix);
1971                    }
1972                    return Ok(());
1973                }
1974            }
1975            Err(anyhow!("no such contact request"))?
1976        }
1977
1978        async fn create_access_token_hash(
1979            &self,
1980            _user_id: UserId,
1981            _access_token_hash: &str,
1982            _max_access_token_count: usize,
1983        ) -> Result<()> {
1984            unimplemented!()
1985        }
1986
1987        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
1988            unimplemented!()
1989        }
1990
1991        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
1992            unimplemented!()
1993        }
1994
1995        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1996            self.background.simulate_random_delay().await;
1997            let mut orgs = self.orgs.lock();
1998            if orgs.values().any(|org| org.slug == slug) {
1999                Err(anyhow!("org already exists"))?
2000            } else {
2001                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2002                orgs.insert(
2003                    org_id,
2004                    Org {
2005                        id: org_id,
2006                        name: name.to_string(),
2007                        slug: slug.to_string(),
2008                    },
2009                );
2010                Ok(org_id)
2011            }
2012        }
2013
2014        async fn add_org_member(
2015            &self,
2016            org_id: OrgId,
2017            user_id: UserId,
2018            is_admin: bool,
2019        ) -> Result<()> {
2020            self.background.simulate_random_delay().await;
2021            if !self.orgs.lock().contains_key(&org_id) {
2022                Err(anyhow!("org does not exist"))?;
2023            }
2024            if !self.users.lock().contains_key(&user_id) {
2025                Err(anyhow!("user does not exist"))?;
2026            }
2027
2028            self.org_memberships
2029                .lock()
2030                .entry((org_id, user_id))
2031                .or_insert(is_admin);
2032            Ok(())
2033        }
2034
2035        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2036            self.background.simulate_random_delay().await;
2037            if !self.orgs.lock().contains_key(&org_id) {
2038                Err(anyhow!("org does not exist"))?;
2039            }
2040
2041            let mut channels = self.channels.lock();
2042            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2043            channels.insert(
2044                channel_id,
2045                Channel {
2046                    id: channel_id,
2047                    name: name.to_string(),
2048                    owner_id: org_id.0,
2049                    owner_is_user: false,
2050                },
2051            );
2052            Ok(channel_id)
2053        }
2054
2055        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2056            self.background.simulate_random_delay().await;
2057            Ok(self
2058                .channels
2059                .lock()
2060                .values()
2061                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2062                .cloned()
2063                .collect())
2064        }
2065
2066        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2067            self.background.simulate_random_delay().await;
2068            let channels = self.channels.lock();
2069            let memberships = self.channel_memberships.lock();
2070            Ok(channels
2071                .values()
2072                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2073                .cloned()
2074                .collect())
2075        }
2076
2077        async fn can_user_access_channel(
2078            &self,
2079            user_id: UserId,
2080            channel_id: ChannelId,
2081        ) -> Result<bool> {
2082            self.background.simulate_random_delay().await;
2083            Ok(self
2084                .channel_memberships
2085                .lock()
2086                .contains_key(&(channel_id, user_id)))
2087        }
2088
2089        async fn add_channel_member(
2090            &self,
2091            channel_id: ChannelId,
2092            user_id: UserId,
2093            is_admin: bool,
2094        ) -> Result<()> {
2095            self.background.simulate_random_delay().await;
2096            if !self.channels.lock().contains_key(&channel_id) {
2097                Err(anyhow!("channel does not exist"))?;
2098            }
2099            if !self.users.lock().contains_key(&user_id) {
2100                Err(anyhow!("user does not exist"))?;
2101            }
2102
2103            self.channel_memberships
2104                .lock()
2105                .entry((channel_id, user_id))
2106                .or_insert(is_admin);
2107            Ok(())
2108        }
2109
2110        async fn create_channel_message(
2111            &self,
2112            channel_id: ChannelId,
2113            sender_id: UserId,
2114            body: &str,
2115            timestamp: OffsetDateTime,
2116            nonce: u128,
2117        ) -> Result<MessageId> {
2118            self.background.simulate_random_delay().await;
2119            if !self.channels.lock().contains_key(&channel_id) {
2120                Err(anyhow!("channel does not exist"))?;
2121            }
2122            if !self.users.lock().contains_key(&sender_id) {
2123                Err(anyhow!("user does not exist"))?;
2124            }
2125
2126            let mut messages = self.channel_messages.lock();
2127            if let Some(message) = messages
2128                .values()
2129                .find(|message| message.nonce.as_u128() == nonce)
2130            {
2131                Ok(message.id)
2132            } else {
2133                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2134                messages.insert(
2135                    message_id,
2136                    ChannelMessage {
2137                        id: message_id,
2138                        channel_id,
2139                        sender_id,
2140                        body: body.to_string(),
2141                        sent_at: timestamp,
2142                        nonce: Uuid::from_u128(nonce),
2143                    },
2144                );
2145                Ok(message_id)
2146            }
2147        }
2148
2149        async fn get_channel_messages(
2150            &self,
2151            channel_id: ChannelId,
2152            count: usize,
2153            before_id: Option<MessageId>,
2154        ) -> Result<Vec<ChannelMessage>> {
2155            let mut messages = self
2156                .channel_messages
2157                .lock()
2158                .values()
2159                .rev()
2160                .filter(|message| {
2161                    message.channel_id == channel_id
2162                        && message.id < before_id.unwrap_or(MessageId::MAX)
2163                })
2164                .take(count)
2165                .cloned()
2166                .collect::<Vec<_>>();
2167            messages.sort_unstable_by_key(|message| message.id);
2168            Ok(messages)
2169        }
2170
2171        async fn teardown(&self, _: &str) {}
2172
2173        #[cfg(test)]
2174        fn as_fake(&self) -> Option<&FakeDb> {
2175            Some(self)
2176        }
2177    }
2178}