db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::{BTreeMap, HashMap, HashSet};
   5use futures::{future::BoxFuture, FutureExt, StreamExt};
   6use rpc::{proto, ConnectionId};
   7use serde::{Deserialize, Serialize};
   8use sqlx::{
   9    migrate::{Migrate as _, Migration, MigrationSource},
  10    types::Uuid,
  11    FromRow,
  12};
  13use std::{future::Future, path::Path, time::Duration};
  14use time::{OffsetDateTime, PrimitiveDateTime};
  15
  16#[cfg(test)]
  17pub type DefaultDb = Db<sqlx::Sqlite>;
  18
  19#[cfg(not(test))]
  20pub type DefaultDb = Db<sqlx::Postgres>;
  21
  22pub struct Db<D: sqlx::Database> {
  23    pool: sqlx::Pool<D>,
  24    #[cfg(test)]
  25    background: Option<std::sync::Arc<gpui::executor::Background>>,
  26    #[cfg(test)]
  27    runtime: Option<tokio::runtime::Runtime>,
  28}
  29
  30pub trait BeginTransaction: Send + Sync {
  31    type Database: sqlx::Database;
  32
  33    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
  34}
  35
  36// In Postgres, serializable transactions are opt-in
  37impl BeginTransaction for Db<sqlx::Postgres> {
  38    type Database = sqlx::Postgres;
  39
  40    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
  41        async move {
  42            let mut tx = self.pool.begin().await?;
  43            sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
  44                .await?;
  45            Ok(tx)
  46        }
  47        .boxed()
  48    }
  49}
  50
  51// In Sqlite, transactions are inherently serializable.
  52impl BeginTransaction for Db<sqlx::Sqlite> {
  53    type Database = sqlx::Sqlite;
  54
  55    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
  56        async move { Ok(self.pool.begin().await?) }.boxed()
  57    }
  58}
  59
  60pub trait RowsAffected {
  61    fn rows_affected(&self) -> u64;
  62}
  63
  64#[cfg(test)]
  65impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  66    fn rows_affected(&self) -> u64 {
  67        self.rows_affected()
  68    }
  69}
  70
  71impl RowsAffected for sqlx::postgres::PgQueryResult {
  72    fn rows_affected(&self) -> u64 {
  73        self.rows_affected()
  74    }
  75}
  76
  77#[cfg(test)]
  78impl Db<sqlx::Sqlite> {
  79    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
  80        use std::str::FromStr as _;
  81        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
  82            .unwrap()
  83            .create_if_missing(true)
  84            .shared_cache(true);
  85        let pool = sqlx::sqlite::SqlitePoolOptions::new()
  86            .min_connections(2)
  87            .max_connections(max_connections)
  88            .connect_with(options)
  89            .await?;
  90        Ok(Self {
  91            pool,
  92            background: None,
  93            runtime: None,
  94        })
  95    }
  96
  97    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
  98        self.transact(|tx| async {
  99            let mut tx = tx;
 100            let query = "
 101                SELECT users.*
 102                FROM users
 103                WHERE users.id IN (SELECT value from json_each($1))
 104            ";
 105            Ok(sqlx::query_as(query)
 106                .bind(&serde_json::json!(ids))
 107                .fetch_all(&mut tx)
 108                .await?)
 109        })
 110        .await
 111    }
 112
 113    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 114        self.transact(|mut tx| async move {
 115            let query = "
 116                SELECT metrics_id
 117                FROM users
 118                WHERE id = $1
 119            ";
 120            Ok(sqlx::query_scalar(query)
 121                .bind(id)
 122                .fetch_one(&mut tx)
 123                .await?)
 124        })
 125        .await
 126    }
 127
 128    pub async fn create_user(
 129        &self,
 130        email_address: &str,
 131        admin: bool,
 132        params: NewUserParams,
 133    ) -> Result<NewUserResult> {
 134        self.transact(|mut tx| async {
 135            let query = "
 136                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 137                VALUES ($1, $2, $3, $4, $5)
 138                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 139                RETURNING id, metrics_id
 140            ";
 141
 142            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 143                .bind(email_address)
 144                .bind(&params.github_login)
 145                .bind(&params.github_user_id)
 146                .bind(admin)
 147                .bind(Uuid::new_v4().to_string())
 148                .fetch_one(&mut tx)
 149                .await?;
 150            tx.commit().await?;
 151            Ok(NewUserResult {
 152                user_id,
 153                metrics_id,
 154                signup_device_id: None,
 155                inviting_user_id: None,
 156            })
 157        })
 158        .await
 159    }
 160
 161    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 162        unimplemented!()
 163    }
 164
 165    pub async fn create_user_from_invite(
 166        &self,
 167        _invite: &Invite,
 168        _user: NewUserParams,
 169    ) -> Result<Option<NewUserResult>> {
 170        unimplemented!()
 171    }
 172
 173    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 174        unimplemented!()
 175    }
 176
 177    pub async fn create_invite_from_code(
 178        &self,
 179        _code: &str,
 180        _email_address: &str,
 181        _device_id: Option<&str>,
 182    ) -> Result<Invite> {
 183        unimplemented!()
 184    }
 185
 186    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 187        unimplemented!()
 188    }
 189}
 190
 191impl Db<sqlx::Postgres> {
 192    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 193        let pool = sqlx::postgres::PgPoolOptions::new()
 194            .max_connections(max_connections)
 195            .connect(url)
 196            .await?;
 197        Ok(Self {
 198            pool,
 199            #[cfg(test)]
 200            background: None,
 201            #[cfg(test)]
 202            runtime: None,
 203        })
 204    }
 205
 206    #[cfg(test)]
 207    pub fn teardown(&self, url: &str) {
 208        self.runtime.as_ref().unwrap().block_on(async {
 209            use util::ResultExt;
 210            let query = "
 211                SELECT pg_terminate_backend(pg_stat_activity.pid)
 212                FROM pg_stat_activity
 213                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 214            ";
 215            sqlx::query(query).execute(&self.pool).await.log_err();
 216            self.pool.close().await;
 217            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 218                .await
 219                .log_err();
 220        })
 221    }
 222
 223    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 224        self.transact(|tx| async {
 225            let mut tx = tx;
 226            let like_string = Self::fuzzy_like_string(name_query);
 227            let query = "
 228                SELECT users.*
 229                FROM users
 230                WHERE github_login ILIKE $1
 231                ORDER BY github_login <-> $2
 232                LIMIT $3
 233            ";
 234            Ok(sqlx::query_as(query)
 235                .bind(like_string)
 236                .bind(name_query)
 237                .bind(limit as i32)
 238                .fetch_all(&mut tx)
 239                .await?)
 240        })
 241        .await
 242    }
 243
 244    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 245        let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
 246        self.transact(|tx| async {
 247            let mut tx = tx;
 248            let query = "
 249                SELECT users.*
 250                FROM users
 251                WHERE users.id = ANY ($1)
 252            ";
 253            Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
 254        })
 255        .await
 256    }
 257
 258    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 259        self.transact(|mut tx| async move {
 260            let query = "
 261                SELECT metrics_id::text
 262                FROM users
 263                WHERE id = $1
 264            ";
 265            Ok(sqlx::query_scalar(query)
 266                .bind(id)
 267                .fetch_one(&mut tx)
 268                .await?)
 269        })
 270        .await
 271    }
 272
 273    pub async fn create_user(
 274        &self,
 275        email_address: &str,
 276        admin: bool,
 277        params: NewUserParams,
 278    ) -> Result<NewUserResult> {
 279        self.transact(|mut tx| async {
 280            let query = "
 281                INSERT INTO users (email_address, github_login, github_user_id, admin)
 282                VALUES ($1, $2, $3, $4)
 283                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 284                RETURNING id, metrics_id::text
 285            ";
 286
 287            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 288                .bind(email_address)
 289                .bind(&params.github_login)
 290                .bind(params.github_user_id)
 291                .bind(admin)
 292                .fetch_one(&mut tx)
 293                .await?;
 294            tx.commit().await?;
 295
 296            Ok(NewUserResult {
 297                user_id,
 298                metrics_id,
 299                signup_device_id: None,
 300                inviting_user_id: None,
 301            })
 302        })
 303        .await
 304    }
 305
 306    pub async fn create_user_from_invite(
 307        &self,
 308        invite: &Invite,
 309        user: NewUserParams,
 310    ) -> Result<Option<NewUserResult>> {
 311        self.transact(|mut tx| async {
 312            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 313                i32,
 314                Option<UserId>,
 315                Option<UserId>,
 316                Option<String>,
 317            ) = sqlx::query_as(
 318                "
 319                SELECT id, user_id, inviting_user_id, device_id
 320                FROM signups
 321                WHERE
 322                    email_address = $1 AND
 323                    email_confirmation_code = $2
 324                ",
 325            )
 326            .bind(&invite.email_address)
 327            .bind(&invite.email_confirmation_code)
 328            .fetch_optional(&mut tx)
 329            .await?
 330            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 331
 332            if existing_user_id.is_some() {
 333                return Ok(None);
 334            }
 335
 336            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 337                "
 338                INSERT INTO users
 339                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 340                VALUES
 341                ($1, $2, $3, FALSE, $4, $5)
 342                ON CONFLICT (github_login) DO UPDATE SET
 343                    email_address = excluded.email_address,
 344                    github_user_id = excluded.github_user_id,
 345                    admin = excluded.admin
 346                RETURNING id, metrics_id::text
 347                ",
 348            )
 349            .bind(&invite.email_address)
 350            .bind(&user.github_login)
 351            .bind(&user.github_user_id)
 352            .bind(&user.invite_count)
 353            .bind(random_invite_code())
 354            .fetch_one(&mut tx)
 355            .await?;
 356
 357            sqlx::query(
 358                "
 359                UPDATE signups
 360                SET user_id = $1
 361                WHERE id = $2
 362                ",
 363            )
 364            .bind(&user_id)
 365            .bind(&signup_id)
 366            .execute(&mut tx)
 367            .await?;
 368
 369            if let Some(inviting_user_id) = inviting_user_id {
 370                let id: Option<UserId> = sqlx::query_scalar(
 371                    "
 372                    UPDATE users
 373                    SET invite_count = invite_count - 1
 374                    WHERE id = $1 AND invite_count > 0
 375                    RETURNING id
 376                    ",
 377                )
 378                .bind(&inviting_user_id)
 379                .fetch_optional(&mut tx)
 380                .await?;
 381
 382                if id.is_none() {
 383                    Err(Error::Http(
 384                        StatusCode::UNAUTHORIZED,
 385                        "no invites remaining".to_string(),
 386                    ))?;
 387                }
 388
 389                sqlx::query(
 390                    "
 391                    INSERT INTO contacts
 392                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 393                    VALUES
 394                        ($1, $2, TRUE, TRUE, TRUE)
 395                    ON CONFLICT DO NOTHING
 396                    ",
 397                )
 398                .bind(inviting_user_id)
 399                .bind(user_id)
 400                .execute(&mut tx)
 401                .await?;
 402            }
 403
 404            tx.commit().await?;
 405            Ok(Some(NewUserResult {
 406                user_id,
 407                metrics_id,
 408                inviting_user_id,
 409                signup_device_id,
 410            }))
 411        })
 412        .await
 413    }
 414
 415    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 416        self.transact(|mut tx| async {
 417            sqlx::query(
 418                "
 419                INSERT INTO signups
 420                (
 421                    email_address,
 422                    email_confirmation_code,
 423                    email_confirmation_sent,
 424                    platform_linux,
 425                    platform_mac,
 426                    platform_windows,
 427                    platform_unknown,
 428                    editor_features,
 429                    programming_languages,
 430                    device_id
 431                )
 432                VALUES
 433                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 434                RETURNING id
 435                ",
 436            )
 437            .bind(&signup.email_address)
 438            .bind(&random_email_confirmation_code())
 439            .bind(&signup.platform_linux)
 440            .bind(&signup.platform_mac)
 441            .bind(&signup.platform_windows)
 442            .bind(&signup.editor_features)
 443            .bind(&signup.programming_languages)
 444            .bind(&signup.device_id)
 445            .execute(&mut tx)
 446            .await?;
 447            tx.commit().await?;
 448            Ok(())
 449        })
 450        .await
 451    }
 452
 453    pub async fn create_invite_from_code(
 454        &self,
 455        code: &str,
 456        email_address: &str,
 457        device_id: Option<&str>,
 458    ) -> Result<Invite> {
 459        self.transact(|mut tx| async {
 460            let existing_user: Option<UserId> = sqlx::query_scalar(
 461                "
 462                SELECT id
 463                FROM users
 464                WHERE email_address = $1
 465                ",
 466            )
 467            .bind(email_address)
 468            .fetch_optional(&mut tx)
 469            .await?;
 470            if existing_user.is_some() {
 471                Err(anyhow!("email address is already in use"))?;
 472            }
 473
 474            let row: Option<(UserId, i32)> = sqlx::query_as(
 475                "
 476                SELECT id, invite_count
 477                FROM users
 478                WHERE invite_code = $1
 479                ",
 480            )
 481            .bind(code)
 482            .fetch_optional(&mut tx)
 483            .await?;
 484
 485            let (inviter_id, invite_count) = match row {
 486                Some(row) => row,
 487                None => Err(Error::Http(
 488                    StatusCode::NOT_FOUND,
 489                    "invite code not found".to_string(),
 490                ))?,
 491            };
 492
 493            if invite_count == 0 {
 494                Err(Error::Http(
 495                    StatusCode::UNAUTHORIZED,
 496                    "no invites remaining".to_string(),
 497                ))?;
 498            }
 499
 500            let email_confirmation_code: String = sqlx::query_scalar(
 501                "
 502                INSERT INTO signups
 503                (
 504                    email_address,
 505                    email_confirmation_code,
 506                    email_confirmation_sent,
 507                    inviting_user_id,
 508                    platform_linux,
 509                    platform_mac,
 510                    platform_windows,
 511                    platform_unknown,
 512                    device_id
 513                )
 514                VALUES
 515                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 516                ON CONFLICT (email_address)
 517                DO UPDATE SET
 518                    inviting_user_id = excluded.inviting_user_id
 519                RETURNING email_confirmation_code
 520                ",
 521            )
 522            .bind(&email_address)
 523            .bind(&random_email_confirmation_code())
 524            .bind(&inviter_id)
 525            .bind(&device_id)
 526            .fetch_one(&mut tx)
 527            .await?;
 528
 529            tx.commit().await?;
 530
 531            Ok(Invite {
 532                email_address: email_address.into(),
 533                email_confirmation_code,
 534            })
 535        })
 536        .await
 537    }
 538
 539    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 540        self.transact(|mut tx| async {
 541            let emails = invites
 542                .iter()
 543                .map(|s| s.email_address.as_str())
 544                .collect::<Vec<_>>();
 545            sqlx::query(
 546                "
 547                UPDATE signups
 548                SET email_confirmation_sent = TRUE
 549                WHERE email_address = ANY ($1)
 550                ",
 551            )
 552            .bind(&emails)
 553            .execute(&mut tx)
 554            .await?;
 555            tx.commit().await?;
 556            Ok(())
 557        })
 558        .await
 559    }
 560}
 561
 562impl<D> Db<D>
 563where
 564    Self: BeginTransaction<Database = D>,
 565    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 566    D::Connection: sqlx::migrate::Migrate,
 567    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 568    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 569    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 570    D::QueryResult: RowsAffected,
 571    String: sqlx::Type<D>,
 572    i32: sqlx::Type<D>,
 573    i64: sqlx::Type<D>,
 574    bool: sqlx::Type<D>,
 575    str: sqlx::Type<D>,
 576    Uuid: sqlx::Type<D>,
 577    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 578    OffsetDateTime: sqlx::Type<D>,
 579    PrimitiveDateTime: sqlx::Type<D>,
 580    usize: sqlx::ColumnIndex<D::Row>,
 581    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 582    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 583    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 584    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 585    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 586    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 587    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 588    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 589    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 590    for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 591    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 592    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 593    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 594{
 595    pub async fn migrate(
 596        &self,
 597        migrations_path: &Path,
 598        ignore_checksum_mismatch: bool,
 599    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 600        let migrations = MigrationSource::resolve(migrations_path)
 601            .await
 602            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 603
 604        let mut conn = self.pool.acquire().await?;
 605
 606        conn.ensure_migrations_table().await?;
 607        let applied_migrations: HashMap<_, _> = conn
 608            .list_applied_migrations()
 609            .await?
 610            .into_iter()
 611            .map(|m| (m.version, m))
 612            .collect();
 613
 614        let mut new_migrations = Vec::new();
 615        for migration in migrations {
 616            match applied_migrations.get(&migration.version) {
 617                Some(applied_migration) => {
 618                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 619                    {
 620                        Err(anyhow!(
 621                            "checksum mismatch for applied migration {}",
 622                            migration.description
 623                        ))?;
 624                    }
 625                }
 626                None => {
 627                    let elapsed = conn.apply(&migration).await?;
 628                    new_migrations.push((migration, elapsed));
 629                }
 630            }
 631        }
 632
 633        Ok(new_migrations)
 634    }
 635
 636    pub fn fuzzy_like_string(string: &str) -> String {
 637        let mut result = String::with_capacity(string.len() * 2 + 1);
 638        for c in string.chars() {
 639            if c.is_alphanumeric() {
 640                result.push('%');
 641                result.push(c);
 642            }
 643        }
 644        result.push('%');
 645        result
 646    }
 647
 648    // users
 649
 650    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 651        self.transact(|tx| async {
 652            let mut tx = tx;
 653            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 654            Ok(sqlx::query_as(query)
 655                .bind(limit as i32)
 656                .bind((page * limit) as i32)
 657                .fetch_all(&mut tx)
 658                .await?)
 659        })
 660        .await
 661    }
 662
 663    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 664        self.transact(|tx| async {
 665            let mut tx = tx;
 666            let query = "
 667                SELECT users.*
 668                FROM users
 669                WHERE id = $1
 670                LIMIT 1
 671            ";
 672            Ok(sqlx::query_as(query)
 673                .bind(&id)
 674                .fetch_optional(&mut tx)
 675                .await?)
 676        })
 677        .await
 678    }
 679
 680    pub async fn get_users_with_no_invites(
 681        &self,
 682        invited_by_another_user: bool,
 683    ) -> Result<Vec<User>> {
 684        self.transact(|tx| async {
 685            let mut tx = tx;
 686            let query = format!(
 687                "
 688                SELECT users.*
 689                FROM users
 690                WHERE invite_count = 0
 691                AND inviter_id IS{} NULL
 692                ",
 693                if invited_by_another_user { " NOT" } else { "" }
 694            );
 695
 696            Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
 697        })
 698        .await
 699    }
 700
 701    pub async fn get_user_by_github_account(
 702        &self,
 703        github_login: &str,
 704        github_user_id: Option<i32>,
 705    ) -> Result<Option<User>> {
 706        self.transact(|tx| async {
 707            let mut tx = tx;
 708            if let Some(github_user_id) = github_user_id {
 709                let mut user = sqlx::query_as::<_, User>(
 710                    "
 711                    UPDATE users
 712                    SET github_login = $1
 713                    WHERE github_user_id = $2
 714                    RETURNING *
 715                    ",
 716                )
 717                .bind(github_login)
 718                .bind(github_user_id)
 719                .fetch_optional(&mut tx)
 720                .await?;
 721
 722                if user.is_none() {
 723                    user = sqlx::query_as::<_, User>(
 724                        "
 725                        UPDATE users
 726                        SET github_user_id = $1
 727                        WHERE github_login = $2
 728                        RETURNING *
 729                        ",
 730                    )
 731                    .bind(github_user_id)
 732                    .bind(github_login)
 733                    .fetch_optional(&mut tx)
 734                    .await?;
 735                }
 736
 737                Ok(user)
 738            } else {
 739                let user = sqlx::query_as(
 740                    "
 741                    SELECT * FROM users
 742                    WHERE github_login = $1
 743                    LIMIT 1
 744                    ",
 745                )
 746                .bind(github_login)
 747                .fetch_optional(&mut tx)
 748                .await?;
 749                Ok(user)
 750            }
 751        })
 752        .await
 753    }
 754
 755    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 756        self.transact(|mut tx| async {
 757            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 758            sqlx::query(query)
 759                .bind(is_admin)
 760                .bind(id.0)
 761                .execute(&mut tx)
 762                .await?;
 763            tx.commit().await?;
 764            Ok(())
 765        })
 766        .await
 767    }
 768
 769    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 770        self.transact(|mut tx| async move {
 771            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 772            sqlx::query(query)
 773                .bind(connected_once)
 774                .bind(id.0)
 775                .execute(&mut tx)
 776                .await?;
 777            tx.commit().await?;
 778            Ok(())
 779        })
 780        .await
 781    }
 782
 783    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 784        self.transact(|mut tx| async move {
 785            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 786            sqlx::query(query)
 787                .bind(id.0)
 788                .execute(&mut tx)
 789                .await
 790                .map(drop)?;
 791            let query = "DELETE FROM users WHERE id = $1;";
 792            sqlx::query(query).bind(id.0).execute(&mut tx).await?;
 793            tx.commit().await?;
 794            Ok(())
 795        })
 796        .await
 797    }
 798
 799    // signups
 800
 801    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 802        self.transact(|mut tx| async move {
 803            Ok(sqlx::query_as(
 804                "
 805                SELECT
 806                    COUNT(*) as count,
 807                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 808                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 809                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 810                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 811                FROM (
 812                    SELECT *
 813                    FROM signups
 814                    WHERE
 815                        NOT email_confirmation_sent
 816                ) AS unsent
 817                ",
 818            )
 819            .fetch_one(&mut tx)
 820            .await?)
 821        })
 822        .await
 823    }
 824
 825    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 826        self.transact(|mut tx| async move {
 827            Ok(sqlx::query_as(
 828                "
 829                SELECT
 830                    email_address, email_confirmation_code
 831                FROM signups
 832                WHERE
 833                    NOT email_confirmation_sent AND
 834                    (platform_mac OR platform_unknown)
 835                LIMIT $1
 836                ",
 837            )
 838            .bind(count as i32)
 839            .fetch_all(&mut tx)
 840            .await?)
 841        })
 842        .await
 843    }
 844
 845    // invite codes
 846
 847    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 848        self.transact(|mut tx| async move {
 849            if count > 0 {
 850                sqlx::query(
 851                    "
 852                    UPDATE users
 853                    SET invite_code = $1
 854                    WHERE id = $2 AND invite_code IS NULL
 855                ",
 856                )
 857                .bind(random_invite_code())
 858                .bind(id)
 859                .execute(&mut tx)
 860                .await?;
 861            }
 862
 863            sqlx::query(
 864                "
 865                UPDATE users
 866                SET invite_count = $1
 867                WHERE id = $2
 868                ",
 869            )
 870            .bind(count as i32)
 871            .bind(id)
 872            .execute(&mut tx)
 873            .await?;
 874            tx.commit().await?;
 875            Ok(())
 876        })
 877        .await
 878    }
 879
 880    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 881        self.transact(|mut tx| async move {
 882            let result: Option<(String, i32)> = sqlx::query_as(
 883                "
 884                    SELECT invite_code, invite_count
 885                    FROM users
 886                    WHERE id = $1 AND invite_code IS NOT NULL 
 887                ",
 888            )
 889            .bind(id)
 890            .fetch_optional(&mut tx)
 891            .await?;
 892            if let Some((code, count)) = result {
 893                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 894            } else {
 895                Ok(None)
 896            }
 897        })
 898        .await
 899    }
 900
 901    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 902        self.transact(|tx| async {
 903            let mut tx = tx;
 904            sqlx::query_as(
 905                "
 906                    SELECT *
 907                    FROM users
 908                    WHERE invite_code = $1
 909                ",
 910            )
 911            .bind(code)
 912            .fetch_optional(&mut tx)
 913            .await?
 914            .ok_or_else(|| {
 915                Error::Http(
 916                    StatusCode::NOT_FOUND,
 917                    "that invite code does not exist".to_string(),
 918                )
 919            })
 920        })
 921        .await
 922    }
 923
 924    pub async fn create_room(
 925        &self,
 926        user_id: UserId,
 927        connection_id: ConnectionId,
 928    ) -> Result<proto::Room> {
 929        self.transact(|mut tx| async move {
 930            let live_kit_room = nanoid::nanoid!(30);
 931            let room_id = sqlx::query_scalar(
 932                "
 933                INSERT INTO rooms (live_kit_room, version)
 934                VALUES ($1, $2)
 935                RETURNING id
 936                ",
 937            )
 938            .bind(&live_kit_room)
 939            .bind(0)
 940            .fetch_one(&mut tx)
 941            .await
 942            .map(RoomId)?;
 943
 944            sqlx::query(
 945                "
 946                INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
 947                VALUES ($1, $2, $3, $4, $5)
 948                ",
 949            )
 950            .bind(room_id)
 951            .bind(user_id)
 952            .bind(connection_id.0 as i32)
 953            .bind(user_id)
 954            .bind(connection_id.0 as i32)
 955            .execute(&mut tx)
 956            .await?;
 957
 958            self.commit_room_transaction(room_id, tx).await
 959        }).await
 960    }
 961
 962    pub async fn call(
 963        &self,
 964        room_id: RoomId,
 965        calling_user_id: UserId,
 966        calling_connection_id: ConnectionId,
 967        called_user_id: UserId,
 968        initial_project_id: Option<ProjectId>,
 969    ) -> Result<(proto::Room, proto::IncomingCall)> {
 970        self.transact(|mut tx| async move {
 971            sqlx::query(
 972                "
 973                INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
 974                VALUES ($1, $2, $3, $4, $5)
 975                ",
 976            )
 977            .bind(room_id)
 978            .bind(called_user_id)
 979            .bind(calling_user_id)
 980            .bind(calling_connection_id.0 as i32)
 981            .bind(initial_project_id)
 982            .execute(&mut tx)
 983            .await?;
 984
 985            let room = self.commit_room_transaction(room_id, tx).await?;
 986            let incoming_call = Self::build_incoming_call(&room, called_user_id)
 987                .ok_or_else(|| anyhow!("failed to build incoming call"))?;
 988            Ok((room, incoming_call))
 989        }).await
 990    }
 991
 992    pub async fn incoming_call_for_user(
 993        &self,
 994        user_id: UserId,
 995    ) -> Result<Option<proto::IncomingCall>> {
 996        self.transact(|mut tx| async move {
 997            let room_id = sqlx::query_scalar::<_, RoomId>(
 998                "
 999                SELECT room_id
1000                FROM room_participants
1001                WHERE user_id = $1 AND answering_connection_id IS NULL
1002                ",
1003            )
1004            .bind(user_id)
1005            .fetch_optional(&mut tx)
1006            .await?;
1007
1008            if let Some(room_id) = room_id {
1009                let room = self.get_room(room_id, &mut tx).await?;
1010                Ok(Self::build_incoming_call(&room, user_id))
1011            } else {
1012                Ok(None)
1013            }
1014        })
1015        .await
1016    }
1017
1018    fn build_incoming_call(
1019        room: &proto::Room,
1020        called_user_id: UserId,
1021    ) -> Option<proto::IncomingCall> {
1022        let pending_participant = room
1023            .pending_participants
1024            .iter()
1025            .find(|participant| participant.user_id == called_user_id.to_proto())?;
1026
1027        Some(proto::IncomingCall {
1028            room_id: room.id,
1029            calling_user_id: pending_participant.calling_user_id,
1030            participant_user_ids: room
1031                .participants
1032                .iter()
1033                .map(|participant| participant.user_id)
1034                .collect(),
1035            initial_project: room.participants.iter().find_map(|participant| {
1036                let initial_project_id = pending_participant.initial_project_id?;
1037                participant
1038                    .projects
1039                    .iter()
1040                    .find(|project| project.id == initial_project_id)
1041                    .cloned()
1042            }),
1043        })
1044    }
1045
1046    pub async fn call_failed(
1047        &self,
1048        room_id: RoomId,
1049        called_user_id: UserId,
1050    ) -> Result<proto::Room> {
1051        self.transact(|mut tx| async move {
1052            sqlx::query(
1053                "
1054                DELETE FROM room_participants
1055                WHERE room_id = $1 AND user_id = $2
1056                ",
1057            )
1058            .bind(room_id)
1059            .bind(called_user_id)
1060            .execute(&mut tx)
1061            .await?;
1062
1063            self.commit_room_transaction(room_id, tx).await
1064        })
1065        .await
1066    }
1067
1068    pub async fn decline_call(
1069        &self,
1070        expected_room_id: Option<RoomId>,
1071        user_id: UserId,
1072    ) -> Result<proto::Room> {
1073        self.transact(|mut tx| async move {
1074            let room_id = sqlx::query_scalar(
1075                "
1076                DELETE FROM room_participants
1077                WHERE user_id = $1 AND answering_connection_id IS NULL
1078                RETURNING room_id
1079                ",
1080            )
1081            .bind(user_id)
1082            .fetch_one(&mut tx)
1083            .await?;
1084            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1085                return Err(anyhow!("declining call on unexpected room"))?;
1086            }
1087
1088            self.commit_room_transaction(room_id, tx).await
1089        })
1090        .await
1091    }
1092
1093    pub async fn cancel_call(
1094        &self,
1095        expected_room_id: Option<RoomId>,
1096        calling_connection_id: ConnectionId,
1097        called_user_id: UserId,
1098    ) -> Result<proto::Room> {
1099        self.transact(|mut tx| async move {
1100            let room_id = sqlx::query_scalar(
1101                "
1102                DELETE FROM room_participants
1103                WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1104                RETURNING room_id
1105                ",
1106            )
1107            .bind(called_user_id)
1108            .bind(calling_connection_id.0 as i32)
1109            .fetch_one(&mut tx)
1110            .await?;
1111            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1112                return Err(anyhow!("canceling call on unexpected room"))?;
1113            }
1114
1115            self.commit_room_transaction(room_id, tx).await
1116        }).await
1117    }
1118
1119    pub async fn join_room(
1120        &self,
1121        room_id: RoomId,
1122        user_id: UserId,
1123        connection_id: ConnectionId,
1124    ) -> Result<proto::Room> {
1125        self.transact(|mut tx| async move {
1126            sqlx::query(
1127                "
1128                UPDATE room_participants 
1129                SET answering_connection_id = $1
1130                WHERE room_id = $2 AND user_id = $3
1131                RETURNING 1
1132                ",
1133            )
1134            .bind(connection_id.0 as i32)
1135            .bind(room_id)
1136            .bind(user_id)
1137            .fetch_one(&mut tx)
1138            .await?;
1139            self.commit_room_transaction(room_id, tx).await
1140        })
1141        .await
1142    }
1143
1144    pub async fn leave_room_for_connection(
1145        &self,
1146        connection_id: ConnectionId,
1147    ) -> Result<Option<LeftRoom>> {
1148        self.transact(|mut tx| async move {
1149            // Leave room.
1150            let room_id = sqlx::query_scalar::<_, RoomId>(
1151                "
1152                DELETE FROM room_participants
1153                WHERE answering_connection_id = $1
1154                RETURNING room_id
1155                ",
1156            )
1157            .bind(connection_id.0 as i32)
1158            .fetch_optional(&mut tx)
1159            .await?;
1160
1161            if let Some(room_id) = room_id {
1162                // Cancel pending calls initiated by the leaving user.
1163                let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1164                    "
1165                    DELETE FROM room_participants
1166                    WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1167                    RETURNING user_id
1168                    ",
1169                )
1170                .bind(connection_id.0 as i32)
1171                .fetch_all(&mut tx)
1172                .await?;
1173
1174                let project_ids = sqlx::query_scalar::<_, ProjectId>(
1175                    "
1176                    SELECT project_id
1177                    FROM project_collaborators
1178                    WHERE connection_id = $1
1179                    ",
1180                )
1181                .bind(connection_id.0 as i32)
1182                .fetch_all(&mut tx)
1183                .await?;
1184
1185                // Leave projects.
1186                let mut left_projects = HashMap::default();
1187                if !project_ids.is_empty() {
1188                    let mut params = "?,".repeat(project_ids.len());
1189                    params.pop();
1190                    let query = format!(
1191                        "
1192                        SELECT *
1193                        FROM project_collaborators
1194                        WHERE project_id IN ({params})
1195                    "
1196                    );
1197                    let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1198                    for project_id in project_ids {
1199                        query = query.bind(project_id);
1200                    }
1201
1202                    let mut project_collaborators = query.fetch(&mut tx);
1203                    while let Some(collaborator) = project_collaborators.next().await {
1204                        let collaborator = collaborator?;
1205                        let left_project =
1206                            left_projects
1207                                .entry(collaborator.project_id)
1208                                .or_insert(LeftProject {
1209                                    id: collaborator.project_id,
1210                                    host_user_id: Default::default(),
1211                                    connection_ids: Default::default(),
1212                                    host_connection_id: Default::default(),
1213                                });
1214
1215                        let collaborator_connection_id =
1216                            ConnectionId(collaborator.connection_id as u32);
1217                        if collaborator_connection_id != connection_id {
1218                            left_project.connection_ids.push(collaborator_connection_id);
1219                        }
1220
1221                        if collaborator.is_host {
1222                            left_project.host_user_id = collaborator.user_id;
1223                            left_project.host_connection_id =
1224                                ConnectionId(collaborator.connection_id as u32);
1225                        }
1226                    }
1227                }
1228                sqlx::query(
1229                    "
1230                    DELETE FROM project_collaborators
1231                    WHERE connection_id = $1
1232                    ",
1233                )
1234                .bind(connection_id.0 as i32)
1235                .execute(&mut tx)
1236                .await?;
1237
1238                // Unshare projects.
1239                sqlx::query(
1240                    "
1241                    DELETE FROM projects
1242                    WHERE room_id = $1 AND host_connection_id = $2
1243                    ",
1244                )
1245                .bind(room_id)
1246                .bind(connection_id.0 as i32)
1247                .execute(&mut tx)
1248                .await?;
1249
1250                let room = self.commit_room_transaction(room_id, tx).await?;
1251                Ok(Some(LeftRoom {
1252                    room,
1253                    left_projects,
1254                    canceled_calls_to_user_ids,
1255                }))
1256            } else {
1257                Ok(None)
1258            }
1259        })
1260        .await
1261    }
1262
1263    pub async fn update_room_participant_location(
1264        &self,
1265        room_id: RoomId,
1266        connection_id: ConnectionId,
1267        location: proto::ParticipantLocation,
1268    ) -> Result<proto::Room> {
1269        self.transact(|tx| async {
1270            let mut tx = tx;
1271            let location_kind;
1272            let location_project_id;
1273            match location
1274                .variant
1275                .as_ref()
1276                .ok_or_else(|| anyhow!("invalid location"))?
1277            {
1278                proto::participant_location::Variant::SharedProject(project) => {
1279                    location_kind = 0;
1280                    location_project_id = Some(ProjectId::from_proto(project.id));
1281                }
1282                proto::participant_location::Variant::UnsharedProject(_) => {
1283                    location_kind = 1;
1284                    location_project_id = None;
1285                }
1286                proto::participant_location::Variant::External(_) => {
1287                    location_kind = 2;
1288                    location_project_id = None;
1289                }
1290            }
1291
1292            sqlx::query(
1293                "
1294                UPDATE room_participants
1295                SET location_kind = $1, location_project_id = $2
1296                WHERE room_id = $3 AND answering_connection_id = $4
1297                RETURNING 1
1298                ",
1299            )
1300            .bind(location_kind)
1301            .bind(location_project_id)
1302            .bind(room_id)
1303            .bind(connection_id.0 as i32)
1304            .fetch_one(&mut tx)
1305            .await?;
1306
1307            self.commit_room_transaction(room_id, tx).await
1308        })
1309        .await
1310    }
1311
1312    async fn commit_room_transaction(
1313        &self,
1314        room_id: RoomId,
1315        mut tx: sqlx::Transaction<'_, D>,
1316    ) -> Result<proto::Room> {
1317        sqlx::query(
1318            "
1319            UPDATE rooms
1320            SET version = version + 1
1321            WHERE id = $1
1322            ",
1323        )
1324        .bind(room_id)
1325        .execute(&mut tx)
1326        .await?;
1327        let room = self.get_room(room_id, &mut tx).await?;
1328        tx.commit().await?;
1329
1330        Ok(room)
1331    }
1332
1333    async fn get_guest_connection_ids(
1334        &self,
1335        project_id: ProjectId,
1336        tx: &mut sqlx::Transaction<'_, D>,
1337    ) -> Result<Vec<ConnectionId>> {
1338        let mut guest_connection_ids = Vec::new();
1339        let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1340            "
1341            SELECT connection_id
1342            FROM project_collaborators
1343            WHERE project_id = $1 AND is_host = FALSE
1344            ",
1345        )
1346        .bind(project_id)
1347        .fetch(tx);
1348        while let Some(connection_id) = db_guest_connection_ids.next().await {
1349            guest_connection_ids.push(ConnectionId(connection_id? as u32));
1350        }
1351        Ok(guest_connection_ids)
1352    }
1353
1354    async fn get_room(
1355        &self,
1356        room_id: RoomId,
1357        tx: &mut sqlx::Transaction<'_, D>,
1358    ) -> Result<proto::Room> {
1359        let room: Room = sqlx::query_as(
1360            "
1361            SELECT *
1362            FROM rooms
1363            WHERE id = $1
1364            ",
1365        )
1366        .bind(room_id)
1367        .fetch_one(&mut *tx)
1368        .await?;
1369
1370        let mut db_participants =
1371            sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1372                "
1373                SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1374                FROM room_participants
1375                WHERE room_id = $1
1376                ",
1377            )
1378            .bind(room_id)
1379            .fetch(&mut *tx);
1380
1381        let mut participants = HashMap::default();
1382        let mut pending_participants = Vec::new();
1383        while let Some(participant) = db_participants.next().await {
1384            let (
1385                user_id,
1386                answering_connection_id,
1387                location_kind,
1388                location_project_id,
1389                calling_user_id,
1390                initial_project_id,
1391            ) = participant?;
1392            if let Some(answering_connection_id) = answering_connection_id {
1393                let location = match (location_kind, location_project_id) {
1394                    (Some(0), Some(project_id)) => {
1395                        Some(proto::participant_location::Variant::SharedProject(
1396                            proto::participant_location::SharedProject {
1397                                id: project_id.to_proto(),
1398                            },
1399                        ))
1400                    }
1401                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1402                        Default::default(),
1403                    )),
1404                    _ => Some(proto::participant_location::Variant::External(
1405                        Default::default(),
1406                    )),
1407                };
1408                participants.insert(
1409                    answering_connection_id,
1410                    proto::Participant {
1411                        user_id: user_id.to_proto(),
1412                        peer_id: answering_connection_id as u32,
1413                        projects: Default::default(),
1414                        location: Some(proto::ParticipantLocation { variant: location }),
1415                    },
1416                );
1417            } else {
1418                pending_participants.push(proto::PendingParticipant {
1419                    user_id: user_id.to_proto(),
1420                    calling_user_id: calling_user_id.to_proto(),
1421                    initial_project_id: initial_project_id.map(|id| id.to_proto()),
1422                });
1423            }
1424        }
1425        drop(db_participants);
1426
1427        let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1428            "
1429            SELECT host_connection_id, projects.id, worktrees.root_name
1430            FROM projects
1431            LEFT JOIN worktrees ON projects.id = worktrees.project_id
1432            WHERE room_id = $1
1433            ",
1434        )
1435        .bind(room_id)
1436        .fetch(&mut *tx);
1437
1438        while let Some(row) = rows.next().await {
1439            let (connection_id, project_id, worktree_root_name) = row?;
1440            if let Some(participant) = participants.get_mut(&connection_id) {
1441                let project = if let Some(project) = participant
1442                    .projects
1443                    .iter_mut()
1444                    .find(|project| project.id == project_id.to_proto())
1445                {
1446                    project
1447                } else {
1448                    participant.projects.push(proto::ParticipantProject {
1449                        id: project_id.to_proto(),
1450                        worktree_root_names: Default::default(),
1451                    });
1452                    participant.projects.last_mut().unwrap()
1453                };
1454                project.worktree_root_names.extend(worktree_root_name);
1455            }
1456        }
1457
1458        Ok(proto::Room {
1459            id: room.id.to_proto(),
1460            version: room.version as u64,
1461            live_kit_room: room.live_kit_room,
1462            participants: participants.into_values().collect(),
1463            pending_participants,
1464        })
1465    }
1466
1467    // projects
1468
1469    pub async fn share_project(
1470        &self,
1471        expected_room_id: RoomId,
1472        connection_id: ConnectionId,
1473        worktrees: &[proto::WorktreeMetadata],
1474    ) -> Result<(ProjectId, proto::Room)> {
1475        self.transact(|mut tx| async move {
1476            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1477                "
1478                SELECT room_id, user_id
1479                FROM room_participants
1480                WHERE answering_connection_id = $1
1481                ",
1482            )
1483            .bind(connection_id.0 as i32)
1484            .fetch_one(&mut tx)
1485            .await?;
1486            if room_id != expected_room_id {
1487                return Err(anyhow!("shared project on unexpected room"))?;
1488            }
1489
1490            let project_id: ProjectId = sqlx::query_scalar(
1491                "
1492                INSERT INTO projects (room_id, host_user_id, host_connection_id)
1493                VALUES ($1, $2, $3)
1494                RETURNING id
1495                ",
1496            )
1497            .bind(room_id)
1498            .bind(user_id)
1499            .bind(connection_id.0 as i32)
1500            .fetch_one(&mut tx)
1501            .await
1502            .unwrap();
1503
1504            if !worktrees.is_empty() {
1505                let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1506                params.pop();
1507                let query = format!(
1508                    "
1509                    INSERT INTO worktrees (
1510                        project_id,
1511                        id,
1512                        root_name,
1513                        abs_path,
1514                        visible,
1515                        scan_id,
1516                        is_complete
1517                    )
1518                    VALUES {params}
1519                    "
1520                );
1521
1522                let mut query = sqlx::query(&query);
1523                for worktree in worktrees {
1524                    query = query
1525                        .bind(project_id)
1526                        .bind(worktree.id as i32)
1527                        .bind(&worktree.root_name)
1528                        .bind(&worktree.abs_path)
1529                        .bind(worktree.visible)
1530                        .bind(0)
1531                        .bind(false);
1532                }
1533                query.execute(&mut tx).await.unwrap();
1534            }
1535
1536            sqlx::query(
1537                "
1538                INSERT INTO project_collaborators (
1539                    project_id,
1540                    connection_id,
1541                    user_id,
1542                    replica_id,
1543                    is_host
1544                )
1545                VALUES ($1, $2, $3, $4, $5)
1546                ",
1547            )
1548            .bind(project_id)
1549            .bind(connection_id.0 as i32)
1550            .bind(user_id)
1551            .bind(0)
1552            .bind(true)
1553            .execute(&mut tx)
1554            .await
1555            .unwrap();
1556
1557            let room = self.commit_room_transaction(room_id, tx).await?;
1558            Ok((project_id, room))
1559        })
1560        .await
1561    }
1562
1563    pub async fn unshare_project(
1564        &self,
1565        project_id: ProjectId,
1566        connection_id: ConnectionId,
1567    ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1568        self.transact(|mut tx| async move {
1569            let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1570            let room_id: RoomId = sqlx::query_scalar(
1571                "
1572                DELETE FROM projects
1573                WHERE id = $1 AND host_connection_id = $2
1574                RETURNING room_id
1575                ",
1576            )
1577            .bind(project_id)
1578            .bind(connection_id.0 as i32)
1579            .fetch_one(&mut tx)
1580            .await?;
1581            let room = self.commit_room_transaction(room_id, tx).await?;
1582
1583            Ok((room, guest_connection_ids))
1584        })
1585        .await
1586    }
1587
1588    pub async fn update_project(
1589        &self,
1590        project_id: ProjectId,
1591        connection_id: ConnectionId,
1592        worktrees: &[proto::WorktreeMetadata],
1593    ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1594        self.transact(|mut tx| async move {
1595            let room_id: RoomId = sqlx::query_scalar(
1596                "
1597                SELECT room_id
1598                FROM projects
1599                WHERE id = $1 AND host_connection_id = $2
1600                ",
1601            )
1602            .bind(project_id)
1603            .bind(connection_id.0 as i32)
1604            .fetch_one(&mut tx)
1605            .await?;
1606
1607            if !worktrees.is_empty() {
1608                let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1609                params.pop();
1610                let query = format!(
1611                    "
1612                    INSERT INTO worktrees (
1613                        project_id,
1614                        id,
1615                        root_name,
1616                        abs_path,
1617                        visible,
1618                        scan_id,
1619                        is_complete
1620                    )
1621                    VALUES {params}
1622                    ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1623                    "
1624                );
1625
1626                let mut query = sqlx::query(&query);
1627                for worktree in worktrees {
1628                    query = query
1629                        .bind(project_id)
1630                        .bind(worktree.id as i32)
1631                        .bind(&worktree.root_name)
1632                        .bind(&worktree.abs_path)
1633                        .bind(worktree.visible)
1634                        .bind(0)
1635                        .bind(false)
1636                }
1637                query.execute(&mut tx).await?;
1638            }
1639
1640            let mut params = "?,".repeat(worktrees.len());
1641            if !worktrees.is_empty() {
1642                params.pop();
1643            }
1644            let query = format!(
1645                "
1646                DELETE FROM worktrees
1647                WHERE project_id = ? AND id NOT IN ({params})
1648                ",
1649            );
1650
1651            let mut query = sqlx::query(&query).bind(project_id);
1652            for worktree in worktrees {
1653                query = query.bind(WorktreeId(worktree.id as i32));
1654            }
1655            query.execute(&mut tx).await?;
1656
1657            let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1658            let room = self.commit_room_transaction(room_id, tx).await?;
1659
1660            Ok((room, guest_connection_ids))
1661        })
1662        .await
1663    }
1664
1665    pub async fn update_worktree(
1666        &self,
1667        update: &proto::UpdateWorktree,
1668        connection_id: ConnectionId,
1669    ) -> Result<Vec<ConnectionId>> {
1670        self.transact(|mut tx| async move {
1671            let project_id = ProjectId::from_proto(update.project_id);
1672            let worktree_id = WorktreeId::from_proto(update.worktree_id);
1673
1674            // Ensure the update comes from the host.
1675            sqlx::query(
1676                "
1677                SELECT 1
1678                FROM projects
1679                WHERE id = $1 AND host_connection_id = $2
1680                ",
1681            )
1682            .bind(project_id)
1683            .bind(connection_id.0 as i32)
1684            .fetch_one(&mut tx)
1685            .await?;
1686
1687            // Update metadata.
1688            sqlx::query(
1689                "
1690                UPDATE worktrees
1691                SET
1692                    root_name = $1,
1693                    scan_id = $2,
1694                    is_complete = $3,
1695                    abs_path = $4
1696                WHERE project_id = $5 AND id = $6
1697                RETURNING 1
1698                ",
1699            )
1700            .bind(&update.root_name)
1701            .bind(update.scan_id as i64)
1702            .bind(update.is_last_update)
1703            .bind(&update.abs_path)
1704            .bind(project_id)
1705            .bind(worktree_id)
1706            .fetch_one(&mut tx)
1707            .await?;
1708
1709            if !update.updated_entries.is_empty() {
1710                let mut params =
1711                    "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1712                params.pop();
1713
1714                let query = format!(
1715                    "
1716                    INSERT INTO worktree_entries (
1717                        project_id, 
1718                        worktree_id, 
1719                        id, 
1720                        is_dir, 
1721                        path, 
1722                        inode,
1723                        mtime_seconds, 
1724                        mtime_nanos, 
1725                        is_symlink, 
1726                        is_ignored
1727                    )
1728                    VALUES {params}
1729                    ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1730                        is_dir = excluded.is_dir,
1731                        path = excluded.path,
1732                        inode = excluded.inode,
1733                        mtime_seconds = excluded.mtime_seconds,
1734                        mtime_nanos = excluded.mtime_nanos,
1735                        is_symlink = excluded.is_symlink,
1736                        is_ignored = excluded.is_ignored
1737                    "
1738                );
1739                let mut query = sqlx::query(&query);
1740                for entry in &update.updated_entries {
1741                    let mtime = entry.mtime.clone().unwrap_or_default();
1742                    query = query
1743                        .bind(project_id)
1744                        .bind(worktree_id)
1745                        .bind(entry.id as i64)
1746                        .bind(entry.is_dir)
1747                        .bind(&entry.path)
1748                        .bind(entry.inode as i64)
1749                        .bind(mtime.seconds as i64)
1750                        .bind(mtime.nanos as i32)
1751                        .bind(entry.is_symlink)
1752                        .bind(entry.is_ignored);
1753                }
1754                query.execute(&mut tx).await?;
1755            }
1756
1757            if !update.removed_entries.is_empty() {
1758                let mut params = "?,".repeat(update.removed_entries.len());
1759                params.pop();
1760                let query = format!(
1761                    "
1762                    DELETE FROM worktree_entries
1763                    WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1764                    "
1765                );
1766
1767                let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1768                for entry_id in &update.removed_entries {
1769                    query = query.bind(*entry_id as i64);
1770                }
1771                query.execute(&mut tx).await?;
1772            }
1773
1774            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1775            tx.commit().await?;
1776            Ok(connection_ids)
1777        })
1778        .await
1779    }
1780
1781    pub async fn update_diagnostic_summary(
1782        &self,
1783        update: &proto::UpdateDiagnosticSummary,
1784        connection_id: ConnectionId,
1785    ) -> Result<Vec<ConnectionId>> {
1786        self.transact(|mut tx| async {
1787            let project_id = ProjectId::from_proto(update.project_id);
1788            let worktree_id = WorktreeId::from_proto(update.worktree_id);
1789            let summary = update
1790                .summary
1791                .as_ref()
1792                .ok_or_else(|| anyhow!("invalid summary"))?;
1793
1794            // Ensure the update comes from the host.
1795            sqlx::query(
1796                "
1797                SELECT 1
1798                FROM projects
1799                WHERE id = $1 AND host_connection_id = $2
1800                ",
1801            )
1802            .bind(project_id)
1803            .bind(connection_id.0 as i32)
1804            .fetch_one(&mut tx)
1805            .await?;
1806
1807            // Update summary.
1808            sqlx::query(
1809                "
1810                INSERT INTO worktree_diagnostic_summaries (
1811                    project_id,
1812                    worktree_id,
1813                    path,
1814                    language_server_id,
1815                    error_count,
1816                    warning_count,
1817                    version
1818                )
1819                VALUES ($1, $2, $3, $4, $5, $6, $7)
1820                ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1821                    language_server_id = excluded.language_server_id,
1822                    error_count = excluded.error_count, 
1823                    warning_count = excluded.warning_count,
1824                    version = excluded.version
1825                ",
1826            )
1827            .bind(project_id)
1828            .bind(worktree_id)
1829            .bind(&summary.path)
1830            .bind(summary.language_server_id as i64)
1831            .bind(summary.error_count as i32)
1832            .bind(summary.warning_count as i32)
1833            .bind(summary.version as i32)
1834            .execute(&mut tx)
1835            .await?;
1836
1837            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1838            tx.commit().await?;
1839            Ok(connection_ids)
1840        })
1841        .await
1842    }
1843
1844    pub async fn start_language_server(
1845        &self,
1846        update: &proto::StartLanguageServer,
1847        connection_id: ConnectionId,
1848    ) -> Result<Vec<ConnectionId>> {
1849        self.transact(|mut tx| async {
1850            let project_id = ProjectId::from_proto(update.project_id);
1851            let server = update
1852                .server
1853                .as_ref()
1854                .ok_or_else(|| anyhow!("invalid language server"))?;
1855
1856            // Ensure the update comes from the host.
1857            sqlx::query(
1858                "
1859                SELECT 1
1860                FROM projects
1861                WHERE id = $1 AND host_connection_id = $2
1862                ",
1863            )
1864            .bind(project_id)
1865            .bind(connection_id.0 as i32)
1866            .fetch_one(&mut tx)
1867            .await?;
1868
1869            // Add the newly-started language server.
1870            sqlx::query(
1871                "
1872                INSERT INTO language_servers (project_id, id, name)
1873                VALUES ($1, $2, $3)
1874                ON CONFLICT (project_id, id) DO UPDATE SET
1875                    name = excluded.name
1876                ",
1877            )
1878            .bind(project_id)
1879            .bind(server.id as i64)
1880            .bind(&server.name)
1881            .execute(&mut tx)
1882            .await?;
1883
1884            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1885            tx.commit().await?;
1886            Ok(connection_ids)
1887        })
1888        .await
1889    }
1890
1891    pub async fn join_project(
1892        &self,
1893        project_id: ProjectId,
1894        connection_id: ConnectionId,
1895    ) -> Result<(Project, ReplicaId)> {
1896        self.transact(|mut tx| async move {
1897            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1898                "
1899                SELECT room_id, user_id
1900                FROM room_participants
1901                WHERE answering_connection_id = $1
1902                ",
1903            )
1904            .bind(connection_id.0 as i32)
1905            .fetch_one(&mut tx)
1906            .await?;
1907
1908            // Ensure project id was shared on this room.
1909            sqlx::query(
1910                "
1911                SELECT 1
1912                FROM projects
1913                WHERE id = $1 AND room_id = $2
1914                ",
1915            )
1916            .bind(project_id)
1917            .bind(room_id)
1918            .fetch_one(&mut tx)
1919            .await?;
1920
1921            let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1922                "
1923                SELECT *
1924                FROM project_collaborators
1925                WHERE project_id = $1
1926                ",
1927            )
1928            .bind(project_id)
1929            .fetch_all(&mut tx)
1930            .await?;
1931            let replica_ids = collaborators
1932                .iter()
1933                .map(|c| c.replica_id)
1934                .collect::<HashSet<_>>();
1935            let mut replica_id = ReplicaId(1);
1936            while replica_ids.contains(&replica_id) {
1937                replica_id.0 += 1;
1938            }
1939            let new_collaborator = ProjectCollaborator {
1940                project_id,
1941                connection_id: connection_id.0 as i32,
1942                user_id,
1943                replica_id,
1944                is_host: false,
1945            };
1946
1947            sqlx::query(
1948                "
1949                INSERT INTO project_collaborators (
1950                    project_id,
1951                    connection_id,
1952                    user_id,
1953                    replica_id,
1954                    is_host
1955                )
1956                VALUES ($1, $2, $3, $4, $5)
1957                ",
1958            )
1959            .bind(new_collaborator.project_id)
1960            .bind(new_collaborator.connection_id)
1961            .bind(new_collaborator.user_id)
1962            .bind(new_collaborator.replica_id)
1963            .bind(new_collaborator.is_host)
1964            .execute(&mut tx)
1965            .await?;
1966            collaborators.push(new_collaborator);
1967
1968            let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
1969                "
1970                SELECT *
1971                FROM worktrees
1972                WHERE project_id = $1
1973                ",
1974            )
1975            .bind(project_id)
1976            .fetch_all(&mut tx)
1977            .await?;
1978            let mut worktrees = worktree_rows
1979                .into_iter()
1980                .map(|worktree_row| {
1981                    (
1982                        worktree_row.id,
1983                        Worktree {
1984                            id: worktree_row.id,
1985                            abs_path: worktree_row.abs_path,
1986                            root_name: worktree_row.root_name,
1987                            visible: worktree_row.visible,
1988                            entries: Default::default(),
1989                            diagnostic_summaries: Default::default(),
1990                            scan_id: worktree_row.scan_id as u64,
1991                            is_complete: worktree_row.is_complete,
1992                        },
1993                    )
1994                })
1995                .collect::<BTreeMap<_, _>>();
1996
1997            // Populate worktree entries.
1998            {
1999                let mut entries = sqlx::query_as::<_, WorktreeEntry>(
2000                    "
2001                    SELECT *
2002                    FROM worktree_entries
2003                    WHERE project_id = $1
2004                    ",
2005                )
2006                .bind(project_id)
2007                .fetch(&mut tx);
2008                while let Some(entry) = entries.next().await {
2009                    let entry = entry?;
2010                    if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2011                        worktree.entries.push(proto::Entry {
2012                            id: entry.id as u64,
2013                            is_dir: entry.is_dir,
2014                            path: entry.path,
2015                            inode: entry.inode as u64,
2016                            mtime: Some(proto::Timestamp {
2017                                seconds: entry.mtime_seconds as u64,
2018                                nanos: entry.mtime_nanos as u32,
2019                            }),
2020                            is_symlink: entry.is_symlink,
2021                            is_ignored: entry.is_ignored,
2022                        });
2023                    }
2024                }
2025            }
2026
2027            // Populate worktree diagnostic summaries.
2028            {
2029                let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2030                    "
2031                    SELECT *
2032                    FROM worktree_diagnostic_summaries
2033                    WHERE project_id = $1
2034                    ",
2035                )
2036                .bind(project_id)
2037                .fetch(&mut tx);
2038                while let Some(summary) = summaries.next().await {
2039                    let summary = summary?;
2040                    if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2041                        worktree
2042                            .diagnostic_summaries
2043                            .push(proto::DiagnosticSummary {
2044                                path: summary.path,
2045                                language_server_id: summary.language_server_id as u64,
2046                                error_count: summary.error_count as u32,
2047                                warning_count: summary.warning_count as u32,
2048                                version: summary.version as u32,
2049                            });
2050                    }
2051                }
2052            }
2053
2054            // Populate language servers.
2055            let language_servers = sqlx::query_as::<_, LanguageServer>(
2056                "
2057                SELECT *
2058                FROM language_servers
2059                WHERE project_id = $1
2060                ",
2061            )
2062            .bind(project_id)
2063            .fetch_all(&mut tx)
2064            .await?;
2065
2066            tx.commit().await?;
2067            Ok((
2068                Project {
2069                    collaborators,
2070                    worktrees,
2071                    language_servers: language_servers
2072                        .into_iter()
2073                        .map(|language_server| proto::LanguageServer {
2074                            id: language_server.id.to_proto(),
2075                            name: language_server.name,
2076                        })
2077                        .collect(),
2078                },
2079                replica_id as ReplicaId,
2080            ))
2081        })
2082        .await
2083    }
2084
2085    pub async fn leave_project(
2086        &self,
2087        project_id: ProjectId,
2088        connection_id: ConnectionId,
2089    ) -> Result<LeftProject> {
2090        self.transact(|mut tx| async move {
2091            let result = sqlx::query(
2092                "
2093                DELETE FROM project_collaborators
2094                WHERE project_id = $1 AND connection_id = $2
2095                ",
2096            )
2097            .bind(project_id)
2098            .bind(connection_id.0 as i32)
2099            .execute(&mut tx)
2100            .await?;
2101
2102            if result.rows_affected() == 0 {
2103                Err(anyhow!("not a collaborator on this project"))?;
2104            }
2105
2106            let connection_ids = sqlx::query_scalar::<_, i32>(
2107                "
2108                SELECT connection_id
2109                FROM project_collaborators
2110                WHERE project_id = $1
2111                ",
2112            )
2113            .bind(project_id)
2114            .fetch_all(&mut tx)
2115            .await?
2116            .into_iter()
2117            .map(|id| ConnectionId(id as u32))
2118            .collect();
2119
2120            let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>(
2121                "
2122                SELECT host_user_id, host_connection_id
2123                FROM projects
2124                WHERE id = $1
2125                ",
2126            )
2127            .bind(project_id)
2128            .fetch_one(&mut tx)
2129            .await?;
2130
2131            tx.commit().await?;
2132
2133            Ok(LeftProject {
2134                id: project_id,
2135                host_user_id: UserId(host_user_id),
2136                host_connection_id: ConnectionId(host_connection_id as u32),
2137                connection_ids,
2138            })
2139        })
2140        .await
2141    }
2142
2143    pub async fn project_collaborators(
2144        &self,
2145        project_id: ProjectId,
2146        connection_id: ConnectionId,
2147    ) -> Result<Vec<ProjectCollaborator>> {
2148        self.transact(|mut tx| async move {
2149            let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2150                "
2151                SELECT *
2152                FROM project_collaborators
2153                WHERE project_id = $1
2154                ",
2155            )
2156            .bind(project_id)
2157            .fetch_all(&mut tx)
2158            .await?;
2159
2160            if collaborators
2161                .iter()
2162                .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2163            {
2164                Ok(collaborators)
2165            } else {
2166                Err(anyhow!("no such project"))?
2167            }
2168        })
2169        .await
2170    }
2171
2172    pub async fn project_connection_ids(
2173        &self,
2174        project_id: ProjectId,
2175        connection_id: ConnectionId,
2176    ) -> Result<HashSet<ConnectionId>> {
2177        self.transact(|mut tx| async move {
2178            let connection_ids = sqlx::query_scalar::<_, i32>(
2179                "
2180                SELECT connection_id
2181                FROM project_collaborators
2182                WHERE project_id = $1
2183                ",
2184            )
2185            .bind(project_id)
2186            .fetch_all(&mut tx)
2187            .await?;
2188
2189            if connection_ids.contains(&(connection_id.0 as i32)) {
2190                Ok(connection_ids
2191                    .into_iter()
2192                    .map(|connection_id| ConnectionId(connection_id as u32))
2193                    .collect())
2194            } else {
2195                Err(anyhow!("no such project"))?
2196            }
2197        })
2198        .await
2199    }
2200
2201    // contacts
2202
2203    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2204        self.transact(|mut tx| async move {
2205            let query = "
2206                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2207                FROM contacts
2208                LEFT JOIN room_participants ON room_participants.user_id = $1
2209                WHERE user_id_a = $1 OR user_id_b = $1;
2210            ";
2211
2212            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2213                .bind(user_id)
2214                .fetch(&mut tx);
2215
2216            let mut contacts = Vec::new();
2217            while let Some(row) = rows.next().await {
2218                let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2219                if user_id_a == user_id {
2220                    if accepted {
2221                        contacts.push(Contact::Accepted {
2222                            user_id: user_id_b,
2223                            should_notify: should_notify && a_to_b,
2224                            busy
2225                        });
2226                    } else if a_to_b {
2227                        contacts.push(Contact::Outgoing { user_id: user_id_b })
2228                    } else {
2229                        contacts.push(Contact::Incoming {
2230                            user_id: user_id_b,
2231                            should_notify,
2232                        });
2233                    }
2234                } else if accepted {
2235                    contacts.push(Contact::Accepted {
2236                        user_id: user_id_a,
2237                        should_notify: should_notify && !a_to_b,
2238                        busy
2239                    });
2240                } else if a_to_b {
2241                    contacts.push(Contact::Incoming {
2242                        user_id: user_id_a,
2243                        should_notify,
2244                    });
2245                } else {
2246                    contacts.push(Contact::Outgoing { user_id: user_id_a });
2247                }
2248            }
2249
2250            contacts.sort_unstable_by_key(|contact| contact.user_id());
2251
2252            Ok(contacts)
2253        })
2254        .await
2255    }
2256
2257    pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2258        self.transact(|mut tx| async move {
2259            Ok(sqlx::query_scalar::<_, i32>(
2260                "
2261                SELECT 1
2262                FROM room_participants
2263                WHERE room_participants.user_id = $1
2264                ",
2265            )
2266            .bind(user_id)
2267            .fetch_optional(&mut tx)
2268            .await?
2269            .is_some())
2270        })
2271        .await
2272    }
2273
2274    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2275        self.transact(|mut tx| async move {
2276            let (id_a, id_b) = if user_id_1 < user_id_2 {
2277                (user_id_1, user_id_2)
2278            } else {
2279                (user_id_2, user_id_1)
2280            };
2281
2282            let query = "
2283                SELECT 1 FROM contacts
2284                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2285                LIMIT 1
2286            ";
2287            Ok(sqlx::query_scalar::<_, i32>(query)
2288                .bind(id_a.0)
2289                .bind(id_b.0)
2290                .fetch_optional(&mut tx)
2291                .await?
2292                .is_some())
2293        })
2294        .await
2295    }
2296
2297    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2298        self.transact(|mut tx| async move {
2299            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2300                (sender_id, receiver_id, true)
2301            } else {
2302                (receiver_id, sender_id, false)
2303            };
2304            let query = "
2305                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2306                VALUES ($1, $2, $3, FALSE, TRUE)
2307                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2308                SET
2309                    accepted = TRUE,
2310                    should_notify = FALSE
2311                WHERE
2312                    NOT contacts.accepted AND
2313                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2314                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2315            ";
2316            let result = sqlx::query(query)
2317                .bind(id_a.0)
2318                .bind(id_b.0)
2319                .bind(a_to_b)
2320                .execute(&mut tx)
2321                .await?;
2322
2323            if result.rows_affected() == 1 {
2324                tx.commit().await?;
2325                Ok(())
2326            } else {
2327                Err(anyhow!("contact already requested"))?
2328            }
2329        }).await
2330    }
2331
2332    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2333        self.transact(|mut tx| async move {
2334            let (id_a, id_b) = if responder_id < requester_id {
2335                (responder_id, requester_id)
2336            } else {
2337                (requester_id, responder_id)
2338            };
2339            let query = "
2340                DELETE FROM contacts
2341                WHERE user_id_a = $1 AND user_id_b = $2;
2342            ";
2343            let result = sqlx::query(query)
2344                .bind(id_a.0)
2345                .bind(id_b.0)
2346                .execute(&mut tx)
2347                .await?;
2348
2349            if result.rows_affected() == 1 {
2350                tx.commit().await?;
2351                Ok(())
2352            } else {
2353                Err(anyhow!("no such contact"))?
2354            }
2355        })
2356        .await
2357    }
2358
2359    pub async fn dismiss_contact_notification(
2360        &self,
2361        user_id: UserId,
2362        contact_user_id: UserId,
2363    ) -> Result<()> {
2364        self.transact(|mut tx| async move {
2365            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2366                (user_id, contact_user_id, true)
2367            } else {
2368                (contact_user_id, user_id, false)
2369            };
2370
2371            let query = "
2372                UPDATE contacts
2373                SET should_notify = FALSE
2374                WHERE
2375                    user_id_a = $1 AND user_id_b = $2 AND
2376                    (
2377                        (a_to_b = $3 AND accepted) OR
2378                        (a_to_b != $3 AND NOT accepted)
2379                    );
2380            ";
2381
2382            let result = sqlx::query(query)
2383                .bind(id_a.0)
2384                .bind(id_b.0)
2385                .bind(a_to_b)
2386                .execute(&mut tx)
2387                .await?;
2388
2389            if result.rows_affected() == 0 {
2390                Err(anyhow!("no such contact request"))?
2391            } else {
2392                tx.commit().await?;
2393                Ok(())
2394            }
2395        })
2396        .await
2397    }
2398
2399    pub async fn respond_to_contact_request(
2400        &self,
2401        responder_id: UserId,
2402        requester_id: UserId,
2403        accept: bool,
2404    ) -> Result<()> {
2405        self.transact(|mut tx| async move {
2406            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2407                (responder_id, requester_id, false)
2408            } else {
2409                (requester_id, responder_id, true)
2410            };
2411            let result = if accept {
2412                let query = "
2413                    UPDATE contacts
2414                    SET accepted = TRUE, should_notify = TRUE
2415                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2416                ";
2417                sqlx::query(query)
2418                    .bind(id_a.0)
2419                    .bind(id_b.0)
2420                    .bind(a_to_b)
2421                    .execute(&mut tx)
2422                    .await?
2423            } else {
2424                let query = "
2425                    DELETE FROM contacts
2426                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2427                ";
2428                sqlx::query(query)
2429                    .bind(id_a.0)
2430                    .bind(id_b.0)
2431                    .bind(a_to_b)
2432                    .execute(&mut tx)
2433                    .await?
2434            };
2435            if result.rows_affected() == 1 {
2436                tx.commit().await?;
2437                Ok(())
2438            } else {
2439                Err(anyhow!("no such contact request"))?
2440            }
2441        })
2442        .await
2443    }
2444
2445    // access tokens
2446
2447    pub async fn create_access_token_hash(
2448        &self,
2449        user_id: UserId,
2450        access_token_hash: &str,
2451        max_access_token_count: usize,
2452    ) -> Result<()> {
2453        self.transact(|tx| async {
2454            let mut tx = tx;
2455            let insert_query = "
2456                INSERT INTO access_tokens (user_id, hash)
2457                VALUES ($1, $2);
2458            ";
2459            let cleanup_query = "
2460                DELETE FROM access_tokens
2461                WHERE id IN (
2462                    SELECT id from access_tokens
2463                    WHERE user_id = $1
2464                    ORDER BY id DESC
2465                    LIMIT 10000
2466                    OFFSET $3
2467                )
2468            ";
2469
2470            sqlx::query(insert_query)
2471                .bind(user_id.0)
2472                .bind(access_token_hash)
2473                .execute(&mut tx)
2474                .await?;
2475            sqlx::query(cleanup_query)
2476                .bind(user_id.0)
2477                .bind(access_token_hash)
2478                .bind(max_access_token_count as i32)
2479                .execute(&mut tx)
2480                .await?;
2481            Ok(tx.commit().await?)
2482        })
2483        .await
2484    }
2485
2486    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2487        self.transact(|mut tx| async move {
2488            let query = "
2489                SELECT hash
2490                FROM access_tokens
2491                WHERE user_id = $1
2492                ORDER BY id DESC
2493            ";
2494            Ok(sqlx::query_scalar(query)
2495                .bind(user_id.0)
2496                .fetch_all(&mut tx)
2497                .await?)
2498        })
2499        .await
2500    }
2501
2502    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2503    where
2504        F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2505        Fut: Send + Future<Output = Result<T>>,
2506    {
2507        let body = async {
2508            loop {
2509                let tx = self.begin_transaction().await?;
2510                match f(tx).await {
2511                    Ok(result) => return Ok(result),
2512                    Err(error) => match error {
2513                        Error::Database(error)
2514                            if error
2515                                .as_database_error()
2516                                .and_then(|error| error.code())
2517                                .as_deref()
2518                                == Some("hey") =>
2519                        {
2520                            // Retry (don't break the loop)
2521                        }
2522                        error @ _ => return Err(error),
2523                    },
2524                }
2525            }
2526        };
2527
2528        #[cfg(test)]
2529        {
2530            if let Some(background) = self.background.as_ref() {
2531                background.simulate_random_delay().await;
2532            }
2533
2534            let result = self.runtime.as_ref().unwrap().block_on(body);
2535
2536            if let Some(background) = self.background.as_ref() {
2537                background.simulate_random_delay().await;
2538            }
2539
2540            result
2541        }
2542
2543        #[cfg(not(test))]
2544        {
2545            body.await
2546        }
2547    }
2548}
2549
2550macro_rules! id_type {
2551    ($name:ident) => {
2552        #[derive(
2553            Clone,
2554            Copy,
2555            Debug,
2556            Default,
2557            PartialEq,
2558            Eq,
2559            PartialOrd,
2560            Ord,
2561            Hash,
2562            sqlx::Type,
2563            Serialize,
2564            Deserialize,
2565        )]
2566        #[sqlx(transparent)]
2567        #[serde(transparent)]
2568        pub struct $name(pub i32);
2569
2570        impl $name {
2571            #[allow(unused)]
2572            pub const MAX: Self = Self(i32::MAX);
2573
2574            #[allow(unused)]
2575            pub fn from_proto(value: u64) -> Self {
2576                Self(value as i32)
2577            }
2578
2579            #[allow(unused)]
2580            pub fn to_proto(self) -> u64 {
2581                self.0 as u64
2582            }
2583        }
2584
2585        impl std::fmt::Display for $name {
2586            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2587                self.0.fmt(f)
2588            }
2589        }
2590    };
2591}
2592
2593id_type!(UserId);
2594#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2595pub struct User {
2596    pub id: UserId,
2597    pub github_login: String,
2598    pub github_user_id: Option<i32>,
2599    pub email_address: Option<String>,
2600    pub admin: bool,
2601    pub invite_code: Option<String>,
2602    pub invite_count: i32,
2603    pub connected_once: bool,
2604}
2605
2606id_type!(RoomId);
2607#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2608pub struct Room {
2609    pub id: RoomId,
2610    pub version: i32,
2611    pub live_kit_room: String,
2612}
2613
2614id_type!(ProjectId);
2615pub struct Project {
2616    pub collaborators: Vec<ProjectCollaborator>,
2617    pub worktrees: BTreeMap<WorktreeId, Worktree>,
2618    pub language_servers: Vec<proto::LanguageServer>,
2619}
2620
2621id_type!(ReplicaId);
2622#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2623pub struct ProjectCollaborator {
2624    pub project_id: ProjectId,
2625    pub connection_id: i32,
2626    pub user_id: UserId,
2627    pub replica_id: ReplicaId,
2628    pub is_host: bool,
2629}
2630
2631id_type!(WorktreeId);
2632#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2633struct WorktreeRow {
2634    pub id: WorktreeId,
2635    pub abs_path: String,
2636    pub root_name: String,
2637    pub visible: bool,
2638    pub scan_id: i64,
2639    pub is_complete: bool,
2640}
2641
2642pub struct Worktree {
2643    pub id: WorktreeId,
2644    pub abs_path: String,
2645    pub root_name: String,
2646    pub visible: bool,
2647    pub entries: Vec<proto::Entry>,
2648    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2649    pub scan_id: u64,
2650    pub is_complete: bool,
2651}
2652
2653#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2654struct WorktreeEntry {
2655    id: i64,
2656    worktree_id: WorktreeId,
2657    is_dir: bool,
2658    path: String,
2659    inode: i64,
2660    mtime_seconds: i64,
2661    mtime_nanos: i32,
2662    is_symlink: bool,
2663    is_ignored: bool,
2664}
2665
2666#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2667struct WorktreeDiagnosticSummary {
2668    worktree_id: WorktreeId,
2669    path: String,
2670    language_server_id: i64,
2671    error_count: i32,
2672    warning_count: i32,
2673    version: i32,
2674}
2675
2676id_type!(LanguageServerId);
2677#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2678struct LanguageServer {
2679    id: LanguageServerId,
2680    name: String,
2681}
2682
2683pub struct LeftProject {
2684    pub id: ProjectId,
2685    pub host_user_id: UserId,
2686    pub host_connection_id: ConnectionId,
2687    pub connection_ids: Vec<ConnectionId>,
2688}
2689
2690pub struct LeftRoom {
2691    pub room: proto::Room,
2692    pub left_projects: HashMap<ProjectId, LeftProject>,
2693    pub canceled_calls_to_user_ids: Vec<UserId>,
2694}
2695
2696#[derive(Clone, Debug, PartialEq, Eq)]
2697pub enum Contact {
2698    Accepted {
2699        user_id: UserId,
2700        should_notify: bool,
2701        busy: bool,
2702    },
2703    Outgoing {
2704        user_id: UserId,
2705    },
2706    Incoming {
2707        user_id: UserId,
2708        should_notify: bool,
2709    },
2710}
2711
2712impl Contact {
2713    pub fn user_id(&self) -> UserId {
2714        match self {
2715            Contact::Accepted { user_id, .. } => *user_id,
2716            Contact::Outgoing { user_id } => *user_id,
2717            Contact::Incoming { user_id, .. } => *user_id,
2718        }
2719    }
2720}
2721
2722#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2723pub struct IncomingContactRequest {
2724    pub requester_id: UserId,
2725    pub should_notify: bool,
2726}
2727
2728#[derive(Clone, Deserialize)]
2729pub struct Signup {
2730    pub email_address: String,
2731    pub platform_mac: bool,
2732    pub platform_windows: bool,
2733    pub platform_linux: bool,
2734    pub editor_features: Vec<String>,
2735    pub programming_languages: Vec<String>,
2736    pub device_id: Option<String>,
2737}
2738
2739#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2740pub struct WaitlistSummary {
2741    #[sqlx(default)]
2742    pub count: i64,
2743    #[sqlx(default)]
2744    pub linux_count: i64,
2745    #[sqlx(default)]
2746    pub mac_count: i64,
2747    #[sqlx(default)]
2748    pub windows_count: i64,
2749    #[sqlx(default)]
2750    pub unknown_count: i64,
2751}
2752
2753#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2754pub struct Invite {
2755    pub email_address: String,
2756    pub email_confirmation_code: String,
2757}
2758
2759#[derive(Debug, Serialize, Deserialize)]
2760pub struct NewUserParams {
2761    pub github_login: String,
2762    pub github_user_id: i32,
2763    pub invite_count: i32,
2764}
2765
2766#[derive(Debug)]
2767pub struct NewUserResult {
2768    pub user_id: UserId,
2769    pub metrics_id: String,
2770    pub inviting_user_id: Option<UserId>,
2771    pub signup_device_id: Option<String>,
2772}
2773
2774fn random_invite_code() -> String {
2775    nanoid::nanoid!(16)
2776}
2777
2778fn random_email_confirmation_code() -> String {
2779    nanoid::nanoid!(64)
2780}
2781
2782#[cfg(test)]
2783pub use test::*;
2784
2785#[cfg(test)]
2786mod test {
2787    use super::*;
2788    use gpui::executor::Background;
2789    use lazy_static::lazy_static;
2790    use parking_lot::Mutex;
2791    use rand::prelude::*;
2792    use sqlx::migrate::MigrateDatabase;
2793    use std::sync::Arc;
2794
2795    pub struct SqliteTestDb {
2796        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2797        pub conn: sqlx::sqlite::SqliteConnection,
2798    }
2799
2800    pub struct PostgresTestDb {
2801        pub db: Option<Arc<Db<sqlx::Postgres>>>,
2802        pub url: String,
2803    }
2804
2805    impl SqliteTestDb {
2806        pub fn new(background: Arc<Background>) -> Self {
2807            let mut rng = StdRng::from_entropy();
2808            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2809            let runtime = tokio::runtime::Builder::new_current_thread()
2810                .enable_io()
2811                .enable_time()
2812                .build()
2813                .unwrap();
2814
2815            let (mut db, conn) = runtime.block_on(async {
2816                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2817                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2818                db.migrate(migrations_path.as_ref(), false).await.unwrap();
2819                let conn = db.pool.acquire().await.unwrap().detach();
2820                (db, conn)
2821            });
2822
2823            db.background = Some(background);
2824            db.runtime = Some(runtime);
2825
2826            Self {
2827                db: Some(Arc::new(db)),
2828                conn,
2829            }
2830        }
2831
2832        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2833            self.db.as_ref().unwrap()
2834        }
2835    }
2836
2837    impl PostgresTestDb {
2838        pub fn new(background: Arc<Background>) -> Self {
2839            lazy_static! {
2840                static ref LOCK: Mutex<()> = Mutex::new(());
2841            }
2842
2843            let _guard = LOCK.lock();
2844            let mut rng = StdRng::from_entropy();
2845            let url = format!(
2846                "postgres://postgres@localhost/zed-test-{}",
2847                rng.gen::<u128>()
2848            );
2849            let runtime = tokio::runtime::Builder::new_current_thread()
2850                .enable_io()
2851                .enable_time()
2852                .build()
2853                .unwrap();
2854
2855            let mut db = runtime.block_on(async {
2856                sqlx::Postgres::create_database(&url)
2857                    .await
2858                    .expect("failed to create test db");
2859                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2860                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2861                db.migrate(Path::new(migrations_path), false).await.unwrap();
2862                db
2863            });
2864
2865            db.background = Some(background);
2866            db.runtime = Some(runtime);
2867
2868            Self {
2869                db: Some(Arc::new(db)),
2870                url,
2871            }
2872        }
2873
2874        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2875            self.db.as_ref().unwrap()
2876        }
2877    }
2878
2879    impl Drop for PostgresTestDb {
2880        fn drop(&mut self) {
2881            let db = self.db.take().unwrap();
2882            db.teardown(&self.url);
2883        }
2884    }
2885}