db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use collections::HashMap;
   6use futures::StreamExt;
   7use serde::{Deserialize, Serialize};
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{
  10    migrate::{Migrate as _, Migration, MigrationSource},
  11    types::Uuid,
  12    FromRow, QueryBuilder,
  13};
  14use std::{cmp, ops::Range, path::Path, time::Duration};
  15use time::{OffsetDateTime, PrimitiveDateTime};
  16
  17#[async_trait]
  18pub trait Db: Send + Sync {
  19    async fn create_user(
  20        &self,
  21        email_address: &str,
  22        admin: bool,
  23        params: NewUserParams,
  24    ) -> Result<NewUserResult>;
  25    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  26    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  27    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  28    async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
  29    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  30    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  31    async fn get_user_by_github_account(
  32        &self,
  33        github_login: &str,
  34        github_user_id: Option<i32>,
  35    ) -> Result<Option<User>>;
  36    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  37    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  38    async fn destroy_user(&self, id: UserId) -> Result<()>;
  39
  40    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
  41    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  42    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  43    async fn create_invite_from_code(
  44        &self,
  45        code: &str,
  46        email_address: &str,
  47        device_id: Option<&str>,
  48    ) -> Result<Invite>;
  49
  50    async fn create_signup(&self, signup: Signup) -> Result<()>;
  51    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
  52    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
  53    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
  54    async fn create_user_from_invite(
  55        &self,
  56        invite: &Invite,
  57        user: NewUserParams,
  58    ) -> Result<Option<NewUserResult>>;
  59
  60    /// Registers a new project for the given user.
  61    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  62
  63    /// Unregisters a project for the given project id.
  64    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  65
  66    /// Update file counts by extension for the given project and worktree.
  67    async fn update_worktree_extensions(
  68        &self,
  69        project_id: ProjectId,
  70        worktree_id: u64,
  71        extensions: HashMap<String, u32>,
  72    ) -> Result<()>;
  73
  74    /// Get the file counts on the given project keyed by their worktree and extension.
  75    async fn get_project_extensions(
  76        &self,
  77        project_id: ProjectId,
  78    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  79
  80    /// Record which users have been active in which projects during
  81    /// a given period of time.
  82    async fn record_user_activity(
  83        &self,
  84        time_period: Range<OffsetDateTime>,
  85        active_projects: &[(UserId, ProjectId)],
  86    ) -> Result<()>;
  87
  88    /// Get the number of users who have been active in the given
  89    /// time period for at least the given time duration.
  90    async fn get_active_user_count(
  91        &self,
  92        time_period: Range<OffsetDateTime>,
  93        min_duration: Duration,
  94        only_collaborative: bool,
  95    ) -> Result<usize>;
  96
  97    /// Get the users that have been most active during the given time period,
  98    /// along with the amount of time they have been active in each project.
  99    async fn get_top_users_activity_summary(
 100        &self,
 101        time_period: Range<OffsetDateTime>,
 102        max_user_count: usize,
 103    ) -> Result<Vec<UserActivitySummary>>;
 104
 105    /// Get the project activity for the given user and time period.
 106    async fn get_user_activity_timeline(
 107        &self,
 108        time_period: Range<OffsetDateTime>,
 109        user_id: UserId,
 110    ) -> Result<Vec<UserActivityPeriod>>;
 111
 112    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
 113    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
 114    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 115    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 116    async fn dismiss_contact_notification(
 117        &self,
 118        responder_id: UserId,
 119        requester_id: UserId,
 120    ) -> Result<()>;
 121    async fn respond_to_contact_request(
 122        &self,
 123        responder_id: UserId,
 124        requester_id: UserId,
 125        accept: bool,
 126    ) -> Result<()>;
 127
 128    async fn create_access_token_hash(
 129        &self,
 130        user_id: UserId,
 131        access_token_hash: &str,
 132        max_access_token_count: usize,
 133    ) -> Result<()>;
 134    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 135
 136    #[cfg(any(test, feature = "seed-support"))]
 137    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 138    #[cfg(any(test, feature = "seed-support"))]
 139    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 140    #[cfg(any(test, feature = "seed-support"))]
 141    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 142    #[cfg(any(test, feature = "seed-support"))]
 143    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 144    #[cfg(any(test, feature = "seed-support"))]
 145
 146    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 147    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 148    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 149        -> Result<bool>;
 150
 151    #[cfg(any(test, feature = "seed-support"))]
 152    async fn add_channel_member(
 153        &self,
 154        channel_id: ChannelId,
 155        user_id: UserId,
 156        is_admin: bool,
 157    ) -> Result<()>;
 158    async fn create_channel_message(
 159        &self,
 160        channel_id: ChannelId,
 161        sender_id: UserId,
 162        body: &str,
 163        timestamp: OffsetDateTime,
 164        nonce: u128,
 165    ) -> Result<MessageId>;
 166    async fn get_channel_messages(
 167        &self,
 168        channel_id: ChannelId,
 169        count: usize,
 170        before_id: Option<MessageId>,
 171    ) -> Result<Vec<ChannelMessage>>;
 172
 173    #[cfg(test)]
 174    async fn teardown(&self, url: &str);
 175
 176    #[cfg(test)]
 177    fn as_fake(&self) -> Option<&FakeDb>;
 178}
 179
 180#[cfg(any(test, debug_assertions))]
 181pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
 182    Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 183
 184#[cfg(not(any(test, debug_assertions)))]
 185pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
 186
 187pub struct PostgresDb {
 188    pool: sqlx::PgPool,
 189}
 190
 191impl PostgresDb {
 192    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 193        let pool = DbOptions::new()
 194            .max_connections(max_connections)
 195            .connect(url)
 196            .await
 197            .context("failed to connect to postgres database")?;
 198        Ok(Self { pool })
 199    }
 200
 201    pub async fn migrate(
 202        &self,
 203        migrations_path: &Path,
 204        ignore_checksum_mismatch: bool,
 205    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 206        let migrations = MigrationSource::resolve(migrations_path)
 207            .await
 208            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 209
 210        let mut conn = self.pool.acquire().await?;
 211
 212        conn.ensure_migrations_table().await?;
 213        let applied_migrations: HashMap<_, _> = conn
 214            .list_applied_migrations()
 215            .await?
 216            .into_iter()
 217            .map(|m| (m.version, m))
 218            .collect();
 219
 220        let mut new_migrations = Vec::new();
 221        for migration in migrations {
 222            match applied_migrations.get(&migration.version) {
 223                Some(applied_migration) => {
 224                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 225                    {
 226                        Err(anyhow!(
 227                            "checksum mismatch for applied migration {}",
 228                            migration.description
 229                        ))?;
 230                    }
 231                }
 232                None => {
 233                    let elapsed = conn.apply(&migration).await?;
 234                    new_migrations.push((migration, elapsed));
 235                }
 236            }
 237        }
 238
 239        Ok(new_migrations)
 240    }
 241
 242    pub fn fuzzy_like_string(string: &str) -> String {
 243        let mut result = String::with_capacity(string.len() * 2 + 1);
 244        for c in string.chars() {
 245            if c.is_alphanumeric() {
 246                result.push('%');
 247                result.push(c);
 248            }
 249        }
 250        result.push('%');
 251        result
 252    }
 253}
 254
 255#[async_trait]
 256impl Db for PostgresDb {
 257    // users
 258
 259    async fn create_user(
 260        &self,
 261        email_address: &str,
 262        admin: bool,
 263        params: NewUserParams,
 264    ) -> Result<NewUserResult> {
 265        let query = "
 266            INSERT INTO users (email_address, github_login, github_user_id, admin)
 267            VALUES ($1, $2, $3, $4)
 268            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 269            RETURNING id, metrics_id::text
 270        ";
 271        let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 272            .bind(email_address)
 273            .bind(params.github_login)
 274            .bind(params.github_user_id)
 275            .bind(admin)
 276            .fetch_one(&self.pool)
 277            .await?;
 278        Ok(NewUserResult {
 279            user_id,
 280            metrics_id,
 281            signup_device_id: None,
 282            inviting_user_id: None,
 283        })
 284    }
 285
 286    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 287        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 288        Ok(sqlx::query_as(query)
 289            .bind(limit as i32)
 290            .bind((page * limit) as i32)
 291            .fetch_all(&self.pool)
 292            .await?)
 293    }
 294
 295    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 296        let like_string = Self::fuzzy_like_string(name_query);
 297        let query = "
 298            SELECT users.*
 299            FROM users
 300            WHERE github_login ILIKE $1
 301            ORDER BY github_login <-> $2
 302            LIMIT $3
 303        ";
 304        Ok(sqlx::query_as(query)
 305            .bind(like_string)
 306            .bind(name_query)
 307            .bind(limit as i32)
 308            .fetch_all(&self.pool)
 309            .await?)
 310    }
 311
 312    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 313        let users = self.get_users_by_ids(vec![id]).await?;
 314        Ok(users.into_iter().next())
 315    }
 316
 317    async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 318        let query = "
 319            SELECT metrics_id::text
 320            FROM users
 321            WHERE id = $1
 322        ";
 323        Ok(sqlx::query_scalar(query)
 324            .bind(id)
 325            .fetch_one(&self.pool)
 326            .await?)
 327    }
 328
 329    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 330        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 331        let query = "
 332            SELECT users.*
 333            FROM users
 334            WHERE users.id = ANY ($1)
 335        ";
 336        Ok(sqlx::query_as(query)
 337            .bind(&ids)
 338            .fetch_all(&self.pool)
 339            .await?)
 340    }
 341
 342    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 343        let query = format!(
 344            "
 345            SELECT users.*
 346            FROM users
 347            WHERE invite_count = 0
 348            AND inviter_id IS{} NULL
 349            ",
 350            if invited_by_another_user { " NOT" } else { "" }
 351        );
 352
 353        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 354    }
 355
 356    async fn get_user_by_github_account(
 357        &self,
 358        github_login: &str,
 359        github_user_id: Option<i32>,
 360    ) -> Result<Option<User>> {
 361        if let Some(github_user_id) = github_user_id {
 362            let mut user = sqlx::query_as::<_, User>(
 363                "
 364                UPDATE users
 365                SET github_login = $1
 366                WHERE github_user_id = $2
 367                RETURNING *
 368                ",
 369            )
 370            .bind(github_login)
 371            .bind(github_user_id)
 372            .fetch_optional(&self.pool)
 373            .await?;
 374
 375            if user.is_none() {
 376                user = sqlx::query_as::<_, User>(
 377                    "
 378                    UPDATE users
 379                    SET github_user_id = $1
 380                    WHERE github_login = $2
 381                    RETURNING *
 382                    ",
 383                )
 384                .bind(github_user_id)
 385                .bind(github_login)
 386                .fetch_optional(&self.pool)
 387                .await?;
 388            }
 389
 390            Ok(user)
 391        } else {
 392            Ok(sqlx::query_as(
 393                "
 394                SELECT * FROM users
 395                WHERE github_login = $1
 396                LIMIT 1
 397                ",
 398            )
 399            .bind(github_login)
 400            .fetch_optional(&self.pool)
 401            .await?)
 402        }
 403    }
 404
 405    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 406        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 407        Ok(sqlx::query(query)
 408            .bind(is_admin)
 409            .bind(id.0)
 410            .execute(&self.pool)
 411            .await
 412            .map(drop)?)
 413    }
 414
 415    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 416        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 417        Ok(sqlx::query(query)
 418            .bind(connected_once)
 419            .bind(id.0)
 420            .execute(&self.pool)
 421            .await
 422            .map(drop)?)
 423    }
 424
 425    async fn destroy_user(&self, id: UserId) -> Result<()> {
 426        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 427        sqlx::query(query)
 428            .bind(id.0)
 429            .execute(&self.pool)
 430            .await
 431            .map(drop)?;
 432        let query = "DELETE FROM users WHERE id = $1;";
 433        Ok(sqlx::query(query)
 434            .bind(id.0)
 435            .execute(&self.pool)
 436            .await
 437            .map(drop)?)
 438    }
 439
 440    // signups
 441
 442    async fn create_signup(&self, signup: Signup) -> Result<()> {
 443        sqlx::query(
 444            "
 445            INSERT INTO signups
 446            (
 447                email_address,
 448                email_confirmation_code,
 449                email_confirmation_sent,
 450                platform_linux,
 451                platform_mac,
 452                platform_windows,
 453                platform_unknown,
 454                editor_features,
 455                programming_languages,
 456                device_id
 457            )
 458            VALUES
 459                ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
 460            RETURNING id
 461            ",
 462        )
 463        .bind(&signup.email_address)
 464        .bind(&random_email_confirmation_code())
 465        .bind(&signup.platform_linux)
 466        .bind(&signup.platform_mac)
 467        .bind(&signup.platform_windows)
 468        .bind(&signup.editor_features)
 469        .bind(&signup.programming_languages)
 470        .bind(&signup.device_id)
 471        .execute(&self.pool)
 472        .await?;
 473        Ok(())
 474    }
 475
 476    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 477        Ok(sqlx::query_as(
 478            "
 479            SELECT
 480                COUNT(*) as count,
 481                COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 482                COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 483                COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 484                COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 485            FROM (
 486                SELECT *
 487                FROM signups
 488                WHERE
 489                    NOT email_confirmation_sent
 490            ) AS unsent
 491            ",
 492        )
 493        .fetch_one(&self.pool)
 494        .await?)
 495    }
 496
 497    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 498        Ok(sqlx::query_as(
 499            "
 500            SELECT
 501                email_address, email_confirmation_code
 502            FROM signups
 503            WHERE
 504                NOT email_confirmation_sent AND
 505                (platform_mac OR platform_unknown)
 506            LIMIT $1
 507            ",
 508        )
 509        .bind(count as i32)
 510        .fetch_all(&self.pool)
 511        .await?)
 512    }
 513
 514    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 515        sqlx::query(
 516            "
 517            UPDATE signups
 518            SET email_confirmation_sent = 't'
 519            WHERE email_address = ANY ($1)
 520            ",
 521        )
 522        .bind(
 523            &invites
 524                .iter()
 525                .map(|s| s.email_address.as_str())
 526                .collect::<Vec<_>>(),
 527        )
 528        .execute(&self.pool)
 529        .await?;
 530        Ok(())
 531    }
 532
 533    async fn create_user_from_invite(
 534        &self,
 535        invite: &Invite,
 536        user: NewUserParams,
 537    ) -> Result<Option<NewUserResult>> {
 538        let mut tx = self.pool.begin().await?;
 539
 540        let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 541            i32,
 542            Option<UserId>,
 543            Option<UserId>,
 544            Option<String>,
 545        ) = sqlx::query_as(
 546            "
 547            SELECT id, user_id, inviting_user_id, device_id
 548            FROM signups
 549            WHERE
 550                email_address = $1 AND
 551                email_confirmation_code = $2
 552            ",
 553        )
 554        .bind(&invite.email_address)
 555        .bind(&invite.email_confirmation_code)
 556        .fetch_optional(&mut tx)
 557        .await?
 558        .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 559
 560        if existing_user_id.is_some() {
 561            return Ok(None);
 562        }
 563
 564        let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 565            "
 566            INSERT INTO users
 567            (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 568            VALUES
 569            ($1, $2, $3, 'f', $4, $5)
 570            RETURNING id, metrics_id::text
 571            ",
 572        )
 573        .bind(&invite.email_address)
 574        .bind(&user.github_login)
 575        .bind(&user.github_user_id)
 576        .bind(&user.invite_count)
 577        .bind(random_invite_code())
 578        .fetch_one(&mut tx)
 579        .await?;
 580
 581        sqlx::query(
 582            "
 583            UPDATE signups
 584            SET user_id = $1
 585            WHERE id = $2
 586            ",
 587        )
 588        .bind(&user_id)
 589        .bind(&signup_id)
 590        .execute(&mut tx)
 591        .await?;
 592
 593        if let Some(inviting_user_id) = inviting_user_id {
 594            let id: Option<UserId> = sqlx::query_scalar(
 595                "
 596                UPDATE users
 597                SET invite_count = invite_count - 1
 598                WHERE id = $1 AND invite_count > 0
 599                RETURNING id
 600                ",
 601            )
 602            .bind(&inviting_user_id)
 603            .fetch_optional(&mut tx)
 604            .await?;
 605
 606            if id.is_none() {
 607                Err(Error::Http(
 608                    StatusCode::UNAUTHORIZED,
 609                    "no invites remaining".to_string(),
 610                ))?;
 611            }
 612
 613            sqlx::query(
 614                "
 615                INSERT INTO contacts
 616                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 617                VALUES
 618                    ($1, $2, 't', 't', 't')
 619                ",
 620            )
 621            .bind(inviting_user_id)
 622            .bind(user_id)
 623            .execute(&mut tx)
 624            .await?;
 625        }
 626
 627        tx.commit().await?;
 628        Ok(Some(NewUserResult {
 629            user_id,
 630            metrics_id,
 631            inviting_user_id,
 632            signup_device_id,
 633        }))
 634    }
 635
 636    // invite codes
 637
 638    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 639        let mut tx = self.pool.begin().await?;
 640        if count > 0 {
 641            sqlx::query(
 642                "
 643                UPDATE users
 644                SET invite_code = $1
 645                WHERE id = $2 AND invite_code IS NULL
 646            ",
 647            )
 648            .bind(random_invite_code())
 649            .bind(id)
 650            .execute(&mut tx)
 651            .await?;
 652        }
 653
 654        sqlx::query(
 655            "
 656            UPDATE users
 657            SET invite_count = $1
 658            WHERE id = $2
 659            ",
 660        )
 661        .bind(count as i32)
 662        .bind(id)
 663        .execute(&mut tx)
 664        .await?;
 665        tx.commit().await?;
 666        Ok(())
 667    }
 668
 669    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 670        let result: Option<(String, i32)> = sqlx::query_as(
 671            "
 672                SELECT invite_code, invite_count
 673                FROM users
 674                WHERE id = $1 AND invite_code IS NOT NULL 
 675            ",
 676        )
 677        .bind(id)
 678        .fetch_optional(&self.pool)
 679        .await?;
 680        if let Some((code, count)) = result {
 681            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 682        } else {
 683            Ok(None)
 684        }
 685    }
 686
 687    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 688        sqlx::query_as(
 689            "
 690                SELECT *
 691                FROM users
 692                WHERE invite_code = $1
 693            ",
 694        )
 695        .bind(code)
 696        .fetch_optional(&self.pool)
 697        .await?
 698        .ok_or_else(|| {
 699            Error::Http(
 700                StatusCode::NOT_FOUND,
 701                "that invite code does not exist".to_string(),
 702            )
 703        })
 704    }
 705
 706    async fn create_invite_from_code(
 707        &self,
 708        code: &str,
 709        email_address: &str,
 710        device_id: Option<&str>,
 711    ) -> Result<Invite> {
 712        let mut tx = self.pool.begin().await?;
 713
 714        let existing_user: Option<UserId> = sqlx::query_scalar(
 715            "
 716            SELECT id
 717            FROM users
 718            WHERE email_address = $1
 719            ",
 720        )
 721        .bind(email_address)
 722        .fetch_optional(&mut tx)
 723        .await?;
 724        if existing_user.is_some() {
 725            Err(anyhow!("email address is already in use"))?;
 726        }
 727
 728        let row: Option<(UserId, i32)> = sqlx::query_as(
 729            "
 730            SELECT id, invite_count
 731            FROM users
 732            WHERE invite_code = $1
 733            ",
 734        )
 735        .bind(code)
 736        .fetch_optional(&mut tx)
 737        .await?;
 738
 739        let (inviter_id, invite_count) = match row {
 740            Some(row) => row,
 741            None => Err(Error::Http(
 742                StatusCode::NOT_FOUND,
 743                "invite code not found".to_string(),
 744            ))?,
 745        };
 746
 747        if invite_count == 0 {
 748            Err(Error::Http(
 749                StatusCode::UNAUTHORIZED,
 750                "no invites remaining".to_string(),
 751            ))?;
 752        }
 753
 754        let email_confirmation_code: String = sqlx::query_scalar(
 755            "
 756            INSERT INTO signups
 757            (
 758                email_address,
 759                email_confirmation_code,
 760                email_confirmation_sent,
 761                inviting_user_id,
 762                platform_linux,
 763                platform_mac,
 764                platform_windows,
 765                platform_unknown,
 766                device_id
 767            )
 768            VALUES
 769                ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
 770            ON CONFLICT (email_address)
 771            DO UPDATE SET
 772                inviting_user_id = excluded.inviting_user_id
 773            RETURNING email_confirmation_code
 774            ",
 775        )
 776        .bind(&email_address)
 777        .bind(&random_email_confirmation_code())
 778        .bind(&inviter_id)
 779        .bind(&device_id)
 780        .fetch_one(&mut tx)
 781        .await?;
 782
 783        tx.commit().await?;
 784
 785        Ok(Invite {
 786            email_address: email_address.into(),
 787            email_confirmation_code,
 788        })
 789    }
 790
 791    // projects
 792
 793    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 794        Ok(sqlx::query_scalar(
 795            "
 796            INSERT INTO projects(host_user_id)
 797            VALUES ($1)
 798            RETURNING id
 799            ",
 800        )
 801        .bind(host_user_id)
 802        .fetch_one(&self.pool)
 803        .await
 804        .map(ProjectId)?)
 805    }
 806
 807    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 808        sqlx::query(
 809            "
 810            UPDATE projects
 811            SET unregistered = 't'
 812            WHERE id = $1
 813            ",
 814        )
 815        .bind(project_id)
 816        .execute(&self.pool)
 817        .await?;
 818        Ok(())
 819    }
 820
 821    async fn update_worktree_extensions(
 822        &self,
 823        project_id: ProjectId,
 824        worktree_id: u64,
 825        extensions: HashMap<String, u32>,
 826    ) -> Result<()> {
 827        if extensions.is_empty() {
 828            return Ok(());
 829        }
 830
 831        let mut query = QueryBuilder::new(
 832            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 833        );
 834        query.push_values(extensions, |mut query, (extension, count)| {
 835            query
 836                .push_bind(project_id)
 837                .push_bind(worktree_id as i32)
 838                .push_bind(extension)
 839                .push_bind(count as i32);
 840        });
 841        query.push(
 842            "
 843            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 844            count = excluded.count
 845            ",
 846        );
 847        query.build().execute(&self.pool).await?;
 848
 849        Ok(())
 850    }
 851
 852    async fn get_project_extensions(
 853        &self,
 854        project_id: ProjectId,
 855    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 856        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 857        struct WorktreeExtension {
 858            worktree_id: i32,
 859            extension: String,
 860            count: i32,
 861        }
 862
 863        let query = "
 864            SELECT worktree_id, extension, count
 865            FROM worktree_extensions
 866            WHERE project_id = $1
 867        ";
 868        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 869            .bind(&project_id)
 870            .fetch_all(&self.pool)
 871            .await?;
 872
 873        let mut extension_counts = HashMap::default();
 874        for count in counts {
 875            extension_counts
 876                .entry(count.worktree_id as u64)
 877                .or_insert_with(HashMap::default)
 878                .insert(count.extension, count.count as usize);
 879        }
 880        Ok(extension_counts)
 881    }
 882
 883    async fn record_user_activity(
 884        &self,
 885        time_period: Range<OffsetDateTime>,
 886        projects: &[(UserId, ProjectId)],
 887    ) -> Result<()> {
 888        let query = "
 889            INSERT INTO project_activity_periods
 890            (ended_at, duration_millis, user_id, project_id)
 891            VALUES
 892            ($1, $2, $3, $4);
 893        ";
 894
 895        let mut tx = self.pool.begin().await?;
 896        let duration_millis =
 897            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 898        for (user_id, project_id) in projects {
 899            sqlx::query(query)
 900                .bind(time_period.end)
 901                .bind(duration_millis)
 902                .bind(user_id)
 903                .bind(project_id)
 904                .execute(&mut tx)
 905                .await?;
 906        }
 907        tx.commit().await?;
 908
 909        Ok(())
 910    }
 911
 912    async fn get_active_user_count(
 913        &self,
 914        time_period: Range<OffsetDateTime>,
 915        min_duration: Duration,
 916        only_collaborative: bool,
 917    ) -> Result<usize> {
 918        let mut with_clause = String::new();
 919        with_clause.push_str("WITH\n");
 920        with_clause.push_str(
 921            "
 922            project_durations AS (
 923                SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 924                FROM project_activity_periods
 925                WHERE $1 < ended_at AND ended_at <= $2
 926                GROUP BY user_id, project_id
 927            ),
 928            ",
 929        );
 930        with_clause.push_str(
 931            "
 932            project_collaborators as (
 933                SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 934                FROM project_durations
 935                GROUP BY project_id
 936            ),
 937            ",
 938        );
 939
 940        if only_collaborative {
 941            with_clause.push_str(
 942                "
 943                user_durations AS (
 944                    SELECT user_id, SUM(project_duration) as total_duration
 945                    FROM project_durations, project_collaborators
 946                    WHERE
 947                        project_durations.project_id = project_collaborators.project_id AND
 948                        max_collaborators > 1
 949                    GROUP BY user_id
 950                    ORDER BY total_duration DESC
 951                    LIMIT $3
 952                )
 953                ",
 954            );
 955        } else {
 956            with_clause.push_str(
 957                "
 958                user_durations AS (
 959                    SELECT user_id, SUM(project_duration) as total_duration
 960                    FROM project_durations
 961                    GROUP BY user_id
 962                    ORDER BY total_duration DESC
 963                    LIMIT $3
 964                )
 965                ",
 966            );
 967        }
 968
 969        let query = format!(
 970            "
 971            {with_clause}
 972            SELECT count(user_durations.user_id)
 973            FROM user_durations
 974            WHERE user_durations.total_duration >= $3
 975            "
 976        );
 977
 978        let count: i64 = sqlx::query_scalar(&query)
 979            .bind(time_period.start)
 980            .bind(time_period.end)
 981            .bind(min_duration.as_millis() as i64)
 982            .fetch_one(&self.pool)
 983            .await?;
 984        Ok(count as usize)
 985    }
 986
 987    async fn get_top_users_activity_summary(
 988        &self,
 989        time_period: Range<OffsetDateTime>,
 990        max_user_count: usize,
 991    ) -> Result<Vec<UserActivitySummary>> {
 992        let query = "
 993            WITH
 994                project_durations AS (
 995                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 996                    FROM project_activity_periods
 997                    WHERE $1 < ended_at AND ended_at <= $2
 998                    GROUP BY user_id, project_id
 999                ),
1000                user_durations AS (
1001                    SELECT user_id, SUM(project_duration) as total_duration
1002                    FROM project_durations
1003                    GROUP BY user_id
1004                    ORDER BY total_duration DESC
1005                    LIMIT $3
1006                ),
1007                project_collaborators as (
1008                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1009                    FROM project_durations
1010                    GROUP BY project_id
1011                )
1012            SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
1013            FROM user_durations, project_durations, project_collaborators, users
1014            WHERE
1015                user_durations.user_id = project_durations.user_id AND
1016                user_durations.user_id = users.id AND
1017                project_durations.project_id = project_collaborators.project_id
1018            ORDER BY total_duration DESC, user_id ASC, project_id ASC
1019        ";
1020
1021        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
1022            .bind(time_period.start)
1023            .bind(time_period.end)
1024            .bind(max_user_count as i32)
1025            .fetch(&self.pool);
1026
1027        let mut result = Vec::<UserActivitySummary>::new();
1028        while let Some(row) = rows.next().await {
1029            let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
1030            let project_id = project_id;
1031            let duration = Duration::from_millis(duration_millis as u64);
1032            let project_activity = ProjectActivitySummary {
1033                id: project_id,
1034                duration,
1035                max_collaborators: project_collaborators as usize,
1036            };
1037            if let Some(last_summary) = result.last_mut() {
1038                if last_summary.id == user_id {
1039                    last_summary.project_activity.push(project_activity);
1040                    continue;
1041                }
1042            }
1043            result.push(UserActivitySummary {
1044                id: user_id,
1045                project_activity: vec![project_activity],
1046                github_login,
1047            });
1048        }
1049
1050        Ok(result)
1051    }
1052
1053    async fn get_user_activity_timeline(
1054        &self,
1055        time_period: Range<OffsetDateTime>,
1056        user_id: UserId,
1057    ) -> Result<Vec<UserActivityPeriod>> {
1058        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1059
1060        let query = "
1061            SELECT
1062                project_activity_periods.ended_at,
1063                project_activity_periods.duration_millis,
1064                project_activity_periods.project_id,
1065                worktree_extensions.extension,
1066                worktree_extensions.count
1067            FROM project_activity_periods
1068            LEFT OUTER JOIN
1069                worktree_extensions
1070            ON
1071                project_activity_periods.project_id = worktree_extensions.project_id
1072            WHERE
1073                project_activity_periods.user_id = $1 AND
1074                $2 < project_activity_periods.ended_at AND
1075                project_activity_periods.ended_at <= $3
1076            ORDER BY project_activity_periods.id ASC
1077        ";
1078
1079        let mut rows = sqlx::query_as::<
1080            _,
1081            (
1082                PrimitiveDateTime,
1083                i32,
1084                ProjectId,
1085                Option<String>,
1086                Option<i32>,
1087            ),
1088        >(query)
1089        .bind(user_id)
1090        .bind(time_period.start)
1091        .bind(time_period.end)
1092        .fetch(&self.pool);
1093
1094        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1095        while let Some(row) = rows.next().await {
1096            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1097            let ended_at = ended_at.assume_utc();
1098            let duration = Duration::from_millis(duration_millis as u64);
1099            let started_at = ended_at - duration;
1100            let project_time_periods = time_periods.entry(project_id).or_default();
1101
1102            if let Some(prev_duration) = project_time_periods.last_mut() {
1103                if started_at <= prev_duration.end + COALESCE_THRESHOLD
1104                    && ended_at >= prev_duration.start
1105                {
1106                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
1107                } else {
1108                    project_time_periods.push(UserActivityPeriod {
1109                        project_id,
1110                        start: started_at,
1111                        end: ended_at,
1112                        extensions: Default::default(),
1113                    });
1114                }
1115            } else {
1116                project_time_periods.push(UserActivityPeriod {
1117                    project_id,
1118                    start: started_at,
1119                    end: ended_at,
1120                    extensions: Default::default(),
1121                });
1122            }
1123
1124            if let Some((extension, extension_count)) = extension.zip(extension_count) {
1125                project_time_periods
1126                    .last_mut()
1127                    .unwrap()
1128                    .extensions
1129                    .insert(extension, extension_count as usize);
1130            }
1131        }
1132
1133        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1134        durations.sort_unstable_by_key(|duration| duration.start);
1135        Ok(durations)
1136    }
1137
1138    // contacts
1139
1140    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1141        let query = "
1142            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1143            FROM contacts
1144            WHERE user_id_a = $1 OR user_id_b = $1;
1145        ";
1146
1147        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1148            .bind(user_id)
1149            .fetch(&self.pool);
1150
1151        let mut contacts = Vec::new();
1152        while let Some(row) = rows.next().await {
1153            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1154
1155            if user_id_a == user_id {
1156                if accepted {
1157                    contacts.push(Contact::Accepted {
1158                        user_id: user_id_b,
1159                        should_notify: should_notify && a_to_b,
1160                    });
1161                } else if a_to_b {
1162                    contacts.push(Contact::Outgoing { user_id: user_id_b })
1163                } else {
1164                    contacts.push(Contact::Incoming {
1165                        user_id: user_id_b,
1166                        should_notify,
1167                    });
1168                }
1169            } else if accepted {
1170                contacts.push(Contact::Accepted {
1171                    user_id: user_id_a,
1172                    should_notify: should_notify && !a_to_b,
1173                });
1174            } else if a_to_b {
1175                contacts.push(Contact::Incoming {
1176                    user_id: user_id_a,
1177                    should_notify,
1178                });
1179            } else {
1180                contacts.push(Contact::Outgoing { user_id: user_id_a });
1181            }
1182        }
1183
1184        contacts.sort_unstable_by_key(|contact| contact.user_id());
1185
1186        Ok(contacts)
1187    }
1188
1189    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1190        let (id_a, id_b) = if user_id_1 < user_id_2 {
1191            (user_id_1, user_id_2)
1192        } else {
1193            (user_id_2, user_id_1)
1194        };
1195
1196        let query = "
1197            SELECT 1 FROM contacts
1198            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1199            LIMIT 1
1200        ";
1201        Ok(sqlx::query_scalar::<_, i32>(query)
1202            .bind(id_a.0)
1203            .bind(id_b.0)
1204            .fetch_optional(&self.pool)
1205            .await?
1206            .is_some())
1207    }
1208
1209    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1210        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1211            (sender_id, receiver_id, true)
1212        } else {
1213            (receiver_id, sender_id, false)
1214        };
1215        let query = "
1216            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1217            VALUES ($1, $2, $3, 'f', 't')
1218            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1219            SET
1220                accepted = 't',
1221                should_notify = 'f'
1222            WHERE
1223                NOT contacts.accepted AND
1224                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1225                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1226        ";
1227        let result = sqlx::query(query)
1228            .bind(id_a.0)
1229            .bind(id_b.0)
1230            .bind(a_to_b)
1231            .execute(&self.pool)
1232            .await?;
1233
1234        if result.rows_affected() == 1 {
1235            Ok(())
1236        } else {
1237            Err(anyhow!("contact already requested"))?
1238        }
1239    }
1240
1241    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1242        let (id_a, id_b) = if responder_id < requester_id {
1243            (responder_id, requester_id)
1244        } else {
1245            (requester_id, responder_id)
1246        };
1247        let query = "
1248            DELETE FROM contacts
1249            WHERE user_id_a = $1 AND user_id_b = $2;
1250        ";
1251        let result = sqlx::query(query)
1252            .bind(id_a.0)
1253            .bind(id_b.0)
1254            .execute(&self.pool)
1255            .await?;
1256
1257        if result.rows_affected() == 1 {
1258            Ok(())
1259        } else {
1260            Err(anyhow!("no such contact"))?
1261        }
1262    }
1263
1264    async fn dismiss_contact_notification(
1265        &self,
1266        user_id: UserId,
1267        contact_user_id: UserId,
1268    ) -> Result<()> {
1269        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1270            (user_id, contact_user_id, true)
1271        } else {
1272            (contact_user_id, user_id, false)
1273        };
1274
1275        let query = "
1276            UPDATE contacts
1277            SET should_notify = 'f'
1278            WHERE
1279                user_id_a = $1 AND user_id_b = $2 AND
1280                (
1281                    (a_to_b = $3 AND accepted) OR
1282                    (a_to_b != $3 AND NOT accepted)
1283                );
1284        ";
1285
1286        let result = sqlx::query(query)
1287            .bind(id_a.0)
1288            .bind(id_b.0)
1289            .bind(a_to_b)
1290            .execute(&self.pool)
1291            .await?;
1292
1293        if result.rows_affected() == 0 {
1294            Err(anyhow!("no such contact request"))?;
1295        }
1296
1297        Ok(())
1298    }
1299
1300    async fn respond_to_contact_request(
1301        &self,
1302        responder_id: UserId,
1303        requester_id: UserId,
1304        accept: bool,
1305    ) -> Result<()> {
1306        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1307            (responder_id, requester_id, false)
1308        } else {
1309            (requester_id, responder_id, true)
1310        };
1311        let result = if accept {
1312            let query = "
1313                UPDATE contacts
1314                SET accepted = 't', should_notify = 't'
1315                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1316            ";
1317            sqlx::query(query)
1318                .bind(id_a.0)
1319                .bind(id_b.0)
1320                .bind(a_to_b)
1321                .execute(&self.pool)
1322                .await?
1323        } else {
1324            let query = "
1325                DELETE FROM contacts
1326                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1327            ";
1328            sqlx::query(query)
1329                .bind(id_a.0)
1330                .bind(id_b.0)
1331                .bind(a_to_b)
1332                .execute(&self.pool)
1333                .await?
1334        };
1335        if result.rows_affected() == 1 {
1336            Ok(())
1337        } else {
1338            Err(anyhow!("no such contact request"))?
1339        }
1340    }
1341
1342    // access tokens
1343
1344    async fn create_access_token_hash(
1345        &self,
1346        user_id: UserId,
1347        access_token_hash: &str,
1348        max_access_token_count: usize,
1349    ) -> Result<()> {
1350        let insert_query = "
1351            INSERT INTO access_tokens (user_id, hash)
1352            VALUES ($1, $2);
1353        ";
1354        let cleanup_query = "
1355            DELETE FROM access_tokens
1356            WHERE id IN (
1357                SELECT id from access_tokens
1358                WHERE user_id = $1
1359                ORDER BY id DESC
1360                OFFSET $3
1361            )
1362        ";
1363
1364        let mut tx = self.pool.begin().await?;
1365        sqlx::query(insert_query)
1366            .bind(user_id.0)
1367            .bind(access_token_hash)
1368            .execute(&mut tx)
1369            .await?;
1370        sqlx::query(cleanup_query)
1371            .bind(user_id.0)
1372            .bind(access_token_hash)
1373            .bind(max_access_token_count as i32)
1374            .execute(&mut tx)
1375            .await?;
1376        Ok(tx.commit().await?)
1377    }
1378
1379    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1380        let query = "
1381            SELECT hash
1382            FROM access_tokens
1383            WHERE user_id = $1
1384            ORDER BY id DESC
1385        ";
1386        Ok(sqlx::query_scalar(query)
1387            .bind(user_id.0)
1388            .fetch_all(&self.pool)
1389            .await?)
1390    }
1391
1392    // orgs
1393
1394    #[allow(unused)] // Help rust-analyzer
1395    #[cfg(any(test, feature = "seed-support"))]
1396    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1397        let query = "
1398            SELECT *
1399            FROM orgs
1400            WHERE slug = $1
1401        ";
1402        Ok(sqlx::query_as(query)
1403            .bind(slug)
1404            .fetch_optional(&self.pool)
1405            .await?)
1406    }
1407
1408    #[cfg(any(test, feature = "seed-support"))]
1409    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1410        let query = "
1411            INSERT INTO orgs (name, slug)
1412            VALUES ($1, $2)
1413            RETURNING id
1414        ";
1415        Ok(sqlx::query_scalar(query)
1416            .bind(name)
1417            .bind(slug)
1418            .fetch_one(&self.pool)
1419            .await
1420            .map(OrgId)?)
1421    }
1422
1423    #[cfg(any(test, feature = "seed-support"))]
1424    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1425        let query = "
1426            INSERT INTO org_memberships (org_id, user_id, admin)
1427            VALUES ($1, $2, $3)
1428            ON CONFLICT DO NOTHING
1429        ";
1430        Ok(sqlx::query(query)
1431            .bind(org_id.0)
1432            .bind(user_id.0)
1433            .bind(is_admin)
1434            .execute(&self.pool)
1435            .await
1436            .map(drop)?)
1437    }
1438
1439    // channels
1440
1441    #[cfg(any(test, feature = "seed-support"))]
1442    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1443        let query = "
1444            INSERT INTO channels (owner_id, owner_is_user, name)
1445            VALUES ($1, false, $2)
1446            RETURNING id
1447        ";
1448        Ok(sqlx::query_scalar(query)
1449            .bind(org_id.0)
1450            .bind(name)
1451            .fetch_one(&self.pool)
1452            .await
1453            .map(ChannelId)?)
1454    }
1455
1456    #[allow(unused)] // Help rust-analyzer
1457    #[cfg(any(test, feature = "seed-support"))]
1458    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1459        let query = "
1460            SELECT *
1461            FROM channels
1462            WHERE
1463                channels.owner_is_user = false AND
1464                channels.owner_id = $1
1465        ";
1466        Ok(sqlx::query_as(query)
1467            .bind(org_id.0)
1468            .fetch_all(&self.pool)
1469            .await?)
1470    }
1471
1472    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1473        let query = "
1474            SELECT
1475                channels.*
1476            FROM
1477                channel_memberships, channels
1478            WHERE
1479                channel_memberships.user_id = $1 AND
1480                channel_memberships.channel_id = channels.id
1481        ";
1482        Ok(sqlx::query_as(query)
1483            .bind(user_id.0)
1484            .fetch_all(&self.pool)
1485            .await?)
1486    }
1487
1488    async fn can_user_access_channel(
1489        &self,
1490        user_id: UserId,
1491        channel_id: ChannelId,
1492    ) -> Result<bool> {
1493        let query = "
1494            SELECT id
1495            FROM channel_memberships
1496            WHERE user_id = $1 AND channel_id = $2
1497            LIMIT 1
1498        ";
1499        Ok(sqlx::query_scalar::<_, i32>(query)
1500            .bind(user_id.0)
1501            .bind(channel_id.0)
1502            .fetch_optional(&self.pool)
1503            .await
1504            .map(|e| e.is_some())?)
1505    }
1506
1507    #[cfg(any(test, feature = "seed-support"))]
1508    async fn add_channel_member(
1509        &self,
1510        channel_id: ChannelId,
1511        user_id: UserId,
1512        is_admin: bool,
1513    ) -> Result<()> {
1514        let query = "
1515            INSERT INTO channel_memberships (channel_id, user_id, admin)
1516            VALUES ($1, $2, $3)
1517            ON CONFLICT DO NOTHING
1518        ";
1519        Ok(sqlx::query(query)
1520            .bind(channel_id.0)
1521            .bind(user_id.0)
1522            .bind(is_admin)
1523            .execute(&self.pool)
1524            .await
1525            .map(drop)?)
1526    }
1527
1528    // messages
1529
1530    async fn create_channel_message(
1531        &self,
1532        channel_id: ChannelId,
1533        sender_id: UserId,
1534        body: &str,
1535        timestamp: OffsetDateTime,
1536        nonce: u128,
1537    ) -> Result<MessageId> {
1538        let query = "
1539            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1540            VALUES ($1, $2, $3, $4, $5)
1541            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1542            RETURNING id
1543        ";
1544        Ok(sqlx::query_scalar(query)
1545            .bind(channel_id.0)
1546            .bind(sender_id.0)
1547            .bind(body)
1548            .bind(timestamp)
1549            .bind(Uuid::from_u128(nonce))
1550            .fetch_one(&self.pool)
1551            .await
1552            .map(MessageId)?)
1553    }
1554
1555    async fn get_channel_messages(
1556        &self,
1557        channel_id: ChannelId,
1558        count: usize,
1559        before_id: Option<MessageId>,
1560    ) -> Result<Vec<ChannelMessage>> {
1561        let query = r#"
1562            SELECT * FROM (
1563                SELECT
1564                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1565                FROM
1566                    channel_messages
1567                WHERE
1568                    channel_id = $1 AND
1569                    id < $2
1570                ORDER BY id DESC
1571                LIMIT $3
1572            ) as recent_messages
1573            ORDER BY id ASC
1574        "#;
1575        Ok(sqlx::query_as(query)
1576            .bind(channel_id.0)
1577            .bind(before_id.unwrap_or(MessageId::MAX))
1578            .bind(count as i64)
1579            .fetch_all(&self.pool)
1580            .await?)
1581    }
1582
1583    #[cfg(test)]
1584    async fn teardown(&self, url: &str) {
1585        use util::ResultExt;
1586
1587        let query = "
1588            SELECT pg_terminate_backend(pg_stat_activity.pid)
1589            FROM pg_stat_activity
1590            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1591        ";
1592        sqlx::query(query).execute(&self.pool).await.log_err();
1593        self.pool.close().await;
1594        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1595            .await
1596            .log_err();
1597    }
1598
1599    #[cfg(test)]
1600    fn as_fake(&self) -> Option<&FakeDb> {
1601        None
1602    }
1603}
1604
1605macro_rules! id_type {
1606    ($name:ident) => {
1607        #[derive(
1608            Clone,
1609            Copy,
1610            Debug,
1611            Default,
1612            PartialEq,
1613            Eq,
1614            PartialOrd,
1615            Ord,
1616            Hash,
1617            sqlx::Type,
1618            Serialize,
1619            Deserialize,
1620        )]
1621        #[sqlx(transparent)]
1622        #[serde(transparent)]
1623        pub struct $name(pub i32);
1624
1625        impl $name {
1626            #[allow(unused)]
1627            pub const MAX: Self = Self(i32::MAX);
1628
1629            #[allow(unused)]
1630            pub fn from_proto(value: u64) -> Self {
1631                Self(value as i32)
1632            }
1633
1634            #[allow(unused)]
1635            pub fn to_proto(self) -> u64 {
1636                self.0 as u64
1637            }
1638        }
1639
1640        impl std::fmt::Display for $name {
1641            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1642                self.0.fmt(f)
1643            }
1644        }
1645    };
1646}
1647
1648id_type!(UserId);
1649#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1650pub struct User {
1651    pub id: UserId,
1652    pub github_login: String,
1653    pub github_user_id: Option<i32>,
1654    pub email_address: Option<String>,
1655    pub admin: bool,
1656    pub invite_code: Option<String>,
1657    pub invite_count: i32,
1658    pub connected_once: bool,
1659}
1660
1661id_type!(ProjectId);
1662#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1663pub struct Project {
1664    pub id: ProjectId,
1665    pub host_user_id: UserId,
1666    pub unregistered: bool,
1667}
1668
1669#[derive(Clone, Debug, PartialEq, Serialize)]
1670pub struct UserActivitySummary {
1671    pub id: UserId,
1672    pub github_login: String,
1673    pub project_activity: Vec<ProjectActivitySummary>,
1674}
1675
1676#[derive(Clone, Debug, PartialEq, Serialize)]
1677pub struct ProjectActivitySummary {
1678    pub id: ProjectId,
1679    pub duration: Duration,
1680    pub max_collaborators: usize,
1681}
1682
1683#[derive(Clone, Debug, PartialEq, Serialize)]
1684pub struct UserActivityPeriod {
1685    pub project_id: ProjectId,
1686    #[serde(with = "time::serde::iso8601")]
1687    pub start: OffsetDateTime,
1688    #[serde(with = "time::serde::iso8601")]
1689    pub end: OffsetDateTime,
1690    pub extensions: HashMap<String, usize>,
1691}
1692
1693id_type!(OrgId);
1694#[derive(FromRow)]
1695pub struct Org {
1696    pub id: OrgId,
1697    pub name: String,
1698    pub slug: String,
1699}
1700
1701id_type!(ChannelId);
1702#[derive(Clone, Debug, FromRow, Serialize)]
1703pub struct Channel {
1704    pub id: ChannelId,
1705    pub name: String,
1706    pub owner_id: i32,
1707    pub owner_is_user: bool,
1708}
1709
1710id_type!(MessageId);
1711#[derive(Clone, Debug, FromRow)]
1712pub struct ChannelMessage {
1713    pub id: MessageId,
1714    pub channel_id: ChannelId,
1715    pub sender_id: UserId,
1716    pub body: String,
1717    pub sent_at: OffsetDateTime,
1718    pub nonce: Uuid,
1719}
1720
1721#[derive(Clone, Debug, PartialEq, Eq)]
1722pub enum Contact {
1723    Accepted {
1724        user_id: UserId,
1725        should_notify: bool,
1726    },
1727    Outgoing {
1728        user_id: UserId,
1729    },
1730    Incoming {
1731        user_id: UserId,
1732        should_notify: bool,
1733    },
1734}
1735
1736impl Contact {
1737    pub fn user_id(&self) -> UserId {
1738        match self {
1739            Contact::Accepted { user_id, .. } => *user_id,
1740            Contact::Outgoing { user_id } => *user_id,
1741            Contact::Incoming { user_id, .. } => *user_id,
1742        }
1743    }
1744}
1745
1746#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1747pub struct IncomingContactRequest {
1748    pub requester_id: UserId,
1749    pub should_notify: bool,
1750}
1751
1752#[derive(Clone, Deserialize)]
1753pub struct Signup {
1754    pub email_address: String,
1755    pub platform_mac: bool,
1756    pub platform_windows: bool,
1757    pub platform_linux: bool,
1758    pub editor_features: Vec<String>,
1759    pub programming_languages: Vec<String>,
1760    pub device_id: Option<String>,
1761}
1762
1763#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1764pub struct WaitlistSummary {
1765    #[sqlx(default)]
1766    pub count: i64,
1767    #[sqlx(default)]
1768    pub linux_count: i64,
1769    #[sqlx(default)]
1770    pub mac_count: i64,
1771    #[sqlx(default)]
1772    pub windows_count: i64,
1773    #[sqlx(default)]
1774    pub unknown_count: i64,
1775}
1776
1777#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1778pub struct Invite {
1779    pub email_address: String,
1780    pub email_confirmation_code: String,
1781}
1782
1783#[derive(Debug, Serialize, Deserialize)]
1784pub struct NewUserParams {
1785    pub github_login: String,
1786    pub github_user_id: i32,
1787    pub invite_count: i32,
1788}
1789
1790#[derive(Debug)]
1791pub struct NewUserResult {
1792    pub user_id: UserId,
1793    pub metrics_id: String,
1794    pub inviting_user_id: Option<UserId>,
1795    pub signup_device_id: Option<String>,
1796}
1797
1798fn random_invite_code() -> String {
1799    nanoid::nanoid!(16)
1800}
1801
1802fn random_email_confirmation_code() -> String {
1803    nanoid::nanoid!(64)
1804}
1805
1806#[cfg(test)]
1807pub use test::*;
1808
1809#[cfg(test)]
1810mod test {
1811    use super::*;
1812    use anyhow::anyhow;
1813    use collections::BTreeMap;
1814    use gpui::executor::Background;
1815    use lazy_static::lazy_static;
1816    use parking_lot::Mutex;
1817    use rand::prelude::*;
1818    use sqlx::{migrate::MigrateDatabase, Postgres};
1819    use std::sync::Arc;
1820    use util::post_inc;
1821
1822    pub struct FakeDb {
1823        background: Arc<Background>,
1824        pub users: Mutex<BTreeMap<UserId, User>>,
1825        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1826        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1827        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1828        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1829        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1830        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1831        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1832        pub contacts: Mutex<Vec<FakeContact>>,
1833        next_channel_message_id: Mutex<i32>,
1834        next_user_id: Mutex<i32>,
1835        next_org_id: Mutex<i32>,
1836        next_channel_id: Mutex<i32>,
1837        next_project_id: Mutex<i32>,
1838    }
1839
1840    #[derive(Debug)]
1841    pub struct FakeContact {
1842        pub requester_id: UserId,
1843        pub responder_id: UserId,
1844        pub accepted: bool,
1845        pub should_notify: bool,
1846    }
1847
1848    impl FakeDb {
1849        pub fn new(background: Arc<Background>) -> Self {
1850            Self {
1851                background,
1852                users: Default::default(),
1853                next_user_id: Mutex::new(0),
1854                projects: Default::default(),
1855                worktree_extensions: Default::default(),
1856                next_project_id: Mutex::new(1),
1857                orgs: Default::default(),
1858                next_org_id: Mutex::new(1),
1859                org_memberships: Default::default(),
1860                channels: Default::default(),
1861                next_channel_id: Mutex::new(1),
1862                channel_memberships: Default::default(),
1863                channel_messages: Default::default(),
1864                next_channel_message_id: Mutex::new(1),
1865                contacts: Default::default(),
1866            }
1867        }
1868    }
1869
1870    #[async_trait]
1871    impl Db for FakeDb {
1872        async fn create_user(
1873            &self,
1874            email_address: &str,
1875            admin: bool,
1876            params: NewUserParams,
1877        ) -> Result<NewUserResult> {
1878            self.background.simulate_random_delay().await;
1879
1880            let mut users = self.users.lock();
1881            let user_id = if let Some(user) = users
1882                .values()
1883                .find(|user| user.github_login == params.github_login)
1884            {
1885                user.id
1886            } else {
1887                let id = post_inc(&mut *self.next_user_id.lock());
1888                let user_id = UserId(id);
1889                users.insert(
1890                    user_id,
1891                    User {
1892                        id: user_id,
1893                        github_login: params.github_login,
1894                        github_user_id: Some(params.github_user_id),
1895                        email_address: Some(email_address.to_string()),
1896                        admin,
1897                        invite_code: None,
1898                        invite_count: 0,
1899                        connected_once: false,
1900                    },
1901                );
1902                user_id
1903            };
1904            Ok(NewUserResult {
1905                user_id,
1906                metrics_id: "the-metrics-id".to_string(),
1907                inviting_user_id: None,
1908                signup_device_id: None,
1909            })
1910        }
1911
1912        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1913            unimplemented!()
1914        }
1915
1916        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1917            unimplemented!()
1918        }
1919
1920        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1921            self.background.simulate_random_delay().await;
1922            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1923        }
1924
1925        async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
1926            Ok("the-metrics-id".to_string())
1927        }
1928
1929        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1930            self.background.simulate_random_delay().await;
1931            let users = self.users.lock();
1932            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1933        }
1934
1935        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1936            unimplemented!()
1937        }
1938
1939        async fn get_user_by_github_account(
1940            &self,
1941            github_login: &str,
1942            github_user_id: Option<i32>,
1943        ) -> Result<Option<User>> {
1944            self.background.simulate_random_delay().await;
1945            if let Some(github_user_id) = github_user_id {
1946                for user in self.users.lock().values_mut() {
1947                    if user.github_user_id == Some(github_user_id) {
1948                        user.github_login = github_login.into();
1949                        return Ok(Some(user.clone()));
1950                    }
1951                    if user.github_login == github_login {
1952                        user.github_user_id = Some(github_user_id);
1953                        return Ok(Some(user.clone()));
1954                    }
1955                }
1956                Ok(None)
1957            } else {
1958                Ok(self
1959                    .users
1960                    .lock()
1961                    .values()
1962                    .find(|user| user.github_login == github_login)
1963                    .cloned())
1964            }
1965        }
1966
1967        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1968            unimplemented!()
1969        }
1970
1971        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1972            self.background.simulate_random_delay().await;
1973            let mut users = self.users.lock();
1974            let mut user = users
1975                .get_mut(&id)
1976                .ok_or_else(|| anyhow!("user not found"))?;
1977            user.connected_once = connected_once;
1978            Ok(())
1979        }
1980
1981        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1982            unimplemented!()
1983        }
1984
1985        // signups
1986
1987        async fn create_signup(&self, _signup: Signup) -> Result<()> {
1988            unimplemented!()
1989        }
1990
1991        async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1992            unimplemented!()
1993        }
1994
1995        async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
1996            unimplemented!()
1997        }
1998
1999        async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
2000            unimplemented!()
2001        }
2002
2003        async fn create_user_from_invite(
2004            &self,
2005            _invite: &Invite,
2006            _user: NewUserParams,
2007        ) -> Result<Option<NewUserResult>> {
2008            unimplemented!()
2009        }
2010
2011        // invite codes
2012
2013        async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
2014            unimplemented!()
2015        }
2016
2017        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2018            self.background.simulate_random_delay().await;
2019            Ok(None)
2020        }
2021
2022        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2023            unimplemented!()
2024        }
2025
2026        async fn create_invite_from_code(
2027            &self,
2028            _code: &str,
2029            _email_address: &str,
2030            _device_id: Option<&str>,
2031        ) -> Result<Invite> {
2032            unimplemented!()
2033        }
2034
2035        // projects
2036
2037        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2038            self.background.simulate_random_delay().await;
2039            if !self.users.lock().contains_key(&host_user_id) {
2040                Err(anyhow!("no such user"))?;
2041            }
2042
2043            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2044            self.projects.lock().insert(
2045                project_id,
2046                Project {
2047                    id: project_id,
2048                    host_user_id,
2049                    unregistered: false,
2050                },
2051            );
2052            Ok(project_id)
2053        }
2054
2055        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2056            self.background.simulate_random_delay().await;
2057            self.projects
2058                .lock()
2059                .get_mut(&project_id)
2060                .ok_or_else(|| anyhow!("no such project"))?
2061                .unregistered = true;
2062            Ok(())
2063        }
2064
2065        async fn update_worktree_extensions(
2066            &self,
2067            project_id: ProjectId,
2068            worktree_id: u64,
2069            extensions: HashMap<String, u32>,
2070        ) -> Result<()> {
2071            self.background.simulate_random_delay().await;
2072            if !self.projects.lock().contains_key(&project_id) {
2073                Err(anyhow!("no such project"))?;
2074            }
2075
2076            for (extension, count) in extensions {
2077                self.worktree_extensions
2078                    .lock()
2079                    .insert((project_id, worktree_id, extension), count);
2080            }
2081
2082            Ok(())
2083        }
2084
2085        async fn get_project_extensions(
2086            &self,
2087            _project_id: ProjectId,
2088        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2089            unimplemented!()
2090        }
2091
2092        async fn record_user_activity(
2093            &self,
2094            _time_period: Range<OffsetDateTime>,
2095            _active_projects: &[(UserId, ProjectId)],
2096        ) -> Result<()> {
2097            unimplemented!()
2098        }
2099
2100        async fn get_active_user_count(
2101            &self,
2102            _time_period: Range<OffsetDateTime>,
2103            _min_duration: Duration,
2104            _only_collaborative: bool,
2105        ) -> Result<usize> {
2106            unimplemented!()
2107        }
2108
2109        async fn get_top_users_activity_summary(
2110            &self,
2111            _time_period: Range<OffsetDateTime>,
2112            _limit: usize,
2113        ) -> Result<Vec<UserActivitySummary>> {
2114            unimplemented!()
2115        }
2116
2117        async fn get_user_activity_timeline(
2118            &self,
2119            _time_period: Range<OffsetDateTime>,
2120            _user_id: UserId,
2121        ) -> Result<Vec<UserActivityPeriod>> {
2122            unimplemented!()
2123        }
2124
2125        // contacts
2126
2127        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2128            self.background.simulate_random_delay().await;
2129            let mut contacts = Vec::new();
2130
2131            for contact in self.contacts.lock().iter() {
2132                if contact.requester_id == id {
2133                    if contact.accepted {
2134                        contacts.push(Contact::Accepted {
2135                            user_id: contact.responder_id,
2136                            should_notify: contact.should_notify,
2137                        });
2138                    } else {
2139                        contacts.push(Contact::Outgoing {
2140                            user_id: contact.responder_id,
2141                        });
2142                    }
2143                } else if contact.responder_id == id {
2144                    if contact.accepted {
2145                        contacts.push(Contact::Accepted {
2146                            user_id: contact.requester_id,
2147                            should_notify: false,
2148                        });
2149                    } else {
2150                        contacts.push(Contact::Incoming {
2151                            user_id: contact.requester_id,
2152                            should_notify: contact.should_notify,
2153                        });
2154                    }
2155                }
2156            }
2157
2158            contacts.sort_unstable_by_key(|contact| contact.user_id());
2159            Ok(contacts)
2160        }
2161
2162        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2163            self.background.simulate_random_delay().await;
2164            Ok(self.contacts.lock().iter().any(|contact| {
2165                contact.accepted
2166                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2167                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2168            }))
2169        }
2170
2171        async fn send_contact_request(
2172            &self,
2173            requester_id: UserId,
2174            responder_id: UserId,
2175        ) -> Result<()> {
2176            self.background.simulate_random_delay().await;
2177            let mut contacts = self.contacts.lock();
2178            for contact in contacts.iter_mut() {
2179                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2180                    if contact.accepted {
2181                        Err(anyhow!("contact already exists"))?;
2182                    } else {
2183                        Err(anyhow!("contact already requested"))?;
2184                    }
2185                }
2186                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2187                    if contact.accepted {
2188                        Err(anyhow!("contact already exists"))?;
2189                    } else {
2190                        contact.accepted = true;
2191                        contact.should_notify = false;
2192                        return Ok(());
2193                    }
2194                }
2195            }
2196            contacts.push(FakeContact {
2197                requester_id,
2198                responder_id,
2199                accepted: false,
2200                should_notify: true,
2201            });
2202            Ok(())
2203        }
2204
2205        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2206            self.background.simulate_random_delay().await;
2207            self.contacts.lock().retain(|contact| {
2208                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2209            });
2210            Ok(())
2211        }
2212
2213        async fn dismiss_contact_notification(
2214            &self,
2215            user_id: UserId,
2216            contact_user_id: UserId,
2217        ) -> Result<()> {
2218            self.background.simulate_random_delay().await;
2219            let mut contacts = self.contacts.lock();
2220            for contact in contacts.iter_mut() {
2221                if contact.requester_id == contact_user_id
2222                    && contact.responder_id == user_id
2223                    && !contact.accepted
2224                {
2225                    contact.should_notify = false;
2226                    return Ok(());
2227                }
2228                if contact.requester_id == user_id
2229                    && contact.responder_id == contact_user_id
2230                    && contact.accepted
2231                {
2232                    contact.should_notify = false;
2233                    return Ok(());
2234                }
2235            }
2236            Err(anyhow!("no such notification"))?
2237        }
2238
2239        async fn respond_to_contact_request(
2240            &self,
2241            responder_id: UserId,
2242            requester_id: UserId,
2243            accept: bool,
2244        ) -> Result<()> {
2245            self.background.simulate_random_delay().await;
2246            let mut contacts = self.contacts.lock();
2247            for (ix, contact) in contacts.iter_mut().enumerate() {
2248                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2249                    if contact.accepted {
2250                        Err(anyhow!("contact already confirmed"))?;
2251                    }
2252                    if accept {
2253                        contact.accepted = true;
2254                        contact.should_notify = true;
2255                    } else {
2256                        contacts.remove(ix);
2257                    }
2258                    return Ok(());
2259                }
2260            }
2261            Err(anyhow!("no such contact request"))?
2262        }
2263
2264        async fn create_access_token_hash(
2265            &self,
2266            _user_id: UserId,
2267            _access_token_hash: &str,
2268            _max_access_token_count: usize,
2269        ) -> Result<()> {
2270            unimplemented!()
2271        }
2272
2273        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2274            unimplemented!()
2275        }
2276
2277        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2278            unimplemented!()
2279        }
2280
2281        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2282            self.background.simulate_random_delay().await;
2283            let mut orgs = self.orgs.lock();
2284            if orgs.values().any(|org| org.slug == slug) {
2285                Err(anyhow!("org already exists"))?
2286            } else {
2287                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2288                orgs.insert(
2289                    org_id,
2290                    Org {
2291                        id: org_id,
2292                        name: name.to_string(),
2293                        slug: slug.to_string(),
2294                    },
2295                );
2296                Ok(org_id)
2297            }
2298        }
2299
2300        async fn add_org_member(
2301            &self,
2302            org_id: OrgId,
2303            user_id: UserId,
2304            is_admin: bool,
2305        ) -> Result<()> {
2306            self.background.simulate_random_delay().await;
2307            if !self.orgs.lock().contains_key(&org_id) {
2308                Err(anyhow!("org does not exist"))?;
2309            }
2310            if !self.users.lock().contains_key(&user_id) {
2311                Err(anyhow!("user does not exist"))?;
2312            }
2313
2314            self.org_memberships
2315                .lock()
2316                .entry((org_id, user_id))
2317                .or_insert(is_admin);
2318            Ok(())
2319        }
2320
2321        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2322            self.background.simulate_random_delay().await;
2323            if !self.orgs.lock().contains_key(&org_id) {
2324                Err(anyhow!("org does not exist"))?;
2325            }
2326
2327            let mut channels = self.channels.lock();
2328            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2329            channels.insert(
2330                channel_id,
2331                Channel {
2332                    id: channel_id,
2333                    name: name.to_string(),
2334                    owner_id: org_id.0,
2335                    owner_is_user: false,
2336                },
2337            );
2338            Ok(channel_id)
2339        }
2340
2341        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2342            self.background.simulate_random_delay().await;
2343            Ok(self
2344                .channels
2345                .lock()
2346                .values()
2347                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2348                .cloned()
2349                .collect())
2350        }
2351
2352        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2353            self.background.simulate_random_delay().await;
2354            let channels = self.channels.lock();
2355            let memberships = self.channel_memberships.lock();
2356            Ok(channels
2357                .values()
2358                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2359                .cloned()
2360                .collect())
2361        }
2362
2363        async fn can_user_access_channel(
2364            &self,
2365            user_id: UserId,
2366            channel_id: ChannelId,
2367        ) -> Result<bool> {
2368            self.background.simulate_random_delay().await;
2369            Ok(self
2370                .channel_memberships
2371                .lock()
2372                .contains_key(&(channel_id, user_id)))
2373        }
2374
2375        async fn add_channel_member(
2376            &self,
2377            channel_id: ChannelId,
2378            user_id: UserId,
2379            is_admin: bool,
2380        ) -> Result<()> {
2381            self.background.simulate_random_delay().await;
2382            if !self.channels.lock().contains_key(&channel_id) {
2383                Err(anyhow!("channel does not exist"))?;
2384            }
2385            if !self.users.lock().contains_key(&user_id) {
2386                Err(anyhow!("user does not exist"))?;
2387            }
2388
2389            self.channel_memberships
2390                .lock()
2391                .entry((channel_id, user_id))
2392                .or_insert(is_admin);
2393            Ok(())
2394        }
2395
2396        async fn create_channel_message(
2397            &self,
2398            channel_id: ChannelId,
2399            sender_id: UserId,
2400            body: &str,
2401            timestamp: OffsetDateTime,
2402            nonce: u128,
2403        ) -> Result<MessageId> {
2404            self.background.simulate_random_delay().await;
2405            if !self.channels.lock().contains_key(&channel_id) {
2406                Err(anyhow!("channel does not exist"))?;
2407            }
2408            if !self.users.lock().contains_key(&sender_id) {
2409                Err(anyhow!("user does not exist"))?;
2410            }
2411
2412            let mut messages = self.channel_messages.lock();
2413            if let Some(message) = messages
2414                .values()
2415                .find(|message| message.nonce.as_u128() == nonce)
2416            {
2417                Ok(message.id)
2418            } else {
2419                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2420                messages.insert(
2421                    message_id,
2422                    ChannelMessage {
2423                        id: message_id,
2424                        channel_id,
2425                        sender_id,
2426                        body: body.to_string(),
2427                        sent_at: timestamp,
2428                        nonce: Uuid::from_u128(nonce),
2429                    },
2430                );
2431                Ok(message_id)
2432            }
2433        }
2434
2435        async fn get_channel_messages(
2436            &self,
2437            channel_id: ChannelId,
2438            count: usize,
2439            before_id: Option<MessageId>,
2440        ) -> Result<Vec<ChannelMessage>> {
2441            self.background.simulate_random_delay().await;
2442            let mut messages = self
2443                .channel_messages
2444                .lock()
2445                .values()
2446                .rev()
2447                .filter(|message| {
2448                    message.channel_id == channel_id
2449                        && message.id < before_id.unwrap_or(MessageId::MAX)
2450                })
2451                .take(count)
2452                .cloned()
2453                .collect::<Vec<_>>();
2454            messages.sort_unstable_by_key(|message| message.id);
2455            Ok(messages)
2456        }
2457
2458        async fn teardown(&self, _: &str) {}
2459
2460        #[cfg(test)]
2461        fn as_fake(&self) -> Option<&FakeDb> {
2462            Some(self)
2463        }
2464    }
2465
2466    pub struct TestDb {
2467        pub db: Option<Arc<dyn Db>>,
2468        pub url: String,
2469    }
2470
2471    impl TestDb {
2472        #[allow(clippy::await_holding_lock)]
2473        pub async fn postgres() -> Self {
2474            lazy_static! {
2475                static ref LOCK: Mutex<()> = Mutex::new(());
2476            }
2477
2478            let _guard = LOCK.lock();
2479            let mut rng = StdRng::from_entropy();
2480            let name = format!("zed-test-{}", rng.gen::<u128>());
2481            let url = format!("postgres://postgres@localhost/{}", name);
2482            Postgres::create_database(&url)
2483                .await
2484                .expect("failed to create test db");
2485            let db = PostgresDb::new(&url, 5).await.unwrap();
2486            db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
2487                .await
2488                .unwrap();
2489            Self {
2490                db: Some(Arc::new(db)),
2491                url,
2492            }
2493        }
2494
2495        pub fn fake(background: Arc<Background>) -> Self {
2496            Self {
2497                db: Some(Arc::new(FakeDb::new(background))),
2498                url: Default::default(),
2499            }
2500        }
2501
2502        pub fn db(&self) -> &Arc<dyn Db> {
2503            self.db.as_ref().unwrap()
2504        }
2505    }
2506
2507    impl Drop for TestDb {
2508        fn drop(&mut self) {
2509            if let Some(db) = self.db.take() {
2510                futures::executor::block_on(db.teardown(&self.url));
2511            }
2512        }
2513    }
2514}