db.rs

   1use crate::{Error, Result};
   2use anyhow::anyhow;
   3use axum::http::StatusCode;
   4use collections::{BTreeMap, HashMap, HashSet};
   5use dashmap::DashMap;
   6use futures::{future::BoxFuture, FutureExt, StreamExt};
   7use rpc::{proto, ConnectionId};
   8use serde::{Deserialize, Serialize};
   9use sqlx::{
  10    migrate::{Migrate as _, Migration, MigrationSource},
  11    types::Uuid,
  12    FromRow,
  13};
  14use std::{
  15    future::Future,
  16    marker::PhantomData,
  17    ops::{Deref, DerefMut},
  18    path::Path,
  19    rc::Rc,
  20    sync::Arc,
  21    time::Duration,
  22};
  23use time::{OffsetDateTime, PrimitiveDateTime};
  24use tokio::sync::{Mutex, OwnedMutexGuard};
  25
  26#[cfg(test)]
  27pub type DefaultDb = Db<sqlx::Sqlite>;
  28
  29#[cfg(not(test))]
  30pub type DefaultDb = Db<sqlx::Postgres>;
  31
  32pub struct Db<D: sqlx::Database> {
  33    pool: sqlx::Pool<D>,
  34    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
  35    #[cfg(test)]
  36    background: Option<std::sync::Arc<gpui::executor::Background>>,
  37    #[cfg(test)]
  38    runtime: Option<tokio::runtime::Runtime>,
  39}
  40
  41pub struct RoomGuard<T> {
  42    data: T,
  43    _guard: OwnedMutexGuard<()>,
  44    _not_send: PhantomData<Rc<()>>,
  45}
  46
  47impl<T> Deref for RoomGuard<T> {
  48    type Target = T;
  49
  50    fn deref(&self) -> &T {
  51        &self.data
  52    }
  53}
  54
  55impl<T> DerefMut for RoomGuard<T> {
  56    fn deref_mut(&mut self) -> &mut T {
  57        &mut self.data
  58    }
  59}
  60
  61pub trait BeginTransaction: Send + Sync {
  62    type Database: sqlx::Database;
  63
  64    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
  65}
  66
  67// In Postgres, serializable transactions are opt-in
  68impl BeginTransaction for Db<sqlx::Postgres> {
  69    type Database = sqlx::Postgres;
  70
  71    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
  72        async move {
  73            let mut tx = self.pool.begin().await?;
  74            sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
  75                .await?;
  76            Ok(tx)
  77        }
  78        .boxed()
  79    }
  80}
  81
  82// In Sqlite, transactions are inherently serializable.
  83#[cfg(test)]
  84impl BeginTransaction for Db<sqlx::Sqlite> {
  85    type Database = sqlx::Sqlite;
  86
  87    fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
  88        async move { Ok(self.pool.begin().await?) }.boxed()
  89    }
  90}
  91
  92pub trait RowsAffected {
  93    fn rows_affected(&self) -> u64;
  94}
  95
  96#[cfg(test)]
  97impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
  98    fn rows_affected(&self) -> u64 {
  99        self.rows_affected()
 100    }
 101}
 102
 103impl RowsAffected for sqlx::postgres::PgQueryResult {
 104    fn rows_affected(&self) -> u64 {
 105        self.rows_affected()
 106    }
 107}
 108
 109#[cfg(test)]
 110impl Db<sqlx::Sqlite> {
 111    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 112        use std::str::FromStr as _;
 113        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
 114            .unwrap()
 115            .create_if_missing(true)
 116            .shared_cache(true);
 117        let pool = sqlx::sqlite::SqlitePoolOptions::new()
 118            .min_connections(2)
 119            .max_connections(max_connections)
 120            .connect_with(options)
 121            .await?;
 122        Ok(Self {
 123            pool,
 124            rooms: Default::default(),
 125            background: None,
 126            runtime: None,
 127        })
 128    }
 129
 130    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 131        self.transact(|tx| async {
 132            let mut tx = tx;
 133            let query = "
 134                SELECT users.*
 135                FROM users
 136                WHERE users.id IN (SELECT value from json_each($1))
 137            ";
 138            Ok(sqlx::query_as(query)
 139                .bind(&serde_json::json!(ids))
 140                .fetch_all(&mut tx)
 141                .await?)
 142        })
 143        .await
 144    }
 145
 146    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 147        self.transact(|mut tx| async move {
 148            let query = "
 149                SELECT metrics_id
 150                FROM users
 151                WHERE id = $1
 152            ";
 153            Ok(sqlx::query_scalar(query)
 154                .bind(id)
 155                .fetch_one(&mut tx)
 156                .await?)
 157        })
 158        .await
 159    }
 160
 161    pub async fn create_user(
 162        &self,
 163        email_address: &str,
 164        admin: bool,
 165        params: NewUserParams,
 166    ) -> Result<NewUserResult> {
 167        self.transact(|mut tx| async {
 168            let query = "
 169                INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
 170                VALUES ($1, $2, $3, $4, $5)
 171                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 172                RETURNING id, metrics_id
 173            ";
 174
 175            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 176                .bind(email_address)
 177                .bind(&params.github_login)
 178                .bind(&params.github_user_id)
 179                .bind(admin)
 180                .bind(Uuid::new_v4().to_string())
 181                .fetch_one(&mut tx)
 182                .await?;
 183            tx.commit().await?;
 184            Ok(NewUserResult {
 185                user_id,
 186                metrics_id,
 187                signup_device_id: None,
 188                inviting_user_id: None,
 189            })
 190        })
 191        .await
 192    }
 193
 194    pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
 195        unimplemented!()
 196    }
 197
 198    pub async fn create_user_from_invite(
 199        &self,
 200        _invite: &Invite,
 201        _user: NewUserParams,
 202    ) -> Result<Option<NewUserResult>> {
 203        unimplemented!()
 204    }
 205
 206    pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
 207        unimplemented!()
 208    }
 209
 210    pub async fn create_invite_from_code(
 211        &self,
 212        _code: &str,
 213        _email_address: &str,
 214        _device_id: Option<&str>,
 215    ) -> Result<Invite> {
 216        unimplemented!()
 217    }
 218
 219    pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
 220        unimplemented!()
 221    }
 222}
 223
 224impl Db<sqlx::Postgres> {
 225    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 226        let pool = sqlx::postgres::PgPoolOptions::new()
 227            .max_connections(max_connections)
 228            .connect(url)
 229            .await?;
 230        Ok(Self {
 231            pool,
 232            rooms: DashMap::with_capacity(16384),
 233            #[cfg(test)]
 234            background: None,
 235            #[cfg(test)]
 236            runtime: None,
 237        })
 238    }
 239
 240    #[cfg(test)]
 241    pub fn teardown(&self, url: &str) {
 242        self.runtime.as_ref().unwrap().block_on(async {
 243            use util::ResultExt;
 244            let query = "
 245                SELECT pg_terminate_backend(pg_stat_activity.pid)
 246                FROM pg_stat_activity
 247                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
 248            ";
 249            sqlx::query(query).execute(&self.pool).await.log_err();
 250            self.pool.close().await;
 251            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
 252                .await
 253                .log_err();
 254        })
 255    }
 256
 257    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 258        self.transact(|tx| async {
 259            let mut tx = tx;
 260            let like_string = Self::fuzzy_like_string(name_query);
 261            let query = "
 262                SELECT users.*
 263                FROM users
 264                WHERE github_login ILIKE $1
 265                ORDER BY github_login <-> $2
 266                LIMIT $3
 267            ";
 268            Ok(sqlx::query_as(query)
 269                .bind(like_string)
 270                .bind(name_query)
 271                .bind(limit as i32)
 272                .fetch_all(&mut tx)
 273                .await?)
 274        })
 275        .await
 276    }
 277
 278    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 279        let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
 280        self.transact(|tx| async {
 281            let mut tx = tx;
 282            let query = "
 283                SELECT users.*
 284                FROM users
 285                WHERE users.id = ANY ($1)
 286            ";
 287            Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
 288        })
 289        .await
 290    }
 291
 292    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 293        self.transact(|mut tx| async move {
 294            let query = "
 295                SELECT metrics_id::text
 296                FROM users
 297                WHERE id = $1
 298            ";
 299            Ok(sqlx::query_scalar(query)
 300                .bind(id)
 301                .fetch_one(&mut tx)
 302                .await?)
 303        })
 304        .await
 305    }
 306
 307    pub async fn create_user(
 308        &self,
 309        email_address: &str,
 310        admin: bool,
 311        params: NewUserParams,
 312    ) -> Result<NewUserResult> {
 313        self.transact(|mut tx| async {
 314            let query = "
 315                INSERT INTO users (email_address, github_login, github_user_id, admin)
 316                VALUES ($1, $2, $3, $4)
 317                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 318                RETURNING id, metrics_id::text
 319            ";
 320
 321            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 322                .bind(email_address)
 323                .bind(&params.github_login)
 324                .bind(params.github_user_id)
 325                .bind(admin)
 326                .fetch_one(&mut tx)
 327                .await?;
 328            tx.commit().await?;
 329
 330            Ok(NewUserResult {
 331                user_id,
 332                metrics_id,
 333                signup_device_id: None,
 334                inviting_user_id: None,
 335            })
 336        })
 337        .await
 338    }
 339
 340    pub async fn create_user_from_invite(
 341        &self,
 342        invite: &Invite,
 343        user: NewUserParams,
 344    ) -> Result<Option<NewUserResult>> {
 345        self.transact(|mut tx| async {
 346            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 347                i32,
 348                Option<UserId>,
 349                Option<UserId>,
 350                Option<String>,
 351            ) = sqlx::query_as(
 352                "
 353                SELECT id, user_id, inviting_user_id, device_id
 354                FROM signups
 355                WHERE
 356                    email_address = $1 AND
 357                    email_confirmation_code = $2
 358                ",
 359            )
 360            .bind(&invite.email_address)
 361            .bind(&invite.email_confirmation_code)
 362            .fetch_optional(&mut tx)
 363            .await?
 364            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 365
 366            if existing_user_id.is_some() {
 367                return Ok(None);
 368            }
 369
 370            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 371                "
 372                INSERT INTO users
 373                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 374                VALUES
 375                ($1, $2, $3, FALSE, $4, $5)
 376                ON CONFLICT (github_login) DO UPDATE SET
 377                    email_address = excluded.email_address,
 378                    github_user_id = excluded.github_user_id,
 379                    admin = excluded.admin
 380                RETURNING id, metrics_id::text
 381                ",
 382            )
 383            .bind(&invite.email_address)
 384            .bind(&user.github_login)
 385            .bind(&user.github_user_id)
 386            .bind(&user.invite_count)
 387            .bind(random_invite_code())
 388            .fetch_one(&mut tx)
 389            .await?;
 390
 391            sqlx::query(
 392                "
 393                UPDATE signups
 394                SET user_id = $1
 395                WHERE id = $2
 396                ",
 397            )
 398            .bind(&user_id)
 399            .bind(&signup_id)
 400            .execute(&mut tx)
 401            .await?;
 402
 403            if let Some(inviting_user_id) = inviting_user_id {
 404                let id: Option<UserId> = sqlx::query_scalar(
 405                    "
 406                    UPDATE users
 407                    SET invite_count = invite_count - 1
 408                    WHERE id = $1 AND invite_count > 0
 409                    RETURNING id
 410                    ",
 411                )
 412                .bind(&inviting_user_id)
 413                .fetch_optional(&mut tx)
 414                .await?;
 415
 416                if id.is_none() {
 417                    Err(Error::Http(
 418                        StatusCode::UNAUTHORIZED,
 419                        "no invites remaining".to_string(),
 420                    ))?;
 421                }
 422
 423                sqlx::query(
 424                    "
 425                    INSERT INTO contacts
 426                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 427                    VALUES
 428                        ($1, $2, TRUE, TRUE, TRUE)
 429                    ON CONFLICT DO NOTHING
 430                    ",
 431                )
 432                .bind(inviting_user_id)
 433                .bind(user_id)
 434                .execute(&mut tx)
 435                .await?;
 436            }
 437
 438            tx.commit().await?;
 439            Ok(Some(NewUserResult {
 440                user_id,
 441                metrics_id,
 442                inviting_user_id,
 443                signup_device_id,
 444            }))
 445        })
 446        .await
 447    }
 448
 449    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
 450        self.transact(|mut tx| async {
 451            sqlx::query(
 452                "
 453                INSERT INTO signups
 454                (
 455                    email_address,
 456                    email_confirmation_code,
 457                    email_confirmation_sent,
 458                    platform_linux,
 459                    platform_mac,
 460                    platform_windows,
 461                    platform_unknown,
 462                    editor_features,
 463                    programming_languages,
 464                    device_id
 465                )
 466                VALUES
 467                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
 468                RETURNING id
 469                ",
 470            )
 471            .bind(&signup.email_address)
 472            .bind(&random_email_confirmation_code())
 473            .bind(&signup.platform_linux)
 474            .bind(&signup.platform_mac)
 475            .bind(&signup.platform_windows)
 476            .bind(&signup.editor_features)
 477            .bind(&signup.programming_languages)
 478            .bind(&signup.device_id)
 479            .execute(&mut tx)
 480            .await?;
 481            tx.commit().await?;
 482            Ok(())
 483        })
 484        .await
 485    }
 486
 487    pub async fn create_invite_from_code(
 488        &self,
 489        code: &str,
 490        email_address: &str,
 491        device_id: Option<&str>,
 492    ) -> Result<Invite> {
 493        self.transact(|mut tx| async {
 494            let existing_user: Option<UserId> = sqlx::query_scalar(
 495                "
 496                SELECT id
 497                FROM users
 498                WHERE email_address = $1
 499                ",
 500            )
 501            .bind(email_address)
 502            .fetch_optional(&mut tx)
 503            .await?;
 504            if existing_user.is_some() {
 505                Err(anyhow!("email address is already in use"))?;
 506            }
 507
 508            let row: Option<(UserId, i32)> = sqlx::query_as(
 509                "
 510                SELECT id, invite_count
 511                FROM users
 512                WHERE invite_code = $1
 513                ",
 514            )
 515            .bind(code)
 516            .fetch_optional(&mut tx)
 517            .await?;
 518
 519            let (inviter_id, invite_count) = match row {
 520                Some(row) => row,
 521                None => Err(Error::Http(
 522                    StatusCode::NOT_FOUND,
 523                    "invite code not found".to_string(),
 524                ))?,
 525            };
 526
 527            if invite_count == 0 {
 528                Err(Error::Http(
 529                    StatusCode::UNAUTHORIZED,
 530                    "no invites remaining".to_string(),
 531                ))?;
 532            }
 533
 534            let email_confirmation_code: String = sqlx::query_scalar(
 535                "
 536                INSERT INTO signups
 537                (
 538                    email_address,
 539                    email_confirmation_code,
 540                    email_confirmation_sent,
 541                    inviting_user_id,
 542                    platform_linux,
 543                    platform_mac,
 544                    platform_windows,
 545                    platform_unknown,
 546                    device_id
 547                )
 548                VALUES
 549                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 550                ON CONFLICT (email_address)
 551                DO UPDATE SET
 552                    inviting_user_id = excluded.inviting_user_id
 553                RETURNING email_confirmation_code
 554                ",
 555            )
 556            .bind(&email_address)
 557            .bind(&random_email_confirmation_code())
 558            .bind(&inviter_id)
 559            .bind(&device_id)
 560            .fetch_one(&mut tx)
 561            .await?;
 562
 563            tx.commit().await?;
 564
 565            Ok(Invite {
 566                email_address: email_address.into(),
 567                email_confirmation_code,
 568            })
 569        })
 570        .await
 571    }
 572
 573    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 574        self.transact(|mut tx| async {
 575            let emails = invites
 576                .iter()
 577                .map(|s| s.email_address.as_str())
 578                .collect::<Vec<_>>();
 579            sqlx::query(
 580                "
 581                UPDATE signups
 582                SET email_confirmation_sent = TRUE
 583                WHERE email_address = ANY ($1)
 584                ",
 585            )
 586            .bind(&emails)
 587            .execute(&mut tx)
 588            .await?;
 589            tx.commit().await?;
 590            Ok(())
 591        })
 592        .await
 593    }
 594}
 595
 596impl<D> Db<D>
 597where
 598    Self: BeginTransaction<Database = D>,
 599    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
 600    D::Connection: sqlx::migrate::Migrate,
 601    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
 602    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
 603    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
 604    D::QueryResult: RowsAffected,
 605    String: sqlx::Type<D>,
 606    i32: sqlx::Type<D>,
 607    i64: sqlx::Type<D>,
 608    bool: sqlx::Type<D>,
 609    str: sqlx::Type<D>,
 610    Uuid: sqlx::Type<D>,
 611    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
 612    OffsetDateTime: sqlx::Type<D>,
 613    PrimitiveDateTime: sqlx::Type<D>,
 614    usize: sqlx::ColumnIndex<D::Row>,
 615    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
 616    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 617    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 618    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 619    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 620    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 621    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 622    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 623    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 624    for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 625    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 626    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
 627    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 628{
 629    pub async fn migrate(
 630        &self,
 631        migrations_path: &Path,
 632        ignore_checksum_mismatch: bool,
 633    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 634        let migrations = MigrationSource::resolve(migrations_path)
 635            .await
 636            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 637
 638        let mut conn = self.pool.acquire().await?;
 639
 640        conn.ensure_migrations_table().await?;
 641        let applied_migrations: HashMap<_, _> = conn
 642            .list_applied_migrations()
 643            .await?
 644            .into_iter()
 645            .map(|m| (m.version, m))
 646            .collect();
 647
 648        let mut new_migrations = Vec::new();
 649        for migration in migrations {
 650            match applied_migrations.get(&migration.version) {
 651                Some(applied_migration) => {
 652                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 653                    {
 654                        Err(anyhow!(
 655                            "checksum mismatch for applied migration {}",
 656                            migration.description
 657                        ))?;
 658                    }
 659                }
 660                None => {
 661                    let elapsed = conn.apply(&migration).await?;
 662                    new_migrations.push((migration, elapsed));
 663                }
 664            }
 665        }
 666
 667        Ok(new_migrations)
 668    }
 669
 670    pub fn fuzzy_like_string(string: &str) -> String {
 671        let mut result = String::with_capacity(string.len() * 2 + 1);
 672        for c in string.chars() {
 673            if c.is_alphanumeric() {
 674                result.push('%');
 675                result.push(c);
 676            }
 677        }
 678        result.push('%');
 679        result
 680    }
 681
 682    // users
 683
 684    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 685        self.transact(|tx| async {
 686            let mut tx = tx;
 687            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 688            Ok(sqlx::query_as(query)
 689                .bind(limit as i32)
 690                .bind((page * limit) as i32)
 691                .fetch_all(&mut tx)
 692                .await?)
 693        })
 694        .await
 695    }
 696
 697    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 698        self.transact(|tx| async {
 699            let mut tx = tx;
 700            let query = "
 701                SELECT users.*
 702                FROM users
 703                WHERE id = $1
 704                LIMIT 1
 705            ";
 706            Ok(sqlx::query_as(query)
 707                .bind(&id)
 708                .fetch_optional(&mut tx)
 709                .await?)
 710        })
 711        .await
 712    }
 713
 714    pub async fn get_users_with_no_invites(
 715        &self,
 716        invited_by_another_user: bool,
 717    ) -> Result<Vec<User>> {
 718        self.transact(|tx| async {
 719            let mut tx = tx;
 720            let query = format!(
 721                "
 722                SELECT users.*
 723                FROM users
 724                WHERE invite_count = 0
 725                AND inviter_id IS{} NULL
 726                ",
 727                if invited_by_another_user { " NOT" } else { "" }
 728            );
 729
 730            Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
 731        })
 732        .await
 733    }
 734
 735    pub async fn get_user_by_github_account(
 736        &self,
 737        github_login: &str,
 738        github_user_id: Option<i32>,
 739    ) -> Result<Option<User>> {
 740        self.transact(|tx| async {
 741            let mut tx = tx;
 742            if let Some(github_user_id) = github_user_id {
 743                let mut user = sqlx::query_as::<_, User>(
 744                    "
 745                    UPDATE users
 746                    SET github_login = $1
 747                    WHERE github_user_id = $2
 748                    RETURNING *
 749                    ",
 750                )
 751                .bind(github_login)
 752                .bind(github_user_id)
 753                .fetch_optional(&mut tx)
 754                .await?;
 755
 756                if user.is_none() {
 757                    user = sqlx::query_as::<_, User>(
 758                        "
 759                        UPDATE users
 760                        SET github_user_id = $1
 761                        WHERE github_login = $2
 762                        RETURNING *
 763                        ",
 764                    )
 765                    .bind(github_user_id)
 766                    .bind(github_login)
 767                    .fetch_optional(&mut tx)
 768                    .await?;
 769                }
 770
 771                Ok(user)
 772            } else {
 773                let user = sqlx::query_as(
 774                    "
 775                    SELECT * FROM users
 776                    WHERE github_login = $1
 777                    LIMIT 1
 778                    ",
 779                )
 780                .bind(github_login)
 781                .fetch_optional(&mut tx)
 782                .await?;
 783                Ok(user)
 784            }
 785        })
 786        .await
 787    }
 788
 789    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 790        self.transact(|mut tx| async {
 791            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 792            sqlx::query(query)
 793                .bind(is_admin)
 794                .bind(id.0)
 795                .execute(&mut tx)
 796                .await?;
 797            tx.commit().await?;
 798            Ok(())
 799        })
 800        .await
 801    }
 802
 803    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 804        self.transact(|mut tx| async move {
 805            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 806            sqlx::query(query)
 807                .bind(connected_once)
 808                .bind(id.0)
 809                .execute(&mut tx)
 810                .await?;
 811            tx.commit().await?;
 812            Ok(())
 813        })
 814        .await
 815    }
 816
 817    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 818        self.transact(|mut tx| async move {
 819            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 820            sqlx::query(query)
 821                .bind(id.0)
 822                .execute(&mut tx)
 823                .await
 824                .map(drop)?;
 825            let query = "DELETE FROM users WHERE id = $1;";
 826            sqlx::query(query).bind(id.0).execute(&mut tx).await?;
 827            tx.commit().await?;
 828            Ok(())
 829        })
 830        .await
 831    }
 832
 833    // signups
 834
 835    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 836        self.transact(|mut tx| async move {
 837            Ok(sqlx::query_as(
 838                "
 839                SELECT
 840                    COUNT(*) as count,
 841                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 842                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 843                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 844                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 845                FROM (
 846                    SELECT *
 847                    FROM signups
 848                    WHERE
 849                        NOT email_confirmation_sent
 850                ) AS unsent
 851                ",
 852            )
 853            .fetch_one(&mut tx)
 854            .await?)
 855        })
 856        .await
 857    }
 858
 859    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 860        self.transact(|mut tx| async move {
 861            Ok(sqlx::query_as(
 862                "
 863                SELECT
 864                    email_address, email_confirmation_code
 865                FROM signups
 866                WHERE
 867                    NOT email_confirmation_sent AND
 868                    (platform_mac OR platform_unknown)
 869                LIMIT $1
 870                ",
 871            )
 872            .bind(count as i32)
 873            .fetch_all(&mut tx)
 874            .await?)
 875        })
 876        .await
 877    }
 878
 879    // invite codes
 880
 881    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 882        self.transact(|mut tx| async move {
 883            if count > 0 {
 884                sqlx::query(
 885                    "
 886                    UPDATE users
 887                    SET invite_code = $1
 888                    WHERE id = $2 AND invite_code IS NULL
 889                ",
 890                )
 891                .bind(random_invite_code())
 892                .bind(id)
 893                .execute(&mut tx)
 894                .await?;
 895            }
 896
 897            sqlx::query(
 898                "
 899                UPDATE users
 900                SET invite_count = $1
 901                WHERE id = $2
 902                ",
 903            )
 904            .bind(count as i32)
 905            .bind(id)
 906            .execute(&mut tx)
 907            .await?;
 908            tx.commit().await?;
 909            Ok(())
 910        })
 911        .await
 912    }
 913
 914    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 915        self.transact(|mut tx| async move {
 916            let result: Option<(String, i32)> = sqlx::query_as(
 917                "
 918                    SELECT invite_code, invite_count
 919                    FROM users
 920                    WHERE id = $1 AND invite_code IS NOT NULL 
 921                ",
 922            )
 923            .bind(id)
 924            .fetch_optional(&mut tx)
 925            .await?;
 926            if let Some((code, count)) = result {
 927                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 928            } else {
 929                Ok(None)
 930            }
 931        })
 932        .await
 933    }
 934
 935    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 936        self.transact(|tx| async {
 937            let mut tx = tx;
 938            sqlx::query_as(
 939                "
 940                    SELECT *
 941                    FROM users
 942                    WHERE invite_code = $1
 943                ",
 944            )
 945            .bind(code)
 946            .fetch_optional(&mut tx)
 947            .await?
 948            .ok_or_else(|| {
 949                Error::Http(
 950                    StatusCode::NOT_FOUND,
 951                    "that invite code does not exist".to_string(),
 952                )
 953            })
 954        })
 955        .await
 956    }
 957
 958    async fn commit_room_transaction<'a, T>(
 959        &'a self,
 960        room_id: RoomId,
 961        tx: sqlx::Transaction<'static, D>,
 962        data: T,
 963    ) -> Result<RoomGuard<T>> {
 964        let lock = self.rooms.entry(room_id).or_default().clone();
 965        let _guard = lock.lock_owned().await;
 966        tx.commit().await?;
 967        Ok(RoomGuard {
 968            data,
 969            _guard,
 970            _not_send: PhantomData,
 971        })
 972    }
 973
 974    pub async fn create_room(
 975        &self,
 976        user_id: UserId,
 977        connection_id: ConnectionId,
 978        live_kit_room: &str,
 979    ) -> Result<RoomGuard<proto::Room>> {
 980        self.transact(|mut tx| async move {
 981            let room_id = sqlx::query_scalar(
 982                "
 983                INSERT INTO rooms (live_kit_room)
 984                VALUES ($1)
 985                RETURNING id
 986                ",
 987            )
 988            .bind(&live_kit_room)
 989            .fetch_one(&mut tx)
 990            .await
 991            .map(RoomId)?;
 992
 993            sqlx::query(
 994                "
 995                INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
 996                VALUES ($1, $2, $3, $4, $5)
 997                ",
 998            )
 999            .bind(room_id)
1000            .bind(user_id)
1001            .bind(connection_id.0 as i32)
1002            .bind(user_id)
1003            .bind(connection_id.0 as i32)
1004            .execute(&mut tx)
1005            .await?;
1006
1007            let room = self.get_room(room_id, &mut tx).await?;
1008            self.commit_room_transaction(room_id, tx, room).await
1009        }).await
1010    }
1011
1012    pub async fn call(
1013        &self,
1014        room_id: RoomId,
1015        calling_user_id: UserId,
1016        calling_connection_id: ConnectionId,
1017        called_user_id: UserId,
1018        initial_project_id: Option<ProjectId>,
1019    ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
1020        self.transact(|mut tx| async move {
1021            sqlx::query(
1022                "
1023                INSERT INTO room_participants (
1024                    room_id,
1025                    user_id,
1026                    calling_user_id,
1027                    calling_connection_id,
1028                    initial_project_id
1029                )
1030                VALUES ($1, $2, $3, $4, $5)
1031                ",
1032            )
1033            .bind(room_id)
1034            .bind(called_user_id)
1035            .bind(calling_user_id)
1036            .bind(calling_connection_id.0 as i32)
1037            .bind(initial_project_id)
1038            .execute(&mut tx)
1039            .await?;
1040
1041            let room = self.get_room(room_id, &mut tx).await?;
1042            let incoming_call = Self::build_incoming_call(&room, called_user_id)
1043                .ok_or_else(|| anyhow!("failed to build incoming call"))?;
1044            self.commit_room_transaction(room_id, tx, (room, incoming_call))
1045                .await
1046        })
1047        .await
1048    }
1049
1050    pub async fn incoming_call_for_user(
1051        &self,
1052        user_id: UserId,
1053    ) -> Result<Option<proto::IncomingCall>> {
1054        self.transact(|mut tx| async move {
1055            let room_id = sqlx::query_scalar::<_, RoomId>(
1056                "
1057                SELECT room_id
1058                FROM room_participants
1059                WHERE user_id = $1 AND answering_connection_id IS NULL
1060                ",
1061            )
1062            .bind(user_id)
1063            .fetch_optional(&mut tx)
1064            .await?;
1065
1066            if let Some(room_id) = room_id {
1067                let room = self.get_room(room_id, &mut tx).await?;
1068                Ok(Self::build_incoming_call(&room, user_id))
1069            } else {
1070                Ok(None)
1071            }
1072        })
1073        .await
1074    }
1075
1076    fn build_incoming_call(
1077        room: &proto::Room,
1078        called_user_id: UserId,
1079    ) -> Option<proto::IncomingCall> {
1080        let pending_participant = room
1081            .pending_participants
1082            .iter()
1083            .find(|participant| participant.user_id == called_user_id.to_proto())?;
1084
1085        Some(proto::IncomingCall {
1086            room_id: room.id,
1087            calling_user_id: pending_participant.calling_user_id,
1088            participant_user_ids: room
1089                .participants
1090                .iter()
1091                .map(|participant| participant.user_id)
1092                .collect(),
1093            initial_project: room.participants.iter().find_map(|participant| {
1094                let initial_project_id = pending_participant.initial_project_id?;
1095                participant
1096                    .projects
1097                    .iter()
1098                    .find(|project| project.id == initial_project_id)
1099                    .cloned()
1100            }),
1101        })
1102    }
1103
1104    pub async fn call_failed(
1105        &self,
1106        room_id: RoomId,
1107        called_user_id: UserId,
1108    ) -> Result<RoomGuard<proto::Room>> {
1109        self.transact(|mut tx| async move {
1110            sqlx::query(
1111                "
1112                DELETE FROM room_participants
1113                WHERE room_id = $1 AND user_id = $2
1114                ",
1115            )
1116            .bind(room_id)
1117            .bind(called_user_id)
1118            .execute(&mut tx)
1119            .await?;
1120
1121            let room = self.get_room(room_id, &mut tx).await?;
1122            self.commit_room_transaction(room_id, tx, room).await
1123        })
1124        .await
1125    }
1126
1127    pub async fn decline_call(
1128        &self,
1129        expected_room_id: Option<RoomId>,
1130        user_id: UserId,
1131    ) -> Result<RoomGuard<proto::Room>> {
1132        self.transact(|mut tx| async move {
1133            let room_id = sqlx::query_scalar(
1134                "
1135                DELETE FROM room_participants
1136                WHERE user_id = $1 AND answering_connection_id IS NULL
1137                RETURNING room_id
1138                ",
1139            )
1140            .bind(user_id)
1141            .fetch_one(&mut tx)
1142            .await?;
1143            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1144                return Err(anyhow!("declining call on unexpected room"))?;
1145            }
1146
1147            let room = self.get_room(room_id, &mut tx).await?;
1148            self.commit_room_transaction(room_id, tx, room).await
1149        })
1150        .await
1151    }
1152
1153    pub async fn cancel_call(
1154        &self,
1155        expected_room_id: Option<RoomId>,
1156        calling_connection_id: ConnectionId,
1157        called_user_id: UserId,
1158    ) -> Result<RoomGuard<proto::Room>> {
1159        self.transact(|mut tx| async move {
1160            let room_id = sqlx::query_scalar(
1161                "
1162                DELETE FROM room_participants
1163                WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1164                RETURNING room_id
1165                ",
1166            )
1167            .bind(called_user_id)
1168            .bind(calling_connection_id.0 as i32)
1169            .fetch_one(&mut tx)
1170            .await?;
1171            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1172                return Err(anyhow!("canceling call on unexpected room"))?;
1173            }
1174
1175            let room = self.get_room(room_id, &mut tx).await?;
1176            self.commit_room_transaction(room_id, tx, room).await
1177        }).await
1178    }
1179
1180    pub async fn join_room(
1181        &self,
1182        room_id: RoomId,
1183        user_id: UserId,
1184        connection_id: ConnectionId,
1185    ) -> Result<RoomGuard<proto::Room>> {
1186        self.transact(|mut tx| async move {
1187            sqlx::query(
1188                "
1189                UPDATE room_participants 
1190                SET answering_connection_id = $1
1191                WHERE room_id = $2 AND user_id = $3
1192                RETURNING 1
1193                ",
1194            )
1195            .bind(connection_id.0 as i32)
1196            .bind(room_id)
1197            .bind(user_id)
1198            .fetch_one(&mut tx)
1199            .await?;
1200
1201            let room = self.get_room(room_id, &mut tx).await?;
1202            self.commit_room_transaction(room_id, tx, room).await
1203        })
1204        .await
1205    }
1206
1207    pub async fn leave_room(
1208        &self,
1209        connection_id: ConnectionId,
1210    ) -> Result<Option<RoomGuard<LeftRoom>>> {
1211        self.transact(|mut tx| async move {
1212            // Leave room.
1213            let room_id = sqlx::query_scalar::<_, RoomId>(
1214                "
1215                DELETE FROM room_participants
1216                WHERE answering_connection_id = $1
1217                RETURNING room_id
1218                ",
1219            )
1220            .bind(connection_id.0 as i32)
1221            .fetch_optional(&mut tx)
1222            .await?;
1223
1224            if let Some(room_id) = room_id {
1225                // Cancel pending calls initiated by the leaving user.
1226                let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1227                    "
1228                    DELETE FROM room_participants
1229                    WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1230                    RETURNING user_id
1231                    ",
1232                )
1233                .bind(connection_id.0 as i32)
1234                .fetch_all(&mut tx)
1235                .await?;
1236
1237                let project_ids = sqlx::query_scalar::<_, ProjectId>(
1238                    "
1239                    SELECT project_id
1240                    FROM project_collaborators
1241                    WHERE connection_id = $1
1242                    ",
1243                )
1244                .bind(connection_id.0 as i32)
1245                .fetch_all(&mut tx)
1246                .await?;
1247
1248                // Leave projects.
1249                let mut left_projects = HashMap::default();
1250                if !project_ids.is_empty() {
1251                    let mut params = "?,".repeat(project_ids.len());
1252                    params.pop();
1253                    let query = format!(
1254                        "
1255                        SELECT *
1256                        FROM project_collaborators
1257                        WHERE project_id IN ({params})
1258                    "
1259                    );
1260                    let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1261                    for project_id in project_ids {
1262                        query = query.bind(project_id);
1263                    }
1264
1265                    let mut project_collaborators = query.fetch(&mut tx);
1266                    while let Some(collaborator) = project_collaborators.next().await {
1267                        let collaborator = collaborator?;
1268                        let left_project =
1269                            left_projects
1270                                .entry(collaborator.project_id)
1271                                .or_insert(LeftProject {
1272                                    id: collaborator.project_id,
1273                                    host_user_id: Default::default(),
1274                                    connection_ids: Default::default(),
1275                                    host_connection_id: Default::default(),
1276                                });
1277
1278                        let collaborator_connection_id =
1279                            ConnectionId(collaborator.connection_id as u32);
1280                        if collaborator_connection_id != connection_id {
1281                            left_project.connection_ids.push(collaborator_connection_id);
1282                        }
1283
1284                        if collaborator.is_host {
1285                            left_project.host_user_id = collaborator.user_id;
1286                            left_project.host_connection_id =
1287                                ConnectionId(collaborator.connection_id as u32);
1288                        }
1289                    }
1290                }
1291                sqlx::query(
1292                    "
1293                    DELETE FROM project_collaborators
1294                    WHERE connection_id = $1
1295                    ",
1296                )
1297                .bind(connection_id.0 as i32)
1298                .execute(&mut tx)
1299                .await?;
1300
1301                // Unshare projects.
1302                sqlx::query(
1303                    "
1304                    DELETE FROM projects
1305                    WHERE room_id = $1 AND host_connection_id = $2
1306                    ",
1307                )
1308                .bind(room_id)
1309                .bind(connection_id.0 as i32)
1310                .execute(&mut tx)
1311                .await?;
1312
1313                let room = self.get_room(room_id, &mut tx).await?;
1314                Ok(Some(
1315                    self.commit_room_transaction(
1316                        room_id,
1317                        tx,
1318                        LeftRoom {
1319                            room,
1320                            left_projects,
1321                            canceled_calls_to_user_ids,
1322                        },
1323                    )
1324                    .await?,
1325                ))
1326            } else {
1327                Ok(None)
1328            }
1329        })
1330        .await
1331    }
1332
1333    pub async fn update_room_participant_location(
1334        &self,
1335        room_id: RoomId,
1336        connection_id: ConnectionId,
1337        location: proto::ParticipantLocation,
1338    ) -> Result<RoomGuard<proto::Room>> {
1339        self.transact(|tx| async {
1340            let mut tx = tx;
1341            let location_kind;
1342            let location_project_id;
1343            match location
1344                .variant
1345                .as_ref()
1346                .ok_or_else(|| anyhow!("invalid location"))?
1347            {
1348                proto::participant_location::Variant::SharedProject(project) => {
1349                    location_kind = 0;
1350                    location_project_id = Some(ProjectId::from_proto(project.id));
1351                }
1352                proto::participant_location::Variant::UnsharedProject(_) => {
1353                    location_kind = 1;
1354                    location_project_id = None;
1355                }
1356                proto::participant_location::Variant::External(_) => {
1357                    location_kind = 2;
1358                    location_project_id = None;
1359                }
1360            }
1361
1362            sqlx::query(
1363                "
1364                UPDATE room_participants
1365                SET location_kind = $1, location_project_id = $2
1366                WHERE room_id = $3 AND answering_connection_id = $4
1367                RETURNING 1
1368                ",
1369            )
1370            .bind(location_kind)
1371            .bind(location_project_id)
1372            .bind(room_id)
1373            .bind(connection_id.0 as i32)
1374            .fetch_one(&mut tx)
1375            .await?;
1376
1377            let room = self.get_room(room_id, &mut tx).await?;
1378            self.commit_room_transaction(room_id, tx, room).await
1379        })
1380        .await
1381    }
1382
1383    async fn get_guest_connection_ids(
1384        &self,
1385        project_id: ProjectId,
1386        tx: &mut sqlx::Transaction<'_, D>,
1387    ) -> Result<Vec<ConnectionId>> {
1388        let mut guest_connection_ids = Vec::new();
1389        let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1390            "
1391            SELECT connection_id
1392            FROM project_collaborators
1393            WHERE project_id = $1 AND is_host = FALSE
1394            ",
1395        )
1396        .bind(project_id)
1397        .fetch(tx);
1398        while let Some(connection_id) = db_guest_connection_ids.next().await {
1399            guest_connection_ids.push(ConnectionId(connection_id? as u32));
1400        }
1401        Ok(guest_connection_ids)
1402    }
1403
1404    async fn get_room(
1405        &self,
1406        room_id: RoomId,
1407        tx: &mut sqlx::Transaction<'_, D>,
1408    ) -> Result<proto::Room> {
1409        let room: Room = sqlx::query_as(
1410            "
1411            SELECT *
1412            FROM rooms
1413            WHERE id = $1
1414            ",
1415        )
1416        .bind(room_id)
1417        .fetch_one(&mut *tx)
1418        .await?;
1419
1420        let mut db_participants =
1421            sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1422                "
1423                SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1424                FROM room_participants
1425                WHERE room_id = $1
1426                ",
1427            )
1428            .bind(room_id)
1429            .fetch(&mut *tx);
1430
1431        let mut participants = HashMap::default();
1432        let mut pending_participants = Vec::new();
1433        while let Some(participant) = db_participants.next().await {
1434            let (
1435                user_id,
1436                answering_connection_id,
1437                location_kind,
1438                location_project_id,
1439                calling_user_id,
1440                initial_project_id,
1441            ) = participant?;
1442            if let Some(answering_connection_id) = answering_connection_id {
1443                let location = match (location_kind, location_project_id) {
1444                    (Some(0), Some(project_id)) => {
1445                        Some(proto::participant_location::Variant::SharedProject(
1446                            proto::participant_location::SharedProject {
1447                                id: project_id.to_proto(),
1448                            },
1449                        ))
1450                    }
1451                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1452                        Default::default(),
1453                    )),
1454                    _ => Some(proto::participant_location::Variant::External(
1455                        Default::default(),
1456                    )),
1457                };
1458                participants.insert(
1459                    answering_connection_id,
1460                    proto::Participant {
1461                        user_id: user_id.to_proto(),
1462                        peer_id: answering_connection_id as u32,
1463                        projects: Default::default(),
1464                        location: Some(proto::ParticipantLocation { variant: location }),
1465                    },
1466                );
1467            } else {
1468                pending_participants.push(proto::PendingParticipant {
1469                    user_id: user_id.to_proto(),
1470                    calling_user_id: calling_user_id.to_proto(),
1471                    initial_project_id: initial_project_id.map(|id| id.to_proto()),
1472                });
1473            }
1474        }
1475        drop(db_participants);
1476
1477        let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1478            "
1479            SELECT host_connection_id, projects.id, worktrees.root_name
1480            FROM projects
1481            LEFT JOIN worktrees ON projects.id = worktrees.project_id
1482            WHERE room_id = $1
1483            ",
1484        )
1485        .bind(room_id)
1486        .fetch(&mut *tx);
1487
1488        while let Some(row) = rows.next().await {
1489            let (connection_id, project_id, worktree_root_name) = row?;
1490            if let Some(participant) = participants.get_mut(&connection_id) {
1491                let project = if let Some(project) = participant
1492                    .projects
1493                    .iter_mut()
1494                    .find(|project| project.id == project_id.to_proto())
1495                {
1496                    project
1497                } else {
1498                    participant.projects.push(proto::ParticipantProject {
1499                        id: project_id.to_proto(),
1500                        worktree_root_names: Default::default(),
1501                    });
1502                    participant.projects.last_mut().unwrap()
1503                };
1504                project.worktree_root_names.extend(worktree_root_name);
1505            }
1506        }
1507
1508        Ok(proto::Room {
1509            id: room.id.to_proto(),
1510            live_kit_room: room.live_kit_room,
1511            participants: participants.into_values().collect(),
1512            pending_participants,
1513        })
1514    }
1515
1516    // projects
1517
1518    pub async fn project_count_excluding_admins(&self) -> Result<usize> {
1519        self.transact(|mut tx| async move {
1520            Ok(sqlx::query_scalar::<_, i32>(
1521                "
1522                SELECT COUNT(*)
1523                FROM projects, users
1524                WHERE projects.host_user_id = users.id AND users.admin IS FALSE
1525                ",
1526            )
1527            .fetch_one(&mut tx)
1528            .await? as usize)
1529        })
1530        .await
1531    }
1532
1533    pub async fn share_project(
1534        &self,
1535        expected_room_id: RoomId,
1536        connection_id: ConnectionId,
1537        worktrees: &[proto::WorktreeMetadata],
1538    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
1539        self.transact(|mut tx| async move {
1540            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1541                "
1542                SELECT room_id, user_id
1543                FROM room_participants
1544                WHERE answering_connection_id = $1
1545                ",
1546            )
1547            .bind(connection_id.0 as i32)
1548            .fetch_one(&mut tx)
1549            .await?;
1550            if room_id != expected_room_id {
1551                return Err(anyhow!("shared project on unexpected room"))?;
1552            }
1553
1554            let project_id: ProjectId = sqlx::query_scalar(
1555                "
1556                INSERT INTO projects (room_id, host_user_id, host_connection_id)
1557                VALUES ($1, $2, $3)
1558                RETURNING id
1559                ",
1560            )
1561            .bind(room_id)
1562            .bind(user_id)
1563            .bind(connection_id.0 as i32)
1564            .fetch_one(&mut tx)
1565            .await?;
1566
1567            if !worktrees.is_empty() {
1568                let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1569                params.pop();
1570                let query = format!(
1571                    "
1572                    INSERT INTO worktrees (
1573                        project_id,
1574                        id,
1575                        root_name,
1576                        abs_path,
1577                        visible,
1578                        scan_id,
1579                        is_complete
1580                    )
1581                    VALUES {params}
1582                    "
1583                );
1584
1585                let mut query = sqlx::query(&query);
1586                for worktree in worktrees {
1587                    query = query
1588                        .bind(project_id)
1589                        .bind(worktree.id as i32)
1590                        .bind(&worktree.root_name)
1591                        .bind(&worktree.abs_path)
1592                        .bind(worktree.visible)
1593                        .bind(0)
1594                        .bind(false);
1595                }
1596                query.execute(&mut tx).await?;
1597            }
1598
1599            sqlx::query(
1600                "
1601                INSERT INTO project_collaborators (
1602                    project_id,
1603                    connection_id,
1604                    user_id,
1605                    replica_id,
1606                    is_host
1607                )
1608                VALUES ($1, $2, $3, $4, $5)
1609                ",
1610            )
1611            .bind(project_id)
1612            .bind(connection_id.0 as i32)
1613            .bind(user_id)
1614            .bind(0)
1615            .bind(true)
1616            .execute(&mut tx)
1617            .await?;
1618
1619            let room = self.get_room(room_id, &mut tx).await?;
1620            self.commit_room_transaction(room_id, tx, (project_id, room))
1621                .await
1622        })
1623        .await
1624    }
1625
1626    pub async fn unshare_project(
1627        &self,
1628        project_id: ProjectId,
1629        connection_id: ConnectionId,
1630    ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1631        self.transact(|mut tx| async move {
1632            let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1633            let room_id: RoomId = sqlx::query_scalar(
1634                "
1635                DELETE FROM projects
1636                WHERE id = $1 AND host_connection_id = $2
1637                RETURNING room_id
1638                ",
1639            )
1640            .bind(project_id)
1641            .bind(connection_id.0 as i32)
1642            .fetch_one(&mut tx)
1643            .await?;
1644            let room = self.get_room(room_id, &mut tx).await?;
1645            self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1646                .await
1647        })
1648        .await
1649    }
1650
1651    pub async fn update_project(
1652        &self,
1653        project_id: ProjectId,
1654        connection_id: ConnectionId,
1655        worktrees: &[proto::WorktreeMetadata],
1656    ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1657        self.transact(|mut tx| async move {
1658            let room_id: RoomId = sqlx::query_scalar(
1659                "
1660                SELECT room_id
1661                FROM projects
1662                WHERE id = $1 AND host_connection_id = $2
1663                ",
1664            )
1665            .bind(project_id)
1666            .bind(connection_id.0 as i32)
1667            .fetch_one(&mut tx)
1668            .await?;
1669
1670            if !worktrees.is_empty() {
1671                let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1672                params.pop();
1673                let query = format!(
1674                    "
1675                    INSERT INTO worktrees (
1676                        project_id,
1677                        id,
1678                        root_name,
1679                        abs_path,
1680                        visible,
1681                        scan_id,
1682                        is_complete
1683                    )
1684                    VALUES {params}
1685                    ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1686                    "
1687                );
1688
1689                let mut query = sqlx::query(&query);
1690                for worktree in worktrees {
1691                    query = query
1692                        .bind(project_id)
1693                        .bind(worktree.id as i32)
1694                        .bind(&worktree.root_name)
1695                        .bind(&worktree.abs_path)
1696                        .bind(worktree.visible)
1697                        .bind(0)
1698                        .bind(false)
1699                }
1700                query.execute(&mut tx).await?;
1701            }
1702
1703            let mut params = "?,".repeat(worktrees.len());
1704            if !worktrees.is_empty() {
1705                params.pop();
1706            }
1707            let query = format!(
1708                "
1709                DELETE FROM worktrees
1710                WHERE project_id = ? AND id NOT IN ({params})
1711                ",
1712            );
1713
1714            let mut query = sqlx::query(&query).bind(project_id);
1715            for worktree in worktrees {
1716                query = query.bind(WorktreeId(worktree.id as i32));
1717            }
1718            query.execute(&mut tx).await?;
1719
1720            let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1721            let room = self.get_room(room_id, &mut tx).await?;
1722            self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1723                .await
1724        })
1725        .await
1726    }
1727
1728    pub async fn update_worktree(
1729        &self,
1730        update: &proto::UpdateWorktree,
1731        connection_id: ConnectionId,
1732    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1733        self.transact(|mut tx| async move {
1734            let project_id = ProjectId::from_proto(update.project_id);
1735            let worktree_id = WorktreeId::from_proto(update.worktree_id);
1736
1737            // Ensure the update comes from the host.
1738            let room_id: RoomId = sqlx::query_scalar(
1739                "
1740                SELECT room_id
1741                FROM projects
1742                WHERE id = $1 AND host_connection_id = $2
1743                ",
1744            )
1745            .bind(project_id)
1746            .bind(connection_id.0 as i32)
1747            .fetch_one(&mut tx)
1748            .await?;
1749
1750            // Update metadata.
1751            sqlx::query(
1752                "
1753                UPDATE worktrees
1754                SET
1755                    root_name = $1,
1756                    scan_id = $2,
1757                    is_complete = $3,
1758                    abs_path = $4
1759                WHERE project_id = $5 AND id = $6
1760                RETURNING 1
1761                ",
1762            )
1763            .bind(&update.root_name)
1764            .bind(update.scan_id as i64)
1765            .bind(update.is_last_update)
1766            .bind(&update.abs_path)
1767            .bind(project_id)
1768            .bind(worktree_id)
1769            .fetch_one(&mut tx)
1770            .await?;
1771
1772            if !update.updated_entries.is_empty() {
1773                let mut params =
1774                    "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1775                params.pop();
1776
1777                let query = format!(
1778                    "
1779                    INSERT INTO worktree_entries (
1780                        project_id, 
1781                        worktree_id, 
1782                        id, 
1783                        is_dir, 
1784                        path, 
1785                        inode,
1786                        mtime_seconds, 
1787                        mtime_nanos, 
1788                        is_symlink, 
1789                        is_ignored
1790                    )
1791                    VALUES {params}
1792                    ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1793                        is_dir = excluded.is_dir,
1794                        path = excluded.path,
1795                        inode = excluded.inode,
1796                        mtime_seconds = excluded.mtime_seconds,
1797                        mtime_nanos = excluded.mtime_nanos,
1798                        is_symlink = excluded.is_symlink,
1799                        is_ignored = excluded.is_ignored
1800                    "
1801                );
1802                let mut query = sqlx::query(&query);
1803                for entry in &update.updated_entries {
1804                    let mtime = entry.mtime.clone().unwrap_or_default();
1805                    query = query
1806                        .bind(project_id)
1807                        .bind(worktree_id)
1808                        .bind(entry.id as i64)
1809                        .bind(entry.is_dir)
1810                        .bind(&entry.path)
1811                        .bind(entry.inode as i64)
1812                        .bind(mtime.seconds as i64)
1813                        .bind(mtime.nanos as i32)
1814                        .bind(entry.is_symlink)
1815                        .bind(entry.is_ignored);
1816                }
1817                query.execute(&mut tx).await?;
1818            }
1819
1820            if !update.removed_entries.is_empty() {
1821                let mut params = "?,".repeat(update.removed_entries.len());
1822                params.pop();
1823                let query = format!(
1824                    "
1825                    DELETE FROM worktree_entries
1826                    WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1827                    "
1828                );
1829
1830                let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1831                for entry_id in &update.removed_entries {
1832                    query = query.bind(*entry_id as i64);
1833                }
1834                query.execute(&mut tx).await?;
1835            }
1836
1837            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1838            self.commit_room_transaction(room_id, tx, connection_ids)
1839                .await
1840        })
1841        .await
1842    }
1843
1844    pub async fn update_diagnostic_summary(
1845        &self,
1846        update: &proto::UpdateDiagnosticSummary,
1847        connection_id: ConnectionId,
1848    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1849        self.transact(|mut tx| async {
1850            let project_id = ProjectId::from_proto(update.project_id);
1851            let worktree_id = WorktreeId::from_proto(update.worktree_id);
1852            let summary = update
1853                .summary
1854                .as_ref()
1855                .ok_or_else(|| anyhow!("invalid summary"))?;
1856
1857            // Ensure the update comes from the host.
1858            let room_id: RoomId = sqlx::query_scalar(
1859                "
1860                SELECT room_id
1861                FROM projects
1862                WHERE id = $1 AND host_connection_id = $2
1863                ",
1864            )
1865            .bind(project_id)
1866            .bind(connection_id.0 as i32)
1867            .fetch_one(&mut tx)
1868            .await?;
1869
1870            // Update summary.
1871            sqlx::query(
1872                "
1873                INSERT INTO worktree_diagnostic_summaries (
1874                    project_id,
1875                    worktree_id,
1876                    path,
1877                    language_server_id,
1878                    error_count,
1879                    warning_count
1880                )
1881                VALUES ($1, $2, $3, $4, $5, $6)
1882                ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1883                    language_server_id = excluded.language_server_id,
1884                    error_count = excluded.error_count, 
1885                    warning_count = excluded.warning_count
1886                ",
1887            )
1888            .bind(project_id)
1889            .bind(worktree_id)
1890            .bind(&summary.path)
1891            .bind(summary.language_server_id as i64)
1892            .bind(summary.error_count as i32)
1893            .bind(summary.warning_count as i32)
1894            .execute(&mut tx)
1895            .await?;
1896
1897            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1898            self.commit_room_transaction(room_id, tx, connection_ids)
1899                .await
1900        })
1901        .await
1902    }
1903
1904    pub async fn start_language_server(
1905        &self,
1906        update: &proto::StartLanguageServer,
1907        connection_id: ConnectionId,
1908    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1909        self.transact(|mut tx| async {
1910            let project_id = ProjectId::from_proto(update.project_id);
1911            let server = update
1912                .server
1913                .as_ref()
1914                .ok_or_else(|| anyhow!("invalid language server"))?;
1915
1916            // Ensure the update comes from the host.
1917            let room_id: RoomId = sqlx::query_scalar(
1918                "
1919                SELECT room_id
1920                FROM projects
1921                WHERE id = $1 AND host_connection_id = $2
1922                ",
1923            )
1924            .bind(project_id)
1925            .bind(connection_id.0 as i32)
1926            .fetch_one(&mut tx)
1927            .await?;
1928
1929            // Add the newly-started language server.
1930            sqlx::query(
1931                "
1932                INSERT INTO language_servers (project_id, id, name)
1933                VALUES ($1, $2, $3)
1934                ON CONFLICT (project_id, id) DO UPDATE SET
1935                    name = excluded.name
1936                ",
1937            )
1938            .bind(project_id)
1939            .bind(server.id as i64)
1940            .bind(&server.name)
1941            .execute(&mut tx)
1942            .await?;
1943
1944            let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1945            self.commit_room_transaction(room_id, tx, connection_ids)
1946                .await
1947        })
1948        .await
1949    }
1950
1951    pub async fn join_project(
1952        &self,
1953        project_id: ProjectId,
1954        connection_id: ConnectionId,
1955    ) -> Result<RoomGuard<(Project, ReplicaId)>> {
1956        self.transact(|mut tx| async move {
1957            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1958                "
1959                SELECT room_id, user_id
1960                FROM room_participants
1961                WHERE answering_connection_id = $1
1962                ",
1963            )
1964            .bind(connection_id.0 as i32)
1965            .fetch_one(&mut tx)
1966            .await?;
1967
1968            // Ensure project id was shared on this room.
1969            sqlx::query(
1970                "
1971                SELECT 1
1972                FROM projects
1973                WHERE id = $1 AND room_id = $2
1974                ",
1975            )
1976            .bind(project_id)
1977            .bind(room_id)
1978            .fetch_one(&mut tx)
1979            .await?;
1980
1981            let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1982                "
1983                SELECT *
1984                FROM project_collaborators
1985                WHERE project_id = $1
1986                ",
1987            )
1988            .bind(project_id)
1989            .fetch_all(&mut tx)
1990            .await?;
1991            let replica_ids = collaborators
1992                .iter()
1993                .map(|c| c.replica_id)
1994                .collect::<HashSet<_>>();
1995            let mut replica_id = ReplicaId(1);
1996            while replica_ids.contains(&replica_id) {
1997                replica_id.0 += 1;
1998            }
1999            let new_collaborator = ProjectCollaborator {
2000                project_id,
2001                connection_id: connection_id.0 as i32,
2002                user_id,
2003                replica_id,
2004                is_host: false,
2005            };
2006
2007            sqlx::query(
2008                "
2009                INSERT INTO project_collaborators (
2010                    project_id,
2011                    connection_id,
2012                    user_id,
2013                    replica_id,
2014                    is_host
2015                )
2016                VALUES ($1, $2, $3, $4, $5)
2017                ",
2018            )
2019            .bind(new_collaborator.project_id)
2020            .bind(new_collaborator.connection_id)
2021            .bind(new_collaborator.user_id)
2022            .bind(new_collaborator.replica_id)
2023            .bind(new_collaborator.is_host)
2024            .execute(&mut tx)
2025            .await?;
2026            collaborators.push(new_collaborator);
2027
2028            let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
2029                "
2030                SELECT *
2031                FROM worktrees
2032                WHERE project_id = $1
2033                ",
2034            )
2035            .bind(project_id)
2036            .fetch_all(&mut tx)
2037            .await?;
2038            let mut worktrees = worktree_rows
2039                .into_iter()
2040                .map(|worktree_row| {
2041                    (
2042                        worktree_row.id,
2043                        Worktree {
2044                            id: worktree_row.id,
2045                            abs_path: worktree_row.abs_path,
2046                            root_name: worktree_row.root_name,
2047                            visible: worktree_row.visible,
2048                            entries: Default::default(),
2049                            diagnostic_summaries: Default::default(),
2050                            scan_id: worktree_row.scan_id as u64,
2051                            is_complete: worktree_row.is_complete,
2052                        },
2053                    )
2054                })
2055                .collect::<BTreeMap<_, _>>();
2056
2057            // Populate worktree entries.
2058            {
2059                let mut entries = sqlx::query_as::<_, WorktreeEntry>(
2060                    "
2061                    SELECT *
2062                    FROM worktree_entries
2063                    WHERE project_id = $1
2064                    ",
2065                )
2066                .bind(project_id)
2067                .fetch(&mut tx);
2068                while let Some(entry) = entries.next().await {
2069                    let entry = entry?;
2070                    if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2071                        worktree.entries.push(proto::Entry {
2072                            id: entry.id as u64,
2073                            is_dir: entry.is_dir,
2074                            path: entry.path,
2075                            inode: entry.inode as u64,
2076                            mtime: Some(proto::Timestamp {
2077                                seconds: entry.mtime_seconds as u64,
2078                                nanos: entry.mtime_nanos as u32,
2079                            }),
2080                            is_symlink: entry.is_symlink,
2081                            is_ignored: entry.is_ignored,
2082                        });
2083                    }
2084                }
2085            }
2086
2087            // Populate worktree diagnostic summaries.
2088            {
2089                let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2090                    "
2091                    SELECT *
2092                    FROM worktree_diagnostic_summaries
2093                    WHERE project_id = $1
2094                    ",
2095                )
2096                .bind(project_id)
2097                .fetch(&mut tx);
2098                while let Some(summary) = summaries.next().await {
2099                    let summary = summary?;
2100                    if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2101                        worktree
2102                            .diagnostic_summaries
2103                            .push(proto::DiagnosticSummary {
2104                                path: summary.path,
2105                                language_server_id: summary.language_server_id as u64,
2106                                error_count: summary.error_count as u32,
2107                                warning_count: summary.warning_count as u32,
2108                            });
2109                    }
2110                }
2111            }
2112
2113            // Populate language servers.
2114            let language_servers = sqlx::query_as::<_, LanguageServer>(
2115                "
2116                SELECT *
2117                FROM language_servers
2118                WHERE project_id = $1
2119                ",
2120            )
2121            .bind(project_id)
2122            .fetch_all(&mut tx)
2123            .await?;
2124
2125            self.commit_room_transaction(
2126                room_id,
2127                tx,
2128                (
2129                    Project {
2130                        collaborators,
2131                        worktrees,
2132                        language_servers: language_servers
2133                            .into_iter()
2134                            .map(|language_server| proto::LanguageServer {
2135                                id: language_server.id.to_proto(),
2136                                name: language_server.name,
2137                            })
2138                            .collect(),
2139                    },
2140                    replica_id as ReplicaId,
2141                ),
2142            )
2143            .await
2144        })
2145        .await
2146    }
2147
2148    pub async fn leave_project(
2149        &self,
2150        project_id: ProjectId,
2151        connection_id: ConnectionId,
2152    ) -> Result<RoomGuard<LeftProject>> {
2153        self.transact(|mut tx| async move {
2154            let result = sqlx::query(
2155                "
2156                DELETE FROM project_collaborators
2157                WHERE project_id = $1 AND connection_id = $2
2158                ",
2159            )
2160            .bind(project_id)
2161            .bind(connection_id.0 as i32)
2162            .execute(&mut tx)
2163            .await?;
2164
2165            if result.rows_affected() == 0 {
2166                Err(anyhow!("not a collaborator on this project"))?;
2167            }
2168
2169            let connection_ids = sqlx::query_scalar::<_, i32>(
2170                "
2171                SELECT connection_id
2172                FROM project_collaborators
2173                WHERE project_id = $1
2174                ",
2175            )
2176            .bind(project_id)
2177            .fetch_all(&mut tx)
2178            .await?
2179            .into_iter()
2180            .map(|id| ConnectionId(id as u32))
2181            .collect();
2182
2183            let (room_id, host_user_id, host_connection_id) =
2184                sqlx::query_as::<_, (RoomId, i32, i32)>(
2185                    "
2186                SELECT room_id, host_user_id, host_connection_id
2187                FROM projects
2188                WHERE id = $1
2189                ",
2190                )
2191                .bind(project_id)
2192                .fetch_one(&mut tx)
2193                .await?;
2194
2195            self.commit_room_transaction(
2196                room_id,
2197                tx,
2198                LeftProject {
2199                    id: project_id,
2200                    host_user_id: UserId(host_user_id),
2201                    host_connection_id: ConnectionId(host_connection_id as u32),
2202                    connection_ids,
2203                },
2204            )
2205            .await
2206        })
2207        .await
2208    }
2209
2210    pub async fn project_collaborators(
2211        &self,
2212        project_id: ProjectId,
2213        connection_id: ConnectionId,
2214    ) -> Result<Vec<ProjectCollaborator>> {
2215        self.transact(|mut tx| async move {
2216            let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2217                "
2218                SELECT *
2219                FROM project_collaborators
2220                WHERE project_id = $1
2221                ",
2222            )
2223            .bind(project_id)
2224            .fetch_all(&mut tx)
2225            .await?;
2226
2227            if collaborators
2228                .iter()
2229                .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2230            {
2231                Ok(collaborators)
2232            } else {
2233                Err(anyhow!("no such project"))?
2234            }
2235        })
2236        .await
2237    }
2238
2239    pub async fn project_connection_ids(
2240        &self,
2241        project_id: ProjectId,
2242        connection_id: ConnectionId,
2243    ) -> Result<HashSet<ConnectionId>> {
2244        self.transact(|mut tx| async move {
2245            let connection_ids = sqlx::query_scalar::<_, i32>(
2246                "
2247                SELECT connection_id
2248                FROM project_collaborators
2249                WHERE project_id = $1
2250                ",
2251            )
2252            .bind(project_id)
2253            .fetch_all(&mut tx)
2254            .await?;
2255
2256            if connection_ids.contains(&(connection_id.0 as i32)) {
2257                Ok(connection_ids
2258                    .into_iter()
2259                    .map(|connection_id| ConnectionId(connection_id as u32))
2260                    .collect())
2261            } else {
2262                Err(anyhow!("no such project"))?
2263            }
2264        })
2265        .await
2266    }
2267
2268    // contacts
2269
2270    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2271        self.transact(|mut tx| async move {
2272            let query = "
2273                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2274                FROM contacts
2275                LEFT JOIN room_participants ON room_participants.user_id = $1
2276                WHERE user_id_a = $1 OR user_id_b = $1;
2277            ";
2278
2279            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2280                .bind(user_id)
2281                .fetch(&mut tx);
2282
2283            let mut contacts = Vec::new();
2284            while let Some(row) = rows.next().await {
2285                let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2286                if user_id_a == user_id {
2287                    if accepted {
2288                        contacts.push(Contact::Accepted {
2289                            user_id: user_id_b,
2290                            should_notify: should_notify && a_to_b,
2291                            busy
2292                        });
2293                    } else if a_to_b {
2294                        contacts.push(Contact::Outgoing { user_id: user_id_b })
2295                    } else {
2296                        contacts.push(Contact::Incoming {
2297                            user_id: user_id_b,
2298                            should_notify,
2299                        });
2300                    }
2301                } else if accepted {
2302                    contacts.push(Contact::Accepted {
2303                        user_id: user_id_a,
2304                        should_notify: should_notify && !a_to_b,
2305                        busy
2306                    });
2307                } else if a_to_b {
2308                    contacts.push(Contact::Incoming {
2309                        user_id: user_id_a,
2310                        should_notify,
2311                    });
2312                } else {
2313                    contacts.push(Contact::Outgoing { user_id: user_id_a });
2314                }
2315            }
2316
2317            contacts.sort_unstable_by_key(|contact| contact.user_id());
2318
2319            Ok(contacts)
2320        })
2321        .await
2322    }
2323
2324    pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2325        self.transact(|mut tx| async move {
2326            Ok(sqlx::query_scalar::<_, i32>(
2327                "
2328                SELECT 1
2329                FROM room_participants
2330                WHERE room_participants.user_id = $1
2331                ",
2332            )
2333            .bind(user_id)
2334            .fetch_optional(&mut tx)
2335            .await?
2336            .is_some())
2337        })
2338        .await
2339    }
2340
2341    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2342        self.transact(|mut tx| async move {
2343            let (id_a, id_b) = if user_id_1 < user_id_2 {
2344                (user_id_1, user_id_2)
2345            } else {
2346                (user_id_2, user_id_1)
2347            };
2348
2349            let query = "
2350                SELECT 1 FROM contacts
2351                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2352                LIMIT 1
2353            ";
2354            Ok(sqlx::query_scalar::<_, i32>(query)
2355                .bind(id_a.0)
2356                .bind(id_b.0)
2357                .fetch_optional(&mut tx)
2358                .await?
2359                .is_some())
2360        })
2361        .await
2362    }
2363
2364    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2365        self.transact(|mut tx| async move {
2366            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2367                (sender_id, receiver_id, true)
2368            } else {
2369                (receiver_id, sender_id, false)
2370            };
2371            let query = "
2372                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2373                VALUES ($1, $2, $3, FALSE, TRUE)
2374                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2375                SET
2376                    accepted = TRUE,
2377                    should_notify = FALSE
2378                WHERE
2379                    NOT contacts.accepted AND
2380                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2381                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2382            ";
2383            let result = sqlx::query(query)
2384                .bind(id_a.0)
2385                .bind(id_b.0)
2386                .bind(a_to_b)
2387                .execute(&mut tx)
2388                .await?;
2389
2390            if result.rows_affected() == 1 {
2391                tx.commit().await?;
2392                Ok(())
2393            } else {
2394                Err(anyhow!("contact already requested"))?
2395            }
2396        }).await
2397    }
2398
2399    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2400        self.transact(|mut tx| async move {
2401            let (id_a, id_b) = if responder_id < requester_id {
2402                (responder_id, requester_id)
2403            } else {
2404                (requester_id, responder_id)
2405            };
2406            let query = "
2407                DELETE FROM contacts
2408                WHERE user_id_a = $1 AND user_id_b = $2;
2409            ";
2410            let result = sqlx::query(query)
2411                .bind(id_a.0)
2412                .bind(id_b.0)
2413                .execute(&mut tx)
2414                .await?;
2415
2416            if result.rows_affected() == 1 {
2417                tx.commit().await?;
2418                Ok(())
2419            } else {
2420                Err(anyhow!("no such contact"))?
2421            }
2422        })
2423        .await
2424    }
2425
2426    pub async fn dismiss_contact_notification(
2427        &self,
2428        user_id: UserId,
2429        contact_user_id: UserId,
2430    ) -> Result<()> {
2431        self.transact(|mut tx| async move {
2432            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2433                (user_id, contact_user_id, true)
2434            } else {
2435                (contact_user_id, user_id, false)
2436            };
2437
2438            let query = "
2439                UPDATE contacts
2440                SET should_notify = FALSE
2441                WHERE
2442                    user_id_a = $1 AND user_id_b = $2 AND
2443                    (
2444                        (a_to_b = $3 AND accepted) OR
2445                        (a_to_b != $3 AND NOT accepted)
2446                    );
2447            ";
2448
2449            let result = sqlx::query(query)
2450                .bind(id_a.0)
2451                .bind(id_b.0)
2452                .bind(a_to_b)
2453                .execute(&mut tx)
2454                .await?;
2455
2456            if result.rows_affected() == 0 {
2457                Err(anyhow!("no such contact request"))?
2458            } else {
2459                tx.commit().await?;
2460                Ok(())
2461            }
2462        })
2463        .await
2464    }
2465
2466    pub async fn respond_to_contact_request(
2467        &self,
2468        responder_id: UserId,
2469        requester_id: UserId,
2470        accept: bool,
2471    ) -> Result<()> {
2472        self.transact(|mut tx| async move {
2473            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2474                (responder_id, requester_id, false)
2475            } else {
2476                (requester_id, responder_id, true)
2477            };
2478            let result = if accept {
2479                let query = "
2480                    UPDATE contacts
2481                    SET accepted = TRUE, should_notify = TRUE
2482                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2483                ";
2484                sqlx::query(query)
2485                    .bind(id_a.0)
2486                    .bind(id_b.0)
2487                    .bind(a_to_b)
2488                    .execute(&mut tx)
2489                    .await?
2490            } else {
2491                let query = "
2492                    DELETE FROM contacts
2493                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2494                ";
2495                sqlx::query(query)
2496                    .bind(id_a.0)
2497                    .bind(id_b.0)
2498                    .bind(a_to_b)
2499                    .execute(&mut tx)
2500                    .await?
2501            };
2502            if result.rows_affected() == 1 {
2503                tx.commit().await?;
2504                Ok(())
2505            } else {
2506                Err(anyhow!("no such contact request"))?
2507            }
2508        })
2509        .await
2510    }
2511
2512    // access tokens
2513
2514    pub async fn create_access_token_hash(
2515        &self,
2516        user_id: UserId,
2517        access_token_hash: &str,
2518        max_access_token_count: usize,
2519    ) -> Result<()> {
2520        self.transact(|tx| async {
2521            let mut tx = tx;
2522            let insert_query = "
2523                INSERT INTO access_tokens (user_id, hash)
2524                VALUES ($1, $2);
2525            ";
2526            let cleanup_query = "
2527                DELETE FROM access_tokens
2528                WHERE id IN (
2529                    SELECT id from access_tokens
2530                    WHERE user_id = $1
2531                    ORDER BY id DESC
2532                    LIMIT 10000
2533                    OFFSET $3
2534                )
2535            ";
2536
2537            sqlx::query(insert_query)
2538                .bind(user_id.0)
2539                .bind(access_token_hash)
2540                .execute(&mut tx)
2541                .await?;
2542            sqlx::query(cleanup_query)
2543                .bind(user_id.0)
2544                .bind(access_token_hash)
2545                .bind(max_access_token_count as i32)
2546                .execute(&mut tx)
2547                .await?;
2548            Ok(tx.commit().await?)
2549        })
2550        .await
2551    }
2552
2553    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2554        self.transact(|mut tx| async move {
2555            let query = "
2556                SELECT hash
2557                FROM access_tokens
2558                WHERE user_id = $1
2559                ORDER BY id DESC
2560            ";
2561            Ok(sqlx::query_scalar(query)
2562                .bind(user_id.0)
2563                .fetch_all(&mut tx)
2564                .await?)
2565        })
2566        .await
2567    }
2568
2569    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2570    where
2571        F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2572        Fut: Send + Future<Output = Result<T>>,
2573    {
2574        let body = async {
2575            loop {
2576                let tx = self.begin_transaction().await?;
2577                match f(tx).await {
2578                    Ok(result) => return Ok(result),
2579                    Err(error) => match error {
2580                        Error::Database(error)
2581                            if error
2582                                .as_database_error()
2583                                .and_then(|error| error.code())
2584                                .as_deref()
2585                                == Some("40001") =>
2586                        {
2587                            // Retry (don't break the loop)
2588                        }
2589                        error @ _ => return Err(error),
2590                    },
2591                }
2592            }
2593        };
2594
2595        #[cfg(test)]
2596        {
2597            if let Some(background) = self.background.as_ref() {
2598                background.simulate_random_delay().await;
2599            }
2600
2601            let result = self.runtime.as_ref().unwrap().block_on(body);
2602
2603            // if let Some(background) = self.background.as_ref() {
2604            //     background.simulate_random_delay().await;
2605            // }
2606
2607            result
2608        }
2609
2610        #[cfg(not(test))]
2611        {
2612            body.await
2613        }
2614    }
2615}
2616
2617macro_rules! id_type {
2618    ($name:ident) => {
2619        #[derive(
2620            Clone,
2621            Copy,
2622            Debug,
2623            Default,
2624            PartialEq,
2625            Eq,
2626            PartialOrd,
2627            Ord,
2628            Hash,
2629            sqlx::Type,
2630            Serialize,
2631            Deserialize,
2632        )]
2633        #[sqlx(transparent)]
2634        #[serde(transparent)]
2635        pub struct $name(pub i32);
2636
2637        impl $name {
2638            #[allow(unused)]
2639            pub const MAX: Self = Self(i32::MAX);
2640
2641            #[allow(unused)]
2642            pub fn from_proto(value: u64) -> Self {
2643                Self(value as i32)
2644            }
2645
2646            #[allow(unused)]
2647            pub fn to_proto(self) -> u64 {
2648                self.0 as u64
2649            }
2650        }
2651
2652        impl std::fmt::Display for $name {
2653            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2654                self.0.fmt(f)
2655            }
2656        }
2657    };
2658}
2659
2660id_type!(UserId);
2661#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2662pub struct User {
2663    pub id: UserId,
2664    pub github_login: String,
2665    pub github_user_id: Option<i32>,
2666    pub email_address: Option<String>,
2667    pub admin: bool,
2668    pub invite_code: Option<String>,
2669    pub invite_count: i32,
2670    pub connected_once: bool,
2671}
2672
2673id_type!(RoomId);
2674#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2675pub struct Room {
2676    pub id: RoomId,
2677    pub live_kit_room: String,
2678}
2679
2680id_type!(ProjectId);
2681pub struct Project {
2682    pub collaborators: Vec<ProjectCollaborator>,
2683    pub worktrees: BTreeMap<WorktreeId, Worktree>,
2684    pub language_servers: Vec<proto::LanguageServer>,
2685}
2686
2687id_type!(ReplicaId);
2688#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2689pub struct ProjectCollaborator {
2690    pub project_id: ProjectId,
2691    pub connection_id: i32,
2692    pub user_id: UserId,
2693    pub replica_id: ReplicaId,
2694    pub is_host: bool,
2695}
2696
2697id_type!(WorktreeId);
2698#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2699struct WorktreeRow {
2700    pub id: WorktreeId,
2701    pub abs_path: String,
2702    pub root_name: String,
2703    pub visible: bool,
2704    pub scan_id: i64,
2705    pub is_complete: bool,
2706}
2707
2708pub struct Worktree {
2709    pub id: WorktreeId,
2710    pub abs_path: String,
2711    pub root_name: String,
2712    pub visible: bool,
2713    pub entries: Vec<proto::Entry>,
2714    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2715    pub scan_id: u64,
2716    pub is_complete: bool,
2717}
2718
2719#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2720struct WorktreeEntry {
2721    id: i64,
2722    worktree_id: WorktreeId,
2723    is_dir: bool,
2724    path: String,
2725    inode: i64,
2726    mtime_seconds: i64,
2727    mtime_nanos: i32,
2728    is_symlink: bool,
2729    is_ignored: bool,
2730}
2731
2732#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2733struct WorktreeDiagnosticSummary {
2734    worktree_id: WorktreeId,
2735    path: String,
2736    language_server_id: i64,
2737    error_count: i32,
2738    warning_count: i32,
2739}
2740
2741id_type!(LanguageServerId);
2742#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2743struct LanguageServer {
2744    id: LanguageServerId,
2745    name: String,
2746}
2747
2748pub struct LeftProject {
2749    pub id: ProjectId,
2750    pub host_user_id: UserId,
2751    pub host_connection_id: ConnectionId,
2752    pub connection_ids: Vec<ConnectionId>,
2753}
2754
2755pub struct LeftRoom {
2756    pub room: proto::Room,
2757    pub left_projects: HashMap<ProjectId, LeftProject>,
2758    pub canceled_calls_to_user_ids: Vec<UserId>,
2759}
2760
2761#[derive(Clone, Debug, PartialEq, Eq)]
2762pub enum Contact {
2763    Accepted {
2764        user_id: UserId,
2765        should_notify: bool,
2766        busy: bool,
2767    },
2768    Outgoing {
2769        user_id: UserId,
2770    },
2771    Incoming {
2772        user_id: UserId,
2773        should_notify: bool,
2774    },
2775}
2776
2777impl Contact {
2778    pub fn user_id(&self) -> UserId {
2779        match self {
2780            Contact::Accepted { user_id, .. } => *user_id,
2781            Contact::Outgoing { user_id } => *user_id,
2782            Contact::Incoming { user_id, .. } => *user_id,
2783        }
2784    }
2785}
2786
2787#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2788pub struct IncomingContactRequest {
2789    pub requester_id: UserId,
2790    pub should_notify: bool,
2791}
2792
2793#[derive(Clone, Deserialize)]
2794pub struct Signup {
2795    pub email_address: String,
2796    pub platform_mac: bool,
2797    pub platform_windows: bool,
2798    pub platform_linux: bool,
2799    pub editor_features: Vec<String>,
2800    pub programming_languages: Vec<String>,
2801    pub device_id: Option<String>,
2802}
2803
2804#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2805pub struct WaitlistSummary {
2806    #[sqlx(default)]
2807    pub count: i64,
2808    #[sqlx(default)]
2809    pub linux_count: i64,
2810    #[sqlx(default)]
2811    pub mac_count: i64,
2812    #[sqlx(default)]
2813    pub windows_count: i64,
2814    #[sqlx(default)]
2815    pub unknown_count: i64,
2816}
2817
2818#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2819pub struct Invite {
2820    pub email_address: String,
2821    pub email_confirmation_code: String,
2822}
2823
2824#[derive(Debug, Serialize, Deserialize)]
2825pub struct NewUserParams {
2826    pub github_login: String,
2827    pub github_user_id: i32,
2828    pub invite_count: i32,
2829}
2830
2831#[derive(Debug)]
2832pub struct NewUserResult {
2833    pub user_id: UserId,
2834    pub metrics_id: String,
2835    pub inviting_user_id: Option<UserId>,
2836    pub signup_device_id: Option<String>,
2837}
2838
2839fn random_invite_code() -> String {
2840    nanoid::nanoid!(16)
2841}
2842
2843fn random_email_confirmation_code() -> String {
2844    nanoid::nanoid!(64)
2845}
2846
2847#[cfg(test)]
2848pub use test::*;
2849
2850#[cfg(test)]
2851mod test {
2852    use super::*;
2853    use gpui::executor::Background;
2854    use lazy_static::lazy_static;
2855    use parking_lot::Mutex;
2856    use rand::prelude::*;
2857    use sqlx::migrate::MigrateDatabase;
2858    use std::sync::Arc;
2859
2860    pub struct SqliteTestDb {
2861        pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2862        pub conn: sqlx::sqlite::SqliteConnection,
2863    }
2864
2865    pub struct PostgresTestDb {
2866        pub db: Option<Arc<Db<sqlx::Postgres>>>,
2867        pub url: String,
2868    }
2869
2870    impl SqliteTestDb {
2871        pub fn new(background: Arc<Background>) -> Self {
2872            let mut rng = StdRng::from_entropy();
2873            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2874            let runtime = tokio::runtime::Builder::new_current_thread()
2875                .enable_io()
2876                .enable_time()
2877                .build()
2878                .unwrap();
2879
2880            let (mut db, conn) = runtime.block_on(async {
2881                let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2882                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2883                db.migrate(migrations_path.as_ref(), false).await.unwrap();
2884                let conn = db.pool.acquire().await.unwrap().detach();
2885                (db, conn)
2886            });
2887
2888            db.background = Some(background);
2889            db.runtime = Some(runtime);
2890
2891            Self {
2892                db: Some(Arc::new(db)),
2893                conn,
2894            }
2895        }
2896
2897        pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2898            self.db.as_ref().unwrap()
2899        }
2900    }
2901
2902    impl PostgresTestDb {
2903        pub fn new(background: Arc<Background>) -> Self {
2904            lazy_static! {
2905                static ref LOCK: Mutex<()> = Mutex::new(());
2906            }
2907
2908            let _guard = LOCK.lock();
2909            let mut rng = StdRng::from_entropy();
2910            let url = format!(
2911                "postgres://postgres@localhost/zed-test-{}",
2912                rng.gen::<u128>()
2913            );
2914            let runtime = tokio::runtime::Builder::new_current_thread()
2915                .enable_io()
2916                .enable_time()
2917                .build()
2918                .unwrap();
2919
2920            let mut db = runtime.block_on(async {
2921                sqlx::Postgres::create_database(&url)
2922                    .await
2923                    .expect("failed to create test db");
2924                let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2925                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2926                db.migrate(Path::new(migrations_path), false).await.unwrap();
2927                db
2928            });
2929
2930            db.background = Some(background);
2931            db.runtime = Some(runtime);
2932
2933            Self {
2934                db: Some(Arc::new(db)),
2935                url,
2936            }
2937        }
2938
2939        pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2940            self.db.as_ref().unwrap()
2941        }
2942    }
2943
2944    impl Drop for PostgresTestDb {
2945        fn drop(&mut self) {
2946            let db = self.db.take().unwrap();
2947            db.teardown(&self.url);
2948        }
2949    }
2950}