db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use serde::{Deserialize, Serialize};
   7use sqlx::{
   8    migrate::{Migrate as _, Migration, MigrationSource},
   9    types::Uuid,
  10    FromRow,
  11};
  12use std::{path::Path, time::Duration};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[cfg(test)]
  16pub type DefaultDb = Db<sqlx::Sqlite>;
  17
  18#[cfg(not(test))]
  19pub type DefaultDb = Db<sqlx::Postgres>;
  20
  21pub struct Db<D: sqlx::Database> {
  22    pool: sqlx::Pool<D>,
  23    #[cfg(test)]
  24    background: Option<std::sync::Arc<gpui::executor::Background>>,
  25    #[cfg(test)]
  26    runtime: Option<tokio::runtime::Runtime>,
  27}
  28
  29macro_rules! test_support {
  30    ($self:ident, { $($token:tt)* }) => {{
  31        let body = async {
  32            $($token)*
  33        };
  34
  35        if cfg!(test) {
  36            #[cfg(not(test))]
  37            unreachable!();
  38
  39            #[cfg(test)]
  40            if let Some(background) = $self.background.as_ref() {
  41                background.simulate_random_delay().await;
  42            }
  43
  44            #[cfg(test)]
  45            $self.runtime.as_ref().unwrap().block_on(body)
  46        } else {
  47            body.await
  48        }
  49    }};
  50}
  51
  52pub trait RowsAffected {
  53    fn rows_affected(&self) -> u64;
  54}
  55
  56impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  57    fn rows_affected(&self) -> u64 {
  58        self.rows_affected()
  59    }
  60}
  61
  62impl RowsAffected for sqlx::postgres::PgQueryResult {
  63    fn rows_affected(&self) -> u64 {
  64        self.rows_affected()
  65    }
  66}
  67
  68#[cfg(test)]
  69impl Db<sqlx::Sqlite> {
  70    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  71        use std::str::FromStr as _;
  72        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  73            .unwrap()
  74            .create_if_missing(true)
  75            .shared_cache(true);
  76        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  77            .min_connections(2)
  78            .max_connections(max_connections)
  79            .connect_with(options)
  80            .await?;
  81        Ok(Self {
  82            pool,
  83            background: None,
  84            runtime: None,
  85        })
  86    }
  87
  88    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  89        test_support!(self, {
  90            let query = "
  91                SELECT users.*
  92                FROM users
  93                WHERE users.id IN (SELECT value from json_each($1))
  94            ";
  95            Ok(sqlx::query_as(query)
  96                .bind(&serde_json::json!(ids))
  97                .fetch_all(&self.pool)
  98                .await?)
  99        })
 100    }
 101
 102    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 103        test_support!(self, {
 104            let query = "
 105                SELECT metrics_id
 106                FROM users
 107                WHERE id = $1
 108            ";
 109            Ok(sqlx::query_scalar(query)
 110                .bind(id)
 111                .fetch_one(&self.pool)
 112                .await?)
 113        })
 114    }
 115
 116    pub async fn create_user(
 117        &self,
 118        email_address: &str,
 119        admin: bool,
 120        params: NewUserParams,
 121    ) -> Result<NewUserResult> {
 122        test_support!(self, {
 123            let query = "
 124                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 125                VALUES ($1, $2, $3, $4, $5)
 126                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 127                RETURNING id, metrics_id
 128            ";
 129
 130            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 131                .bind(email_address)
 132                .bind(params.github_login)
 133                .bind(params.github_user_id)
 134                .bind(admin)
 135                .bind(Uuid::new_v4().to_string())
 136                .fetch_one(&self.pool)
 137                .await?;
 138            Ok(NewUserResult {
 139                user_id,
 140                metrics_id,
 141                signup_device_id: None,
 142                inviting_user_id: None,
 143            })
 144        })
 145    }
 146
 147    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 148        unimplemented!()
 149    }
 150
 151    pub async fn create_user_from_invite(
 152        &self,
 153        _invite: &Invite,
 154        _user: NewUserParams,
 155    ) -> Result<Option<NewUserResult>> {
 156        unimplemented!()
 157    }
 158
 159    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 160        unimplemented!()
 161    }
 162
 163    pub async fn create_invite_from_code(
 164        &self,
 165        _code: &str,
 166        _email_address: &str,
 167        _device_id: Option<&str>,
 168    ) -> Result<Invite> {
 169        unimplemented!()
 170    }
 171
 172    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 173        unimplemented!()
 174    }
 175}
 176
 177impl Db<sqlx::Postgres> {
 178    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 179        let pool = sqlx::postgres::PgPoolOptions::new()
 180            .max_connections(max_connections)
 181            .connect(url)
 182            .await?;
 183        Ok(Self {
 184            pool,
 185            #[cfg(test)]
 186            background: None,
 187            #[cfg(test)]
 188            runtime: None,
 189        })
 190    }
 191
 192    #[cfg(test)]
 193    pub fn teardown(&self, url: &str) {
 194        self.runtime.as_ref().unwrap().block_on(async {
 195            use util::ResultExt;
 196            let query = "
 197                SELECT pg_terminate_backend(pg_stat_activity.pid)
 198                FROM pg_stat_activity
 199                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 200            ";
 201            sqlx::query(query).execute(&self.pool).await.log_err();
 202            self.pool.close().await;
 203            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 204                .await
 205                .log_err();
 206        })
 207    }
 208
 209    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 210        test_support!(self, {
 211            let like_string = Self::fuzzy_like_string(name_query);
 212            let query = "
 213                SELECT users.*
 214                FROM users
 215                WHERE github_login ILIKE $1
 216                ORDER BY github_login <-> $2
 217                LIMIT $3
 218            ";
 219            Ok(sqlx::query_as(query)
 220                .bind(like_string)
 221                .bind(name_query)
 222                .bind(limit as i32)
 223                .fetch_all(&self.pool)
 224                .await?)
 225        })
 226    }
 227
 228    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 229        test_support!(self, {
 230            let query = "
 231                SELECT users.*
 232                FROM users
 233                WHERE users.id = ANY ($1)
 234            ";
 235            Ok(sqlx::query_as(query)
 236                .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
 237                .fetch_all(&self.pool)
 238                .await?)
 239        })
 240    }
 241
 242    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 243        test_support!(self, {
 244            let query = "
 245                SELECT metrics_id::text
 246                FROM users
 247                WHERE id = $1
 248            ";
 249            Ok(sqlx::query_scalar(query)
 250                .bind(id)
 251                .fetch_one(&self.pool)
 252                .await?)
 253        })
 254    }
 255
 256    pub async fn create_user(
 257        &self,
 258        email_address: &str,
 259        admin: bool,
 260        params: NewUserParams,
 261    ) -> Result<NewUserResult> {
 262        test_support!(self, {
 263            let query = "
 264                INSERT INTO users (email_address, github_login, github_user_id, admin)
 265                VALUES ($1, $2, $3, $4)
 266                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 267                RETURNING id, metrics_id::text
 268            ";
 269
 270            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 271                .bind(email_address)
 272                .bind(params.github_login)
 273                .bind(params.github_user_id)
 274                .bind(admin)
 275                .fetch_one(&self.pool)
 276                .await?;
 277            Ok(NewUserResult {
 278                user_id,
 279                metrics_id,
 280                signup_device_id: None,
 281                inviting_user_id: None,
 282            })
 283        })
 284    }
 285
 286    pub async fn create_user_from_invite(
 287        &self,
 288        invite: &Invite,
 289        user: NewUserParams,
 290    ) -> Result<Option<NewUserResult>> {
 291        test_support!(self, {
 292            let mut tx = self.pool.begin().await?;
 293
 294            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 295                i32,
 296                Option<UserId>,
 297                Option<UserId>,
 298                Option<String>,
 299            ) = sqlx::query_as(
 300                "
 301                SELECT id, user_id, inviting_user_id, device_id
 302                FROM signups
 303                WHERE
 304                    email_address = $1 AND
 305                    email_confirmation_code = $2
 306                ",
 307            )
 308            .bind(&invite.email_address)
 309            .bind(&invite.email_confirmation_code)
 310            .fetch_optional(&mut tx)
 311            .await?
 312            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 313
 314            if existing_user_id.is_some() {
 315                return Ok(None);
 316            }
 317
 318            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 319                "
 320                INSERT INTO users
 321                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 322                VALUES
 323                ($1, $2, $3, FALSE, $4, $5)
 324                ON CONFLICT (github_login) DO UPDATE SET
 325                    email_address = excluded.email_address,
 326                    github_user_id = excluded.github_user_id,
 327                    admin = excluded.admin
 328                RETURNING id, metrics_id::text
 329                ",
 330            )
 331            .bind(&invite.email_address)
 332            .bind(&user.github_login)
 333            .bind(&user.github_user_id)
 334            .bind(&user.invite_count)
 335            .bind(random_invite_code())
 336            .fetch_one(&mut tx)
 337            .await?;
 338
 339            sqlx::query(
 340                "
 341                UPDATE signups
 342                SET user_id = $1
 343                WHERE id = $2
 344                ",
 345            )
 346            .bind(&user_id)
 347            .bind(&signup_id)
 348            .execute(&mut tx)
 349            .await?;
 350
 351            if let Some(inviting_user_id) = inviting_user_id {
 352                let id: Option<UserId> = sqlx::query_scalar(
 353                    "
 354                    UPDATE users
 355                    SET invite_count = invite_count - 1
 356                    WHERE id = $1 AND invite_count > 0
 357                    RETURNING id
 358                    ",
 359                )
 360                .bind(&inviting_user_id)
 361                .fetch_optional(&mut tx)
 362                .await?;
 363
 364                if id.is_none() {
 365                    Err(Error::Http(
 366                        StatusCode::UNAUTHORIZED,
 367                        "no invites remaining".to_string(),
 368                    ))?;
 369                }
 370
 371                sqlx::query(
 372                    "
 373                    INSERT INTO contacts
 374                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 375                    VALUES
 376                        ($1, $2, TRUE, TRUE, TRUE)
 377                    ON CONFLICT DO NOTHING
 378                    ",
 379                )
 380                .bind(inviting_user_id)
 381                .bind(user_id)
 382                .execute(&mut tx)
 383                .await?;
 384            }
 385
 386            tx.commit().await?;
 387            Ok(Some(NewUserResult {
 388                user_id,
 389                metrics_id,
 390                inviting_user_id,
 391                signup_device_id,
 392            }))
 393        })
 394    }
 395
 396    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 397        test_support!(self, {
 398            sqlx::query(
 399                "
 400                INSERT INTO signups
 401                (
 402                    email_address,
 403                    email_confirmation_code,
 404                    email_confirmation_sent,
 405                    platform_linux,
 406                    platform_mac,
 407                    platform_windows,
 408                    platform_unknown,
 409                    editor_features,
 410                    programming_languages,
 411                    device_id
 412                )
 413                VALUES
 414                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 415                RETURNING id
 416                ",
 417            )
 418            .bind(&signup.email_address)
 419            .bind(&random_email_confirmation_code())
 420            .bind(&signup.platform_linux)
 421            .bind(&signup.platform_mac)
 422            .bind(&signup.platform_windows)
 423            .bind(&signup.editor_features)
 424            .bind(&signup.programming_languages)
 425            .bind(&signup.device_id)
 426            .execute(&self.pool)
 427            .await?;
 428            Ok(())
 429        })
 430    }
 431
 432    pub async fn create_invite_from_code(
 433        &self,
 434        code: &str,
 435        email_address: &str,
 436        device_id: Option<&str>,
 437    ) -> Result<Invite> {
 438        test_support!(self, {
 439            let mut tx = self.pool.begin().await?;
 440
 441            let existing_user: Option<UserId> = sqlx::query_scalar(
 442                "
 443                SELECT id
 444                FROM users
 445                WHERE email_address = $1
 446                ",
 447            )
 448            .bind(email_address)
 449            .fetch_optional(&mut tx)
 450            .await?;
 451            if existing_user.is_some() {
 452                Err(anyhow!("email address is already in use"))?;
 453            }
 454
 455            let row: Option<(UserId, i32)> = sqlx::query_as(
 456                "
 457                SELECT id, invite_count
 458                FROM users
 459                WHERE invite_code = $1
 460                ",
 461            )
 462            .bind(code)
 463            .fetch_optional(&mut tx)
 464            .await?;
 465
 466            let (inviter_id, invite_count) = match row {
 467                Some(row) => row,
 468                None => Err(Error::Http(
 469                    StatusCode::NOT_FOUND,
 470                    "invite code not found".to_string(),
 471                ))?,
 472            };
 473
 474            if invite_count == 0 {
 475                Err(Error::Http(
 476                    StatusCode::UNAUTHORIZED,
 477                    "no invites remaining".to_string(),
 478                ))?;
 479            }
 480
 481            let email_confirmation_code: String = sqlx::query_scalar(
 482                "
 483                INSERT INTO signups
 484                (
 485                    email_address,
 486                    email_confirmation_code,
 487                    email_confirmation_sent,
 488                    inviting_user_id,
 489                    platform_linux,
 490                    platform_mac,
 491                    platform_windows,
 492                    platform_unknown,
 493                    device_id
 494                )
 495                VALUES
 496                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 497                ON CONFLICT (email_address)
 498                DO UPDATE SET
 499                    inviting_user_id = excluded.inviting_user_id
 500                RETURNING email_confirmation_code
 501                ",
 502            )
 503            .bind(&email_address)
 504            .bind(&random_email_confirmation_code())
 505            .bind(&inviter_id)
 506            .bind(&device_id)
 507            .fetch_one(&mut tx)
 508            .await?;
 509
 510            tx.commit().await?;
 511
 512            Ok(Invite {
 513                email_address: email_address.into(),
 514                email_confirmation_code,
 515            })
 516        })
 517    }
 518
 519    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 520        test_support!(self, {
 521            let emails = invites
 522                .iter()
 523                .map(|s| s.email_address.as_str())
 524                .collect::<Vec<_>>();
 525            sqlx::query(
 526                "
 527                UPDATE signups
 528                SET email_confirmation_sent = TRUE
 529                WHERE email_address = ANY ($1)
 530                ",
 531            )
 532            .bind(&emails)
 533            .execute(&self.pool)
 534            .await?;
 535            Ok(())
 536        })
 537    }
 538}
 539
 540impl<D> Db<D>
 541where
 542    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 543    D::Connection: sqlx::migrate::Migrate,
 544    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 545    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 546    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 547    D::QueryResult: RowsAffected,
 548    String: sqlx::Type<D>,
 549    i32: sqlx::Type<D>,
 550    i64: sqlx::Type<D>,
 551    bool: sqlx::Type<D>,
 552    str: sqlx::Type<D>,
 553    Uuid: sqlx::Type<D>,
 554    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 555    OffsetDateTime: sqlx::Type<D>,
 556    PrimitiveDateTime: sqlx::Type<D>,
 557    usize: sqlx::ColumnIndex<D::Row>,
 558    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 559    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 560    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 561    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 562    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 563    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 564    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 565    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 566    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 567    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 568    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 569    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 570{
 571    pub async fn migrate(
 572        &self,
 573        migrations_path: &Path,
 574        ignore_checksum_mismatch: bool,
 575    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 576        let migrations = MigrationSource::resolve(migrations_path)
 577            .await
 578            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 579
 580        let mut conn = self.pool.acquire().await?;
 581
 582        conn.ensure_migrations_table().await?;
 583        let applied_migrations: HashMap<_, _> = conn
 584            .list_applied_migrations()
 585            .await?
 586            .into_iter()
 587            .map(|m| (m.version, m))
 588            .collect();
 589
 590        let mut new_migrations = Vec::new();
 591        for migration in migrations {
 592            match applied_migrations.get(&migration.version) {
 593                Some(applied_migration) => {
 594                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 595                    {
 596                        Err(anyhow!(
 597                            "checksum mismatch for applied migration {}",
 598                            migration.description
 599                        ))?;
 600                    }
 601                }
 602                None => {
 603                    let elapsed = conn.apply(&migration).await?;
 604                    new_migrations.push((migration, elapsed));
 605                }
 606            }
 607        }
 608
 609        Ok(new_migrations)
 610    }
 611
 612    pub fn fuzzy_like_string(string: &str) -> String {
 613        let mut result = String::with_capacity(string.len() * 2 + 1);
 614        for c in string.chars() {
 615            if c.is_alphanumeric() {
 616                result.push('%');
 617                result.push(c);
 618            }
 619        }
 620        result.push('%');
 621        result
 622    }
 623
 624    // users
 625
 626    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 627        test_support!(self, {
 628            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 629            Ok(sqlx::query_as(query)
 630                .bind(limit as i32)
 631                .bind((page * limit) as i32)
 632                .fetch_all(&self.pool)
 633                .await?)
 634        })
 635    }
 636
 637    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 638        test_support!(self, {
 639            let query = "
 640                SELECT users.*
 641                FROM users
 642                WHERE id = $1
 643                LIMIT 1
 644            ";
 645            Ok(sqlx::query_as(query)
 646                .bind(&id)
 647                .fetch_optional(&self.pool)
 648                .await?)
 649        })
 650    }
 651
 652    pub async fn get_users_with_no_invites(
 653        &self,
 654        invited_by_another_user: bool,
 655    ) -> Result<Vec<User>> {
 656        test_support!(self, {
 657            let query = format!(
 658                "
 659                SELECT users.*
 660                FROM users
 661                WHERE invite_count = 0
 662                AND inviter_id IS{} NULL
 663                ",
 664                if invited_by_another_user { " NOT" } else { "" }
 665            );
 666
 667            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 668        })
 669    }
 670
 671    pub async fn get_user_by_github_account(
 672        &self,
 673        github_login: &str,
 674        github_user_id: Option<i32>,
 675    ) -> Result<Option<User>> {
 676        test_support!(self, {
 677            if let Some(github_user_id) = github_user_id {
 678                let mut user = sqlx::query_as::<_, User>(
 679                    "
 680                    UPDATE users
 681                    SET github_login = $1
 682                    WHERE github_user_id = $2
 683                    RETURNING *
 684                    ",
 685                )
 686                .bind(github_login)
 687                .bind(github_user_id)
 688                .fetch_optional(&self.pool)
 689                .await?;
 690
 691                if user.is_none() {
 692                    user = sqlx::query_as::<_, User>(
 693                        "
 694                        UPDATE users
 695                        SET github_user_id = $1
 696                        WHERE github_login = $2
 697                        RETURNING *
 698                        ",
 699                    )
 700                    .bind(github_user_id)
 701                    .bind(github_login)
 702                    .fetch_optional(&self.pool)
 703                    .await?;
 704                }
 705
 706                Ok(user)
 707            } else {
 708                let user = sqlx::query_as(
 709                    "
 710                    SELECT * FROM users
 711                    WHERE github_login = $1
 712                    LIMIT 1
 713                    ",
 714                )
 715                .bind(github_login)
 716                .fetch_optional(&self.pool)
 717                .await?;
 718                Ok(user)
 719            }
 720        })
 721    }
 722
 723    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 724        test_support!(self, {
 725            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 726            Ok(sqlx::query(query)
 727                .bind(is_admin)
 728                .bind(id.0)
 729                .execute(&self.pool)
 730                .await
 731                .map(drop)?)
 732        })
 733    }
 734
 735    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 736        test_support!(self, {
 737            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 738            Ok(sqlx::query(query)
 739                .bind(connected_once)
 740                .bind(id.0)
 741                .execute(&self.pool)
 742                .await
 743                .map(drop)?)
 744        })
 745    }
 746
 747    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 748        test_support!(self, {
 749            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 750            sqlx::query(query)
 751                .bind(id.0)
 752                .execute(&self.pool)
 753                .await
 754                .map(drop)?;
 755            let query = "DELETE FROM users WHERE id = $1;";
 756            Ok(sqlx::query(query)
 757                .bind(id.0)
 758                .execute(&self.pool)
 759                .await
 760                .map(drop)?)
 761        })
 762    }
 763
 764    // signups
 765
 766    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 767        test_support!(self, {
 768            Ok(sqlx::query_as(
 769                "
 770                SELECT
 771                    COUNT(*) as count,
 772                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 773                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 774                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 775                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 776                FROM (
 777                    SELECT *
 778                    FROM signups
 779                    WHERE
 780                        NOT email_confirmation_sent
 781                ) AS unsent
 782                ",
 783            )
 784            .fetch_one(&self.pool)
 785            .await?)
 786        })
 787    }
 788
 789    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 790        test_support!(self, {
 791            Ok(sqlx::query_as(
 792                "
 793                SELECT
 794                    email_address, email_confirmation_code
 795                FROM signups
 796                WHERE
 797                    NOT email_confirmation_sent AND
 798                    (platform_mac OR platform_unknown)
 799                LIMIT $1
 800                ",
 801            )
 802            .bind(count as i32)
 803            .fetch_all(&self.pool)
 804            .await?)
 805        })
 806    }
 807
 808    // invite codes
 809
 810    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 811        test_support!(self, {
 812            let mut tx = self.pool.begin().await?;
 813            if count > 0 {
 814                sqlx::query(
 815                    "
 816                    UPDATE users
 817                    SET invite_code = $1
 818                    WHERE id = $2 AND invite_code IS NULL
 819                ",
 820                )
 821                .bind(random_invite_code())
 822                .bind(id)
 823                .execute(&mut tx)
 824                .await?;
 825            }
 826
 827            sqlx::query(
 828                "
 829                UPDATE users
 830                SET invite_count = $1
 831                WHERE id = $2
 832                ",
 833            )
 834            .bind(count as i32)
 835            .bind(id)
 836            .execute(&mut tx)
 837            .await?;
 838            tx.commit().await?;
 839            Ok(())
 840        })
 841    }
 842
 843    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 844        test_support!(self, {
 845            let result: Option<(String, i32)> = sqlx::query_as(
 846                "
 847                    SELECT invite_code, invite_count
 848                    FROM users
 849                    WHERE id = $1 AND invite_code IS NOT NULL 
 850                ",
 851            )
 852            .bind(id)
 853            .fetch_optional(&self.pool)
 854            .await?;
 855            if let Some((code, count)) = result {
 856                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 857            } else {
 858                Ok(None)
 859            }
 860        })
 861    }
 862
 863    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 864        test_support!(self, {
 865            sqlx::query_as(
 866                "
 867                    SELECT *
 868                    FROM users
 869                    WHERE invite_code = $1
 870                ",
 871            )
 872            .bind(code)
 873            .fetch_optional(&self.pool)
 874            .await?
 875            .ok_or_else(|| {
 876                Error::Http(
 877                    StatusCode::NOT_FOUND,
 878                    "that invite code does not exist".to_string(),
 879                )
 880            })
 881        })
 882    }
 883
 884    // projects
 885
 886    /// Registers a new project for the given user.
 887    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 888        test_support!(self, {
 889            Ok(sqlx::query_scalar(
 890                "
 891                INSERT INTO projects(host_user_id)
 892                VALUES ($1)
 893                RETURNING id
 894                ",
 895            )
 896            .bind(host_user_id)
 897            .fetch_one(&self.pool)
 898            .await
 899            .map(ProjectId)?)
 900        })
 901    }
 902
 903    /// Unregisters a project for the given project id.
 904    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 905        test_support!(self, {
 906            sqlx::query(
 907                "
 908                UPDATE projects
 909                SET unregistered = TRUE
 910                WHERE id = $1
 911                ",
 912            )
 913            .bind(project_id)
 914            .execute(&self.pool)
 915            .await?;
 916            Ok(())
 917        })
 918    }
 919
 920    // contacts
 921
 922    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 923        test_support!(self, {
 924            let query = "
 925                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 926                FROM contacts
 927                WHERE user_id_a = $1 OR user_id_b = $1;
 928            ";
 929
 930            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 931                .bind(user_id)
 932                .fetch(&self.pool);
 933
 934            let mut contacts = Vec::new();
 935            while let Some(row) = rows.next().await {
 936                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 937
 938                if user_id_a == user_id {
 939                    if accepted {
 940                        contacts.push(Contact::Accepted {
 941                            user_id: user_id_b,
 942                            should_notify: should_notify && a_to_b,
 943                        });
 944                    } else if a_to_b {
 945                        contacts.push(Contact::Outgoing { user_id: user_id_b })
 946                    } else {
 947                        contacts.push(Contact::Incoming {
 948                            user_id: user_id_b,
 949                            should_notify,
 950                        });
 951                    }
 952                } else if accepted {
 953                    contacts.push(Contact::Accepted {
 954                        user_id: user_id_a,
 955                        should_notify: should_notify && !a_to_b,
 956                    });
 957                } else if a_to_b {
 958                    contacts.push(Contact::Incoming {
 959                        user_id: user_id_a,
 960                        should_notify,
 961                    });
 962                } else {
 963                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 964                }
 965            }
 966
 967            contacts.sort_unstable_by_key(|contact| contact.user_id());
 968
 969            Ok(contacts)
 970        })
 971    }
 972
 973    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 974        test_support!(self, {
 975            let (id_a, id_b) = if user_id_1 < user_id_2 {
 976                (user_id_1, user_id_2)
 977            } else {
 978                (user_id_2, user_id_1)
 979            };
 980
 981            let query = "
 982                SELECT 1 FROM contacts
 983                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
 984                LIMIT 1
 985            ";
 986            Ok(sqlx::query_scalar::<_, i32>(query)
 987                .bind(id_a.0)
 988                .bind(id_b.0)
 989                .fetch_optional(&self.pool)
 990                .await?
 991                .is_some())
 992        })
 993    }
 994
 995    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 996        test_support!(self, {
 997            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 998                (sender_id, receiver_id, true)
 999            } else {
1000                (receiver_id, sender_id, false)
1001            };
1002            let query = "
1003                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1004                VALUES ($1, $2, $3, FALSE, TRUE)
1005                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1006                SET
1007                    accepted = TRUE,
1008                    should_notify = FALSE
1009                WHERE
1010                    NOT contacts.accepted AND
1011                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1012                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1013            ";
1014            let result = sqlx::query(query)
1015                .bind(id_a.0)
1016                .bind(id_b.0)
1017                .bind(a_to_b)
1018                .execute(&self.pool)
1019                .await?;
1020
1021            if result.rows_affected() == 1 {
1022                Ok(())
1023            } else {
1024                Err(anyhow!("contact already requested"))?
1025            }
1026        })
1027    }
1028
1029    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1030        test_support!(self, {
1031            let (id_a, id_b) = if responder_id < requester_id {
1032                (responder_id, requester_id)
1033            } else {
1034                (requester_id, responder_id)
1035            };
1036            let query = "
1037                DELETE FROM contacts
1038                WHERE user_id_a = $1 AND user_id_b = $2;
1039            ";
1040            let result = sqlx::query(query)
1041                .bind(id_a.0)
1042                .bind(id_b.0)
1043                .execute(&self.pool)
1044                .await?;
1045
1046            if result.rows_affected() == 1 {
1047                Ok(())
1048            } else {
1049                Err(anyhow!("no such contact"))?
1050            }
1051        })
1052    }
1053
1054    pub async fn dismiss_contact_notification(
1055        &self,
1056        user_id: UserId,
1057        contact_user_id: UserId,
1058    ) -> Result<()> {
1059        test_support!(self, {
1060            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1061                (user_id, contact_user_id, true)
1062            } else {
1063                (contact_user_id, user_id, false)
1064            };
1065
1066            let query = "
1067                UPDATE contacts
1068                SET should_notify = FALSE
1069                WHERE
1070                    user_id_a = $1 AND user_id_b = $2 AND
1071                    (
1072                        (a_to_b = $3 AND accepted) OR
1073                        (a_to_b != $3 AND NOT accepted)
1074                    );
1075            ";
1076
1077            let result = sqlx::query(query)
1078                .bind(id_a.0)
1079                .bind(id_b.0)
1080                .bind(a_to_b)
1081                .execute(&self.pool)
1082                .await?;
1083
1084            if result.rows_affected() == 0 {
1085                Err(anyhow!("no such contact request"))?;
1086            }
1087
1088            Ok(())
1089        })
1090    }
1091
1092    pub async fn respond_to_contact_request(
1093        &self,
1094        responder_id: UserId,
1095        requester_id: UserId,
1096        accept: bool,
1097    ) -> Result<()> {
1098        test_support!(self, {
1099            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1100                (responder_id, requester_id, false)
1101            } else {
1102                (requester_id, responder_id, true)
1103            };
1104            let result = if accept {
1105                let query = "
1106                    UPDATE contacts
1107                    SET accepted = TRUE, should_notify = TRUE
1108                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1109                ";
1110                sqlx::query(query)
1111                    .bind(id_a.0)
1112                    .bind(id_b.0)
1113                    .bind(a_to_b)
1114                    .execute(&self.pool)
1115                    .await?
1116            } else {
1117                let query = "
1118                    DELETE FROM contacts
1119                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1120                ";
1121                sqlx::query(query)
1122                    .bind(id_a.0)
1123                    .bind(id_b.0)
1124                    .bind(a_to_b)
1125                    .execute(&self.pool)
1126                    .await?
1127            };
1128            if result.rows_affected() == 1 {
1129                Ok(())
1130            } else {
1131                Err(anyhow!("no such contact request"))?
1132            }
1133        })
1134    }
1135
1136    // access tokens
1137
1138    pub async fn create_access_token_hash(
1139        &self,
1140        user_id: UserId,
1141        access_token_hash: &str,
1142        max_access_token_count: usize,
1143    ) -> Result<()> {
1144        test_support!(self, {
1145            let insert_query = "
1146                INSERT INTO access_tokens (user_id, hash)
1147                VALUES ($1, $2);
1148            ";
1149            let cleanup_query = "
1150                DELETE FROM access_tokens
1151                WHERE id IN (
1152                    SELECT id from access_tokens
1153                    WHERE user_id = $1
1154                    ORDER BY id DESC
1155                    LIMIT 10000
1156                    OFFSET $3
1157                )
1158            ";
1159
1160            let mut tx = self.pool.begin().await?;
1161            sqlx::query(insert_query)
1162                .bind(user_id.0)
1163                .bind(access_token_hash)
1164                .execute(&mut tx)
1165                .await?;
1166            sqlx::query(cleanup_query)
1167                .bind(user_id.0)
1168                .bind(access_token_hash)
1169                .bind(max_access_token_count as i32)
1170                .execute(&mut tx)
1171                .await?;
1172            Ok(tx.commit().await?)
1173        })
1174    }
1175
1176    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1177        test_support!(self, {
1178            let query = "
1179                SELECT hash
1180                FROM access_tokens
1181                WHERE user_id = $1
1182                ORDER BY id DESC
1183            ";
1184            Ok(sqlx::query_scalar(query)
1185                .bind(user_id.0)
1186                .fetch_all(&self.pool)
1187                .await?)
1188        })
1189    }
1190}
1191
1192macro_rules! id_type {
1193    ($name:ident) => {
1194        #[derive(
1195            Clone,
1196            Copy,
1197            Debug,
1198            Default,
1199            PartialEq,
1200            Eq,
1201            PartialOrd,
1202            Ord,
1203            Hash,
1204            sqlx::Type,
1205            Serialize,
1206            Deserialize,
1207        )]
1208        #[sqlx(transparent)]
1209        #[serde(transparent)]
1210        pub struct $name(pub i32);
1211
1212        impl $name {
1213            #[allow(unused)]
1214            pub const MAX: Self = Self(i32::MAX);
1215
1216            #[allow(unused)]
1217            pub fn from_proto(value: u64) -> Self {
1218                Self(value as i32)
1219            }
1220
1221            #[allow(unused)]
1222            pub fn to_proto(self) -> u64 {
1223                self.0 as u64
1224            }
1225        }
1226
1227        impl std::fmt::Display for $name {
1228            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1229                self.0.fmt(f)
1230            }
1231        }
1232    };
1233}
1234
1235id_type!(UserId);
1236#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1237pub struct User {
1238    pub id: UserId,
1239    pub github_login: String,
1240    pub github_user_id: Option<i32>,
1241    pub email_address: Option<String>,
1242    pub admin: bool,
1243    pub invite_code: Option<String>,
1244    pub invite_count: i32,
1245    pub connected_once: bool,
1246}
1247
1248id_type!(ProjectId);
1249#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1250pub struct Project {
1251    pub id: ProjectId,
1252    pub host_user_id: UserId,
1253    pub unregistered: bool,
1254}
1255
1256#[derive(Clone, Debug, PartialEq, Eq)]
1257pub enum Contact {
1258    Accepted {
1259        user_id: UserId,
1260        should_notify: bool,
1261    },
1262    Outgoing {
1263        user_id: UserId,
1264    },
1265    Incoming {
1266        user_id: UserId,
1267        should_notify: bool,
1268    },
1269}
1270
1271impl Contact {
1272    pub fn user_id(&self) -> UserId {
1273        match self {
1274            Contact::Accepted { user_id, .. } => *user_id,
1275            Contact::Outgoing { user_id } => *user_id,
1276            Contact::Incoming { user_id, .. } => *user_id,
1277        }
1278    }
1279}
1280
1281#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1282pub struct IncomingContactRequest {
1283    pub requester_id: UserId,
1284    pub should_notify: bool,
1285}
1286
1287#[derive(Clone, Deserialize)]
1288pub struct Signup {
1289    pub email_address: String,
1290    pub platform_mac: bool,
1291    pub platform_windows: bool,
1292    pub platform_linux: bool,
1293    pub editor_features: Vec<String>,
1294    pub programming_languages: Vec<String>,
1295    pub device_id: Option<String>,
1296}
1297
1298#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1299pub struct WaitlistSummary {
1300    #[sqlx(default)]
1301    pub count: i64,
1302    #[sqlx(default)]
1303    pub linux_count: i64,
1304    #[sqlx(default)]
1305    pub mac_count: i64,
1306    #[sqlx(default)]
1307    pub windows_count: i64,
1308    #[sqlx(default)]
1309    pub unknown_count: i64,
1310}
1311
1312#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1313pub struct Invite {
1314    pub email_address: String,
1315    pub email_confirmation_code: String,
1316}
1317
1318#[derive(Debug, Serialize, Deserialize)]
1319pub struct NewUserParams {
1320    pub github_login: String,
1321    pub github_user_id: i32,
1322    pub invite_count: i32,
1323}
1324
1325#[derive(Debug)]
1326pub struct NewUserResult {
1327    pub user_id: UserId,
1328    pub metrics_id: String,
1329    pub inviting_user_id: Option<UserId>,
1330    pub signup_device_id: Option<String>,
1331}
1332
1333fn random_invite_code() -> String {
1334    nanoid::nanoid!(16)
1335}
1336
1337fn random_email_confirmation_code() -> String {
1338    nanoid::nanoid!(64)
1339}
1340
1341#[cfg(test)]
1342pub use test::*;
1343
1344#[cfg(test)]
1345mod test {
1346    use super::*;
1347    use gpui::executor::Background;
1348    use lazy_static::lazy_static;
1349    use parking_lot::Mutex;
1350    use rand::prelude::*;
1351    use sqlx::migrate::MigrateDatabase;
1352    use std::sync::Arc;
1353
1354    pub struct SqliteTestDb {
1355        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1356        pub conn: sqlx::sqlite::SqliteConnection,
1357    }
1358
1359    pub struct PostgresTestDb {
1360        pub db: Option<Arc<Db<sqlx::Postgres>>>,
1361        pub url: String,
1362    }
1363
1364    impl SqliteTestDb {
1365        pub fn new(background: Arc<Background>) -> Self {
1366            let mut rng = StdRng::from_entropy();
1367            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1368            let runtime = tokio::runtime::Builder::new_current_thread()
1369                .enable_io()
1370                .enable_time()
1371                .build()
1372                .unwrap();
1373
1374            let (mut db, conn) = runtime.block_on(async {
1375                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1376                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1377                db.migrate(migrations_path.as_ref(), false).await.unwrap();
1378                let conn = db.pool.acquire().await.unwrap().detach();
1379                (db, conn)
1380            });
1381
1382            db.background = Some(background);
1383            db.runtime = Some(runtime);
1384
1385            Self {
1386                db: Some(Arc::new(db)),
1387                conn,
1388            }
1389        }
1390
1391        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1392            self.db.as_ref().unwrap()
1393        }
1394    }
1395
1396    impl PostgresTestDb {
1397        pub fn new(background: Arc<Background>) -> Self {
1398            lazy_static! {
1399                static ref LOCK: Mutex<()> = Mutex::new(());
1400            }
1401
1402            let _guard = LOCK.lock();
1403            let mut rng = StdRng::from_entropy();
1404            let url = format!(
1405                "postgres://postgres@localhost/zed-test-{}",
1406                rng.gen::<u128>()
1407            );
1408            let runtime = tokio::runtime::Builder::new_current_thread()
1409                .enable_io()
1410                .enable_time()
1411                .build()
1412                .unwrap();
1413
1414            let mut db = runtime.block_on(async {
1415                sqlx::Postgres::create_database(&url)
1416                    .await
1417                    .expect("failed to create test db");
1418                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1419                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1420                db.migrate(Path::new(migrations_path), false).await.unwrap();
1421                db
1422            });
1423
1424            db.background = Some(background);
1425            db.runtime = Some(runtime);
1426
1427            Self {
1428                db: Some(Arc::new(db)),
1429                url,
1430            }
1431        }
1432
1433        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1434            self.db.as_ref().unwrap()
1435        }
1436    }
1437
1438    impl Drop for PostgresTestDb {
1439        fn drop(&mut self) {
1440            let db = self.db.take().unwrap();
1441            db.teardown(&self.url);
1442        }
1443    }
1444}