db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use serde::{Deserialize, Serialize};
   7use sqlx::{
   8    migrate::{Migrate as _, Migration, MigrationSource},
   9    types::Uuid,
  10    FromRow,
  11};
  12use std::{path::Path, time::Duration};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[cfg(test)]
  16pub type DefaultDb = Db<sqlx::Sqlite>;
  17
  18#[cfg(not(test))]
  19pub type DefaultDb = Db<sqlx::Postgres>;
  20
  21pub struct Db<D: sqlx::Database> {
  22    pool: sqlx::Pool<D>,
  23    #[cfg(test)]
  24    background: Option<std::sync::Arc<gpui::executor::Background>>,
  25    #[cfg(test)]
  26    runtime: Option<tokio::runtime::Runtime>,
  27}
  28
  29macro_rules! test_support {
  30    ($self:ident, { $($token:tt)* }) => {{
  31        let body = async {
  32            $($token)*
  33        };
  34
  35        if cfg!(test) {
  36            #[cfg(not(test))]
  37            unreachable!();
  38
  39            #[cfg(test)]
  40            if let Some(background) = $self.background.as_ref() {
  41                background.simulate_random_delay().await;
  42            }
  43
  44            #[cfg(test)]
  45            $self.runtime.as_ref().unwrap().block_on(body)
  46        } else {
  47            body.await
  48        }
  49    }};
  50}
  51
  52pub trait RowsAffected {
  53    fn rows_affected(&self) -> u64;
  54}
  55
  56#[cfg(test)]
  57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  58    fn rows_affected(&self) -> u64 {
  59        self.rows_affected()
  60    }
  61}
  62
  63impl RowsAffected for sqlx::postgres::PgQueryResult {
  64    fn rows_affected(&self) -> u64 {
  65        self.rows_affected()
  66    }
  67}
  68
  69#[cfg(test)]
  70impl Db<sqlx::Sqlite> {
  71    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  72        use std::str::FromStr as _;
  73        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  74            .unwrap()
  75            .create_if_missing(true)
  76            .shared_cache(true);
  77        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  78            .min_connections(2)
  79            .max_connections(max_connections)
  80            .connect_with(options)
  81            .await?;
  82        Ok(Self {
  83            pool,
  84            background: None,
  85            runtime: None,
  86        })
  87    }
  88
  89    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  90        test_support!(self, {
  91            let query = "
  92                SELECT users.*
  93                FROM users
  94                WHERE users.id IN (SELECT value from json_each($1))
  95            ";
  96            Ok(sqlx::query_as(query)
  97                .bind(&serde_json::json!(ids))
  98                .fetch_all(&self.pool)
  99                .await?)
 100        })
 101    }
 102
 103    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 104        test_support!(self, {
 105            let query = "
 106                SELECT metrics_id
 107                FROM users
 108                WHERE id = $1
 109            ";
 110            Ok(sqlx::query_scalar(query)
 111                .bind(id)
 112                .fetch_one(&self.pool)
 113                .await?)
 114        })
 115    }
 116
 117    pub async fn create_user(
 118        &self,
 119        email_address: &str,
 120        admin: bool,
 121        params: NewUserParams,
 122    ) -> Result<NewUserResult> {
 123        test_support!(self, {
 124            let query = "
 125                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 126                VALUES ($1, $2, $3, $4, $5)
 127                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 128                RETURNING id, metrics_id
 129            ";
 130
 131            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 132                .bind(email_address)
 133                .bind(params.github_login)
 134                .bind(params.github_user_id)
 135                .bind(admin)
 136                .bind(Uuid::new_v4().to_string())
 137                .fetch_one(&self.pool)
 138                .await?;
 139            Ok(NewUserResult {
 140                user_id,
 141                metrics_id,
 142                signup_device_id: None,
 143                inviting_user_id: None,
 144            })
 145        })
 146    }
 147
 148    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 149        unimplemented!()
 150    }
 151
 152    pub async fn create_user_from_invite(
 153        &self,
 154        _invite: &Invite,
 155        _user: NewUserParams,
 156    ) -> Result<Option<NewUserResult>> {
 157        unimplemented!()
 158    }
 159
 160    pub async fn create_signup(&self, _signup: &Signup) -> Result<()> {
 161        unimplemented!()
 162    }
 163
 164    pub async fn create_invite_from_code(
 165        &self,
 166        _code: &str,
 167        _email_address: &str,
 168        _device_id: Option<&str>,
 169    ) -> Result<Invite> {
 170        unimplemented!()
 171    }
 172
 173    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 174        unimplemented!()
 175    }
 176}
 177
 178impl Db<sqlx::Postgres> {
 179    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 180        let pool = sqlx::postgres::PgPoolOptions::new()
 181            .max_connections(max_connections)
 182            .connect(url)
 183            .await?;
 184        Ok(Self {
 185            pool,
 186            #[cfg(test)]
 187            background: None,
 188            #[cfg(test)]
 189            runtime: None,
 190        })
 191    }
 192
 193    #[cfg(test)]
 194    pub fn teardown(&self, url: &str) {
 195        self.runtime.as_ref().unwrap().block_on(async {
 196            use util::ResultExt;
 197            let query = "
 198                SELECT pg_terminate_backend(pg_stat_activity.pid)
 199                FROM pg_stat_activity
 200                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 201            ";
 202            sqlx::query(query).execute(&self.pool).await.log_err();
 203            self.pool.close().await;
 204            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 205                .await
 206                .log_err();
 207        })
 208    }
 209
 210    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 211        test_support!(self, {
 212            let like_string = Self::fuzzy_like_string(name_query);
 213            let query = "
 214                SELECT users.*
 215                FROM users
 216                WHERE github_login ILIKE $1
 217                ORDER BY github_login <-> $2
 218                LIMIT $3
 219            ";
 220            Ok(sqlx::query_as(query)
 221                .bind(like_string)
 222                .bind(name_query)
 223                .bind(limit as i32)
 224                .fetch_all(&self.pool)
 225                .await?)
 226        })
 227    }
 228
 229    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 230        test_support!(self, {
 231            let query = "
 232                SELECT users.*
 233                FROM users
 234                WHERE users.id = ANY ($1)
 235            ";
 236            Ok(sqlx::query_as(query)
 237                .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
 238                .fetch_all(&self.pool)
 239                .await?)
 240        })
 241    }
 242
 243    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 244        test_support!(self, {
 245            let query = "
 246                SELECT metrics_id::text
 247                FROM users
 248                WHERE id = $1
 249            ";
 250            Ok(sqlx::query_scalar(query)
 251                .bind(id)
 252                .fetch_one(&self.pool)
 253                .await?)
 254        })
 255    }
 256
 257    pub async fn create_user(
 258        &self,
 259        email_address: &str,
 260        admin: bool,
 261        params: NewUserParams,
 262    ) -> Result<NewUserResult> {
 263        test_support!(self, {
 264            let query = "
 265                INSERT INTO users (email_address, github_login, github_user_id, admin)
 266                VALUES ($1, $2, $3, $4)
 267                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 268                RETURNING id, metrics_id::text
 269            ";
 270
 271            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 272                .bind(email_address)
 273                .bind(params.github_login)
 274                .bind(params.github_user_id)
 275                .bind(admin)
 276                .fetch_one(&self.pool)
 277                .await?;
 278            Ok(NewUserResult {
 279                user_id,
 280                metrics_id,
 281                signup_device_id: None,
 282                inviting_user_id: None,
 283            })
 284        })
 285    }
 286
 287    pub async fn create_user_from_invite(
 288        &self,
 289        invite: &Invite,
 290        user: NewUserParams,
 291    ) -> Result<Option<NewUserResult>> {
 292        test_support!(self, {
 293            let mut tx = self.pool.begin().await?;
 294
 295            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 296                i32,
 297                Option<UserId>,
 298                Option<UserId>,
 299                Option<String>,
 300            ) = sqlx::query_as(
 301                "
 302                SELECT id, user_id, inviting_user_id, device_id
 303                FROM signups
 304                WHERE
 305                    email_address = $1 AND
 306                    email_confirmation_code = $2
 307                ",
 308            )
 309            .bind(&invite.email_address)
 310            .bind(&invite.email_confirmation_code)
 311            .fetch_optional(&mut tx)
 312            .await?
 313            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 314
 315            if existing_user_id.is_some() {
 316                return Ok(None);
 317            }
 318
 319            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 320                "
 321                INSERT INTO users
 322                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 323                VALUES
 324                ($1, $2, $3, FALSE, $4, $5)
 325                ON CONFLICT (github_login) DO UPDATE SET
 326                    email_address = excluded.email_address,
 327                    github_user_id = excluded.github_user_id,
 328                    admin = excluded.admin
 329                RETURNING id, metrics_id::text
 330                ",
 331            )
 332            .bind(&invite.email_address)
 333            .bind(&user.github_login)
 334            .bind(&user.github_user_id)
 335            .bind(&user.invite_count)
 336            .bind(random_invite_code())
 337            .fetch_one(&mut tx)
 338            .await?;
 339
 340            sqlx::query(
 341                "
 342                UPDATE signups
 343                SET user_id = $1
 344                WHERE id = $2
 345                ",
 346            )
 347            .bind(&user_id)
 348            .bind(&signup_id)
 349            .execute(&mut tx)
 350            .await?;
 351
 352            if let Some(inviting_user_id) = inviting_user_id {
 353                sqlx::query(
 354                    "
 355                    INSERT INTO contacts
 356                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 357                    VALUES
 358                        ($1, $2, TRUE, TRUE, TRUE)
 359                    ON CONFLICT DO NOTHING
 360                    ",
 361                )
 362                .bind(inviting_user_id)
 363                .bind(user_id)
 364                .execute(&mut tx)
 365                .await?;
 366            }
 367
 368            tx.commit().await?;
 369            Ok(Some(NewUserResult {
 370                user_id,
 371                metrics_id,
 372                inviting_user_id,
 373                signup_device_id,
 374            }))
 375        })
 376    }
 377
 378    pub async fn create_signup(&self, signup: &Signup) -> Result<()> {
 379        test_support!(self, {
 380            sqlx::query(
 381                "
 382                INSERT INTO signups
 383                (
 384                    email_address,
 385                    email_confirmation_code,
 386                    email_confirmation_sent,
 387                    platform_linux,
 388                    platform_mac,
 389                    platform_windows,
 390                    platform_unknown,
 391                    editor_features,
 392                    programming_languages,
 393                    device_id,
 394                    added_to_mailing_list
 395                )
 396                VALUES
 397                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8, $9)
 398                ON CONFLICT (email_address) DO UPDATE SET
 399                    email_address = excluded.email_address
 400                RETURNING id
 401                ",
 402            )
 403            .bind(&signup.email_address)
 404            .bind(&random_email_confirmation_code())
 405            .bind(&signup.platform_linux)
 406            .bind(&signup.platform_mac)
 407            .bind(&signup.platform_windows)
 408            .bind(&signup.editor_features)
 409            .bind(&signup.programming_languages)
 410            .bind(&signup.device_id)
 411            .bind(&signup.added_to_mailing_list)
 412            .execute(&self.pool)
 413            .await?;
 414            Ok(())
 415        })
 416    }
 417
 418    pub async fn create_invite_from_code(
 419        &self,
 420        code: &str,
 421        email_address: &str,
 422        device_id: Option<&str>,
 423    ) -> Result<Invite> {
 424        test_support!(self, {
 425            let mut tx = self.pool.begin().await?;
 426
 427            let existing_user: Option<UserId> = sqlx::query_scalar(
 428                "
 429                SELECT id
 430                FROM users
 431                WHERE email_address = $1
 432                ",
 433            )
 434            .bind(email_address)
 435            .fetch_optional(&mut tx)
 436            .await?;
 437            if existing_user.is_some() {
 438                Err(anyhow!("email address is already in use"))?;
 439            }
 440
 441            let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
 442                "
 443                UPDATE users
 444                SET invite_count = invite_count - 1
 445                WHERE invite_code = $1 AND invite_count > 0
 446                RETURNING id
 447                ",
 448            )
 449            .bind(code)
 450            .fetch_optional(&mut tx)
 451            .await?;
 452
 453            let Some(inviter_id) = inviting_user_id_with_invites else {
 454                return Err(Error::Http(
 455                    StatusCode::UNAUTHORIZED,
 456                    "unable to find an invite code with invites remaining".to_string(),
 457                ));
 458            };
 459
 460            let email_confirmation_code: String = sqlx::query_scalar(
 461                "
 462                INSERT INTO signups
 463                (
 464                    email_address,
 465                    email_confirmation_code,
 466                    email_confirmation_sent,
 467                    inviting_user_id,
 468                    platform_linux,
 469                    platform_mac,
 470                    platform_windows,
 471                    platform_unknown,
 472                    device_id
 473                )
 474                VALUES
 475                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 476                ON CONFLICT (email_address)
 477                DO UPDATE SET
 478                    inviting_user_id = excluded.inviting_user_id
 479                RETURNING email_confirmation_code
 480                ",
 481            )
 482            .bind(&email_address)
 483            .bind(&random_email_confirmation_code())
 484            .bind(&inviter_id)
 485            .bind(&device_id)
 486            .fetch_one(&mut tx)
 487            .await?;
 488
 489            tx.commit().await?;
 490
 491            Ok(Invite {
 492                email_address: email_address.into(),
 493                email_confirmation_code,
 494            })
 495        })
 496    }
 497
 498    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 499        test_support!(self, {
 500            let emails = invites
 501                .iter()
 502                .map(|s| s.email_address.as_str())
 503                .collect::<Vec<_>>();
 504            sqlx::query(
 505                "
 506                UPDATE signups
 507                SET email_confirmation_sent = TRUE
 508                WHERE email_address = ANY ($1)
 509                ",
 510            )
 511            .bind(&emails)
 512            .execute(&self.pool)
 513            .await?;
 514            Ok(())
 515        })
 516    }
 517}
 518
 519impl<D> Db<D>
 520where
 521    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 522    D::Connection: sqlx::migrate::Migrate,
 523    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 524    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 525    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 526    D::QueryResult: RowsAffected,
 527    String: sqlx::Type<D>,
 528    i32: sqlx::Type<D>,
 529    i64: sqlx::Type<D>,
 530    bool: sqlx::Type<D>,
 531    str: sqlx::Type<D>,
 532    Uuid: sqlx::Type<D>,
 533    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 534    OffsetDateTime: sqlx::Type<D>,
 535    PrimitiveDateTime: sqlx::Type<D>,
 536    usize: sqlx::ColumnIndex<D::Row>,
 537    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 538    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 539    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 540    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 541    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 542    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 543    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 544    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 545    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 546    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 547    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 548    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 549{
 550    pub async fn migrate(
 551        &self,
 552        migrations_path: &Path,
 553        ignore_checksum_mismatch: bool,
 554    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 555        let migrations = MigrationSource::resolve(migrations_path)
 556            .await
 557            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 558
 559        let mut conn = self.pool.acquire().await?;
 560
 561        conn.ensure_migrations_table().await?;
 562        let applied_migrations: HashMap<_, _> = conn
 563            .list_applied_migrations()
 564            .await?
 565            .into_iter()
 566            .map(|m| (m.version, m))
 567            .collect();
 568
 569        let mut new_migrations = Vec::new();
 570        for migration in migrations {
 571            match applied_migrations.get(&migration.version) {
 572                Some(applied_migration) => {
 573                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 574                    {
 575                        Err(anyhow!(
 576                            "checksum mismatch for applied migration {}",
 577                            migration.description
 578                        ))?;
 579                    }
 580                }
 581                None => {
 582                    let elapsed = conn.apply(&migration).await?;
 583                    new_migrations.push((migration, elapsed));
 584                }
 585            }
 586        }
 587
 588        Ok(new_migrations)
 589    }
 590
 591    pub fn fuzzy_like_string(string: &str) -> String {
 592        let mut result = String::with_capacity(string.len() * 2 + 1);
 593        for c in string.chars() {
 594            if c.is_alphanumeric() {
 595                result.push('%');
 596                result.push(c);
 597            }
 598        }
 599        result.push('%');
 600        result
 601    }
 602
 603    // users
 604
 605    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 606        test_support!(self, {
 607            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 608            Ok(sqlx::query_as(query)
 609                .bind(limit as i32)
 610                .bind((page * limit) as i32)
 611                .fetch_all(&self.pool)
 612                .await?)
 613        })
 614    }
 615
 616    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 617        test_support!(self, {
 618            let query = "
 619                SELECT users.*
 620                FROM users
 621                WHERE id = $1
 622                LIMIT 1
 623            ";
 624            Ok(sqlx::query_as(query)
 625                .bind(&id)
 626                .fetch_optional(&self.pool)
 627                .await?)
 628        })
 629    }
 630
 631    pub async fn get_users_with_no_invites(
 632        &self,
 633        invited_by_another_user: bool,
 634    ) -> Result<Vec<User>> {
 635        test_support!(self, {
 636            let query = format!(
 637                "
 638                SELECT users.*
 639                FROM users
 640                WHERE invite_count = 0
 641                AND inviter_id IS{} NULL
 642                ",
 643                if invited_by_another_user { " NOT" } else { "" }
 644            );
 645
 646            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 647        })
 648    }
 649
 650    pub async fn get_user_by_github_account(
 651        &self,
 652        github_login: &str,
 653        github_user_id: Option<i32>,
 654    ) -> Result<Option<User>> {
 655        test_support!(self, {
 656            if let Some(github_user_id) = github_user_id {
 657                let mut user = sqlx::query_as::<_, User>(
 658                    "
 659                    UPDATE users
 660                    SET github_login = $1
 661                    WHERE github_user_id = $2
 662                    RETURNING *
 663                    ",
 664                )
 665                .bind(github_login)
 666                .bind(github_user_id)
 667                .fetch_optional(&self.pool)
 668                .await?;
 669
 670                if user.is_none() {
 671                    user = sqlx::query_as::<_, User>(
 672                        "
 673                        UPDATE users
 674                        SET github_user_id = $1
 675                        WHERE github_login = $2
 676                        RETURNING *
 677                        ",
 678                    )
 679                    .bind(github_user_id)
 680                    .bind(github_login)
 681                    .fetch_optional(&self.pool)
 682                    .await?;
 683                }
 684
 685                Ok(user)
 686            } else {
 687                let user = sqlx::query_as(
 688                    "
 689                    SELECT * FROM users
 690                    WHERE github_login = $1
 691                    LIMIT 1
 692                    ",
 693                )
 694                .bind(github_login)
 695                .fetch_optional(&self.pool)
 696                .await?;
 697                Ok(user)
 698            }
 699        })
 700    }
 701
 702    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 703        test_support!(self, {
 704            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 705            Ok(sqlx::query(query)
 706                .bind(is_admin)
 707                .bind(id.0)
 708                .execute(&self.pool)
 709                .await
 710                .map(drop)?)
 711        })
 712    }
 713
 714    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 715        test_support!(self, {
 716            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 717            Ok(sqlx::query(query)
 718                .bind(connected_once)
 719                .bind(id.0)
 720                .execute(&self.pool)
 721                .await
 722                .map(drop)?)
 723        })
 724    }
 725
 726    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 727        test_support!(self, {
 728            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 729            sqlx::query(query)
 730                .bind(id.0)
 731                .execute(&self.pool)
 732                .await
 733                .map(drop)?;
 734            let query = "DELETE FROM users WHERE id = $1;";
 735            Ok(sqlx::query(query)
 736                .bind(id.0)
 737                .execute(&self.pool)
 738                .await
 739                .map(drop)?)
 740        })
 741    }
 742
 743    // signups
 744
 745    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 746        test_support!(self, {
 747            Ok(sqlx::query_as(
 748                "
 749                SELECT
 750                    COUNT(*) as count,
 751                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 752                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 753                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 754                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 755                FROM (
 756                    SELECT *
 757                    FROM signups
 758                    WHERE
 759                        NOT email_confirmation_sent
 760                ) AS unsent
 761                ",
 762            )
 763            .fetch_one(&self.pool)
 764            .await?)
 765        })
 766    }
 767
 768    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 769        test_support!(self, {
 770            Ok(sqlx::query_as(
 771                "
 772                SELECT
 773                    email_address, email_confirmation_code
 774                FROM signups
 775                WHERE
 776                    NOT email_confirmation_sent AND
 777                    (platform_mac OR platform_unknown)
 778                ORDER BY
 779                    created_at
 780                LIMIT $1
 781                ",
 782            )
 783            .bind(count as i32)
 784            .fetch_all(&self.pool)
 785            .await?)
 786        })
 787    }
 788
 789    // invite codes
 790
 791    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 792        test_support!(self, {
 793            let mut tx = self.pool.begin().await?;
 794            if count > 0 {
 795                sqlx::query(
 796                    "
 797                    UPDATE users
 798                    SET invite_code = $1
 799                    WHERE id = $2 AND invite_code IS NULL
 800                ",
 801                )
 802                .bind(random_invite_code())
 803                .bind(id)
 804                .execute(&mut tx)
 805                .await?;
 806            }
 807
 808            sqlx::query(
 809                "
 810                UPDATE users
 811                SET invite_count = $1
 812                WHERE id = $2
 813                ",
 814            )
 815            .bind(count as i32)
 816            .bind(id)
 817            .execute(&mut tx)
 818            .await?;
 819            tx.commit().await?;
 820            Ok(())
 821        })
 822    }
 823
 824    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 825        test_support!(self, {
 826            let result: Option<(String, i32)> = sqlx::query_as(
 827                "
 828                    SELECT invite_code, invite_count
 829                    FROM users
 830                    WHERE id = $1 AND invite_code IS NOT NULL 
 831                ",
 832            )
 833            .bind(id)
 834            .fetch_optional(&self.pool)
 835            .await?;
 836            if let Some((code, count)) = result {
 837                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 838            } else {
 839                Ok(None)
 840            }
 841        })
 842    }
 843
 844    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 845        test_support!(self, {
 846            sqlx::query_as(
 847                "
 848                    SELECT *
 849                    FROM users
 850                    WHERE invite_code = $1
 851                ",
 852            )
 853            .bind(code)
 854            .fetch_optional(&self.pool)
 855            .await?
 856            .ok_or_else(|| {
 857                Error::Http(
 858                    StatusCode::NOT_FOUND,
 859                    "that invite code does not exist".to_string(),
 860                )
 861            })
 862        })
 863    }
 864
 865    // projects
 866
 867    /// Registers a new project for the given user.
 868    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 869        test_support!(self, {
 870            Ok(sqlx::query_scalar(
 871                "
 872                INSERT INTO projects(host_user_id)
 873                VALUES ($1)
 874                RETURNING id
 875                ",
 876            )
 877            .bind(host_user_id)
 878            .fetch_one(&self.pool)
 879            .await
 880            .map(ProjectId)?)
 881        })
 882    }
 883
 884    /// Unregisters a project for the given project id.
 885    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 886        test_support!(self, {
 887            sqlx::query(
 888                "
 889                UPDATE projects
 890                SET unregistered = TRUE
 891                WHERE id = $1
 892                ",
 893            )
 894            .bind(project_id)
 895            .execute(&self.pool)
 896            .await?;
 897            Ok(())
 898        })
 899    }
 900
 901    // contacts
 902
 903    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 904        test_support!(self, {
 905            let query = "
 906                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 907                FROM contacts
 908                WHERE user_id_a = $1 OR user_id_b = $1;
 909            ";
 910
 911            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 912                .bind(user_id)
 913                .fetch(&self.pool);
 914
 915            let mut contacts = Vec::new();
 916            while let Some(row) = rows.next().await {
 917                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 918
 919                if user_id_a == user_id {
 920                    if accepted {
 921                        contacts.push(Contact::Accepted {
 922                            user_id: user_id_b,
 923                            should_notify: should_notify && a_to_b,
 924                        });
 925                    } else if a_to_b {
 926                        contacts.push(Contact::Outgoing { user_id: user_id_b })
 927                    } else {
 928                        contacts.push(Contact::Incoming {
 929                            user_id: user_id_b,
 930                            should_notify,
 931                        });
 932                    }
 933                } else if accepted {
 934                    contacts.push(Contact::Accepted {
 935                        user_id: user_id_a,
 936                        should_notify: should_notify && !a_to_b,
 937                    });
 938                } else if a_to_b {
 939                    contacts.push(Contact::Incoming {
 940                        user_id: user_id_a,
 941                        should_notify,
 942                    });
 943                } else {
 944                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 945                }
 946            }
 947
 948            contacts.sort_unstable_by_key(|contact| contact.user_id());
 949
 950            Ok(contacts)
 951        })
 952    }
 953
 954    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 955        test_support!(self, {
 956            let (id_a, id_b) = if user_id_1 < user_id_2 {
 957                (user_id_1, user_id_2)
 958            } else {
 959                (user_id_2, user_id_1)
 960            };
 961
 962            let query = "
 963                SELECT 1 FROM contacts
 964                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
 965                LIMIT 1
 966            ";
 967            Ok(sqlx::query_scalar::<_, i32>(query)
 968                .bind(id_a.0)
 969                .bind(id_b.0)
 970                .fetch_optional(&self.pool)
 971                .await?
 972                .is_some())
 973        })
 974    }
 975
 976    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 977        test_support!(self, {
 978            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 979                (sender_id, receiver_id, true)
 980            } else {
 981                (receiver_id, sender_id, false)
 982            };
 983            let query = "
 984                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 985                VALUES ($1, $2, $3, FALSE, TRUE)
 986                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 987                SET
 988                    accepted = TRUE,
 989                    should_notify = FALSE
 990                WHERE
 991                    NOT contacts.accepted AND
 992                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 993                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 994            ";
 995            let result = sqlx::query(query)
 996                .bind(id_a.0)
 997                .bind(id_b.0)
 998                .bind(a_to_b)
 999                .execute(&self.pool)
1000                .await?;
1001
1002            if result.rows_affected() == 1 {
1003                Ok(())
1004            } else {
1005                Err(anyhow!("contact already requested"))?
1006            }
1007        })
1008    }
1009
1010    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1011        test_support!(self, {
1012            let (id_a, id_b) = if responder_id < requester_id {
1013                (responder_id, requester_id)
1014            } else {
1015                (requester_id, responder_id)
1016            };
1017            let query = "
1018                DELETE FROM contacts
1019                WHERE user_id_a = $1 AND user_id_b = $2;
1020            ";
1021            let result = sqlx::query(query)
1022                .bind(id_a.0)
1023                .bind(id_b.0)
1024                .execute(&self.pool)
1025                .await?;
1026
1027            if result.rows_affected() == 1 {
1028                Ok(())
1029            } else {
1030                Err(anyhow!("no such contact"))?
1031            }
1032        })
1033    }
1034
1035    pub async fn dismiss_contact_notification(
1036        &self,
1037        user_id: UserId,
1038        contact_user_id: UserId,
1039    ) -> Result<()> {
1040        test_support!(self, {
1041            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1042                (user_id, contact_user_id, true)
1043            } else {
1044                (contact_user_id, user_id, false)
1045            };
1046
1047            let query = "
1048                UPDATE contacts
1049                SET should_notify = FALSE
1050                WHERE
1051                    user_id_a = $1 AND user_id_b = $2 AND
1052                    (
1053                        (a_to_b = $3 AND accepted) OR
1054                        (a_to_b != $3 AND NOT accepted)
1055                    );
1056            ";
1057
1058            let result = sqlx::query(query)
1059                .bind(id_a.0)
1060                .bind(id_b.0)
1061                .bind(a_to_b)
1062                .execute(&self.pool)
1063                .await?;
1064
1065            if result.rows_affected() == 0 {
1066                Err(anyhow!("no such contact request"))?;
1067            }
1068
1069            Ok(())
1070        })
1071    }
1072
1073    pub async fn respond_to_contact_request(
1074        &self,
1075        responder_id: UserId,
1076        requester_id: UserId,
1077        accept: bool,
1078    ) -> Result<()> {
1079        test_support!(self, {
1080            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1081                (responder_id, requester_id, false)
1082            } else {
1083                (requester_id, responder_id, true)
1084            };
1085            let result = if accept {
1086                let query = "
1087                    UPDATE contacts
1088                    SET accepted = TRUE, should_notify = TRUE
1089                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1090                ";
1091                sqlx::query(query)
1092                    .bind(id_a.0)
1093                    .bind(id_b.0)
1094                    .bind(a_to_b)
1095                    .execute(&self.pool)
1096                    .await?
1097            } else {
1098                let query = "
1099                    DELETE FROM contacts
1100                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1101                ";
1102                sqlx::query(query)
1103                    .bind(id_a.0)
1104                    .bind(id_b.0)
1105                    .bind(a_to_b)
1106                    .execute(&self.pool)
1107                    .await?
1108            };
1109            if result.rows_affected() == 1 {
1110                Ok(())
1111            } else {
1112                Err(anyhow!("no such contact request"))?
1113            }
1114        })
1115    }
1116
1117    // access tokens
1118
1119    pub async fn create_access_token_hash(
1120        &self,
1121        user_id: UserId,
1122        access_token_hash: &str,
1123        max_access_token_count: usize,
1124    ) -> Result<()> {
1125        test_support!(self, {
1126            let insert_query = "
1127                INSERT INTO access_tokens (user_id, hash)
1128                VALUES ($1, $2);
1129            ";
1130            let cleanup_query = "
1131                DELETE FROM access_tokens
1132                WHERE id IN (
1133                    SELECT id from access_tokens
1134                    WHERE user_id = $1
1135                    ORDER BY id DESC
1136                    LIMIT 10000
1137                    OFFSET $3
1138                )
1139            ";
1140
1141            let mut tx = self.pool.begin().await?;
1142            sqlx::query(insert_query)
1143                .bind(user_id.0)
1144                .bind(access_token_hash)
1145                .execute(&mut tx)
1146                .await?;
1147            sqlx::query(cleanup_query)
1148                .bind(user_id.0)
1149                .bind(access_token_hash)
1150                .bind(max_access_token_count as i32)
1151                .execute(&mut tx)
1152                .await?;
1153            Ok(tx.commit().await?)
1154        })
1155    }
1156
1157    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1158        test_support!(self, {
1159            let query = "
1160                SELECT hash
1161                FROM access_tokens
1162                WHERE user_id = $1
1163                ORDER BY id DESC
1164            ";
1165            Ok(sqlx::query_scalar(query)
1166                .bind(user_id.0)
1167                .fetch_all(&self.pool)
1168                .await?)
1169        })
1170    }
1171}
1172
1173macro_rules! id_type {
1174    ($name:ident) => {
1175        #[derive(
1176            Clone,
1177            Copy,
1178            Debug,
1179            Default,
1180            PartialEq,
1181            Eq,
1182            PartialOrd,
1183            Ord,
1184            Hash,
1185            sqlx::Type,
1186            Serialize,
1187            Deserialize,
1188        )]
1189        #[sqlx(transparent)]
1190        #[serde(transparent)]
1191        pub struct $name(pub i32);
1192
1193        impl $name {
1194            #[allow(unused)]
1195            pub const MAX: Self = Self(i32::MAX);
1196
1197            #[allow(unused)]
1198            pub fn from_proto(value: u64) -> Self {
1199                Self(value as i32)
1200            }
1201
1202            #[allow(unused)]
1203            pub fn to_proto(self) -> u64 {
1204                self.0 as u64
1205            }
1206        }
1207
1208        impl std::fmt::Display for $name {
1209            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1210                self.0.fmt(f)
1211            }
1212        }
1213    };
1214}
1215
1216id_type!(UserId);
1217#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1218pub struct User {
1219    pub id: UserId,
1220    pub github_login: String,
1221    pub github_user_id: Option<i32>,
1222    pub email_address: Option<String>,
1223    pub admin: bool,
1224    pub invite_code: Option<String>,
1225    pub invite_count: i32,
1226    pub connected_once: bool,
1227}
1228
1229id_type!(ProjectId);
1230#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1231pub struct Project {
1232    pub id: ProjectId,
1233    pub host_user_id: UserId,
1234    pub unregistered: bool,
1235}
1236
1237#[derive(Clone, Debug, PartialEq, Eq)]
1238pub enum Contact {
1239    Accepted {
1240        user_id: UserId,
1241        should_notify: bool,
1242    },
1243    Outgoing {
1244        user_id: UserId,
1245    },
1246    Incoming {
1247        user_id: UserId,
1248        should_notify: bool,
1249    },
1250}
1251
1252impl Contact {
1253    pub fn user_id(&self) -> UserId {
1254        match self {
1255            Contact::Accepted { user_id, .. } => *user_id,
1256            Contact::Outgoing { user_id } => *user_id,
1257            Contact::Incoming { user_id, .. } => *user_id,
1258        }
1259    }
1260}
1261
1262#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1263pub struct IncomingContactRequest {
1264    pub requester_id: UserId,
1265    pub should_notify: bool,
1266}
1267
1268#[derive(Clone, Deserialize, Default)]
1269pub struct Signup {
1270    pub email_address: String,
1271    pub platform_mac: bool,
1272    pub platform_windows: bool,
1273    pub platform_linux: bool,
1274    pub editor_features: Vec<String>,
1275    pub programming_languages: Vec<String>,
1276    pub device_id: Option<String>,
1277    pub added_to_mailing_list: bool,
1278}
1279
1280#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1281pub struct WaitlistSummary {
1282    #[sqlx(default)]
1283    pub count: i64,
1284    #[sqlx(default)]
1285    pub linux_count: i64,
1286    #[sqlx(default)]
1287    pub mac_count: i64,
1288    #[sqlx(default)]
1289    pub windows_count: i64,
1290    #[sqlx(default)]
1291    pub unknown_count: i64,
1292}
1293
1294#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)]
1295pub struct Invite {
1296    pub email_address: String,
1297    pub email_confirmation_code: String,
1298}
1299
1300#[derive(Debug, Serialize, Deserialize)]
1301pub struct NewUserParams {
1302    pub github_login: String,
1303    pub github_user_id: i32,
1304    pub invite_count: i32,
1305}
1306
1307#[derive(Debug)]
1308pub struct NewUserResult {
1309    pub user_id: UserId,
1310    pub metrics_id: String,
1311    pub inviting_user_id: Option<UserId>,
1312    pub signup_device_id: Option<String>,
1313}
1314
1315fn random_invite_code() -> String {
1316    nanoid::nanoid!(16)
1317}
1318
1319fn random_email_confirmation_code() -> String {
1320    nanoid::nanoid!(64)
1321}
1322
1323#[cfg(test)]
1324pub use test::*;
1325
1326#[cfg(test)]
1327mod test {
1328    use super::*;
1329    use gpui::executor::Background;
1330    use lazy_static::lazy_static;
1331    use parking_lot::Mutex;
1332    use rand::prelude::*;
1333    use sqlx::migrate::MigrateDatabase;
1334    use std::sync::Arc;
1335
1336    pub struct SqliteTestDb {
1337        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1338        pub conn: sqlx::sqlite::SqliteConnection,
1339    }
1340
1341    pub struct PostgresTestDb {
1342        pub db: Option<Arc<Db<sqlx::Postgres>>>,
1343        pub url: String,
1344    }
1345
1346    impl SqliteTestDb {
1347        pub fn new(background: Arc<Background>) -> Self {
1348            let mut rng = StdRng::from_entropy();
1349            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1350            let runtime = tokio::runtime::Builder::new_current_thread()
1351                .enable_io()
1352                .enable_time()
1353                .build()
1354                .unwrap();
1355
1356            let (mut db, conn) = runtime.block_on(async {
1357                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1358                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1359                db.migrate(migrations_path.as_ref(), false).await.unwrap();
1360                let conn = db.pool.acquire().await.unwrap().detach();
1361                (db, conn)
1362            });
1363
1364            db.background = Some(background);
1365            db.runtime = Some(runtime);
1366
1367            Self {
1368                db: Some(Arc::new(db)),
1369                conn,
1370            }
1371        }
1372
1373        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1374            self.db.as_ref().unwrap()
1375        }
1376    }
1377
1378    impl PostgresTestDb {
1379        pub fn new(background: Arc<Background>) -> Self {
1380            lazy_static! {
1381                static ref LOCK: Mutex<()> = Mutex::new(());
1382            }
1383
1384            let _guard = LOCK.lock();
1385            let mut rng = StdRng::from_entropy();
1386            let url = format!(
1387                "postgres://postgres@localhost/zed-test-{}",
1388                rng.gen::<u128>()
1389            );
1390            let runtime = tokio::runtime::Builder::new_current_thread()
1391                .enable_io()
1392                .enable_time()
1393                .build()
1394                .unwrap();
1395
1396            let mut db = runtime.block_on(async {
1397                sqlx::Postgres::create_database(&url)
1398                    .await
1399                    .expect("failed to create test db");
1400                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1401                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1402                db.migrate(Path::new(migrations_path), false).await.unwrap();
1403                db
1404            });
1405
1406            db.background = Some(background);
1407            db.runtime = Some(runtime);
1408
1409            Self {
1410                db: Some(Arc::new(db)),
1411                url,
1412            }
1413        }
1414
1415        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1416            self.db.as_ref().unwrap()
1417        }
1418    }
1419
1420    impl Drop for PostgresTestDb {
1421        fn drop(&mut self) {
1422            let db = self.db.take().unwrap();
1423            db.teardown(&self.url);
1424        }
1425    }
1426}