db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use serde::{Deserialize, Serialize};
   7use sqlx::{
   8    migrate::{Migrate as _, Migration, MigrationSource},
   9    types::Uuid,
  10    FromRow,
  11};
  12use std::{path::Path, time::Duration};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[cfg(test)]
  16pub type DefaultDb = Db<sqlx::Sqlite>;
  17
  18#[cfg(not(test))]
  19pub type DefaultDb = Db<sqlx::Postgres>;
  20
  21pub struct Db<D: sqlx::Database> {
  22    pool: sqlx::Pool<D>,
  23    #[cfg(test)]
  24    background: Option<std::sync::Arc<gpui::executor::Background>>,
  25    #[cfg(test)]
  26    runtime: Option<tokio::runtime::Runtime>,
  27}
  28
  29macro_rules! test_support {
  30    ($self:ident, { $($token:tt)* }) => {{
  31        let body = async {
  32            $($token)*
  33        };
  34
  35        if cfg!(test) {
  36            #[cfg(test)]
  37            if let Some(background) = $self.background.as_ref() {
  38                background.simulate_random_delay().await;
  39            }
  40            #[cfg(test)]
  41            $self.runtime.as_ref().unwrap().block_on(body)
  42        } else {
  43            body.await
  44        }
  45    }};
  46}
  47
  48pub trait RowsAffected {
  49    fn rows_affected(&self) -> u64;
  50}
  51
  52impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  53    fn rows_affected(&self) -> u64 {
  54        self.rows_affected()
  55    }
  56}
  57
  58impl RowsAffected for sqlx::postgres::PgQueryResult {
  59    fn rows_affected(&self) -> u64 {
  60        self.rows_affected()
  61    }
  62}
  63
  64#[cfg(test)]
  65impl Db<sqlx::Sqlite> {
  66    const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
  67
  68    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  69        use std::str::FromStr as _;
  70        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  71            .unwrap()
  72            .create_if_missing(true)
  73            .shared_cache(true);
  74        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  75            .min_connections(2)
  76            .max_connections(max_connections)
  77            .connect_with(options)
  78            .await?;
  79        Ok(Self {
  80            pool,
  81            background: None,
  82            runtime: None,
  83        })
  84    }
  85
  86    #[cfg(test)]
  87    pub fn teardown(&self, _url: &str) {}
  88
  89    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
  90        test_support!(self, {
  91            let query = "
  92                SELECT metrics_id
  93                FROM users
  94                WHERE id = $1
  95            ";
  96            Ok(sqlx::query_scalar(query)
  97                .bind(id)
  98                .fetch_one(&self.pool)
  99                .await?)
 100        })
 101    }
 102
 103    pub async fn create_user(
 104        &self,
 105        email_address: &str,
 106        admin: bool,
 107        params: NewUserParams,
 108    ) -> Result<NewUserResult> {
 109        test_support!(self, {
 110            let query = "
 111                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 112                VALUES ($1, $2, $3, $4, $5)
 113                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 114                RETURNING id, metrics_id
 115            ";
 116
 117            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 118                .bind(email_address)
 119                .bind(params.github_login)
 120                .bind(params.github_user_id)
 121                .bind(admin)
 122                .bind(Uuid::new_v4().to_string())
 123                .fetch_one(&self.pool)
 124                .await?;
 125            Ok(NewUserResult {
 126                user_id,
 127                metrics_id,
 128                signup_device_id: None,
 129                inviting_user_id: None,
 130            })
 131        })
 132    }
 133
 134    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 135        unimplemented!()
 136    }
 137
 138    pub async fn create_user_from_invite(
 139        &self,
 140        _invite: &Invite,
 141        _user: NewUserParams,
 142    ) -> Result<Option<NewUserResult>> {
 143        unimplemented!()
 144    }
 145
 146    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 147        unimplemented!()
 148    }
 149
 150    pub async fn create_invite_from_code(
 151        &self,
 152        _code: &str,
 153        _email_address: &str,
 154        _device_id: Option<&str>,
 155    ) -> Result<Invite> {
 156        unimplemented!()
 157    }
 158}
 159
 160impl Db<sqlx::Postgres> {
 161    const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 162
 163    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 164        let pool = sqlx::postgres::PgPoolOptions::new()
 165            .max_connections(max_connections)
 166            .connect(url)
 167            .await?;
 168        Ok(Self {
 169            pool,
 170            #[cfg(test)]
 171            background: None,
 172            #[cfg(test)]
 173            runtime: None,
 174        })
 175    }
 176
 177    #[cfg(test)]
 178    pub fn teardown(&self, url: &str) {
 179        self.runtime.as_ref().unwrap().block_on(async {
 180            use util::ResultExt;
 181            let query = "
 182                SELECT pg_terminate_backend(pg_stat_activity.pid)
 183                FROM pg_stat_activity
 184                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 185            ";
 186            sqlx::query(query).execute(&self.pool).await.log_err();
 187            self.pool.close().await;
 188            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 189                .await
 190                .log_err();
 191        })
 192    }
 193
 194    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 195        test_support!(self, {
 196            let like_string = Self::fuzzy_like_string(name_query);
 197            let query = "
 198                SELECT users.*
 199                FROM users
 200                WHERE github_login ILIKE $1
 201                ORDER BY github_login <-> $2
 202                LIMIT $3
 203            ";
 204            Ok(sqlx::query_as(query)
 205                .bind(like_string)
 206                .bind(name_query)
 207                .bind(limit as i32)
 208                .fetch_all(&self.pool)
 209                .await?)
 210        })
 211    }
 212
 213    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 214        test_support!(self, {
 215            let query = "
 216                SELECT metrics_id::text
 217                FROM users
 218                WHERE id = $1
 219            ";
 220            Ok(sqlx::query_scalar(query)
 221                .bind(id)
 222                .fetch_one(&self.pool)
 223                .await?)
 224        })
 225    }
 226
 227    pub async fn create_user(
 228        &self,
 229        email_address: &str,
 230        admin: bool,
 231        params: NewUserParams,
 232    ) -> Result<NewUserResult> {
 233        test_support!(self, {
 234            let query = "
 235                INSERT INTO users (email_address, github_login, github_user_id, admin)
 236                VALUES ($1, $2, $3, $4)
 237                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 238                RETURNING id, metrics_id::text
 239            ";
 240
 241            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 242                .bind(email_address)
 243                .bind(params.github_login)
 244                .bind(params.github_user_id)
 245                .bind(admin)
 246                .fetch_one(&self.pool)
 247                .await?;
 248            Ok(NewUserResult {
 249                user_id,
 250                metrics_id,
 251                signup_device_id: None,
 252                inviting_user_id: None,
 253            })
 254        })
 255    }
 256
 257    pub async fn create_user_from_invite(
 258        &self,
 259        invite: &Invite,
 260        user: NewUserParams,
 261    ) -> Result<Option<NewUserResult>> {
 262        test_support!(self, {
 263            let mut tx = self.pool.begin().await?;
 264
 265            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 266                i32,
 267                Option<UserId>,
 268                Option<UserId>,
 269                Option<String>,
 270            ) = sqlx::query_as(
 271                "
 272                SELECT id, user_id, inviting_user_id, device_id
 273                FROM signups
 274                WHERE
 275                    email_address = $1 AND
 276                    email_confirmation_code = $2
 277                ",
 278            )
 279            .bind(&invite.email_address)
 280            .bind(&invite.email_confirmation_code)
 281            .fetch_optional(&mut tx)
 282            .await?
 283            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 284
 285            if existing_user_id.is_some() {
 286                return Ok(None);
 287            }
 288
 289            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 290                "
 291                INSERT INTO users
 292                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 293                VALUES
 294                ($1, $2, $3, FALSE, $4, $5)
 295                ON CONFLICT (github_login) DO UPDATE SET
 296                    email_address = excluded.email_address,
 297                    github_user_id = excluded.github_user_id,
 298                    admin = excluded.admin
 299                RETURNING id, metrics_id::text
 300                ",
 301            )
 302            .bind(&invite.email_address)
 303            .bind(&user.github_login)
 304            .bind(&user.github_user_id)
 305            .bind(&user.invite_count)
 306            .bind(random_invite_code())
 307            .fetch_one(&mut tx)
 308            .await?;
 309
 310            sqlx::query(
 311                "
 312                UPDATE signups
 313                SET user_id = $1
 314                WHERE id = $2
 315                ",
 316            )
 317            .bind(&user_id)
 318            .bind(&signup_id)
 319            .execute(&mut tx)
 320            .await?;
 321
 322            if let Some(inviting_user_id) = inviting_user_id {
 323                let id: Option<UserId> = sqlx::query_scalar(
 324                    "
 325                    UPDATE users
 326                    SET invite_count = invite_count - 1
 327                    WHERE id = $1 AND invite_count > 0
 328                    RETURNING id
 329                    ",
 330                )
 331                .bind(&inviting_user_id)
 332                .fetch_optional(&mut tx)
 333                .await?;
 334
 335                if id.is_none() {
 336                    Err(Error::Http(
 337                        StatusCode::UNAUTHORIZED,
 338                        "no invites remaining".to_string(),
 339                    ))?;
 340                }
 341
 342                sqlx::query(
 343                    "
 344                    INSERT INTO contacts
 345                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 346                    VALUES
 347                        ($1, $2, TRUE, TRUE, TRUE)
 348                    ON CONFLICT DO NOTHING
 349                    ",
 350                )
 351                .bind(inviting_user_id)
 352                .bind(user_id)
 353                .execute(&mut tx)
 354                .await?;
 355            }
 356
 357            tx.commit().await?;
 358            Ok(Some(NewUserResult {
 359                user_id,
 360                metrics_id,
 361                inviting_user_id,
 362                signup_device_id,
 363            }))
 364        })
 365    }
 366
 367    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 368        test_support!(self, {
 369            sqlx::query(
 370                "
 371                INSERT INTO signups
 372                (
 373                    email_address,
 374                    email_confirmation_code,
 375                    email_confirmation_sent,
 376                    platform_linux,
 377                    platform_mac,
 378                    platform_windows,
 379                    platform_unknown,
 380                    editor_features,
 381                    programming_languages,
 382                    device_id
 383                )
 384                VALUES
 385                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6)
 386                RETURNING id
 387                ",
 388            )
 389            .bind(&signup.email_address)
 390            .bind(&random_email_confirmation_code())
 391            .bind(&signup.platform_linux)
 392            .bind(&signup.platform_mac)
 393            .bind(&signup.platform_windows)
 394            .bind(&signup.editor_features)
 395            .bind(&signup.programming_languages)
 396            .bind(&signup.device_id)
 397            .execute(&self.pool)
 398            .await?;
 399            Ok(())
 400        })
 401    }
 402
 403    pub async fn create_invite_from_code(
 404        &self,
 405        code: &str,
 406        email_address: &str,
 407        device_id: Option<&str>,
 408    ) -> Result<Invite> {
 409        test_support!(self, {
 410            let mut tx = self.pool.begin().await?;
 411
 412            let existing_user: Option<UserId> = sqlx::query_scalar(
 413                "
 414                SELECT id
 415                FROM users
 416                WHERE email_address = $1
 417                ",
 418            )
 419            .bind(email_address)
 420            .fetch_optional(&mut tx)
 421            .await?;
 422            if existing_user.is_some() {
 423                Err(anyhow!("email address is already in use"))?;
 424            }
 425
 426            let row: Option<(UserId, i32)> = sqlx::query_as(
 427                "
 428                SELECT id, invite_count
 429                FROM users
 430                WHERE invite_code = $1
 431                ",
 432            )
 433            .bind(code)
 434            .fetch_optional(&mut tx)
 435            .await?;
 436
 437            let (inviter_id, invite_count) = match row {
 438                Some(row) => row,
 439                None => Err(Error::Http(
 440                    StatusCode::NOT_FOUND,
 441                    "invite code not found".to_string(),
 442                ))?,
 443            };
 444
 445            if invite_count == 0 {
 446                Err(Error::Http(
 447                    StatusCode::UNAUTHORIZED,
 448                    "no invites remaining".to_string(),
 449                ))?;
 450            }
 451
 452            let email_confirmation_code: String = sqlx::query_scalar(
 453                "
 454                INSERT INTO signups
 455                (
 456                    email_address,
 457                    email_confirmation_code,
 458                    email_confirmation_sent,
 459                    inviting_user_id,
 460                    platform_linux,
 461                    platform_mac,
 462                    platform_windows,
 463                    platform_unknown,
 464                    device_id
 465                )
 466                VALUES
 467                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 468                ON CONFLICT (email_address)
 469                DO UPDATE SET
 470                    inviting_user_id = excluded.inviting_user_id
 471                RETURNING email_confirmation_code
 472                ",
 473            )
 474            .bind(&email_address)
 475            .bind(&random_email_confirmation_code())
 476            .bind(&inviter_id)
 477            .bind(&device_id)
 478            .fetch_one(&mut tx)
 479            .await?;
 480
 481            tx.commit().await?;
 482
 483            Ok(Invite {
 484                email_address: email_address.into(),
 485                email_confirmation_code,
 486            })
 487        })
 488    }
 489}
 490
 491impl<D> Db<D>
 492where
 493    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 494    D::Connection: sqlx::migrate::Migrate,
 495    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 496    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 497    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 498    D::QueryResult: RowsAffected,
 499    String: sqlx::Type<D>,
 500    i32: sqlx::Type<D>,
 501    i64: sqlx::Type<D>,
 502    bool: sqlx::Type<D>,
 503    str: sqlx::Type<D>,
 504    Uuid: sqlx::Type<D>,
 505    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 506    OffsetDateTime: sqlx::Type<D>,
 507    PrimitiveDateTime: sqlx::Type<D>,
 508    usize: sqlx::ColumnIndex<D::Row>,
 509    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 510    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 511    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 512    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 513    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 514    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 515    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 516    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 517    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 518    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 519    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 520    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 521{
 522    pub async fn migrate(
 523        &self,
 524        migrations_path: &Path,
 525        ignore_checksum_mismatch: bool,
 526    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 527        let migrations = MigrationSource::resolve(migrations_path)
 528            .await
 529            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 530
 531        let mut conn = self.pool.acquire().await?;
 532
 533        conn.ensure_migrations_table().await?;
 534        let applied_migrations: HashMap<_, _> = conn
 535            .list_applied_migrations()
 536            .await?
 537            .into_iter()
 538            .map(|m| (m.version, m))
 539            .collect();
 540
 541        let mut new_migrations = Vec::new();
 542        for migration in migrations {
 543            match applied_migrations.get(&migration.version) {
 544                Some(applied_migration) => {
 545                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 546                    {
 547                        Err(anyhow!(
 548                            "checksum mismatch for applied migration {}",
 549                            migration.description
 550                        ))?;
 551                    }
 552                }
 553                None => {
 554                    let elapsed = conn.apply(&migration).await?;
 555                    new_migrations.push((migration, elapsed));
 556                }
 557            }
 558        }
 559
 560        Ok(new_migrations)
 561    }
 562
 563    pub fn fuzzy_like_string(string: &str) -> String {
 564        let mut result = String::with_capacity(string.len() * 2 + 1);
 565        for c in string.chars() {
 566            if c.is_alphanumeric() {
 567                result.push('%');
 568                result.push(c);
 569            }
 570        }
 571        result.push('%');
 572        result
 573    }
 574
 575    // users
 576
 577    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 578        test_support!(self, {
 579            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 580            Ok(sqlx::query_as(query)
 581                .bind(limit as i32)
 582                .bind((page * limit) as i32)
 583                .fetch_all(&self.pool)
 584                .await?)
 585        })
 586    }
 587
 588    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 589        test_support!(self, {
 590            let query = "
 591                SELECT users.*
 592                FROM users
 593                WHERE id = $1
 594                LIMIT 1
 595            ";
 596            Ok(sqlx::query_as(query)
 597                .bind(&id)
 598                .fetch_optional(&self.pool)
 599                .await?)
 600        })
 601    }
 602
 603    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 604        test_support!(self, {
 605            let query = "
 606                SELECT users.*
 607                FROM users
 608                WHERE users.id IN (SELECT value from json_each($1))
 609            ";
 610            Ok(sqlx::query_as(query)
 611                .bind(&serde_json::json!(ids))
 612                .fetch_all(&self.pool)
 613                .await?)
 614        })
 615    }
 616
 617    pub async fn get_users_with_no_invites(
 618        &self,
 619        invited_by_another_user: bool,
 620    ) -> Result<Vec<User>> {
 621        test_support!(self, {
 622            let query = format!(
 623                "
 624                SELECT users.*
 625                FROM users
 626                WHERE invite_count = 0
 627                AND inviter_id IS{} NULL
 628                ",
 629                if invited_by_another_user { " NOT" } else { "" }
 630            );
 631
 632            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 633        })
 634    }
 635
 636    pub async fn get_user_by_github_account(
 637        &self,
 638        github_login: &str,
 639        github_user_id: Option<i32>,
 640    ) -> Result<Option<User>> {
 641        test_support!(self, {
 642            if let Some(github_user_id) = github_user_id {
 643                let mut user = sqlx::query_as::<_, User>(
 644                    "
 645                    UPDATE users
 646                    SET github_login = $1
 647                    WHERE github_user_id = $2
 648                    RETURNING *
 649                    ",
 650                )
 651                .bind(github_login)
 652                .bind(github_user_id)
 653                .fetch_optional(&self.pool)
 654                .await?;
 655
 656                if user.is_none() {
 657                    user = sqlx::query_as::<_, User>(
 658                        "
 659                        UPDATE users
 660                        SET github_user_id = $1
 661                        WHERE github_login = $2
 662                        RETURNING *
 663                        ",
 664                    )
 665                    .bind(github_user_id)
 666                    .bind(github_login)
 667                    .fetch_optional(&self.pool)
 668                    .await?;
 669                }
 670
 671                Ok(user)
 672            } else {
 673                let user = sqlx::query_as(
 674                    "
 675                    SELECT * FROM users
 676                    WHERE github_login = $1
 677                    LIMIT 1
 678                    ",
 679                )
 680                .bind(github_login)
 681                .fetch_optional(&self.pool)
 682                .await?;
 683                Ok(user)
 684            }
 685        })
 686    }
 687
 688    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 689        test_support!(self, {
 690            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 691            Ok(sqlx::query(query)
 692                .bind(is_admin)
 693                .bind(id.0)
 694                .execute(&self.pool)
 695                .await
 696                .map(drop)?)
 697        })
 698    }
 699
 700    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 701        test_support!(self, {
 702            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 703            Ok(sqlx::query(query)
 704                .bind(connected_once)
 705                .bind(id.0)
 706                .execute(&self.pool)
 707                .await
 708                .map(drop)?)
 709        })
 710    }
 711
 712    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 713        test_support!(self, {
 714            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 715            sqlx::query(query)
 716                .bind(id.0)
 717                .execute(&self.pool)
 718                .await
 719                .map(drop)?;
 720            let query = "DELETE FROM users WHERE id = $1;";
 721            Ok(sqlx::query(query)
 722                .bind(id.0)
 723                .execute(&self.pool)
 724                .await
 725                .map(drop)?)
 726        })
 727    }
 728
 729    // signups
 730
 731    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 732        test_support!(self, {
 733            Ok(sqlx::query_as(
 734                "
 735                SELECT
 736                    COUNT(*) as count,
 737                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 738                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 739                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 740                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 741                FROM (
 742                    SELECT *
 743                    FROM signups
 744                    WHERE
 745                        NOT email_confirmation_sent
 746                ) AS unsent
 747                ",
 748            )
 749            .fetch_one(&self.pool)
 750            .await?)
 751        })
 752    }
 753
 754    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 755        test_support!(self, {
 756            Ok(sqlx::query_as(
 757                "
 758                SELECT
 759                    email_address, email_confirmation_code
 760                FROM signups
 761                WHERE
 762                    NOT email_confirmation_sent AND
 763                    (platform_mac OR platform_unknown)
 764                LIMIT $1
 765                ",
 766            )
 767            .bind(count as i32)
 768            .fetch_all(&self.pool)
 769            .await?)
 770        })
 771    }
 772
 773    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 774        test_support!(self, {
 775            let emails = invites
 776                .iter()
 777                .map(|s| s.email_address.as_str())
 778                .collect::<Vec<_>>();
 779            sqlx::query(
 780                "
 781                UPDATE signups
 782                SET email_confirmation_sent = TRUE
 783                WHERE email_address IN (SELECT value from json_each($1))
 784                ",
 785            )
 786            .bind(&serde_json::json!(emails))
 787            .execute(&self.pool)
 788            .await?;
 789            Ok(())
 790        })
 791    }
 792
 793    // invite codes
 794
 795    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 796        test_support!(self, {
 797            let mut tx = self.pool.begin().await?;
 798            if count > 0 {
 799                sqlx::query(
 800                    "
 801                    UPDATE users
 802                    SET invite_code = $1
 803                    WHERE id = $2 AND invite_code IS NULL
 804                ",
 805                )
 806                .bind(random_invite_code())
 807                .bind(id)
 808                .execute(&mut tx)
 809                .await?;
 810            }
 811
 812            sqlx::query(
 813                "
 814                UPDATE users
 815                SET invite_count = $1
 816                WHERE id = $2
 817                ",
 818            )
 819            .bind(count as i32)
 820            .bind(id)
 821            .execute(&mut tx)
 822            .await?;
 823            tx.commit().await?;
 824            Ok(())
 825        })
 826    }
 827
 828    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 829        test_support!(self, {
 830            let result: Option<(String, i32)> = sqlx::query_as(
 831                "
 832                    SELECT invite_code, invite_count
 833                    FROM users
 834                    WHERE id = $1 AND invite_code IS NOT NULL 
 835                ",
 836            )
 837            .bind(id)
 838            .fetch_optional(&self.pool)
 839            .await?;
 840            if let Some((code, count)) = result {
 841                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 842            } else {
 843                Ok(None)
 844            }
 845        })
 846    }
 847
 848    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 849        test_support!(self, {
 850            sqlx::query_as(
 851                "
 852                    SELECT *
 853                    FROM users
 854                    WHERE invite_code = $1
 855                ",
 856            )
 857            .bind(code)
 858            .fetch_optional(&self.pool)
 859            .await?
 860            .ok_or_else(|| {
 861                Error::Http(
 862                    StatusCode::NOT_FOUND,
 863                    "that invite code does not exist".to_string(),
 864                )
 865            })
 866        })
 867    }
 868
 869    // projects
 870
 871    /// Registers a new project for the given user.
 872    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 873        test_support!(self, {
 874            Ok(sqlx::query_scalar(
 875                "
 876                INSERT INTO projects(host_user_id)
 877                VALUES ($1)
 878                RETURNING id
 879                ",
 880            )
 881            .bind(host_user_id)
 882            .fetch_one(&self.pool)
 883            .await
 884            .map(ProjectId)?)
 885        })
 886    }
 887
 888    /// Unregisters a project for the given project id.
 889    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 890        test_support!(self, {
 891            sqlx::query(
 892                "
 893                UPDATE projects
 894                SET unregistered = TRUE
 895                WHERE id = $1
 896                ",
 897            )
 898            .bind(project_id)
 899            .execute(&self.pool)
 900            .await?;
 901            Ok(())
 902        })
 903    }
 904
 905    // contacts
 906
 907    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 908        test_support!(self, {
 909            let query = "
 910                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 911                FROM contacts
 912                WHERE user_id_a = $1 OR user_id_b = $1;
 913            ";
 914
 915            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 916                .bind(user_id)
 917                .fetch(&self.pool);
 918
 919            let mut contacts = Vec::new();
 920            while let Some(row) = rows.next().await {
 921                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 922
 923                if user_id_a == user_id {
 924                    if accepted {
 925                        contacts.push(Contact::Accepted {
 926                            user_id: user_id_b,
 927                            should_notify: should_notify && a_to_b,
 928                        });
 929                    } else if a_to_b {
 930                        contacts.push(Contact::Outgoing { user_id: user_id_b })
 931                    } else {
 932                        contacts.push(Contact::Incoming {
 933                            user_id: user_id_b,
 934                            should_notify,
 935                        });
 936                    }
 937                } else if accepted {
 938                    contacts.push(Contact::Accepted {
 939                        user_id: user_id_a,
 940                        should_notify: should_notify && !a_to_b,
 941                    });
 942                } else if a_to_b {
 943                    contacts.push(Contact::Incoming {
 944                        user_id: user_id_a,
 945                        should_notify,
 946                    });
 947                } else {
 948                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 949                }
 950            }
 951
 952            contacts.sort_unstable_by_key(|contact| contact.user_id());
 953
 954            Ok(contacts)
 955        })
 956    }
 957
 958    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 959        test_support!(self, {
 960            let (id_a, id_b) = if user_id_1 < user_id_2 {
 961                (user_id_1, user_id_2)
 962            } else {
 963                (user_id_2, user_id_1)
 964            };
 965
 966            let query = "
 967                SELECT 1 FROM contacts
 968                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
 969                LIMIT 1
 970            ";
 971            Ok(sqlx::query_scalar::<_, i32>(query)
 972                .bind(id_a.0)
 973                .bind(id_b.0)
 974                .fetch_optional(&self.pool)
 975                .await?
 976                .is_some())
 977        })
 978    }
 979
 980    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 981        test_support!(self, {
 982            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 983                (sender_id, receiver_id, true)
 984            } else {
 985                (receiver_id, sender_id, false)
 986            };
 987            let query = "
 988                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 989                VALUES ($1, $2, $3, FALSE, TRUE)
 990                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 991                SET
 992                    accepted = TRUE,
 993                    should_notify = FALSE
 994                WHERE
 995                    NOT contacts.accepted AND
 996                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 997                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 998            ";
 999            let result = sqlx::query(query)
1000                .bind(id_a.0)
1001                .bind(id_b.0)
1002                .bind(a_to_b)
1003                .execute(&self.pool)
1004                .await?;
1005
1006            if result.rows_affected() == 1 {
1007                Ok(())
1008            } else {
1009                Err(anyhow!("contact already requested"))?
1010            }
1011        })
1012    }
1013
1014    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1015        test_support!(self, {
1016            let (id_a, id_b) = if responder_id < requester_id {
1017                (responder_id, requester_id)
1018            } else {
1019                (requester_id, responder_id)
1020            };
1021            let query = "
1022                DELETE FROM contacts
1023                WHERE user_id_a = $1 AND user_id_b = $2;
1024            ";
1025            let result = sqlx::query(query)
1026                .bind(id_a.0)
1027                .bind(id_b.0)
1028                .execute(&self.pool)
1029                .await?;
1030
1031            if result.rows_affected() == 1 {
1032                Ok(())
1033            } else {
1034                Err(anyhow!("no such contact"))?
1035            }
1036        })
1037    }
1038
1039    pub async fn dismiss_contact_notification(
1040        &self,
1041        user_id: UserId,
1042        contact_user_id: UserId,
1043    ) -> Result<()> {
1044        test_support!(self, {
1045            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1046                (user_id, contact_user_id, true)
1047            } else {
1048                (contact_user_id, user_id, false)
1049            };
1050
1051            let query = "
1052                UPDATE contacts
1053                SET should_notify = FALSE
1054                WHERE
1055                    user_id_a = $1 AND user_id_b = $2 AND
1056                    (
1057                        (a_to_b = $3 AND accepted) OR
1058                        (a_to_b != $3 AND NOT accepted)
1059                    );
1060            ";
1061
1062            let result = sqlx::query(query)
1063                .bind(id_a.0)
1064                .bind(id_b.0)
1065                .bind(a_to_b)
1066                .execute(&self.pool)
1067                .await?;
1068
1069            if result.rows_affected() == 0 {
1070                Err(anyhow!("no such contact request"))?;
1071            }
1072
1073            Ok(())
1074        })
1075    }
1076
1077    pub async fn respond_to_contact_request(
1078        &self,
1079        responder_id: UserId,
1080        requester_id: UserId,
1081        accept: bool,
1082    ) -> Result<()> {
1083        test_support!(self, {
1084            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1085                (responder_id, requester_id, false)
1086            } else {
1087                (requester_id, responder_id, true)
1088            };
1089            let result = if accept {
1090                let query = "
1091                    UPDATE contacts
1092                    SET accepted = TRUE, should_notify = TRUE
1093                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1094                ";
1095                sqlx::query(query)
1096                    .bind(id_a.0)
1097                    .bind(id_b.0)
1098                    .bind(a_to_b)
1099                    .execute(&self.pool)
1100                    .await?
1101            } else {
1102                let query = "
1103                    DELETE FROM contacts
1104                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1105                ";
1106                sqlx::query(query)
1107                    .bind(id_a.0)
1108                    .bind(id_b.0)
1109                    .bind(a_to_b)
1110                    .execute(&self.pool)
1111                    .await?
1112            };
1113            if result.rows_affected() == 1 {
1114                Ok(())
1115            } else {
1116                Err(anyhow!("no such contact request"))?
1117            }
1118        })
1119    }
1120
1121    // access tokens
1122
1123    pub async fn create_access_token_hash(
1124        &self,
1125        user_id: UserId,
1126        access_token_hash: &str,
1127        max_access_token_count: usize,
1128    ) -> Result<()> {
1129        test_support!(self, {
1130            let insert_query = "
1131                INSERT INTO access_tokens (user_id, hash)
1132                VALUES ($1, $2);
1133            ";
1134            let cleanup_query = "
1135                DELETE FROM access_tokens
1136                WHERE id IN (
1137                    SELECT id from access_tokens
1138                    WHERE user_id = $1
1139                    ORDER BY id DESC
1140                    LIMIT 10000
1141                    OFFSET $3
1142                )
1143            ";
1144
1145            let mut tx = self.pool.begin().await?;
1146            sqlx::query(insert_query)
1147                .bind(user_id.0)
1148                .bind(access_token_hash)
1149                .execute(&mut tx)
1150                .await?;
1151            sqlx::query(cleanup_query)
1152                .bind(user_id.0)
1153                .bind(access_token_hash)
1154                .bind(max_access_token_count as i32)
1155                .execute(&mut tx)
1156                .await?;
1157            Ok(tx.commit().await?)
1158        })
1159    }
1160
1161    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1162        test_support!(self, {
1163            let query = "
1164                SELECT hash
1165                FROM access_tokens
1166                WHERE user_id = $1
1167                ORDER BY id DESC
1168            ";
1169            Ok(sqlx::query_scalar(query)
1170                .bind(user_id.0)
1171                .fetch_all(&self.pool)
1172                .await?)
1173        })
1174    }
1175}
1176
1177macro_rules! id_type {
1178    ($name:ident) => {
1179        #[derive(
1180            Clone,
1181            Copy,
1182            Debug,
1183            Default,
1184            PartialEq,
1185            Eq,
1186            PartialOrd,
1187            Ord,
1188            Hash,
1189            sqlx::Type,
1190            Serialize,
1191            Deserialize,
1192        )]
1193        #[sqlx(transparent)]
1194        #[serde(transparent)]
1195        pub struct $name(pub i32);
1196
1197        impl $name {
1198            #[allow(unused)]
1199            pub const MAX: Self = Self(i32::MAX);
1200
1201            #[allow(unused)]
1202            pub fn from_proto(value: u64) -> Self {
1203                Self(value as i32)
1204            }
1205
1206            #[allow(unused)]
1207            pub fn to_proto(self) -> u64 {
1208                self.0 as u64
1209            }
1210        }
1211
1212        impl std::fmt::Display for $name {
1213            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1214                self.0.fmt(f)
1215            }
1216        }
1217    };
1218}
1219
1220id_type!(UserId);
1221#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1222pub struct User {
1223    pub id: UserId,
1224    pub github_login: String,
1225    pub github_user_id: Option<i32>,
1226    pub email_address: Option<String>,
1227    pub admin: bool,
1228    pub invite_code: Option<String>,
1229    pub invite_count: i32,
1230    pub connected_once: bool,
1231}
1232
1233id_type!(ProjectId);
1234#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1235pub struct Project {
1236    pub id: ProjectId,
1237    pub host_user_id: UserId,
1238    pub unregistered: bool,
1239}
1240
1241#[derive(Clone, Debug, PartialEq, Eq)]
1242pub enum Contact {
1243    Accepted {
1244        user_id: UserId,
1245        should_notify: bool,
1246    },
1247    Outgoing {
1248        user_id: UserId,
1249    },
1250    Incoming {
1251        user_id: UserId,
1252        should_notify: bool,
1253    },
1254}
1255
1256impl Contact {
1257    pub fn user_id(&self) -> UserId {
1258        match self {
1259            Contact::Accepted { user_id, .. } => *user_id,
1260            Contact::Outgoing { user_id } => *user_id,
1261            Contact::Incoming { user_id, .. } => *user_id,
1262        }
1263    }
1264}
1265
1266#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1267pub struct IncomingContactRequest {
1268    pub requester_id: UserId,
1269    pub should_notify: bool,
1270}
1271
1272#[derive(Clone, Deserialize)]
1273pub struct Signup {
1274    pub email_address: String,
1275    pub platform_mac: bool,
1276    pub platform_windows: bool,
1277    pub platform_linux: bool,
1278    pub editor_features: Vec<String>,
1279    pub programming_languages: Vec<String>,
1280    pub device_id: Option<String>,
1281}
1282
1283#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1284pub struct WaitlistSummary {
1285    #[sqlx(default)]
1286    pub count: i64,
1287    #[sqlx(default)]
1288    pub linux_count: i64,
1289    #[sqlx(default)]
1290    pub mac_count: i64,
1291    #[sqlx(default)]
1292    pub windows_count: i64,
1293    #[sqlx(default)]
1294    pub unknown_count: i64,
1295}
1296
1297#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1298pub struct Invite {
1299    pub email_address: String,
1300    pub email_confirmation_code: String,
1301}
1302
1303#[derive(Debug, Serialize, Deserialize)]
1304pub struct NewUserParams {
1305    pub github_login: String,
1306    pub github_user_id: i32,
1307    pub invite_count: i32,
1308}
1309
1310#[derive(Debug)]
1311pub struct NewUserResult {
1312    pub user_id: UserId,
1313    pub metrics_id: String,
1314    pub inviting_user_id: Option<UserId>,
1315    pub signup_device_id: Option<String>,
1316}
1317
1318fn random_invite_code() -> String {
1319    nanoid::nanoid!(16)
1320}
1321
1322fn random_email_confirmation_code() -> String {
1323    nanoid::nanoid!(64)
1324}
1325
1326#[cfg(test)]
1327pub use test::*;
1328
1329#[cfg(test)]
1330mod test {
1331    use super::*;
1332    use gpui::executor::Background;
1333    use rand::prelude::*;
1334    use std::sync::Arc;
1335
1336    pub struct TestDb {
1337        pub db: Option<Arc<DefaultDb>>,
1338        pub conn: sqlx::sqlite::SqliteConnection,
1339        pub url: String,
1340    }
1341
1342    impl TestDb {
1343        pub fn new(background: Arc<Background>) -> Self {
1344            let mut rng = StdRng::from_entropy();
1345            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1346            let runtime = tokio::runtime::Builder::new_current_thread()
1347                .enable_io()
1348                .enable_time()
1349                .build()
1350                .unwrap();
1351
1352            let (mut db, conn) = runtime.block_on(async {
1353                let db = DefaultDb::new(&url, 5).await.unwrap();
1354                db.migrate(Path::new(DefaultDb::MIGRATIONS_PATH), false)
1355                    .await
1356                    .unwrap();
1357                let conn = db.pool.acquire().await.unwrap().detach();
1358                (db, conn)
1359            });
1360
1361            db.background = Some(background);
1362            db.runtime = Some(runtime);
1363
1364            Self {
1365                db: Some(Arc::new(db)),
1366                conn,
1367                url,
1368            }
1369        }
1370
1371        pub fn db(&self) -> &Arc<DefaultDb> {
1372            self.db.as_ref().unwrap()
1373        }
1374    }
1375
1376    impl Drop for TestDb {
1377        fn drop(&mut self) {
1378            let db = self.db.take().unwrap();
1379            db.teardown(&self.url);
1380        }
1381    }
1382}