db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::HashMap;
   5use futures::StreamExt;
   6use rpc::{proto, ConnectionId};
   7use serde::{Deserialize, Serialize};
   8use sqlx::{
   9    migrate::{Migrate as _, Migration, MigrationSource},
  10    types::Uuid,
  11    FromRow,
  12};
  13use std::{path::Path, time::Duration};
  14use time::{OffsetDateTime, PrimitiveDateTime};
  15
  16#[cfg(test)]
  17pub type DefaultDb = Db<sqlx::Sqlite>;
  18
  19#[cfg(not(test))]
  20pub type DefaultDb = Db<sqlx::Postgres>;
  21
  22pub struct Db<D: sqlx::Database> {
  23    pool: sqlx::Pool<D>,
  24    #[cfg(test)]
  25    background: Option<std::sync::Arc<gpui::executor::Background>>,
  26    #[cfg(test)]
  27    runtime: Option<tokio::runtime::Runtime>,
  28}
  29
  30macro_rules! test_support {
  31    ($self:ident, { $($token:tt)* }) => {{
  32        let body = async {
  33            $($token)*
  34        };
  35
  36        if cfg!(test) {
  37            #[cfg(not(test))]
  38            unreachable!();
  39
  40            #[cfg(test)]
  41            if let Some(background) = $self.background.as_ref() {
  42                background.simulate_random_delay().await;
  43            }
  44
  45            #[cfg(test)]
  46            $self.runtime.as_ref().unwrap().block_on(body)
  47        } else {
  48            body.await
  49        }
  50    }};
  51}
  52
  53pub trait RowsAffected {
  54    fn rows_affected(&self) -> u64;
  55}
  56
  57#[cfg(test)]
  58impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  59    fn rows_affected(&self) -> u64 {
  60        self.rows_affected()
  61    }
  62}
  63
  64impl RowsAffected for sqlx::postgres::PgQueryResult {
  65    fn rows_affected(&self) -> u64 {
  66        self.rows_affected()
  67    }
  68}
  69
  70#[cfg(test)]
  71impl Db<sqlx::Sqlite> {
  72    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  73        use std::str::FromStr as _;
  74        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  75            .unwrap()
  76            .create_if_missing(true)
  77            .shared_cache(true);
  78        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  79            .min_connections(2)
  80            .max_connections(max_connections)
  81            .connect_with(options)
  82            .await?;
  83        Ok(Self {
  84            pool,
  85            background: None,
  86            runtime: None,
  87        })
  88    }
  89
  90    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  91        test_support!(self, {
  92            let query = "
  93                SELECT users.*
  94                FROM users
  95                WHERE users.id IN (SELECT value from json_each($1))
  96            ";
  97            Ok(sqlx::query_as(query)
  98                .bind(&serde_json::json!(ids))
  99                .fetch_all(&self.pool)
 100                .await?)
 101        })
 102    }
 103
 104    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 105        test_support!(self, {
 106            let query = "
 107                SELECT metrics_id
 108                FROM users
 109                WHERE id = $1
 110            ";
 111            Ok(sqlx::query_scalar(query)
 112                .bind(id)
 113                .fetch_one(&self.pool)
 114                .await?)
 115        })
 116    }
 117
 118    pub async fn create_user(
 119        &self,
 120        email_address: &str,
 121        admin: bool,
 122        params: NewUserParams,
 123    ) -> Result<NewUserResult> {
 124        test_support!(self, {
 125            let query = "
 126                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 127                VALUES ($1, $2, $3, $4, $5)
 128                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 129                RETURNING id, metrics_id
 130            ";
 131
 132            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 133                .bind(email_address)
 134                .bind(params.github_login)
 135                .bind(params.github_user_id)
 136                .bind(admin)
 137                .bind(Uuid::new_v4().to_string())
 138                .fetch_one(&self.pool)
 139                .await?;
 140            Ok(NewUserResult {
 141                user_id,
 142                metrics_id,
 143                signup_device_id: None,
 144                inviting_user_id: None,
 145            })
 146        })
 147    }
 148
 149    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 150        unimplemented!()
 151    }
 152
 153    pub async fn create_user_from_invite(
 154        &self,
 155        _invite: &Invite,
 156        _user: NewUserParams,
 157    ) -> Result<Option<NewUserResult>> {
 158        unimplemented!()
 159    }
 160
 161    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 162        unimplemented!()
 163    }
 164
 165    pub async fn create_invite_from_code(
 166        &self,
 167        _code: &str,
 168        _email_address: &str,
 169        _device_id: Option<&str>,
 170    ) -> Result<Invite> {
 171        unimplemented!()
 172    }
 173
 174    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 175        unimplemented!()
 176    }
 177}
 178
 179impl Db<sqlx::Postgres> {
 180    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 181        let pool = sqlx::postgres::PgPoolOptions::new()
 182            .max_connections(max_connections)
 183            .connect(url)
 184            .await?;
 185        Ok(Self {
 186            pool,
 187            #[cfg(test)]
 188            background: None,
 189            #[cfg(test)]
 190            runtime: None,
 191        })
 192    }
 193
 194    #[cfg(test)]
 195    pub fn teardown(&self, url: &str) {
 196        self.runtime.as_ref().unwrap().block_on(async {
 197            use util::ResultExt;
 198            let query = "
 199                SELECT pg_terminate_backend(pg_stat_activity.pid)
 200                FROM pg_stat_activity
 201                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 202            ";
 203            sqlx::query(query).execute(&self.pool).await.log_err();
 204            self.pool.close().await;
 205            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 206                .await
 207                .log_err();
 208        })
 209    }
 210
 211    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 212        test_support!(self, {
 213            let like_string = Self::fuzzy_like_string(name_query);
 214            let query = "
 215                SELECT users.*
 216                FROM users
 217                WHERE github_login ILIKE $1
 218                ORDER BY github_login <-> $2
 219                LIMIT $3
 220            ";
 221            Ok(sqlx::query_as(query)
 222                .bind(like_string)
 223                .bind(name_query)
 224                .bind(limit as i32)
 225                .fetch_all(&self.pool)
 226                .await?)
 227        })
 228    }
 229
 230    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 231        test_support!(self, {
 232            let query = "
 233                SELECT users.*
 234                FROM users
 235                WHERE users.id = ANY ($1)
 236            ";
 237            Ok(sqlx::query_as(query)
 238                .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
 239                .fetch_all(&self.pool)
 240                .await?)
 241        })
 242    }
 243
 244    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 245        test_support!(self, {
 246            let query = "
 247                SELECT metrics_id::text
 248                FROM users
 249                WHERE id = $1
 250            ";
 251            Ok(sqlx::query_scalar(query)
 252                .bind(id)
 253                .fetch_one(&self.pool)
 254                .await?)
 255        })
 256    }
 257
 258    pub async fn create_user(
 259        &self,
 260        email_address: &str,
 261        admin: bool,
 262        params: NewUserParams,
 263    ) -> Result<NewUserResult> {
 264        test_support!(self, {
 265            let query = "
 266                INSERT INTO users (email_address, github_login, github_user_id, admin)
 267                VALUES ($1, $2, $3, $4)
 268                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 269                RETURNING id, metrics_id::text
 270            ";
 271
 272            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 273                .bind(email_address)
 274                .bind(params.github_login)
 275                .bind(params.github_user_id)
 276                .bind(admin)
 277                .fetch_one(&self.pool)
 278                .await?;
 279            Ok(NewUserResult {
 280                user_id,
 281                metrics_id,
 282                signup_device_id: None,
 283                inviting_user_id: None,
 284            })
 285        })
 286    }
 287
 288    pub async fn create_user_from_invite(
 289        &self,
 290        invite: &Invite,
 291        user: NewUserParams,
 292    ) -> Result<Option<NewUserResult>> {
 293        test_support!(self, {
 294            let mut tx = self.pool.begin().await?;
 295
 296            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 297                i32,
 298                Option<UserId>,
 299                Option<UserId>,
 300                Option<String>,
 301            ) = sqlx::query_as(
 302                "
 303                SELECT id, user_id, inviting_user_id, device_id
 304                FROM signups
 305                WHERE
 306                    email_address = $1 AND
 307                    email_confirmation_code = $2
 308                ",
 309            )
 310            .bind(&invite.email_address)
 311            .bind(&invite.email_confirmation_code)
 312            .fetch_optional(&mut tx)
 313            .await?
 314            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 315
 316            if existing_user_id.is_some() {
 317                return Ok(None);
 318            }
 319
 320            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 321                "
 322                INSERT INTO users
 323                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 324                VALUES
 325                ($1, $2, $3, FALSE, $4, $5)
 326                ON CONFLICT (github_login) DO UPDATE SET
 327                    email_address = excluded.email_address,
 328                    github_user_id = excluded.github_user_id,
 329                    admin = excluded.admin
 330                RETURNING id, metrics_id::text
 331                ",
 332            )
 333            .bind(&invite.email_address)
 334            .bind(&user.github_login)
 335            .bind(&user.github_user_id)
 336            .bind(&user.invite_count)
 337            .bind(random_invite_code())
 338            .fetch_one(&mut tx)
 339            .await?;
 340
 341            sqlx::query(
 342                "
 343                UPDATE signups
 344                SET user_id = $1
 345                WHERE id = $2
 346                ",
 347            )
 348            .bind(&user_id)
 349            .bind(&signup_id)
 350            .execute(&mut tx)
 351            .await?;
 352
 353            if let Some(inviting_user_id) = inviting_user_id {
 354                let id: Option<UserId> = sqlx::query_scalar(
 355                    "
 356                    UPDATE users
 357                    SET invite_count = invite_count - 1
 358                    WHERE id = $1 AND invite_count > 0
 359                    RETURNING id
 360                    ",
 361                )
 362                .bind(&inviting_user_id)
 363                .fetch_optional(&mut tx)
 364                .await?;
 365
 366                if id.is_none() {
 367                    Err(Error::Http(
 368                        StatusCode::UNAUTHORIZED,
 369                        "no invites remaining".to_string(),
 370                    ))?;
 371                }
 372
 373                sqlx::query(
 374                    "
 375                    INSERT INTO contacts
 376                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 377                    VALUES
 378                        ($1, $2, TRUE, TRUE, TRUE)
 379                    ON CONFLICT DO NOTHING
 380                    ",
 381                )
 382                .bind(inviting_user_id)
 383                .bind(user_id)
 384                .execute(&mut tx)
 385                .await?;
 386            }
 387
 388            tx.commit().await?;
 389            Ok(Some(NewUserResult {
 390                user_id,
 391                metrics_id,
 392                inviting_user_id,
 393                signup_device_id,
 394            }))
 395        })
 396    }
 397
 398    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 399        test_support!(self, {
 400            sqlx::query(
 401                "
 402                INSERT INTO signups
 403                (
 404                    email_address,
 405                    email_confirmation_code,
 406                    email_confirmation_sent,
 407                    platform_linux,
 408                    platform_mac,
 409                    platform_windows,
 410                    platform_unknown,
 411                    editor_features,
 412                    programming_languages,
 413                    device_id
 414                )
 415                VALUES
 416                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 417                RETURNING id
 418                ",
 419            )
 420            .bind(&signup.email_address)
 421            .bind(&random_email_confirmation_code())
 422            .bind(&signup.platform_linux)
 423            .bind(&signup.platform_mac)
 424            .bind(&signup.platform_windows)
 425            .bind(&signup.editor_features)
 426            .bind(&signup.programming_languages)
 427            .bind(&signup.device_id)
 428            .execute(&self.pool)
 429            .await?;
 430            Ok(())
 431        })
 432    }
 433
 434    pub async fn create_invite_from_code(
 435        &self,
 436        code: &str,
 437        email_address: &str,
 438        device_id: Option<&str>,
 439    ) -> Result<Invite> {
 440        test_support!(self, {
 441            let mut tx = self.pool.begin().await?;
 442
 443            let existing_user: Option<UserId> = sqlx::query_scalar(
 444                "
 445                SELECT id
 446                FROM users
 447                WHERE email_address = $1
 448                ",
 449            )
 450            .bind(email_address)
 451            .fetch_optional(&mut tx)
 452            .await?;
 453            if existing_user.is_some() {
 454                Err(anyhow!("email address is already in use"))?;
 455            }
 456
 457            let row: Option<(UserId, i32)> = sqlx::query_as(
 458                "
 459                SELECT id, invite_count
 460                FROM users
 461                WHERE invite_code = $1
 462                ",
 463            )
 464            .bind(code)
 465            .fetch_optional(&mut tx)
 466            .await?;
 467
 468            let (inviter_id, invite_count) = match row {
 469                Some(row) => row,
 470                None => Err(Error::Http(
 471                    StatusCode::NOT_FOUND,
 472                    "invite code not found".to_string(),
 473                ))?,
 474            };
 475
 476            if invite_count == 0 {
 477                Err(Error::Http(
 478                    StatusCode::UNAUTHORIZED,
 479                    "no invites remaining".to_string(),
 480                ))?;
 481            }
 482
 483            let email_confirmation_code: String = sqlx::query_scalar(
 484                "
 485                INSERT INTO signups
 486                (
 487                    email_address,
 488                    email_confirmation_code,
 489                    email_confirmation_sent,
 490                    inviting_user_id,
 491                    platform_linux,
 492                    platform_mac,
 493                    platform_windows,
 494                    platform_unknown,
 495                    device_id
 496                )
 497                VALUES
 498                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 499                ON CONFLICT (email_address)
 500                DO UPDATE SET
 501                    inviting_user_id = excluded.inviting_user_id
 502                RETURNING email_confirmation_code
 503                ",
 504            )
 505            .bind(&email_address)
 506            .bind(&random_email_confirmation_code())
 507            .bind(&inviter_id)
 508            .bind(&device_id)
 509            .fetch_one(&mut tx)
 510            .await?;
 511
 512            tx.commit().await?;
 513
 514            Ok(Invite {
 515                email_address: email_address.into(),
 516                email_confirmation_code,
 517            })
 518        })
 519    }
 520
 521    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 522        test_support!(self, {
 523            let emails = invites
 524                .iter()
 525                .map(|s| s.email_address.as_str())
 526                .collect::<Vec<_>>();
 527            sqlx::query(
 528                "
 529                UPDATE signups
 530                SET email_confirmation_sent = TRUE
 531                WHERE email_address = ANY ($1)
 532                ",
 533            )
 534            .bind(&emails)
 535            .execute(&self.pool)
 536            .await?;
 537            Ok(())
 538        })
 539    }
 540}
 541
 542impl<D> Db<D>
 543where
 544    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 545    D::Connection: sqlx::migrate::Migrate,
 546    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 547    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 548    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 549    D::QueryResult: RowsAffected,
 550    String: sqlx::Type<D>,
 551    i32: sqlx::Type<D>,
 552    i64: sqlx::Type<D>,
 553    bool: sqlx::Type<D>,
 554    str: sqlx::Type<D>,
 555    Uuid: sqlx::Type<D>,
 556    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 557    OffsetDateTime: sqlx::Type<D>,
 558    PrimitiveDateTime: sqlx::Type<D>,
 559    usize: sqlx::ColumnIndex<D::Row>,
 560    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 561    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 562    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 563    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 564    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 565    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 566    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 567    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 568    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 569    for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 570    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 571    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 572    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 573{
 574    pub async fn migrate(
 575        &self,
 576        migrations_path: &Path,
 577        ignore_checksum_mismatch: bool,
 578    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 579        let migrations = MigrationSource::resolve(migrations_path)
 580            .await
 581            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 582
 583        let mut conn = self.pool.acquire().await?;
 584
 585        conn.ensure_migrations_table().await?;
 586        let applied_migrations: HashMap<_, _> = conn
 587            .list_applied_migrations()
 588            .await?
 589            .into_iter()
 590            .map(|m| (m.version, m))
 591            .collect();
 592
 593        let mut new_migrations = Vec::new();
 594        for migration in migrations {
 595            match applied_migrations.get(&migration.version) {
 596                Some(applied_migration) => {
 597                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 598                    {
 599                        Err(anyhow!(
 600                            "checksum mismatch for applied migration {}",
 601                            migration.description
 602                        ))?;
 603                    }
 604                }
 605                None => {
 606                    let elapsed = conn.apply(&migration).await?;
 607                    new_migrations.push((migration, elapsed));
 608                }
 609            }
 610        }
 611
 612        Ok(new_migrations)
 613    }
 614
 615    pub fn fuzzy_like_string(string: &str) -> String {
 616        let mut result = String::with_capacity(string.len() * 2 + 1);
 617        for c in string.chars() {
 618            if c.is_alphanumeric() {
 619                result.push('%');
 620                result.push(c);
 621            }
 622        }
 623        result.push('%');
 624        result
 625    }
 626
 627    // users
 628
 629    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 630        test_support!(self, {
 631            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 632            Ok(sqlx::query_as(query)
 633                .bind(limit as i32)
 634                .bind((page * limit) as i32)
 635                .fetch_all(&self.pool)
 636                .await?)
 637        })
 638    }
 639
 640    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 641        test_support!(self, {
 642            let query = "
 643                SELECT users.*
 644                FROM users
 645                WHERE id = $1
 646                LIMIT 1
 647            ";
 648            Ok(sqlx::query_as(query)
 649                .bind(&id)
 650                .fetch_optional(&self.pool)
 651                .await?)
 652        })
 653    }
 654
 655    pub async fn get_users_with_no_invites(
 656        &self,
 657        invited_by_another_user: bool,
 658    ) -> Result<Vec<User>> {
 659        test_support!(self, {
 660            let query = format!(
 661                "
 662                SELECT users.*
 663                FROM users
 664                WHERE invite_count = 0
 665                AND inviter_id IS{} NULL
 666                ",
 667                if invited_by_another_user { " NOT" } else { "" }
 668            );
 669
 670            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 671        })
 672    }
 673
 674    pub async fn get_user_by_github_account(
 675        &self,
 676        github_login: &str,
 677        github_user_id: Option<i32>,
 678    ) -> Result<Option<User>> {
 679        test_support!(self, {
 680            if let Some(github_user_id) = github_user_id {
 681                let mut user = sqlx::query_as::<_, User>(
 682                    "
 683                    UPDATE users
 684                    SET github_login = $1
 685                    WHERE github_user_id = $2
 686                    RETURNING *
 687                    ",
 688                )
 689                .bind(github_login)
 690                .bind(github_user_id)
 691                .fetch_optional(&self.pool)
 692                .await?;
 693
 694                if user.is_none() {
 695                    user = sqlx::query_as::<_, User>(
 696                        "
 697                        UPDATE users
 698                        SET github_user_id = $1
 699                        WHERE github_login = $2
 700                        RETURNING *
 701                        ",
 702                    )
 703                    .bind(github_user_id)
 704                    .bind(github_login)
 705                    .fetch_optional(&self.pool)
 706                    .await?;
 707                }
 708
 709                Ok(user)
 710            } else {
 711                let user = sqlx::query_as(
 712                    "
 713                    SELECT * FROM users
 714                    WHERE github_login = $1
 715                    LIMIT 1
 716                    ",
 717                )
 718                .bind(github_login)
 719                .fetch_optional(&self.pool)
 720                .await?;
 721                Ok(user)
 722            }
 723        })
 724    }
 725
 726    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 727        test_support!(self, {
 728            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 729            Ok(sqlx::query(query)
 730                .bind(is_admin)
 731                .bind(id.0)
 732                .execute(&self.pool)
 733                .await
 734                .map(drop)?)
 735        })
 736    }
 737
 738    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 739        test_support!(self, {
 740            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 741            Ok(sqlx::query(query)
 742                .bind(connected_once)
 743                .bind(id.0)
 744                .execute(&self.pool)
 745                .await
 746                .map(drop)?)
 747        })
 748    }
 749
 750    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 751        test_support!(self, {
 752            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 753            sqlx::query(query)
 754                .bind(id.0)
 755                .execute(&self.pool)
 756                .await
 757                .map(drop)?;
 758            let query = "DELETE FROM users WHERE id = $1;";
 759            Ok(sqlx::query(query)
 760                .bind(id.0)
 761                .execute(&self.pool)
 762                .await
 763                .map(drop)?)
 764        })
 765    }
 766
 767    // signups
 768
 769    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 770        test_support!(self, {
 771            Ok(sqlx::query_as(
 772                "
 773                SELECT
 774                    COUNT(*) as count,
 775                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 776                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 777                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 778                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 779                FROM (
 780                    SELECT *
 781                    FROM signups
 782                    WHERE
 783                        NOT email_confirmation_sent
 784                ) AS unsent
 785                ",
 786            )
 787            .fetch_one(&self.pool)
 788            .await?)
 789        })
 790    }
 791
 792    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 793        test_support!(self, {
 794            Ok(sqlx::query_as(
 795                "
 796                SELECT
 797                    email_address, email_confirmation_code
 798                FROM signups
 799                WHERE
 800                    NOT email_confirmation_sent AND
 801                    (platform_mac OR platform_unknown)
 802                LIMIT $1
 803                ",
 804            )
 805            .bind(count as i32)
 806            .fetch_all(&self.pool)
 807            .await?)
 808        })
 809    }
 810
 811    // invite codes
 812
 813    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 814        test_support!(self, {
 815            let mut tx = self.pool.begin().await?;
 816            if count > 0 {
 817                sqlx::query(
 818                    "
 819                    UPDATE users
 820                    SET invite_code = $1
 821                    WHERE id = $2 AND invite_code IS NULL
 822                ",
 823                )
 824                .bind(random_invite_code())
 825                .bind(id)
 826                .execute(&mut tx)
 827                .await?;
 828            }
 829
 830            sqlx::query(
 831                "
 832                UPDATE users
 833                SET invite_count = $1
 834                WHERE id = $2
 835                ",
 836            )
 837            .bind(count as i32)
 838            .bind(id)
 839            .execute(&mut tx)
 840            .await?;
 841            tx.commit().await?;
 842            Ok(())
 843        })
 844    }
 845
 846    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 847        test_support!(self, {
 848            let result: Option<(String, i32)> = sqlx::query_as(
 849                "
 850                    SELECT invite_code, invite_count
 851                    FROM users
 852                    WHERE id = $1 AND invite_code IS NOT NULL 
 853                ",
 854            )
 855            .bind(id)
 856            .fetch_optional(&self.pool)
 857            .await?;
 858            if let Some((code, count)) = result {
 859                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 860            } else {
 861                Ok(None)
 862            }
 863        })
 864    }
 865
 866    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 867        test_support!(self, {
 868            sqlx::query_as(
 869                "
 870                    SELECT *
 871                    FROM users
 872                    WHERE invite_code = $1
 873                ",
 874            )
 875            .bind(code)
 876            .fetch_optional(&self.pool)
 877            .await?
 878            .ok_or_else(|| {
 879                Error::Http(
 880                    StatusCode::NOT_FOUND,
 881                    "that invite code does not exist".to_string(),
 882                )
 883            })
 884        })
 885    }
 886
 887    pub async fn create_room(
 888        &self,
 889        user_id: UserId,
 890        connection_id: ConnectionId,
 891    ) -> Result<proto::Room> {
 892        test_support!(self, {
 893            let mut tx = self.pool.begin().await?;
 894            let live_kit_room = nanoid::nanoid!(30);
 895            let room_id = sqlx::query_scalar(
 896                "
 897                INSERT INTO rooms (live_kit_room, version)
 898                VALUES ($1, $2)
 899                RETURNING id
 900                ",
 901            )
 902            .bind(&live_kit_room)
 903            .bind(0)
 904            .fetch_one(&mut tx)
 905            .await
 906            .map(RoomId)?;
 907
 908            sqlx::query(
 909                "
 910                INSERT INTO room_participants (room_id, user_id, connection_id, calling_user_id)
 911                VALUES ($1, $2, $3, $4)
 912                ",
 913            )
 914            .bind(room_id)
 915            .bind(user_id)
 916            .bind(connection_id.0 as i32)
 917            .bind(user_id)
 918            .execute(&mut tx)
 919            .await?;
 920
 921            self.commit_room_transaction(room_id, tx).await
 922        })
 923    }
 924
 925    pub async fn call(
 926        &self,
 927        room_id: RoomId,
 928        calling_user_id: UserId,
 929        called_user_id: UserId,
 930        initial_project_id: Option<ProjectId>,
 931    ) -> Result<(proto::Room, proto::IncomingCall)> {
 932        test_support!(self, {
 933            let mut tx = self.pool.begin().await?;
 934            sqlx::query(
 935                "
 936                INSERT INTO room_participants (room_id, user_id, calling_user_id, initial_project_id)
 937                VALUES ($1, $2, $3, $4)
 938                ",
 939            )
 940            .bind(room_id)
 941            .bind(called_user_id)
 942            .bind(calling_user_id)
 943            .bind(initial_project_id)
 944            .execute(&mut tx)
 945            .await?;
 946
 947            let room = self.commit_room_transaction(room_id, tx).await?;
 948            let incoming_call = Self::build_incoming_call(&room, called_user_id)
 949                .ok_or_else(|| anyhow!("failed to build incoming call"))?;
 950            Ok((room, incoming_call))
 951        })
 952    }
 953
 954    pub async fn incoming_call_for_user(
 955        &self,
 956        user_id: UserId,
 957    ) -> Result<Option<proto::IncomingCall>> {
 958        test_support!(self, {
 959            let mut tx = self.pool.begin().await?;
 960            let room_id = sqlx::query_scalar::<_, RoomId>(
 961                "
 962                SELECT room_id
 963                FROM room_participants
 964                WHERE user_id = $1 AND connection_id IS NULL
 965                ",
 966            )
 967            .bind(user_id)
 968            .fetch_optional(&mut tx)
 969            .await?;
 970
 971            if let Some(room_id) = room_id {
 972                let room = self.get_room(room_id, &mut tx).await?;
 973                Ok(Self::build_incoming_call(&room, user_id))
 974            } else {
 975                Ok(None)
 976            }
 977        })
 978    }
 979
 980    fn build_incoming_call(
 981        room: &proto::Room,
 982        called_user_id: UserId,
 983    ) -> Option<proto::IncomingCall> {
 984        let pending_participant = room
 985            .pending_participants
 986            .iter()
 987            .find(|participant| participant.user_id == called_user_id.to_proto())?;
 988
 989        Some(proto::IncomingCall {
 990            room_id: room.id,
 991            calling_user_id: pending_participant.calling_user_id,
 992            participant_user_ids: room
 993                .participants
 994                .iter()
 995                .map(|participant| participant.user_id)
 996                .collect(),
 997            initial_project: room.participants.iter().find_map(|participant| {
 998                let initial_project_id = pending_participant.initial_project_id?;
 999                participant
1000                    .projects
1001                    .iter()
1002                    .find(|project| project.id == initial_project_id)
1003                    .cloned()
1004            }),
1005        })
1006    }
1007
1008    pub async fn call_failed(
1009        &self,
1010        room_id: RoomId,
1011        called_user_id: UserId,
1012    ) -> Result<proto::Room> {
1013        test_support!(self, {
1014            let mut tx = self.pool.begin().await?;
1015            sqlx::query(
1016                "
1017                DELETE FROM room_participants
1018                WHERE room_id = $1 AND user_id = $2
1019                ",
1020            )
1021            .bind(room_id)
1022            .bind(called_user_id)
1023            .execute(&mut tx)
1024            .await?;
1025
1026            self.commit_room_transaction(room_id, tx).await
1027        })
1028    }
1029
1030    pub async fn decline_call(&self, room_id: RoomId, user_id: UserId) -> Result<proto::Room> {
1031        test_support!(self, {
1032            let mut tx = self.pool.begin().await?;
1033            sqlx::query(
1034                "
1035                DELETE FROM room_participants
1036                WHERE room_id = $1 AND user_id = $2 AND connection_id IS NULL
1037                ",
1038            )
1039            .bind(room_id)
1040            .bind(user_id)
1041            .execute(&mut tx)
1042            .await?;
1043
1044            self.commit_room_transaction(room_id, tx).await
1045        })
1046    }
1047
1048    pub async fn join_room(
1049        &self,
1050        room_id: RoomId,
1051        user_id: UserId,
1052        connection_id: ConnectionId,
1053    ) -> Result<proto::Room> {
1054        test_support!(self, {
1055            let mut tx = self.pool.begin().await?;
1056            sqlx::query(
1057                "
1058                UPDATE room_participants 
1059                SET connection_id = $1
1060                WHERE room_id = $2 AND user_id = $3
1061                RETURNING 1
1062                ",
1063            )
1064            .bind(connection_id.0 as i32)
1065            .bind(room_id)
1066            .bind(user_id)
1067            .fetch_one(&mut tx)
1068            .await?;
1069            self.commit_room_transaction(room_id, tx).await
1070        })
1071    }
1072
1073    pub async fn leave_room(
1074        &self,
1075        room_id: RoomId,
1076        connection_id: ConnectionId,
1077    ) -> Result<LeftRoom> {
1078        test_support!(self, {
1079            let mut tx = self.pool.begin().await?;
1080
1081            // Leave room.
1082            let user_id: UserId = sqlx::query_scalar(
1083                "
1084                DELETE FROM room_participants
1085                WHERE room_id = $1 AND connection_id = $2
1086                RETURNING user_id
1087                ",
1088            )
1089            .bind(room_id)
1090            .bind(connection_id.0 as i32)
1091            .fetch_one(&mut tx)
1092            .await?;
1093
1094            // Cancel pending calls initiated by the leaving user.
1095            let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1096                "
1097                DELETE FROM room_participants
1098                WHERE calling_user_id = $1 AND connection_id IS NULL
1099                RETURNING user_id
1100                ",
1101            )
1102            .bind(room_id)
1103            .bind(connection_id.0 as i32)
1104            .fetch_all(&mut tx)
1105            .await?;
1106
1107            let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1108                "
1109                SELECT project_collaborators.*
1110                FROM projects, project_collaborators
1111                WHERE
1112                    projects.room_id = $1 AND
1113                    projects.user_id = $2 AND
1114                    projects.id = project_collaborators.project_id
1115                ",
1116            )
1117            .bind(room_id)
1118            .bind(user_id)
1119            .fetch(&mut tx);
1120
1121            let mut left_projects = HashMap::default();
1122            while let Some(collaborator) = project_collaborators.next().await {
1123                let collaborator = collaborator?;
1124                let left_project =
1125                    left_projects
1126                        .entry(collaborator.project_id)
1127                        .or_insert(LeftProject {
1128                            id: collaborator.project_id,
1129                            host_user_id: Default::default(),
1130                            connection_ids: Default::default(),
1131                        });
1132
1133                let collaborator_connection_id = ConnectionId(collaborator.connection_id as u32);
1134                if collaborator_connection_id != connection_id || collaborator.is_host {
1135                    left_project.connection_ids.push(collaborator_connection_id);
1136                }
1137
1138                if collaborator.is_host {
1139                    left_project.host_user_id = collaborator.user_id;
1140                }
1141            }
1142            drop(project_collaborators);
1143
1144            sqlx::query(
1145                "
1146                DELETE FROM projects
1147                WHERE room_id = $1 AND user_id = $2
1148                ",
1149            )
1150            .bind(room_id)
1151            .bind(user_id)
1152            .execute(&mut tx)
1153            .await?;
1154
1155            let room = self.commit_room_transaction(room_id, tx).await?;
1156            Ok(LeftRoom {
1157                room,
1158                left_projects,
1159                canceled_calls_to_user_ids,
1160            })
1161        })
1162    }
1163
1164    pub async fn update_room_participant_location(
1165        &self,
1166        room_id: RoomId,
1167        user_id: UserId,
1168        location: proto::ParticipantLocation,
1169    ) -> Result<proto::Room> {
1170        test_support!(self, {
1171            let mut tx = self.pool.begin().await?;
1172
1173            let location_kind;
1174            let location_project_id;
1175            match location
1176                .variant
1177                .ok_or_else(|| anyhow!("invalid location"))?
1178            {
1179                proto::participant_location::Variant::SharedProject(project) => {
1180                    location_kind = 0;
1181                    location_project_id = Some(ProjectId::from_proto(project.id));
1182                }
1183                proto::participant_location::Variant::UnsharedProject(_) => {
1184                    location_kind = 1;
1185                    location_project_id = None;
1186                }
1187                proto::participant_location::Variant::External(_) => {
1188                    location_kind = 2;
1189                    location_project_id = None;
1190                }
1191            }
1192
1193            sqlx::query(
1194                "
1195                UPDATE room_participants
1196                SET location_kind = $1 AND location_project_id = $2
1197                WHERE room_id = $1 AND user_id = $2
1198                ",
1199            )
1200            .bind(location_kind)
1201            .bind(location_project_id)
1202            .bind(room_id)
1203            .bind(user_id)
1204            .execute(&mut tx)
1205            .await?;
1206
1207            self.commit_room_transaction(room_id, tx).await
1208        })
1209    }
1210
1211    async fn commit_room_transaction(
1212        &self,
1213        room_id: RoomId,
1214        mut tx: sqlx::Transaction<'_, D>,
1215    ) -> Result<proto::Room> {
1216        sqlx::query(
1217            "
1218            UPDATE rooms
1219            SET version = version + 1
1220            WHERE id = $1
1221            ",
1222        )
1223        .bind(room_id)
1224        .execute(&mut tx)
1225        .await?;
1226        let room = self.get_room(room_id, &mut tx).await?;
1227        tx.commit().await?;
1228
1229        Ok(room)
1230    }
1231
1232    async fn get_room(
1233        &self,
1234        room_id: RoomId,
1235        tx: &mut sqlx::Transaction<'_, D>,
1236    ) -> Result<proto::Room> {
1237        let room: Room = sqlx::query_as(
1238            "
1239            SELECT *
1240            FROM rooms
1241            WHERE id = $1
1242            ",
1243        )
1244        .bind(room_id)
1245        .fetch_one(&mut *tx)
1246        .await?;
1247
1248        let mut db_participants =
1249            sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1250                "
1251                SELECT user_id, connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1252                FROM room_participants
1253                WHERE room_id = $1
1254                ",
1255            )
1256            .bind(room_id)
1257            .fetch(&mut *tx);
1258
1259        let mut participants = Vec::new();
1260        let mut pending_participants = Vec::new();
1261        while let Some(participant) = db_participants.next().await {
1262            let (
1263                user_id,
1264                connection_id,
1265                _location_kind,
1266                _location_project_id,
1267                calling_user_id,
1268                initial_project_id,
1269            ) = participant?;
1270            if let Some(connection_id) = connection_id {
1271                participants.push(proto::Participant {
1272                    user_id: user_id.to_proto(),
1273                    peer_id: connection_id as u32,
1274                    projects: Default::default(),
1275                    location: Some(proto::ParticipantLocation {
1276                        variant: Some(proto::participant_location::Variant::External(
1277                            Default::default(),
1278                        )),
1279                    }),
1280                });
1281            } else {
1282                pending_participants.push(proto::PendingParticipant {
1283                    user_id: user_id.to_proto(),
1284                    calling_user_id: calling_user_id.to_proto(),
1285                    initial_project_id: initial_project_id.map(|id| id.to_proto()),
1286                });
1287            }
1288        }
1289        drop(db_participants);
1290
1291        for participant in &mut participants {
1292            let mut entries = sqlx::query_as::<_, (ProjectId, String)>(
1293                "
1294                SELECT projects.id, worktrees.root_name
1295                FROM projects
1296                LEFT JOIN worktrees ON projects.id = worktrees.project_id
1297                WHERE room_id = $1 AND host_user_id = $2
1298                ",
1299            )
1300            .bind(room_id)
1301            .fetch(&mut *tx);
1302
1303            let mut projects = HashMap::default();
1304            while let Some(entry) = entries.next().await {
1305                let (project_id, worktree_root_name) = entry?;
1306                let participant_project =
1307                    projects
1308                        .entry(project_id)
1309                        .or_insert(proto::ParticipantProject {
1310                            id: project_id.to_proto(),
1311                            worktree_root_names: Default::default(),
1312                        });
1313                participant_project
1314                    .worktree_root_names
1315                    .push(worktree_root_name);
1316            }
1317
1318            participant.projects = projects.into_values().collect();
1319        }
1320        Ok(proto::Room {
1321            id: room.id.to_proto(),
1322            version: room.version as u64,
1323            live_kit_room: room.live_kit_room,
1324            participants,
1325            pending_participants,
1326        })
1327    }
1328
1329    // projects
1330
1331    pub async fn share_project(
1332        &self,
1333        user_id: UserId,
1334        connection_id: ConnectionId,
1335        room_id: RoomId,
1336        worktrees: &[proto::WorktreeMetadata],
1337    ) -> Result<(ProjectId, proto::Room)> {
1338        test_support!(self, {
1339            let mut tx = self.pool.begin().await?;
1340            let project_id = sqlx::query_scalar(
1341                "
1342                INSERT INTO projects (host_user_id, room_id)
1343                VALUES ($1)
1344                RETURNING id
1345            ",
1346            )
1347            .bind(user_id)
1348            .bind(room_id)
1349            .fetch_one(&mut tx)
1350            .await
1351            .map(ProjectId)?;
1352
1353            for worktree in worktrees {
1354                sqlx::query(
1355                    "
1356                INSERT INTO worktrees (id, project_id, root_name)
1357                ",
1358                )
1359                .bind(worktree.id as i32)
1360                .bind(project_id)
1361                .bind(&worktree.root_name)
1362                .execute(&mut tx)
1363                .await?;
1364            }
1365
1366            sqlx::query(
1367                "
1368                INSERT INTO project_collaborators (
1369                project_id,
1370                connection_id,
1371                user_id,
1372                replica_id,
1373                is_host
1374                )
1375                VALUES ($1, $2, $3, $4, $5)
1376                ",
1377            )
1378            .bind(project_id)
1379            .bind(connection_id.0 as i32)
1380            .bind(user_id)
1381            .bind(0)
1382            .bind(true)
1383            .execute(&mut tx)
1384            .await?;
1385
1386            let room = self.commit_room_transaction(room_id, tx).await?;
1387            Ok((project_id, room))
1388        })
1389    }
1390
1391    pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
1392        todo!()
1393        // test_support!(self, {
1394        //     sqlx::query(
1395        //         "
1396        //         UPDATE projects
1397        //         SET unregistered = TRUE
1398        //         WHERE id = $1
1399        //         ",
1400        //     )
1401        //     .bind(project_id)
1402        //     .execute(&self.pool)
1403        //     .await?;
1404        //     Ok(())
1405        // })
1406    }
1407
1408    // contacts
1409
1410    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1411        test_support!(self, {
1412            let query = "
1413                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1414                FROM contacts
1415                WHERE user_id_a = $1 OR user_id_b = $1;
1416            ";
1417
1418            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1419                .bind(user_id)
1420                .fetch(&self.pool);
1421
1422            let mut contacts = Vec::new();
1423            while let Some(row) = rows.next().await {
1424                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1425
1426                if user_id_a == user_id {
1427                    if accepted {
1428                        contacts.push(Contact::Accepted {
1429                            user_id: user_id_b,
1430                            should_notify: should_notify && a_to_b,
1431                        });
1432                    } else if a_to_b {
1433                        contacts.push(Contact::Outgoing { user_id: user_id_b })
1434                    } else {
1435                        contacts.push(Contact::Incoming {
1436                            user_id: user_id_b,
1437                            should_notify,
1438                        });
1439                    }
1440                } else if accepted {
1441                    contacts.push(Contact::Accepted {
1442                        user_id: user_id_a,
1443                        should_notify: should_notify && !a_to_b,
1444                    });
1445                } else if a_to_b {
1446                    contacts.push(Contact::Incoming {
1447                        user_id: user_id_a,
1448                        should_notify,
1449                    });
1450                } else {
1451                    contacts.push(Contact::Outgoing { user_id: user_id_a });
1452                }
1453            }
1454
1455            contacts.sort_unstable_by_key(|contact| contact.user_id());
1456
1457            Ok(contacts)
1458        })
1459    }
1460
1461    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1462        test_support!(self, {
1463            let (id_a, id_b) = if user_id_1 < user_id_2 {
1464                (user_id_1, user_id_2)
1465            } else {
1466                (user_id_2, user_id_1)
1467            };
1468
1469            let query = "
1470                SELECT 1 FROM contacts
1471                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1472                LIMIT 1
1473            ";
1474            Ok(sqlx::query_scalar::<_, i32>(query)
1475                .bind(id_a.0)
1476                .bind(id_b.0)
1477                .fetch_optional(&self.pool)
1478                .await?
1479                .is_some())
1480        })
1481    }
1482
1483    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1484        test_support!(self, {
1485            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1486                (sender_id, receiver_id, true)
1487            } else {
1488                (receiver_id, sender_id, false)
1489            };
1490            let query = "
1491                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1492                VALUES ($1, $2, $3, FALSE, TRUE)
1493                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1494                SET
1495                    accepted = TRUE,
1496                    should_notify = FALSE
1497                WHERE
1498                    NOT contacts.accepted AND
1499                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1500                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1501            ";
1502            let result = sqlx::query(query)
1503                .bind(id_a.0)
1504                .bind(id_b.0)
1505                .bind(a_to_b)
1506                .execute(&self.pool)
1507                .await?;
1508
1509            if result.rows_affected() == 1 {
1510                Ok(())
1511            } else {
1512                Err(anyhow!("contact already requested"))?
1513            }
1514        })
1515    }
1516
1517    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1518        test_support!(self, {
1519            let (id_a, id_b) = if responder_id < requester_id {
1520                (responder_id, requester_id)
1521            } else {
1522                (requester_id, responder_id)
1523            };
1524            let query = "
1525                DELETE FROM contacts
1526                WHERE user_id_a = $1 AND user_id_b = $2;
1527            ";
1528            let result = sqlx::query(query)
1529                .bind(id_a.0)
1530                .bind(id_b.0)
1531                .execute(&self.pool)
1532                .await?;
1533
1534            if result.rows_affected() == 1 {
1535                Ok(())
1536            } else {
1537                Err(anyhow!("no such contact"))?
1538            }
1539        })
1540    }
1541
1542    pub async fn dismiss_contact_notification(
1543        &self,
1544        user_id: UserId,
1545        contact_user_id: UserId,
1546    ) -> Result<()> {
1547        test_support!(self, {
1548            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1549                (user_id, contact_user_id, true)
1550            } else {
1551                (contact_user_id, user_id, false)
1552            };
1553
1554            let query = "
1555                UPDATE contacts
1556                SET should_notify = FALSE
1557                WHERE
1558                    user_id_a = $1 AND user_id_b = $2 AND
1559                    (
1560                        (a_to_b = $3 AND accepted) OR
1561                        (a_to_b != $3 AND NOT accepted)
1562                    );
1563            ";
1564
1565            let result = sqlx::query(query)
1566                .bind(id_a.0)
1567                .bind(id_b.0)
1568                .bind(a_to_b)
1569                .execute(&self.pool)
1570                .await?;
1571
1572            if result.rows_affected() == 0 {
1573                Err(anyhow!("no such contact request"))?;
1574            }
1575
1576            Ok(())
1577        })
1578    }
1579
1580    pub async fn respond_to_contact_request(
1581        &self,
1582        responder_id: UserId,
1583        requester_id: UserId,
1584        accept: bool,
1585    ) -> Result<()> {
1586        test_support!(self, {
1587            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1588                (responder_id, requester_id, false)
1589            } else {
1590                (requester_id, responder_id, true)
1591            };
1592            let result = if accept {
1593                let query = "
1594                    UPDATE contacts
1595                    SET accepted = TRUE, should_notify = TRUE
1596                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1597                ";
1598                sqlx::query(query)
1599                    .bind(id_a.0)
1600                    .bind(id_b.0)
1601                    .bind(a_to_b)
1602                    .execute(&self.pool)
1603                    .await?
1604            } else {
1605                let query = "
1606                    DELETE FROM contacts
1607                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1608                ";
1609                sqlx::query(query)
1610                    .bind(id_a.0)
1611                    .bind(id_b.0)
1612                    .bind(a_to_b)
1613                    .execute(&self.pool)
1614                    .await?
1615            };
1616            if result.rows_affected() == 1 {
1617                Ok(())
1618            } else {
1619                Err(anyhow!("no such contact request"))?
1620            }
1621        })
1622    }
1623
1624    // access tokens
1625
1626    pub async fn create_access_token_hash(
1627        &self,
1628        user_id: UserId,
1629        access_token_hash: &str,
1630        max_access_token_count: usize,
1631    ) -> Result<()> {
1632        test_support!(self, {
1633            let insert_query = "
1634                INSERT INTO access_tokens (user_id, hash)
1635                VALUES ($1, $2);
1636            ";
1637            let cleanup_query = "
1638                DELETE FROM access_tokens
1639                WHERE id IN (
1640                    SELECT id from access_tokens
1641                    WHERE user_id = $1
1642                    ORDER BY id DESC
1643                    LIMIT 10000
1644                    OFFSET $3
1645                )
1646            ";
1647
1648            let mut tx = self.pool.begin().await?;
1649            sqlx::query(insert_query)
1650                .bind(user_id.0)
1651                .bind(access_token_hash)
1652                .execute(&mut tx)
1653                .await?;
1654            sqlx::query(cleanup_query)
1655                .bind(user_id.0)
1656                .bind(access_token_hash)
1657                .bind(max_access_token_count as i32)
1658                .execute(&mut tx)
1659                .await?;
1660            Ok(tx.commit().await?)
1661        })
1662    }
1663
1664    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1665        test_support!(self, {
1666            let query = "
1667                SELECT hash
1668                FROM access_tokens
1669                WHERE user_id = $1
1670                ORDER BY id DESC
1671            ";
1672            Ok(sqlx::query_scalar(query)
1673                .bind(user_id.0)
1674                .fetch_all(&self.pool)
1675                .await?)
1676        })
1677    }
1678}
1679
1680macro_rules! id_type {
1681    ($name:ident) => {
1682        #[derive(
1683            Clone,
1684            Copy,
1685            Debug,
1686            Default,
1687            PartialEq,
1688            Eq,
1689            PartialOrd,
1690            Ord,
1691            Hash,
1692            sqlx::Type,
1693            Serialize,
1694            Deserialize,
1695        )]
1696        #[sqlx(transparent)]
1697        #[serde(transparent)]
1698        pub struct $name(pub i32);
1699
1700        impl $name {
1701            #[allow(unused)]
1702            pub const MAX: Self = Self(i32::MAX);
1703
1704            #[allow(unused)]
1705            pub fn from_proto(value: u64) -> Self {
1706                Self(value as i32)
1707            }
1708
1709            #[allow(unused)]
1710            pub fn to_proto(self) -> u64 {
1711                self.0 as u64
1712            }
1713        }
1714
1715        impl std::fmt::Display for $name {
1716            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1717                self.0.fmt(f)
1718            }
1719        }
1720    };
1721}
1722
1723id_type!(UserId);
1724#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1725pub struct User {
1726    pub id: UserId,
1727    pub github_login: String,
1728    pub github_user_id: Option<i32>,
1729    pub email_address: Option<String>,
1730    pub admin: bool,
1731    pub invite_code: Option<String>,
1732    pub invite_count: i32,
1733    pub connected_once: bool,
1734}
1735
1736id_type!(RoomId);
1737#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1738pub struct Room {
1739    pub id: RoomId,
1740    pub version: i32,
1741    pub live_kit_room: String,
1742}
1743
1744#[derive(Clone, Debug, Default, FromRow, PartialEq)]
1745pub struct Call {
1746    pub room_id: RoomId,
1747    pub calling_user_id: UserId,
1748    pub called_user_id: UserId,
1749    pub answering_connection_id: Option<i32>,
1750    pub initial_project_id: Option<ProjectId>,
1751}
1752
1753id_type!(ProjectId);
1754#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1755pub struct Project {
1756    pub id: ProjectId,
1757    pub host_user_id: UserId,
1758    pub unregistered: bool,
1759}
1760
1761#[derive(Clone, Debug, Default, FromRow, PartialEq)]
1762pub struct ProjectCollaborator {
1763    pub project_id: ProjectId,
1764    pub connection_id: i32,
1765    pub user_id: UserId,
1766    pub replica_id: i32,
1767    pub is_host: bool,
1768}
1769
1770pub struct LeftProject {
1771    pub id: ProjectId,
1772    pub host_user_id: UserId,
1773    pub connection_ids: Vec<ConnectionId>,
1774}
1775
1776pub struct LeftRoom {
1777    pub room: proto::Room,
1778    pub left_projects: HashMap<ProjectId, LeftProject>,
1779    pub canceled_calls_to_user_ids: Vec<UserId>,
1780}
1781
1782#[derive(Clone, Debug, PartialEq, Eq)]
1783pub enum Contact {
1784    Accepted {
1785        user_id: UserId,
1786        should_notify: bool,
1787    },
1788    Outgoing {
1789        user_id: UserId,
1790    },
1791    Incoming {
1792        user_id: UserId,
1793        should_notify: bool,
1794    },
1795}
1796
1797impl Contact {
1798    pub fn user_id(&self) -> UserId {
1799        match self {
1800            Contact::Accepted { user_id, .. } => *user_id,
1801            Contact::Outgoing { user_id } => *user_id,
1802            Contact::Incoming { user_id, .. } => *user_id,
1803        }
1804    }
1805}
1806
1807#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1808pub struct IncomingContactRequest {
1809    pub requester_id: UserId,
1810    pub should_notify: bool,
1811}
1812
1813#[derive(Clone, Deserialize)]
1814pub struct Signup {
1815    pub email_address: String,
1816    pub platform_mac: bool,
1817    pub platform_windows: bool,
1818    pub platform_linux: bool,
1819    pub editor_features: Vec<String>,
1820    pub programming_languages: Vec<String>,
1821    pub device_id: Option<String>,
1822}
1823
1824#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1825pub struct WaitlistSummary {
1826    #[sqlx(default)]
1827    pub count: i64,
1828    #[sqlx(default)]
1829    pub linux_count: i64,
1830    #[sqlx(default)]
1831    pub mac_count: i64,
1832    #[sqlx(default)]
1833    pub windows_count: i64,
1834    #[sqlx(default)]
1835    pub unknown_count: i64,
1836}
1837
1838#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1839pub struct Invite {
1840    pub email_address: String,
1841    pub email_confirmation_code: String,
1842}
1843
1844#[derive(Debug, Serialize, Deserialize)]
1845pub struct NewUserParams {
1846    pub github_login: String,
1847    pub github_user_id: i32,
1848    pub invite_count: i32,
1849}
1850
1851#[derive(Debug)]
1852pub struct NewUserResult {
1853    pub user_id: UserId,
1854    pub metrics_id: String,
1855    pub inviting_user_id: Option<UserId>,
1856    pub signup_device_id: Option<String>,
1857}
1858
1859fn random_invite_code() -> String {
1860    nanoid::nanoid!(16)
1861}
1862
1863fn random_email_confirmation_code() -> String {
1864    nanoid::nanoid!(64)
1865}
1866
1867#[cfg(test)]
1868pub use test::*;
1869
1870#[cfg(test)]
1871mod test {
1872    use super::*;
1873    use gpui::executor::Background;
1874    use lazy_static::lazy_static;
1875    use parking_lot::Mutex;
1876    use rand::prelude::*;
1877    use sqlx::migrate::MigrateDatabase;
1878    use std::sync::Arc;
1879
1880    pub struct SqliteTestDb {
1881        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1882        pub conn: sqlx::sqlite::SqliteConnection,
1883    }
1884
1885    pub struct PostgresTestDb {
1886        pub db: Option<Arc<Db<sqlx::Postgres>>>,
1887        pub url: String,
1888    }
1889
1890    impl SqliteTestDb {
1891        pub fn new(background: Arc<Background>) -> Self {
1892            let mut rng = StdRng::from_entropy();
1893            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1894            let runtime = tokio::runtime::Builder::new_current_thread()
1895                .enable_io()
1896                .enable_time()
1897                .build()
1898                .unwrap();
1899
1900            let (mut db, conn) = runtime.block_on(async {
1901                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1902                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1903                db.migrate(migrations_path.as_ref(), false).await.unwrap();
1904                let conn = db.pool.acquire().await.unwrap().detach();
1905                (db, conn)
1906            });
1907
1908            db.background = Some(background);
1909            db.runtime = Some(runtime);
1910
1911            Self {
1912                db: Some(Arc::new(db)),
1913                conn,
1914            }
1915        }
1916
1917        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1918            self.db.as_ref().unwrap()
1919        }
1920    }
1921
1922    impl PostgresTestDb {
1923        pub fn new(background: Arc<Background>) -> Self {
1924            lazy_static! {
1925                static ref LOCK: Mutex<()> = Mutex::new(());
1926            }
1927
1928            let _guard = LOCK.lock();
1929            let mut rng = StdRng::from_entropy();
1930            let url = format!(
1931                "postgres://postgres@localhost/zed-test-{}",
1932                rng.gen::<u128>()
1933            );
1934            let runtime = tokio::runtime::Builder::new_current_thread()
1935                .enable_io()
1936                .enable_time()
1937                .build()
1938                .unwrap();
1939
1940            let mut db = runtime.block_on(async {
1941                sqlx::Postgres::create_database(&url)
1942                    .await
1943                    .expect("failed to create test db");
1944                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1945                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1946                db.migrate(Path::new(migrations_path), false).await.unwrap();
1947                db
1948            });
1949
1950            db.background = Some(background);
1951            db.runtime = Some(runtime);
1952
1953            Self {
1954                db: Some(Arc::new(db)),
1955                url,
1956            }
1957        }
1958
1959        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1960            self.db.as_ref().unwrap()
1961        }
1962    }
1963
1964    impl Drop for PostgresTestDb {
1965        fn drop(&mut self) {
1966            let db = self.db.take().unwrap();
1967            db.teardown(&self.url);
1968        }
1969    }
1970}