db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use serde::{Deserialize, Serialize};
   7use sqlx::{
   8    migrate::{Migrate as _, Migration, MigrationSource},
   9    types::Uuid,
  10    FromRow,
  11};
  12use std::{path::Path, time::Duration};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[cfg(test)]
  16pub type DefaultDb = Db<sqlx::Sqlite>;
  17
  18#[cfg(not(test))]
  19pub type DefaultDb = Db<sqlx::Postgres>;
  20
  21pub struct Db<D: sqlx::Database> {
  22    pool: sqlx::Pool<D>,
  23    #[cfg(test)]
  24    background: Option<std::sync::Arc<gpui::executor::Background>>,
  25    #[cfg(test)]
  26    runtime: Option<tokio::runtime::Runtime>,
  27}
  28
  29macro_rules! test_support {
  30    ($self:ident, { $($token:tt)* }) => {{
  31        let body = async {
  32            $($token)*
  33        };
  34
  35        if cfg!(test) {
  36            #[cfg(not(test))]
  37            unreachable!();
  38
  39            #[cfg(test)]
  40            if let Some(background) = $self.background.as_ref() {
  41                background.simulate_random_delay().await;
  42            }
  43
  44            #[cfg(test)]
  45            $self.runtime.as_ref().unwrap().block_on(body)
  46        } else {
  47            body.await
  48        }
  49    }};
  50}
  51
  52pub trait RowsAffected {
  53    fn rows_affected(&self) -> u64;
  54}
  55
  56#[cfg(test)]
  57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  58    fn rows_affected(&self) -> u64 {
  59        self.rows_affected()
  60    }
  61}
  62
  63impl RowsAffected for sqlx::postgres::PgQueryResult {
  64    fn rows_affected(&self) -> u64 {
  65        self.rows_affected()
  66    }
  67}
  68
  69#[cfg(test)]
  70impl Db<sqlx::Sqlite> {
  71    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  72        use std::str::FromStr as _;
  73        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  74            .unwrap()
  75            .create_if_missing(true)
  76            .shared_cache(true);
  77        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  78            .min_connections(2)
  79            .max_connections(max_connections)
  80            .connect_with(options)
  81            .await?;
  82        Ok(Self {
  83            pool,
  84            background: None,
  85            runtime: None,
  86        })
  87    }
  88
  89    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  90        test_support!(self, {
  91            let query = "
  92                SELECT users.*
  93                FROM users
  94                WHERE users.id IN (SELECT value from json_each($1))
  95            ";
  96            Ok(sqlx::query_as(query)
  97                .bind(&serde_json::json!(ids))
  98                .fetch_all(&self.pool)
  99                .await?)
 100        })
 101    }
 102
 103    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 104        test_support!(self, {
 105            let query = "
 106                SELECT metrics_id
 107                FROM users
 108                WHERE id = $1
 109            ";
 110            Ok(sqlx::query_scalar(query)
 111                .bind(id)
 112                .fetch_one(&self.pool)
 113                .await?)
 114        })
 115    }
 116
 117    pub async fn create_user(
 118        &self,
 119        email_address: &str,
 120        admin: bool,
 121        params: NewUserParams,
 122    ) -> Result<NewUserResult> {
 123        test_support!(self, {
 124            let query = "
 125                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 126                VALUES ($1, $2, $3, $4, $5)
 127                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 128                RETURNING id, metrics_id
 129            ";
 130
 131            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 132                .bind(email_address)
 133                .bind(params.github_login)
 134                .bind(params.github_user_id)
 135                .bind(admin)
 136                .bind(Uuid::new_v4().to_string())
 137                .fetch_one(&self.pool)
 138                .await?;
 139            Ok(NewUserResult {
 140                user_id,
 141                metrics_id,
 142                signup_device_id: None,
 143                inviting_user_id: None,
 144            })
 145        })
 146    }
 147
 148    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 149        unimplemented!()
 150    }
 151
 152    pub async fn create_user_from_invite(
 153        &self,
 154        _invite: &Invite,
 155        _user: NewUserParams,
 156    ) -> Result<Option<NewUserResult>> {
 157        unimplemented!()
 158    }
 159
 160    pub async fn create_signup(&self, _signup: &Signup) -> Result<()> {
 161        unimplemented!()
 162    }
 163
 164    pub async fn create_invite_from_code(
 165        &self,
 166        _code: &str,
 167        _email_address: &str,
 168        _device_id: Option<&str>,
 169    ) -> Result<Invite> {
 170        unimplemented!()
 171    }
 172
 173    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 174        unimplemented!()
 175    }
 176}
 177
 178impl Db<sqlx::Postgres> {
 179    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 180        let pool = sqlx::postgres::PgPoolOptions::new()
 181            .max_connections(max_connections)
 182            .connect(url)
 183            .await?;
 184        Ok(Self {
 185            pool,
 186            #[cfg(test)]
 187            background: None,
 188            #[cfg(test)]
 189            runtime: None,
 190        })
 191    }
 192
 193    #[cfg(test)]
 194    pub fn teardown(&self, url: &str) {
 195        self.runtime.as_ref().unwrap().block_on(async {
 196            use util::ResultExt;
 197            let query = "
 198                SELECT pg_terminate_backend(pg_stat_activity.pid)
 199                FROM pg_stat_activity
 200                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 201            ";
 202            sqlx::query(query).execute(&self.pool).await.log_err();
 203            self.pool.close().await;
 204            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 205                .await
 206                .log_err();
 207        })
 208    }
 209
 210    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 211        test_support!(self, {
 212            let like_string = Self::fuzzy_like_string(name_query);
 213            let query = "
 214                SELECT users.*
 215                FROM users
 216                WHERE github_login ILIKE $1
 217                ORDER BY github_login <-> $2
 218                LIMIT $3
 219            ";
 220            Ok(sqlx::query_as(query)
 221                .bind(like_string)
 222                .bind(name_query)
 223                .bind(limit as i32)
 224                .fetch_all(&self.pool)
 225                .await?)
 226        })
 227    }
 228
 229    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 230        test_support!(self, {
 231            let query = "
 232                SELECT users.*
 233                FROM users
 234                WHERE users.id = ANY ($1)
 235            ";
 236            Ok(sqlx::query_as(query)
 237                .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
 238                .fetch_all(&self.pool)
 239                .await?)
 240        })
 241    }
 242
 243    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 244        test_support!(self, {
 245            let query = "
 246                SELECT metrics_id::text
 247                FROM users
 248                WHERE id = $1
 249            ";
 250            Ok(sqlx::query_scalar(query)
 251                .bind(id)
 252                .fetch_one(&self.pool)
 253                .await?)
 254        })
 255    }
 256
 257    pub async fn create_user(
 258        &self,
 259        email_address: &str,
 260        admin: bool,
 261        params: NewUserParams,
 262    ) -> Result<NewUserResult> {
 263        test_support!(self, {
 264            let query = "
 265                INSERT INTO users (email_address, github_login, github_user_id, admin)
 266                VALUES ($1, $2, $3, $4)
 267                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 268                RETURNING id, metrics_id::text
 269            ";
 270
 271            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 272                .bind(email_address)
 273                .bind(params.github_login)
 274                .bind(params.github_user_id)
 275                .bind(admin)
 276                .fetch_one(&self.pool)
 277                .await?;
 278            Ok(NewUserResult {
 279                user_id,
 280                metrics_id,
 281                signup_device_id: None,
 282                inviting_user_id: None,
 283            })
 284        })
 285    }
 286
 287    pub async fn create_user_from_invite(
 288        &self,
 289        invite: &Invite,
 290        user: NewUserParams,
 291    ) -> Result<Option<NewUserResult>> {
 292        test_support!(self, {
 293            let mut tx = self.pool.begin().await?;
 294
 295            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 296                i32,
 297                Option<UserId>,
 298                Option<UserId>,
 299                Option<String>,
 300            ) = sqlx::query_as(
 301                "
 302                SELECT id, user_id, inviting_user_id, device_id
 303                FROM signups
 304                WHERE
 305                    email_address = $1 AND
 306                    email_confirmation_code = $2
 307                ",
 308            )
 309            .bind(&invite.email_address)
 310            .bind(&invite.email_confirmation_code)
 311            .fetch_optional(&mut tx)
 312            .await?
 313            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 314
 315            if existing_user_id.is_some() {
 316                return Ok(None);
 317            }
 318
 319            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 320                "
 321                INSERT INTO users
 322                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 323                VALUES
 324                ($1, $2, $3, FALSE, $4, $5)
 325                ON CONFLICT (github_login) DO UPDATE SET
 326                    email_address = excluded.email_address,
 327                    github_user_id = excluded.github_user_id,
 328                    admin = excluded.admin
 329                RETURNING id, metrics_id::text
 330                ",
 331            )
 332            .bind(&invite.email_address)
 333            .bind(&user.github_login)
 334            .bind(&user.github_user_id)
 335            .bind(&user.invite_count)
 336            .bind(random_invite_code())
 337            .fetch_one(&mut tx)
 338            .await?;
 339
 340            sqlx::query(
 341                "
 342                UPDATE signups
 343                SET user_id = $1
 344                WHERE id = $2
 345                ",
 346            )
 347            .bind(&user_id)
 348            .bind(&signup_id)
 349            .execute(&mut tx)
 350            .await?;
 351
 352            if let Some(inviting_user_id) = inviting_user_id {
 353                let (user_id_a, user_id_b, a_to_b) = if inviting_user_id < user_id {
 354                    (inviting_user_id, user_id, true)
 355                } else {
 356                    (user_id, inviting_user_id, false)
 357                };
 358
 359                sqlx::query(
 360                    "
 361                    INSERT INTO contacts
 362                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 363                    VALUES
 364                        ($1, $2, $3, TRUE, TRUE)
 365                    ON CONFLICT DO NOTHING
 366                    ",
 367                )
 368                .bind(user_id_a)
 369                .bind(user_id_b)
 370                .bind(a_to_b)
 371                .execute(&mut tx)
 372                .await?;
 373            }
 374
 375            tx.commit().await?;
 376            Ok(Some(NewUserResult {
 377                user_id,
 378                metrics_id,
 379                inviting_user_id,
 380                signup_device_id,
 381            }))
 382        })
 383    }
 384
 385    pub async fn create_signup(&self, signup: &Signup) -> Result<()> {
 386        test_support!(self, {
 387            sqlx::query(
 388                "
 389                INSERT INTO signups
 390                (
 391                    email_address,
 392                    email_confirmation_code,
 393                    email_confirmation_sent,
 394                    platform_linux,
 395                    platform_mac,
 396                    platform_windows,
 397                    platform_unknown,
 398                    editor_features,
 399                    programming_languages,
 400                    device_id,
 401                    added_to_mailing_list
 402                )
 403                VALUES
 404                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8, $9)
 405                ON CONFLICT (email_address) DO UPDATE SET
 406                    email_address = excluded.email_address
 407                RETURNING id
 408                ",
 409            )
 410            .bind(&signup.email_address)
 411            .bind(&random_email_confirmation_code())
 412            .bind(&signup.platform_linux)
 413            .bind(&signup.platform_mac)
 414            .bind(&signup.platform_windows)
 415            .bind(&signup.editor_features)
 416            .bind(&signup.programming_languages)
 417            .bind(&signup.device_id)
 418            .bind(&signup.added_to_mailing_list)
 419            .execute(&self.pool)
 420            .await?;
 421            Ok(())
 422        })
 423    }
 424
 425    pub async fn create_invite_from_code(
 426        &self,
 427        code: &str,
 428        email_address: &str,
 429        device_id: Option<&str>,
 430    ) -> Result<Invite> {
 431        test_support!(self, {
 432            let mut tx = self.pool.begin().await?;
 433
 434            let existing_user: Option<UserId> = sqlx::query_scalar(
 435                "
 436                SELECT id
 437                FROM users
 438                WHERE email_address = $1
 439                ",
 440            )
 441            .bind(email_address)
 442            .fetch_optional(&mut tx)
 443            .await?;
 444            if existing_user.is_some() {
 445                Err(anyhow!("email address is already in use"))?;
 446            }
 447
 448            let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
 449                "
 450                UPDATE users
 451                SET invite_count = invite_count - 1
 452                WHERE invite_code = $1 AND invite_count > 0
 453                RETURNING id
 454                ",
 455            )
 456            .bind(code)
 457            .fetch_optional(&mut tx)
 458            .await?;
 459
 460            let Some(inviter_id) = inviting_user_id_with_invites else {
 461                return Err(Error::Http(
 462                    StatusCode::UNAUTHORIZED,
 463                    "unable to find an invite code with invites remaining".to_string(),
 464                ));
 465            };
 466
 467            let email_confirmation_code: String = sqlx::query_scalar(
 468                "
 469                INSERT INTO signups
 470                (
 471                    email_address,
 472                    email_confirmation_code,
 473                    email_confirmation_sent,
 474                    inviting_user_id,
 475                    platform_linux,
 476                    platform_mac,
 477                    platform_windows,
 478                    platform_unknown,
 479                    device_id
 480                )
 481                VALUES
 482                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 483                ON CONFLICT (email_address)
 484                DO UPDATE SET
 485                    inviting_user_id = excluded.inviting_user_id
 486                RETURNING email_confirmation_code
 487                ",
 488            )
 489            .bind(&email_address)
 490            .bind(&random_email_confirmation_code())
 491            .bind(&inviter_id)
 492            .bind(&device_id)
 493            .fetch_one(&mut tx)
 494            .await?;
 495
 496            tx.commit().await?;
 497
 498            Ok(Invite {
 499                email_address: email_address.into(),
 500                email_confirmation_code,
 501            })
 502        })
 503    }
 504
 505    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 506        test_support!(self, {
 507            let emails = invites
 508                .iter()
 509                .map(|s| s.email_address.as_str())
 510                .collect::<Vec<_>>();
 511            sqlx::query(
 512                "
 513                UPDATE signups
 514                SET email_confirmation_sent = TRUE
 515                WHERE email_address = ANY ($1)
 516                ",
 517            )
 518            .bind(&emails)
 519            .execute(&self.pool)
 520            .await?;
 521            Ok(())
 522        })
 523    }
 524}
 525
 526impl<D> Db<D>
 527where
 528    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 529    D::Connection: sqlx::migrate::Migrate,
 530    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 531    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 532    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 533    D::QueryResult: RowsAffected,
 534    String: sqlx::Type<D>,
 535    i32: sqlx::Type<D>,
 536    i64: sqlx::Type<D>,
 537    bool: sqlx::Type<D>,
 538    str: sqlx::Type<D>,
 539    Uuid: sqlx::Type<D>,
 540    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 541    OffsetDateTime: sqlx::Type<D>,
 542    PrimitiveDateTime: sqlx::Type<D>,
 543    usize: sqlx::ColumnIndex<D::Row>,
 544    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 545    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 546    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 547    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 548    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 549    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 550    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 551    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 552    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 553    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 554    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 555    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 556{
 557    pub async fn migrate(
 558        &self,
 559        migrations_path: &Path,
 560        ignore_checksum_mismatch: bool,
 561    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 562        let migrations = MigrationSource::resolve(migrations_path)
 563            .await
 564            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 565
 566        let mut conn = self.pool.acquire().await?;
 567
 568        conn.ensure_migrations_table().await?;
 569        let applied_migrations: HashMap<_, _> = conn
 570            .list_applied_migrations()
 571            .await?
 572            .into_iter()
 573            .map(|m| (m.version, m))
 574            .collect();
 575
 576        let mut new_migrations = Vec::new();
 577        for migration in migrations {
 578            match applied_migrations.get(&migration.version) {
 579                Some(applied_migration) => {
 580                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 581                    {
 582                        Err(anyhow!(
 583                            "checksum mismatch for applied migration {}",
 584                            migration.description
 585                        ))?;
 586                    }
 587                }
 588                None => {
 589                    let elapsed = conn.apply(&migration).await?;
 590                    new_migrations.push((migration, elapsed));
 591                }
 592            }
 593        }
 594
 595        Ok(new_migrations)
 596    }
 597
 598    pub fn fuzzy_like_string(string: &str) -> String {
 599        let mut result = String::with_capacity(string.len() * 2 + 1);
 600        for c in string.chars() {
 601            if c.is_alphanumeric() {
 602                result.push('%');
 603                result.push(c);
 604            }
 605        }
 606        result.push('%');
 607        result
 608    }
 609
 610    // users
 611
 612    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 613        test_support!(self, {
 614            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 615            Ok(sqlx::query_as(query)
 616                .bind(limit as i32)
 617                .bind((page * limit) as i32)
 618                .fetch_all(&self.pool)
 619                .await?)
 620        })
 621    }
 622
 623    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 624        test_support!(self, {
 625            let query = "
 626                SELECT users.*
 627                FROM users
 628                WHERE id = $1
 629                LIMIT 1
 630            ";
 631            Ok(sqlx::query_as(query)
 632                .bind(&id)
 633                .fetch_optional(&self.pool)
 634                .await?)
 635        })
 636    }
 637
 638    pub async fn get_users_with_no_invites(
 639        &self,
 640        invited_by_another_user: bool,
 641    ) -> Result<Vec<User>> {
 642        test_support!(self, {
 643            let query = format!(
 644                "
 645                SELECT users.*
 646                FROM users
 647                WHERE invite_count = 0
 648                AND inviter_id IS{} NULL
 649                ",
 650                if invited_by_another_user { " NOT" } else { "" }
 651            );
 652
 653            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 654        })
 655    }
 656
 657    pub async fn get_user_by_github_account(
 658        &self,
 659        github_login: &str,
 660        github_user_id: Option<i32>,
 661    ) -> Result<Option<User>> {
 662        test_support!(self, {
 663            if let Some(github_user_id) = github_user_id {
 664                let mut user = sqlx::query_as::<_, User>(
 665                    "
 666                    UPDATE users
 667                    SET github_login = $1
 668                    WHERE github_user_id = $2
 669                    RETURNING *
 670                    ",
 671                )
 672                .bind(github_login)
 673                .bind(github_user_id)
 674                .fetch_optional(&self.pool)
 675                .await?;
 676
 677                if user.is_none() {
 678                    user = sqlx::query_as::<_, User>(
 679                        "
 680                        UPDATE users
 681                        SET github_user_id = $1
 682                        WHERE github_login = $2
 683                        RETURNING *
 684                        ",
 685                    )
 686                    .bind(github_user_id)
 687                    .bind(github_login)
 688                    .fetch_optional(&self.pool)
 689                    .await?;
 690                }
 691
 692                Ok(user)
 693            } else {
 694                let user = sqlx::query_as(
 695                    "
 696                    SELECT * FROM users
 697                    WHERE github_login = $1
 698                    LIMIT 1
 699                    ",
 700                )
 701                .bind(github_login)
 702                .fetch_optional(&self.pool)
 703                .await?;
 704                Ok(user)
 705            }
 706        })
 707    }
 708
 709    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 710        test_support!(self, {
 711            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 712            Ok(sqlx::query(query)
 713                .bind(is_admin)
 714                .bind(id.0)
 715                .execute(&self.pool)
 716                .await
 717                .map(drop)?)
 718        })
 719    }
 720
 721    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 722        test_support!(self, {
 723            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 724            Ok(sqlx::query(query)
 725                .bind(connected_once)
 726                .bind(id.0)
 727                .execute(&self.pool)
 728                .await
 729                .map(drop)?)
 730        })
 731    }
 732
 733    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 734        test_support!(self, {
 735            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 736            sqlx::query(query)
 737                .bind(id.0)
 738                .execute(&self.pool)
 739                .await
 740                .map(drop)?;
 741            let query = "DELETE FROM users WHERE id = $1;";
 742            Ok(sqlx::query(query)
 743                .bind(id.0)
 744                .execute(&self.pool)
 745                .await
 746                .map(drop)?)
 747        })
 748    }
 749
 750    // signups
 751
 752    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 753        test_support!(self, {
 754            Ok(sqlx::query_as(
 755                "
 756                SELECT
 757                    COUNT(*) as count,
 758                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 759                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 760                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 761                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 762                FROM (
 763                    SELECT *
 764                    FROM signups
 765                    WHERE
 766                        NOT email_confirmation_sent
 767                ) AS unsent
 768                ",
 769            )
 770            .fetch_one(&self.pool)
 771            .await?)
 772        })
 773    }
 774
 775    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 776        test_support!(self, {
 777            Ok(sqlx::query_as(
 778                "
 779                SELECT
 780                    email_address, email_confirmation_code
 781                FROM signups
 782                WHERE
 783                    NOT email_confirmation_sent AND
 784                    (platform_mac OR platform_unknown)
 785                ORDER BY
 786                    created_at
 787                LIMIT $1
 788                ",
 789            )
 790            .bind(count as i32)
 791            .fetch_all(&self.pool)
 792            .await?)
 793        })
 794    }
 795
 796    // invite codes
 797
 798    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 799        test_support!(self, {
 800            let mut tx = self.pool.begin().await?;
 801            if count > 0 {
 802                sqlx::query(
 803                    "
 804                    UPDATE users
 805                    SET invite_code = $1
 806                    WHERE id = $2 AND invite_code IS NULL
 807                ",
 808                )
 809                .bind(random_invite_code())
 810                .bind(id)
 811                .execute(&mut tx)
 812                .await?;
 813            }
 814
 815            sqlx::query(
 816                "
 817                UPDATE users
 818                SET invite_count = $1
 819                WHERE id = $2
 820                ",
 821            )
 822            .bind(count as i32)
 823            .bind(id)
 824            .execute(&mut tx)
 825            .await?;
 826            tx.commit().await?;
 827            Ok(())
 828        })
 829    }
 830
 831    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 832        test_support!(self, {
 833            let result: Option<(String, i32)> = sqlx::query_as(
 834                "
 835                    SELECT invite_code, invite_count
 836                    FROM users
 837                    WHERE id = $1 AND invite_code IS NOT NULL 
 838                ",
 839            )
 840            .bind(id)
 841            .fetch_optional(&self.pool)
 842            .await?;
 843            if let Some((code, count)) = result {
 844                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 845            } else {
 846                Ok(None)
 847            }
 848        })
 849    }
 850
 851    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 852        test_support!(self, {
 853            sqlx::query_as(
 854                "
 855                    SELECT *
 856                    FROM users
 857                    WHERE invite_code = $1
 858                ",
 859            )
 860            .bind(code)
 861            .fetch_optional(&self.pool)
 862            .await?
 863            .ok_or_else(|| {
 864                Error::Http(
 865                    StatusCode::NOT_FOUND,
 866                    "that invite code does not exist".to_string(),
 867                )
 868            })
 869        })
 870    }
 871
 872    // projects
 873
 874    /// Registers a new project for the given user.
 875    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 876        test_support!(self, {
 877            Ok(sqlx::query_scalar(
 878                "
 879                INSERT INTO projects(host_user_id)
 880                VALUES ($1)
 881                RETURNING id
 882                ",
 883            )
 884            .bind(host_user_id)
 885            .fetch_one(&self.pool)
 886            .await
 887            .map(ProjectId)?)
 888        })
 889    }
 890
 891    /// Unregisters a project for the given project id.
 892    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 893        test_support!(self, {
 894            sqlx::query(
 895                "
 896                UPDATE projects
 897                SET unregistered = TRUE
 898                WHERE id = $1
 899                ",
 900            )
 901            .bind(project_id)
 902            .execute(&self.pool)
 903            .await?;
 904            Ok(())
 905        })
 906    }
 907
 908    // contacts
 909
 910    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 911        test_support!(self, {
 912            let query = "
 913                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 914                FROM contacts
 915                WHERE user_id_a = $1 OR user_id_b = $1;
 916            ";
 917
 918            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 919                .bind(user_id)
 920                .fetch(&self.pool);
 921
 922            let mut contacts = Vec::new();
 923            while let Some(row) = rows.next().await {
 924                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 925
 926                if user_id_a == user_id {
 927                    if accepted {
 928                        contacts.push(Contact::Accepted {
 929                            user_id: user_id_b,
 930                            should_notify: should_notify && a_to_b,
 931                        });
 932                    } else if a_to_b {
 933                        contacts.push(Contact::Outgoing { user_id: user_id_b })
 934                    } else {
 935                        contacts.push(Contact::Incoming {
 936                            user_id: user_id_b,
 937                            should_notify,
 938                        });
 939                    }
 940                } else if accepted {
 941                    contacts.push(Contact::Accepted {
 942                        user_id: user_id_a,
 943                        should_notify: should_notify && !a_to_b,
 944                    });
 945                } else if a_to_b {
 946                    contacts.push(Contact::Incoming {
 947                        user_id: user_id_a,
 948                        should_notify,
 949                    });
 950                } else {
 951                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 952                }
 953            }
 954
 955            contacts.sort_unstable_by_key(|contact| contact.user_id());
 956
 957            Ok(contacts)
 958        })
 959    }
 960
 961    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 962        test_support!(self, {
 963            let (id_a, id_b) = if user_id_1 < user_id_2 {
 964                (user_id_1, user_id_2)
 965            } else {
 966                (user_id_2, user_id_1)
 967            };
 968
 969            let query = "
 970                SELECT 1 FROM contacts
 971                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
 972                LIMIT 1
 973            ";
 974            Ok(sqlx::query_scalar::<_, i32>(query)
 975                .bind(id_a.0)
 976                .bind(id_b.0)
 977                .fetch_optional(&self.pool)
 978                .await?
 979                .is_some())
 980        })
 981    }
 982
 983    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 984        test_support!(self, {
 985            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 986                (sender_id, receiver_id, true)
 987            } else {
 988                (receiver_id, sender_id, false)
 989            };
 990            let query = "
 991                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 992                VALUES ($1, $2, $3, FALSE, TRUE)
 993                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 994                SET
 995                    accepted = TRUE,
 996                    should_notify = FALSE
 997                WHERE
 998                    NOT contacts.accepted AND
 999                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1000                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1001            ";
1002            let result = sqlx::query(query)
1003                .bind(id_a.0)
1004                .bind(id_b.0)
1005                .bind(a_to_b)
1006                .execute(&self.pool)
1007                .await?;
1008
1009            if result.rows_affected() == 1 {
1010                Ok(())
1011            } else {
1012                Err(anyhow!("contact already requested"))?
1013            }
1014        })
1015    }
1016
1017    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1018        test_support!(self, {
1019            let (id_a, id_b) = if responder_id < requester_id {
1020                (responder_id, requester_id)
1021            } else {
1022                (requester_id, responder_id)
1023            };
1024            let query = "
1025                DELETE FROM contacts
1026                WHERE user_id_a = $1 AND user_id_b = $2;
1027            ";
1028            let result = sqlx::query(query)
1029                .bind(id_a.0)
1030                .bind(id_b.0)
1031                .execute(&self.pool)
1032                .await?;
1033
1034            if result.rows_affected() == 1 {
1035                Ok(())
1036            } else {
1037                Err(anyhow!("no such contact"))?
1038            }
1039        })
1040    }
1041
1042    pub async fn dismiss_contact_notification(
1043        &self,
1044        user_id: UserId,
1045        contact_user_id: UserId,
1046    ) -> Result<()> {
1047        test_support!(self, {
1048            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1049                (user_id, contact_user_id, true)
1050            } else {
1051                (contact_user_id, user_id, false)
1052            };
1053
1054            let query = "
1055                UPDATE contacts
1056                SET should_notify = FALSE
1057                WHERE
1058                    user_id_a = $1 AND user_id_b = $2 AND
1059                    (
1060                        (a_to_b = $3 AND accepted) OR
1061                        (a_to_b != $3 AND NOT accepted)
1062                    );
1063            ";
1064
1065            let result = sqlx::query(query)
1066                .bind(id_a.0)
1067                .bind(id_b.0)
1068                .bind(a_to_b)
1069                .execute(&self.pool)
1070                .await?;
1071
1072            if result.rows_affected() == 0 {
1073                Err(anyhow!("no such contact request"))?;
1074            }
1075
1076            Ok(())
1077        })
1078    }
1079
1080    pub async fn respond_to_contact_request(
1081        &self,
1082        responder_id: UserId,
1083        requester_id: UserId,
1084        accept: bool,
1085    ) -> Result<()> {
1086        test_support!(self, {
1087            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1088                (responder_id, requester_id, false)
1089            } else {
1090                (requester_id, responder_id, true)
1091            };
1092            let result = if accept {
1093                let query = "
1094                    UPDATE contacts
1095                    SET accepted = TRUE, should_notify = TRUE
1096                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1097                ";
1098                sqlx::query(query)
1099                    .bind(id_a.0)
1100                    .bind(id_b.0)
1101                    .bind(a_to_b)
1102                    .execute(&self.pool)
1103                    .await?
1104            } else {
1105                let query = "
1106                    DELETE FROM contacts
1107                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1108                ";
1109                sqlx::query(query)
1110                    .bind(id_a.0)
1111                    .bind(id_b.0)
1112                    .bind(a_to_b)
1113                    .execute(&self.pool)
1114                    .await?
1115            };
1116            if result.rows_affected() == 1 {
1117                Ok(())
1118            } else {
1119                Err(anyhow!("no such contact request"))?
1120            }
1121        })
1122    }
1123
1124    // access tokens
1125
1126    pub async fn create_access_token_hash(
1127        &self,
1128        user_id: UserId,
1129        access_token_hash: &str,
1130        max_access_token_count: usize,
1131    ) -> Result<()> {
1132        test_support!(self, {
1133            let insert_query = "
1134                INSERT INTO access_tokens (user_id, hash)
1135                VALUES ($1, $2);
1136            ";
1137            let cleanup_query = "
1138                DELETE FROM access_tokens
1139                WHERE id IN (
1140                    SELECT id from access_tokens
1141                    WHERE user_id = $1
1142                    ORDER BY id DESC
1143                    LIMIT 10000
1144                    OFFSET $3
1145                )
1146            ";
1147
1148            let mut tx = self.pool.begin().await?;
1149            sqlx::query(insert_query)
1150                .bind(user_id.0)
1151                .bind(access_token_hash)
1152                .execute(&mut tx)
1153                .await?;
1154            sqlx::query(cleanup_query)
1155                .bind(user_id.0)
1156                .bind(access_token_hash)
1157                .bind(max_access_token_count as i32)
1158                .execute(&mut tx)
1159                .await?;
1160            Ok(tx.commit().await?)
1161        })
1162    }
1163
1164    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1165        test_support!(self, {
1166            let query = "
1167                SELECT hash
1168                FROM access_tokens
1169                WHERE user_id = $1
1170                ORDER BY id DESC
1171            ";
1172            Ok(sqlx::query_scalar(query)
1173                .bind(user_id.0)
1174                .fetch_all(&self.pool)
1175                .await?)
1176        })
1177    }
1178}
1179
1180macro_rules! id_type {
1181    ($name:ident) => {
1182        #[derive(
1183            Clone,
1184            Copy,
1185            Debug,
1186            Default,
1187            PartialEq,
1188            Eq,
1189            PartialOrd,
1190            Ord,
1191            Hash,
1192            sqlx::Type,
1193            Serialize,
1194            Deserialize,
1195        )]
1196        #[sqlx(transparent)]
1197        #[serde(transparent)]
1198        pub struct $name(pub i32);
1199
1200        impl $name {
1201            #[allow(unused)]
1202            pub const MAX: Self = Self(i32::MAX);
1203
1204            #[allow(unused)]
1205            pub fn from_proto(value: u64) -> Self {
1206                Self(value as i32)
1207            }
1208
1209            #[allow(unused)]
1210            pub fn to_proto(self) -> u64 {
1211                self.0 as u64
1212            }
1213        }
1214
1215        impl std::fmt::Display for $name {
1216            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1217                self.0.fmt(f)
1218            }
1219        }
1220    };
1221}
1222
1223id_type!(UserId);
1224#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1225pub struct User {
1226    pub id: UserId,
1227    pub github_login: String,
1228    pub github_user_id: Option<i32>,
1229    pub email_address: Option<String>,
1230    pub admin: bool,
1231    pub invite_code: Option<String>,
1232    pub invite_count: i32,
1233    pub connected_once: bool,
1234}
1235
1236id_type!(ProjectId);
1237#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1238pub struct Project {
1239    pub id: ProjectId,
1240    pub host_user_id: UserId,
1241    pub unregistered: bool,
1242}
1243
1244#[derive(Clone, Debug, PartialEq, Eq)]
1245pub enum Contact {
1246    Accepted {
1247        user_id: UserId,
1248        should_notify: bool,
1249    },
1250    Outgoing {
1251        user_id: UserId,
1252    },
1253    Incoming {
1254        user_id: UserId,
1255        should_notify: bool,
1256    },
1257}
1258
1259impl Contact {
1260    pub fn user_id(&self) -> UserId {
1261        match self {
1262            Contact::Accepted { user_id, .. } => *user_id,
1263            Contact::Outgoing { user_id } => *user_id,
1264            Contact::Incoming { user_id, .. } => *user_id,
1265        }
1266    }
1267}
1268
1269#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1270pub struct IncomingContactRequest {
1271    pub requester_id: UserId,
1272    pub should_notify: bool,
1273}
1274
1275#[derive(Clone, Deserialize, Default)]
1276pub struct Signup {
1277    pub email_address: String,
1278    pub platform_mac: bool,
1279    pub platform_windows: bool,
1280    pub platform_linux: bool,
1281    pub editor_features: Vec<String>,
1282    pub programming_languages: Vec<String>,
1283    pub device_id: Option<String>,
1284    pub added_to_mailing_list: bool,
1285}
1286
1287#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1288pub struct WaitlistSummary {
1289    #[sqlx(default)]
1290    pub count: i64,
1291    #[sqlx(default)]
1292    pub linux_count: i64,
1293    #[sqlx(default)]
1294    pub mac_count: i64,
1295    #[sqlx(default)]
1296    pub windows_count: i64,
1297    #[sqlx(default)]
1298    pub unknown_count: i64,
1299}
1300
1301#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)]
1302pub struct Invite {
1303    pub email_address: String,
1304    pub email_confirmation_code: String,
1305}
1306
1307#[derive(Debug, Serialize, Deserialize)]
1308pub struct NewUserParams {
1309    pub github_login: String,
1310    pub github_user_id: i32,
1311    pub invite_count: i32,
1312}
1313
1314#[derive(Debug)]
1315pub struct NewUserResult {
1316    pub user_id: UserId,
1317    pub metrics_id: String,
1318    pub inviting_user_id: Option<UserId>,
1319    pub signup_device_id: Option<String>,
1320}
1321
1322fn random_invite_code() -> String {
1323    nanoid::nanoid!(16)
1324}
1325
1326fn random_email_confirmation_code() -> String {
1327    nanoid::nanoid!(64)
1328}
1329
1330#[cfg(test)]
1331pub use test::*;
1332
1333#[cfg(test)]
1334mod test {
1335    use super::*;
1336    use gpui::executor::Background;
1337    use lazy_static::lazy_static;
1338    use parking_lot::Mutex;
1339    use rand::prelude::*;
1340    use sqlx::migrate::MigrateDatabase;
1341    use std::sync::Arc;
1342
1343    pub struct SqliteTestDb {
1344        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1345        pub conn: sqlx::sqlite::SqliteConnection,
1346    }
1347
1348    pub struct PostgresTestDb {
1349        pub db: Option<Arc<Db<sqlx::Postgres>>>,
1350        pub url: String,
1351    }
1352
1353    impl SqliteTestDb {
1354        pub fn new(background: Arc<Background>) -> Self {
1355            let mut rng = StdRng::from_entropy();
1356            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1357            let runtime = tokio::runtime::Builder::new_current_thread()
1358                .enable_io()
1359                .enable_time()
1360                .build()
1361                .unwrap();
1362
1363            let (mut db, conn) = runtime.block_on(async {
1364                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1365                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1366                db.migrate(migrations_path.as_ref(), false).await.unwrap();
1367                let conn = db.pool.acquire().await.unwrap().detach();
1368                (db, conn)
1369            });
1370
1371            db.background = Some(background);
1372            db.runtime = Some(runtime);
1373
1374            Self {
1375                db: Some(Arc::new(db)),
1376                conn,
1377            }
1378        }
1379
1380        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1381            self.db.as_ref().unwrap()
1382        }
1383    }
1384
1385    impl PostgresTestDb {
1386        pub fn new(background: Arc<Background>) -> Self {
1387            lazy_static! {
1388                static ref LOCK: Mutex<()> = Mutex::new(());
1389            }
1390
1391            let _guard = LOCK.lock();
1392            let mut rng = StdRng::from_entropy();
1393            let url = format!(
1394                "postgres://postgres@localhost/zed-test-{}",
1395                rng.gen::<u128>()
1396            );
1397            let runtime = tokio::runtime::Builder::new_current_thread()
1398                .enable_io()
1399                .enable_time()
1400                .build()
1401                .unwrap();
1402
1403            let mut db = runtime.block_on(async {
1404                sqlx::Postgres::create_database(&url)
1405                    .await
1406                    .expect("failed to create test db");
1407                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1408                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1409                db.migrate(Path::new(migrations_path), false).await.unwrap();
1410                db
1411            });
1412
1413            db.background = Some(background);
1414            db.runtime = Some(runtime);
1415
1416            Self {
1417                db: Some(Arc::new(db)),
1418                url,
1419            }
1420        }
1421
1422        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1423            self.db.as_ref().unwrap()
1424        }
1425    }
1426
1427    impl Drop for PostgresTestDb {
1428        fn drop(&mut self) {
1429            let db = self.db.take().unwrap();
1430            db.teardown(&self.url);
1431        }
1432    }
1433}