db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use serde::{Deserialize, Serialize};
   7use sqlx::{
   8    migrate::{Migrate as _, Migration, MigrationSource},
   9    types::Uuid,
  10    FromRow,
  11};
  12use std::{path::Path, time::Duration};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[cfg(test)]
  16pub type DefaultDb = Db<sqlx::Sqlite>;
  17
  18#[cfg(not(test))]
  19pub type DefaultDb = Db<sqlx::Postgres>;
  20
  21pub struct Db<D: sqlx::Database> {
  22    pool: sqlx::Pool<D>,
  23    #[cfg(test)]
  24    background: Option<std::sync::Arc<gpui::executor::Background>>,
  25    #[cfg(test)]
  26    runtime: Option<tokio::runtime::Runtime>,
  27}
  28
  29macro_rules! test_support {
  30    ($self:ident, { $($token:tt)* }) => {{
  31        let body = async {
  32            $($token)*
  33        };
  34
  35        if cfg!(test) {
  36            #[cfg(not(test))]
  37            unreachable!();
  38
  39            #[cfg(test)]
  40            if let Some(background) = $self.background.as_ref() {
  41                background.simulate_random_delay().await;
  42            }
  43
  44            #[cfg(test)]
  45            $self.runtime.as_ref().unwrap().block_on(body)
  46        } else {
  47            body.await
  48        }
  49    }};
  50}
  51
  52pub trait RowsAffected {
  53    fn rows_affected(&self) -> u64;
  54}
  55
  56#[cfg(test)]
  57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  58    fn rows_affected(&self) -> u64 {
  59        self.rows_affected()
  60    }
  61}
  62
  63impl RowsAffected for sqlx::postgres::PgQueryResult {
  64    fn rows_affected(&self) -> u64 {
  65        self.rows_affected()
  66    }
  67}
  68
  69#[cfg(test)]
  70impl Db<sqlx::Sqlite> {
  71    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  72        use std::str::FromStr as _;
  73        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  74            .unwrap()
  75            .create_if_missing(true)
  76            .shared_cache(true);
  77        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  78            .min_connections(2)
  79            .max_connections(max_connections)
  80            .connect_with(options)
  81            .await?;
  82        Ok(Self {
  83            pool,
  84            background: None,
  85            runtime: None,
  86        })
  87    }
  88
  89    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  90        test_support!(self, {
  91            let query = "
  92                SELECT users.*
  93                FROM users
  94                WHERE users.id IN (SELECT value from json_each($1))
  95            ";
  96            Ok(sqlx::query_as(query)
  97                .bind(&serde_json::json!(ids))
  98                .fetch_all(&self.pool)
  99                .await?)
 100        })
 101    }
 102
 103    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 104        test_support!(self, {
 105            let query = "
 106                SELECT metrics_id
 107                FROM users
 108                WHERE id = $1
 109            ";
 110            Ok(sqlx::query_scalar(query)
 111                .bind(id)
 112                .fetch_one(&self.pool)
 113                .await?)
 114        })
 115    }
 116
 117    pub async fn create_user(
 118        &self,
 119        email_address: &str,
 120        admin: bool,
 121        params: NewUserParams,
 122    ) -> Result<NewUserResult> {
 123        test_support!(self, {
 124            let query = "
 125                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 126                VALUES ($1, $2, $3, $4, $5)
 127                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 128                RETURNING id, metrics_id
 129            ";
 130
 131            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 132                .bind(email_address)
 133                .bind(params.github_login)
 134                .bind(params.github_user_id)
 135                .bind(admin)
 136                .bind(Uuid::new_v4().to_string())
 137                .fetch_one(&self.pool)
 138                .await?;
 139            Ok(NewUserResult {
 140                user_id,
 141                metrics_id,
 142                signup_device_id: None,
 143                inviting_user_id: None,
 144            })
 145        })
 146    }
 147
 148    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 149        unimplemented!()
 150    }
 151
 152    pub async fn create_user_from_invite(
 153        &self,
 154        _invite: &Invite,
 155        _user: NewUserParams,
 156    ) -> Result<Option<NewUserResult>> {
 157        unimplemented!()
 158    }
 159
 160    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 161        unimplemented!()
 162    }
 163
 164    pub async fn create_invite_from_code(
 165        &self,
 166        _code: &str,
 167        _email_address: &str,
 168        _device_id: Option<&str>,
 169    ) -> Result<Invite> {
 170        unimplemented!()
 171    }
 172
 173    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 174        unimplemented!()
 175    }
 176}
 177
 178impl Db<sqlx::Postgres> {
 179    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 180        let pool = sqlx::postgres::PgPoolOptions::new()
 181            .max_connections(max_connections)
 182            .connect(url)
 183            .await?;
 184        Ok(Self {
 185            pool,
 186            #[cfg(test)]
 187            background: None,
 188            #[cfg(test)]
 189            runtime: None,
 190        })
 191    }
 192
 193    #[cfg(test)]
 194    pub fn teardown(&self, url: &str) {
 195        self.runtime.as_ref().unwrap().block_on(async {
 196            use util::ResultExt;
 197            let query = "
 198                SELECT pg_terminate_backend(pg_stat_activity.pid)
 199                FROM pg_stat_activity
 200                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 201            ";
 202            sqlx::query(query).execute(&self.pool).await.log_err();
 203            self.pool.close().await;
 204            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 205                .await
 206                .log_err();
 207        })
 208    }
 209
 210    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 211        test_support!(self, {
 212            let like_string = Self::fuzzy_like_string(name_query);
 213            let query = "
 214                SELECT users.*
 215                FROM users
 216                WHERE github_login ILIKE $1
 217                ORDER BY github_login <-> $2
 218                LIMIT $3
 219            ";
 220            Ok(sqlx::query_as(query)
 221                .bind(like_string)
 222                .bind(name_query)
 223                .bind(limit as i32)
 224                .fetch_all(&self.pool)
 225                .await?)
 226        })
 227    }
 228
 229    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 230        test_support!(self, {
 231            let query = "
 232                SELECT users.*
 233                FROM users
 234                WHERE users.id = ANY ($1)
 235            ";
 236            Ok(sqlx::query_as(query)
 237                .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
 238                .fetch_all(&self.pool)
 239                .await?)
 240        })
 241    }
 242
 243    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 244        test_support!(self, {
 245            let query = "
 246                SELECT metrics_id::text
 247                FROM users
 248                WHERE id = $1
 249            ";
 250            Ok(sqlx::query_scalar(query)
 251                .bind(id)
 252                .fetch_one(&self.pool)
 253                .await?)
 254        })
 255    }
 256
 257    pub async fn create_user(
 258        &self,
 259        email_address: &str,
 260        admin: bool,
 261        params: NewUserParams,
 262    ) -> Result<NewUserResult> {
 263        test_support!(self, {
 264            let query = "
 265                INSERT INTO users (email_address, github_login, github_user_id, admin)
 266                VALUES ($1, $2, $3, $4)
 267                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 268                RETURNING id, metrics_id::text
 269            ";
 270
 271            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 272                .bind(email_address)
 273                .bind(params.github_login)
 274                .bind(params.github_user_id)
 275                .bind(admin)
 276                .fetch_one(&self.pool)
 277                .await?;
 278            Ok(NewUserResult {
 279                user_id,
 280                metrics_id,
 281                signup_device_id: None,
 282                inviting_user_id: None,
 283            })
 284        })
 285    }
 286
 287    pub async fn create_user_from_invite(
 288        &self,
 289        invite: &Invite,
 290        user: NewUserParams,
 291    ) -> Result<Option<NewUserResult>> {
 292        test_support!(self, {
 293            let mut tx = self.pool.begin().await?;
 294
 295            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 296                i32,
 297                Option<UserId>,
 298                Option<UserId>,
 299                Option<String>,
 300            ) = sqlx::query_as(
 301                "
 302                SELECT id, user_id, inviting_user_id, device_id
 303                FROM signups
 304                WHERE
 305                    email_address = $1 AND
 306                    email_confirmation_code = $2
 307                ",
 308            )
 309            .bind(&invite.email_address)
 310            .bind(&invite.email_confirmation_code)
 311            .fetch_optional(&mut tx)
 312            .await?
 313            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 314
 315            if existing_user_id.is_some() {
 316                return Ok(None);
 317            }
 318
 319            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 320                "
 321                INSERT INTO users
 322                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 323                VALUES
 324                ($1, $2, $3, FALSE, $4, $5)
 325                ON CONFLICT (github_login) DO UPDATE SET
 326                    email_address = excluded.email_address,
 327                    github_user_id = excluded.github_user_id,
 328                    admin = excluded.admin
 329                RETURNING id, metrics_id::text
 330                ",
 331            )
 332            .bind(&invite.email_address)
 333            .bind(&user.github_login)
 334            .bind(&user.github_user_id)
 335            .bind(&user.invite_count)
 336            .bind(random_invite_code())
 337            .fetch_one(&mut tx)
 338            .await?;
 339
 340            sqlx::query(
 341                "
 342                UPDATE signups
 343                SET user_id = $1
 344                WHERE id = $2
 345                ",
 346            )
 347            .bind(&user_id)
 348            .bind(&signup_id)
 349            .execute(&mut tx)
 350            .await?;
 351
 352            if let Some(inviting_user_id) = inviting_user_id {
 353                sqlx::query(
 354                    "
 355                    INSERT INTO contacts
 356                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 357                    VALUES
 358                        ($1, $2, TRUE, TRUE, TRUE)
 359                    ON CONFLICT DO NOTHING
 360                    ",
 361                )
 362                .bind(inviting_user_id)
 363                .bind(user_id)
 364                .execute(&mut tx)
 365                .await?;
 366            }
 367
 368            tx.commit().await?;
 369            Ok(Some(NewUserResult {
 370                user_id,
 371                metrics_id,
 372                inviting_user_id,
 373                signup_device_id,
 374            }))
 375        })
 376    }
 377
 378    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 379        test_support!(self, {
 380            sqlx::query(
 381                "
 382                INSERT INTO signups
 383                (
 384                    email_address,
 385                    email_confirmation_code,
 386                    email_confirmation_sent,
 387                    platform_linux,
 388                    platform_mac,
 389                    platform_windows,
 390                    platform_unknown,
 391                    editor_features,
 392                    programming_languages,
 393                    device_id
 394                )
 395                VALUES
 396                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 397                RETURNING id
 398                ",
 399            )
 400            .bind(&signup.email_address)
 401            .bind(&random_email_confirmation_code())
 402            .bind(&signup.platform_linux)
 403            .bind(&signup.platform_mac)
 404            .bind(&signup.platform_windows)
 405            .bind(&signup.editor_features)
 406            .bind(&signup.programming_languages)
 407            .bind(&signup.device_id)
 408            .execute(&self.pool)
 409            .await?;
 410            Ok(())
 411        })
 412    }
 413
 414    pub async fn create_invite_from_code(
 415        &self,
 416        code: &str,
 417        email_address: &str,
 418        device_id: Option<&str>,
 419    ) -> Result<Invite> {
 420        test_support!(self, {
 421            let mut tx = self.pool.begin().await?;
 422
 423            let existing_user: Option<UserId> = sqlx::query_scalar(
 424                "
 425                SELECT id
 426                FROM users
 427                WHERE email_address = $1
 428                ",
 429            )
 430            .bind(email_address)
 431            .fetch_optional(&mut tx)
 432            .await?;
 433            if existing_user.is_some() {
 434                Err(anyhow!("email address is already in use"))?;
 435            }
 436
 437            let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
 438                "
 439                UPDATE users
 440                SET invite_count = invite_count - 1
 441                WHERE invite_code = $1 AND invite_count > 0
 442                RETURNING id
 443                ",
 444            )
 445            .bind(code)
 446            .fetch_optional(&mut tx)
 447            .await?;
 448
 449            let Some(inviter_id) = inviting_user_id_with_invites else {
 450                return Err(Error::Http(
 451                    StatusCode::UNAUTHORIZED,
 452                    "unable to find an invite code with invites remaining".to_string(),
 453                ));
 454            };
 455
 456            let email_confirmation_code: String = sqlx::query_scalar(
 457                "
 458                INSERT INTO signups
 459                (
 460                    email_address,
 461                    email_confirmation_code,
 462                    email_confirmation_sent,
 463                    inviting_user_id,
 464                    platform_linux,
 465                    platform_mac,
 466                    platform_windows,
 467                    platform_unknown,
 468                    device_id
 469                )
 470                VALUES
 471                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 472                ON CONFLICT (email_address)
 473                DO UPDATE SET
 474                    inviting_user_id = excluded.inviting_user_id
 475                RETURNING email_confirmation_code
 476                ",
 477            )
 478            .bind(&email_address)
 479            .bind(&random_email_confirmation_code())
 480            .bind(&inviter_id)
 481            .bind(&device_id)
 482            .fetch_one(&mut tx)
 483            .await?;
 484
 485            tx.commit().await?;
 486
 487            Ok(Invite {
 488                email_address: email_address.into(),
 489                email_confirmation_code,
 490            })
 491        })
 492    }
 493
 494    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 495        test_support!(self, {
 496            let emails = invites
 497                .iter()
 498                .map(|s| s.email_address.as_str())
 499                .collect::<Vec<_>>();
 500            sqlx::query(
 501                "
 502                UPDATE signups
 503                SET email_confirmation_sent = TRUE
 504                WHERE email_address = ANY ($1)
 505                ",
 506            )
 507            .bind(&emails)
 508            .execute(&self.pool)
 509            .await?;
 510            Ok(())
 511        })
 512    }
 513}
 514
 515impl<D> Db<D>
 516where
 517    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 518    D::Connection: sqlx::migrate::Migrate,
 519    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 520    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 521    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 522    D::QueryResult: RowsAffected,
 523    String: sqlx::Type<D>,
 524    i32: sqlx::Type<D>,
 525    i64: sqlx::Type<D>,
 526    bool: sqlx::Type<D>,
 527    str: sqlx::Type<D>,
 528    Uuid: sqlx::Type<D>,
 529    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 530    OffsetDateTime: sqlx::Type<D>,
 531    PrimitiveDateTime: sqlx::Type<D>,
 532    usize: sqlx::ColumnIndex<D::Row>,
 533    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 534    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 535    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 536    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 537    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 538    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 539    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 540    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 541    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 542    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 543    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 544    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 545{
 546    pub async fn migrate(
 547        &self,
 548        migrations_path: &Path,
 549        ignore_checksum_mismatch: bool,
 550    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 551        let migrations = MigrationSource::resolve(migrations_path)
 552            .await
 553            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 554
 555        let mut conn = self.pool.acquire().await?;
 556
 557        conn.ensure_migrations_table().await?;
 558        let applied_migrations: HashMap<_, _> = conn
 559            .list_applied_migrations()
 560            .await?
 561            .into_iter()
 562            .map(|m| (m.version, m))
 563            .collect();
 564
 565        let mut new_migrations = Vec::new();
 566        for migration in migrations {
 567            match applied_migrations.get(&migration.version) {
 568                Some(applied_migration) => {
 569                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 570                    {
 571                        Err(anyhow!(
 572                            "checksum mismatch for applied migration {}",
 573                            migration.description
 574                        ))?;
 575                    }
 576                }
 577                None => {
 578                    let elapsed = conn.apply(&migration).await?;
 579                    new_migrations.push((migration, elapsed));
 580                }
 581            }
 582        }
 583
 584        Ok(new_migrations)
 585    }
 586
 587    pub fn fuzzy_like_string(string: &str) -> String {
 588        let mut result = String::with_capacity(string.len() * 2 + 1);
 589        for c in string.chars() {
 590            if c.is_alphanumeric() {
 591                result.push('%');
 592                result.push(c);
 593            }
 594        }
 595        result.push('%');
 596        result
 597    }
 598
 599    // users
 600
 601    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 602        test_support!(self, {
 603            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 604            Ok(sqlx::query_as(query)
 605                .bind(limit as i32)
 606                .bind((page * limit) as i32)
 607                .fetch_all(&self.pool)
 608                .await?)
 609        })
 610    }
 611
 612    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 613        test_support!(self, {
 614            let query = "
 615                SELECT users.*
 616                FROM users
 617                WHERE id = $1
 618                LIMIT 1
 619            ";
 620            Ok(sqlx::query_as(query)
 621                .bind(&id)
 622                .fetch_optional(&self.pool)
 623                .await?)
 624        })
 625    }
 626
 627    pub async fn get_users_with_no_invites(
 628        &self,
 629        invited_by_another_user: bool,
 630    ) -> Result<Vec<User>> {
 631        test_support!(self, {
 632            let query = format!(
 633                "
 634                SELECT users.*
 635                FROM users
 636                WHERE invite_count = 0
 637                AND inviter_id IS{} NULL
 638                ",
 639                if invited_by_another_user { " NOT" } else { "" }
 640            );
 641
 642            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 643        })
 644    }
 645
 646    pub async fn get_user_by_github_account(
 647        &self,
 648        github_login: &str,
 649        github_user_id: Option<i32>,
 650    ) -> Result<Option<User>> {
 651        test_support!(self, {
 652            if let Some(github_user_id) = github_user_id {
 653                let mut user = sqlx::query_as::<_, User>(
 654                    "
 655                    UPDATE users
 656                    SET github_login = $1
 657                    WHERE github_user_id = $2
 658                    RETURNING *
 659                    ",
 660                )
 661                .bind(github_login)
 662                .bind(github_user_id)
 663                .fetch_optional(&self.pool)
 664                .await?;
 665
 666                if user.is_none() {
 667                    user = sqlx::query_as::<_, User>(
 668                        "
 669                        UPDATE users
 670                        SET github_user_id = $1
 671                        WHERE github_login = $2
 672                        RETURNING *
 673                        ",
 674                    )
 675                    .bind(github_user_id)
 676                    .bind(github_login)
 677                    .fetch_optional(&self.pool)
 678                    .await?;
 679                }
 680
 681                Ok(user)
 682            } else {
 683                let user = sqlx::query_as(
 684                    "
 685                    SELECT * FROM users
 686                    WHERE github_login = $1
 687                    LIMIT 1
 688                    ",
 689                )
 690                .bind(github_login)
 691                .fetch_optional(&self.pool)
 692                .await?;
 693                Ok(user)
 694            }
 695        })
 696    }
 697
 698    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 699        test_support!(self, {
 700            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 701            Ok(sqlx::query(query)
 702                .bind(is_admin)
 703                .bind(id.0)
 704                .execute(&self.pool)
 705                .await
 706                .map(drop)?)
 707        })
 708    }
 709
 710    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 711        test_support!(self, {
 712            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 713            Ok(sqlx::query(query)
 714                .bind(connected_once)
 715                .bind(id.0)
 716                .execute(&self.pool)
 717                .await
 718                .map(drop)?)
 719        })
 720    }
 721
 722    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 723        test_support!(self, {
 724            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 725            sqlx::query(query)
 726                .bind(id.0)
 727                .execute(&self.pool)
 728                .await
 729                .map(drop)?;
 730            let query = "DELETE FROM users WHERE id = $1;";
 731            Ok(sqlx::query(query)
 732                .bind(id.0)
 733                .execute(&self.pool)
 734                .await
 735                .map(drop)?)
 736        })
 737    }
 738
 739    // signups
 740
 741    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 742        test_support!(self, {
 743            Ok(sqlx::query_as(
 744                "
 745                SELECT
 746                    COUNT(*) as count,
 747                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 748                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 749                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 750                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 751                FROM (
 752                    SELECT *
 753                    FROM signups
 754                    WHERE
 755                        NOT email_confirmation_sent
 756                ) AS unsent
 757                ",
 758            )
 759            .fetch_one(&self.pool)
 760            .await?)
 761        })
 762    }
 763
 764    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 765        test_support!(self, {
 766            Ok(sqlx::query_as(
 767                "
 768                SELECT
 769                    email_address, email_confirmation_code
 770                FROM signups
 771                WHERE
 772                    NOT email_confirmation_sent AND
 773                    (platform_mac OR platform_unknown)
 774                LIMIT $1
 775                ",
 776            )
 777            .bind(count as i32)
 778            .fetch_all(&self.pool)
 779            .await?)
 780        })
 781    }
 782
 783    // invite codes
 784
 785    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 786        test_support!(self, {
 787            let mut tx = self.pool.begin().await?;
 788            if count > 0 {
 789                sqlx::query(
 790                    "
 791                    UPDATE users
 792                    SET invite_code = $1
 793                    WHERE id = $2 AND invite_code IS NULL
 794                ",
 795                )
 796                .bind(random_invite_code())
 797                .bind(id)
 798                .execute(&mut tx)
 799                .await?;
 800            }
 801
 802            sqlx::query(
 803                "
 804                UPDATE users
 805                SET invite_count = $1
 806                WHERE id = $2
 807                ",
 808            )
 809            .bind(count as i32)
 810            .bind(id)
 811            .execute(&mut tx)
 812            .await?;
 813            tx.commit().await?;
 814            Ok(())
 815        })
 816    }
 817
 818    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 819        test_support!(self, {
 820            let result: Option<(String, i32)> = sqlx::query_as(
 821                "
 822                    SELECT invite_code, invite_count
 823                    FROM users
 824                    WHERE id = $1 AND invite_code IS NOT NULL 
 825                ",
 826            )
 827            .bind(id)
 828            .fetch_optional(&self.pool)
 829            .await?;
 830            if let Some((code, count)) = result {
 831                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 832            } else {
 833                Ok(None)
 834            }
 835        })
 836    }
 837
 838    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 839        test_support!(self, {
 840            sqlx::query_as(
 841                "
 842                    SELECT *
 843                    FROM users
 844                    WHERE invite_code = $1
 845                ",
 846            )
 847            .bind(code)
 848            .fetch_optional(&self.pool)
 849            .await?
 850            .ok_or_else(|| {
 851                Error::Http(
 852                    StatusCode::NOT_FOUND,
 853                    "that invite code does not exist".to_string(),
 854                )
 855            })
 856        })
 857    }
 858
 859    // projects
 860
 861    /// Registers a new project for the given user.
 862    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 863        test_support!(self, {
 864            Ok(sqlx::query_scalar(
 865                "
 866                INSERT INTO projects(host_user_id)
 867                VALUES ($1)
 868                RETURNING id
 869                ",
 870            )
 871            .bind(host_user_id)
 872            .fetch_one(&self.pool)
 873            .await
 874            .map(ProjectId)?)
 875        })
 876    }
 877
 878    /// Unregisters a project for the given project id.
 879    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 880        test_support!(self, {
 881            sqlx::query(
 882                "
 883                UPDATE projects
 884                SET unregistered = TRUE
 885                WHERE id = $1
 886                ",
 887            )
 888            .bind(project_id)
 889            .execute(&self.pool)
 890            .await?;
 891            Ok(())
 892        })
 893    }
 894
 895    // contacts
 896
 897    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 898        test_support!(self, {
 899            let query = "
 900                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 901                FROM contacts
 902                WHERE user_id_a = $1 OR user_id_b = $1;
 903            ";
 904
 905            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 906                .bind(user_id)
 907                .fetch(&self.pool);
 908
 909            let mut contacts = Vec::new();
 910            while let Some(row) = rows.next().await {
 911                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 912
 913                if user_id_a == user_id {
 914                    if accepted {
 915                        contacts.push(Contact::Accepted {
 916                            user_id: user_id_b,
 917                            should_notify: should_notify && a_to_b,
 918                        });
 919                    } else if a_to_b {
 920                        contacts.push(Contact::Outgoing { user_id: user_id_b })
 921                    } else {
 922                        contacts.push(Contact::Incoming {
 923                            user_id: user_id_b,
 924                            should_notify,
 925                        });
 926                    }
 927                } else if accepted {
 928                    contacts.push(Contact::Accepted {
 929                        user_id: user_id_a,
 930                        should_notify: should_notify && !a_to_b,
 931                    });
 932                } else if a_to_b {
 933                    contacts.push(Contact::Incoming {
 934                        user_id: user_id_a,
 935                        should_notify,
 936                    });
 937                } else {
 938                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 939                }
 940            }
 941
 942            contacts.sort_unstable_by_key(|contact| contact.user_id());
 943
 944            Ok(contacts)
 945        })
 946    }
 947
 948    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 949        test_support!(self, {
 950            let (id_a, id_b) = if user_id_1 < user_id_2 {
 951                (user_id_1, user_id_2)
 952            } else {
 953                (user_id_2, user_id_1)
 954            };
 955
 956            let query = "
 957                SELECT 1 FROM contacts
 958                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
 959                LIMIT 1
 960            ";
 961            Ok(sqlx::query_scalar::<_, i32>(query)
 962                .bind(id_a.0)
 963                .bind(id_b.0)
 964                .fetch_optional(&self.pool)
 965                .await?
 966                .is_some())
 967        })
 968    }
 969
 970    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 971        test_support!(self, {
 972            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 973                (sender_id, receiver_id, true)
 974            } else {
 975                (receiver_id, sender_id, false)
 976            };
 977            let query = "
 978                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 979                VALUES ($1, $2, $3, FALSE, TRUE)
 980                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 981                SET
 982                    accepted = TRUE,
 983                    should_notify = FALSE
 984                WHERE
 985                    NOT contacts.accepted AND
 986                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 987                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 988            ";
 989            let result = sqlx::query(query)
 990                .bind(id_a.0)
 991                .bind(id_b.0)
 992                .bind(a_to_b)
 993                .execute(&self.pool)
 994                .await?;
 995
 996            if result.rows_affected() == 1 {
 997                Ok(())
 998            } else {
 999                Err(anyhow!("contact already requested"))?
1000            }
1001        })
1002    }
1003
1004    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1005        test_support!(self, {
1006            let (id_a, id_b) = if responder_id < requester_id {
1007                (responder_id, requester_id)
1008            } else {
1009                (requester_id, responder_id)
1010            };
1011            let query = "
1012                DELETE FROM contacts
1013                WHERE user_id_a = $1 AND user_id_b = $2;
1014            ";
1015            let result = sqlx::query(query)
1016                .bind(id_a.0)
1017                .bind(id_b.0)
1018                .execute(&self.pool)
1019                .await?;
1020
1021            if result.rows_affected() == 1 {
1022                Ok(())
1023            } else {
1024                Err(anyhow!("no such contact"))?
1025            }
1026        })
1027    }
1028
1029    pub async fn dismiss_contact_notification(
1030        &self,
1031        user_id: UserId,
1032        contact_user_id: UserId,
1033    ) -> Result<()> {
1034        test_support!(self, {
1035            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1036                (user_id, contact_user_id, true)
1037            } else {
1038                (contact_user_id, user_id, false)
1039            };
1040
1041            let query = "
1042                UPDATE contacts
1043                SET should_notify = FALSE
1044                WHERE
1045                    user_id_a = $1 AND user_id_b = $2 AND
1046                    (
1047                        (a_to_b = $3 AND accepted) OR
1048                        (a_to_b != $3 AND NOT accepted)
1049                    );
1050            ";
1051
1052            let result = sqlx::query(query)
1053                .bind(id_a.0)
1054                .bind(id_b.0)
1055                .bind(a_to_b)
1056                .execute(&self.pool)
1057                .await?;
1058
1059            if result.rows_affected() == 0 {
1060                Err(anyhow!("no such contact request"))?;
1061            }
1062
1063            Ok(())
1064        })
1065    }
1066
1067    pub async fn respond_to_contact_request(
1068        &self,
1069        responder_id: UserId,
1070        requester_id: UserId,
1071        accept: bool,
1072    ) -> Result<()> {
1073        test_support!(self, {
1074            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1075                (responder_id, requester_id, false)
1076            } else {
1077                (requester_id, responder_id, true)
1078            };
1079            let result = if accept {
1080                let query = "
1081                    UPDATE contacts
1082                    SET accepted = TRUE, should_notify = TRUE
1083                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1084                ";
1085                sqlx::query(query)
1086                    .bind(id_a.0)
1087                    .bind(id_b.0)
1088                    .bind(a_to_b)
1089                    .execute(&self.pool)
1090                    .await?
1091            } else {
1092                let query = "
1093                    DELETE FROM contacts
1094                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1095                ";
1096                sqlx::query(query)
1097                    .bind(id_a.0)
1098                    .bind(id_b.0)
1099                    .bind(a_to_b)
1100                    .execute(&self.pool)
1101                    .await?
1102            };
1103            if result.rows_affected() == 1 {
1104                Ok(())
1105            } else {
1106                Err(anyhow!("no such contact request"))?
1107            }
1108        })
1109    }
1110
1111    // access tokens
1112
1113    pub async fn create_access_token_hash(
1114        &self,
1115        user_id: UserId,
1116        access_token_hash: &str,
1117        max_access_token_count: usize,
1118    ) -> Result<()> {
1119        test_support!(self, {
1120            let insert_query = "
1121                INSERT INTO access_tokens (user_id, hash)
1122                VALUES ($1, $2);
1123            ";
1124            let cleanup_query = "
1125                DELETE FROM access_tokens
1126                WHERE id IN (
1127                    SELECT id from access_tokens
1128                    WHERE user_id = $1
1129                    ORDER BY id DESC
1130                    LIMIT 10000
1131                    OFFSET $3
1132                )
1133            ";
1134
1135            let mut tx = self.pool.begin().await?;
1136            sqlx::query(insert_query)
1137                .bind(user_id.0)
1138                .bind(access_token_hash)
1139                .execute(&mut tx)
1140                .await?;
1141            sqlx::query(cleanup_query)
1142                .bind(user_id.0)
1143                .bind(access_token_hash)
1144                .bind(max_access_token_count as i32)
1145                .execute(&mut tx)
1146                .await?;
1147            Ok(tx.commit().await?)
1148        })
1149    }
1150
1151    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1152        test_support!(self, {
1153            let query = "
1154                SELECT hash
1155                FROM access_tokens
1156                WHERE user_id = $1
1157                ORDER BY id DESC
1158            ";
1159            Ok(sqlx::query_scalar(query)
1160                .bind(user_id.0)
1161                .fetch_all(&self.pool)
1162                .await?)
1163        })
1164    }
1165}
1166
1167macro_rules! id_type {
1168    ($name:ident) => {
1169        #[derive(
1170            Clone,
1171            Copy,
1172            Debug,
1173            Default,
1174            PartialEq,
1175            Eq,
1176            PartialOrd,
1177            Ord,
1178            Hash,
1179            sqlx::Type,
1180            Serialize,
1181            Deserialize,
1182        )]
1183        #[sqlx(transparent)]
1184        #[serde(transparent)]
1185        pub struct $name(pub i32);
1186
1187        impl $name {
1188            #[allow(unused)]
1189            pub const MAX: Self = Self(i32::MAX);
1190
1191            #[allow(unused)]
1192            pub fn from_proto(value: u64) -> Self {
1193                Self(value as i32)
1194            }
1195
1196            #[allow(unused)]
1197            pub fn to_proto(self) -> u64 {
1198                self.0 as u64
1199            }
1200        }
1201
1202        impl std::fmt::Display for $name {
1203            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1204                self.0.fmt(f)
1205            }
1206        }
1207    };
1208}
1209
1210id_type!(UserId);
1211#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1212pub struct User {
1213    pub id: UserId,
1214    pub github_login: String,
1215    pub github_user_id: Option<i32>,
1216    pub email_address: Option<String>,
1217    pub admin: bool,
1218    pub invite_code: Option<String>,
1219    pub invite_count: i32,
1220    pub connected_once: bool,
1221}
1222
1223id_type!(ProjectId);
1224#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1225pub struct Project {
1226    pub id: ProjectId,
1227    pub host_user_id: UserId,
1228    pub unregistered: bool,
1229}
1230
1231#[derive(Clone, Debug, PartialEq, Eq)]
1232pub enum Contact {
1233    Accepted {
1234        user_id: UserId,
1235        should_notify: bool,
1236    },
1237    Outgoing {
1238        user_id: UserId,
1239    },
1240    Incoming {
1241        user_id: UserId,
1242        should_notify: bool,
1243    },
1244}
1245
1246impl Contact {
1247    pub fn user_id(&self) -> UserId {
1248        match self {
1249            Contact::Accepted { user_id, .. } => *user_id,
1250            Contact::Outgoing { user_id } => *user_id,
1251            Contact::Incoming { user_id, .. } => *user_id,
1252        }
1253    }
1254}
1255
1256#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1257pub struct IncomingContactRequest {
1258    pub requester_id: UserId,
1259    pub should_notify: bool,
1260}
1261
1262#[derive(Clone, Deserialize)]
1263pub struct Signup {
1264    pub email_address: String,
1265    pub platform_mac: bool,
1266    pub platform_windows: bool,
1267    pub platform_linux: bool,
1268    pub editor_features: Vec<String>,
1269    pub programming_languages: Vec<String>,
1270    pub device_id: Option<String>,
1271}
1272
1273#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1274pub struct WaitlistSummary {
1275    #[sqlx(default)]
1276    pub count: i64,
1277    #[sqlx(default)]
1278    pub linux_count: i64,
1279    #[sqlx(default)]
1280    pub mac_count: i64,
1281    #[sqlx(default)]
1282    pub windows_count: i64,
1283    #[sqlx(default)]
1284    pub unknown_count: i64,
1285}
1286
1287#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1288pub struct Invite {
1289    pub email_address: String,
1290    pub email_confirmation_code: String,
1291}
1292
1293#[derive(Debug, Serialize, Deserialize)]
1294pub struct NewUserParams {
1295    pub github_login: String,
1296    pub github_user_id: i32,
1297    pub invite_count: i32,
1298}
1299
1300#[derive(Debug)]
1301pub struct NewUserResult {
1302    pub user_id: UserId,
1303    pub metrics_id: String,
1304    pub inviting_user_id: Option<UserId>,
1305    pub signup_device_id: Option<String>,
1306}
1307
1308fn random_invite_code() -> String {
1309    nanoid::nanoid!(16)
1310}
1311
1312fn random_email_confirmation_code() -> String {
1313    nanoid::nanoid!(64)
1314}
1315
1316#[cfg(test)]
1317pub use test::*;
1318
1319#[cfg(test)]
1320mod test {
1321    use super::*;
1322    use gpui::executor::Background;
1323    use lazy_static::lazy_static;
1324    use parking_lot::Mutex;
1325    use rand::prelude::*;
1326    use sqlx::migrate::MigrateDatabase;
1327    use std::sync::Arc;
1328
1329    pub struct SqliteTestDb {
1330        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1331        pub conn: sqlx::sqlite::SqliteConnection,
1332    }
1333
1334    pub struct PostgresTestDb {
1335        pub db: Option<Arc<Db<sqlx::Postgres>>>,
1336        pub url: String,
1337    }
1338
1339    impl SqliteTestDb {
1340        pub fn new(background: Arc<Background>) -> Self {
1341            let mut rng = StdRng::from_entropy();
1342            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1343            let runtime = tokio::runtime::Builder::new_current_thread()
1344                .enable_io()
1345                .enable_time()
1346                .build()
1347                .unwrap();
1348
1349            let (mut db, conn) = runtime.block_on(async {
1350                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1351                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1352                db.migrate(migrations_path.as_ref(), false).await.unwrap();
1353                let conn = db.pool.acquire().await.unwrap().detach();
1354                (db, conn)
1355            });
1356
1357            db.background = Some(background);
1358            db.runtime = Some(runtime);
1359
1360            Self {
1361                db: Some(Arc::new(db)),
1362                conn,
1363            }
1364        }
1365
1366        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1367            self.db.as_ref().unwrap()
1368        }
1369    }
1370
1371    impl PostgresTestDb {
1372        pub fn new(background: Arc<Background>) -> Self {
1373            lazy_static! {
1374                static ref LOCK: Mutex<()> = Mutex::new(());
1375            }
1376
1377            let _guard = LOCK.lock();
1378            let mut rng = StdRng::from_entropy();
1379            let url = format!(
1380                "postgres://postgres@localhost/zed-test-{}",
1381                rng.gen::<u128>()
1382            );
1383            let runtime = tokio::runtime::Builder::new_current_thread()
1384                .enable_io()
1385                .enable_time()
1386                .build()
1387                .unwrap();
1388
1389            let mut db = runtime.block_on(async {
1390                sqlx::Postgres::create_database(&url)
1391                    .await
1392                    .expect("failed to create test db");
1393                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1394                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1395                db.migrate(Path::new(migrations_path), false).await.unwrap();
1396                db
1397            });
1398
1399            db.background = Some(background);
1400            db.runtime = Some(runtime);
1401
1402            Self {
1403                db: Some(Arc::new(db)),
1404                url,
1405            }
1406        }
1407
1408        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1409            self.db.as_ref().unwrap()
1410        }
1411    }
1412
1413    impl Drop for PostgresTestDb {
1414        fn drop(&mut self) {
1415            let db = self.db.take().unwrap();
1416            db.teardown(&self.url);
1417        }
1418    }
1419}