db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::{BTreeMap, HashMap, HashSet};
   5use futures::{future::BoxFuture, FutureExt, StreamExt};
   6use rpc::{proto, ConnectionId};
   7use serde::{Deserialize, Serialize};
   8use sqlx::{
   9    migrate::{Migrate as _, Migration, MigrationSource},
  10    types::Uuid,
  11    FromRow,
  12};
  13use std::{
  14    future::Future,
  15    path::{Path, PathBuf},
  16    time::Duration,
  17};
  18use time::{OffsetDateTime, PrimitiveDateTime};
  19
  20#[cfg(test)]
  21pub type DefaultDb = Db<sqlx::Sqlite>;
  22
  23#[cfg(not(test))]
  24pub type DefaultDb = Db<sqlx::Postgres>;
  25
  26pub struct Db<D: sqlx::Database> {
  27    pool: sqlx::Pool<D>,
  28    #[cfg(test)]
  29    background: Option<std::sync::Arc<gpui::executor::Background>>,
  30    #[cfg(test)]
  31    runtime: Option<tokio::runtime::Runtime>,
  32}
  33
  34pub trait BeginTransaction: Send + Sync {
  35    type Database: sqlx::Database;
  36
  37    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
  38}
  39
  40// In Postgres, serializable transactions are opt-in
  41impl BeginTransaction for Db<sqlx::Postgres> {
  42    type Database = sqlx::Postgres;
  43
  44    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
  45        async move {
  46            let mut tx = self.pool.begin().await?;
  47            sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
  48                .await?;
  49            Ok(tx)
  50        }
  51        .boxed()
  52    }
  53}
  54
  55// In Sqlite, transactions are inherently serializable.
  56impl BeginTransaction for Db<sqlx::Sqlite> {
  57    type Database = sqlx::Sqlite;
  58
  59    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
  60        async move { Ok(self.pool.begin().await?) }.boxed()
  61    }
  62}
  63
  64pub trait RowsAffected {
  65    fn rows_affected(&self) -> u64;
  66}
  67
  68#[cfg(test)]
  69impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  70    fn rows_affected(&self) -> u64 {
  71        self.rows_affected()
  72    }
  73}
  74
  75impl RowsAffected for sqlx::postgres::PgQueryResult {
  76    fn rows_affected(&self) -> u64 {
  77        self.rows_affected()
  78    }
  79}
  80
  81#[cfg(test)]
  82impl Db<sqlx::Sqlite> {
  83    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  84        use std::str::FromStr as _;
  85        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  86            .unwrap()
  87            .create_if_missing(true)
  88            .shared_cache(true);
  89        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  90            .min_connections(2)
  91            .max_connections(max_connections)
  92            .connect_with(options)
  93            .await?;
  94        Ok(Self {
  95            pool,
  96            background: None,
  97            runtime: None,
  98        })
  99    }
 100
 101    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 102        self.transact(|tx| async {
 103            let mut tx = tx;
 104            let query = "
 105                SELECT users.*
 106                FROM users
 107                WHERE users.id IN (SELECT value from json_each($1))
 108            ";
 109            Ok(sqlx::query_as(query)
 110                .bind(&serde_json::json!(ids))
 111                .fetch_all(&mut tx)
 112                .await?)
 113        })
 114        .await
 115    }
 116
 117    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 118        self.transact(|mut tx| async move {
 119            let query = "
 120                SELECT metrics_id
 121                FROM users
 122                WHERE id = $1
 123            ";
 124            Ok(sqlx::query_scalar(query)
 125                .bind(id)
 126                .fetch_one(&mut tx)
 127                .await?)
 128        })
 129        .await
 130    }
 131
 132    pub async fn create_user(
 133        &self,
 134        email_address: &str,
 135        admin: bool,
 136        params: NewUserParams,
 137    ) -> Result<NewUserResult> {
 138        self.transact(|mut tx| async {
 139            let query = "
 140                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 141                VALUES ($1, $2, $3, $4, $5)
 142                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 143                RETURNING id, metrics_id
 144            ";
 145
 146            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 147                .bind(email_address)
 148                .bind(&params.github_login)
 149                .bind(&params.github_user_id)
 150                .bind(admin)
 151                .bind(Uuid::new_v4().to_string())
 152                .fetch_one(&mut tx)
 153                .await?;
 154            tx.commit().await?;
 155            Ok(NewUserResult {
 156                user_id,
 157                metrics_id,
 158                signup_device_id: None,
 159                inviting_user_id: None,
 160            })
 161        })
 162        .await
 163    }
 164
 165    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 166        unimplemented!()
 167    }
 168
 169    pub async fn create_user_from_invite(
 170        &self,
 171        _invite: &Invite,
 172        _user: NewUserParams,
 173    ) -> Result<Option<NewUserResult>> {
 174        unimplemented!()
 175    }
 176
 177    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 178        unimplemented!()
 179    }
 180
 181    pub async fn create_invite_from_code(
 182        &self,
 183        _code: &str,
 184        _email_address: &str,
 185        _device_id: Option<&str>,
 186    ) -> Result<Invite> {
 187        unimplemented!()
 188    }
 189
 190    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 191        unimplemented!()
 192    }
 193}
 194
 195impl Db<sqlx::Postgres> {
 196    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 197        let pool = sqlx::postgres::PgPoolOptions::new()
 198            .max_connections(max_connections)
 199            .connect(url)
 200            .await?;
 201        Ok(Self {
 202            pool,
 203            #[cfg(test)]
 204            background: None,
 205            #[cfg(test)]
 206            runtime: None,
 207        })
 208    }
 209
 210    #[cfg(test)]
 211    pub fn teardown(&self, url: &str) {
 212        self.runtime.as_ref().unwrap().block_on(async {
 213            use util::ResultExt;
 214            let query = "
 215                SELECT pg_terminate_backend(pg_stat_activity.pid)
 216                FROM pg_stat_activity
 217                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 218            ";
 219            sqlx::query(query).execute(&self.pool).await.log_err();
 220            self.pool.close().await;
 221            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 222                .await
 223                .log_err();
 224        })
 225    }
 226
 227    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 228        self.transact(|tx| async {
 229            let mut tx = tx;
 230            let like_string = Self::fuzzy_like_string(name_query);
 231            let query = "
 232                SELECT users.*
 233                FROM users
 234                WHERE github_login ILIKE $1
 235                ORDER BY github_login <-> $2
 236                LIMIT $3
 237            ";
 238            Ok(sqlx::query_as(query)
 239                .bind(like_string)
 240                .bind(name_query)
 241                .bind(limit as i32)
 242                .fetch_all(&mut tx)
 243                .await?)
 244        })
 245        .await
 246    }
 247
 248    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 249        let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
 250        self.transact(|tx| async {
 251            let mut tx = tx;
 252            let query = "
 253                SELECT users.*
 254                FROM users
 255                WHERE users.id = ANY ($1)
 256            ";
 257            Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
 258        })
 259        .await
 260    }
 261
 262    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 263        self.transact(|mut tx| async move {
 264            let query = "
 265                SELECT metrics_id::text
 266                FROM users
 267                WHERE id = $1
 268            ";
 269            Ok(sqlx::query_scalar(query)
 270                .bind(id)
 271                .fetch_one(&mut tx)
 272                .await?)
 273        })
 274        .await
 275    }
 276
 277    pub async fn create_user(
 278        &self,
 279        email_address: &str,
 280        admin: bool,
 281        params: NewUserParams,
 282    ) -> Result<NewUserResult> {
 283        self.transact(|mut tx| async {
 284            let query = "
 285                INSERT INTO users (email_address, github_login, github_user_id, admin)
 286                VALUES ($1, $2, $3, $4)
 287                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 288                RETURNING id, metrics_id::text
 289            ";
 290
 291            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 292                .bind(email_address)
 293                .bind(&params.github_login)
 294                .bind(params.github_user_id)
 295                .bind(admin)
 296                .fetch_one(&mut tx)
 297                .await?;
 298            tx.commit().await?;
 299
 300            Ok(NewUserResult {
 301                user_id,
 302                metrics_id,
 303                signup_device_id: None,
 304                inviting_user_id: None,
 305            })
 306        })
 307        .await
 308    }
 309
 310    pub async fn create_user_from_invite(
 311        &self,
 312        invite: &Invite,
 313        user: NewUserParams,
 314    ) -> Result<Option<NewUserResult>> {
 315        self.transact(|mut tx| async {
 316            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 317                i32,
 318                Option<UserId>,
 319                Option<UserId>,
 320                Option<String>,
 321            ) = sqlx::query_as(
 322                "
 323                SELECT id, user_id, inviting_user_id, device_id
 324                FROM signups
 325                WHERE
 326                    email_address = $1 AND
 327                    email_confirmation_code = $2
 328                ",
 329            )
 330            .bind(&invite.email_address)
 331            .bind(&invite.email_confirmation_code)
 332            .fetch_optional(&mut tx)
 333            .await?
 334            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 335
 336            if existing_user_id.is_some() {
 337                return Ok(None);
 338            }
 339
 340            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 341                "
 342                INSERT INTO users
 343                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 344                VALUES
 345                ($1, $2, $3, FALSE, $4, $5)
 346                ON CONFLICT (github_login) DO UPDATE SET
 347                    email_address = excluded.email_address,
 348                    github_user_id = excluded.github_user_id,
 349                    admin = excluded.admin
 350                RETURNING id, metrics_id::text
 351                ",
 352            )
 353            .bind(&invite.email_address)
 354            .bind(&user.github_login)
 355            .bind(&user.github_user_id)
 356            .bind(&user.invite_count)
 357            .bind(random_invite_code())
 358            .fetch_one(&mut tx)
 359            .await?;
 360
 361            sqlx::query(
 362                "
 363                UPDATE signups
 364                SET user_id = $1
 365                WHERE id = $2
 366                ",
 367            )
 368            .bind(&user_id)
 369            .bind(&signup_id)
 370            .execute(&mut tx)
 371            .await?;
 372
 373            if let Some(inviting_user_id) = inviting_user_id {
 374                let id: Option<UserId> = sqlx::query_scalar(
 375                    "
 376                    UPDATE users
 377                    SET invite_count = invite_count - 1
 378                    WHERE id = $1 AND invite_count > 0
 379                    RETURNING id
 380                    ",
 381                )
 382                .bind(&inviting_user_id)
 383                .fetch_optional(&mut tx)
 384                .await?;
 385
 386                if id.is_none() {
 387                    Err(Error::Http(
 388                        StatusCode::UNAUTHORIZED,
 389                        "no invites remaining".to_string(),
 390                    ))?;
 391                }
 392
 393                sqlx::query(
 394                    "
 395                    INSERT INTO contacts
 396                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 397                    VALUES
 398                        ($1, $2, TRUE, TRUE, TRUE)
 399                    ON CONFLICT DO NOTHING
 400                    ",
 401                )
 402                .bind(inviting_user_id)
 403                .bind(user_id)
 404                .execute(&mut tx)
 405                .await?;
 406            }
 407
 408            tx.commit().await?;
 409            Ok(Some(NewUserResult {
 410                user_id,
 411                metrics_id,
 412                inviting_user_id,
 413                signup_device_id,
 414            }))
 415        })
 416        .await
 417    }
 418
 419    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 420        self.transact(|mut tx| async {
 421            sqlx::query(
 422                "
 423                INSERT INTO signups
 424                (
 425                    email_address,
 426                    email_confirmation_code,
 427                    email_confirmation_sent,
 428                    platform_linux,
 429                    platform_mac,
 430                    platform_windows,
 431                    platform_unknown,
 432                    editor_features,
 433                    programming_languages,
 434                    device_id
 435                )
 436                VALUES
 437                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 438                RETURNING id
 439                ",
 440            )
 441            .bind(&signup.email_address)
 442            .bind(&random_email_confirmation_code())
 443            .bind(&signup.platform_linux)
 444            .bind(&signup.platform_mac)
 445            .bind(&signup.platform_windows)
 446            .bind(&signup.editor_features)
 447            .bind(&signup.programming_languages)
 448            .bind(&signup.device_id)
 449            .execute(&mut tx)
 450            .await?;
 451            tx.commit().await?;
 452            Ok(())
 453        })
 454        .await
 455    }
 456
 457    pub async fn create_invite_from_code(
 458        &self,
 459        code: &str,
 460        email_address: &str,
 461        device_id: Option<&str>,
 462    ) -> Result<Invite> {
 463        self.transact(|mut tx| async {
 464            let existing_user: Option<UserId> = sqlx::query_scalar(
 465                "
 466                SELECT id
 467                FROM users
 468                WHERE email_address = $1
 469                ",
 470            )
 471            .bind(email_address)
 472            .fetch_optional(&mut tx)
 473            .await?;
 474            if existing_user.is_some() {
 475                Err(anyhow!("email address is already in use"))?;
 476            }
 477
 478            let row: Option<(UserId, i32)> = sqlx::query_as(
 479                "
 480                SELECT id, invite_count
 481                FROM users
 482                WHERE invite_code = $1
 483                ",
 484            )
 485            .bind(code)
 486            .fetch_optional(&mut tx)
 487            .await?;
 488
 489            let (inviter_id, invite_count) = match row {
 490                Some(row) => row,
 491                None => Err(Error::Http(
 492                    StatusCode::NOT_FOUND,
 493                    "invite code not found".to_string(),
 494                ))?,
 495            };
 496
 497            if invite_count == 0 {
 498                Err(Error::Http(
 499                    StatusCode::UNAUTHORIZED,
 500                    "no invites remaining".to_string(),
 501                ))?;
 502            }
 503
 504            let email_confirmation_code: String = sqlx::query_scalar(
 505                "
 506                INSERT INTO signups
 507                (
 508                    email_address,
 509                    email_confirmation_code,
 510                    email_confirmation_sent,
 511                    inviting_user_id,
 512                    platform_linux,
 513                    platform_mac,
 514                    platform_windows,
 515                    platform_unknown,
 516                    device_id
 517                )
 518                VALUES
 519                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 520                ON CONFLICT (email_address)
 521                DO UPDATE SET
 522                    inviting_user_id = excluded.inviting_user_id
 523                RETURNING email_confirmation_code
 524                ",
 525            )
 526            .bind(&email_address)
 527            .bind(&random_email_confirmation_code())
 528            .bind(&inviter_id)
 529            .bind(&device_id)
 530            .fetch_one(&mut tx)
 531            .await?;
 532
 533            tx.commit().await?;
 534
 535            Ok(Invite {
 536                email_address: email_address.into(),
 537                email_confirmation_code,
 538            })
 539        })
 540        .await
 541    }
 542
 543    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 544        self.transact(|mut tx| async {
 545            let emails = invites
 546                .iter()
 547                .map(|s| s.email_address.as_str())
 548                .collect::<Vec<_>>();
 549            sqlx::query(
 550                "
 551                UPDATE signups
 552                SET email_confirmation_sent = TRUE
 553                WHERE email_address = ANY ($1)
 554                ",
 555            )
 556            .bind(&emails)
 557            .execute(&mut tx)
 558            .await?;
 559            tx.commit().await?;
 560            Ok(())
 561        })
 562        .await
 563    }
 564}
 565
 566impl<D> Db<D>
 567where
 568    Self: BeginTransaction<Database = D>,
 569    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 570    D::Connection: sqlx::migrate::Migrate,
 571    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 572    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 573    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 574    D::QueryResult: RowsAffected,
 575    String: sqlx::Type<D>,
 576    i32: sqlx::Type<D>,
 577    i64: sqlx::Type<D>,
 578    bool: sqlx::Type<D>,
 579    str: sqlx::Type<D>,
 580    Uuid: sqlx::Type<D>,
 581    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 582    OffsetDateTime: sqlx::Type<D>,
 583    PrimitiveDateTime: sqlx::Type<D>,
 584    usize: sqlx::ColumnIndex<D::Row>,
 585    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 586    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 587    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 588    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 589    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 590    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 591    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 592    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 593    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 594    for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 595    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 596    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 597    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 598{
 599    pub async fn migrate(
 600        &self,
 601        migrations_path: &Path,
 602        ignore_checksum_mismatch: bool,
 603    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 604        let migrations = MigrationSource::resolve(migrations_path)
 605            .await
 606            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 607
 608        let mut conn = self.pool.acquire().await?;
 609
 610        conn.ensure_migrations_table().await?;
 611        let applied_migrations: HashMap<_, _> = conn
 612            .list_applied_migrations()
 613            .await?
 614            .into_iter()
 615            .map(|m| (m.version, m))
 616            .collect();
 617
 618        let mut new_migrations = Vec::new();
 619        for migration in migrations {
 620            match applied_migrations.get(&migration.version) {
 621                Some(applied_migration) => {
 622                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 623                    {
 624                        Err(anyhow!(
 625                            "checksum mismatch for applied migration {}",
 626                            migration.description
 627                        ))?;
 628                    }
 629                }
 630                None => {
 631                    let elapsed = conn.apply(&migration).await?;
 632                    new_migrations.push((migration, elapsed));
 633                }
 634            }
 635        }
 636
 637        Ok(new_migrations)
 638    }
 639
 640    pub fn fuzzy_like_string(string: &str) -> String {
 641        let mut result = String::with_capacity(string.len() * 2 + 1);
 642        for c in string.chars() {
 643            if c.is_alphanumeric() {
 644                result.push('%');
 645                result.push(c);
 646            }
 647        }
 648        result.push('%');
 649        result
 650    }
 651
 652    // users
 653
 654    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 655        self.transact(|tx| async {
 656            let mut tx = tx;
 657            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 658            Ok(sqlx::query_as(query)
 659                .bind(limit as i32)
 660                .bind((page * limit) as i32)
 661                .fetch_all(&mut tx)
 662                .await?)
 663        })
 664        .await
 665    }
 666
 667    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 668        self.transact(|tx| async {
 669            let mut tx = tx;
 670            let query = "
 671                SELECT users.*
 672                FROM users
 673                WHERE id = $1
 674                LIMIT 1
 675            ";
 676            Ok(sqlx::query_as(query)
 677                .bind(&id)
 678                .fetch_optional(&mut tx)
 679                .await?)
 680        })
 681        .await
 682    }
 683
 684    pub async fn get_users_with_no_invites(
 685        &self,
 686        invited_by_another_user: bool,
 687    ) -> Result<Vec<User>> {
 688        self.transact(|tx| async {
 689            let mut tx = tx;
 690            let query = format!(
 691                "
 692                SELECT users.*
 693                FROM users
 694                WHERE invite_count = 0
 695                AND inviter_id IS{} NULL
 696                ",
 697                if invited_by_another_user { " NOT" } else { "" }
 698            );
 699
 700            Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
 701        })
 702        .await
 703    }
 704
 705    pub async fn get_user_by_github_account(
 706        &self,
 707        github_login: &str,
 708        github_user_id: Option<i32>,
 709    ) -> Result<Option<User>> {
 710        self.transact(|tx| async {
 711            let mut tx = tx;
 712            if let Some(github_user_id) = github_user_id {
 713                let mut user = sqlx::query_as::<_, User>(
 714                    "
 715                    UPDATE users
 716                    SET github_login = $1
 717                    WHERE github_user_id = $2
 718                    RETURNING *
 719                    ",
 720                )
 721                .bind(github_login)
 722                .bind(github_user_id)
 723                .fetch_optional(&mut tx)
 724                .await?;
 725
 726                if user.is_none() {
 727                    user = sqlx::query_as::<_, User>(
 728                        "
 729                        UPDATE users
 730                        SET github_user_id = $1
 731                        WHERE github_login = $2
 732                        RETURNING *
 733                        ",
 734                    )
 735                    .bind(github_user_id)
 736                    .bind(github_login)
 737                    .fetch_optional(&mut tx)
 738                    .await?;
 739                }
 740
 741                Ok(user)
 742            } else {
 743                let user = sqlx::query_as(
 744                    "
 745                    SELECT * FROM users
 746                    WHERE github_login = $1
 747                    LIMIT 1
 748                    ",
 749                )
 750                .bind(github_login)
 751                .fetch_optional(&mut tx)
 752                .await?;
 753                Ok(user)
 754            }
 755        })
 756        .await
 757    }
 758
 759    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 760        self.transact(|mut tx| async {
 761            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 762            sqlx::query(query)
 763                .bind(is_admin)
 764                .bind(id.0)
 765                .execute(&mut tx)
 766                .await?;
 767            tx.commit().await?;
 768            Ok(())
 769        })
 770        .await
 771    }
 772
 773    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 774        self.transact(|mut tx| async move {
 775            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 776            sqlx::query(query)
 777                .bind(connected_once)
 778                .bind(id.0)
 779                .execute(&mut tx)
 780                .await?;
 781            tx.commit().await?;
 782            Ok(())
 783        })
 784        .await
 785    }
 786
 787    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 788        self.transact(|mut tx| async move {
 789            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 790            sqlx::query(query)
 791                .bind(id.0)
 792                .execute(&mut tx)
 793                .await
 794                .map(drop)?;
 795            let query = "DELETE FROM users WHERE id = $1;";
 796            sqlx::query(query).bind(id.0).execute(&mut tx).await?;
 797            tx.commit().await?;
 798            Ok(())
 799        })
 800        .await
 801    }
 802
 803    // signups
 804
 805    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 806        self.transact(|mut tx| async move {
 807            Ok(sqlx::query_as(
 808                "
 809                SELECT
 810                    COUNT(*) as count,
 811                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 812                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 813                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 814                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 815                FROM (
 816                    SELECT *
 817                    FROM signups
 818                    WHERE
 819                        NOT email_confirmation_sent
 820                ) AS unsent
 821                ",
 822            )
 823            .fetch_one(&mut tx)
 824            .await?)
 825        })
 826        .await
 827    }
 828
 829    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 830        self.transact(|mut tx| async move {
 831            Ok(sqlx::query_as(
 832                "
 833                SELECT
 834                    email_address, email_confirmation_code
 835                FROM signups
 836                WHERE
 837                    NOT email_confirmation_sent AND
 838                    (platform_mac OR platform_unknown)
 839                LIMIT $1
 840                ",
 841            )
 842            .bind(count as i32)
 843            .fetch_all(&mut tx)
 844            .await?)
 845        })
 846        .await
 847    }
 848
 849    // invite codes
 850
 851    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 852        self.transact(|mut tx| async move {
 853            if count > 0 {
 854                sqlx::query(
 855                    "
 856                    UPDATE users
 857                    SET invite_code = $1
 858                    WHERE id = $2 AND invite_code IS NULL
 859                ",
 860                )
 861                .bind(random_invite_code())
 862                .bind(id)
 863                .execute(&mut tx)
 864                .await?;
 865            }
 866
 867            sqlx::query(
 868                "
 869                UPDATE users
 870                SET invite_count = $1
 871                WHERE id = $2
 872                ",
 873            )
 874            .bind(count as i32)
 875            .bind(id)
 876            .execute(&mut tx)
 877            .await?;
 878            tx.commit().await?;
 879            Ok(())
 880        })
 881        .await
 882    }
 883
 884    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 885        self.transact(|mut tx| async move {
 886            let result: Option<(String, i32)> = sqlx::query_as(
 887                "
 888                    SELECT invite_code, invite_count
 889                    FROM users
 890                    WHERE id = $1 AND invite_code IS NOT NULL 
 891                ",
 892            )
 893            .bind(id)
 894            .fetch_optional(&mut tx)
 895            .await?;
 896            if let Some((code, count)) = result {
 897                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 898            } else {
 899                Ok(None)
 900            }
 901        })
 902        .await
 903    }
 904
 905    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 906        self.transact(|tx| async {
 907            let mut tx = tx;
 908            sqlx::query_as(
 909                "
 910                    SELECT *
 911                    FROM users
 912                    WHERE invite_code = $1
 913                ",
 914            )
 915            .bind(code)
 916            .fetch_optional(&mut tx)
 917            .await?
 918            .ok_or_else(|| {
 919                Error::Http(
 920                    StatusCode::NOT_FOUND,
 921                    "that invite code does not exist".to_string(),
 922                )
 923            })
 924        })
 925        .await
 926    }
 927
 928    pub async fn create_room(
 929        &self,
 930        user_id: UserId,
 931        connection_id: ConnectionId,
 932    ) -> Result<proto::Room> {
 933        self.transact(|mut tx| async move {
 934            let live_kit_room = nanoid::nanoid!(30);
 935            let room_id = sqlx::query_scalar(
 936                "
 937                INSERT INTO rooms (live_kit_room, version)
 938                VALUES ($1, $2)
 939                RETURNING id
 940                ",
 941            )
 942            .bind(&live_kit_room)
 943            .bind(0)
 944            .fetch_one(&mut tx)
 945            .await
 946            .map(RoomId)?;
 947
 948            sqlx::query(
 949                "
 950                INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
 951                VALUES ($1, $2, $3, $4, $5)
 952                ",
 953            )
 954            .bind(room_id)
 955            .bind(user_id)
 956            .bind(connection_id.0 as i32)
 957            .bind(user_id)
 958            .bind(connection_id.0 as i32)
 959            .execute(&mut tx)
 960            .await?;
 961
 962            self.commit_room_transaction(room_id, tx).await
 963        }).await
 964    }
 965
 966    pub async fn call(
 967        &self,
 968        room_id: RoomId,
 969        calling_user_id: UserId,
 970        calling_connection_id: ConnectionId,
 971        called_user_id: UserId,
 972        initial_project_id: Option<ProjectId>,
 973    ) -> Result<(proto::Room, proto::IncomingCall)> {
 974        self.transact(|mut tx| async move {
 975            sqlx::query(
 976                "
 977                INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
 978                VALUES ($1, $2, $3, $4, $5)
 979                ",
 980            )
 981            .bind(room_id)
 982            .bind(called_user_id)
 983            .bind(calling_user_id)
 984            .bind(calling_connection_id.0 as i32)
 985            .bind(initial_project_id)
 986            .execute(&mut tx)
 987            .await?;
 988
 989            let room = self.commit_room_transaction(room_id, tx).await?;
 990            let incoming_call = Self::build_incoming_call(&room, called_user_id)
 991                .ok_or_else(|| anyhow!("failed to build incoming call"))?;
 992            Ok((room, incoming_call))
 993        }).await
 994    }
 995
 996    pub async fn incoming_call_for_user(
 997        &self,
 998        user_id: UserId,
 999    ) -> Result<Option<proto::IncomingCall>> {
1000        self.transact(|mut tx| async move {
1001            let room_id = sqlx::query_scalar::<_, RoomId>(
1002                "
1003                SELECT room_id
1004                FROM room_participants
1005                WHERE user_id = $1 AND answering_connection_id IS NULL
1006                ",
1007            )
1008            .bind(user_id)
1009            .fetch_optional(&mut tx)
1010            .await?;
1011
1012            if let Some(room_id) = room_id {
1013                let room = self.get_room(room_id, &mut tx).await?;
1014                Ok(Self::build_incoming_call(&room, user_id))
1015            } else {
1016                Ok(None)
1017            }
1018        })
1019        .await
1020    }
1021
1022    fn build_incoming_call(
1023        room: &proto::Room,
1024        called_user_id: UserId,
1025    ) -> Option<proto::IncomingCall> {
1026        let pending_participant = room
1027            .pending_participants
1028            .iter()
1029            .find(|participant| participant.user_id == called_user_id.to_proto())?;
1030
1031        Some(proto::IncomingCall {
1032            room_id: room.id,
1033            calling_user_id: pending_participant.calling_user_id,
1034            participant_user_ids: room
1035                .participants
1036                .iter()
1037                .map(|participant| participant.user_id)
1038                .collect(),
1039            initial_project: room.participants.iter().find_map(|participant| {
1040                let initial_project_id = pending_participant.initial_project_id?;
1041                participant
1042                    .projects
1043                    .iter()
1044                    .find(|project| project.id == initial_project_id)
1045                    .cloned()
1046            }),
1047        })
1048    }
1049
1050    pub async fn call_failed(
1051        &self,
1052        room_id: RoomId,
1053        called_user_id: UserId,
1054    ) -> Result<proto::Room> {
1055        self.transact(|mut tx| async move {
1056            sqlx::query(
1057                "
1058                DELETE FROM room_participants
1059                WHERE room_id = $1 AND user_id = $2
1060                ",
1061            )
1062            .bind(room_id)
1063            .bind(called_user_id)
1064            .execute(&mut tx)
1065            .await?;
1066
1067            self.commit_room_transaction(room_id, tx).await
1068        })
1069        .await
1070    }
1071
1072    pub async fn decline_call(
1073        &self,
1074        expected_room_id: Option<RoomId>,
1075        user_id: UserId,
1076    ) -> Result<proto::Room> {
1077        self.transact(|mut tx| async move {
1078            let room_id = sqlx::query_scalar(
1079                "
1080                DELETE FROM room_participants
1081                WHERE user_id = $1 AND answering_connection_id IS NULL
1082                RETURNING room_id
1083                ",
1084            )
1085            .bind(user_id)
1086            .fetch_one(&mut tx)
1087            .await?;
1088            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1089                return Err(anyhow!("declining call on unexpected room"))?;
1090            }
1091
1092            self.commit_room_transaction(room_id, tx).await
1093        })
1094        .await
1095    }
1096
1097    pub async fn cancel_call(
1098        &self,
1099        expected_room_id: Option<RoomId>,
1100        calling_connection_id: ConnectionId,
1101        called_user_id: UserId,
1102    ) -> Result<proto::Room> {
1103        self.transact(|mut tx| async move {
1104            let room_id = sqlx::query_scalar(
1105                "
1106                DELETE FROM room_participants
1107                WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1108                RETURNING room_id
1109                ",
1110            )
1111            .bind(called_user_id)
1112            .bind(calling_connection_id.0 as i32)
1113            .fetch_one(&mut tx)
1114            .await?;
1115            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1116                return Err(anyhow!("canceling call on unexpected room"))?;
1117            }
1118
1119            self.commit_room_transaction(room_id, tx).await
1120        }).await
1121    }
1122
1123    pub async fn join_room(
1124        &self,
1125        room_id: RoomId,
1126        user_id: UserId,
1127        connection_id: ConnectionId,
1128    ) -> Result<proto::Room> {
1129        self.transact(|mut tx| async move {
1130            sqlx::query(
1131                "
1132                UPDATE room_participants 
1133                SET answering_connection_id = $1
1134                WHERE room_id = $2 AND user_id = $3
1135                RETURNING 1
1136                ",
1137            )
1138            .bind(connection_id.0 as i32)
1139            .bind(room_id)
1140            .bind(user_id)
1141            .fetch_one(&mut tx)
1142            .await?;
1143            self.commit_room_transaction(room_id, tx).await
1144        })
1145        .await
1146    }
1147
1148    pub async fn leave_room_for_connection(
1149        &self,
1150        connection_id: ConnectionId,
1151    ) -> Result<Option<LeftRoom>> {
1152        self.transact(|mut tx| async move {
1153            // Leave room.
1154            let room_id = sqlx::query_scalar::<_, RoomId>(
1155                "
1156                DELETE FROM room_participants
1157                WHERE answering_connection_id = $1
1158                RETURNING room_id
1159                ",
1160            )
1161            .bind(connection_id.0 as i32)
1162            .fetch_optional(&mut tx)
1163            .await?;
1164
1165            if let Some(room_id) = room_id {
1166                // Cancel pending calls initiated by the leaving user.
1167                let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1168                    "
1169                    DELETE FROM room_participants
1170                    WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1171                    RETURNING user_id
1172                    ",
1173                )
1174                .bind(connection_id.0 as i32)
1175                .fetch_all(&mut tx)
1176                .await?;
1177
1178                let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1179                    "
1180                    SELECT project_collaborators.*
1181                    FROM projects, project_collaborators
1182                    WHERE
1183                        projects.room_id = $1 AND
1184                        projects.id = project_collaborators.project_id AND
1185                        project_collaborators.connection_id = $2
1186                    ",
1187                )
1188                .bind(room_id)
1189                .bind(connection_id.0 as i32)
1190                .fetch(&mut tx);
1191
1192                let mut left_projects = HashMap::default();
1193                while let Some(collaborator) = project_collaborators.next().await {
1194                    let collaborator = collaborator?;
1195                    let left_project =
1196                        left_projects
1197                            .entry(collaborator.project_id)
1198                            .or_insert(LeftProject {
1199                                id: collaborator.project_id,
1200                                host_user_id: Default::default(),
1201                                connection_ids: Default::default(),
1202                            });
1203
1204                    let collaborator_connection_id =
1205                        ConnectionId(collaborator.connection_id as u32);
1206                    if collaborator_connection_id != connection_id || collaborator.is_host {
1207                        left_project.connection_ids.push(collaborator_connection_id);
1208                    }
1209
1210                    if collaborator.is_host {
1211                        left_project.host_user_id = collaborator.user_id;
1212                    }
1213                }
1214                drop(project_collaborators);
1215
1216                sqlx::query(
1217                    "
1218                    DELETE FROM projects
1219                    WHERE room_id = $1 AND host_connection_id = $2
1220                    ",
1221                )
1222                .bind(room_id)
1223                .bind(connection_id.0 as i32)
1224                .execute(&mut tx)
1225                .await?;
1226
1227                let room = self.commit_room_transaction(room_id, tx).await?;
1228                Ok(Some(LeftRoom {
1229                    room,
1230                    left_projects,
1231                    canceled_calls_to_user_ids,
1232                }))
1233            } else {
1234                Ok(None)
1235            }
1236        })
1237        .await
1238    }
1239
1240    pub async fn update_room_participant_location(
1241        &self,
1242        room_id: RoomId,
1243        connection_id: ConnectionId,
1244        location: proto::ParticipantLocation,
1245    ) -> Result<proto::Room> {
1246        self.transact(|tx| async {
1247            let mut tx = tx;
1248            let location_kind;
1249            let location_project_id;
1250            match location
1251                .variant
1252                .as_ref()
1253                .ok_or_else(|| anyhow!("invalid location"))?
1254            {
1255                proto::participant_location::Variant::SharedProject(project) => {
1256                    location_kind = 0;
1257                    location_project_id = Some(ProjectId::from_proto(project.id));
1258                }
1259                proto::participant_location::Variant::UnsharedProject(_) => {
1260                    location_kind = 1;
1261                    location_project_id = None;
1262                }
1263                proto::participant_location::Variant::External(_) => {
1264                    location_kind = 2;
1265                    location_project_id = None;
1266                }
1267            }
1268
1269            sqlx::query(
1270                "
1271                UPDATE room_participants
1272                SET location_kind = $1 AND location_project_id = $2
1273                WHERE room_id = $3 AND answering_connection_id = $4
1274                ",
1275            )
1276            .bind(location_kind)
1277            .bind(location_project_id)
1278            .bind(room_id)
1279            .bind(connection_id.0 as i32)
1280            .execute(&mut tx)
1281            .await?;
1282
1283            self.commit_room_transaction(room_id, tx).await
1284        })
1285        .await
1286    }
1287
1288    async fn commit_room_transaction(
1289        &self,
1290        room_id: RoomId,
1291        mut tx: sqlx::Transaction<'_, D>,
1292    ) -> Result<proto::Room> {
1293        sqlx::query(
1294            "
1295            UPDATE rooms
1296            SET version = version + 1
1297            WHERE id = $1
1298            ",
1299        )
1300        .bind(room_id)
1301        .execute(&mut tx)
1302        .await?;
1303        let room = self.get_room(room_id, &mut tx).await?;
1304        tx.commit().await?;
1305
1306        Ok(room)
1307    }
1308
1309    async fn get_room(
1310        &self,
1311        room_id: RoomId,
1312        tx: &mut sqlx::Transaction<'_, D>,
1313    ) -> Result<proto::Room> {
1314        let room: Room = sqlx::query_as(
1315            "
1316            SELECT *
1317            FROM rooms
1318            WHERE id = $1
1319            ",
1320        )
1321        .bind(room_id)
1322        .fetch_one(&mut *tx)
1323        .await?;
1324
1325        let mut db_participants =
1326            sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1327                "
1328                SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1329                FROM room_participants
1330                WHERE room_id = $1
1331                ",
1332            )
1333            .bind(room_id)
1334            .fetch(&mut *tx);
1335
1336        let mut participants = Vec::new();
1337        let mut pending_participants = Vec::new();
1338        while let Some(participant) = db_participants.next().await {
1339            let (
1340                user_id,
1341                answering_connection_id,
1342                _location_kind,
1343                _location_project_id,
1344                calling_user_id,
1345                initial_project_id,
1346            ) = participant?;
1347            if let Some(answering_connection_id) = answering_connection_id {
1348                participants.push(proto::Participant {
1349                    user_id: user_id.to_proto(),
1350                    peer_id: answering_connection_id as u32,
1351                    projects: Default::default(),
1352                    location: Some(proto::ParticipantLocation {
1353                        variant: Some(proto::participant_location::Variant::External(
1354                            Default::default(),
1355                        )),
1356                    }),
1357                });
1358            } else {
1359                pending_participants.push(proto::PendingParticipant {
1360                    user_id: user_id.to_proto(),
1361                    calling_user_id: calling_user_id.to_proto(),
1362                    initial_project_id: initial_project_id.map(|id| id.to_proto()),
1363                });
1364            }
1365        }
1366        drop(db_participants);
1367
1368        for participant in &mut participants {
1369            let mut entries = sqlx::query_as::<_, (ProjectId, String)>(
1370                "
1371                SELECT projects.id, worktrees.root_name
1372                FROM projects
1373                LEFT JOIN worktrees ON projects.id = worktrees.project_id
1374                WHERE room_id = $1 AND host_connection_id = $2
1375                ",
1376            )
1377            .bind(room_id)
1378            .bind(participant.peer_id as i32)
1379            .fetch(&mut *tx);
1380
1381            let mut projects = HashMap::default();
1382            while let Some(entry) = entries.next().await {
1383                let (project_id, worktree_root_name) = entry?;
1384                let participant_project =
1385                    projects
1386                        .entry(project_id)
1387                        .or_insert(proto::ParticipantProject {
1388                            id: project_id.to_proto(),
1389                            worktree_root_names: Default::default(),
1390                        });
1391                participant_project
1392                    .worktree_root_names
1393                    .push(worktree_root_name);
1394            }
1395
1396            participant.projects = projects.into_values().collect();
1397        }
1398        Ok(proto::Room {
1399            id: room.id.to_proto(),
1400            version: room.version as u64,
1401            live_kit_room: room.live_kit_room,
1402            participants,
1403            pending_participants,
1404        })
1405    }
1406
1407    // projects
1408
1409    pub async fn share_project(
1410        &self,
1411        expected_room_id: RoomId,
1412        connection_id: ConnectionId,
1413        worktrees: &[proto::WorktreeMetadata],
1414    ) -> Result<(ProjectId, proto::Room)> {
1415        self.transact(|mut tx| async move {
1416            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1417                "
1418                SELECT room_id, user_id
1419                FROM room_participants
1420                WHERE answering_connection_id = $1
1421                ",
1422            )
1423            .bind(connection_id.0 as i32)
1424            .fetch_one(&mut tx)
1425            .await?;
1426            if room_id != expected_room_id {
1427                return Err(anyhow!("shared project on unexpected room"))?;
1428            }
1429
1430            let project_id: ProjectId = sqlx::query_scalar(
1431                "
1432                INSERT INTO projects (room_id, host_user_id, host_connection_id)
1433                VALUES ($1, $2, $3)
1434                RETURNING id
1435                ",
1436            )
1437            .bind(room_id)
1438            .bind(user_id)
1439            .bind(connection_id.0 as i32)
1440            .fetch_one(&mut tx)
1441            .await?;
1442
1443            for worktree in worktrees {
1444                sqlx::query(
1445                    "
1446                    INSERT INTO worktrees (id, project_id, root_name)
1447                    VALUES ($1, $2, $3)
1448                    ",
1449                )
1450                .bind(worktree.id as i32)
1451                .bind(project_id)
1452                .bind(&worktree.root_name)
1453                .execute(&mut tx)
1454                .await?;
1455            }
1456
1457            sqlx::query(
1458                "
1459                INSERT INTO project_collaborators (
1460                    project_id,
1461                    connection_id,
1462                    user_id,
1463                    replica_id,
1464                    is_host
1465                )
1466                VALUES ($1, $2, $3, $4, $5)
1467                ",
1468            )
1469            .bind(project_id)
1470            .bind(connection_id.0 as i32)
1471            .bind(user_id)
1472            .bind(0)
1473            .bind(true)
1474            .execute(&mut tx)
1475            .await?;
1476
1477            let room = self.commit_room_transaction(room_id, tx).await?;
1478            Ok((project_id, room))
1479        })
1480        .await
1481    }
1482
1483    pub async fn update_project(
1484        &self,
1485        project_id: ProjectId,
1486        connection_id: ConnectionId,
1487        worktrees: &[proto::WorktreeMetadata],
1488    ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1489        self.transact(|mut tx| async move {
1490            let room_id: RoomId = sqlx::query_scalar(
1491                "
1492                SELECT room_id
1493                FROM projects
1494                WHERE id = $1 AND host_connection_id = $2
1495                ",
1496            )
1497            .bind(project_id)
1498            .bind(connection_id.0 as i32)
1499            .fetch_one(&mut tx)
1500            .await?;
1501
1502            for worktree in worktrees {
1503                sqlx::query(
1504                    "
1505                    INSERT INTO worktrees (project_id, id, root_name)
1506                    VALUES ($1, $2, $3)
1507                    ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1508                    ",
1509                )
1510                .bind(project_id)
1511                .bind(worktree.id as i32)
1512                .bind(&worktree.root_name)
1513                .execute(&mut tx)
1514                .await?;
1515            }
1516
1517            let mut params = "?,".repeat(worktrees.len());
1518            if !worktrees.is_empty() {
1519                params.pop();
1520            }
1521            let query = format!(
1522                "
1523                DELETE FROM worktrees
1524                WHERE id NOT IN ({params})
1525                ",
1526            );
1527
1528            let mut query = sqlx::query(&query);
1529            for worktree in worktrees {
1530                query = query.bind(worktree.id as i32);
1531            }
1532            query.execute(&mut tx).await?;
1533
1534            let mut guest_connection_ids = Vec::new();
1535            {
1536                let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1537                    "
1538                    SELECT connection_id
1539                    FROM project_collaborators
1540                    WHERE project_id = $1 AND is_host = FALSE
1541                    ",
1542                )
1543                .fetch(&mut tx);
1544                while let Some(connection_id) = db_guest_connection_ids.next().await {
1545                    guest_connection_ids.push(ConnectionId(connection_id? as u32));
1546                }
1547            }
1548
1549            let room = self.commit_room_transaction(room_id, tx).await?;
1550            Ok((room, guest_connection_ids))
1551        })
1552        .await
1553    }
1554
1555    pub async fn join_project(
1556        &self,
1557        project_id: ProjectId,
1558        connection_id: ConnectionId,
1559    ) -> Result<(Project, i32)> {
1560        self.transact(|mut tx| async move {
1561            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1562                "
1563                SELECT room_id, user_id
1564                FROM room_participants
1565                WHERE answering_connection_id = $1
1566                ",
1567            )
1568            .bind(connection_id.0 as i32)
1569            .fetch_one(&mut tx)
1570            .await?;
1571
1572            // Ensure project id was shared on this room.
1573            sqlx::query(
1574                "
1575                SELECT 1
1576                FROM projects
1577                WHERE project_id = $1 AND room_id = $2
1578                ",
1579            )
1580            .bind(project_id)
1581            .bind(room_id)
1582            .fetch_one(&mut tx)
1583            .await?;
1584
1585            let replica_ids = sqlx::query_scalar::<_, i32>(
1586                "
1587                SELECT replica_id
1588                FROM project_collaborators
1589                WHERE project_id = $1
1590                ",
1591            )
1592            .bind(project_id)
1593            .fetch_all(&mut tx)
1594            .await?;
1595            let replica_ids = HashSet::from_iter(replica_ids);
1596            let mut replica_id = 1;
1597            while replica_ids.contains(&replica_id) {
1598                replica_id += 1;
1599            }
1600
1601            sqlx::query(
1602                "
1603                INSERT INTO project_collaborators (
1604                    project_id,
1605                    connection_id,
1606                    user_id,
1607                    replica_id,
1608                    is_host
1609                )
1610                VALUES ($1, $2, $3, $4, $5)
1611                ",
1612            )
1613            .bind(project_id)
1614            .bind(connection_id.0 as i32)
1615            .bind(user_id)
1616            .bind(replica_id)
1617            .bind(false)
1618            .execute(&mut tx)
1619            .await?;
1620
1621            tx.commit().await?;
1622            todo!()
1623        })
1624        .await
1625        // sqlx::query(
1626        //     "
1627        //     SELECT replica_id
1628        //     FROM project_collaborators
1629        //     WHERE project_id = $
1630        //     ",
1631        // )
1632        // .bind(project_id)
1633        // .bind(connection_id.0 as i32)
1634        // .bind(user_id)
1635        // .bind(0)
1636        // .bind(true)
1637        // .execute(&mut tx)
1638        // .await?;
1639        // sqlx::query(
1640        //     "
1641        //     INSERT INTO project_collaborators (
1642        //         project_id,
1643        //         connection_id,
1644        //         user_id,
1645        //         replica_id,
1646        //         is_host
1647        //     )
1648        //     VALUES ($1, $2, $3, $4, $5)
1649        //     ",
1650        // )
1651        // .bind(project_id)
1652        // .bind(connection_id.0 as i32)
1653        // .bind(user_id)
1654        // .bind(0)
1655        // .bind(true)
1656        // .execute(&mut tx)
1657        // .await?;
1658    }
1659
1660    pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
1661        todo!()
1662        // test_support!(self, {
1663        //     sqlx::query(
1664        //         "
1665        //         UPDATE projects
1666        //         SET unregistered = TRUE
1667        //         WHERE id = $1
1668        //         ",
1669        //     )
1670        //     .bind(project_id)
1671        //     .execute(&self.pool)
1672        //     .await?;
1673        //     Ok(())
1674        // })
1675    }
1676
1677    // contacts
1678
1679    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1680        self.transact(|mut tx| async move {
1681            let query = "
1682                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
1683                FROM contacts
1684                LEFT JOIN room_participants ON room_participants.user_id = $1
1685                WHERE user_id_a = $1 OR user_id_b = $1;
1686            ";
1687
1688            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
1689                .bind(user_id)
1690                .fetch(&mut tx);
1691
1692            let mut contacts = Vec::new();
1693            while let Some(row) = rows.next().await {
1694                let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
1695                if user_id_a == user_id {
1696                    if accepted {
1697                        contacts.push(Contact::Accepted {
1698                            user_id: user_id_b,
1699                            should_notify: should_notify && a_to_b,
1700                            busy
1701                        });
1702                    } else if a_to_b {
1703                        contacts.push(Contact::Outgoing { user_id: user_id_b })
1704                    } else {
1705                        contacts.push(Contact::Incoming {
1706                            user_id: user_id_b,
1707                            should_notify,
1708                        });
1709                    }
1710                } else if accepted {
1711                    contacts.push(Contact::Accepted {
1712                        user_id: user_id_a,
1713                        should_notify: should_notify && !a_to_b,
1714                        busy
1715                    });
1716                } else if a_to_b {
1717                    contacts.push(Contact::Incoming {
1718                        user_id: user_id_a,
1719                        should_notify,
1720                    });
1721                } else {
1722                    contacts.push(Contact::Outgoing { user_id: user_id_a });
1723                }
1724            }
1725
1726            contacts.sort_unstable_by_key(|contact| contact.user_id());
1727
1728            Ok(contacts)
1729        })
1730        .await
1731    }
1732
1733    pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
1734        self.transact(|mut tx| async move {
1735            Ok(sqlx::query_scalar::<_, i32>(
1736                "
1737                SELECT 1
1738                FROM room_participants
1739                WHERE room_participants.user_id = $1
1740                ",
1741            )
1742            .bind(user_id)
1743            .fetch_optional(&mut tx)
1744            .await?
1745            .is_some())
1746        })
1747        .await
1748    }
1749
1750    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1751        self.transact(|mut tx| async move {
1752            let (id_a, id_b) = if user_id_1 < user_id_2 {
1753                (user_id_1, user_id_2)
1754            } else {
1755                (user_id_2, user_id_1)
1756            };
1757
1758            let query = "
1759                SELECT 1 FROM contacts
1760                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1761                LIMIT 1
1762            ";
1763            Ok(sqlx::query_scalar::<_, i32>(query)
1764                .bind(id_a.0)
1765                .bind(id_b.0)
1766                .fetch_optional(&mut tx)
1767                .await?
1768                .is_some())
1769        })
1770        .await
1771    }
1772
1773    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1774        self.transact(|mut tx| async move {
1775            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1776                (sender_id, receiver_id, true)
1777            } else {
1778                (receiver_id, sender_id, false)
1779            };
1780            let query = "
1781                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1782                VALUES ($1, $2, $3, FALSE, TRUE)
1783                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1784                SET
1785                    accepted = TRUE,
1786                    should_notify = FALSE
1787                WHERE
1788                    NOT contacts.accepted AND
1789                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1790                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1791            ";
1792            let result = sqlx::query(query)
1793                .bind(id_a.0)
1794                .bind(id_b.0)
1795                .bind(a_to_b)
1796                .execute(&mut tx)
1797                .await?;
1798
1799            if result.rows_affected() == 1 {
1800                tx.commit().await?;
1801                Ok(())
1802            } else {
1803                Err(anyhow!("contact already requested"))?
1804            }
1805        }).await
1806    }
1807
1808    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1809        self.transact(|mut tx| async move {
1810            let (id_a, id_b) = if responder_id < requester_id {
1811                (responder_id, requester_id)
1812            } else {
1813                (requester_id, responder_id)
1814            };
1815            let query = "
1816                DELETE FROM contacts
1817                WHERE user_id_a = $1 AND user_id_b = $2;
1818            ";
1819            let result = sqlx::query(query)
1820                .bind(id_a.0)
1821                .bind(id_b.0)
1822                .execute(&mut tx)
1823                .await?;
1824
1825            if result.rows_affected() == 1 {
1826                tx.commit().await?;
1827                Ok(())
1828            } else {
1829                Err(anyhow!("no such contact"))?
1830            }
1831        })
1832        .await
1833    }
1834
1835    pub async fn dismiss_contact_notification(
1836        &self,
1837        user_id: UserId,
1838        contact_user_id: UserId,
1839    ) -> Result<()> {
1840        self.transact(|mut tx| async move {
1841            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1842                (user_id, contact_user_id, true)
1843            } else {
1844                (contact_user_id, user_id, false)
1845            };
1846
1847            let query = "
1848                UPDATE contacts
1849                SET should_notify = FALSE
1850                WHERE
1851                    user_id_a = $1 AND user_id_b = $2 AND
1852                    (
1853                        (a_to_b = $3 AND accepted) OR
1854                        (a_to_b != $3 AND NOT accepted)
1855                    );
1856            ";
1857
1858            let result = sqlx::query(query)
1859                .bind(id_a.0)
1860                .bind(id_b.0)
1861                .bind(a_to_b)
1862                .execute(&mut tx)
1863                .await?;
1864
1865            if result.rows_affected() == 0 {
1866                Err(anyhow!("no such contact request"))?
1867            } else {
1868                tx.commit().await?;
1869                Ok(())
1870            }
1871        })
1872        .await
1873    }
1874
1875    pub async fn respond_to_contact_request(
1876        &self,
1877        responder_id: UserId,
1878        requester_id: UserId,
1879        accept: bool,
1880    ) -> Result<()> {
1881        self.transact(|mut tx| async move {
1882            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1883                (responder_id, requester_id, false)
1884            } else {
1885                (requester_id, responder_id, true)
1886            };
1887            let result = if accept {
1888                let query = "
1889                    UPDATE contacts
1890                    SET accepted = TRUE, should_notify = TRUE
1891                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1892                ";
1893                sqlx::query(query)
1894                    .bind(id_a.0)
1895                    .bind(id_b.0)
1896                    .bind(a_to_b)
1897                    .execute(&mut tx)
1898                    .await?
1899            } else {
1900                let query = "
1901                    DELETE FROM contacts
1902                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1903                ";
1904                sqlx::query(query)
1905                    .bind(id_a.0)
1906                    .bind(id_b.0)
1907                    .bind(a_to_b)
1908                    .execute(&mut tx)
1909                    .await?
1910            };
1911            if result.rows_affected() == 1 {
1912                tx.commit().await?;
1913                Ok(())
1914            } else {
1915                Err(anyhow!("no such contact request"))?
1916            }
1917        })
1918        .await
1919    }
1920
1921    // access tokens
1922
1923    pub async fn create_access_token_hash(
1924        &self,
1925        user_id: UserId,
1926        access_token_hash: &str,
1927        max_access_token_count: usize,
1928    ) -> Result<()> {
1929        self.transact(|tx| async {
1930            let mut tx = tx;
1931            let insert_query = "
1932                INSERT INTO access_tokens (user_id, hash)
1933                VALUES ($1, $2);
1934            ";
1935            let cleanup_query = "
1936                DELETE FROM access_tokens
1937                WHERE id IN (
1938                    SELECT id from access_tokens
1939                    WHERE user_id = $1
1940                    ORDER BY id DESC
1941                    LIMIT 10000
1942                    OFFSET $3
1943                )
1944            ";
1945
1946            sqlx::query(insert_query)
1947                .bind(user_id.0)
1948                .bind(access_token_hash)
1949                .execute(&mut tx)
1950                .await?;
1951            sqlx::query(cleanup_query)
1952                .bind(user_id.0)
1953                .bind(access_token_hash)
1954                .bind(max_access_token_count as i32)
1955                .execute(&mut tx)
1956                .await?;
1957            Ok(tx.commit().await?)
1958        })
1959        .await
1960    }
1961
1962    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1963        self.transact(|mut tx| async move {
1964            let query = "
1965                SELECT hash
1966                FROM access_tokens
1967                WHERE user_id = $1
1968                ORDER BY id DESC
1969            ";
1970            Ok(sqlx::query_scalar(query)
1971                .bind(user_id.0)
1972                .fetch_all(&mut tx)
1973                .await?)
1974        })
1975        .await
1976    }
1977
1978    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
1979    where
1980        F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
1981        Fut: Send + Future<Output = Result<T>>,
1982    {
1983        let body = async {
1984            loop {
1985                let tx = self.begin_transaction().await?;
1986                match f(tx).await {
1987                    Ok(result) => return Ok(result),
1988                    Err(error) => match error {
1989                        Error::Database(error)
1990                            if error
1991                                .as_database_error()
1992                                .and_then(|error| error.code())
1993                                .as_deref()
1994                                == Some("hey") =>
1995                        {
1996                            // Retry (don't break the loop)
1997                        }
1998                        error @ _ => return Err(error),
1999                    },
2000                }
2001            }
2002        };
2003
2004        #[cfg(test)]
2005        {
2006            if let Some(background) = self.background.as_ref() {
2007                background.simulate_random_delay().await;
2008            }
2009
2010            let result = self.runtime.as_ref().unwrap().block_on(body);
2011
2012            if let Some(background) = self.background.as_ref() {
2013                background.simulate_random_delay().await;
2014            }
2015
2016            result
2017        }
2018
2019        #[cfg(not(test))]
2020        {
2021            body.await
2022        }
2023    }
2024}
2025
2026macro_rules! id_type {
2027    ($name:ident) => {
2028        #[derive(
2029            Clone,
2030            Copy,
2031            Debug,
2032            Default,
2033            PartialEq,
2034            Eq,
2035            PartialOrd,
2036            Ord,
2037            Hash,
2038            sqlx::Type,
2039            Serialize,
2040            Deserialize,
2041        )]
2042        #[sqlx(transparent)]
2043        #[serde(transparent)]
2044        pub struct $name(pub i32);
2045
2046        impl $name {
2047            #[allow(unused)]
2048            pub const MAX: Self = Self(i32::MAX);
2049
2050            #[allow(unused)]
2051            pub fn from_proto(value: u64) -> Self {
2052                Self(value as i32)
2053            }
2054
2055            #[allow(unused)]
2056            pub fn to_proto(self) -> u64 {
2057                self.0 as u64
2058            }
2059        }
2060
2061        impl std::fmt::Display for $name {
2062            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2063                self.0.fmt(f)
2064            }
2065        }
2066    };
2067}
2068
2069id_type!(UserId);
2070#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2071pub struct User {
2072    pub id: UserId,
2073    pub github_login: String,
2074    pub github_user_id: Option<i32>,
2075    pub email_address: Option<String>,
2076    pub admin: bool,
2077    pub invite_code: Option<String>,
2078    pub invite_count: i32,
2079    pub connected_once: bool,
2080}
2081
2082id_type!(RoomId);
2083#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2084pub struct Room {
2085    pub id: RoomId,
2086    pub version: i32,
2087    pub live_kit_room: String,
2088}
2089
2090id_type!(ProjectId);
2091pub struct Project {
2092    pub id: ProjectId,
2093    pub collaborators: Vec<ProjectCollaborator>,
2094    pub worktrees: BTreeMap<u64, Worktree>,
2095    pub language_servers: Vec<proto::LanguageServer>,
2096}
2097
2098#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2099pub struct ProjectCollaborator {
2100    pub project_id: ProjectId,
2101    pub connection_id: i32,
2102    pub user_id: UserId,
2103    pub replica_id: i32,
2104    pub is_host: bool,
2105}
2106
2107#[derive(Default)]
2108pub struct Worktree {
2109    pub abs_path: PathBuf,
2110    pub root_name: String,
2111    pub visible: bool,
2112    pub entries: BTreeMap<u64, proto::Entry>,
2113    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
2114    pub scan_id: u64,
2115    pub is_complete: bool,
2116}
2117
2118pub struct LeftProject {
2119    pub id: ProjectId,
2120    pub host_user_id: UserId,
2121    pub connection_ids: Vec<ConnectionId>,
2122}
2123
2124pub struct LeftRoom {
2125    pub room: proto::Room,
2126    pub left_projects: HashMap<ProjectId, LeftProject>,
2127    pub canceled_calls_to_user_ids: Vec<UserId>,
2128}
2129
2130#[derive(Clone, Debug, PartialEq, Eq)]
2131pub enum Contact {
2132    Accepted {
2133        user_id: UserId,
2134        should_notify: bool,
2135        busy: bool,
2136    },
2137    Outgoing {
2138        user_id: UserId,
2139    },
2140    Incoming {
2141        user_id: UserId,
2142        should_notify: bool,
2143    },
2144}
2145
2146impl Contact {
2147    pub fn user_id(&self) -> UserId {
2148        match self {
2149            Contact::Accepted { user_id, .. } => *user_id,
2150            Contact::Outgoing { user_id } => *user_id,
2151            Contact::Incoming { user_id, .. } => *user_id,
2152        }
2153    }
2154}
2155
2156#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2157pub struct IncomingContactRequest {
2158    pub requester_id: UserId,
2159    pub should_notify: bool,
2160}
2161
2162#[derive(Clone, Deserialize)]
2163pub struct Signup {
2164    pub email_address: String,
2165    pub platform_mac: bool,
2166    pub platform_windows: bool,
2167    pub platform_linux: bool,
2168    pub editor_features: Vec<String>,
2169    pub programming_languages: Vec<String>,
2170    pub device_id: Option<String>,
2171}
2172
2173#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2174pub struct WaitlistSummary {
2175    #[sqlx(default)]
2176    pub count: i64,
2177    #[sqlx(default)]
2178    pub linux_count: i64,
2179    #[sqlx(default)]
2180    pub mac_count: i64,
2181    #[sqlx(default)]
2182    pub windows_count: i64,
2183    #[sqlx(default)]
2184    pub unknown_count: i64,
2185}
2186
2187#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2188pub struct Invite {
2189    pub email_address: String,
2190    pub email_confirmation_code: String,
2191}
2192
2193#[derive(Debug, Serialize, Deserialize)]
2194pub struct NewUserParams {
2195    pub github_login: String,
2196    pub github_user_id: i32,
2197    pub invite_count: i32,
2198}
2199
2200#[derive(Debug)]
2201pub struct NewUserResult {
2202    pub user_id: UserId,
2203    pub metrics_id: String,
2204    pub inviting_user_id: Option<UserId>,
2205    pub signup_device_id: Option<String>,
2206}
2207
2208fn random_invite_code() -> String {
2209    nanoid::nanoid!(16)
2210}
2211
2212fn random_email_confirmation_code() -> String {
2213    nanoid::nanoid!(64)
2214}
2215
2216#[cfg(test)]
2217pub use test::*;
2218
2219#[cfg(test)]
2220mod test {
2221    use super::*;
2222    use gpui::executor::Background;
2223    use lazy_static::lazy_static;
2224    use parking_lot::Mutex;
2225    use rand::prelude::*;
2226    use sqlx::migrate::MigrateDatabase;
2227    use std::sync::Arc;
2228
2229    pub struct SqliteTestDb {
2230        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2231        pub conn: sqlx::sqlite::SqliteConnection,
2232    }
2233
2234    pub struct PostgresTestDb {
2235        pub db: Option<Arc<Db<sqlx::Postgres>>>,
2236        pub url: String,
2237    }
2238
2239    impl SqliteTestDb {
2240        pub fn new(background: Arc<Background>) -> Self {
2241            let mut rng = StdRng::from_entropy();
2242            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2243            let runtime = tokio::runtime::Builder::new_current_thread()
2244                .enable_io()
2245                .enable_time()
2246                .build()
2247                .unwrap();
2248
2249            let (mut db, conn) = runtime.block_on(async {
2250                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2251                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2252                db.migrate(migrations_path.as_ref(), false).await.unwrap();
2253                let conn = db.pool.acquire().await.unwrap().detach();
2254                (db, conn)
2255            });
2256
2257            db.background = Some(background);
2258            db.runtime = Some(runtime);
2259
2260            Self {
2261                db: Some(Arc::new(db)),
2262                conn,
2263            }
2264        }
2265
2266        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2267            self.db.as_ref().unwrap()
2268        }
2269    }
2270
2271    impl PostgresTestDb {
2272        pub fn new(background: Arc<Background>) -> Self {
2273            lazy_static! {
2274                static ref LOCK: Mutex<()> = Mutex::new(());
2275            }
2276
2277            let _guard = LOCK.lock();
2278            let mut rng = StdRng::from_entropy();
2279            let url = format!(
2280                "postgres://postgres@localhost/zed-test-{}",
2281                rng.gen::<u128>()
2282            );
2283            let runtime = tokio::runtime::Builder::new_current_thread()
2284                .enable_io()
2285                .enable_time()
2286                .build()
2287                .unwrap();
2288
2289            let mut db = runtime.block_on(async {
2290                sqlx::Postgres::create_database(&url)
2291                    .await
2292                    .expect("failed to create test db");
2293                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2294                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2295                db.migrate(Path::new(migrations_path), false).await.unwrap();
2296                db
2297            });
2298
2299            db.background = Some(background);
2300            db.runtime = Some(runtime);
2301
2302            Self {
2303                db: Some(Arc::new(db)),
2304                url,
2305            }
2306        }
2307
2308        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2309            self.db.as_ref().unwrap()
2310        }
2311    }
2312
2313    impl Drop for PostgresTestDb {
2314        fn drop(&mut self) {
2315            let db = self.db.take().unwrap();
2316            db.teardown(&self.url);
2317        }
2318    }
2319}