db.rs

   1mod schema;
   2#[cfg(test)]
   3mod tests;
   4
   5use crate::{Error, Result};
   6use anyhow::anyhow;
   7use axum::http::StatusCode;
   8use collections::{BTreeMap, HashMap, HashSet};
   9use dashmap::DashMap;
  10use futures::{future::BoxFuture, FutureExt, StreamExt};
  11use rpc::{proto, ConnectionId};
  12use sea_query::{Expr, Query};
  13use sea_query_binder::SqlxBinder;
  14use serde::{Deserialize, Serialize};
  15use sqlx::{
  16    migrate::{Migrate as _, Migration, MigrationSource},
  17    types::Uuid,
  18    FromRow,
  19};
  20use std::{
  21    future::Future,
  22    marker::PhantomData,
  23    ops::{Deref, DerefMut},
  24    path::Path,
  25    rc::Rc,
  26    sync::Arc,
  27    time::Duration,
  28};
  29use time::{OffsetDateTime, PrimitiveDateTime};
  30use tokio::sync::{Mutex, OwnedMutexGuard};
  31
  32#[cfg(test)]
  33pub type DefaultDb = Db<sqlx::Sqlite>;
  34
  35#[cfg(not(test))]
  36pub type DefaultDb = Db<sqlx::Postgres>;
  37
  38pub struct Db<D: sqlx::Database> {
  39    pool: sqlx::Pool<D>,
  40    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
  41    #[cfg(test)]
  42    background: Option<std::sync::Arc<gpui::executor::Background>>,
  43    #[cfg(test)]
  44    runtime: Option<tokio::runtime::Runtime>,
  45}
  46
  47pub struct RoomGuard<T> {
  48    data: T,
  49    _guard: OwnedMutexGuard<()>,
  50    _not_send: PhantomData<Rc<()>>,
  51}
  52
  53impl<T> Deref for RoomGuard<T> {
  54    type Target = T;
  55
  56    fn deref(&self) -> &T {
  57        &self.data
  58    }
  59}
  60
  61impl<T> DerefMut for RoomGuard<T> {
  62    fn deref_mut(&mut self) -> &mut T {
  63        &mut self.data
  64    }
  65}
  66
  67pub trait BeginTransaction: Send + Sync {
  68    type Database: sqlx::Database;
  69
  70    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
  71}
  72
  73// In Postgres, serializable transactions are opt-in
  74impl BeginTransaction for Db<sqlx::Postgres> {
  75    type Database = sqlx::Postgres;
  76
  77    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
  78        async move {
  79            let mut tx = self.pool.begin().await?;
  80            sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
  81                .await?;
  82            Ok(tx)
  83        }
  84        .boxed()
  85    }
  86}
  87
  88// In Sqlite, transactions are inherently serializable.
  89#[cfg(test)]
  90impl BeginTransaction for Db<sqlx::Sqlite> {
  91    type Database = sqlx::Sqlite;
  92
  93    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
  94        async move { Ok(self.pool.begin().await?) }.boxed()
  95    }
  96}
  97
  98pub trait BuildQuery {
  99    fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues);
 100}
 101
 102impl BuildQuery for Db<sqlx::Postgres> {
 103    fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
 104        query.build_sqlx(sea_query::PostgresQueryBuilder)
 105    }
 106}
 107
 108#[cfg(test)]
 109impl BuildQuery for Db<sqlx::Sqlite> {
 110    fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
 111        query.build_sqlx(sea_query::SqliteQueryBuilder)
 112    }
 113}
 114
 115pub trait RowsAffected {
 116    fn rows_affected(&self) -> u64;
 117}
 118
 119#[cfg(test)]
 120impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
 121    fn rows_affected(&self) -> u64 {
 122        self.rows_affected()
 123    }
 124}
 125
 126impl RowsAffected for sqlx::postgres::PgQueryResult {
 127    fn rows_affected(&self) -> u64 {
 128        self.rows_affected()
 129    }
 130}
 131
 132#[cfg(test)]
 133impl Db<sqlx::Sqlite> {
 134    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 135        use std::str::FromStr as _;
 136        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
 137            .unwrap()
 138            .create_if_missing(true)
 139            .shared_cache(true);
 140        let pool = sqlx::sqlite::SqlitePoolOptions::new()
 141            .min_connections(2)
 142            .max_connections(max_connections)
 143            .connect_with(options)
 144            .await?;
 145        Ok(Self {
 146            pool,
 147            rooms: Default::default(),
 148            background: None,
 149            runtime: None,
 150        })
 151    }
 152
 153    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 154        self.transact(|tx| async {
 155            let mut tx = tx;
 156            let query = "
 157                SELECT users.*
 158                FROM users
 159                WHERE users.id IN (SELECT value from json_each($1))
 160            ";
 161            Ok(sqlx::query_as(query)
 162                .bind(&serde_json::json!(ids))
 163                .fetch_all(&mut tx)
 164                .await?)
 165        })
 166        .await
 167    }
 168
 169    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 170        self.transact(|mut tx| async move {
 171            let query = "
 172                SELECT metrics_id
 173                FROM users
 174                WHERE id = $1
 175            ";
 176            Ok(sqlx::query_scalar(query)
 177                .bind(id)
 178                .fetch_one(&mut tx)
 179                .await?)
 180        })
 181        .await
 182    }
 183
 184    pub async fn create_user(
 185        &self,
 186        email_address: &str,
 187        admin: bool,
 188        params: NewUserParams,
 189    ) -> Result<NewUserResult> {
 190        self.transact(|mut tx| async {
 191            let query = "
 192                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 193                VALUES ($1, $2, $3, $4, $5)
 194                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 195                RETURNING id, metrics_id
 196            ";
 197
 198            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 199                .bind(email_address)
 200                .bind(&params.github_login)
 201                .bind(&params.github_user_id)
 202                .bind(admin)
 203                .bind(Uuid::new_v4().to_string())
 204                .fetch_one(&mut tx)
 205                .await?;
 206            tx.commit().await?;
 207            Ok(NewUserResult {
 208                user_id,
 209                metrics_id,
 210                signup_device_id: None,
 211                inviting_user_id: None,
 212            })
 213        })
 214        .await
 215    }
 216
 217    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 218        unimplemented!()
 219    }
 220
 221    pub async fn create_user_from_invite(
 222        &self,
 223        _invite: &Invite,
 224        _user: NewUserParams,
 225    ) -> Result<Option<NewUserResult>> {
 226        unimplemented!()
 227    }
 228
 229    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 230        unimplemented!()
 231    }
 232
 233    pub async fn create_invite_from_code(
 234        &self,
 235        _code: &str,
 236        _email_address: &str,
 237        _device_id: Option<&str>,
 238    ) -> Result<Invite> {
 239        unimplemented!()
 240    }
 241
 242    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 243        unimplemented!()
 244    }
 245}
 246
 247impl Db<sqlx::Postgres> {
 248    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 249        let pool = sqlx::postgres::PgPoolOptions::new()
 250            .max_connections(max_connections)
 251            .connect(url)
 252            .await?;
 253        Ok(Self {
 254            pool,
 255            rooms: DashMap::with_capacity(16384),
 256            #[cfg(test)]
 257            background: None,
 258            #[cfg(test)]
 259            runtime: None,
 260        })
 261    }
 262
 263    #[cfg(test)]
 264    pub fn teardown(&self, url: &str) {
 265        self.runtime.as_ref().unwrap().block_on(async {
 266            use util::ResultExt;
 267            let query = "
 268                SELECT pg_terminate_backend(pg_stat_activity.pid)
 269                FROM pg_stat_activity
 270                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 271            ";
 272            sqlx::query(query).execute(&self.pool).await.log_err();
 273            self.pool.close().await;
 274            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 275                .await
 276                .log_err();
 277        })
 278    }
 279
 280    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 281        self.transact(|tx| async {
 282            let mut tx = tx;
 283            let like_string = Self::fuzzy_like_string(name_query);
 284            let query = "
 285                SELECT users.*
 286                FROM users
 287                WHERE github_login ILIKE $1
 288                ORDER BY github_login <-> $2
 289                LIMIT $3
 290            ";
 291            Ok(sqlx::query_as(query)
 292                .bind(like_string)
 293                .bind(name_query)
 294                .bind(limit as i32)
 295                .fetch_all(&mut tx)
 296                .await?)
 297        })
 298        .await
 299    }
 300
 301    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 302        let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
 303        self.transact(|tx| async {
 304            let mut tx = tx;
 305            let query = "
 306                SELECT users.*
 307                FROM users
 308                WHERE users.id = ANY ($1)
 309            ";
 310            Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
 311        })
 312        .await
 313    }
 314
 315    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 316        self.transact(|mut tx| async move {
 317            let query = "
 318                SELECT metrics_id::text
 319                FROM users
 320                WHERE id = $1
 321            ";
 322            Ok(sqlx::query_scalar(query)
 323                .bind(id)
 324                .fetch_one(&mut tx)
 325                .await?)
 326        })
 327        .await
 328    }
 329
 330    pub async fn create_user(
 331        &self,
 332        email_address: &str,
 333        admin: bool,
 334        params: NewUserParams,
 335    ) -> Result<NewUserResult> {
 336        self.transact(|mut tx| async {
 337            let query = "
 338                INSERT INTO users (email_address, github_login, github_user_id, admin)
 339                VALUES ($1, $2, $3, $4)
 340                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 341                RETURNING id, metrics_id::text
 342            ";
 343
 344            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 345                .bind(email_address)
 346                .bind(&params.github_login)
 347                .bind(params.github_user_id)
 348                .bind(admin)
 349                .fetch_one(&mut tx)
 350                .await?;
 351            tx.commit().await?;
 352
 353            Ok(NewUserResult {
 354                user_id,
 355                metrics_id,
 356                signup_device_id: None,
 357                inviting_user_id: None,
 358            })
 359        })
 360        .await
 361    }
 362
 363    pub async fn create_user_from_invite(
 364        &self,
 365        invite: &Invite,
 366        user: NewUserParams,
 367    ) -> Result<Option<NewUserResult>> {
 368        self.transact(|mut tx| async {
 369            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 370                i32,
 371                Option<UserId>,
 372                Option<UserId>,
 373                Option<String>,
 374            ) = sqlx::query_as(
 375                "
 376                SELECT id, user_id, inviting_user_id, device_id
 377                FROM signups
 378                WHERE
 379                    email_address = $1 AND
 380                    email_confirmation_code = $2
 381                ",
 382            )
 383            .bind(&invite.email_address)
 384            .bind(&invite.email_confirmation_code)
 385            .fetch_optional(&mut tx)
 386            .await?
 387            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 388
 389            if existing_user_id.is_some() {
 390                return Ok(None);
 391            }
 392
 393            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 394                "
 395                INSERT INTO users
 396                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 397                VALUES
 398                ($1, $2, $3, FALSE, $4, $5)
 399                ON CONFLICT (github_login) DO UPDATE SET
 400                    email_address = excluded.email_address,
 401                    github_user_id = excluded.github_user_id,
 402                    admin = excluded.admin
 403                RETURNING id, metrics_id::text
 404                ",
 405            )
 406            .bind(&invite.email_address)
 407            .bind(&user.github_login)
 408            .bind(&user.github_user_id)
 409            .bind(&user.invite_count)
 410            .bind(random_invite_code())
 411            .fetch_one(&mut tx)
 412            .await?;
 413
 414            sqlx::query(
 415                "
 416                UPDATE signups
 417                SET user_id = $1
 418                WHERE id = $2
 419                ",
 420            )
 421            .bind(&user_id)
 422            .bind(&signup_id)
 423            .execute(&mut tx)
 424            .await?;
 425
 426            if let Some(inviting_user_id) = inviting_user_id {
 427                let id: Option<UserId> = sqlx::query_scalar(
 428                    "
 429                    UPDATE users
 430                    SET invite_count = invite_count - 1
 431                    WHERE id = $1 AND invite_count > 0
 432                    RETURNING id
 433                    ",
 434                )
 435                .bind(&inviting_user_id)
 436                .fetch_optional(&mut tx)
 437                .await?;
 438
 439                if id.is_none() {
 440                    Err(Error::Http(
 441                        StatusCode::UNAUTHORIZED,
 442                        "no invites remaining".to_string(),
 443                    ))?;
 444                }
 445
 446                sqlx::query(
 447                    "
 448                    INSERT INTO contacts
 449                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 450                    VALUES
 451                        ($1, $2, TRUE, TRUE, TRUE)
 452                    ON CONFLICT DO NOTHING
 453                    ",
 454                )
 455                .bind(inviting_user_id)
 456                .bind(user_id)
 457                .execute(&mut tx)
 458                .await?;
 459            }
 460
 461            tx.commit().await?;
 462            Ok(Some(NewUserResult {
 463                user_id,
 464                metrics_id,
 465                inviting_user_id,
 466                signup_device_id,
 467            }))
 468        })
 469        .await
 470    }
 471
 472    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 473        self.transact(|mut tx| async {
 474            sqlx::query(
 475                "
 476                INSERT INTO signups
 477                (
 478                    email_address,
 479                    email_confirmation_code,
 480                    email_confirmation_sent,
 481                    platform_linux,
 482                    platform_mac,
 483                    platform_windows,
 484                    platform_unknown,
 485                    editor_features,
 486                    programming_languages,
 487                    device_id
 488                )
 489                VALUES
 490                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 491                RETURNING id
 492                ",
 493            )
 494            .bind(&signup.email_address)
 495            .bind(&random_email_confirmation_code())
 496            .bind(&signup.platform_linux)
 497            .bind(&signup.platform_mac)
 498            .bind(&signup.platform_windows)
 499            .bind(&signup.editor_features)
 500            .bind(&signup.programming_languages)
 501            .bind(&signup.device_id)
 502            .execute(&mut tx)
 503            .await?;
 504            tx.commit().await?;
 505            Ok(())
 506        })
 507        .await
 508    }
 509
 510    pub async fn create_invite_from_code(
 511        &self,
 512        code: &str,
 513        email_address: &str,
 514        device_id: Option<&str>,
 515    ) -> Result<Invite> {
 516        self.transact(|mut tx| async {
 517            let existing_user: Option<UserId> = sqlx::query_scalar(
 518                "
 519                SELECT id
 520                FROM users
 521                WHERE email_address = $1
 522                ",
 523            )
 524            .bind(email_address)
 525            .fetch_optional(&mut tx)
 526            .await?;
 527            if existing_user.is_some() {
 528                Err(anyhow!("email address is already in use"))?;
 529            }
 530
 531            let row: Option<(UserId, i32)> = sqlx::query_as(
 532                "
 533                SELECT id, invite_count
 534                FROM users
 535                WHERE invite_code = $1
 536                ",
 537            )
 538            .bind(code)
 539            .fetch_optional(&mut tx)
 540            .await?;
 541
 542            let (inviter_id, invite_count) = match row {
 543                Some(row) => row,
 544                None => Err(Error::Http(
 545                    StatusCode::NOT_FOUND,
 546                    "invite code not found".to_string(),
 547                ))?,
 548            };
 549
 550            if invite_count == 0 {
 551                Err(Error::Http(
 552                    StatusCode::UNAUTHORIZED,
 553                    "no invites remaining".to_string(),
 554                ))?;
 555            }
 556
 557            let email_confirmation_code: String = sqlx::query_scalar(
 558                "
 559                INSERT INTO signups
 560                (
 561                    email_address,
 562                    email_confirmation_code,
 563                    email_confirmation_sent,
 564                    inviting_user_id,
 565                    platform_linux,
 566                    platform_mac,
 567                    platform_windows,
 568                    platform_unknown,
 569                    device_id
 570                )
 571                VALUES
 572                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 573                ON CONFLICT (email_address)
 574                DO UPDATE SET
 575                    inviting_user_id = excluded.inviting_user_id
 576                RETURNING email_confirmation_code
 577                ",
 578            )
 579            .bind(&email_address)
 580            .bind(&random_email_confirmation_code())
 581            .bind(&inviter_id)
 582            .bind(&device_id)
 583            .fetch_one(&mut tx)
 584            .await?;
 585
 586            tx.commit().await?;
 587
 588            Ok(Invite {
 589                email_address: email_address.into(),
 590                email_confirmation_code,
 591            })
 592        })
 593        .await
 594    }
 595
 596    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 597        self.transact(|mut tx| async {
 598            let emails = invites
 599                .iter()
 600                .map(|s| s.email_address.as_str())
 601                .collect::<Vec<_>>();
 602            sqlx::query(
 603                "
 604                UPDATE signups
 605                SET email_confirmation_sent = TRUE
 606                WHERE email_address = ANY ($1)
 607                ",
 608            )
 609            .bind(&emails)
 610            .execute(&mut tx)
 611            .await?;
 612            tx.commit().await?;
 613            Ok(())
 614        })
 615        .await
 616    }
 617}
 618
 619impl<D> Db<D>
 620where
 621    Self: BeginTransaction<Database = D> + BuildQuery,
 622    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 623    D::Connection: sqlx::migrate::Migrate,
 624    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 625    for<'a> sea_query_binder::SqlxValues: sqlx::IntoArguments<'a, D>,
 626    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 627    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 628    D::QueryResult: RowsAffected,
 629    String: sqlx::Type<D>,
 630    i32: sqlx::Type<D>,
 631    i64: sqlx::Type<D>,
 632    bool: sqlx::Type<D>,
 633    str: sqlx::Type<D>,
 634    Uuid: sqlx::Type<D>,
 635    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 636    OffsetDateTime: sqlx::Type<D>,
 637    PrimitiveDateTime: sqlx::Type<D>,
 638    usize: sqlx::ColumnIndex<D::Row>,
 639    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 640    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 641    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 642    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 643    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 644    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 645    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 646    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 647    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 648    for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 649    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 650    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 651    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 652{
 653    pub async fn migrate(
 654        &self,
 655        migrations_path: &Path,
 656        ignore_checksum_mismatch: bool,
 657    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 658        let migrations = MigrationSource::resolve(migrations_path)
 659            .await
 660            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 661
 662        let mut conn = self.pool.acquire().await?;
 663
 664        conn.ensure_migrations_table().await?;
 665        let applied_migrations: HashMap<_, _> = conn
 666            .list_applied_migrations()
 667            .await?
 668            .into_iter()
 669            .map(|m| (m.version, m))
 670            .collect();
 671
 672        let mut new_migrations = Vec::new();
 673        for migration in migrations {
 674            match applied_migrations.get(&migration.version) {
 675                Some(applied_migration) => {
 676                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 677                    {
 678                        Err(anyhow!(
 679                            "checksum mismatch for applied migration {}",
 680                            migration.description
 681                        ))?;
 682                    }
 683                }
 684                None => {
 685                    let elapsed = conn.apply(&migration).await?;
 686                    new_migrations.push((migration, elapsed));
 687                }
 688            }
 689        }
 690
 691        Ok(new_migrations)
 692    }
 693
 694    pub fn fuzzy_like_string(string: &str) -> String {
 695        let mut result = String::with_capacity(string.len() * 2 + 1);
 696        for c in string.chars() {
 697            if c.is_alphanumeric() {
 698                result.push('%');
 699                result.push(c);
 700            }
 701        }
 702        result.push('%');
 703        result
 704    }
 705
 706    // users
 707
 708    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 709        self.transact(|tx| async {
 710            let mut tx = tx;
 711            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 712            Ok(sqlx::query_as(query)
 713                .bind(limit as i32)
 714                .bind((page * limit) as i32)
 715                .fetch_all(&mut tx)
 716                .await?)
 717        })
 718        .await
 719    }
 720
 721    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 722        self.transact(|tx| async {
 723            let mut tx = tx;
 724            let query = "
 725                SELECT users.*
 726                FROM users
 727                WHERE id = $1
 728                LIMIT 1
 729            ";
 730            Ok(sqlx::query_as(query)
 731                .bind(&id)
 732                .fetch_optional(&mut tx)
 733                .await?)
 734        })
 735        .await
 736    }
 737
 738    pub async fn get_users_with_no_invites(
 739        &self,
 740        invited_by_another_user: bool,
 741    ) -> Result<Vec<User>> {
 742        self.transact(|tx| async {
 743            let mut tx = tx;
 744            let query = format!(
 745                "
 746                SELECT users.*
 747                FROM users
 748                WHERE invite_count = 0
 749                AND inviter_id IS{} NULL
 750                ",
 751                if invited_by_another_user { " NOT" } else { "" }
 752            );
 753
 754            Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
 755        })
 756        .await
 757    }
 758
 759    pub async fn get_user_by_github_account(
 760        &self,
 761        github_login: &str,
 762        github_user_id: Option<i32>,
 763    ) -> Result<Option<User>> {
 764        self.transact(|tx| async {
 765            let mut tx = tx;
 766            if let Some(github_user_id) = github_user_id {
 767                let mut user = sqlx::query_as::<_, User>(
 768                    "
 769                    UPDATE users
 770                    SET github_login = $1
 771                    WHERE github_user_id = $2
 772                    RETURNING *
 773                    ",
 774                )
 775                .bind(github_login)
 776                .bind(github_user_id)
 777                .fetch_optional(&mut tx)
 778                .await?;
 779
 780                if user.is_none() {
 781                    user = sqlx::query_as::<_, User>(
 782                        "
 783                        UPDATE users
 784                        SET github_user_id = $1
 785                        WHERE github_login = $2
 786                        RETURNING *
 787                        ",
 788                    )
 789                    .bind(github_user_id)
 790                    .bind(github_login)
 791                    .fetch_optional(&mut tx)
 792                    .await?;
 793                }
 794
 795                Ok(user)
 796            } else {
 797                let user = sqlx::query_as(
 798                    "
 799                    SELECT * FROM users
 800                    WHERE github_login = $1
 801                    LIMIT 1
 802                    ",
 803                )
 804                .bind(github_login)
 805                .fetch_optional(&mut tx)
 806                .await?;
 807                Ok(user)
 808            }
 809        })
 810        .await
 811    }
 812
 813    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 814        self.transact(|mut tx| async {
 815            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 816            sqlx::query(query)
 817                .bind(is_admin)
 818                .bind(id.0)
 819                .execute(&mut tx)
 820                .await?;
 821            tx.commit().await?;
 822            Ok(())
 823        })
 824        .await
 825    }
 826
 827    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 828        self.transact(|mut tx| async move {
 829            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 830            sqlx::query(query)
 831                .bind(connected_once)
 832                .bind(id.0)
 833                .execute(&mut tx)
 834                .await?;
 835            tx.commit().await?;
 836            Ok(())
 837        })
 838        .await
 839    }
 840
 841    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 842        self.transact(|mut tx| async move {
 843            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 844            sqlx::query(query)
 845                .bind(id.0)
 846                .execute(&mut tx)
 847                .await
 848                .map(drop)?;
 849            let query = "DELETE FROM users WHERE id = $1;";
 850            sqlx::query(query).bind(id.0).execute(&mut tx).await?;
 851            tx.commit().await?;
 852            Ok(())
 853        })
 854        .await
 855    }
 856
 857    // signups
 858
 859    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 860        self.transact(|mut tx| async move {
 861            Ok(sqlx::query_as(
 862                "
 863                SELECT
 864                    COUNT(*) as count,
 865                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 866                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 867                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 868                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 869                FROM (
 870                    SELECT *
 871                    FROM signups
 872                    WHERE
 873                        NOT email_confirmation_sent
 874                ) AS unsent
 875                ",
 876            )
 877            .fetch_one(&mut tx)
 878            .await?)
 879        })
 880        .await
 881    }
 882
 883    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 884        self.transact(|mut tx| async move {
 885            Ok(sqlx::query_as(
 886                "
 887                SELECT
 888                    email_address, email_confirmation_code
 889                FROM signups
 890                WHERE
 891                    NOT email_confirmation_sent AND
 892                    (platform_mac OR platform_unknown)
 893                LIMIT $1
 894                ",
 895            )
 896            .bind(count as i32)
 897            .fetch_all(&mut tx)
 898            .await?)
 899        })
 900        .await
 901    }
 902
 903    // invite codes
 904
 905    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 906        self.transact(|mut tx| async move {
 907            if count > 0 {
 908                sqlx::query(
 909                    "
 910                    UPDATE users
 911                    SET invite_code = $1
 912                    WHERE id = $2 AND invite_code IS NULL
 913                ",
 914                )
 915                .bind(random_invite_code())
 916                .bind(id)
 917                .execute(&mut tx)
 918                .await?;
 919            }
 920
 921            sqlx::query(
 922                "
 923                UPDATE users
 924                SET invite_count = $1
 925                WHERE id = $2
 926                ",
 927            )
 928            .bind(count as i32)
 929            .bind(id)
 930            .execute(&mut tx)
 931            .await?;
 932            tx.commit().await?;
 933            Ok(())
 934        })
 935        .await
 936    }
 937
 938    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 939        self.transact(|mut tx| async move {
 940            let result: Option<(String, i32)> = sqlx::query_as(
 941                "
 942                    SELECT invite_code, invite_count
 943                    FROM users
 944                    WHERE id = $1 AND invite_code IS NOT NULL 
 945                ",
 946            )
 947            .bind(id)
 948            .fetch_optional(&mut tx)
 949            .await?;
 950            if let Some((code, count)) = result {
 951                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 952            } else {
 953                Ok(None)
 954            }
 955        })
 956        .await
 957    }
 958
 959    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 960        self.transact(|tx| async {
 961            let mut tx = tx;
 962            sqlx::query_as(
 963                "
 964                    SELECT *
 965                    FROM users
 966                    WHERE invite_code = $1
 967                ",
 968            )
 969            .bind(code)
 970            .fetch_optional(&mut tx)
 971            .await?
 972            .ok_or_else(|| {
 973                Error::Http(
 974                    StatusCode::NOT_FOUND,
 975                    "that invite code does not exist".to_string(),
 976                )
 977            })
 978        })
 979        .await
 980    }
 981
 982    async fn commit_room_transaction<'a, T>(
 983        &'a self,
 984        room_id: RoomId,
 985        tx: sqlx::Transaction<'static, D>,
 986        data: T,
 987    ) -> Result<RoomGuard<T>> {
 988        let lock = self.rooms.entry(room_id).or_default().clone();
 989        let _guard = lock.lock_owned().await;
 990        tx.commit().await?;
 991        Ok(RoomGuard {
 992            data,
 993            _guard,
 994            _not_send: PhantomData,
 995        })
 996    }
 997
 998    pub async fn create_room(
 999        &self,
1000        user_id: UserId,
1001        connection_id: ConnectionId,
1002        live_kit_room: &str,
1003    ) -> Result<RoomGuard<proto::Room>> {
1004        self.transact(|mut tx| async move {
1005            let room_id = sqlx::query_scalar(
1006                "
1007                INSERT INTO rooms (live_kit_room)
1008                VALUES ($1)
1009                RETURNING id
1010                ",
1011            )
1012            .bind(&live_kit_room)
1013            .fetch_one(&mut tx)
1014            .await
1015            .map(RoomId)?;
1016
1017            sqlx::query(
1018                "
1019                INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
1020                VALUES ($1, $2, $3, $4, $5)
1021                ",
1022            )
1023            .bind(room_id)
1024            .bind(user_id)
1025            .bind(connection_id.0 as i32)
1026            .bind(user_id)
1027            .bind(connection_id.0 as i32)
1028            .execute(&mut tx)
1029            .await?;
1030
1031            let room = self.get_room(room_id, &mut tx).await?;
1032            self.commit_room_transaction(room_id, tx, room).await
1033        }).await
1034    }
1035
1036    pub async fn call(
1037        &self,
1038        room_id: RoomId,
1039        calling_user_id: UserId,
1040        calling_connection_id: ConnectionId,
1041        called_user_id: UserId,
1042        initial_project_id: Option<ProjectId>,
1043    ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
1044        self.transact(|mut tx| async move {
1045            sqlx::query(
1046                "
1047                INSERT INTO room_participants (
1048                    room_id,
1049                    user_id,
1050                    calling_user_id,
1051                    calling_connection_id,
1052                    initial_project_id
1053                )
1054                VALUES ($1, $2, $3, $4, $5)
1055                ",
1056            )
1057            .bind(room_id)
1058            .bind(called_user_id)
1059            .bind(calling_user_id)
1060            .bind(calling_connection_id.0 as i32)
1061            .bind(initial_project_id)
1062            .execute(&mut tx)
1063            .await?;
1064
1065            let room = self.get_room(room_id, &mut tx).await?;
1066            let incoming_call = Self::build_incoming_call(&room, called_user_id)
1067                .ok_or_else(|| anyhow!("failed to build incoming call"))?;
1068            self.commit_room_transaction(room_id, tx, (room, incoming_call))
1069                .await
1070        })
1071        .await
1072    }
1073
1074    pub async fn incoming_call_for_user(
1075        &self,
1076        user_id: UserId,
1077    ) -> Result<Option<proto::IncomingCall>> {
1078        self.transact(|mut tx| async move {
1079            let room_id = sqlx::query_scalar::<_, RoomId>(
1080                "
1081                SELECT room_id
1082                FROM room_participants
1083                WHERE user_id = $1 AND answering_connection_id IS NULL
1084                ",
1085            )
1086            .bind(user_id)
1087            .fetch_optional(&mut tx)
1088            .await?;
1089
1090            if let Some(room_id) = room_id {
1091                let room = self.get_room(room_id, &mut tx).await?;
1092                Ok(Self::build_incoming_call(&room, user_id))
1093            } else {
1094                Ok(None)
1095            }
1096        })
1097        .await
1098    }
1099
1100    fn build_incoming_call(
1101        room: &proto::Room,
1102        called_user_id: UserId,
1103    ) -> Option<proto::IncomingCall> {
1104        let pending_participant = room
1105            .pending_participants
1106            .iter()
1107            .find(|participant| participant.user_id == called_user_id.to_proto())?;
1108
1109        Some(proto::IncomingCall {
1110            room_id: room.id,
1111            calling_user_id: pending_participant.calling_user_id,
1112            participant_user_ids: room
1113                .participants
1114                .iter()
1115                .map(|participant| participant.user_id)
1116                .collect(),
1117            initial_project: room.participants.iter().find_map(|participant| {
1118                let initial_project_id = pending_participant.initial_project_id?;
1119                participant
1120                    .projects
1121                    .iter()
1122                    .find(|project| project.id == initial_project_id)
1123                    .cloned()
1124            }),
1125        })
1126    }
1127
1128    pub async fn call_failed(
1129        &self,
1130        room_id: RoomId,
1131        called_user_id: UserId,
1132    ) -> Result<RoomGuard<proto::Room>> {
1133        self.transact(|mut tx| async move {
1134            sqlx::query(
1135                "
1136                DELETE FROM room_participants
1137                WHERE room_id = $1 AND user_id = $2
1138                ",
1139            )
1140            .bind(room_id)
1141            .bind(called_user_id)
1142            .execute(&mut tx)
1143            .await?;
1144
1145            let room = self.get_room(room_id, &mut tx).await?;
1146            self.commit_room_transaction(room_id, tx, room).await
1147        })
1148        .await
1149    }
1150
1151    pub async fn decline_call(
1152        &self,
1153        expected_room_id: Option<RoomId>,
1154        user_id: UserId,
1155    ) -> Result<RoomGuard<proto::Room>> {
1156        self.transact(|mut tx| async move {
1157            let room_id = sqlx::query_scalar(
1158                "
1159                DELETE FROM room_participants
1160                WHERE user_id = $1 AND answering_connection_id IS NULL
1161                RETURNING room_id
1162                ",
1163            )
1164            .bind(user_id)
1165            .fetch_one(&mut tx)
1166            .await?;
1167            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1168                return Err(anyhow!("declining call on unexpected room"))?;
1169            }
1170
1171            let room = self.get_room(room_id, &mut tx).await?;
1172            self.commit_room_transaction(room_id, tx, room).await
1173        })
1174        .await
1175    }
1176
1177    pub async fn cancel_call(
1178        &self,
1179        expected_room_id: Option<RoomId>,
1180        calling_connection_id: ConnectionId,
1181        called_user_id: UserId,
1182    ) -> Result<RoomGuard<proto::Room>> {
1183        self.transact(|mut tx| async move {
1184            let room_id = sqlx::query_scalar(
1185                "
1186                DELETE FROM room_participants
1187                WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1188                RETURNING room_id
1189                ",
1190            )
1191            .bind(called_user_id)
1192            .bind(calling_connection_id.0 as i32)
1193            .fetch_one(&mut tx)
1194            .await?;
1195            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1196                return Err(anyhow!("canceling call on unexpected room"))?;
1197            }
1198
1199            let room = self.get_room(room_id, &mut tx).await?;
1200            self.commit_room_transaction(room_id, tx, room).await
1201        }).await
1202    }
1203
1204    pub async fn join_room(
1205        &self,
1206        room_id: RoomId,
1207        user_id: UserId,
1208        connection_id: ConnectionId,
1209    ) -> Result<RoomGuard<proto::Room>> {
1210        self.transact(|mut tx| async move {
1211            sqlx::query(
1212                "
1213                UPDATE room_participants 
1214                SET answering_connection_id = $1
1215                WHERE room_id = $2 AND user_id = $3
1216                RETURNING 1
1217                ",
1218            )
1219            .bind(connection_id.0 as i32)
1220            .bind(room_id)
1221            .bind(user_id)
1222            .fetch_one(&mut tx)
1223            .await?;
1224
1225            let room = self.get_room(room_id, &mut tx).await?;
1226            self.commit_room_transaction(room_id, tx, room).await
1227        })
1228        .await
1229    }
1230
1231    pub async fn leave_room(
1232        &self,
1233        connection_id: ConnectionId,
1234    ) -> Result<Option<RoomGuard<LeftRoom>>> {
1235        self.transact(|mut tx| async move {
1236            // Leave room.
1237            let room_id = sqlx::query_scalar::<_, RoomId>(
1238                "
1239                DELETE FROM room_participants
1240                WHERE answering_connection_id = $1
1241                RETURNING room_id
1242                ",
1243            )
1244            .bind(connection_id.0 as i32)
1245            .fetch_optional(&mut tx)
1246            .await?;
1247
1248            if let Some(room_id) = room_id {
1249                // Cancel pending calls initiated by the leaving user.
1250                let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1251                    "
1252                    DELETE FROM room_participants
1253                    WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1254                    RETURNING user_id
1255                    ",
1256                )
1257                .bind(connection_id.0 as i32)
1258                .fetch_all(&mut tx)
1259                .await?;
1260
1261                let project_ids = sqlx::query_scalar::<_, ProjectId>(
1262                    "
1263                    SELECT project_id
1264                    FROM project_collaborators
1265                    WHERE connection_id = $1
1266                    ",
1267                )
1268                .bind(connection_id.0 as i32)
1269                .fetch_all(&mut tx)
1270                .await?;
1271
1272                // Leave projects.
1273                let mut left_projects = HashMap::default();
1274                if !project_ids.is_empty() {
1275                    let mut params = "?,".repeat(project_ids.len());
1276                    params.pop();
1277                    let query = format!(
1278                        "
1279                        SELECT *
1280                        FROM project_collaborators
1281                        WHERE project_id IN ({params})
1282                    "
1283                    );
1284                    let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1285                    for project_id in project_ids {
1286                        query = query.bind(project_id);
1287                    }
1288
1289                    let mut project_collaborators = query.fetch(&mut tx);
1290                    while let Some(collaborator) = project_collaborators.next().await {
1291                        let collaborator = collaborator?;
1292                        let left_project =
1293                            left_projects
1294                                .entry(collaborator.project_id)
1295                                .or_insert(LeftProject {
1296                                    id: collaborator.project_id,
1297                                    host_user_id: Default::default(),
1298                                    connection_ids: Default::default(),
1299                                    host_connection_id: Default::default(),
1300                                });
1301
1302                        let collaborator_connection_id =
1303                            ConnectionId(collaborator.connection_id as u32);
1304                        if collaborator_connection_id != connection_id {
1305                            left_project.connection_ids.push(collaborator_connection_id);
1306                        }
1307
1308                        if collaborator.is_host {
1309                            left_project.host_user_id = collaborator.user_id;
1310                            left_project.host_connection_id =
1311                                ConnectionId(collaborator.connection_id as u32);
1312                        }
1313                    }
1314                }
1315                sqlx::query(
1316                    "
1317                    DELETE FROM project_collaborators
1318                    WHERE connection_id = $1
1319                    ",
1320                )
1321                .bind(connection_id.0 as i32)
1322                .execute(&mut tx)
1323                .await?;
1324
1325                // Unshare projects.
1326                sqlx::query(
1327                    "
1328                    DELETE FROM projects
1329                    WHERE room_id = $1 AND host_connection_id = $2
1330                    ",
1331                )
1332                .bind(room_id)
1333                .bind(connection_id.0 as i32)
1334                .execute(&mut tx)
1335                .await?;
1336
1337                let room = self.get_room(room_id, &mut tx).await?;
1338                Ok(Some(
1339                    self.commit_room_transaction(
1340                        room_id,
1341                        tx,
1342                        LeftRoom {
1343                            room,
1344                            left_projects,
1345                            canceled_calls_to_user_ids,
1346                        },
1347                    )
1348                    .await?,
1349                ))
1350            } else {
1351                Ok(None)
1352            }
1353        })
1354        .await
1355    }
1356
1357    pub async fn update_room_participant_location(
1358        &self,
1359        room_id: RoomId,
1360        connection_id: ConnectionId,
1361        location: proto::ParticipantLocation,
1362    ) -> Result<RoomGuard<proto::Room>> {
1363        self.transact(|tx| async {
1364            let mut tx = tx;
1365            let location_kind;
1366            let location_project_id;
1367            match location
1368                .variant
1369                .as_ref()
1370                .ok_or_else(|| anyhow!("invalid location"))?
1371            {
1372                proto::participant_location::Variant::SharedProject(project) => {
1373                    location_kind = 0;
1374                    location_project_id = Some(ProjectId::from_proto(project.id));
1375                }
1376                proto::participant_location::Variant::UnsharedProject(_) => {
1377                    location_kind = 1;
1378                    location_project_id = None;
1379                }
1380                proto::participant_location::Variant::External(_) => {
1381                    location_kind = 2;
1382                    location_project_id = None;
1383                }
1384            }
1385
1386            sqlx::query(
1387                "
1388                UPDATE room_participants
1389                SET location_kind = $1, location_project_id = $2
1390                WHERE room_id = $3 AND answering_connection_id = $4
1391                RETURNING 1
1392                ",
1393            )
1394            .bind(location_kind)
1395            .bind(location_project_id)
1396            .bind(room_id)
1397            .bind(connection_id.0 as i32)
1398            .fetch_one(&mut tx)
1399            .await?;
1400
1401            let room = self.get_room(room_id, &mut tx).await?;
1402            self.commit_room_transaction(room_id, tx, room).await
1403        })
1404        .await
1405    }
1406
1407    async fn get_guest_connection_ids(
1408        &self,
1409        project_id: ProjectId,
1410        tx: &mut sqlx::Transaction<'_, D>,
1411    ) -> Result<Vec<ConnectionId>> {
1412        let mut guest_connection_ids = Vec::new();
1413        let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1414            "
1415            SELECT connection_id
1416            FROM project_collaborators
1417            WHERE project_id = $1 AND is_host = FALSE
1418            ",
1419        )
1420        .bind(project_id)
1421        .fetch(tx);
1422        while let Some(connection_id) = db_guest_connection_ids.next().await {
1423            guest_connection_ids.push(ConnectionId(connection_id? as u32));
1424        }
1425        Ok(guest_connection_ids)
1426    }
1427
1428    async fn get_room(
1429        &self,
1430        room_id: RoomId,
1431        tx: &mut sqlx::Transaction<'_, D>,
1432    ) -> Result<proto::Room> {
1433        let room: Room = sqlx::query_as(
1434            "
1435            SELECT *
1436            FROM rooms
1437            WHERE id = $1
1438            ",
1439        )
1440        .bind(room_id)
1441        .fetch_one(&mut *tx)
1442        .await?;
1443
1444        let mut db_participants =
1445            sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1446                "
1447                SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1448                FROM room_participants
1449                WHERE room_id = $1
1450                ",
1451            )
1452            .bind(room_id)
1453            .fetch(&mut *tx);
1454
1455        let mut participants = HashMap::default();
1456        let mut pending_participants = Vec::new();
1457        while let Some(participant) = db_participants.next().await {
1458            let (
1459                user_id,
1460                answering_connection_id,
1461                location_kind,
1462                location_project_id,
1463                calling_user_id,
1464                initial_project_id,
1465            ) = participant?;
1466            if let Some(answering_connection_id) = answering_connection_id {
1467                let location = match (location_kind, location_project_id) {
1468                    (Some(0), Some(project_id)) => {
1469                        Some(proto::participant_location::Variant::SharedProject(
1470                            proto::participant_location::SharedProject {
1471                                id: project_id.to_proto(),
1472                            },
1473                        ))
1474                    }
1475                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1476                        Default::default(),
1477                    )),
1478                    _ => Some(proto::participant_location::Variant::External(
1479                        Default::default(),
1480                    )),
1481                };
1482                participants.insert(
1483                    answering_connection_id,
1484                    proto::Participant {
1485                        user_id: user_id.to_proto(),
1486                        peer_id: answering_connection_id as u32,
1487                        projects: Default::default(),
1488                        location: Some(proto::ParticipantLocation { variant: location }),
1489                    },
1490                );
1491            } else {
1492                pending_participants.push(proto::PendingParticipant {
1493                    user_id: user_id.to_proto(),
1494                    calling_user_id: calling_user_id.to_proto(),
1495                    initial_project_id: initial_project_id.map(|id| id.to_proto()),
1496                });
1497            }
1498        }
1499        drop(db_participants);
1500
1501        let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1502            "
1503            SELECT host_connection_id, projects.id, worktrees.root_name
1504            FROM projects
1505            LEFT JOIN worktrees ON projects.id = worktrees.project_id
1506            WHERE room_id = $1
1507            ",
1508        )
1509        .bind(room_id)
1510        .fetch(&mut *tx);
1511
1512        while let Some(row) = rows.next().await {
1513            let (connection_id, project_id, worktree_root_name) = row?;
1514            if let Some(participant) = participants.get_mut(&connection_id) {
1515                let project = if let Some(project) = participant
1516                    .projects
1517                    .iter_mut()
1518                    .find(|project| project.id == project_id.to_proto())
1519                {
1520                    project
1521                } else {
1522                    participant.projects.push(proto::ParticipantProject {
1523                        id: project_id.to_proto(),
1524                        worktree_root_names: Default::default(),
1525                    });
1526                    participant.projects.last_mut().unwrap()
1527                };
1528                project.worktree_root_names.extend(worktree_root_name);
1529            }
1530        }
1531
1532        Ok(proto::Room {
1533            id: room.id.to_proto(),
1534            live_kit_room: room.live_kit_room,
1535            participants: participants.into_values().collect(),
1536            pending_participants,
1537        })
1538    }
1539
1540    // projects
1541
1542    pub async fn project_count_excluding_admins(&self) -> Result<usize> {
1543        self.transact(|mut tx| async move {
1544            Ok(sqlx::query_scalar::<_, i32>(
1545                "
1546                SELECT COUNT(*)
1547                FROM projects, users
1548                WHERE projects.host_user_id = users.id AND users.admin IS FALSE
1549                ",
1550            )
1551            .fetch_one(&mut tx)
1552            .await? as usize)
1553        })
1554        .await
1555    }
1556
1557    pub async fn share_project(
1558        &self,
1559        expected_room_id: RoomId,
1560        connection_id: ConnectionId,
1561        worktrees: &[proto::WorktreeMetadata],
1562    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
1563        self.transact(|mut tx| async move {
1564            let (sql, values) = self.build_query(
1565                Query::select()
1566                    .columns([
1567                        schema::room_participant::Definition::RoomId,
1568                        schema::room_participant::Definition::UserId,
1569                    ])
1570                    .from(schema::room_participant::Definition::Table)
1571                    .and_where(
1572                        Expr::col(schema::room_participant::Definition::AnsweringConnectionId)
1573                            .eq(connection_id.0),
1574                    ),
1575            );
1576            let (room_id, user_id) = sqlx::query_as_with::<_, (RoomId, UserId), _>(&sql, values)
1577                .fetch_one(&mut tx)
1578                .await?;
1579            if room_id != expected_room_id {
1580                return Err(anyhow!("shared project on unexpected room"))?;
1581            }
1582
1583            let (sql, values) = self.build_query(
1584                Query::insert()
1585                    .into_table(schema::project::Definition::Table)
1586                    .columns([
1587                        schema::project::Definition::RoomId,
1588                        schema::project::Definition::HostUserId,
1589                        schema::project::Definition::HostConnectionId,
1590                    ])
1591                    .values_panic([room_id.into(), user_id.into(), connection_id.0.into()])
1592                    .returning_col(schema::project::Definition::Id),
1593            );
1594            let project_id: ProjectId = sqlx::query_scalar_with(&sql, values)
1595                .fetch_one(&mut tx)
1596                .await?;
1597
1598            if !worktrees.is_empty() {
1599                let mut query = Query::insert()
1600                    .into_table(schema::worktree::Definition::Table)
1601                    .columns([
1602                        schema::worktree::Definition::ProjectId,
1603                        schema::worktree::Definition::Id,
1604                        schema::worktree::Definition::RootName,
1605                        schema::worktree::Definition::AbsPath,
1606                        schema::worktree::Definition::Visible,
1607                        schema::worktree::Definition::ScanId,
1608                        schema::worktree::Definition::IsComplete,
1609                    ])
1610                    .to_owned();
1611                for worktree in worktrees {
1612                    query.values_panic([
1613                        project_id.into(),
1614                        worktree.id.into(),
1615                        worktree.root_name.clone().into(),
1616                        worktree.abs_path.clone().into(),
1617                        worktree.visible.into(),
1618                        0.into(),
1619                        false.into(),
1620                    ]);
1621                }
1622                let (sql, values) = self.build_query(&query);
1623                sqlx::query_with(&sql, values).execute(&mut tx).await?;
1624            }
1625
1626            sqlx::query(
1627                "
1628                INSERT INTO project_collaborators (
1629                    project_id,
1630                    connection_id,
1631                    user_id,
1632                    replica_id,
1633                    is_host
1634                )
1635                VALUES ($1, $2, $3, $4, $5)
1636                ",
1637            )
1638            .bind(project_id)
1639            .bind(connection_id.0 as i32)
1640            .bind(user_id)
1641            .bind(0)
1642            .bind(true)
1643            .execute(&mut tx)
1644            .await?;
1645
1646            let room = self.get_room(room_id, &mut tx).await?;
1647            self.commit_room_transaction(room_id, tx, (project_id, room))
1648                .await
1649        })
1650        .await
1651    }
1652
1653    pub async fn unshare_project(
1654        &self,
1655        project_id: ProjectId,
1656        connection_id: ConnectionId,
1657    ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1658        self.transact(|mut tx| async move {
1659            let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1660            let room_id: RoomId = sqlx::query_scalar(
1661                "
1662                DELETE FROM projects
1663                WHERE id = $1 AND host_connection_id = $2
1664                RETURNING room_id
1665                ",
1666            )
1667            .bind(project_id)
1668            .bind(connection_id.0 as i32)
1669            .fetch_one(&mut tx)
1670            .await?;
1671            let room = self.get_room(room_id, &mut tx).await?;
1672            self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1673                .await
1674        })
1675        .await
1676    }
1677
1678    pub async fn update_project(
1679        &self,
1680        project_id: ProjectId,
1681        connection_id: ConnectionId,
1682        worktrees: &[proto::WorktreeMetadata],
1683    ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1684        self.transact(|mut tx| async move {
1685            let room_id: RoomId = sqlx::query_scalar(
1686                "
1687                SELECT room_id
1688                FROM projects
1689                WHERE id = $1 AND host_connection_id = $2
1690                ",
1691            )
1692            .bind(project_id)
1693            .bind(connection_id.0 as i32)
1694            .fetch_one(&mut tx)
1695            .await?;
1696
1697            if !worktrees.is_empty() {
1698                let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1699                params.pop();
1700                let query = format!(
1701                    "
1702                    INSERT INTO worktrees (
1703                        project_id,
1704                        id,
1705                        root_name,
1706                        abs_path,
1707                        visible,
1708                        scan_id,
1709                        is_complete
1710                    )
1711                    VALUES {params}
1712                    ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1713                    "
1714                );
1715
1716                let mut query = sqlx::query(&query);
1717                for worktree in worktrees {
1718                    query = query
1719                        .bind(project_id)
1720                        .bind(worktree.id as i32)
1721                        .bind(&worktree.root_name)
1722                        .bind(&worktree.abs_path)
1723                        .bind(worktree.visible)
1724                        .bind(0)
1725                        .bind(false)
1726                }
1727                query.execute(&mut tx).await?;
1728            }
1729
1730            let mut params = "?,".repeat(worktrees.len());
1731            if !worktrees.is_empty() {
1732                params.pop();
1733            }
1734            let query = format!(
1735                "
1736                DELETE FROM worktrees
1737                WHERE project_id = ? AND id NOT IN ({params})
1738                ",
1739            );
1740
1741            let mut query = sqlx::query(&query).bind(project_id);
1742            for worktree in worktrees {
1743                query = query.bind(WorktreeId(worktree.id as i32));
1744            }
1745            query.execute(&mut tx).await?;
1746
1747            let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1748            let room = self.get_room(room_id, &mut tx).await?;
1749            self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1750                .await
1751        })
1752        .await
1753    }
1754
1755    pub async fn update_worktree(
1756        &self,
1757        update: &proto::UpdateWorktree,
1758        connection_id: ConnectionId,
1759    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1760        self.transact(|mut tx| async move {
1761            let project_id = ProjectId::from_proto(update.project_id);
1762            let worktree_id = WorktreeId::from_proto(update.worktree_id);
1763
1764            // Ensure the update comes from the host.
1765            let room_id: RoomId = sqlx::query_scalar(
1766                "
1767                SELECT room_id
1768                FROM projects
1769                WHERE id = $1 AND host_connection_id = $2
1770                ",
1771            )
1772            .bind(project_id)
1773            .bind(connection_id.0 as i32)
1774            .fetch_one(&mut tx)
1775            .await?;
1776
1777            // Update metadata.
1778            sqlx::query(
1779                "
1780                UPDATE worktrees
1781                SET
1782                    root_name = $1,
1783                    scan_id = $2,
1784                    is_complete = $3,
1785                    abs_path = $4
1786                WHERE project_id = $5 AND id = $6
1787                RETURNING 1
1788                ",
1789            )
1790            .bind(&update.root_name)
1791            .bind(update.scan_id as i64)
1792            .bind(update.is_last_update)
1793            .bind(&update.abs_path)
1794            .bind(project_id)
1795            .bind(worktree_id)
1796            .fetch_one(&mut tx)
1797            .await?;
1798
1799            if !update.updated_entries.is_empty() {
1800                let mut params =
1801                    "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1802                params.pop();
1803
1804                let query = format!(
1805                    "
1806                    INSERT INTO worktree_entries (
1807                        project_id, 
1808                        worktree_id, 
1809                        id, 
1810                        is_dir, 
1811                        path, 
1812                        inode,
1813                        mtime_seconds, 
1814                        mtime_nanos, 
1815                        is_symlink, 
1816                        is_ignored
1817                    )
1818                    VALUES {params}
1819                    ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1820                        is_dir = excluded.is_dir,
1821                        path = excluded.path,
1822                        inode = excluded.inode,
1823                        mtime_seconds = excluded.mtime_seconds,
1824                        mtime_nanos = excluded.mtime_nanos,
1825                        is_symlink = excluded.is_symlink,
1826                        is_ignored = excluded.is_ignored
1827                    "
1828                );
1829                let mut query = sqlx::query(&query);
1830                for entry in &update.updated_entries {
1831                    let mtime = entry.mtime.clone().unwrap_or_default();
1832                    query = query
1833                        .bind(project_id)
1834                        .bind(worktree_id)
1835                        .bind(entry.id as i64)
1836                        .bind(entry.is_dir)
1837                        .bind(&entry.path)
1838                        .bind(entry.inode as i64)
1839                        .bind(mtime.seconds as i64)
1840                        .bind(mtime.nanos as i32)
1841                        .bind(entry.is_symlink)
1842                        .bind(entry.is_ignored);
1843                }
1844                query.execute(&mut tx).await?;
1845            }
1846
1847            if !update.removed_entries.is_empty() {
1848                let mut params = "?,".repeat(update.removed_entries.len());
1849                params.pop();
1850                let query = format!(
1851                    "
1852                    DELETE FROM worktree_entries
1853                    WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1854                    "
1855                );
1856
1857                let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1858                for entry_id in &update.removed_entries {
1859                    query = query.bind(*entry_id as i64);
1860                }
1861                query.execute(&mut tx).await?;
1862            }
1863
1864            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1865            self.commit_room_transaction(room_id, tx, connection_ids)
1866                .await
1867        })
1868        .await
1869    }
1870
1871    pub async fn update_diagnostic_summary(
1872        &self,
1873        update: &proto::UpdateDiagnosticSummary,
1874        connection_id: ConnectionId,
1875    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1876        self.transact(|mut tx| async {
1877            let project_id = ProjectId::from_proto(update.project_id);
1878            let worktree_id = WorktreeId::from_proto(update.worktree_id);
1879            let summary = update
1880                .summary
1881                .as_ref()
1882                .ok_or_else(|| anyhow!("invalid summary"))?;
1883
1884            // Ensure the update comes from the host.
1885            let room_id: RoomId = sqlx::query_scalar(
1886                "
1887                SELECT room_id
1888                FROM projects
1889                WHERE id = $1 AND host_connection_id = $2
1890                ",
1891            )
1892            .bind(project_id)
1893            .bind(connection_id.0 as i32)
1894            .fetch_one(&mut tx)
1895            .await?;
1896
1897            // Update summary.
1898            sqlx::query(
1899                "
1900                INSERT INTO worktree_diagnostic_summaries (
1901                    project_id,
1902                    worktree_id,
1903                    path,
1904                    language_server_id,
1905                    error_count,
1906                    warning_count
1907                )
1908                VALUES ($1, $2, $3, $4, $5, $6)
1909                ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1910                    language_server_id = excluded.language_server_id,
1911                    error_count = excluded.error_count, 
1912                    warning_count = excluded.warning_count
1913                ",
1914            )
1915            .bind(project_id)
1916            .bind(worktree_id)
1917            .bind(&summary.path)
1918            .bind(summary.language_server_id as i64)
1919            .bind(summary.error_count as i32)
1920            .bind(summary.warning_count as i32)
1921            .execute(&mut tx)
1922            .await?;
1923
1924            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1925            self.commit_room_transaction(room_id, tx, connection_ids)
1926                .await
1927        })
1928        .await
1929    }
1930
1931    pub async fn start_language_server(
1932        &self,
1933        update: &proto::StartLanguageServer,
1934        connection_id: ConnectionId,
1935    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1936        self.transact(|mut tx| async {
1937            let project_id = ProjectId::from_proto(update.project_id);
1938            let server = update
1939                .server
1940                .as_ref()
1941                .ok_or_else(|| anyhow!("invalid language server"))?;
1942
1943            // Ensure the update comes from the host.
1944            let room_id: RoomId = sqlx::query_scalar(
1945                "
1946                SELECT room_id
1947                FROM projects
1948                WHERE id = $1 AND host_connection_id = $2
1949                ",
1950            )
1951            .bind(project_id)
1952            .bind(connection_id.0 as i32)
1953            .fetch_one(&mut tx)
1954            .await?;
1955
1956            // Add the newly-started language server.
1957            sqlx::query(
1958                "
1959                INSERT INTO language_servers (project_id, id, name)
1960                VALUES ($1, $2, $3)
1961                ON CONFLICT (project_id, id) DO UPDATE SET
1962                    name = excluded.name
1963                ",
1964            )
1965            .bind(project_id)
1966            .bind(server.id as i64)
1967            .bind(&server.name)
1968            .execute(&mut tx)
1969            .await?;
1970
1971            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1972            self.commit_room_transaction(room_id, tx, connection_ids)
1973                .await
1974        })
1975        .await
1976    }
1977
1978    pub async fn join_project(
1979        &self,
1980        project_id: ProjectId,
1981        connection_id: ConnectionId,
1982    ) -> Result<RoomGuard<(Project, ReplicaId)>> {
1983        self.transact(|mut tx| async move {
1984            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1985                "
1986                SELECT room_id, user_id
1987                FROM room_participants
1988                WHERE answering_connection_id = $1
1989                ",
1990            )
1991            .bind(connection_id.0 as i32)
1992            .fetch_one(&mut tx)
1993            .await?;
1994
1995            // Ensure project id was shared on this room.
1996            sqlx::query(
1997                "
1998                SELECT 1
1999                FROM projects
2000                WHERE id = $1 AND room_id = $2
2001                ",
2002            )
2003            .bind(project_id)
2004            .bind(room_id)
2005            .fetch_one(&mut tx)
2006            .await?;
2007
2008            let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2009                "
2010                SELECT *
2011                FROM project_collaborators
2012                WHERE project_id = $1
2013                ",
2014            )
2015            .bind(project_id)
2016            .fetch_all(&mut tx)
2017            .await?;
2018            let replica_ids = collaborators
2019                .iter()
2020                .map(|c| c.replica_id)
2021                .collect::<HashSet<_>>();
2022            let mut replica_id = ReplicaId(1);
2023            while replica_ids.contains(&replica_id) {
2024                replica_id.0 += 1;
2025            }
2026            let new_collaborator = ProjectCollaborator {
2027                project_id,
2028                connection_id: connection_id.0 as i32,
2029                user_id,
2030                replica_id,
2031                is_host: false,
2032            };
2033
2034            sqlx::query(
2035                "
2036                INSERT INTO project_collaborators (
2037                    project_id,
2038                    connection_id,
2039                    user_id,
2040                    replica_id,
2041                    is_host
2042                )
2043                VALUES ($1, $2, $3, $4, $5)
2044                ",
2045            )
2046            .bind(new_collaborator.project_id)
2047            .bind(new_collaborator.connection_id)
2048            .bind(new_collaborator.user_id)
2049            .bind(new_collaborator.replica_id)
2050            .bind(new_collaborator.is_host)
2051            .execute(&mut tx)
2052            .await?;
2053            collaborators.push(new_collaborator);
2054
2055            let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
2056                "
2057                SELECT *
2058                FROM worktrees
2059                WHERE project_id = $1
2060                ",
2061            )
2062            .bind(project_id)
2063            .fetch_all(&mut tx)
2064            .await?;
2065            let mut worktrees = worktree_rows
2066                .into_iter()
2067                .map(|worktree_row| {
2068                    (
2069                        worktree_row.id,
2070                        Worktree {
2071                            id: worktree_row.id,
2072                            abs_path: worktree_row.abs_path,
2073                            root_name: worktree_row.root_name,
2074                            visible: worktree_row.visible,
2075                            entries: Default::default(),
2076                            diagnostic_summaries: Default::default(),
2077                            scan_id: worktree_row.scan_id as u64,
2078                            is_complete: worktree_row.is_complete,
2079                        },
2080                    )
2081                })
2082                .collect::<BTreeMap<_, _>>();
2083
2084            // Populate worktree entries.
2085            {
2086                let mut entries = sqlx::query_as::<_, WorktreeEntry>(
2087                    "
2088                    SELECT *
2089                    FROM worktree_entries
2090                    WHERE project_id = $1
2091                    ",
2092                )
2093                .bind(project_id)
2094                .fetch(&mut tx);
2095                while let Some(entry) = entries.next().await {
2096                    let entry = entry?;
2097                    if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2098                        worktree.entries.push(proto::Entry {
2099                            id: entry.id as u64,
2100                            is_dir: entry.is_dir,
2101                            path: entry.path,
2102                            inode: entry.inode as u64,
2103                            mtime: Some(proto::Timestamp {
2104                                seconds: entry.mtime_seconds as u64,
2105                                nanos: entry.mtime_nanos as u32,
2106                            }),
2107                            is_symlink: entry.is_symlink,
2108                            is_ignored: entry.is_ignored,
2109                        });
2110                    }
2111                }
2112            }
2113
2114            // Populate worktree diagnostic summaries.
2115            {
2116                let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2117                    "
2118                    SELECT *
2119                    FROM worktree_diagnostic_summaries
2120                    WHERE project_id = $1
2121                    ",
2122                )
2123                .bind(project_id)
2124                .fetch(&mut tx);
2125                while let Some(summary) = summaries.next().await {
2126                    let summary = summary?;
2127                    if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2128                        worktree
2129                            .diagnostic_summaries
2130                            .push(proto::DiagnosticSummary {
2131                                path: summary.path,
2132                                language_server_id: summary.language_server_id as u64,
2133                                error_count: summary.error_count as u32,
2134                                warning_count: summary.warning_count as u32,
2135                            });
2136                    }
2137                }
2138            }
2139
2140            // Populate language servers.
2141            let language_servers = sqlx::query_as::<_, LanguageServer>(
2142                "
2143                SELECT *
2144                FROM language_servers
2145                WHERE project_id = $1
2146                ",
2147            )
2148            .bind(project_id)
2149            .fetch_all(&mut tx)
2150            .await?;
2151
2152            self.commit_room_transaction(
2153                room_id,
2154                tx,
2155                (
2156                    Project {
2157                        collaborators,
2158                        worktrees,
2159                        language_servers: language_servers
2160                            .into_iter()
2161                            .map(|language_server| proto::LanguageServer {
2162                                id: language_server.id.to_proto(),
2163                                name: language_server.name,
2164                            })
2165                            .collect(),
2166                    },
2167                    replica_id as ReplicaId,
2168                ),
2169            )
2170            .await
2171        })
2172        .await
2173    }
2174
2175    pub async fn leave_project(
2176        &self,
2177        project_id: ProjectId,
2178        connection_id: ConnectionId,
2179    ) -> Result<RoomGuard<LeftProject>> {
2180        self.transact(|mut tx| async move {
2181            let result = sqlx::query(
2182                "
2183                DELETE FROM project_collaborators
2184                WHERE project_id = $1 AND connection_id = $2
2185                ",
2186            )
2187            .bind(project_id)
2188            .bind(connection_id.0 as i32)
2189            .execute(&mut tx)
2190            .await?;
2191
2192            if result.rows_affected() == 0 {
2193                Err(anyhow!("not a collaborator on this project"))?;
2194            }
2195
2196            let connection_ids = sqlx::query_scalar::<_, i32>(
2197                "
2198                SELECT connection_id
2199                FROM project_collaborators
2200                WHERE project_id = $1
2201                ",
2202            )
2203            .bind(project_id)
2204            .fetch_all(&mut tx)
2205            .await?
2206            .into_iter()
2207            .map(|id| ConnectionId(id as u32))
2208            .collect();
2209
2210            let (room_id, host_user_id, host_connection_id) =
2211                sqlx::query_as::<_, (RoomId, i32, i32)>(
2212                    "
2213                SELECT room_id, host_user_id, host_connection_id
2214                FROM projects
2215                WHERE id = $1
2216                ",
2217                )
2218                .bind(project_id)
2219                .fetch_one(&mut tx)
2220                .await?;
2221
2222            self.commit_room_transaction(
2223                room_id,
2224                tx,
2225                LeftProject {
2226                    id: project_id,
2227                    host_user_id: UserId(host_user_id),
2228                    host_connection_id: ConnectionId(host_connection_id as u32),
2229                    connection_ids,
2230                },
2231            )
2232            .await
2233        })
2234        .await
2235    }
2236
2237    pub async fn project_collaborators(
2238        &self,
2239        project_id: ProjectId,
2240        connection_id: ConnectionId,
2241    ) -> Result<Vec<ProjectCollaborator>> {
2242        self.transact(|mut tx| async move {
2243            let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2244                "
2245                SELECT *
2246                FROM project_collaborators
2247                WHERE project_id = $1
2248                ",
2249            )
2250            .bind(project_id)
2251            .fetch_all(&mut tx)
2252            .await?;
2253
2254            if collaborators
2255                .iter()
2256                .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2257            {
2258                Ok(collaborators)
2259            } else {
2260                Err(anyhow!("no such project"))?
2261            }
2262        })
2263        .await
2264    }
2265
2266    pub async fn project_connection_ids(
2267        &self,
2268        project_id: ProjectId,
2269        connection_id: ConnectionId,
2270    ) -> Result<HashSet<ConnectionId>> {
2271        self.transact(|mut tx| async move {
2272            let connection_ids = sqlx::query_scalar::<_, i32>(
2273                "
2274                SELECT connection_id
2275                FROM project_collaborators
2276                WHERE project_id = $1
2277                ",
2278            )
2279            .bind(project_id)
2280            .fetch_all(&mut tx)
2281            .await?;
2282
2283            if connection_ids.contains(&(connection_id.0 as i32)) {
2284                Ok(connection_ids
2285                    .into_iter()
2286                    .map(|connection_id| ConnectionId(connection_id as u32))
2287                    .collect())
2288            } else {
2289                Err(anyhow!("no such project"))?
2290            }
2291        })
2292        .await
2293    }
2294
2295    // contacts
2296
2297    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2298        self.transact(|mut tx| async move {
2299            let query = "
2300                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2301                FROM contacts
2302                LEFT JOIN room_participants ON room_participants.user_id = $1
2303                WHERE user_id_a = $1 OR user_id_b = $1;
2304            ";
2305
2306            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2307                .bind(user_id)
2308                .fetch(&mut tx);
2309
2310            let mut contacts = Vec::new();
2311            while let Some(row) = rows.next().await {
2312                let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2313                if user_id_a == user_id {
2314                    if accepted {
2315                        contacts.push(Contact::Accepted {
2316                            user_id: user_id_b,
2317                            should_notify: should_notify && a_to_b,
2318                            busy
2319                        });
2320                    } else if a_to_b {
2321                        contacts.push(Contact::Outgoing { user_id: user_id_b })
2322                    } else {
2323                        contacts.push(Contact::Incoming {
2324                            user_id: user_id_b,
2325                            should_notify,
2326                        });
2327                    }
2328                } else if accepted {
2329                    contacts.push(Contact::Accepted {
2330                        user_id: user_id_a,
2331                        should_notify: should_notify && !a_to_b,
2332                        busy
2333                    });
2334                } else if a_to_b {
2335                    contacts.push(Contact::Incoming {
2336                        user_id: user_id_a,
2337                        should_notify,
2338                    });
2339                } else {
2340                    contacts.push(Contact::Outgoing { user_id: user_id_a });
2341                }
2342            }
2343
2344            contacts.sort_unstable_by_key(|contact| contact.user_id());
2345
2346            Ok(contacts)
2347        })
2348        .await
2349    }
2350
2351    pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2352        self.transact(|mut tx| async move {
2353            Ok(sqlx::query_scalar::<_, i32>(
2354                "
2355                SELECT 1
2356                FROM room_participants
2357                WHERE room_participants.user_id = $1
2358                ",
2359            )
2360            .bind(user_id)
2361            .fetch_optional(&mut tx)
2362            .await?
2363            .is_some())
2364        })
2365        .await
2366    }
2367
2368    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2369        self.transact(|mut tx| async move {
2370            let (id_a, id_b) = if user_id_1 < user_id_2 {
2371                (user_id_1, user_id_2)
2372            } else {
2373                (user_id_2, user_id_1)
2374            };
2375
2376            let query = "
2377                SELECT 1 FROM contacts
2378                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2379                LIMIT 1
2380            ";
2381            Ok(sqlx::query_scalar::<_, i32>(query)
2382                .bind(id_a.0)
2383                .bind(id_b.0)
2384                .fetch_optional(&mut tx)
2385                .await?
2386                .is_some())
2387        })
2388        .await
2389    }
2390
2391    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2392        self.transact(|mut tx| async move {
2393            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2394                (sender_id, receiver_id, true)
2395            } else {
2396                (receiver_id, sender_id, false)
2397            };
2398            let query = "
2399                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2400                VALUES ($1, $2, $3, FALSE, TRUE)
2401                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2402                SET
2403                    accepted = TRUE,
2404                    should_notify = FALSE
2405                WHERE
2406                    NOT contacts.accepted AND
2407                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2408                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2409            ";
2410            let result = sqlx::query(query)
2411                .bind(id_a.0)
2412                .bind(id_b.0)
2413                .bind(a_to_b)
2414                .execute(&mut tx)
2415                .await?;
2416
2417            if result.rows_affected() == 1 {
2418                tx.commit().await?;
2419                Ok(())
2420            } else {
2421                Err(anyhow!("contact already requested"))?
2422            }
2423        }).await
2424    }
2425
2426    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2427        self.transact(|mut tx| async move {
2428            let (id_a, id_b) = if responder_id < requester_id {
2429                (responder_id, requester_id)
2430            } else {
2431                (requester_id, responder_id)
2432            };
2433            let query = "
2434                DELETE FROM contacts
2435                WHERE user_id_a = $1 AND user_id_b = $2;
2436            ";
2437            let result = sqlx::query(query)
2438                .bind(id_a.0)
2439                .bind(id_b.0)
2440                .execute(&mut tx)
2441                .await?;
2442
2443            if result.rows_affected() == 1 {
2444                tx.commit().await?;
2445                Ok(())
2446            } else {
2447                Err(anyhow!("no such contact"))?
2448            }
2449        })
2450        .await
2451    }
2452
2453    pub async fn dismiss_contact_notification(
2454        &self,
2455        user_id: UserId,
2456        contact_user_id: UserId,
2457    ) -> Result<()> {
2458        self.transact(|mut tx| async move {
2459            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2460                (user_id, contact_user_id, true)
2461            } else {
2462                (contact_user_id, user_id, false)
2463            };
2464
2465            let query = "
2466                UPDATE contacts
2467                SET should_notify = FALSE
2468                WHERE
2469                    user_id_a = $1 AND user_id_b = $2 AND
2470                    (
2471                        (a_to_b = $3 AND accepted) OR
2472                        (a_to_b != $3 AND NOT accepted)
2473                    );
2474            ";
2475
2476            let result = sqlx::query(query)
2477                .bind(id_a.0)
2478                .bind(id_b.0)
2479                .bind(a_to_b)
2480                .execute(&mut tx)
2481                .await?;
2482
2483            if result.rows_affected() == 0 {
2484                Err(anyhow!("no such contact request"))?
2485            } else {
2486                tx.commit().await?;
2487                Ok(())
2488            }
2489        })
2490        .await
2491    }
2492
2493    pub async fn respond_to_contact_request(
2494        &self,
2495        responder_id: UserId,
2496        requester_id: UserId,
2497        accept: bool,
2498    ) -> Result<()> {
2499        self.transact(|mut tx| async move {
2500            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2501                (responder_id, requester_id, false)
2502            } else {
2503                (requester_id, responder_id, true)
2504            };
2505            let result = if accept {
2506                let query = "
2507                    UPDATE contacts
2508                    SET accepted = TRUE, should_notify = TRUE
2509                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2510                ";
2511                sqlx::query(query)
2512                    .bind(id_a.0)
2513                    .bind(id_b.0)
2514                    .bind(a_to_b)
2515                    .execute(&mut tx)
2516                    .await?
2517            } else {
2518                let query = "
2519                    DELETE FROM contacts
2520                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2521                ";
2522                sqlx::query(query)
2523                    .bind(id_a.0)
2524                    .bind(id_b.0)
2525                    .bind(a_to_b)
2526                    .execute(&mut tx)
2527                    .await?
2528            };
2529            if result.rows_affected() == 1 {
2530                tx.commit().await?;
2531                Ok(())
2532            } else {
2533                Err(anyhow!("no such contact request"))?
2534            }
2535        })
2536        .await
2537    }
2538
2539    // access tokens
2540
2541    pub async fn create_access_token_hash(
2542        &self,
2543        user_id: UserId,
2544        access_token_hash: &str,
2545        max_access_token_count: usize,
2546    ) -> Result<()> {
2547        self.transact(|tx| async {
2548            let mut tx = tx;
2549            let insert_query = "
2550                INSERT INTO access_tokens (user_id, hash)
2551                VALUES ($1, $2);
2552            ";
2553            let cleanup_query = "
2554                DELETE FROM access_tokens
2555                WHERE id IN (
2556                    SELECT id from access_tokens
2557                    WHERE user_id = $1
2558                    ORDER BY id DESC
2559                    LIMIT 10000
2560                    OFFSET $3
2561                )
2562            ";
2563
2564            sqlx::query(insert_query)
2565                .bind(user_id.0)
2566                .bind(access_token_hash)
2567                .execute(&mut tx)
2568                .await?;
2569            sqlx::query(cleanup_query)
2570                .bind(user_id.0)
2571                .bind(access_token_hash)
2572                .bind(max_access_token_count as i32)
2573                .execute(&mut tx)
2574                .await?;
2575            Ok(tx.commit().await?)
2576        })
2577        .await
2578    }
2579
2580    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2581        self.transact(|mut tx| async move {
2582            let query = "
2583                SELECT hash
2584                FROM access_tokens
2585                WHERE user_id = $1
2586                ORDER BY id DESC
2587            ";
2588            Ok(sqlx::query_scalar(query)
2589                .bind(user_id.0)
2590                .fetch_all(&mut tx)
2591                .await?)
2592        })
2593        .await
2594    }
2595
2596    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2597    where
2598        F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2599        Fut: Send + Future<Output = Result<T>>,
2600    {
2601        let body = async {
2602            loop {
2603                let tx = self.begin_transaction().await?;
2604                match f(tx).await {
2605                    Ok(result) => return Ok(result),
2606                    Err(error) => match error {
2607                        Error::Database(error)
2608                            if error
2609                                .as_database_error()
2610                                .and_then(|error| error.code())
2611                                .as_deref()
2612                                == Some("40001") =>
2613                        {
2614                            // Retry (don't break the loop)
2615                        }
2616                        error @ _ => return Err(error),
2617                    },
2618                }
2619            }
2620        };
2621
2622        #[cfg(test)]
2623        {
2624            if let Some(background) = self.background.as_ref() {
2625                background.simulate_random_delay().await;
2626            }
2627
2628            self.runtime.as_ref().unwrap().block_on(body)
2629        }
2630
2631        #[cfg(not(test))]
2632        {
2633            body.await
2634        }
2635    }
2636}
2637
2638macro_rules! id_type {
2639    ($name:ident) => {
2640        #[derive(
2641            Clone,
2642            Copy,
2643            Debug,
2644            Default,
2645            PartialEq,
2646            Eq,
2647            PartialOrd,
2648            Ord,
2649            Hash,
2650            sqlx::Type,
2651            Serialize,
2652            Deserialize,
2653        )]
2654        #[sqlx(transparent)]
2655        #[serde(transparent)]
2656        pub struct $name(pub i32);
2657
2658        impl $name {
2659            #[allow(unused)]
2660            pub const MAX: Self = Self(i32::MAX);
2661
2662            #[allow(unused)]
2663            pub fn from_proto(value: u64) -> Self {
2664                Self(value as i32)
2665            }
2666
2667            #[allow(unused)]
2668            pub fn to_proto(self) -> u64 {
2669                self.0 as u64
2670            }
2671        }
2672
2673        impl std::fmt::Display for $name {
2674            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2675                self.0.fmt(f)
2676            }
2677        }
2678
2679        impl From<$name> for sea_query::Value {
2680            fn from(value: $name) -> Self {
2681                sea_query::Value::Int(Some(value.0))
2682            }
2683        }
2684    };
2685}
2686
2687id_type!(UserId);
2688#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2689pub struct User {
2690    pub id: UserId,
2691    pub github_login: String,
2692    pub github_user_id: Option<i32>,
2693    pub email_address: Option<String>,
2694    pub admin: bool,
2695    pub invite_code: Option<String>,
2696    pub invite_count: i32,
2697    pub connected_once: bool,
2698}
2699
2700id_type!(RoomId);
2701#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2702pub struct Room {
2703    pub id: RoomId,
2704    pub live_kit_room: String,
2705}
2706
2707id_type!(ProjectId);
2708pub struct Project {
2709    pub collaborators: Vec<ProjectCollaborator>,
2710    pub worktrees: BTreeMap<WorktreeId, Worktree>,
2711    pub language_servers: Vec<proto::LanguageServer>,
2712}
2713
2714id_type!(ReplicaId);
2715#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2716pub struct ProjectCollaborator {
2717    pub project_id: ProjectId,
2718    pub connection_id: i32,
2719    pub user_id: UserId,
2720    pub replica_id: ReplicaId,
2721    pub is_host: bool,
2722}
2723
2724id_type!(WorktreeId);
2725#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2726struct WorktreeRow {
2727    pub id: WorktreeId,
2728    pub project_id: ProjectId,
2729    pub abs_path: String,
2730    pub root_name: String,
2731    pub visible: bool,
2732    pub scan_id: i64,
2733    pub is_complete: bool,
2734}
2735
2736pub struct Worktree {
2737    pub id: WorktreeId,
2738    pub abs_path: String,
2739    pub root_name: String,
2740    pub visible: bool,
2741    pub entries: Vec<proto::Entry>,
2742    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2743    pub scan_id: u64,
2744    pub is_complete: bool,
2745}
2746
2747#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2748struct WorktreeEntry {
2749    id: i64,
2750    worktree_id: WorktreeId,
2751    is_dir: bool,
2752    path: String,
2753    inode: i64,
2754    mtime_seconds: i64,
2755    mtime_nanos: i32,
2756    is_symlink: bool,
2757    is_ignored: bool,
2758}
2759
2760#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2761struct WorktreeDiagnosticSummary {
2762    worktree_id: WorktreeId,
2763    path: String,
2764    language_server_id: i64,
2765    error_count: i32,
2766    warning_count: i32,
2767}
2768
2769id_type!(LanguageServerId);
2770#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2771struct LanguageServer {
2772    id: LanguageServerId,
2773    name: String,
2774}
2775
2776pub struct LeftProject {
2777    pub id: ProjectId,
2778    pub host_user_id: UserId,
2779    pub host_connection_id: ConnectionId,
2780    pub connection_ids: Vec<ConnectionId>,
2781}
2782
2783pub struct LeftRoom {
2784    pub room: proto::Room,
2785    pub left_projects: HashMap<ProjectId, LeftProject>,
2786    pub canceled_calls_to_user_ids: Vec<UserId>,
2787}
2788
2789#[derive(Clone, Debug, PartialEq, Eq)]
2790pub enum Contact {
2791    Accepted {
2792        user_id: UserId,
2793        should_notify: bool,
2794        busy: bool,
2795    },
2796    Outgoing {
2797        user_id: UserId,
2798    },
2799    Incoming {
2800        user_id: UserId,
2801        should_notify: bool,
2802    },
2803}
2804
2805impl Contact {
2806    pub fn user_id(&self) -> UserId {
2807        match self {
2808            Contact::Accepted { user_id, .. } => *user_id,
2809            Contact::Outgoing { user_id } => *user_id,
2810            Contact::Incoming { user_id, .. } => *user_id,
2811        }
2812    }
2813}
2814
2815#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2816pub struct IncomingContactRequest {
2817    pub requester_id: UserId,
2818    pub should_notify: bool,
2819}
2820
2821#[derive(Clone, Deserialize)]
2822pub struct Signup {
2823    pub email_address: String,
2824    pub platform_mac: bool,
2825    pub platform_windows: bool,
2826    pub platform_linux: bool,
2827    pub editor_features: Vec<String>,
2828    pub programming_languages: Vec<String>,
2829    pub device_id: Option<String>,
2830}
2831
2832#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2833pub struct WaitlistSummary {
2834    #[sqlx(default)]
2835    pub count: i64,
2836    #[sqlx(default)]
2837    pub linux_count: i64,
2838    #[sqlx(default)]
2839    pub mac_count: i64,
2840    #[sqlx(default)]
2841    pub windows_count: i64,
2842    #[sqlx(default)]
2843    pub unknown_count: i64,
2844}
2845
2846#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2847pub struct Invite {
2848    pub email_address: String,
2849    pub email_confirmation_code: String,
2850}
2851
2852#[derive(Debug, Serialize, Deserialize)]
2853pub struct NewUserParams {
2854    pub github_login: String,
2855    pub github_user_id: i32,
2856    pub invite_count: i32,
2857}
2858
2859#[derive(Debug)]
2860pub struct NewUserResult {
2861    pub user_id: UserId,
2862    pub metrics_id: String,
2863    pub inviting_user_id: Option<UserId>,
2864    pub signup_device_id: Option<String>,
2865}
2866
2867fn random_invite_code() -> String {
2868    nanoid::nanoid!(16)
2869}
2870
2871fn random_email_confirmation_code() -> String {
2872    nanoid::nanoid!(64)
2873}
2874
2875#[cfg(test)]
2876pub use test::*;
2877
2878#[cfg(test)]
2879mod test {
2880    use super::*;
2881    use gpui::executor::Background;
2882    use lazy_static::lazy_static;
2883    use parking_lot::Mutex;
2884    use rand::prelude::*;
2885    use sqlx::migrate::MigrateDatabase;
2886    use std::sync::Arc;
2887
2888    pub struct SqliteTestDb {
2889        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2890        pub conn: sqlx::sqlite::SqliteConnection,
2891    }
2892
2893    pub struct PostgresTestDb {
2894        pub db: Option<Arc<Db<sqlx::Postgres>>>,
2895        pub url: String,
2896    }
2897
2898    impl SqliteTestDb {
2899        pub fn new(background: Arc<Background>) -> Self {
2900            let mut rng = StdRng::from_entropy();
2901            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2902            let runtime = tokio::runtime::Builder::new_current_thread()
2903                .enable_io()
2904                .enable_time()
2905                .build()
2906                .unwrap();
2907
2908            let (mut db, conn) = runtime.block_on(async {
2909                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2910                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2911                db.migrate(migrations_path.as_ref(), false).await.unwrap();
2912                let conn = db.pool.acquire().await.unwrap().detach();
2913                (db, conn)
2914            });
2915
2916            db.background = Some(background);
2917            db.runtime = Some(runtime);
2918
2919            Self {
2920                db: Some(Arc::new(db)),
2921                conn,
2922            }
2923        }
2924
2925        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2926            self.db.as_ref().unwrap()
2927        }
2928    }
2929
2930    impl PostgresTestDb {
2931        pub fn new(background: Arc<Background>) -> Self {
2932            lazy_static! {
2933                static ref LOCK: Mutex<()> = Mutex::new(());
2934            }
2935
2936            let _guard = LOCK.lock();
2937            let mut rng = StdRng::from_entropy();
2938            let url = format!(
2939                "postgres://postgres@localhost/zed-test-{}",
2940                rng.gen::<u128>()
2941            );
2942            let runtime = tokio::runtime::Builder::new_current_thread()
2943                .enable_io()
2944                .enable_time()
2945                .build()
2946                .unwrap();
2947
2948            let mut db = runtime.block_on(async {
2949                sqlx::Postgres::create_database(&url)
2950                    .await
2951                    .expect("failed to create test db");
2952                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2953                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2954                db.migrate(Path::new(migrations_path), false).await.unwrap();
2955                db
2956            });
2957
2958            db.background = Some(background);
2959            db.runtime = Some(runtime);
2960
2961            Self {
2962                db: Some(Arc::new(db)),
2963                url,
2964            }
2965        }
2966
2967        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2968            self.db.as_ref().unwrap()
2969        }
2970    }
2971
2972    impl Drop for PostgresTestDb {
2973        fn drop(&mut self) {
2974            let db = self.db.take().unwrap();
2975            db.teardown(&self.url);
2976        }
2977    }
2978}