db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use serde::{Deserialize, Serialize};
   7use sqlx::{
   8    migrate::{Migrate as _, Migration, MigrationSource},
   9    types::Uuid,
  10    FromRow,
  11};
  12use std::{path::Path, time::Duration};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[cfg(test)]
  16pub type DefaultDb = Db<sqlx::Sqlite>;
  17
  18#[cfg(not(test))]
  19pub type DefaultDb = Db<sqlx::Postgres>;
  20
  21pub struct Db<D: sqlx::Database> {
  22    pool: sqlx::Pool<D>,
  23    #[cfg(test)]
  24    background: Option<std::sync::Arc<gpui::executor::Background>>,
  25    #[cfg(test)]
  26    runtime: Option<tokio::runtime::Runtime>,
  27}
  28
  29macro_rules! test_support {
  30    ($self:ident, { $($token:tt)* }) => {{
  31        let body = async {
  32            $($token)*
  33        };
  34
  35        if cfg!(test) {
  36            #[cfg(not(test))]
  37            unreachable!();
  38
  39            #[cfg(test)]
  40            if let Some(background) = $self.background.as_ref() {
  41                background.simulate_random_delay().await;
  42            }
  43
  44            #[cfg(test)]
  45            $self.runtime.as_ref().unwrap().block_on(body)
  46        } else {
  47            body.await
  48        }
  49    }};
  50}
  51
  52pub trait RowsAffected {
  53    fn rows_affected(&self) -> u64;
  54}
  55
  56#[cfg(test)]
  57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  58    fn rows_affected(&self) -> u64 {
  59        self.rows_affected()
  60    }
  61}
  62
  63impl RowsAffected for sqlx::postgres::PgQueryResult {
  64    fn rows_affected(&self) -> u64 {
  65        self.rows_affected()
  66    }
  67}
  68
  69#[cfg(test)]
  70impl Db<sqlx::Sqlite> {
  71    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  72        use std::str::FromStr as _;
  73        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  74            .unwrap()
  75            .create_if_missing(true)
  76            .shared_cache(true);
  77        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  78            .min_connections(2)
  79            .max_connections(max_connections)
  80            .connect_with(options)
  81            .await?;
  82        Ok(Self {
  83            pool,
  84            background: None,
  85            runtime: None,
  86        })
  87    }
  88
  89    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  90        test_support!(self, {
  91            let query = "
  92                SELECT users.*
  93                FROM users
  94                WHERE users.id IN (SELECT value from json_each($1))
  95            ";
  96            Ok(sqlx::query_as(query)
  97                .bind(&serde_json::json!(ids))
  98                .fetch_all(&self.pool)
  99                .await?)
 100        })
 101    }
 102
 103    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 104        test_support!(self, {
 105            let query = "
 106                SELECT metrics_id
 107                FROM users
 108                WHERE id = $1
 109            ";
 110            Ok(sqlx::query_scalar(query)
 111                .bind(id)
 112                .fetch_one(&self.pool)
 113                .await?)
 114        })
 115    }
 116
 117    pub async fn create_user(
 118        &self,
 119        email_address: &str,
 120        admin: bool,
 121        params: NewUserParams,
 122    ) -> Result<NewUserResult> {
 123        test_support!(self, {
 124            let query = "
 125                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 126                VALUES ($1, $2, $3, $4, $5)
 127                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 128                RETURNING id, metrics_id
 129            ";
 130
 131            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 132                .bind(email_address)
 133                .bind(params.github_login)
 134                .bind(params.github_user_id)
 135                .bind(admin)
 136                .bind(Uuid::new_v4().to_string())
 137                .fetch_one(&self.pool)
 138                .await?;
 139            Ok(NewUserResult {
 140                user_id,
 141                metrics_id,
 142                signup_device_id: None,
 143                inviting_user_id: None,
 144            })
 145        })
 146    }
 147
 148    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 149        unimplemented!()
 150    }
 151
 152    pub async fn create_user_from_invite(
 153        &self,
 154        _invite: &Invite,
 155        _user: NewUserParams,
 156    ) -> Result<Option<NewUserResult>> {
 157        unimplemented!()
 158    }
 159
 160    pub async fn create_signup(&self, _signup: &Signup) -> Result<()> {
 161        unimplemented!()
 162    }
 163
 164    pub async fn create_invite_from_code(
 165        &self,
 166        _code: &str,
 167        _email_address: &str,
 168        _device_id: Option<&str>,
 169    ) -> Result<Invite> {
 170        unimplemented!()
 171    }
 172
 173    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 174        unimplemented!()
 175    }
 176}
 177
 178impl Db<sqlx::Postgres> {
 179    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 180        let pool = sqlx::postgres::PgPoolOptions::new()
 181            .max_connections(max_connections)
 182            .connect(url)
 183            .await?;
 184        Ok(Self {
 185            pool,
 186            #[cfg(test)]
 187            background: None,
 188            #[cfg(test)]
 189            runtime: None,
 190        })
 191    }
 192
 193    #[cfg(test)]
 194    pub fn teardown(&self, url: &str) {
 195        self.runtime.as_ref().unwrap().block_on(async {
 196            use util::ResultExt;
 197            let query = "
 198                SELECT pg_terminate_backend(pg_stat_activity.pid)
 199                FROM pg_stat_activity
 200                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 201            ";
 202            sqlx::query(query).execute(&self.pool).await.log_err();
 203            self.pool.close().await;
 204            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 205                .await
 206                .log_err();
 207        })
 208    }
 209
 210    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 211        test_support!(self, {
 212            let like_string = Self::fuzzy_like_string(name_query);
 213            let query = "
 214                SELECT users.*
 215                FROM users
 216                WHERE github_login ILIKE $1
 217                ORDER BY github_login <-> $2
 218                LIMIT $3
 219            ";
 220            Ok(sqlx::query_as(query)
 221                .bind(like_string)
 222                .bind(name_query)
 223                .bind(limit as i32)
 224                .fetch_all(&self.pool)
 225                .await?)
 226        })
 227    }
 228
 229    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 230        test_support!(self, {
 231            let query = "
 232                SELECT users.*
 233                FROM users
 234                WHERE users.id = ANY ($1)
 235            ";
 236            Ok(sqlx::query_as(query)
 237                .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
 238                .fetch_all(&self.pool)
 239                .await?)
 240        })
 241    }
 242
 243    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 244        test_support!(self, {
 245            let query = "
 246                SELECT metrics_id::text
 247                FROM users
 248                WHERE id = $1
 249            ";
 250            Ok(sqlx::query_scalar(query)
 251                .bind(id)
 252                .fetch_one(&self.pool)
 253                .await?)
 254        })
 255    }
 256
 257    pub async fn create_user(
 258        &self,
 259        email_address: &str,
 260        admin: bool,
 261        params: NewUserParams,
 262    ) -> Result<NewUserResult> {
 263        test_support!(self, {
 264            let query = "
 265                INSERT INTO users (email_address, github_login, github_user_id, admin)
 266                VALUES ($1, $2, $3, $4)
 267                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 268                RETURNING id, metrics_id::text
 269            ";
 270
 271            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 272                .bind(email_address)
 273                .bind(params.github_login)
 274                .bind(params.github_user_id)
 275                .bind(admin)
 276                .fetch_one(&self.pool)
 277                .await?;
 278            Ok(NewUserResult {
 279                user_id,
 280                metrics_id,
 281                signup_device_id: None,
 282                inviting_user_id: None,
 283            })
 284        })
 285    }
 286
 287    pub async fn create_user_from_invite(
 288        &self,
 289        invite: &Invite,
 290        user: NewUserParams,
 291    ) -> Result<Option<NewUserResult>> {
 292        test_support!(self, {
 293            let mut tx = self.pool.begin().await?;
 294
 295            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 296                i32,
 297                Option<UserId>,
 298                Option<UserId>,
 299                Option<String>,
 300            ) = sqlx::query_as(
 301                "
 302                SELECT id, user_id, inviting_user_id, device_id
 303                FROM signups
 304                WHERE
 305                    email_address = $1 AND
 306                    email_confirmation_code = $2
 307                ",
 308            )
 309            .bind(&invite.email_address)
 310            .bind(&invite.email_confirmation_code)
 311            .fetch_optional(&mut tx)
 312            .await?
 313            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 314
 315            if existing_user_id.is_some() {
 316                return Ok(None);
 317            }
 318
 319            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 320                "
 321                INSERT INTO users
 322                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 323                VALUES
 324                ($1, $2, $3, FALSE, $4, $5)
 325                ON CONFLICT (github_login) DO UPDATE SET
 326                    email_address = excluded.email_address,
 327                    github_user_id = excluded.github_user_id,
 328                    admin = excluded.admin
 329                RETURNING id, metrics_id::text
 330                ",
 331            )
 332            .bind(&invite.email_address)
 333            .bind(&user.github_login)
 334            .bind(&user.github_user_id)
 335            .bind(&user.invite_count)
 336            .bind(random_invite_code())
 337            .fetch_one(&mut tx)
 338            .await?;
 339
 340            sqlx::query(
 341                "
 342                UPDATE signups
 343                SET user_id = $1
 344                WHERE id = $2
 345                ",
 346            )
 347            .bind(&user_id)
 348            .bind(&signup_id)
 349            .execute(&mut tx)
 350            .await?;
 351
 352            if let Some(inviting_user_id) = inviting_user_id {
 353                sqlx::query(
 354                    "
 355                    INSERT INTO contacts
 356                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 357                    VALUES
 358                        ($1, $2, TRUE, TRUE, TRUE)
 359                    ON CONFLICT DO NOTHING
 360                    ",
 361                )
 362                .bind(inviting_user_id)
 363                .bind(user_id)
 364                .execute(&mut tx)
 365                .await?;
 366            }
 367
 368            tx.commit().await?;
 369            Ok(Some(NewUserResult {
 370                user_id,
 371                metrics_id,
 372                inviting_user_id,
 373                signup_device_id,
 374            }))
 375        })
 376    }
 377
 378    pub async fn create_signup(&self, signup: &Signup) -> Result<()> {
 379        test_support!(self, {
 380            sqlx::query(
 381                "
 382                INSERT INTO signups
 383                (
 384                    email_address,
 385                    email_confirmation_code,
 386                    email_confirmation_sent,
 387                    platform_linux,
 388                    platform_mac,
 389                    platform_windows,
 390                    platform_unknown,
 391                    editor_features,
 392                    programming_languages,
 393                    device_id
 394                )
 395                VALUES
 396                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 397                ON CONFLICT (email_address) DO UPDATE SET
 398                    email_address = excluded.email_address
 399                RETURNING id
 400                ",
 401            )
 402            .bind(&signup.email_address)
 403            .bind(&random_email_confirmation_code())
 404            .bind(&signup.platform_linux)
 405            .bind(&signup.platform_mac)
 406            .bind(&signup.platform_windows)
 407            .bind(&signup.editor_features)
 408            .bind(&signup.programming_languages)
 409            .bind(&signup.device_id)
 410            .execute(&self.pool)
 411            .await?;
 412            Ok(())
 413        })
 414    }
 415
 416    pub async fn create_invite_from_code(
 417        &self,
 418        code: &str,
 419        email_address: &str,
 420        device_id: Option<&str>,
 421    ) -> Result<Invite> {
 422        test_support!(self, {
 423            let mut tx = self.pool.begin().await?;
 424
 425            let existing_user: Option<UserId> = sqlx::query_scalar(
 426                "
 427                SELECT id
 428                FROM users
 429                WHERE email_address = $1
 430                ",
 431            )
 432            .bind(email_address)
 433            .fetch_optional(&mut tx)
 434            .await?;
 435            if existing_user.is_some() {
 436                Err(anyhow!("email address is already in use"))?;
 437            }
 438
 439            let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
 440                "
 441                UPDATE users
 442                SET invite_count = invite_count - 1
 443                WHERE invite_code = $1 AND invite_count > 0
 444                RETURNING id
 445                ",
 446            )
 447            .bind(code)
 448            .fetch_optional(&mut tx)
 449            .await?;
 450
 451            let Some(inviter_id) = inviting_user_id_with_invites else {
 452                return Err(Error::Http(
 453                    StatusCode::UNAUTHORIZED,
 454                    "unable to find an invite code with invites remaining".to_string(),
 455                ));
 456            };
 457
 458            let email_confirmation_code: String = sqlx::query_scalar(
 459                "
 460                INSERT INTO signups
 461                (
 462                    email_address,
 463                    email_confirmation_code,
 464                    email_confirmation_sent,
 465                    inviting_user_id,
 466                    platform_linux,
 467                    platform_mac,
 468                    platform_windows,
 469                    platform_unknown,
 470                    device_id
 471                )
 472                VALUES
 473                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 474                ON CONFLICT (email_address)
 475                DO UPDATE SET
 476                    inviting_user_id = excluded.inviting_user_id
 477                RETURNING email_confirmation_code
 478                ",
 479            )
 480            .bind(&email_address)
 481            .bind(&random_email_confirmation_code())
 482            .bind(&inviter_id)
 483            .bind(&device_id)
 484            .fetch_one(&mut tx)
 485            .await?;
 486
 487            tx.commit().await?;
 488
 489            Ok(Invite {
 490                email_address: email_address.into(),
 491                email_confirmation_code,
 492            })
 493        })
 494    }
 495
 496    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 497        test_support!(self, {
 498            let emails = invites
 499                .iter()
 500                .map(|s| s.email_address.as_str())
 501                .collect::<Vec<_>>();
 502            sqlx::query(
 503                "
 504                UPDATE signups
 505                SET email_confirmation_sent = TRUE
 506                WHERE email_address = ANY ($1)
 507                ",
 508            )
 509            .bind(&emails)
 510            .execute(&self.pool)
 511            .await?;
 512            Ok(())
 513        })
 514    }
 515}
 516
 517impl<D> Db<D>
 518where
 519    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 520    D::Connection: sqlx::migrate::Migrate,
 521    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 522    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 523    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 524    D::QueryResult: RowsAffected,
 525    String: sqlx::Type<D>,
 526    i32: sqlx::Type<D>,
 527    i64: sqlx::Type<D>,
 528    bool: sqlx::Type<D>,
 529    str: sqlx::Type<D>,
 530    Uuid: sqlx::Type<D>,
 531    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 532    OffsetDateTime: sqlx::Type<D>,
 533    PrimitiveDateTime: sqlx::Type<D>,
 534    usize: sqlx::ColumnIndex<D::Row>,
 535    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 536    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 537    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 538    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 539    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 540    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 541    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 542    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 543    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 544    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 545    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 546    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 547{
 548    pub async fn migrate(
 549        &self,
 550        migrations_path: &Path,
 551        ignore_checksum_mismatch: bool,
 552    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 553        let migrations = MigrationSource::resolve(migrations_path)
 554            .await
 555            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 556
 557        let mut conn = self.pool.acquire().await?;
 558
 559        conn.ensure_migrations_table().await?;
 560        let applied_migrations: HashMap<_, _> = conn
 561            .list_applied_migrations()
 562            .await?
 563            .into_iter()
 564            .map(|m| (m.version, m))
 565            .collect();
 566
 567        let mut new_migrations = Vec::new();
 568        for migration in migrations {
 569            match applied_migrations.get(&migration.version) {
 570                Some(applied_migration) => {
 571                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 572                    {
 573                        Err(anyhow!(
 574                            "checksum mismatch for applied migration {}",
 575                            migration.description
 576                        ))?;
 577                    }
 578                }
 579                None => {
 580                    let elapsed = conn.apply(&migration).await?;
 581                    new_migrations.push((migration, elapsed));
 582                }
 583            }
 584        }
 585
 586        Ok(new_migrations)
 587    }
 588
 589    pub fn fuzzy_like_string(string: &str) -> String {
 590        let mut result = String::with_capacity(string.len() * 2 + 1);
 591        for c in string.chars() {
 592            if c.is_alphanumeric() {
 593                result.push('%');
 594                result.push(c);
 595            }
 596        }
 597        result.push('%');
 598        result
 599    }
 600
 601    // users
 602
 603    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 604        test_support!(self, {
 605            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 606            Ok(sqlx::query_as(query)
 607                .bind(limit as i32)
 608                .bind((page * limit) as i32)
 609                .fetch_all(&self.pool)
 610                .await?)
 611        })
 612    }
 613
 614    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 615        test_support!(self, {
 616            let query = "
 617                SELECT users.*
 618                FROM users
 619                WHERE id = $1
 620                LIMIT 1
 621            ";
 622            Ok(sqlx::query_as(query)
 623                .bind(&id)
 624                .fetch_optional(&self.pool)
 625                .await?)
 626        })
 627    }
 628
 629    pub async fn get_users_with_no_invites(
 630        &self,
 631        invited_by_another_user: bool,
 632    ) -> Result<Vec<User>> {
 633        test_support!(self, {
 634            let query = format!(
 635                "
 636                SELECT users.*
 637                FROM users
 638                WHERE invite_count = 0
 639                AND inviter_id IS{} NULL
 640                ",
 641                if invited_by_another_user { " NOT" } else { "" }
 642            );
 643
 644            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 645        })
 646    }
 647
 648    pub async fn get_user_by_github_account(
 649        &self,
 650        github_login: &str,
 651        github_user_id: Option<i32>,
 652    ) -> Result<Option<User>> {
 653        test_support!(self, {
 654            if let Some(github_user_id) = github_user_id {
 655                let mut user = sqlx::query_as::<_, User>(
 656                    "
 657                    UPDATE users
 658                    SET github_login = $1
 659                    WHERE github_user_id = $2
 660                    RETURNING *
 661                    ",
 662                )
 663                .bind(github_login)
 664                .bind(github_user_id)
 665                .fetch_optional(&self.pool)
 666                .await?;
 667
 668                if user.is_none() {
 669                    user = sqlx::query_as::<_, User>(
 670                        "
 671                        UPDATE users
 672                        SET github_user_id = $1
 673                        WHERE github_login = $2
 674                        RETURNING *
 675                        ",
 676                    )
 677                    .bind(github_user_id)
 678                    .bind(github_login)
 679                    .fetch_optional(&self.pool)
 680                    .await?;
 681                }
 682
 683                Ok(user)
 684            } else {
 685                let user = sqlx::query_as(
 686                    "
 687                    SELECT * FROM users
 688                    WHERE github_login = $1
 689                    LIMIT 1
 690                    ",
 691                )
 692                .bind(github_login)
 693                .fetch_optional(&self.pool)
 694                .await?;
 695                Ok(user)
 696            }
 697        })
 698    }
 699
 700    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 701        test_support!(self, {
 702            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 703            Ok(sqlx::query(query)
 704                .bind(is_admin)
 705                .bind(id.0)
 706                .execute(&self.pool)
 707                .await
 708                .map(drop)?)
 709        })
 710    }
 711
 712    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 713        test_support!(self, {
 714            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 715            Ok(sqlx::query(query)
 716                .bind(connected_once)
 717                .bind(id.0)
 718                .execute(&self.pool)
 719                .await
 720                .map(drop)?)
 721        })
 722    }
 723
 724    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 725        test_support!(self, {
 726            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 727            sqlx::query(query)
 728                .bind(id.0)
 729                .execute(&self.pool)
 730                .await
 731                .map(drop)?;
 732            let query = "DELETE FROM users WHERE id = $1;";
 733            Ok(sqlx::query(query)
 734                .bind(id.0)
 735                .execute(&self.pool)
 736                .await
 737                .map(drop)?)
 738        })
 739    }
 740
 741    // signups
 742
 743    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 744        test_support!(self, {
 745            Ok(sqlx::query_as(
 746                "
 747                SELECT
 748                    COUNT(*) as count,
 749                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 750                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 751                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 752                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 753                FROM (
 754                    SELECT *
 755                    FROM signups
 756                    WHERE
 757                        NOT email_confirmation_sent
 758                ) AS unsent
 759                ",
 760            )
 761            .fetch_one(&self.pool)
 762            .await?)
 763        })
 764    }
 765
 766    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 767        test_support!(self, {
 768            Ok(sqlx::query_as(
 769                "
 770                SELECT
 771                    email_address, email_confirmation_code
 772                FROM signups
 773                WHERE
 774                    NOT email_confirmation_sent AND
 775                    (platform_mac OR platform_unknown)
 776                LIMIT $1
 777                ",
 778            )
 779            .bind(count as i32)
 780            .fetch_all(&self.pool)
 781            .await?)
 782        })
 783    }
 784
 785    // invite codes
 786
 787    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 788        test_support!(self, {
 789            let mut tx = self.pool.begin().await?;
 790            if count > 0 {
 791                sqlx::query(
 792                    "
 793                    UPDATE users
 794                    SET invite_code = $1
 795                    WHERE id = $2 AND invite_code IS NULL
 796                ",
 797                )
 798                .bind(random_invite_code())
 799                .bind(id)
 800                .execute(&mut tx)
 801                .await?;
 802            }
 803
 804            sqlx::query(
 805                "
 806                UPDATE users
 807                SET invite_count = $1
 808                WHERE id = $2
 809                ",
 810            )
 811            .bind(count as i32)
 812            .bind(id)
 813            .execute(&mut tx)
 814            .await?;
 815            tx.commit().await?;
 816            Ok(())
 817        })
 818    }
 819
 820    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 821        test_support!(self, {
 822            let result: Option<(String, i32)> = sqlx::query_as(
 823                "
 824                    SELECT invite_code, invite_count
 825                    FROM users
 826                    WHERE id = $1 AND invite_code IS NOT NULL 
 827                ",
 828            )
 829            .bind(id)
 830            .fetch_optional(&self.pool)
 831            .await?;
 832            if let Some((code, count)) = result {
 833                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 834            } else {
 835                Ok(None)
 836            }
 837        })
 838    }
 839
 840    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 841        test_support!(self, {
 842            sqlx::query_as(
 843                "
 844                    SELECT *
 845                    FROM users
 846                    WHERE invite_code = $1
 847                ",
 848            )
 849            .bind(code)
 850            .fetch_optional(&self.pool)
 851            .await?
 852            .ok_or_else(|| {
 853                Error::Http(
 854                    StatusCode::NOT_FOUND,
 855                    "that invite code does not exist".to_string(),
 856                )
 857            })
 858        })
 859    }
 860
 861    // projects
 862
 863    /// Registers a new project for the given user.
 864    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 865        test_support!(self, {
 866            Ok(sqlx::query_scalar(
 867                "
 868                INSERT INTO projects(host_user_id)
 869                VALUES ($1)
 870                RETURNING id
 871                ",
 872            )
 873            .bind(host_user_id)
 874            .fetch_one(&self.pool)
 875            .await
 876            .map(ProjectId)?)
 877        })
 878    }
 879
 880    /// Unregisters a project for the given project id.
 881    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 882        test_support!(self, {
 883            sqlx::query(
 884                "
 885                UPDATE projects
 886                SET unregistered = TRUE
 887                WHERE id = $1
 888                ",
 889            )
 890            .bind(project_id)
 891            .execute(&self.pool)
 892            .await?;
 893            Ok(())
 894        })
 895    }
 896
 897    // contacts
 898
 899    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 900        test_support!(self, {
 901            let query = "
 902                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 903                FROM contacts
 904                WHERE user_id_a = $1 OR user_id_b = $1;
 905            ";
 906
 907            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 908                .bind(user_id)
 909                .fetch(&self.pool);
 910
 911            let mut contacts = Vec::new();
 912            while let Some(row) = rows.next().await {
 913                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 914
 915                if user_id_a == user_id {
 916                    if accepted {
 917                        contacts.push(Contact::Accepted {
 918                            user_id: user_id_b,
 919                            should_notify: should_notify && a_to_b,
 920                        });
 921                    } else if a_to_b {
 922                        contacts.push(Contact::Outgoing { user_id: user_id_b })
 923                    } else {
 924                        contacts.push(Contact::Incoming {
 925                            user_id: user_id_b,
 926                            should_notify,
 927                        });
 928                    }
 929                } else if accepted {
 930                    contacts.push(Contact::Accepted {
 931                        user_id: user_id_a,
 932                        should_notify: should_notify && !a_to_b,
 933                    });
 934                } else if a_to_b {
 935                    contacts.push(Contact::Incoming {
 936                        user_id: user_id_a,
 937                        should_notify,
 938                    });
 939                } else {
 940                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 941                }
 942            }
 943
 944            contacts.sort_unstable_by_key(|contact| contact.user_id());
 945
 946            Ok(contacts)
 947        })
 948    }
 949
 950    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 951        test_support!(self, {
 952            let (id_a, id_b) = if user_id_1 < user_id_2 {
 953                (user_id_1, user_id_2)
 954            } else {
 955                (user_id_2, user_id_1)
 956            };
 957
 958            let query = "
 959                SELECT 1 FROM contacts
 960                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
 961                LIMIT 1
 962            ";
 963            Ok(sqlx::query_scalar::<_, i32>(query)
 964                .bind(id_a.0)
 965                .bind(id_b.0)
 966                .fetch_optional(&self.pool)
 967                .await?
 968                .is_some())
 969        })
 970    }
 971
 972    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 973        test_support!(self, {
 974            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 975                (sender_id, receiver_id, true)
 976            } else {
 977                (receiver_id, sender_id, false)
 978            };
 979            let query = "
 980                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 981                VALUES ($1, $2, $3, FALSE, TRUE)
 982                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 983                SET
 984                    accepted = TRUE,
 985                    should_notify = FALSE
 986                WHERE
 987                    NOT contacts.accepted AND
 988                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 989                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 990            ";
 991            let result = sqlx::query(query)
 992                .bind(id_a.0)
 993                .bind(id_b.0)
 994                .bind(a_to_b)
 995                .execute(&self.pool)
 996                .await?;
 997
 998            if result.rows_affected() == 1 {
 999                Ok(())
1000            } else {
1001                Err(anyhow!("contact already requested"))?
1002            }
1003        })
1004    }
1005
1006    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1007        test_support!(self, {
1008            let (id_a, id_b) = if responder_id < requester_id {
1009                (responder_id, requester_id)
1010            } else {
1011                (requester_id, responder_id)
1012            };
1013            let query = "
1014                DELETE FROM contacts
1015                WHERE user_id_a = $1 AND user_id_b = $2;
1016            ";
1017            let result = sqlx::query(query)
1018                .bind(id_a.0)
1019                .bind(id_b.0)
1020                .execute(&self.pool)
1021                .await?;
1022
1023            if result.rows_affected() == 1 {
1024                Ok(())
1025            } else {
1026                Err(anyhow!("no such contact"))?
1027            }
1028        })
1029    }
1030
1031    pub async fn dismiss_contact_notification(
1032        &self,
1033        user_id: UserId,
1034        contact_user_id: UserId,
1035    ) -> Result<()> {
1036        test_support!(self, {
1037            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1038                (user_id, contact_user_id, true)
1039            } else {
1040                (contact_user_id, user_id, false)
1041            };
1042
1043            let query = "
1044                UPDATE contacts
1045                SET should_notify = FALSE
1046                WHERE
1047                    user_id_a = $1 AND user_id_b = $2 AND
1048                    (
1049                        (a_to_b = $3 AND accepted) OR
1050                        (a_to_b != $3 AND NOT accepted)
1051                    );
1052            ";
1053
1054            let result = sqlx::query(query)
1055                .bind(id_a.0)
1056                .bind(id_b.0)
1057                .bind(a_to_b)
1058                .execute(&self.pool)
1059                .await?;
1060
1061            if result.rows_affected() == 0 {
1062                Err(anyhow!("no such contact request"))?;
1063            }
1064
1065            Ok(())
1066        })
1067    }
1068
1069    pub async fn respond_to_contact_request(
1070        &self,
1071        responder_id: UserId,
1072        requester_id: UserId,
1073        accept: bool,
1074    ) -> Result<()> {
1075        test_support!(self, {
1076            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1077                (responder_id, requester_id, false)
1078            } else {
1079                (requester_id, responder_id, true)
1080            };
1081            let result = if accept {
1082                let query = "
1083                    UPDATE contacts
1084                    SET accepted = TRUE, should_notify = TRUE
1085                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1086                ";
1087                sqlx::query(query)
1088                    .bind(id_a.0)
1089                    .bind(id_b.0)
1090                    .bind(a_to_b)
1091                    .execute(&self.pool)
1092                    .await?
1093            } else {
1094                let query = "
1095                    DELETE FROM contacts
1096                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1097                ";
1098                sqlx::query(query)
1099                    .bind(id_a.0)
1100                    .bind(id_b.0)
1101                    .bind(a_to_b)
1102                    .execute(&self.pool)
1103                    .await?
1104            };
1105            if result.rows_affected() == 1 {
1106                Ok(())
1107            } else {
1108                Err(anyhow!("no such contact request"))?
1109            }
1110        })
1111    }
1112
1113    // access tokens
1114
1115    pub async fn create_access_token_hash(
1116        &self,
1117        user_id: UserId,
1118        access_token_hash: &str,
1119        max_access_token_count: usize,
1120    ) -> Result<()> {
1121        test_support!(self, {
1122            let insert_query = "
1123                INSERT INTO access_tokens (user_id, hash)
1124                VALUES ($1, $2);
1125            ";
1126            let cleanup_query = "
1127                DELETE FROM access_tokens
1128                WHERE id IN (
1129                    SELECT id from access_tokens
1130                    WHERE user_id = $1
1131                    ORDER BY id DESC
1132                    LIMIT 10000
1133                    OFFSET $3
1134                )
1135            ";
1136
1137            let mut tx = self.pool.begin().await?;
1138            sqlx::query(insert_query)
1139                .bind(user_id.0)
1140                .bind(access_token_hash)
1141                .execute(&mut tx)
1142                .await?;
1143            sqlx::query(cleanup_query)
1144                .bind(user_id.0)
1145                .bind(access_token_hash)
1146                .bind(max_access_token_count as i32)
1147                .execute(&mut tx)
1148                .await?;
1149            Ok(tx.commit().await?)
1150        })
1151    }
1152
1153    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1154        test_support!(self, {
1155            let query = "
1156                SELECT hash
1157                FROM access_tokens
1158                WHERE user_id = $1
1159                ORDER BY id DESC
1160            ";
1161            Ok(sqlx::query_scalar(query)
1162                .bind(user_id.0)
1163                .fetch_all(&self.pool)
1164                .await?)
1165        })
1166    }
1167}
1168
1169macro_rules! id_type {
1170    ($name:ident) => {
1171        #[derive(
1172            Clone,
1173            Copy,
1174            Debug,
1175            Default,
1176            PartialEq,
1177            Eq,
1178            PartialOrd,
1179            Ord,
1180            Hash,
1181            sqlx::Type,
1182            Serialize,
1183            Deserialize,
1184        )]
1185        #[sqlx(transparent)]
1186        #[serde(transparent)]
1187        pub struct $name(pub i32);
1188
1189        impl $name {
1190            #[allow(unused)]
1191            pub const MAX: Self = Self(i32::MAX);
1192
1193            #[allow(unused)]
1194            pub fn from_proto(value: u64) -> Self {
1195                Self(value as i32)
1196            }
1197
1198            #[allow(unused)]
1199            pub fn to_proto(self) -> u64 {
1200                self.0 as u64
1201            }
1202        }
1203
1204        impl std::fmt::Display for $name {
1205            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1206                self.0.fmt(f)
1207            }
1208        }
1209    };
1210}
1211
1212id_type!(UserId);
1213#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1214pub struct User {
1215    pub id: UserId,
1216    pub github_login: String,
1217    pub github_user_id: Option<i32>,
1218    pub email_address: Option<String>,
1219    pub admin: bool,
1220    pub invite_code: Option<String>,
1221    pub invite_count: i32,
1222    pub connected_once: bool,
1223}
1224
1225id_type!(ProjectId);
1226#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1227pub struct Project {
1228    pub id: ProjectId,
1229    pub host_user_id: UserId,
1230    pub unregistered: bool,
1231}
1232
1233#[derive(Clone, Debug, PartialEq, Eq)]
1234pub enum Contact {
1235    Accepted {
1236        user_id: UserId,
1237        should_notify: bool,
1238    },
1239    Outgoing {
1240        user_id: UserId,
1241    },
1242    Incoming {
1243        user_id: UserId,
1244        should_notify: bool,
1245    },
1246}
1247
1248impl Contact {
1249    pub fn user_id(&self) -> UserId {
1250        match self {
1251            Contact::Accepted { user_id, .. } => *user_id,
1252            Contact::Outgoing { user_id } => *user_id,
1253            Contact::Incoming { user_id, .. } => *user_id,
1254        }
1255    }
1256}
1257
1258#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1259pub struct IncomingContactRequest {
1260    pub requester_id: UserId,
1261    pub should_notify: bool,
1262}
1263
1264#[derive(Clone, Deserialize, Default)]
1265pub struct Signup {
1266    pub email_address: String,
1267    pub platform_mac: bool,
1268    pub platform_windows: bool,
1269    pub platform_linux: bool,
1270    pub editor_features: Vec<String>,
1271    pub programming_languages: Vec<String>,
1272    pub device_id: Option<String>,
1273}
1274
1275#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1276pub struct WaitlistSummary {
1277    #[sqlx(default)]
1278    pub count: i64,
1279    #[sqlx(default)]
1280    pub linux_count: i64,
1281    #[sqlx(default)]
1282    pub mac_count: i64,
1283    #[sqlx(default)]
1284    pub windows_count: i64,
1285    #[sqlx(default)]
1286    pub unknown_count: i64,
1287}
1288
1289#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)]
1290pub struct Invite {
1291    pub email_address: String,
1292    pub email_confirmation_code: String,
1293}
1294
1295#[derive(Debug, Serialize, Deserialize)]
1296pub struct NewUserParams {
1297    pub github_login: String,
1298    pub github_user_id: i32,
1299    pub invite_count: i32,
1300}
1301
1302#[derive(Debug)]
1303pub struct NewUserResult {
1304    pub user_id: UserId,
1305    pub metrics_id: String,
1306    pub inviting_user_id: Option<UserId>,
1307    pub signup_device_id: Option<String>,
1308}
1309
1310fn random_invite_code() -> String {
1311    nanoid::nanoid!(16)
1312}
1313
1314fn random_email_confirmation_code() -> String {
1315    nanoid::nanoid!(64)
1316}
1317
1318#[cfg(test)]
1319pub use test::*;
1320
1321#[cfg(test)]
1322mod test {
1323    use super::*;
1324    use gpui::executor::Background;
1325    use lazy_static::lazy_static;
1326    use parking_lot::Mutex;
1327    use rand::prelude::*;
1328    use sqlx::migrate::MigrateDatabase;
1329    use std::sync::Arc;
1330
1331    pub struct SqliteTestDb {
1332        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1333        pub conn: sqlx::sqlite::SqliteConnection,
1334    }
1335
1336    pub struct PostgresTestDb {
1337        pub db: Option<Arc<Db<sqlx::Postgres>>>,
1338        pub url: String,
1339    }
1340
1341    impl SqliteTestDb {
1342        pub fn new(background: Arc<Background>) -> Self {
1343            let mut rng = StdRng::from_entropy();
1344            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1345            let runtime = tokio::runtime::Builder::new_current_thread()
1346                .enable_io()
1347                .enable_time()
1348                .build()
1349                .unwrap();
1350
1351            let (mut db, conn) = runtime.block_on(async {
1352                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1353                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1354                db.migrate(migrations_path.as_ref(), false).await.unwrap();
1355                let conn = db.pool.acquire().await.unwrap().detach();
1356                (db, conn)
1357            });
1358
1359            db.background = Some(background);
1360            db.runtime = Some(runtime);
1361
1362            Self {
1363                db: Some(Arc::new(db)),
1364                conn,
1365            }
1366        }
1367
1368        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1369            self.db.as_ref().unwrap()
1370        }
1371    }
1372
1373    impl PostgresTestDb {
1374        pub fn new(background: Arc<Background>) -> Self {
1375            lazy_static! {
1376                static ref LOCK: Mutex<()> = Mutex::new(());
1377            }
1378
1379            let _guard = LOCK.lock();
1380            let mut rng = StdRng::from_entropy();
1381            let url = format!(
1382                "postgres://postgres@localhost/zed-test-{}",
1383                rng.gen::<u128>()
1384            );
1385            let runtime = tokio::runtime::Builder::new_current_thread()
1386                .enable_io()
1387                .enable_time()
1388                .build()
1389                .unwrap();
1390
1391            let mut db = runtime.block_on(async {
1392                sqlx::Postgres::create_database(&url)
1393                    .await
1394                    .expect("failed to create test db");
1395                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1396                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1397                db.migrate(Path::new(migrations_path), false).await.unwrap();
1398                db
1399            });
1400
1401            db.background = Some(background);
1402            db.runtime = Some(runtime);
1403
1404            Self {
1405                db: Some(Arc::new(db)),
1406                url,
1407            }
1408        }
1409
1410        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1411            self.db.as_ref().unwrap()
1412        }
1413    }
1414
1415    impl Drop for PostgresTestDb {
1416        fn drop(&mut self) {
1417            let db = self.db.take().unwrap();
1418            db.teardown(&self.url);
1419        }
1420    }
1421}