db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use serde::{Deserialize, Serialize};
   7use sqlx::{
   8    migrate::{Migrate as _, Migration, MigrationSource},
   9    types::Uuid,
  10    FromRow,
  11};
  12use std::{path::Path, time::Duration};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[cfg(test)]
  16pub type DefaultDb = Db<sqlx::Sqlite>;
  17
  18#[cfg(not(test))]
  19pub type DefaultDb = Db<sqlx::Postgres>;
  20
  21pub struct Db<D: sqlx::Database> {
  22    pool: sqlx::Pool<D>,
  23    #[cfg(test)]
  24    background: Option<std::sync::Arc<gpui::executor::Background>>,
  25    #[cfg(test)]
  26    runtime: Option<tokio::runtime::Runtime>,
  27}
  28
  29macro_rules! test_support {
  30    ($self:ident, { $($token:tt)* }) => {{
  31        let body = async {
  32            $($token)*
  33        };
  34
  35        if cfg!(test) {
  36            #[cfg(not(test))]
  37            unreachable!();
  38
  39            #[cfg(test)]
  40            if let Some(background) = $self.background.as_ref() {
  41                background.simulate_random_delay().await;
  42            }
  43
  44            #[cfg(test)]
  45            $self.runtime.as_ref().unwrap().block_on(body)
  46        } else {
  47            body.await
  48        }
  49    }};
  50}
  51
  52pub trait RowsAffected {
  53    fn rows_affected(&self) -> u64;
  54}
  55
  56#[cfg(test)]
  57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  58    fn rows_affected(&self) -> u64 {
  59        self.rows_affected()
  60    }
  61}
  62
  63impl RowsAffected for sqlx::postgres::PgQueryResult {
  64    fn rows_affected(&self) -> u64 {
  65        self.rows_affected()
  66    }
  67}
  68
  69#[cfg(test)]
  70impl Db<sqlx::Sqlite> {
  71    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  72        use std::str::FromStr as _;
  73        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  74            .unwrap()
  75            .create_if_missing(true)
  76            .shared_cache(true);
  77        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  78            .min_connections(2)
  79            .max_connections(max_connections)
  80            .connect_with(options)
  81            .await?;
  82        Ok(Self {
  83            pool,
  84            background: None,
  85            runtime: None,
  86        })
  87    }
  88
  89    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  90        test_support!(self, {
  91            let query = "
  92                SELECT users.*
  93                FROM users
  94                WHERE users.id IN (SELECT value from json_each($1))
  95            ";
  96            Ok(sqlx::query_as(query)
  97                .bind(&serde_json::json!(ids))
  98                .fetch_all(&self.pool)
  99                .await?)
 100        })
 101    }
 102
 103    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 104        test_support!(self, {
 105            let query = "
 106                SELECT metrics_id
 107                FROM users
 108                WHERE id = $1
 109            ";
 110            Ok(sqlx::query_scalar(query)
 111                .bind(id)
 112                .fetch_one(&self.pool)
 113                .await?)
 114        })
 115    }
 116
 117    pub async fn create_user(
 118        &self,
 119        email_address: &str,
 120        admin: bool,
 121        params: NewUserParams,
 122    ) -> Result<NewUserResult> {
 123        test_support!(self, {
 124            let query = "
 125                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 126                VALUES ($1, $2, $3, $4, $5)
 127                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 128                RETURNING id, metrics_id
 129            ";
 130
 131            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 132                .bind(email_address)
 133                .bind(params.github_login)
 134                .bind(params.github_user_id)
 135                .bind(admin)
 136                .bind(Uuid::new_v4().to_string())
 137                .fetch_one(&self.pool)
 138                .await?;
 139            Ok(NewUserResult {
 140                user_id,
 141                metrics_id,
 142                signup_device_id: None,
 143                inviting_user_id: None,
 144            })
 145        })
 146    }
 147
 148    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 149        unimplemented!()
 150    }
 151
 152    pub async fn create_user_from_invite(
 153        &self,
 154        _invite: &Invite,
 155        _user: NewUserParams,
 156    ) -> Result<Option<NewUserResult>> {
 157        unimplemented!()
 158    }
 159
 160    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 161        unimplemented!()
 162    }
 163
 164    pub async fn create_invite_from_code(
 165        &self,
 166        _code: &str,
 167        _email_address: &str,
 168        _device_id: Option<&str>,
 169    ) -> Result<Invite> {
 170        unimplemented!()
 171    }
 172
 173    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 174        unimplemented!()
 175    }
 176}
 177
 178impl Db<sqlx::Postgres> {
 179    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 180        let pool = sqlx::postgres::PgPoolOptions::new()
 181            .max_connections(max_connections)
 182            .connect(url)
 183            .await?;
 184        Ok(Self {
 185            pool,
 186            #[cfg(test)]
 187            background: None,
 188            #[cfg(test)]
 189            runtime: None,
 190        })
 191    }
 192
 193    #[cfg(test)]
 194    pub fn teardown(&self, url: &str) {
 195        self.runtime.as_ref().unwrap().block_on(async {
 196            use util::ResultExt;
 197            let query = "
 198                SELECT pg_terminate_backend(pg_stat_activity.pid)
 199                FROM pg_stat_activity
 200                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 201            ";
 202            sqlx::query(query).execute(&self.pool).await.log_err();
 203            self.pool.close().await;
 204            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 205                .await
 206                .log_err();
 207        })
 208    }
 209
 210    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 211        test_support!(self, {
 212            let like_string = Self::fuzzy_like_string(name_query);
 213            let query = "
 214                SELECT users.*
 215                FROM users
 216                WHERE github_login ILIKE $1
 217                ORDER BY github_login <-> $2
 218                LIMIT $3
 219            ";
 220            Ok(sqlx::query_as(query)
 221                .bind(like_string)
 222                .bind(name_query)
 223                .bind(limit as i32)
 224                .fetch_all(&self.pool)
 225                .await?)
 226        })
 227    }
 228
 229    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 230        test_support!(self, {
 231            let query = "
 232                SELECT users.*
 233                FROM users
 234                WHERE users.id = ANY ($1)
 235            ";
 236            Ok(sqlx::query_as(query)
 237                .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
 238                .fetch_all(&self.pool)
 239                .await?)
 240        })
 241    }
 242
 243    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 244        test_support!(self, {
 245            let query = "
 246                SELECT metrics_id::text
 247                FROM users
 248                WHERE id = $1
 249            ";
 250            Ok(sqlx::query_scalar(query)
 251                .bind(id)
 252                .fetch_one(&self.pool)
 253                .await?)
 254        })
 255    }
 256
 257    pub async fn create_user(
 258        &self,
 259        email_address: &str,
 260        admin: bool,
 261        params: NewUserParams,
 262    ) -> Result<NewUserResult> {
 263        test_support!(self, {
 264            let query = "
 265                INSERT INTO users (email_address, github_login, github_user_id, admin)
 266                VALUES ($1, $2, $3, $4)
 267                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 268                RETURNING id, metrics_id::text
 269            ";
 270
 271            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 272                .bind(email_address)
 273                .bind(params.github_login)
 274                .bind(params.github_user_id)
 275                .bind(admin)
 276                .fetch_one(&self.pool)
 277                .await?;
 278            Ok(NewUserResult {
 279                user_id,
 280                metrics_id,
 281                signup_device_id: None,
 282                inviting_user_id: None,
 283            })
 284        })
 285    }
 286
 287    pub async fn create_user_from_invite(
 288        &self,
 289        invite: &Invite,
 290        user: NewUserParams,
 291    ) -> Result<Option<NewUserResult>> {
 292        test_support!(self, {
 293            let mut tx = self.pool.begin().await?;
 294
 295            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 296                i32,
 297                Option<UserId>,
 298                Option<UserId>,
 299                Option<String>,
 300            ) = sqlx::query_as(
 301                "
 302                SELECT id, user_id, inviting_user_id, device_id
 303                FROM signups
 304                WHERE
 305                    email_address = $1 AND
 306                    email_confirmation_code = $2
 307                ",
 308            )
 309            .bind(&invite.email_address)
 310            .bind(&invite.email_confirmation_code)
 311            .fetch_optional(&mut tx)
 312            .await?
 313            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 314
 315            if existing_user_id.is_some() {
 316                return Ok(None);
 317            }
 318
 319            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 320                "
 321                INSERT INTO users
 322                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 323                VALUES
 324                ($1, $2, $3, FALSE, $4, $5)
 325                ON CONFLICT (github_login) DO UPDATE SET
 326                    email_address = excluded.email_address,
 327                    github_user_id = excluded.github_user_id,
 328                    admin = excluded.admin
 329                RETURNING id, metrics_id::text
 330                ",
 331            )
 332            .bind(&invite.email_address)
 333            .bind(&user.github_login)
 334            .bind(&user.github_user_id)
 335            .bind(&user.invite_count)
 336            .bind(random_invite_code())
 337            .fetch_one(&mut tx)
 338            .await?;
 339
 340            sqlx::query(
 341                "
 342                UPDATE signups
 343                SET user_id = $1
 344                WHERE id = $2
 345                ",
 346            )
 347            .bind(&user_id)
 348            .bind(&signup_id)
 349            .execute(&mut tx)
 350            .await?;
 351
 352            if let Some(inviting_user_id) = inviting_user_id {
 353                let id: Option<UserId> = sqlx::query_scalar(
 354                    "
 355                    UPDATE users
 356                    SET invite_count = invite_count - 1
 357                    WHERE id = $1 AND invite_count > 0
 358                    RETURNING id
 359                    ",
 360                )
 361                .bind(&inviting_user_id)
 362                .fetch_optional(&mut tx)
 363                .await?;
 364
 365                if id.is_none() {
 366                    Err(Error::Http(
 367                        StatusCode::UNAUTHORIZED,
 368                        "no invites remaining".to_string(),
 369                    ))?;
 370                }
 371
 372                sqlx::query(
 373                    "
 374                    INSERT INTO contacts
 375                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 376                    VALUES
 377                        ($1, $2, TRUE, TRUE, TRUE)
 378                    ON CONFLICT DO NOTHING
 379                    ",
 380                )
 381                .bind(inviting_user_id)
 382                .bind(user_id)
 383                .execute(&mut tx)
 384                .await?;
 385            }
 386
 387            tx.commit().await?;
 388            Ok(Some(NewUserResult {
 389                user_id,
 390                metrics_id,
 391                inviting_user_id,
 392                signup_device_id,
 393            }))
 394        })
 395    }
 396
 397    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 398        test_support!(self, {
 399            sqlx::query(
 400                "
 401                INSERT INTO signups
 402                (
 403                    email_address,
 404                    email_confirmation_code,
 405                    email_confirmation_sent,
 406                    platform_linux,
 407                    platform_mac,
 408                    platform_windows,
 409                    platform_unknown,
 410                    editor_features,
 411                    programming_languages,
 412                    device_id
 413                )
 414                VALUES
 415                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 416                RETURNING id
 417                ",
 418            )
 419            .bind(&signup.email_address)
 420            .bind(&random_email_confirmation_code())
 421            .bind(&signup.platform_linux)
 422            .bind(&signup.platform_mac)
 423            .bind(&signup.platform_windows)
 424            .bind(&signup.editor_features)
 425            .bind(&signup.programming_languages)
 426            .bind(&signup.device_id)
 427            .execute(&self.pool)
 428            .await?;
 429            Ok(())
 430        })
 431    }
 432
 433    pub async fn create_invite_from_code(
 434        &self,
 435        code: &str,
 436        email_address: &str,
 437        device_id: Option<&str>,
 438    ) -> Result<Invite> {
 439        test_support!(self, {
 440            let mut tx = self.pool.begin().await?;
 441
 442            let existing_user: Option<UserId> = sqlx::query_scalar(
 443                "
 444                SELECT id
 445                FROM users
 446                WHERE email_address = $1
 447                ",
 448            )
 449            .bind(email_address)
 450            .fetch_optional(&mut tx)
 451            .await?;
 452            if existing_user.is_some() {
 453                Err(anyhow!("email address is already in use"))?;
 454            }
 455
 456            let row: Option<(UserId, i32)> = sqlx::query_as(
 457                "
 458                SELECT id, invite_count
 459                FROM users
 460                WHERE invite_code = $1
 461                ",
 462            )
 463            .bind(code)
 464            .fetch_optional(&mut tx)
 465            .await?;
 466
 467            let (inviter_id, invite_count) = match row {
 468                Some(row) => row,
 469                None => Err(Error::Http(
 470                    StatusCode::NOT_FOUND,
 471                    "invite code not found".to_string(),
 472                ))?,
 473            };
 474
 475            if invite_count == 0 {
 476                Err(Error::Http(
 477                    StatusCode::UNAUTHORIZED,
 478                    "no invites remaining".to_string(),
 479                ))?;
 480            }
 481
 482            let email_confirmation_code: String = sqlx::query_scalar(
 483                "
 484                INSERT INTO signups
 485                (
 486                    email_address,
 487                    email_confirmation_code,
 488                    email_confirmation_sent,
 489                    inviting_user_id,
 490                    platform_linux,
 491                    platform_mac,
 492                    platform_windows,
 493                    platform_unknown,
 494                    device_id
 495                )
 496                VALUES
 497                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 498                ON CONFLICT (email_address)
 499                DO UPDATE SET
 500                    inviting_user_id = excluded.inviting_user_id
 501                RETURNING email_confirmation_code
 502                ",
 503            )
 504            .bind(&email_address)
 505            .bind(&random_email_confirmation_code())
 506            .bind(&inviter_id)
 507            .bind(&device_id)
 508            .fetch_one(&mut tx)
 509            .await?;
 510
 511            tx.commit().await?;
 512
 513            Ok(Invite {
 514                email_address: email_address.into(),
 515                email_confirmation_code,
 516            })
 517        })
 518    }
 519
 520    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 521        test_support!(self, {
 522            let emails = invites
 523                .iter()
 524                .map(|s| s.email_address.as_str())
 525                .collect::<Vec<_>>();
 526            sqlx::query(
 527                "
 528                UPDATE signups
 529                SET email_confirmation_sent = TRUE
 530                WHERE email_address = ANY ($1)
 531                ",
 532            )
 533            .bind(&emails)
 534            .execute(&self.pool)
 535            .await?;
 536            Ok(())
 537        })
 538    }
 539}
 540
 541impl<D> Db<D>
 542where
 543    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 544    D::Connection: sqlx::migrate::Migrate,
 545    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 546    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 547    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 548    D::QueryResult: RowsAffected,
 549    String: sqlx::Type<D>,
 550    i32: sqlx::Type<D>,
 551    i64: sqlx::Type<D>,
 552    bool: sqlx::Type<D>,
 553    str: sqlx::Type<D>,
 554    Uuid: sqlx::Type<D>,
 555    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 556    OffsetDateTime: sqlx::Type<D>,
 557    PrimitiveDateTime: sqlx::Type<D>,
 558    usize: sqlx::ColumnIndex<D::Row>,
 559    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 560    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 561    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 562    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 563    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 564    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 565    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 566    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 567    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 568    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 569    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 570    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 571{
 572    pub async fn migrate(
 573        &self,
 574        migrations_path: &Path,
 575        ignore_checksum_mismatch: bool,
 576    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 577        let migrations = MigrationSource::resolve(migrations_path)
 578            .await
 579            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 580
 581        let mut conn = self.pool.acquire().await?;
 582
 583        conn.ensure_migrations_table().await?;
 584        let applied_migrations: HashMap<_, _> = conn
 585            .list_applied_migrations()
 586            .await?
 587            .into_iter()
 588            .map(|m| (m.version, m))
 589            .collect();
 590
 591        let mut new_migrations = Vec::new();
 592        for migration in migrations {
 593            match applied_migrations.get(&migration.version) {
 594                Some(applied_migration) => {
 595                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 596                    {
 597                        Err(anyhow!(
 598                            "checksum mismatch for applied migration {}",
 599                            migration.description
 600                        ))?;
 601                    }
 602                }
 603                None => {
 604                    let elapsed = conn.apply(&migration).await?;
 605                    new_migrations.push((migration, elapsed));
 606                }
 607            }
 608        }
 609
 610        Ok(new_migrations)
 611    }
 612
 613    pub fn fuzzy_like_string(string: &str) -> String {
 614        let mut result = String::with_capacity(string.len() * 2 + 1);
 615        for c in string.chars() {
 616            if c.is_alphanumeric() {
 617                result.push('%');
 618                result.push(c);
 619            }
 620        }
 621        result.push('%');
 622        result
 623    }
 624
 625    // users
 626
 627    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 628        test_support!(self, {
 629            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 630            Ok(sqlx::query_as(query)
 631                .bind(limit as i32)
 632                .bind((page * limit) as i32)
 633                .fetch_all(&self.pool)
 634                .await?)
 635        })
 636    }
 637
 638    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 639        test_support!(self, {
 640            let query = "
 641                SELECT users.*
 642                FROM users
 643                WHERE id = $1
 644                LIMIT 1
 645            ";
 646            Ok(sqlx::query_as(query)
 647                .bind(&id)
 648                .fetch_optional(&self.pool)
 649                .await?)
 650        })
 651    }
 652
 653    pub async fn get_users_with_no_invites(
 654        &self,
 655        invited_by_another_user: bool,
 656    ) -> Result<Vec<User>> {
 657        test_support!(self, {
 658            let query = format!(
 659                "
 660                SELECT users.*
 661                FROM users
 662                WHERE invite_count = 0
 663                AND inviter_id IS{} NULL
 664                ",
 665                if invited_by_another_user { " NOT" } else { "" }
 666            );
 667
 668            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 669        })
 670    }
 671
 672    pub async fn get_user_by_github_account(
 673        &self,
 674        github_login: &str,
 675        github_user_id: Option<i32>,
 676    ) -> Result<Option<User>> {
 677        test_support!(self, {
 678            if let Some(github_user_id) = github_user_id {
 679                let mut user = sqlx::query_as::<_, User>(
 680                    "
 681                    UPDATE users
 682                    SET github_login = $1
 683                    WHERE github_user_id = $2
 684                    RETURNING *
 685                    ",
 686                )
 687                .bind(github_login)
 688                .bind(github_user_id)
 689                .fetch_optional(&self.pool)
 690                .await?;
 691
 692                if user.is_none() {
 693                    user = sqlx::query_as::<_, User>(
 694                        "
 695                        UPDATE users
 696                        SET github_user_id = $1
 697                        WHERE github_login = $2
 698                        RETURNING *
 699                        ",
 700                    )
 701                    .bind(github_user_id)
 702                    .bind(github_login)
 703                    .fetch_optional(&self.pool)
 704                    .await?;
 705                }
 706
 707                Ok(user)
 708            } else {
 709                let user = sqlx::query_as(
 710                    "
 711                    SELECT * FROM users
 712                    WHERE github_login = $1
 713                    LIMIT 1
 714                    ",
 715                )
 716                .bind(github_login)
 717                .fetch_optional(&self.pool)
 718                .await?;
 719                Ok(user)
 720            }
 721        })
 722    }
 723
 724    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 725        test_support!(self, {
 726            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 727            Ok(sqlx::query(query)
 728                .bind(is_admin)
 729                .bind(id.0)
 730                .execute(&self.pool)
 731                .await
 732                .map(drop)?)
 733        })
 734    }
 735
 736    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 737        test_support!(self, {
 738            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 739            Ok(sqlx::query(query)
 740                .bind(connected_once)
 741                .bind(id.0)
 742                .execute(&self.pool)
 743                .await
 744                .map(drop)?)
 745        })
 746    }
 747
 748    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 749        test_support!(self, {
 750            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 751            sqlx::query(query)
 752                .bind(id.0)
 753                .execute(&self.pool)
 754                .await
 755                .map(drop)?;
 756            let query = "DELETE FROM users WHERE id = $1;";
 757            Ok(sqlx::query(query)
 758                .bind(id.0)
 759                .execute(&self.pool)
 760                .await
 761                .map(drop)?)
 762        })
 763    }
 764
 765    // signups
 766
 767    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 768        test_support!(self, {
 769            Ok(sqlx::query_as(
 770                "
 771                SELECT
 772                    COUNT(*) as count,
 773                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 774                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 775                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 776                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 777                FROM (
 778                    SELECT *
 779                    FROM signups
 780                    WHERE
 781                        NOT email_confirmation_sent
 782                ) AS unsent
 783                ",
 784            )
 785            .fetch_one(&self.pool)
 786            .await?)
 787        })
 788    }
 789
 790    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 791        test_support!(self, {
 792            Ok(sqlx::query_as(
 793                "
 794                SELECT
 795                    email_address, email_confirmation_code
 796                FROM signups
 797                WHERE
 798                    NOT email_confirmation_sent AND
 799                    (platform_mac OR platform_unknown)
 800                LIMIT $1
 801                ",
 802            )
 803            .bind(count as i32)
 804            .fetch_all(&self.pool)
 805            .await?)
 806        })
 807    }
 808
 809    // invite codes
 810
 811    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 812        test_support!(self, {
 813            let mut tx = self.pool.begin().await?;
 814            if count > 0 {
 815                sqlx::query(
 816                    "
 817                    UPDATE users
 818                    SET invite_code = $1
 819                    WHERE id = $2 AND invite_code IS NULL
 820                ",
 821                )
 822                .bind(random_invite_code())
 823                .bind(id)
 824                .execute(&mut tx)
 825                .await?;
 826            }
 827
 828            sqlx::query(
 829                "
 830                UPDATE users
 831                SET invite_count = $1
 832                WHERE id = $2
 833                ",
 834            )
 835            .bind(count as i32)
 836            .bind(id)
 837            .execute(&mut tx)
 838            .await?;
 839            tx.commit().await?;
 840            Ok(())
 841        })
 842    }
 843
 844    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 845        test_support!(self, {
 846            let result: Option<(String, i32)> = sqlx::query_as(
 847                "
 848                    SELECT invite_code, invite_count
 849                    FROM users
 850                    WHERE id = $1 AND invite_code IS NOT NULL 
 851                ",
 852            )
 853            .bind(id)
 854            .fetch_optional(&self.pool)
 855            .await?;
 856            if let Some((code, count)) = result {
 857                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 858            } else {
 859                Ok(None)
 860            }
 861        })
 862    }
 863
 864    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 865        test_support!(self, {
 866            sqlx::query_as(
 867                "
 868                    SELECT *
 869                    FROM users
 870                    WHERE invite_code = $1
 871                ",
 872            )
 873            .bind(code)
 874            .fetch_optional(&self.pool)
 875            .await?
 876            .ok_or_else(|| {
 877                Error::Http(
 878                    StatusCode::NOT_FOUND,
 879                    "that invite code does not exist".to_string(),
 880                )
 881            })
 882        })
 883    }
 884
 885    // projects
 886
 887    /// Registers a new project for the given user.
 888    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 889        test_support!(self, {
 890            Ok(sqlx::query_scalar(
 891                "
 892                INSERT INTO projects(host_user_id)
 893                VALUES ($1)
 894                RETURNING id
 895                ",
 896            )
 897            .bind(host_user_id)
 898            .fetch_one(&self.pool)
 899            .await
 900            .map(ProjectId)?)
 901        })
 902    }
 903
 904    /// Unregisters a project for the given project id.
 905    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 906        test_support!(self, {
 907            sqlx::query(
 908                "
 909                UPDATE projects
 910                SET unregistered = TRUE
 911                WHERE id = $1
 912                ",
 913            )
 914            .bind(project_id)
 915            .execute(&self.pool)
 916            .await?;
 917            Ok(())
 918        })
 919    }
 920
 921    // contacts
 922
 923    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 924        test_support!(self, {
 925            let query = "
 926                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 927                FROM contacts
 928                WHERE user_id_a = $1 OR user_id_b = $1;
 929            ";
 930
 931            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 932                .bind(user_id)
 933                .fetch(&self.pool);
 934
 935            let mut contacts = Vec::new();
 936            while let Some(row) = rows.next().await {
 937                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 938
 939                if user_id_a == user_id {
 940                    if accepted {
 941                        contacts.push(Contact::Accepted {
 942                            user_id: user_id_b,
 943                            should_notify: should_notify && a_to_b,
 944                        });
 945                    } else if a_to_b {
 946                        contacts.push(Contact::Outgoing { user_id: user_id_b })
 947                    } else {
 948                        contacts.push(Contact::Incoming {
 949                            user_id: user_id_b,
 950                            should_notify,
 951                        });
 952                    }
 953                } else if accepted {
 954                    contacts.push(Contact::Accepted {
 955                        user_id: user_id_a,
 956                        should_notify: should_notify && !a_to_b,
 957                    });
 958                } else if a_to_b {
 959                    contacts.push(Contact::Incoming {
 960                        user_id: user_id_a,
 961                        should_notify,
 962                    });
 963                } else {
 964                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 965                }
 966            }
 967
 968            contacts.sort_unstable_by_key(|contact| contact.user_id());
 969
 970            Ok(contacts)
 971        })
 972    }
 973
 974    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 975        test_support!(self, {
 976            let (id_a, id_b) = if user_id_1 < user_id_2 {
 977                (user_id_1, user_id_2)
 978            } else {
 979                (user_id_2, user_id_1)
 980            };
 981
 982            let query = "
 983                SELECT 1 FROM contacts
 984                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
 985                LIMIT 1
 986            ";
 987            Ok(sqlx::query_scalar::<_, i32>(query)
 988                .bind(id_a.0)
 989                .bind(id_b.0)
 990                .fetch_optional(&self.pool)
 991                .await?
 992                .is_some())
 993        })
 994    }
 995
 996    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 997        test_support!(self, {
 998            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 999                (sender_id, receiver_id, true)
1000            } else {
1001                (receiver_id, sender_id, false)
1002            };
1003            let query = "
1004                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1005                VALUES ($1, $2, $3, FALSE, TRUE)
1006                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1007                SET
1008                    accepted = TRUE,
1009                    should_notify = FALSE
1010                WHERE
1011                    NOT contacts.accepted AND
1012                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1013                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1014            ";
1015            let result = sqlx::query(query)
1016                .bind(id_a.0)
1017                .bind(id_b.0)
1018                .bind(a_to_b)
1019                .execute(&self.pool)
1020                .await?;
1021
1022            if result.rows_affected() == 1 {
1023                Ok(())
1024            } else {
1025                Err(anyhow!("contact already requested"))?
1026            }
1027        })
1028    }
1029
1030    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1031        test_support!(self, {
1032            let (id_a, id_b) = if responder_id < requester_id {
1033                (responder_id, requester_id)
1034            } else {
1035                (requester_id, responder_id)
1036            };
1037            let query = "
1038                DELETE FROM contacts
1039                WHERE user_id_a = $1 AND user_id_b = $2;
1040            ";
1041            let result = sqlx::query(query)
1042                .bind(id_a.0)
1043                .bind(id_b.0)
1044                .execute(&self.pool)
1045                .await?;
1046
1047            if result.rows_affected() == 1 {
1048                Ok(())
1049            } else {
1050                Err(anyhow!("no such contact"))?
1051            }
1052        })
1053    }
1054
1055    pub async fn dismiss_contact_notification(
1056        &self,
1057        user_id: UserId,
1058        contact_user_id: UserId,
1059    ) -> Result<()> {
1060        test_support!(self, {
1061            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1062                (user_id, contact_user_id, true)
1063            } else {
1064                (contact_user_id, user_id, false)
1065            };
1066
1067            let query = "
1068                UPDATE contacts
1069                SET should_notify = FALSE
1070                WHERE
1071                    user_id_a = $1 AND user_id_b = $2 AND
1072                    (
1073                        (a_to_b = $3 AND accepted) OR
1074                        (a_to_b != $3 AND NOT accepted)
1075                    );
1076            ";
1077
1078            let result = sqlx::query(query)
1079                .bind(id_a.0)
1080                .bind(id_b.0)
1081                .bind(a_to_b)
1082                .execute(&self.pool)
1083                .await?;
1084
1085            if result.rows_affected() == 0 {
1086                Err(anyhow!("no such contact request"))?;
1087            }
1088
1089            Ok(())
1090        })
1091    }
1092
1093    pub async fn respond_to_contact_request(
1094        &self,
1095        responder_id: UserId,
1096        requester_id: UserId,
1097        accept: bool,
1098    ) -> Result<()> {
1099        test_support!(self, {
1100            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1101                (responder_id, requester_id, false)
1102            } else {
1103                (requester_id, responder_id, true)
1104            };
1105            let result = if accept {
1106                let query = "
1107                    UPDATE contacts
1108                    SET accepted = TRUE, should_notify = TRUE
1109                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1110                ";
1111                sqlx::query(query)
1112                    .bind(id_a.0)
1113                    .bind(id_b.0)
1114                    .bind(a_to_b)
1115                    .execute(&self.pool)
1116                    .await?
1117            } else {
1118                let query = "
1119                    DELETE FROM contacts
1120                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1121                ";
1122                sqlx::query(query)
1123                    .bind(id_a.0)
1124                    .bind(id_b.0)
1125                    .bind(a_to_b)
1126                    .execute(&self.pool)
1127                    .await?
1128            };
1129            if result.rows_affected() == 1 {
1130                Ok(())
1131            } else {
1132                Err(anyhow!("no such contact request"))?
1133            }
1134        })
1135    }
1136
1137    // access tokens
1138
1139    pub async fn create_access_token_hash(
1140        &self,
1141        user_id: UserId,
1142        access_token_hash: &str,
1143        max_access_token_count: usize,
1144    ) -> Result<()> {
1145        test_support!(self, {
1146            let insert_query = "
1147                INSERT INTO access_tokens (user_id, hash)
1148                VALUES ($1, $2);
1149            ";
1150            let cleanup_query = "
1151                DELETE FROM access_tokens
1152                WHERE id IN (
1153                    SELECT id from access_tokens
1154                    WHERE user_id = $1
1155                    ORDER BY id DESC
1156                    LIMIT 10000
1157                    OFFSET $3
1158                )
1159            ";
1160
1161            let mut tx = self.pool.begin().await?;
1162            sqlx::query(insert_query)
1163                .bind(user_id.0)
1164                .bind(access_token_hash)
1165                .execute(&mut tx)
1166                .await?;
1167            sqlx::query(cleanup_query)
1168                .bind(user_id.0)
1169                .bind(access_token_hash)
1170                .bind(max_access_token_count as i32)
1171                .execute(&mut tx)
1172                .await?;
1173            Ok(tx.commit().await?)
1174        })
1175    }
1176
1177    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1178        test_support!(self, {
1179            let query = "
1180                SELECT hash
1181                FROM access_tokens
1182                WHERE user_id = $1
1183                ORDER BY id DESC
1184            ";
1185            Ok(sqlx::query_scalar(query)
1186                .bind(user_id.0)
1187                .fetch_all(&self.pool)
1188                .await?)
1189        })
1190    }
1191}
1192
1193macro_rules! id_type {
1194    ($name:ident) => {
1195        #[derive(
1196            Clone,
1197            Copy,
1198            Debug,
1199            Default,
1200            PartialEq,
1201            Eq,
1202            PartialOrd,
1203            Ord,
1204            Hash,
1205            sqlx::Type,
1206            Serialize,
1207            Deserialize,
1208        )]
1209        #[sqlx(transparent)]
1210        #[serde(transparent)]
1211        pub struct $name(pub i32);
1212
1213        impl $name {
1214            #[allow(unused)]
1215            pub const MAX: Self = Self(i32::MAX);
1216
1217            #[allow(unused)]
1218            pub fn from_proto(value: u64) -> Self {
1219                Self(value as i32)
1220            }
1221
1222            #[allow(unused)]
1223            pub fn to_proto(self) -> u64 {
1224                self.0 as u64
1225            }
1226        }
1227
1228        impl std::fmt::Display for $name {
1229            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1230                self.0.fmt(f)
1231            }
1232        }
1233    };
1234}
1235
1236id_type!(UserId);
1237#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1238pub struct User {
1239    pub id: UserId,
1240    pub github_login: String,
1241    pub github_user_id: Option<i32>,
1242    pub email_address: Option<String>,
1243    pub admin: bool,
1244    pub invite_code: Option<String>,
1245    pub invite_count: i32,
1246    pub connected_once: bool,
1247}
1248
1249id_type!(ProjectId);
1250#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1251pub struct Project {
1252    pub id: ProjectId,
1253    pub host_user_id: UserId,
1254    pub unregistered: bool,
1255}
1256
1257#[derive(Clone, Debug, PartialEq, Eq)]
1258pub enum Contact {
1259    Accepted {
1260        user_id: UserId,
1261        should_notify: bool,
1262    },
1263    Outgoing {
1264        user_id: UserId,
1265    },
1266    Incoming {
1267        user_id: UserId,
1268        should_notify: bool,
1269    },
1270}
1271
1272impl Contact {
1273    pub fn user_id(&self) -> UserId {
1274        match self {
1275            Contact::Accepted { user_id, .. } => *user_id,
1276            Contact::Outgoing { user_id } => *user_id,
1277            Contact::Incoming { user_id, .. } => *user_id,
1278        }
1279    }
1280}
1281
1282#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1283pub struct IncomingContactRequest {
1284    pub requester_id: UserId,
1285    pub should_notify: bool,
1286}
1287
1288#[derive(Clone, Deserialize)]
1289pub struct Signup {
1290    pub email_address: String,
1291    pub platform_mac: bool,
1292    pub platform_windows: bool,
1293    pub platform_linux: bool,
1294    pub editor_features: Vec<String>,
1295    pub programming_languages: Vec<String>,
1296    pub device_id: Option<String>,
1297}
1298
1299#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1300pub struct WaitlistSummary {
1301    #[sqlx(default)]
1302    pub count: i64,
1303    #[sqlx(default)]
1304    pub linux_count: i64,
1305    #[sqlx(default)]
1306    pub mac_count: i64,
1307    #[sqlx(default)]
1308    pub windows_count: i64,
1309    #[sqlx(default)]
1310    pub unknown_count: i64,
1311}
1312
1313#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1314pub struct Invite {
1315    pub email_address: String,
1316    pub email_confirmation_code: String,
1317}
1318
1319#[derive(Debug, Serialize, Deserialize)]
1320pub struct NewUserParams {
1321    pub github_login: String,
1322    pub github_user_id: i32,
1323    pub invite_count: i32,
1324}
1325
1326#[derive(Debug)]
1327pub struct NewUserResult {
1328    pub user_id: UserId,
1329    pub metrics_id: String,
1330    pub inviting_user_id: Option<UserId>,
1331    pub signup_device_id: Option<String>,
1332}
1333
1334fn random_invite_code() -> String {
1335    nanoid::nanoid!(16)
1336}
1337
1338fn random_email_confirmation_code() -> String {
1339    nanoid::nanoid!(64)
1340}
1341
1342#[cfg(test)]
1343pub use test::*;
1344
1345#[cfg(test)]
1346mod test {
1347    use super::*;
1348    use gpui::executor::Background;
1349    use lazy_static::lazy_static;
1350    use parking_lot::Mutex;
1351    use rand::prelude::*;
1352    use sqlx::migrate::MigrateDatabase;
1353    use std::sync::Arc;
1354
1355    pub struct SqliteTestDb {
1356        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1357        pub conn: sqlx::sqlite::SqliteConnection,
1358    }
1359
1360    pub struct PostgresTestDb {
1361        pub db: Option<Arc<Db<sqlx::Postgres>>>,
1362        pub url: String,
1363    }
1364
1365    impl SqliteTestDb {
1366        pub fn new(background: Arc<Background>) -> Self {
1367            let mut rng = StdRng::from_entropy();
1368            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1369            let runtime = tokio::runtime::Builder::new_current_thread()
1370                .enable_io()
1371                .enable_time()
1372                .build()
1373                .unwrap();
1374
1375            let (mut db, conn) = runtime.block_on(async {
1376                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1377                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1378                db.migrate(migrations_path.as_ref(), false).await.unwrap();
1379                let conn = db.pool.acquire().await.unwrap().detach();
1380                (db, conn)
1381            });
1382
1383            db.background = Some(background);
1384            db.runtime = Some(runtime);
1385
1386            Self {
1387                db: Some(Arc::new(db)),
1388                conn,
1389            }
1390        }
1391
1392        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1393            self.db.as_ref().unwrap()
1394        }
1395    }
1396
1397    impl PostgresTestDb {
1398        pub fn new(background: Arc<Background>) -> Self {
1399            lazy_static! {
1400                static ref LOCK: Mutex<()> = Mutex::new(());
1401            }
1402
1403            let _guard = LOCK.lock();
1404            let mut rng = StdRng::from_entropy();
1405            let url = format!(
1406                "postgres://postgres@localhost/zed-test-{}",
1407                rng.gen::<u128>()
1408            );
1409            let runtime = tokio::runtime::Builder::new_current_thread()
1410                .enable_io()
1411                .enable_time()
1412                .build()
1413                .unwrap();
1414
1415            let mut db = runtime.block_on(async {
1416                sqlx::Postgres::create_database(&url)
1417                    .await
1418                    .expect("failed to create test db");
1419                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1420                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1421                db.migrate(Path::new(migrations_path), false).await.unwrap();
1422                db
1423            });
1424
1425            db.background = Some(background);
1426            db.runtime = Some(runtime);
1427
1428            Self {
1429                db: Some(Arc::new(db)),
1430                url,
1431            }
1432        }
1433
1434        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1435            self.db.as_ref().unwrap()
1436        }
1437    }
1438
1439    impl Drop for PostgresTestDb {
1440        fn drop(&mut self) {
1441            let db = self.db.take().unwrap();
1442            db.teardown(&self.url);
1443        }
1444    }
1445}